#define COMPILER_DXC 1
#include "../../hlsl/include/NRD.hlsli"

Texture2D indirectLightRaysHistoryBufferSRV : register(t0, space0);
Texture2D indirectSpecularLightRaysHistoryBufferSRV : register(t1, space0);
Texture2D diffusePrimarySurfaceModulation : register(t2, space0);

RWTexture2D<float4> maxLuminance : register(u0);

SamplerState bilinearWrap : register(s0);

cbuffer globalData : register(b0)
{
    float2 screenSize;
    int    reflectionMode;
    int    shadowMode;
    int    frameIndex;
}

float luminance(float3 v)
{
    return dot(v, float3(0.2126f, 0.7152f, 0.0722f));
}

float3 change_luminance(float3 c_in, float l_out)
{
    float l_in = luminance(c_in);
    return c_in * (l_out / l_in);
}

float3 reinhard_extended_luminance(float3 v, float max_white_l)
{
    float l_old = luminance(v);
    float numerator = l_old * (1.0f + (l_old / (max_white_l * max_white_l)));
    float l_new = numerator / (1.0f + l_old);
    return change_luminance(v, l_new);
}

#define GROUP_SIZE_X 32
#define GROUP_SIZE_Y 32
#define GROUP_SIZE GROUP_SIZE_X * GROUP_SIZE_Y

// Declare shared memory
groupshared float sharedMax[GROUP_SIZE];
groupshared float sharedAve[GROUP_SIZE];


[numthreads(GROUP_SIZE_X, GROUP_SIZE_Y, 1)]
void main(int3 dispatchThreadID : SV_DispatchThreadID)
{

    // Initialize local max value to negative infinity
    float localMax = 0.0;
    float localAverage = 0.0;
    int count = 0;

    // Iterate over the texture pixels in a grid pattern
    for (uint y = dispatchThreadID.y; y < screenSize.y; y += GROUP_SIZE_Y)
    {
        for (uint x = dispatchThreadID.x; x < screenSize.x; x += GROUP_SIZE_X)
        {
            uint2 index = uint2(x, y);
            float4 specularUnpacked = REBLUR_BackEnd_UnpackRadiance(indirectSpecularLightRaysHistoryBufferSRV[index]);
            float4 diffuseUnpacked = REBLUR_BackEnd_UnpackRadiance(indirectLightRaysHistoryBufferSRV[index]);
            diffuseUnpacked *= diffusePrimarySurfaceModulation[index];

            float3 rgb = float3((specularUnpacked + diffuseUnpacked).xyz);

            float lum = luminance(rgb);

            // Update local max value
            localMax = max(localMax, lum);

            localAverage += lum;
            count++;
        }
    }

    localAverage /= count;

    // Store local max value to shared memory
    sharedMax[dispatchThreadID.x + dispatchThreadID.y * GROUP_SIZE_X] = localMax;

    sharedAve[dispatchThreadID.x + dispatchThreadID.y * GROUP_SIZE_X] = localAverage;

    // Synchronize threads within the group
    GroupMemoryBarrierWithGroupSync();

    // Perform parallel reduction in shared memory
    for (uint stride = (GROUP_SIZE / 2); stride > 0; stride /= 2)
    {
        if (dispatchThreadID.x < stride && dispatchThreadID.y == 0)
        {
            sharedMax[dispatchThreadID.x] = max(sharedMax[dispatchThreadID.x], sharedMax[dispatchThreadID.x + stride]);
            sharedAve[dispatchThreadID.x] = (sharedAve[dispatchThreadID.x] + sharedAve[dispatchThreadID.x + stride]) / 2;
        }

        // Synchronize threads within the group
        GroupMemoryBarrierWithGroupSync();
    }

    // The maximum value will be stored in sharedMax[0] after the reduction

    if (dispatchThreadID.x == 0 && dispatchThreadID.y == 0)
    {
        // Output the final result
        uint2 historyIndex = {frameIndex % 8, (frameIndex / 8) % 8};

        maxLuminance[historyIndex].x = sharedMax[0];

        maxLuminance[uint2(8, 0)] = sharedAve[0];//log(sharedAve[0] + 1);
    }
}