#version 450

layout(location = 0) in vec2 FragCoord;
layout(location = 0) out vec4 FragColor;

layout(set = 3, binding = 0) uniform PushConstants {
    float u_Time;
    vec2 u_Resolution;
};

#include <noise.glsl>
#include <rotation.glsl>
#include <sdf.glsl>

#define WINDOWS_START 64
#define CURSOR_WAIT 40
#define CURSOR_LAND 130
#define WINDOWS_CLICK 156
#define DEMO_START 164
#define DEMO_END 1056
#define EPSILON 0.001
#define PI 3.14159265
#define FOV 60.
#define SKY_COLOR1 (vec3(0.1, 0.4, 0.8) * 2.)
#define SKY_COLOR2 (vec3(1.0, 0.4, 0.1) * 4.)
#define LIGHT_DIR normalize(vec3(0., -1., -1.))
#define LIGHT_COLOR mix(vec3(5.), SKY_COLOR2, 0.1)

const vec3 MTL_COLORS[] = vec3[](
        vec3(0.5, 0.8, 0.5),
        vec3(0.9),
        vec3(0.56, 0.57, 0.58)
    );

// x = roughness, y = metalness, z = reflectance
const vec3 MTL_PARAMS[] = vec3[](
        vec3(0.1, 0., 1.),
        vec3(0.1, 0., 0.0),
        vec3(0.9, 0., 0.0)
    );

vec3 cam_target() {
    return vec3(0., 2., 0.);
}

vec3 cam_pos() {
    float t = u_Time / 256.;
    vec3 sway = vec3(
            noise(0.3 * t),
            noise(0.3 * t + 5.),
            noise(0.3 * t + 11.)
        );
    sway *= 3.;
    vec3 base = vec3(
            sin(t * PI * 2.) * 16.,
            4.,
            cos(t * PI * 2.) * 16.
        );
    return base + sway;
}

float cam_fov() {
    return FOV;
}

float aspectRatio() {
    return u_Resolution.x / u_Resolution.y;
}

mat3 viewMatrix() {
    vec3 f = normalize(cam_target() - cam_pos());
    vec3 s = -normalize(cross(f, vec3(0., 1., 0.)));
    vec3 u = cross(s, f);
    return mat3(s, u, f);
}

vec3 cameraRay() {
    float c = tan((90. - cam_fov() / 2.) * (PI / 180.));
    return normalize(vec3(FragCoord * vec2(aspectRatio(), 1.), c));
}

vec2 sdFloor(vec3 p) {
    float freq = 1. / 2. * PI;
    float stripes = clamp(sin(p.z * freq) * sin(p.x * freq) * 1024. * 1024., 0.0, 1.);
    return vec2(sdPlaneXZ(p), stripes + 1. + EPSILON);
}

vec2 spheres(vec3 p) {
    float dist = 1. / EPSILON;
    float spacing = 2. * PI / 6.;
    float t = u_Time / 64.;
    float rad = 5.;
    dist = min(dist, sdSphere(
                p - vec3(
                        sin(spacing * 0.) * rad,
                        sin(t + spacing * 0. * 1.434) * 2. + 3.,
                        cos(spacing * 0.) * rad
                    ), 1.)
        );
    dist = min(dist, sdSphere(
                p - vec3(
                        sin(spacing * 1.) * rad,
                        sin(t + spacing * 1. * 1.434) * 2. + 3.,
                        cos(spacing * 1.) * rad
                    ), 1.)
        );
    dist = min(dist, sdSphere(
                p - vec3(
                        sin(spacing * 2.) * rad,
                        sin(t + spacing * 2. * 1.434) * 2. + 3.,
                        cos(spacing * 2.) * rad
                    ), 1.)
        );
    dist = min(dist, sdSphere(
                p - vec3(
                        sin(spacing * 3.) * rad,
                        sin(t + spacing * 3. * 1.434) * 2. + 3.,
                        cos(spacing * 3.) * rad
                    ), 1.)
        );
    dist = min(dist, sdSphere(
                p - vec3(
                        sin(spacing * 4.) * rad,
                        sin(t + spacing * 4. * 1.434) * 2. + 3.,
                        cos(spacing * 4.) * rad
                    ), 1.)
        );
    dist = min(dist, sdSphere(
                p - vec3(
                        sin(spacing * 5.) * rad,
                        sin(t + spacing * 5. * 1.434) * 2. + 3.,
                        cos(spacing * 5.) * rad
                    ), 1.)
        );
    return vec2(dist, 0.);
}

vec2 sdf(vec3 p) {
    return opUnion(spheres(p), sdFloor(p));
}

#include <march.glsl>
#include <brdf.glsl>

vec3 sky(vec3 v) {
    vec3 c1 = SKY_COLOR1;
    vec3 c2 = SKY_COLOR2;
    float angle = 1. - (v.y * 0.5 + 0.5);
    return mix(c1, c2, angle * angle);
}

// Compute light output for a world position
// Rendering Equation:
// Radiance out to view = Emitted radiance to view
// + integral (sort of like sum) over the whole hemisphere:
// brdf(v, l) * incoming irradiance (radiance per area)
vec3 light(vec3 pos, vec3 dir, vec3 n, vec3 l, vec3 lc, vec3 ga, int mtlID) {
    // No emissive surfaces
    vec3 albedo = MTL_COLORS[mtlID];
    vec3 params = clamp(MTL_PARAMS[mtlID], 0., 1.);
    // Light received by the surface
    vec3 irradiance = max(dot(l, n), 0.) * lc;
    irradiance += sky(dir) * 0.1;
    // Attenuate irradiance
    irradiance *= max(marchShadow(pos, l, 2.), 0.2);
    irradiance *= ga;
    // Compute BRDF
    vec3 brd = brdf(l, -dir, n, params.y, params.x, albedo, params.z);

    return irradiance * brd;
}

