mod fps;
pub mod resource;
mod text;

use crate::{config, render::resource::ResourceManager};
use anyhow::{Context, Result};
use crevice::std140::AsStd140;
use fps::FpsCounter;
use glam::Vec2;
use sdl3::{
    gpu::{
        ColorTargetBlendState, ColorTargetDescription, ColorTargetInfo, CommandBuffer, CompareOp,
        CullMode, DepthStencilState, Device, FillMode, Filter, FrontFace, GraphicsPipeline,
        GraphicsPipelineTargetInfo, LoadOp, PrimitiveType, RasterizerState, RenderPass,
        SampleCount, Sampler, SamplerAddressMode, SamplerCreateInfo, SamplerMipmapMode, Shader,
        ShaderFormat, StoreOp, Texture, TextureCreateInfo, TextureFormat, TextureSamplerBinding,
        TextureType, TextureUsage, VertexInputState,
    },
    pixels::Color,
    sys::{
        gpu::{
            SDL_GPU_FILTER_NEAREST, SDL_GPU_LOADOP_CLEAR, SDL_GPU_STOREOP_STORE, SDL_GPUBlitInfo,
            SDL_GPUBlitRegion, SDL_GPUColorTargetInfo, SDL_GPUViewport,
            SDL_WaitAndAcquireGPUSwapchainTexture,
        },
        pixels::SDL_FColor,
        surface::SDL_FlipMode,
    },
    video::Window,
};
use std::{
    collections::HashSet,
    mem,
    ops::{BitOr, Range},
    path::Path,
    ptr,
};
use text::Text;

#[derive(AsStd140)]
struct Uniforms {
    time: f32,
    resolution: Vec2,
}

enum TextureInput {
    Framebuffer(usize),
    File(String),
}

struct ShaderQuad {
    pipeline: GraphicsPipeline,
    textures: Vec<TextureInput>,
    timing: Option<Range<f32>>,
}

impl ShaderQuad {
    fn draw(
        &self,
        render_pass: &RenderPass,
        resource_manager: &ResourceManager,
        framebuffers: &[Texture<'static>],
        sampler: &Sampler,
        time: f32,
    ) {
        if let Some(timing) = &self.timing {
            if time < timing.start || time >= timing.end {
                return;
            }
        }
        render_pass.bind_graphics_pipeline(&self.pipeline);
        let sampler_bindings: Vec<TextureSamplerBinding> = self
            .textures
            .iter()
            .map(|texture_input| {
                TextureSamplerBinding::new()
                    .with_texture(match texture_input {
                        TextureInput::File(name) => resource_manager.texture(name).unwrap(),
                        TextureInput::Framebuffer(i) => &framebuffers[*i],
                    })
                    .with_sampler(sampler)
            })
            .collect();
        render_pass.bind_fragment_samplers(0, &sampler_bindings);
        render_pass.draw_primitives(4, 1, 0, 0);
    }
}

enum Draw {
    Shader(ShaderQuad),
    Text(Text),
}

impl Draw {
    pub fn from_config(
        cfg: &config::Draw,
        renderer: &Renderer,
        vertex_shader: &Shader,
        target_formats: &[TextureFormat],
    ) -> Result<Self> {
        Ok(match cfg {
            config::Draw::Shader(config::Shader {
                name,
                textures,
                timing,
            }) => Draw::Shader(ShaderQuad {
                pipeline: Renderer::load_shader_effect(
                    &renderer.gpu,
                    vertex_shader,
                    textures.len() as u32,
                    target_formats,
                    name,
                )?,
                textures: textures
                    .iter()
                    .map(|t| match t {
                        config::Texture::Framebuffer(i) => TextureInput::Framebuffer(*i),
                        config::Texture::File(name) => TextureInput::File(name.clone()),
                    })
                    .collect(),
                timing: timing.clone(),
            }),
            config::Draw::Text(text) => Draw::Text(Text::new(
                &renderer.gpu,
                target_formats[0],
                text,
                renderer.fps.value(),
            )?),
        })
    }

