#include "../include/structs.hlsl"
#include "../include/dxr1_0_defines.hlsl"
#define COMPILER_DXC 1
#include "../../hlsl/include/NRD.hlsli"

RaytracingAccelerationStructure             rtAS : register(t0, space0);
Texture2D                                   diffuseTexture[]                 : register(t1, space1);
StructuredBuffer<CompressedAttribute>       vertexBuffer[]                   : register(t2, space2);
Buffer<uint>                                indexBuffer[]                    : register(t3, space3);
Buffer<uint>                                instanceIndexToMaterialMapping   : register(t4, space0);
Buffer<uint>                                instanceIndexToAttributesMapping : register(t5, space0);
Buffer<float>                               instanceNormalMatrixTransforms   : register(t6, space0);
StructuredBuffer<UniformMaterial>           uniformMaterials                 : register(t7, space0);
StructuredBuffer<AlignedHemisphereSample3D> sampleSets                       : register(t8, space0);
TextureCube                                 skyboxTexture                    : register(t9, space0);
Buffer<uint>                                instanceIndexOffsetMapping       : register(t10, space0);
Texture2D                                   spatioTemporalBlueNoise          : register(t11, space0);

RWTexture2D<float4> indirectLightRaysUAV : register(u0);
RWTexture2D<float4> indirectSpecularLightRaysUAV : register(u1);
RWTexture2D<float4> diffusePrimarySurfaceModulation : register(u2);
RWTexture2D<float4> albedoUAV : register(u3);
RWTexture2D<float4> positionUAV : register(u4);
RWTexture2D<float4> normalUAV : register(u5);
RWTexture2D<float4> viewZUAV : register(u6);
RWBuffer<uint>      rayQueue : register(u7);
RWStructuredBuffer<RayDesc>  rayDescBuffer : register(u8);

SamplerState bilinearWrap : register(s0);

#define USE_SANITIZATION 1

cbuffer globalData : register(b0)
{
    float4x4 inverseView;
    float4x4 viewTransform;

    float4 lightColors[MAX_LIGHTS];
    float4 lightPositions[MAX_LIGHTS];
    float4 lightRanges[MAX_LIGHTS / 4];
    uint4  isPointLight[MAX_LIGHTS / 4];
    uint   numLights;

    float2 screenSize;
    float2 uvJitterScale;
    float2 haltonSeq;

    uint seed;
    uint numSamplesPerSet;
    uint numSampleSets;
    uint numPixelsPerDimPerSet;
    uint texturesPerMaterial;

    uint resetHistoryBuffer;
    uint frameNumber;

    uint maxBounces;

    float fov;

    // class enum renderMode
    //{
    //    DIFFUSE_DENOISED = 0,
    //    SPECULAR_DENOISED = 1,
    //    BOTH_DIFFUSE_AND_SPECULAR = 2
    //    DIFFUSE_RAW = 3,
    //    SPECULAR_RAW = 4
    //};

    // Default to both specular and diffuse
    int renderMode;
    int rayBounceIndex;

    int diffuseOrSpecular;
    int reflectionOrRefraction;
    bool enableEmissives;
    bool enableIBL;
    bool enableRayQueueSystem;
    bool freezeRays;
    bool enableSER;
    bool enableMipCalculation;
    bool enableStochasticBilinear;
}

#include "../include/sunLightCommon.hlsl"
#include "../include/utils.hlsl"

static float reflectionIndex = 0.5;
static float refractionIndex = 1.0 - reflectionIndex;

#define TS_USE_MORTON 1
#define TS_TILE_SIZE 8 // seems to be the sweet spot
#define TS_TILE_MASK (TS_TILE_SIZE*TS_TILE_SIZE-1)

inline uint GenericTSComputeLineStride(const uint imageWidth, const uint imageHeight)
{
    uint tileCountX = (imageWidth + TS_TILE_SIZE - 1) / TS_TILE_SIZE;
    return tileCountX * TS_TILE_SIZE;
}
inline uint GenericTSComputePlaneStride(const uint imageWidth, const uint imageHeight)
{
    uint tileCountY = (imageHeight + TS_TILE_SIZE - 1) / TS_TILE_SIZE;
    return GenericTSComputeLineStride(imageWidth, imageHeight) * tileCountY * TS_TILE_SIZE;
}

inline uint2 Morton16BitDecode(uint morton) // morton is expected to be max 16 bit
{
    uint temp = (morton & 0x5555) | ((morton & 0xaaaa) << 15);
    temp = (temp ^ (temp >> 1)) & 0x33333333;
    temp = (temp ^ (temp >> 2)) & 0x0f0f0f0f;
    temp ^= temp >> 4;
    return uint2(0xff & temp, 0xff & (temp >> 16));
}

// "Insert" a 0 bit after each of the 16 low bits of x
uint Part1By1(uint x)
{
    x &= 0x0000ffff;                  // x = ---- ---- ---- ---- fedc ba98 7654 3210
    x = (x ^ (x << 8)) & 0x00ff00ff; // x = ---- ---- fedc ba98 ---- ---- 7654 3210
    x = (x ^ (x << 4)) & 0x0f0f0f0f; // x = ---- fedc ---- ba98 ---- 7654 ---- 3210
    x = (x ^ (x << 2)) & 0x33333333; // x = --fe --dc --ba --98 --76 --54 --32 --10
    x = (x ^ (x << 1)) & 0x55555555; // x = -f-e -d-c -b-a -9-8 -7-6 -5-4 -3-2 -1-0
    return x;
}

