#define COMPILER_DXC 1
#include "../../hlsl/include/NRD.hlsli"

Texture2D indirectLightRaysHistoryBufferSRV : register(t0, space0);
Texture2D indirectSpecularLightRaysHistoryBufferSRV : register(t1, space0);
Texture2D diffusePrimarySurfaceModulation : register(t2, space0);
//Texture2D specularPrimarySurfaceModulation : register(t3, space0);
Texture2D<float4> maxLuminance : register(t3, space0);

RWTexture2D<float4> pathTracerUAV : register(u0);

SamplerState bilinearWrap : register(s0);

cbuffer globalData : register(b0)
{
    float2 screenSize;
    int    reflectionMode;
    int    shadowMode;
    int    frameIndex;
}

float luminance(float3 v)
{
    return dot(v, float3(0.2126f, 0.7152f, 0.0722f));
}

float3 change_luminance(float3 c_in, float l_out)
{
    float l_in = luminance(c_in);
    return c_in * (l_out / l_in);
}

float3 reinhard_extended_luminance(float3 v, float max_white_l)
{
    float l_old = luminance(v);
    float numerator = l_old * (1.0f + (l_old / (max_white_l * max_white_l)));
    float l_new = numerator / (1.0f + l_old);
    return change_luminance(v, l_new);
}

float3 Reinhard2(float3 x, float L_white) {
    return (x * (1.0 + x / (L_white * L_white))) / (1.0 + x);
}

[numthreads(8, 8, 1)]

void main(int3 threadId : SV_DispatchThreadID,
          int3 threadGroupThreadId : SV_GroupThreadID)
{
    float4 specularUnpacked =
        REBLUR_BackEnd_UnpackRadiance(indirectSpecularLightRaysHistoryBufferSRV[threadId.xy]);
    float4 diffuseUnpacked =
        REBLUR_BackEnd_UnpackRadiance(indirectLightRaysHistoryBufferSRV[threadId.xy]);

    diffuseUnpacked *= diffusePrimarySurfaceModulation[threadId.xy];
    //specularUnpacked *= specularPrimarySurfaceModulation[threadId.xy];

    const float3 rgb = float3((specularUnpacked + diffuseUnpacked).xyz);
 
    float averageLuminance = 0.0;
    for (int i = 0; i < 8; i++)
    {
        for (int j = 0; j < 8; j++)
        {
            uint2 historyIndex = uint2(i, j);
            averageLuminance += maxLuminance[historyIndex].x;
        }
    }

    averageLuminance /= 64.0;

    const float gamma = 2.2;
    //// Regular reinhards
    //float3 mapped = Reinhard2(rgb, averageLuminance);
    //float3 mapped = reinhard_extended_luminance(rgb, averageLuminance);
    //mapped = pow(mapped, float3(1.0, 1.0, 1.0) / float3(gamma, gamma, gamma));

    //// Apply local adaptation
    //float lum = luminance(rgb);
    //float key = maxLuminance[uint2(8, 0)].x;
    //float adaptionRate = 0.05;
    //float adaptedLuminance = lum + (key - lum) * adaptionRate;
    //
    //// Tone mapping
    //float mappedLuminance = adaptedLuminance / (1.0 + adaptedLuminance);
    //
    //// Apply color correction
    //float3 mapped = rgb * (mappedLuminance / lum);
    //mapped = pow(mapped, float3(1.0, 1.0, 1.0) / float3(gamma, gamma, gamma));

    // const float exposure = 0.01;
    const float exposure = 1.0;
    // reinhard tone mapping
    float3 mapped = float3(1.0, 1.0, 1.0) - exp(-((specularUnpacked + diffuseUnpacked).xyz + pathTracerUAV[threadId.xy].xyz) * exposure);
    // gamma correction
    mapped = pow(mapped, float3(1.0, 1.0, 1.0) / float3(gamma, gamma, gamma));

    pathTracerUAV[threadId.xy] = float4(mapped, 1.0);
}