    pub fn draw(
        &mut self,
        command_buffer: &CommandBuffer,
        render_pass: &RenderPass,
        framebuffers: &[Texture<'static>],
        resource_manager: &ResourceManager,
        sampler: &Sampler,
        time: f32,
        resolution: config::Resolution,
    ) {
        // This goes to layout (set=3, binding = 0) as per
        // https://wiki.libsdl.org/SDL3/SDL_CreateGPUShader
        let uniforms = Uniforms {
            time,
            resolution: Vec2::new(resolution.width as f32, resolution.height as f32),
        }
        .as_std140();
        command_buffer.push_fragment_uniform_data(0, &uniforms);

        match self {
            Draw::Shader(shader_quad) => {
                shader_quad.draw(render_pass, resource_manager, framebuffers, sampler, time);
            }
            Draw::Text(text) => {
                text.draw(command_buffer, render_pass, resolution, time);
            }
        }
    }
}

struct Pass {
    draw: Vec<Draw>,
    targets: Vec<usize>,
}

struct Output {
    draw: Vec<Draw>,
}

pub struct Renderer {
    config: config::Render,
    window: Window,
    gpu: Device,
    fps: FpsCounter,
    passes: Vec<Pass>,
    output: Output,
    framebuffers: Vec<Texture<'static>>,
    output_buffer: Texture<'static>,
    resource_manager: ResourceManager,
    sampler: Sampler,
}

impl Renderer {
    pub fn new(gpu: &Device, window: &Window, config: &config::Render) -> Result<Self> {
        let sampler = gpu.create_sampler(
            SamplerCreateInfo::new()
                .with_min_filter(Filter::Nearest)
                .with_mag_filter(Filter::Nearest)
                .with_mipmap_mode(SamplerMipmapMode::Nearest)
                .with_address_mode_u(SamplerAddressMode::ClampToEdge)
                .with_address_mode_v(SamplerAddressMode::ClampToEdge)
                .with_address_mode_w(SamplerAddressMode::ClampToEdge),
        )?;

        let (code, stage) = resource::load_shader("quad.vert")?;
        let vertex_shader = gpu
            .create_shader()
            .with_code(ShaderFormat::SpirV, &code, stage)
            .with_entrypoint(c"main")
            .build()
            .context("Can't create vertex shader")?;

        let window_format = gpu.get_swapchain_texture_format(window);
        let output_buffer = gpu
            .create_texture(
                TextureCreateInfo::new()
                    .with_type(TextureType::_2D)
                    .with_format(window_format)
                    .with_usage(TextureUsage::ColorTarget.bitor(TextureUsage::Sampler))
                    .with_width(config.resolution.width)
                    .with_height(config.resolution.height)
                    .with_layer_count_or_depth(1)
                    .with_num_levels(1)
                    .with_sample_count(SampleCount::NoMultiSampling),
            )
            .context("Can't create framebuffer texture")?;

        let texture_files: HashSet<&str> = config
            .passes
            .iter()
            .flat_map(|pass| pass.draw.iter())
            .chain(config.output.draw.iter())
            .filter_map(|draw| match draw {
                config::Draw::Shader(shader) => Some(&shader.textures),
                config::Draw::Text(_) => None,
            })
            .flat_map(|v| {
                v.iter().filter_map(|conf| match conf {
                    config::Texture::Framebuffer(_) => None,
                    config::Texture::File(name) => Some(name.as_str()),
                })
            })
            .collect();
        let resource_manager = ResourceManager::new(gpu, texture_files);

        let mut renderer = Self {
            config: config.clone(),
            window: window.clone(),
            gpu: gpu.clone(),
            fps: FpsCounter::new(),
            passes: Vec::new(),
            output: Output { draw: Vec::new() },
            framebuffers: Vec::new(),
            output_buffer,
            sampler,
            resource_manager,
        };

        for pass in &config.passes {
            let target_formats: Vec<TextureFormat> = pass
                .targets
                .iter()
                .map(|idx| TextureFormat::from(&config.framebuffers[*idx]))
                .collect();
            let mut draw = Vec::new();
            for cfg_draw in &pass.draw {
                draw.push(Draw::from_config(
                    cfg_draw,
                    &renderer,
                    &vertex_shader,
                    &target_formats,
                )?);
            }
            renderer.passes.push(Pass {
                draw,
                targets: pass.targets.clone(),
            });
        }

        for cfg_draw in &config.output.draw {
            renderer.output.draw.push(Draw::from_config(
                cfg_draw,
                &renderer,
                &vertex_shader,
                &[window_format],
            )?);
        }

        for format in config.framebuffers.iter().map(TextureFormat::from) {
            let texture = gpu
                .create_texture(
                    TextureCreateInfo::new()
                        .with_type(TextureType::_2D)
                        .with_format(format)
                        .with_usage(TextureUsage::ColorTarget.bitor(TextureUsage::Sampler))
                        .with_width(config.resolution.width)
                        .with_height(config.resolution.height)
                        .with_layer_count_or_depth(1)
                        .with_num_levels(1)
                        .with_sample_count(SampleCount::NoMultiSampling),
                )
                .context("Can't create framebuffer texture")?;
            renderer.framebuffers.push(texture);
        }

        Ok(renderer)
    }

