#include "../include/structs.hlsl"
#include "../include/math.hlsl"
#include "../include/dxr1_0_defines.hlsl"
#define COMPILER_DXC 1
#include "../../hlsl/include/NRD.hlsli"

#if ENABLE_SER
    #define NV_HITOBJECT_USE_MACRO_API

    #define NV_SHADER_EXTN_SLOT           u127       // matches slot number in NvAPI_D3D12_SetNvShaderExtnSlotSpace 
    #define NV_SHADER_EXTN_REGISTER_SPACE space0   // matches space number in NvAPI_D3D12_SetNvShaderExtnSlotSpace 

    #include "../../../../libs/NVAPI/nvHLSLExtns.h"
#endif

RaytracingAccelerationStructure             rtAS                             : register(t0, space0);

RWTexture2D<float4>         indirectLightRaysUAV : register(u0);
RWTexture2D<float4>         indirectSpecularLightRaysUAV : register(u1);
RWTexture2D<float4>         diffusePrimarySurfaceModulation : register(u2);
RWStructuredBuffer<RayDesc> rayDescBuffer : register(u3);

SamplerState bilinearWrap : register(s0);

#define USE_SANITIZATION 1

cbuffer globalData : register(b0)
{
    float4x4 inverseView;
    float4x4 viewTransform;

    float4 lightColors[MAX_LIGHTS];
    float4 lightPositions[MAX_LIGHTS];
    float4 lightRanges[MAX_LIGHTS / 4];
    uint4  isPointLight[MAX_LIGHTS / 4];
    uint   numLights;

    float2 screenSize;

    uint seed;
    uint numSamplesPerSet;
    uint numSampleSets;
    uint numPixelsPerDimPerSet;
    uint texturesPerMaterial;

    uint resetHistoryBuffer;
    uint frameNumber;

    uint maxBounces;

    float fov;

    // class enum renderMode
    //{
    //    DIFFUSE_DENOISED = 0,
    //    SPECULAR_DENOISED = 1,
    //    BOTH_DIFFUSE_AND_SPECULAR = 2
    //    DIFFUSE_RAW = 3,
    //    SPECULAR_RAW = 4
    //};

    // Default to both specular and diffuse
    int renderMode;
    int rayBounceIndex;

    int diffuseOrSpecular;
    int reflectionOrRefraction;
    bool enableEmissives;
    bool enableIBL;
    bool enableRayQueueSystem;
    bool freezeRays;
    bool enableSER;
    bool enableMipCalculation;
}

//#include "../include/sunLightCommon.hlsl"
//#include "../include/utils.hlsl"

// Camera ray with projective rays eminating from a single point being the eye location
void GenerateCameraRay(uint2 index, out float3 origin, out float3 direction, in float4x4 viewTransform)
{
    // Projection ray
    float3 cameraPosition = float3(inverseView[3][0], inverseView[3][1], inverseView[3][2]);
    origin = cameraPosition;

    //float fov = 30.0f;
    float imageAspectRatio = screenSize.x / screenSize.y; // assuming width > height
    float Px = (2.0 * ((index.x + 0.5) / screenSize.x) - 1.0) * tan(fov / 2.0 * PI / 180.0) *
        imageAspectRatio;
    float Py = (1.0 - 2.0 * ((index.y + 0.5) / screenSize.y)) * tan(fov / 2.0 * PI / 180.0);

    float4 rayDirection = float4(Px, Py, 1.0, 1.0);
    float4 rayOrigin = float4(0, 0, 0, 1.0);
    float3 rayOriginWorld = mul(viewTransform, rayOrigin).xyz;
    float3 rayPWorld = mul(viewTransform, rayDirection).xyz;
    direction = normalize(rayPWorld - rayOriginWorld);
}