// "Insert" two 0 bits after each of the 10 low bits of x
uint Part1By2(uint x)
{
    x &= 0x000003ff;                  // x = ---- ---- ---- ---- ---- --98 7654 3210
    x = (x ^ (x << 16)) & 0xff0000ff; // x = ---- --98 ---- ---- ---- ---- 7654 3210
    x = (x ^ (x << 8)) & 0x0300f00f; // x = ---- --98 ---- ---- 7654 ---- ---- 3210
    x = (x ^ (x << 4)) & 0x030c30c3; // x = ---- --98 ---- 76-- --54 ---- 32-- --10
    x = (x ^ (x << 2)) & 0x09249249; // x = ---- 9--8 --7- -6-- 5--4 --3- -2-- 1--0
    return x;
}

uint EncodeMorton2(uint x, uint y)
{
    return (Part1By1(y) << 1) + Part1By1(x);
}

uint EncodeMorton3(uint x, uint y, uint z)
{
    return (Part1By2(z) << 2) + (Part1By2(y) << 1) + Part1By2(x);
}


inline uint3 GenericTSAddressToPixel(const uint address, const uint lineStride, const uint planeStride) // <- pass ptConstants or StablePlane constants in...
{
    const uint planeIndex = address / planeStride;
    const uint localAddress = address % planeStride;
#if TS_USE_MORTON
    uint2 pixelPos = Morton16BitDecode(localAddress & TS_TILE_MASK);
#else // else simple scanline
    uint tilePixelIndex = localAddress % (TS_TILE_SIZE * TS_TILE_SIZE);
    uint2 pixelPos = uint2(tilePixelIndex % TS_TILE_SIZE, tilePixelIndex / TS_TILE_SIZE); // linear
#endif
    uint maskedLocalAddressBase = (localAddress & ~TS_TILE_MASK) / TS_TILE_SIZE;
    pixelPos += uint2(maskedLocalAddressBase % lineStride, (maskedLocalAddressBase / lineStride) * TS_TILE_SIZE);
    return uint3(pixelPos.x, pixelPos.y, planeIndex);
}



float3 LinearToYCoCg(float3 color)
{
    float Co = color.x - color.z;
    float t = color.z + Co * 0.5;
    float Cg = color.y - t;
    float Y = t + Cg * 0.5;

    // TODO: useful, but not needed in many cases
    Y = max(Y, 0.0);

    return float3(Y, Co, Cg);
}

float3 YCoCgToLinear(float3 color)
{
    // TODO: useful, but not needed in many cases
    color.x = max(color.x, 0.0);

    float  t = color.x - color.z * 0.5;
    float  g = color.z + t;
    float  b = t - color.y * 0.5;
    float  r = b + color.y;
    float3 res = float3(r, g, b);

    return res;
}

float GetPDF(float NoL = 1.0) // default can be useful to handle NoL cancelation ( PDF's NoL
// cancels throughput's NoL )
{
    float pdf = NoL / PI;

    return max(pdf, 1e-7);
}

float3 GetRay(float2 rnd)
{
    float cosTheta = sqrt(saturate(rnd.y));

    float sinTheta = sqrt(saturate(1.0 - cosTheta * cosTheta));
    float phi = rnd.x * 2.0 * PI;

    float3 ray;
    ray.x = sinTheta * cos(phi);
    ray.y = sinTheta * sin(phi);
    ray.z = cosTheta;

    return ray;
}

// http://marc-b-reynolds.github.io/quaternions/2016/07/06/Orthonormal.html
float3x3 GetBasis(float3 N)
{
    float sz = sign(N.z);
    float a = 1.0 / (sz + N.z);
    float ya = N.y * a;
    float b = N.x * ya;
    float c = N.x * sz;

    float3 T = float3(c * N.x * a - 1.0, sz * b, c);
    float3 B = float3(b, N.y * ya - sz, N.y);

    // Note: due to the quaternion formulation, the generated frame is rotated by 180 degrees,
    // s.t. if N = (0, 0, 1), then T = (-1, 0, 0) and B = (0, -1, 0).
    return float3x3(T, B, N);
}

static float4 diffHitDistParams = float4(3.0f, 0.1f, 10.0f, -25.0f);
static float4 specHitDistParams = float4(3.0f, 0.1f, 10.0f, -25.0f);

