#include "../../include/structs.hlsl"
#include "../../include/dxr1_1_defines.hlsl"

RaytracingAccelerationStructure             rtAS : register(t0, space0);
//Texture2D                                   diffuseTexture[]                 : register(t1, space1);
//StructuredBuffer<CompressedAttribute>       vertexBuffer[]                   : register(t2, space2);
//Buffer<uint>                                indexBuffer[]                    : register(t3, space3);
//Buffer<uint>                                instanceIndexToMaterialMapping   : register(t4, space0);
//Buffer<uint>                                instanceIndexToAttributesMapping : register(t5, space0);
//Buffer<float>                               instanceNormalMatrixTransforms   : register(t6, space0);
//StructuredBuffer<UniformMaterial>           uniformMaterials                 : register(t7, space0);
//StructuredBuffer<AlignedHemisphereSample3D> sampleSets                       : register(t8, space0);
//TextureCube                                 skyboxTexture                    : register(t9, space0);
//Buffer<uint>                                instanceIndexOffsetMapping       : register(t10, space0);


RWTexture2D<float4>         indirectLightRaysUAV : register(u0);
RWTexture2D<float4>         indirectSpecularLightRaysUAV : register(u1);
RWTexture2D<float4>         diffusePrimarySurfaceModulation : register(u2);
RWStructuredBuffer<RayDesc> rayDescBuffer : register(u3);

SamplerState bilinearWrap : register(s0);

#define USE_SANITIZATION 1

cbuffer globalData : register(b0)
{
    float4x4 inverseView;
    float4x4 viewTransform;

    float4 lightColors[MAX_LIGHTS];
    float4 lightPositions[MAX_LIGHTS];
    float4 lightRanges[MAX_LIGHTS / 4];
    uint4  isPointLight[MAX_LIGHTS / 4];
    uint   numLights;

    float2 screenSize;

    uint seed;
    uint numSamplesPerSet;
    uint numSampleSets;
    uint numPixelsPerDimPerSet;
    uint texturesPerMaterial;

    uint resetHistoryBuffer;
    uint frameNumber;

    uint maxBounces;

    float fov;

    // class enum renderMode
    //{
    //    DIFFUSE_DENOISED = 0,
    //    SPECULAR_DENOISED = 1,
    //    BOTH_DIFFUSE_AND_SPECULAR = 2
    //    DIFFUSE_RAW = 3,
    //    SPECULAR_RAW = 4
    //};

    // Default to both specular and diffuse
    int renderMode;
    int rayBounceIndex;

    int diffuseOrSpecular;
    int reflectionOrRefraction;
    bool enableEmissives;
    bool enableIBL;
    bool enableRayQueueSystem;
    bool freezeRays;
    bool enableSER;
    bool enableMipCalculation;
}

//#include "../../include/utils.hlsl"

static float reflectionIndex = 0.5;
static float refractionIndex = 1.0 - reflectionIndex;