[shader("raygeneration")]
void PrimaryRaygen()
{
    uint2 screenSizeInts = uint2(screenSize.x, screenSize.y);

    Payload payload;
    payload.nrdSpecular = REBLUR_FrontEnd_PackRadianceAndHitDist(float3(0.0, 0.0, 0.0), 0.0, 0);
    payload.nrdDiffuse = REBLUR_FrontEnd_PackRadianceAndHitDist(float3(0.0, 0.0, 0.0), 0.0, 0);
    payload.indirectPos = float3(0.0, 0.0, 0.0);
    payload.indirectNormal = float3(0.0, 0.0, 0.0);
    payload.throughput = float3(1.0, 1.0, 1.0);
    payload.previousPosition = float3(0.0, 0.0, 0.0);
    payload.isEarlyOut = 0;
    payload.albedo = float3(0.0, 0.0, 0.0);
    payload.transmittance = 0.0;
    payload.metallic = 0.0;
    payload.roughness = 0.0;
    payload.i = 0;
    payload.diffuseAlbedoDemodulation = float3(0.0, 0.0, 0.0);
    payload.path = 0.0;
    payload.diffuseRay = 0;
    payload.grabbedPrimarySurfaceDemodulator = 0;
    payload.occlusion = 0;
    payload.rayIndex = DispatchRaysIndex().xy;


    for (payload.i = 0; payload.i < maxBounces; payload.i++)
    {

        uint rayDescBufferIndex = ((DispatchRaysIndex().x + (DispatchRaysIndex().y * screenSizeInts.x)) * maxBounces) + payload.i;

        payload.indirectPos = rayDescBuffer[rayDescBufferIndex].Origin;
        payload.previousPosition = payload.indirectPos;

#if ENABLE_SER
        if (enableSER == false)
        {
#endif
            //float3 rayDirection = float3(0.0, 0.0, 0.0);
            //GenerateCameraRay(payload.rayIndex, payload.indirectPos, rayDirection, viewTransform);
            //
            //RayDesc ray;
            //ray.TMin = MIN_RAY_LENGTH;
            //ray.TMax = MAX_RAY_LENGTH;
            //ray.Origin = payload.indirectPos;
            //ray.Direction = rayDirection;
            //TraceRay(rtAS, RAY_FLAG_SKIP_PROCEDURAL_PRIMITIVES | RAY_FLAG_FORCE_OPAQUE, ~0, 0, 0, 0, ray, payload);

            TraceRay(rtAS, RAY_FLAG_SKIP_PROCEDURAL_PRIMITIVES | RAY_FLAG_FORCE_OPAQUE, ~0, 0, 0, 0, rayDescBuffer[rayDescBufferIndex], payload);
#if ENABLE_SER
        }
        else
        {
            if (payload.i == 0)
            {
                TraceRay(rtAS, RAY_FLAG_SKIP_PROCEDURAL_PRIMITIVES | RAY_FLAG_FORCE_OPAQUE, ~0, 0, 0, 0, rayDescBuffer[rayDescBufferIndex], payload);
            }
            else
            {
                NvHitObject hitObject;
                NvTraceRayHitObject(rtAS, RAY_FLAG_SKIP_PROCEDURAL_PRIMITIVES | RAY_FLAG_FORCE_OPAQUE, ~0, 0, 0, 0, rayDescBuffer[rayDescBufferIndex], payload, hitObject);
        
                NvReorderThread(hitObject, 0, 0);
        
                NvInvokeHitObject(rtAS, hitObject, payload);
            }
        }
#endif
        if (payload.isEarlyOut == 1)
        {
            break;
        }
    }

    if (renderMode == 1 || renderMode == 2 || renderMode == 4)
    {
        indirectSpecularLightRaysUAV[payload.rayIndex] = payload.nrdSpecular;
    }

    if (renderMode == 0 || renderMode == 2 || renderMode == 3)
    {
        indirectLightRaysUAV[payload.rayIndex] = payload.nrdDiffuse;
    }

    if (payload.grabbedPrimarySurfaceDemodulator == 0)
    {
        diffusePrimarySurfaceModulation[payload.rayIndex] = float4(1.0, 1.0, 1.0, 1.0);
    }
    else
    {
        diffusePrimarySurfaceModulation[payload.rayIndex] = float4(payload.diffuseAlbedoDemodulation.xyz, 1.0);
    }
}

[shader("anyhit")] void PrimaryAnyHit(inout Payload                               payload,
                                      in    BuiltInTriangleIntersectionAttributes attr)
{
    RayTraversalData rayData;
    rayData.worldRayOrigin    = WorldRayOrigin();
    rayData.currentRayT       = RayTCurrent();
    // Anyhit invokation dictates that previous accepted non opaque triangle is farther away
    // than current
    rayData.closestRayT       = RayTCurrent() + 1.0;
    rayData.worldRayDirection = WorldRayDirection();
    rayData.geometryIndex     = GeometryIndex();
    rayData.primitiveIndex    = PrimitiveIndex();
    rayData.instanceIndex     = InstanceIndex();
    rayData.barycentrics      = attr.barycentrics;
    rayData.objectToWorld     = ObjectToWorld4x3();

    /*bool isHit = ProcessTransparentTriangle(rayData);
    if (isHit == false)
    {
        IgnoreHit();
    }*/
}

[shader("closesthit")] void PrimaryClosestHit(inout Payload                               payload,
                                              in    BuiltInTriangleIntersectionAttributes attr)
{
    float3 color = float3(((float)payload.i) / ((float)maxBounces), ((float)payload.i) / ((float)maxBounces), ((float)payload.i) / ((float)maxBounces));
    if (payload.i == 0)
    {
        color = float3(1.0, 0.0, 0.0);
    }
    else if (payload.i == 1)
    {
        color = float3(0.0, 1.0, 0.0);
    }
    else if (payload.i == 2)
    {
        color = float3(0.0, 0.0, 1.0);
    }
    else if (payload.i == 3)
    {
        color = float3(1.0, 1.0, 1.0);
    }

    payload.nrdDiffuse = float4(color.xyz, 1.0);
    payload.grabbedPrimarySurfaceDemodulator = 1;
    payload.diffuseAlbedoDemodulation = color.xyz;

    payload.isEarlyOut = 0;

}

[shader("miss")]
void PrimaryMiss(inout Payload payload)
{
    payload.isEarlyOut = 1;

}