[shader("raygeneration")]
void ReflectionRaygen()
{
    uint2 screenSizeInts = uint2(screenSize.x, screenSize.y);
    uint lineStride = GenericTSComputeLineStride(screenSizeInts.x, screenSizeInts.y);
    uint planeStride = GenericTSComputePlaneStride(screenSizeInts.x, screenSizeInts.y);

    bool grabbedRays = false;
    uint baseRayThread = 0;
    uint rayCountPerThread = 1;
    uint rayQueueThread = 0;

    int enableCheckerBoarding = true;
    //int checkerBoardDiffuseOrSpecular = frameNumber % 2;

    if (freezeRays)
    {

        Payload payload;
        payload.nrdSpecular = REBLUR_FrontEnd_PackRadianceAndHitDist(float3(0.0, 0.0, 0.0), 0.0, 0);
        payload.nrdDiffuse = REBLUR_FrontEnd_PackRadianceAndHitDist(float3(0.0, 0.0, 0.0), 0.0, 0);
        payload.indirectPos = float3(0.0, 0.0, 0.0);
        payload.indirectNormal = float3(0.0, 0.0, 0.0);
        payload.throughput = float3(1.0, 1.0, 1.0);
        payload.previousPosition = float3(0.0, 0.0, 0.0);
        payload.isEarlyOut = 0;
        payload.albedo = float3(0.0, 0.0, 0.0);
        payload.transmittance = 0.0;
        payload.metallic = 0.0;
        payload.roughness = 0.0;
        payload.i = 0;
        payload.diffuseAlbedoDemodulation = float3(0.0, 0.0, 0.0);
        payload.path = 0.0;
        payload.diffuseRay = 0;// checkerBoardDiffuseOrSpecular;
        payload.grabbedPrimarySurfaceDemodulator = 0;
        payload.occlusion = 0;
        payload.rayIndex = DispatchRaysIndex().xy;


        for (payload.i = 0; payload.i < maxBounces; payload.i++)
        {

            uint rayDescBufferIndex = ((DispatchRaysIndex().x + (DispatchRaysIndex().y * screenSizeInts.x)) * maxBounces) + payload.i;

            payload.indirectPos = rayDescBuffer[rayDescBufferIndex].Origin;
            payload.previousPosition = payload.indirectPos;

#if ENABLE_SER
            if (enableSER == false)
            {
#endif
                TraceRay(rtAS, RAY_FLAG_SKIP_PROCEDURAL_PRIMITIVES | RAY_FLAG_FORCE_OPAQUE, ~0, 0, 0, 0, rayDescBuffer[rayDescBufferIndex], payload);
#if ENABLE_SER
            }
            else
            {
                if (payload.i == 0)
                {
                    TraceRay(rtAS, RAY_FLAG_SKIP_PROCEDURAL_PRIMITIVES | RAY_FLAG_FORCE_OPAQUE, ~0, 0, 0, 0, rayDescBuffer[rayDescBufferIndex], payload);
                }
                else
                {
                    NvHitObject hitObject;
                    NvTraceRayHitObject(rtAS, RAY_FLAG_SKIP_PROCEDURAL_PRIMITIVES | RAY_FLAG_FORCE_OPAQUE, ~0, 0, 0, 0, rayDescBuffer[rayDescBufferIndex], payload, hitObject);
                
                    // Sort based on hit group, instance and geometry index primarily
                    NvReorderThread(hitObject, 0, 0);
                
                    NvInvokeHitObject(rtAS, hitObject, payload);
                }
            }
#endif
            if (payload.isEarlyOut == 1)
            {
                break;
            }
        }

        if (renderMode == 1 || renderMode == 2 || renderMode == 4)
        {
            indirectSpecularLightRaysUAV[payload.rayIndex] = payload.nrdSpecular;
        }
        
        if (renderMode == 0 || renderMode == 2 || renderMode == 3)
        {
            indirectLightRaysUAV[payload.rayIndex] = payload.nrdDiffuse;
        }

        if (payload.grabbedPrimarySurfaceDemodulator == 0)
        {
            diffusePrimarySurfaceModulation[payload.rayIndex] = float4(1.0, 1.0, 1.0, 1.0);
        }
        else
        {
            diffusePrimarySurfaceModulation[payload.rayIndex] = float4(payload.diffuseAlbedoDemodulation.xyz, 1.0);
        }
    }
    else
    {
        while (true)
        {
            Payload payload;
            payload.nrdSpecular = REBLUR_FrontEnd_PackRadianceAndHitDist(float3(0.0, 0.0, 0.0), 0.0, 0);
            payload.nrdDiffuse = REBLUR_FrontEnd_PackRadianceAndHitDist(float3(0.0, 0.0, 0.0), 0.0, 0);
            payload.indirectPos = float3(0.0, 0.0, 0.0);
            payload.indirectNormal = float3(0.0, 0.0, 0.0);
            payload.throughput = float3(1.0, 1.0, 1.0);
            payload.previousPosition = float3(0.0, 0.0, 0.0);
            payload.isEarlyOut = 0;
            payload.albedo = float3(0.0, 0.0, 0.0);
            payload.transmittance = 0.0;
            payload.metallic = 0.0;
            payload.roughness = 0.0;
            payload.i = 0;
            payload.diffuseAlbedoDemodulation = float3(0.0, 0.0, 0.0);
            payload.path = 0.0;
            // First ray is directional and perfect mirror from camera eye so mirror ray it is
            payload.diffuseRay = 0;
            payload.grabbedPrimarySurfaceDemodulator = 0;
            payload.occlusion = 0;
            payload.rayIndex = uint2(0, 0);

            float3 rayDir = float3(0.0, 0.0, 0.0);

            if (enableRayQueueSystem)
            {
                if (grabbedRays == false)
                {
                    InterlockedAdd(rayQueue[0], rayCountPerThread, rayQueueThread);
                    baseRayThread = rayQueueThread;
                    grabbedRays = true;
                }
                else
                {
                    rayQueueThread++;
                }

                if (grabbedRays == true && (rayQueueThread - baseRayThread) == (rayCountPerThread - 1))
                {
                    grabbedRays = false;
                }

                if (rayQueueThread >= (screenSizeInts.x * screenSizeInts.y))
                {
                    break;
                }

#if 0
                payload.rayIndex = GenericTSAddressToPixel(rayQueueThread, lineStride, planeStride).xy;
#else
                payload.rayIndex = uint2(rayQueueThread % screenSizeInts.x, rayQueueThread / screenSizeInts.x);
#endif

            }
            else
            {
                payload.rayIndex = DispatchRaysIndex().xy;
            }

            int adjustedMaxBounces = maxBounces;
            bool lastRay = false;
            bool continueSpecular = false;

            for (payload.i = 0; payload.i < adjustedMaxBounces; payload.i++)
            {
                if (payload.i == 0)
                {
                    GenerateCameraRayHalton(payload.rayIndex, payload.indirectPos, rayDir, viewTransform, haltonSeq);
                }
                else
                {
                    float2 stochastic = GetRandomSample(payload.rayIndex, screenSize).xy;
                    stochastic = (stochastic + 1.0) / 2.0;

                    float3 viewVector = normalize(payload.indirectPos - payload.previousPosition);
                    // Decide whether to sample diffuse or specular BRDF (based on Fresnel term)
                    float brdfProbability = getBrdfProbability(payload.albedo, payload.metallic, viewVector, payload.indirectNormal);

                    // When calculating the fresnel term we need to make the probability flipped for translucent geometry
                    // and shoot rays based on the brdf probability
                    bool isRefractiveRay = false;
                    if (payload.transmittance > 0.0 /*&& stochastic.y < brdfProbability*/)
                    {
                        isRefractiveRay = true;
                    }

                    // specular rays can only be launched 1-3 bounce index
                    if (stochastic.x < brdfProbability/* && payload.i < 3*/)
                    {
                        //throughput /= brdfProbability;
                        payload.diffuseRay = 0;
                    }
                    // diffuse ray
                    else
                    {
                        float3 newRayDir = float3(0.0, 0.0, 0.0);
                        if ((isRefractiveRay == true && reflectionOrRefraction == 2) ||
                            reflectionOrRefraction == 1)
                        {
                            newRayDir = -payload.indirectNormal;
                        }
                        else if ((isRefractiveRay == false && reflectionOrRefraction == 2) ||
                            reflectionOrRefraction == 0)
                        {
                            newRayDir = payload.indirectNormal;
                        }

                        float NdotL = max(0, dot(payload.indirectNormal, newRayDir));

                        //throughput /= (1.0f - brdfProbability);
                        payload.diffuseRay = 1;
                    }

                    if (reflectionOrRefraction == 1 || reflectionOrRefraction == 0 ||
                        diffuseOrSpecular == 1 || diffuseOrSpecular == 0)
                    {
                        brdfProbability = 1.0;
                    }

                    if (reflectionOrRefraction == 1)
                    {
                        isRefractiveRay = true;
                    }
                    else if (reflectionOrRefraction == 0)
                    {
                        isRefractiveRay = false;
                    }

                    if (diffuseOrSpecular == 1)
                    {
                        payload.diffuseRay = 0;
                    }
                    else if (diffuseOrSpecular == 0)
                    {
                        payload.diffuseRay = 1;
                    }
                    
                    if (payload.roughness <= 0.2 || continueSpecular)
                    {
                        payload.diffuseRay = 0;
                        continueSpecular = true;

                    }

                        payload.indirectNormal = normalize(-payload.indirectNormal);

                    if ((payload.diffuseRay == 1 && diffuseOrSpecular == 2) || diffuseOrSpecular == 0)
                    {
                        float pdf = 0.0;
                        if ((isRefractiveRay == true && reflectionOrRefraction == 2) || reflectionOrRefraction == 1)
                        {
                            payload.indirectNormal = normalize(-payload.indirectNormal);
                            rayDir = normalize(GetRandomRayDirection(payload.rayIndex, payload.indirectNormal, screenSize, 0, payload.indirectPos));
                            //rayDir = normalize((indirectNormal + GetRandomRayDirection(threadId.xy, indirectNormal, screenSize, 0, indirectPos)));//indirectNormal;
                            //if (dot(indirectNormal, rayDir) < 0.0)
                            //{
                            //    rayDir = indirectNormal;
                            //}
                            //rayDir = randomDir(uv, indirectNormal, pdf);
                            //rayDir = indirectNormal;
                        }
                        else if ((isRefractiveRay == false && reflectionOrRefraction == 2) || reflectionOrRefraction == 0)
                        {
                            rayDir = normalize(GetRandomRayDirection(payload.rayIndex, payload.indirectNormal, screenSize, 0, payload.indirectPos));
                            //rayDir = normalize((indirectNormal + GetRandomRayDirection(threadId.xy, indirectNormal, screenSize, 0, indirectPos)));//indirectNormal;
                            //if (dot(indirectNormal, rayDir) < 0.0)
                            //{
                            //    rayDir = indirectNormal;
                            //}
                            //rayDir = randomDir(uv, indirectNormal, pdf);
                            //rayDir = indirectNormal;
                        }

                        float3 viewVector = normalize(payload.indirectPos - payload.previousPosition);
                        float NdotL = max(0, dot(payload.indirectNormal, rayDir));

                        float3 diffuseWeight = payload.albedo * (1.0 - payload.metallic);
                        // NdotL is for cosign weighted diffuse distribution
                        payload.throughput *= (diffuseWeight * NdotL);
                    }
                    else if ((payload.diffuseRay == 0 && diffuseOrSpecular == 2) || diffuseOrSpecular == 1)
                    {
                        // Specular
                        float3x3 basis = orthoNormalBasis(payload.indirectNormal);

                        // Sampling of normal distribution function to compute the reflected ray.
                        // See the paper "Sampling the GGX Distribution of Visible Normals" by E. Heitz,
                        // Journal of Computer Graphics Techniques Vol. 7, No. 4, 2018.
                        // http://jcgt.org/published/0007/04/01/paper.pdf

                        float3 viewVector = normalize(payload.indirectPos - payload.previousPosition);

                        float3 V = viewVector;
                        float3 R = reflect(viewVector, payload.indirectNormal);
                        // Tests perfect reflections
                        float3 H = normalize(-V + R);
                        //float3 H = ImportanceSampleGGX_VNDF(stochastic, payload.roughness, V, basis);

                        if ((isRefractiveRay == true && reflectionOrRefraction == 2) || reflectionOrRefraction == 1)
                        {
                            float3 refractedRay = RefractionRay(payload.indirectNormal, V);
                            rayDir = refractedRay;
                            //rayDir = V;
                        }
                        else if ((isRefractiveRay == false && reflectionOrRefraction == 2) || reflectionOrRefraction == 0)
                        {
                            // VNDF reflection sampling
                            rayDir = reflect(V, H);
                        }

                        if ((isRefractiveRay == true && reflectionOrRefraction == 2) || reflectionOrRefraction == 1)
                        {
                            payload.indirectNormal = normalize(-payload.indirectNormal);
                        }

                        float3 N = payload.indirectNormal;

                        float NdotV = max(0, -dot(N, viewVector));
                        float NdotL = max(0, dot(N, rayDir));
                        float VoH = max(0, -dot(V, H));

                        float3 F0 = float3(0.04f, 0.04f, 0.04f);
                        F0 = lerp(F0, payload.albedo, payload.metallic);
                        float3 F = FresnelSchlick(VoH, F0);

                        float3 specularWeight = F * Smith_G2_Over_G1_Height_Correlated(payload.roughness, payload.roughness * payload.roughness, NdotL, NdotV);
                        payload.throughput *= specularWeight;
                    }
                }

                float3 rayDirection = normalize(rayDir);

                payload.previousPosition = payload.indirectPos;

                RayDesc ray;
                ray.TMin = MIN_RAY_LENGTH;
                ray.TMax = MAX_RAY_LENGTH;

                ray.Origin = payload.indirectPos + (payload.indirectNormal * 0.001);

                ray.Direction = rayDirection;

                uint rayDescBufferIndex = ((payload.rayIndex.x + (payload.rayIndex.y * screenSizeInts.x)) * maxBounces) + payload.i;

                rayDescBuffer[rayDescBufferIndex] = ray;

#if ENABLE_SER
                if (enableSER == false)
                {
#endif
                    TraceRay(rtAS, RAY_FLAG_SKIP_PROCEDURAL_PRIMITIVES | RAY_FLAG_FORCE_OPAQUE, ~0, 0, 0, 0, ray, payload);
#if ENABLE_SER
                }
                else
                {
                    if (payload.i == 0)
                    {
                        TraceRay(rtAS, RAY_FLAG_SKIP_PROCEDURAL_PRIMITIVES | RAY_FLAG_FORCE_OPAQUE, ~0, 0, 0, 0, ray, payload);
                    }
                    else
                    {
                        NvHitObject hitObject;
                        NvTraceRayHitObject(rtAS, RAY_FLAG_SKIP_PROCEDURAL_PRIMITIVES | RAY_FLAG_FORCE_OPAQUE, ~0, 0, 0, 0, ray, payload, hitObject);
                
                        // Sort based on hit group, instance and geometry index primarily
                        NvReorderThread(hitObject, 0, 0);
                
                        //float3 hitPosition = hitObject.GetWorldRayOrigin() + hitObject.GetClosestRayT() * hitObject.GetWorldRayDirection();
                
                        //Morton code based on position
                        //uint32 sortKey = EncodeMorton3(hitPosition.x, hitPosition.y, hitPosition.z);
                
                        //NvReorderThread(hitObject, 32);
                
                        //// Sort based on hit material to improve texture access
                        //int geometryIndex = hitObject.GetGeometryIndex();
                        //int primitiveIndex = hitObject.GetPrimitiveIndex();
                        //int instanceIndex = hitObject.GetInstanceIndex();
                        //
                        //int materialIndex = instanceIndexToMaterialMapping[instanceIndexOffsetMapping[instanceIndex] + geometryIndex];
                        //int attributeIndex = instanceIndexToAttributesMapping[instanceIndexOffsetMapping[instanceIndex] + geometryIndex];
                        //
                        //int sortKey = attributeIndex;
                        //NvReorderThread(sortKey, 16);
                
                        //uint geomBits = (geometryIndex << 24) & 0xFF000000;
                        //uint primBits = (primitiveIndex << 16) & 0x00FF0000;
                        //uint rayQueueThreadBits = rayQueueThread & 0xFFFF;
                        ////uint rayQueueThreadBits = (rayQueueThread << 16) & 0xFFFF0000;
                        ////uint geomBits = (geometryIndex << 8) & 0x0000FF00;
                        ////uint primBits = primitiveIndex & 0xFF;
                        //
                        //int sortKey = geomBits | primBits | rayQueueThreadBits;
                        //NvReorderThread(sortKey, 32);
                
                        NvInvokeHitObject(rtAS, hitObject, payload);
                    }
                }
#endif

                if (payload.isEarlyOut == 1)
                {
                    break;
                }

                //if ((payload.transmittance > 0.0 || payload.roughness < 0.1) && payload.i == adjustedMaxBounces)
                //{
                //    adjustedMaxBounces++;
                //}
                //else if(lastRay == false)
                //{
                //    lastRay = true;
                //    // Need to sample the diffuse material once
                //    adjustedMaxBounces++;
                //}
            }

            if (renderMode == 1 || renderMode == 2 || renderMode == 4)
            {
                indirectSpecularLightRaysUAV[payload.rayIndex] = payload.nrdSpecular;
            }

            if (renderMode == 0 || renderMode == 2 || renderMode == 3)
            {
                indirectLightRaysUAV[payload.rayIndex] = payload.nrdDiffuse;
            }

            if (payload.grabbedPrimarySurfaceDemodulator == 0)
            {
                diffusePrimarySurfaceModulation[payload.rayIndex] = float4(1.0, 1.0, 1.0, 1.0);
            }
            else
            {
                diffusePrimarySurfaceModulation[payload.rayIndex] = float4(payload.diffuseAlbedoDemodulation.xyz, 1.0);
            }

            if (enableRayQueueSystem == false)
            {
                break;
            }
        }
    }
}