vec3 render(vec3 pos, vec3 n, vec3 dir, float t, int mtlID) {
    // Compute a mask for parts of the image that should be sky (ray didn't hit)
    float mask = clamp((t + 128.) / 128. - 1., 0., 1.);
    vec3 radiance = light(pos, dir, n, -LIGHT_DIR, LIGHT_COLOR, vec3(1.), mtlID);
    return mix(radiance, sky(dir), mask);
}

vec3 image(vec2 uv) {
    // Spheretrace all surfaces in view
    vec3 origin = cam_pos();
    vec3 dir = viewMatrix() * cameraRay();
    vec2 hit = marchBasic(origin, dir);
    vec3 pos = origin + dir * hit.x;
    vec3 n = normal(pos);
    int mtlID = int(hit.y);

    vec3 radiance = render(pos, n, dir, hit.x, mtlID);

    // Reflect
    if (mtlID == 0) {
        float n_minus1 = 0.5 - 1.;
        float n_plus1 = 0.5 + 1.;
        float r0 = (n_minus1 * n_minus1) / (n_plus1 * n_plus1);
        vec3 fresnel = fresnelSchlick(dot(dir, -n), vec3(r0));
        vec3 rfl_origin = pos;
        vec3 rfl_dir = reflect(dir, n);
        vec2 rfl_hit = marchBasic(rfl_origin, rfl_dir);
        vec3 rfl_pos = rfl_origin + rfl_dir * rfl_hit.x;
        vec3 rfl_n = normal(rfl_pos);
        int rfl_mtlID = int(rfl_hit.y);
        vec3 reflected = render(rfl_pos, rfl_n, rfl_dir, rfl_hit.x, int(rfl_hit.y)) * fresnel;
        radiance += reflected * MTL_PARAMS[mtlID].z;
    }

    return radiance;
}

vec3 palette(float t, vec3 a, vec3 b, vec3 c, vec3 d) {
    return a + b * cos(6.283185 * (c * t + d));
}

vec4 twister(vec2 uv, float width) {
    float t = u_Time / 32.;
    uv.x -= noise(uv.y * 1.6 + t * 2.) * 0.6 - 0.2;

    vec3 color = vec3(0.0);

    float a = (sin(t * 1.4 + uv.y) * 3.0) * uv.y + t * 1.2; //Rotation value
    vec4 twister = vec4(sin(a), sin(a + 0.5 * PI), sin(a + PI), sin(a + 1.5 * PI)) * width;

    vec4 alpha = vec4( //If here in x should be filled or not. Multiply color with this
            (1.0 - clamp(((uv.x - twister.x) * (uv.x - twister.y)) * 1024. * 1024., 0.0, 1.0)),
            (1.0 - clamp(((uv.x - twister.y) * (uv.x - twister.z)) * 1024. * 1024., 0.0, 1.0)),
            (1.0 - clamp(((uv.x - twister.z) * (uv.x - twister.w)) * 1024. * 1024., 0.0, 1.0)),
            (1.0 - clamp(((uv.x - twister.w) * (uv.x - twister.x)) * 1024. * 1024., 0.0, 1.0))
        );

    alpha *= vec4( //Test if line is facing the way it will be showing
            1.0 - clamp((twister.x - twister.y) * 1024. * 1024., 0.0, 1.0),
            1.0 - clamp((twister.y - twister.z) * 1024. * 1024., 0.0, 1.0),
            1.0 - clamp((twister.z - twister.w) * 1024. * 1024., 0.0, 1.0),
            1.0 - clamp((twister.w - twister.x) * 1024. * 1024., 0.0, 1.0)
        );

    vec4 shade = vec4(
            twister.y - twister.x,
            twister.z - twister.y,
            twister.w - twister.z,
            twister.x - twister.w
        );

    shade /= width * 1.8;

    vec3 ca = vec3(0.5);
    vec3 cb = vec3(0.5);
    vec3 cc = vec3(1.);
    vec3 cd = vec3(0., 0.33, 0.67);
    color += alpha.x * palette(0.1, ca, cb, cc, cd) * shade.x;
    color += alpha.y * palette(0.33, ca, cb, cc, cd) * shade.y;
    color += alpha.z * palette(0.58, ca, cb, cc, cd) * shade.z;
    color += alpha.w * palette(0.4, ca, cb, cc, cd) * shade.w;

    return vec4(color * 2., length(alpha));
}

void main() {
    vec2 uv = (FragCoord + 1.) / 2.;

    float fadeIn = 1. - clamp(32. - (u_Time / 8. - DEMO_START / 8.), 0., 1.);

    vec3 sceneColor = image(uv);
    vec4 twisterLayer = twister(FragCoord - vec2(-0.66, 0.), 0.2);
    vec3 color = mix(sceneColor, twisterLayer.rgb, twisterLayer.a * 0.3 * fadeIn);
    color *= 1. - clamp(u_Time / 8. - (DEMO_END / 8. - 1.), 0., 1.);

    FragColor = vec4(color, 1.);
}