[numthreads(8, 8, 1)]

    void
    main(int3 threadId           : SV_DispatchThreadID,
         int3 threadGroupThreadId : SV_GroupThreadID) {
        float3 rayDir = float3(0.0, 0.0, 0.0);
        float3 origin = float3(1.0, 1.0, 1.0);

        //GenerateCameraRay(threadId.xy, origin, rayDir, viewTransform);

        RayDesc ray;
        ray.Origin    = origin;
        ray.Direction = rayDir;
        ray.TMin      = MIN_RAY_LENGTH;
        ray.TMax      = MAX_RAY_LENGTH;

        rayDescBuffer[0] = ray;

        RayQuery<RAY_FLAG_SKIP_PROCEDURAL_PRIMITIVES | RAY_FLAG_FORCE_OPAQUE> rayQuery;
        rayQuery.TraceRayInline(rtAS, RAY_FLAG_NONE, ~0, ray);
        rayQuery.Proceed();

        //while (rayQuery.Proceed())
        //{
        //    RayTraversalData rayData;
        //    rayData.worldRayOrigin    = rayQuery.WorldRayOrigin();
        //    rayData.currentRayT       = rayQuery.CandidateTriangleRayT();
        //    rayData.closestRayT       = rayQuery.CommittedRayT();
        //    rayData.worldRayDirection = rayQuery.WorldRayDirection();
        //    rayData.geometryIndex     = rayQuery.CandidateGeometryIndex();
        //    rayData.primitiveIndex    = rayQuery.CandidatePrimitiveIndex();
        //    rayData.instanceIndex     = rayQuery.CandidateInstanceIndex();
        //    rayData.barycentrics      = rayQuery.CandidateTriangleBarycentrics();
        //    rayData.objectToWorld     = rayQuery.CandidateObjectToWorld4x3();
        //
        //    bool isHit = ProcessTransparentTriangle(rayData);
        //    if (isHit)
        //    {
        //        rayQuery.CommitNonOpaqueTriangleHit();
        //    }
        //}

        if (rayQuery.CommittedStatus() == COMMITTED_TRIANGLE_HIT)
        {
            float3 albedo;
            float  roughness;
            float  metallic;
            float3 normal;
            float3 hitPosition;
            float  transmittance;
            float3 emissiveColor;

            /*RayTraversalData rayData;
            rayData.worldRayOrigin    = rayQuery.WorldRayOrigin();
            rayData.closestRayT       = rayQuery.CommittedRayT();
            rayData.worldRayDirection = rayQuery.WorldRayDirection();
            rayData.geometryIndex     = rayQuery.CommittedGeometryIndex();
            rayData.primitiveIndex    = rayQuery.CommittedPrimitiveIndex();
            rayData.instanceIndex     = rayQuery.CommittedInstanceIndex();
            rayData.barycentrics      = rayQuery.CommittedTriangleBarycentrics();
            rayData.objectToWorld     = rayQuery.CommittedObjectToWorld4x3();
            rayData.uvIsValid         = false;

            ProcessOpaqueTriangle(rayData,
                                  albedo,
                                  roughness,
                                  metallic,
                                  normal,
                                  hitPosition,
                                  transmittance,
                                  emissiveColor);

            if (rayQuery.CommittedTriangleFrontFace() == false)
            {
                normal = -normal;
            }

            normalUAV[threadId.xy].xyz   = (normal + 1.0) / 2.0;
            positionUAV[threadId.xy].xyz = hitPosition;
            albedoUAV[threadId.xy].xyz   = albedo.xyz;

            // Denoiser can't handle roughness value of 0.0
            normalUAV[threadId.xy].w     = max(roughness, 0.05);
            positionUAV[threadId.xy].w   = rayData.instanceIndex;
            albedoUAV[threadId.xy].w     = metallic;

            viewZUAV[threadId.xy].x =  mul(float4(hitPosition, 1.0), viewTransform).z;*/
        }
        else
        {
            float3 sampleVector = normalize(rayDir);
            float4 dayColor = float4(1.0, 1.0, 1.0, 0.0);//skyboxTexture.SampleLevel(bilinearWrap, float3(sampleVector.x, sampleVector.y, sampleVector.z), 0);

            //albedoUAV[threadId.xy]   = float4(dayColor.xyz, 0.0);
            //normalUAV[threadId.xy]   = float4(0.0, 0.0, 0.0, 1.0);
            //positionUAV[threadId.xy] = float4(0.0, 0.0, 0.0, -1.0);
            //
            //viewZUAV[threadId.xy].x = 1e5;
        }

        if (renderMode == 1 || renderMode == 2 || renderMode == 4)
        {
            indirectSpecularLightRaysUAV[threadId.xy] = float4(1.0, 1.0, 1.0, 1.0);
        }

        if (renderMode == 0 || renderMode == 2 || renderMode == 3)
        {
            indirectLightRaysUAV[threadId.xy] = float4(1.0, 1.0, 1.0, 1.0);
        }

        diffusePrimarySurfaceModulation[threadId.xy] = float4(1.0, 1.0, 1.0, 1.0);
    }