[shader("anyhit")] void ReflectionAnyHit(inout Payload                               payload,
                                         in    BuiltInTriangleIntersectionAttributes attr)
{
    //RayTraversalData rayData;
    //rayData.worldRayOrigin    = WorldRayOrigin();
    //rayData.currentRayT       = RayTCurrent();
    //// Anyhit invokation dictates that previous accepted non opaque triangle is farther away
    //// than current
    //rayData.closestRayT       = RayTCurrent() + 1.0;
    //rayData.worldRayDirection = WorldRayDirection();
    //rayData.geometryIndex     = GeometryIndex();
    //rayData.primitiveIndex    = PrimitiveIndex();
    //rayData.instanceIndex     = InstanceIndex();
    //rayData.barycentrics      = attr.barycentrics;
    //rayData.objectToWorld     = ObjectToWorld4x3();

    //bool isHit = ProcessTransparentTriangle(rayData);
    //if (isHit == false)
    //{
    //    IgnoreHit();
    //}
}

[shader("closesthit")] void ReflectionClosestHit(inout Payload                               payload,
                                                 in    BuiltInTriangleIntersectionAttributes attr)
{

    if (freezeRays)
    {
        float3 color = float3(((float)payload.i) / ((float)maxBounces), ((float)payload.i) / ((float)maxBounces), ((float)payload.i) / ((float)maxBounces));
        if (payload.i == 0)
        {
            color = float3(1.0, 0.0, 0.0);
        }
        else if (payload.i == 1)
        {
            color = float3(0.0, 1.0, 0.0);
        }
        else if (payload.i == 2)
        {
            color = float3(0.0, 0.0, 1.0);
        }
        else if (payload.i == 3)
        {
            color = float3(1.0, 1.0, 1.0);
        }

        payload.nrdDiffuse = float4(color.xyz, 1.0);
        payload.grabbedPrimarySurfaceDemodulator = 1;
        payload.diffuseAlbedoDemodulation = color.xyz;
    }
    else
    {
        RayTraversalData rayData;
        rayData.worldRayOrigin    = WorldRayOrigin();
        rayData.closestRayT       = RayTCurrent();
        rayData.worldRayDirection = WorldRayDirection();
        rayData.geometryIndex     = GeometryIndex();
        rayData.primitiveIndex    = PrimitiveIndex();
        rayData.instanceIndex     = InstanceIndex();
        rayData.barycentrics      = attr.barycentrics;
        rayData.objectToWorld     = ObjectToWorld4x3();
        rayData.uvIsValid         = false;
        rayData.enableMipCalculation = enableMipCalculation;
        rayData.isFrontFace       = HitKind();

        //int    geometryIndex = rayData.geometryIndex;
        //int    primitiveIndex = rayData.primitiveIndex;
        //int    instanceIndex = rayData.instanceIndex;
        //float2 barycentrics = rayData.barycentrics;
        //
        //int materialIndex = instanceIndexToMaterialMapping[instanceIndexOffsetMapping[instanceIndex] + geometryIndex];
        //
        //uint width = 0;
        //uint height = 0;
        //uint numMips = 0;
        //diffuseTexture[NonUniformResourceIndex(materialIndex)].GetDimensions(0, width, height, numMips);
        //uint2 texDimensions = uint2(width, height);

        //rayData.uvJitter = GetRandomSample(DispatchRaysIndex().xy, screenSize).xy / (texDimensions /** 4*/);
        //rayData.uvJitter = GetRandomSample(DispatchRaysIndex().xy, screenSize).xy * uvJitterScale;
        rayData.uvJitter = GetBlueNoise(DispatchRaysIndex().xy, frameNumber) * uvJitterScale;

        float3 emissiveColor = float3(0.0, 0.0, 0.0);

        ProcessOpaqueTriangle(rayData, payload.albedo, payload.roughness, payload.metallic, payload.indirectNormal, payload.indirectPos,
            payload.transmittance, emissiveColor);

        emissiveColor *= enableEmissives ? 10.0 : 0.0;

        float3 accumulatedLightRadiance = float3(0.0, 0.0, 0.0);
        float3 accumulatedDiffuseRadiance = float3(0.0, 0.0, 0.0);
        float3 accumulatedSpecularRadiance = float3(0.0, 0.0, 0.0);

        if (rayData.isFrontFace == HIT_KIND_TRIANGLE_BACK_FACE)
        {
            payload.indirectNormal = -payload.indirectNormal;
        }

        int pointLightCount = 0;
        for (int lightIndex = 0; lightIndex < numLights; lightIndex++)
        {
            float3 lightRadiance            = float3(0.0, 0.0, 0.0);
            float3 indirectDiffuseRadiance  = float3(0.0, 0.0, 0.0);
            float3 indirectSpecularRadiance = float3(0.0, 0.0, 0.0);

            float3 indirectLighting = GetBRDFLight(payload.albedo, payload.indirectNormal, payload.indirectPos, payload.roughness, payload.metallic, payload.rayIndex, payload.previousPosition,
                                      lightPositions[lightIndex].xyz, isPointLight[lightIndex/4][lightIndex%4], lightRanges[lightIndex/4][lightIndex%4], lightColors[lightIndex].xyz,
                                      indirectDiffuseRadiance, indirectSpecularRadiance, lightRadiance);

            // bug fix for light leaking
            if (length(lightRadiance) > 0.0)
            {
                accumulatedLightRadiance += lightRadiance;
                accumulatedDiffuseRadiance += indirectDiffuseRadiance;
                accumulatedSpecularRadiance += indirectSpecularRadiance;
            }
        }

        // Primary surface recording for denoiser
        if (payload.i == 0)
        {
            payload.diffuseAlbedoDemodulation        = payload.albedo/* * (1.0 - payload.transmittance)*/;
            payload.grabbedPrimarySurfaceDemodulator = 1;

            normalUAV[payload.rayIndex.xy].xyz   = (-payload.indirectNormal + 1.0) / 2.0;
            positionUAV[payload.rayIndex.xy].xyz = payload.indirectPos;
            albedoUAV[payload.rayIndex.xy].xyz = payload.albedo;

            // Denoiser can't handle roughness value of 0.0
            normalUAV[payload.rayIndex.xy].w   = payload.roughness;
            positionUAV[payload.rayIndex.xy].w = rayData.instanceIndex;// +rayData.geometryIndex;
            albedoUAV[payload.rayIndex.xy].w   = payload.metallic;

            viewZUAV[payload.rayIndex.xy].x = mul(float4(payload.indirectPos, 1.0), viewTransform).z;
        }

        if (payload.i == 0)
        {
            float3 light = float3(0.0, 0.0, 0.0);

            light = (accumulatedSpecularRadiance + accumulatedDiffuseRadiance) *
                    accumulatedLightRadiance * payload.throughput;

            if (length(emissiveColor) > 0.0)
            {
                // Account for emissive surfaces
                light += payload.throughput * emissiveColor;
            }

            float sampleWeight = NRD_GetSampleWeight(light, USE_SANITIZATION);

            if (payload.i == rayBounceIndex || rayBounceIndex == -1)
            {
                payload.nrdDiffuse +=
                    REBLUR_FrontEnd_PackRadianceAndHitDist(light, 0, USE_SANITIZATION) *
                    sampleWeight;
            }
        }
        else
        {
            if (payload.diffuseRay == 1 || payload.roughness > 0.2)
            {
                float3 light = float3(0.0, 0.0, 0.0);

                light = (accumulatedSpecularRadiance + accumulatedDiffuseRadiance) *
                        accumulatedLightRadiance * payload.throughput;

                if (length(emissiveColor) > 0.0)
                {
                    // Account for emissive surfaces
                    light += payload.throughput * emissiveColor;
                }

                float sampleWeight = NRD_GetSampleWeight(light, USE_SANITIZATION);

                if (payload.i == rayBounceIndex || rayBounceIndex == -1)
                {

                    payload.path += NRD_GetCorrectedHitDist(rayData.closestRayT, payload.i, 1.0);

                    float normDist = REBLUR_FrontEnd_GetNormHitDist(payload.path,
                                                                    viewZUAV[payload.rayIndex.xy].x,
                                                                    diffHitDistParams);

                    payload.nrdDiffuse += REBLUR_FrontEnd_PackRadianceAndHitDist(light, normDist,
                                                                            USE_SANITIZATION) *
                            sampleWeight;
                }
            }
            else
            {
                float3 light = float3(0.0, 0.0, 0.0);
               
                light = (accumulatedSpecularRadiance + accumulatedDiffuseRadiance) *
                        accumulatedLightRadiance * payload.throughput;

                if (length(emissiveColor) > 0.0)
                {
                    // Account for emissive surfaces
                    light += payload.throughput * emissiveColor;
                }

                float sampleWeight = NRD_GetSampleWeight(light, USE_SANITIZATION);

                if (payload.i == rayBounceIndex || rayBounceIndex == -1)
                {
                    payload.path += NRD_GetCorrectedHitDist(rayData.closestRayT, payload.i, payload.roughness);
                    float normDist = REBLUR_FrontEnd_GetNormHitDist(payload.path,
                                                                    viewZUAV[payload.rayIndex.xy].x,
                                                                    specHitDistParams);

                    payload.nrdSpecular += REBLUR_FrontEnd_PackRadianceAndHitDist(light, normDist,
                                                                            USE_SANITIZATION) *
                        sampleWeight;
                        
                }
            }
        }
    }
    payload.isEarlyOut = 0;
}