    pub fn render(&mut self, time: f32) -> Result<()> {
        // Record offscreen passes
        let mut command_buffer = self
            .gpu
            .acquire_command_buffer()
            .context("Can't acquire command buffer")?;

        let mut swapchain_texture_raw = ptr::null_mut();
        let mut swapchain_texture_width = 0;
        let mut swapchain_texture_height = 0;
        unsafe {
            SDL_WaitAndAcquireGPUSwapchainTexture(
                command_buffer.raw(),
                self.window.raw(),
                &mut swapchain_texture_raw,
                &mut swapchain_texture_width,
                &mut swapchain_texture_height,
            );
        }
        if swapchain_texture_raw.is_null() {
            command_buffer.cancel();
            return Ok(());
        }

        for pass in &mut self.passes {
            let color_infos: Vec<ColorTargetInfo> = pass
                .targets
                .iter()
                .map(|target| {
                    ColorTargetInfo::default()
                        .with_texture(&self.framebuffers[*target])
                        .with_clear_color(Color::RGBA(0, 0, 0, 0))
                        .with_load_op(LoadOp::Clear)
                        .with_store_op(StoreOp::Store)
                })
                .collect();

            let render_pass = self
                .gpu
                .begin_render_pass(&command_buffer, &color_infos, None)
                .context("Can't begin render pass")?;

            for draw in &mut pass.draw {
                draw.draw(
                    &command_buffer,
                    &render_pass,
                    &self.framebuffers,
                    &self.resource_manager,
                    &self.sampler,
                    time,
                    self.config.resolution,
                );
            }

            self.gpu.end_render_pass(render_pass);
        }

        // Record onscreen pass
        let resolution_matches = swapchain_texture_width == self.config.resolution.width
            || swapchain_texture_height == self.config.resolution.height;

        let (color_infos, width, height) = if resolution_matches {
            (
                [unsafe {
                    mem::transmute::<SDL_GPUColorTargetInfo, ColorTargetInfo>(
                        SDL_GPUColorTargetInfo {
                            texture: swapchain_texture_raw,
                            mip_level: 0,
                            layer_or_depth_plane: 0,
                            clear_color: SDL_FColor {
                                r: 0.,
                                g: 0.,
                                b: 0.,
                                a: 1.,
                            },
                            load_op: SDL_GPU_LOADOP_CLEAR,
                            store_op: SDL_GPU_STOREOP_STORE,
                            ..Default::default()
                        },
                    )
                }],
                swapchain_texture_width,
                swapchain_texture_height,
            )
        } else {
            (
                [ColorTargetInfo::default()
                    .with_texture(&self.output_buffer)
                    .with_clear_color(Color::BLACK)
                    .with_load_op(LoadOp::Clear)
                    .with_store_op(StoreOp::Store)],
                self.output_buffer.width(),
                self.output_buffer.height(),
            )
        };

        let render_pass = self
            .gpu
            .begin_render_pass(&command_buffer, &color_infos, None)
            .context("Can't begin render pass")?;

        // Configure viewport so that aspect ratio is preserved
        if resolution_matches {
            self.gpu
                .set_viewport(&render_pass, self.viewport(width, height));
        }

        for draw in &mut self.output.draw {
            draw.draw(
                &command_buffer,
                &render_pass,
                &self.framebuffers,
                &self.resource_manager,
                &self.sampler,
                time,
                self.config.resolution,
            );
        }

        self.gpu.end_render_pass(render_pass);

        // Extra victory lap if necessary
        if !resolution_matches {
            let viewport = self.viewport(swapchain_texture_width, swapchain_texture_height);

            unsafe {
                sdl3::sys::gpu::SDL_BlitGPUTexture(
                    command_buffer.raw(),
                    &SDL_GPUBlitInfo {
                        source: SDL_GPUBlitRegion {
                            texture: self.output_buffer.raw(),
                            mip_level: 0,
                            layer_or_depth_plane: 0,
                            x: 0,
                            y: 0,
                            w: self.output_buffer.width(),
                            h: self.output_buffer.height(),
                        },
                        destination: SDL_GPUBlitRegion {
                            texture: swapchain_texture_raw,
                            mip_level: 0,
                            layer_or_depth_plane: 0,
                            x: (viewport.x + 0.5) as u32,
                            y: (viewport.y + 0.5) as u32,
                            w: (viewport.w + 0.5) as u32,
                            h: (viewport.h + 0.5) as u32,
                        },
                        load_op: SDL_GPU_LOADOP_CLEAR,
                        clear_color: SDL_FColor {
                            r: 0.,
                            g: 0.,
                            b: 0.,
                            a: 1.,
                        },
                        flip_mode: SDL_FlipMode::NONE,
                        filter: SDL_GPU_FILTER_NEAREST,
                        cycle: false,
                        ..Default::default()
                    },
                );
            }
        }

        command_buffer
            .submit()
            .context("Command buffer submission failed")?;

        self.fps.frame();

        Ok(())
    }