[shader("miss")]
void ReflectionMiss(inout Payload payload)
{
    if (freezeRays)
    {
        
    }
    else
    {
        float3 sampleVector = normalize(WorldRayDirection());
        float4 dayColor = min(skyboxTexture.SampleLevel(bilinearWrap, float3(sampleVector.x, sampleVector.y, sampleVector.z), 0), 10.0);

        dayColor *= 25.0;

        if (enableIBL == false)
        {
            dayColor *= 0.0;
        }

        if (payload.i == 0)
        {
            payload.grabbedPrimarySurfaceDemodulator = 1;
            payload.diffuseAlbedoDemodulation = dayColor.xyz;

            if (payload.i == rayBounceIndex || rayBounceIndex == -1)
            {
                payload.nrdDiffuse = float4(1.0, 1.0, 1.0, 1.0);
            }

            albedoUAV[payload.rayIndex.xy] = float4(dayColor.xyz, 0.0);
            normalUAV[payload.rayIndex.xy] = float4(0.0, 0.0, 0.0, 1.0);
            positionUAV[payload.rayIndex.xy] = float4(0.0, 0.0, 0.0, -1.0);

            viewZUAV[payload.rayIndex.xy].x = 1e7f;
        }
        else
        {
            if (payload.diffuseRay == 1 || payload.roughness > 0.2)
            {
                payload.path += NRD_GetCorrectedHitDist(1e7f, payload.i, payload.roughness);
                float normDist = REBLUR_FrontEnd_GetNormHitDist(payload.path, viewZUAV[payload.rayIndex.xy].x, diffHitDistParams);
                float3 light = dayColor.xyz * payload.throughput;

                if (payload.i == rayBounceIndex || rayBounceIndex == -1)
                {
                    payload.nrdDiffuse += REBLUR_FrontEnd_PackRadianceAndHitDist(light, normDist,
                        USE_SANITIZATION);
                }

            }
            else
            {
                payload.path += NRD_GetCorrectedHitDist(1e7f, payload.i, payload.roughness);
                float  normDist = REBLUR_FrontEnd_GetNormHitDist(payload.path, viewZUAV[payload.rayIndex.xy].x, specHitDistParams);
                float3 light = dayColor.xyz * payload.throughput;

                if (payload.i == rayBounceIndex || rayBounceIndex == -1)
                {
                    payload.nrdSpecular += REBLUR_FrontEnd_PackRadianceAndHitDist(light, normDist,
                        USE_SANITIZATION);
                }
            }
        }
    }
    payload.isEarlyOut = 1;
}

// Shadow occlusion hit group and miss shader
[shader("anyhit")]
void ShadowAnyHit(inout ShadowPayload                            payload,
                  in BuiltInTriangleIntersectionAttributes attr)
{
    RayTraversalData rayData;
    rayData.worldRayOrigin = WorldRayOrigin();
    rayData.currentRayT    = RayTCurrent();
    // Anyhit invokation dictates that previous accepted non opaque triangle is farther away
    // than current
    rayData.closestRayT       = RayTCurrent() + 1.0;
    rayData.worldRayDirection = WorldRayDirection();
    rayData.geometryIndex     = GeometryIndex();
    rayData.primitiveIndex    = PrimitiveIndex();
    rayData.instanceIndex     = InstanceIndex();
    rayData.barycentrics      = attr.barycentrics;
    rayData.objectToWorld     = ObjectToWorld4x3();

    bool isHit = ProcessTransparentTriangle(rayData);
    if (isHit == false)
    {
        payload.occlusion = 0;
        IgnoreHit();
    }
    else
    {
        payload.occlusion = 1;
    }
}

[shader("closesthit")]
void ShadowClosestHit(inout ShadowPayload                            payload,
                      in BuiltInTriangleIntersectionAttributes attr) 
{
    payload.occlusion = 1;
}

[shader("miss")]
void ShadowMiss(inout ShadowPayload payload)
{
    payload.occlusion = 0;
}