    fn load_shader_effect(
        gpu: &Device,
        vertex_shader: &Shader,
        texture_inputs: u32,
        target_formats: &[TextureFormat],
        name: impl AsRef<Path>,
    ) -> Result<GraphicsPipeline> {
        let (code, stage) = resource::load_shader(name)?;
        let fragment_shader = gpu
            .create_shader()
            .with_code(ShaderFormat::SpirV, &code, stage)
            .with_entrypoint(c"main")
            .with_uniform_buffers(2)
            .with_samplers(texture_inputs)
            .build()
            .context("Can't create fragment shader")?;

        let color_target_descs: Vec<ColorTargetDescription> = target_formats
            .iter()
            .map(|texture_format| {
                ColorTargetDescription::new()
                    .with_format(*texture_format)
                    .with_blend_state(ColorTargetBlendState::default())
            })
            .collect();

        let pipeline = gpu
            .create_graphics_pipeline()
            .with_vertex_shader(vertex_shader)
            .with_fragment_shader(&fragment_shader)
            .with_vertex_input_state(VertexInputState::default())
            .with_primitive_type(PrimitiveType::TriangleStrip)
            .with_rasterizer_state(
                RasterizerState::new()
                    .with_fill_mode(FillMode::Fill)
                    .with_cull_mode(CullMode::None)
                    .with_front_face(FrontFace::Clockwise),
            )
            .with_depth_stencil_state(
                DepthStencilState::new()
                    .with_compare_op(CompareOp::Greater)
                    .with_enable_depth_test(false)
                    .with_enable_stencil_test(false),
            )
            .with_target_info(
                GraphicsPipelineTargetInfo::new()
                    .with_color_target_descriptions(&color_target_descs)
                    .with_has_depth_stencil_target(false),
            )
            .build()
            .context("Can't create graphics pipeline")?;

        Ok(pipeline)
    }

    fn viewport(&self, width: u32, height: u32) -> SDL_GPUViewport {
        let aspect_ratio = width as f32 / height as f32;
        let target_ratio =
            self.config.resolution.width as f32 / self.config.resolution.height as f32;
        let (vp_width, vp_height) = if aspect_ratio > target_ratio {
            (height as f32 * target_ratio, height as f32)
        } else {
            (width as f32, width as f32 / target_ratio)
        };

        sdl3::sys::gpu::SDL_GPUViewport {
            x: if aspect_ratio > target_ratio {
                (width as f32 - vp_width) / 2.
            } else {
                0.
            },
            y: if aspect_ratio > target_ratio {
                0.
            } else {
                (height as f32 - vp_height) / 2.
            },
            w: vp_width,
            h: vp_height,
            min_depth: 0.,
            max_depth: 1.,
        }
    }
}
