#include "mesh.h"
#include "SDL2/SDL_log.h"
#include "cglm/affine.h"
#include "cglm/mat4.h"
#include "cglm/quat.h"
#include "gl.h"
#include "material.h"
#include "shader.h"
#include "texture.h"
#include <assert.h>
#include <stdio.h>

#define CGLTF_IMPLEMENTATION
#include "cgltf.h"

// Maps cgltf component type to GL type
static GLint type_map(cgltf_component_type t) {
    switch (t) {
    case cgltf_component_type_r_8: /* BYTE */
        return GL_BYTE;
    case cgltf_component_type_r_8u: /* UNSIGNED_BYTE */
        return GL_UNSIGNED_BYTE;
    case cgltf_component_type_r_16: /* SHORT */
        return GL_SHORT;
    case cgltf_component_type_r_16u: /* UNSIGNED_SHORT */
        return GL_UNSIGNED_SHORT;
    case cgltf_component_type_r_32u: /* UNSIGNED_INT */
        return GL_UNSIGNED_INT;
    case cgltf_component_type_r_32f: /* FLOAT */
        return GL_FLOAT;
    case cgltf_component_type_max_enum:
    case cgltf_component_type_invalid:
    default:
        return 0;
    }
}

// Maps gltf primitive type to GL draw mode
static GLenum mode_map(cgltf_primitive_type t) {
    switch (t) {
    case cgltf_primitive_type_points:
        return GL_POINTS;
    case cgltf_primitive_type_lines:
        return GL_LINES;
    case cgltf_primitive_type_line_loop:
        return GL_LINE_LOOP;
    case cgltf_primitive_type_line_strip:
        return GL_LINE_STRIP;
    case cgltf_primitive_type_triangles:
        return GL_TRIANGLES;
    case cgltf_primitive_type_triangle_strip:
        return GL_TRIANGLE_STRIP;
    case cgltf_primitive_type_triangle_fan:
        return GL_TRIANGLE_FAN;
    case cgltf_primitive_type_max_enum:
    default:
        return 0;
    }
}

static char *vec_to_str(float *vec, size_t count) {
    static char buf[256];
    memset(buf, 0, 256);

    size_t offset = snprintf(buf, 256, "{");
    for (size_t i = 0; i < count - 1; i++) {
        offset += snprintf(buf + offset, 256 - offset, "%f, ", vec[i]);
    }
    snprintf(buf + offset, 256 - offset, "%f}", vec[count - 1]);

    return buf;
}

static GLuint load_buffer_view(buffer_cache_t *buffers,
                               cgltf_buffer_view *buffer_view,
                               cgltf_data *data) {
    size_t i = buffer_view - data->buffer_views;
    assert(i < buffers->buffers_count);
    if (buffers->buffers[i] != 0) {
        // Already loaded
        return buffers->buffers[i];
    }

    assert(buffer_view->stride == 0);

    glGenBuffers(1, &buffers->buffers[i]);
    glBindBuffer(GL_ARRAY_BUFFER, buffers->buffers[i]);

    char *payload = buffer_view->buffer->data;
    if (buffer_view->data) {
        payload = buffer_view->data;
    }
    glBufferData(GL_ARRAY_BUFFER, buffer_view->size,
                 payload + buffer_view->offset, GL_STATIC_DRAW);

    SDL_Log("Loaded buffer view %zu to GL buffer %u\n", i, buffers->buffers[i]);

    return buffers->buffers[i];
}

static mesh_t *mesh_init(mesh_t *mesh_out, cgltf_mesh *mesh, cgltf_data *data,
                         buffer_cache_t *buffers) {
    SDL_Log("\tMesh has %zu primitives\n", mesh->primitives_count);
    mesh_out->primitives_count = mesh->primitives_count;
    mesh_out->vaos = malloc(sizeof(GLuint) * mesh->primitives_count);
    mesh_out->draw_params =
        malloc(sizeof(draw_param_t) * mesh->primitives_count);
    mesh_out->materials = malloc(sizeof(size_t) * mesh->primitives_count);
    glGenVertexArrays(mesh_out->primitives_count, mesh_out->vaos);
    for (size_t i = 0; i < mesh->primitives_count; i++) {
        SDL_Log("\tPrimitive %zu:\n", i);
        cgltf_primitive *pri = mesh->primitives + i;
        assert(pri->indices != NULL); // TODO not all primitives are indexed

        // Set material index (with default for NULL)
        material_log(pri->material, data);
        mesh_out->materials[i] = pri->material
                                     ? (size_t)(pri->material - data->materials)
                                     : data->materials_count;

        // Store draw parameters
        mesh_out->draw_params[i] = (draw_param_t){
            .mode = mode_map(pri->type),
            .count = pri->indices->count,
            .index_type = type_map(pri->indices->component_type),
        };
        SDL_Log("\t\tHas %zu indices, buffer offset %zu\n", pri->indices->count,
                pri->indices->buffer_view->offset);

        // Record VAO
        glBindVertexArray(mesh_out->vaos[i]);

        // Load and bind index buffer
        GLuint index_buffer =
            load_buffer_view(buffers, pri->indices->buffer_view, data);
        glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, index_buffer);

        // Iterate attribute buffers
        SDL_Log("\t\tHas %zu attributes\n", pri->attributes_count);
        for (size_t j = 0; j < pri->attributes_count; j++) {
            cgltf_attribute *atr = pri->attributes + j;
            SDL_Log("\t\tAttribute %zu, index %d, type %d, name %s:\n", j,
                    atr->index, atr->type, atr->name);

            cgltf_accessor *acc = atr->data;
            cgltf_buffer_view *bv = acc->buffer_view;
            GLuint buffer = load_buffer_view(buffers, bv, data);
            glBindBuffer(GL_ARRAY_BUFFER, buffer);
            SDL_Log("\t\t\taccessor name %s\n", acc->name);
            SDL_Log("\t\t\taccessor component_type %u\n", acc->component_type);
            SDL_Log("\t\t\taccessor normalized %d\n", acc->normalized);
            SDL_Log("\t\t\taccessor type %u\n", acc->type);
            SDL_Log("\t\t\taccessor offset %zu\n", acc->offset);
            SDL_Log("\t\t\taccessor count %zu\n", acc->count);
            SDL_Log("\t\t\taccessor stride %zu\n", acc->stride);
            SDL_Log("\t\t\taccessor buffer_view maps to GL buffer %u\n",
                    buffer);
            SDL_Log("\t\t\taccessor has_min %d\n", acc->has_min);
            SDL_Log("\t\t\taccessor has_max %d\n", acc->has_max);
            SDL_Log("\t\t\taccessor is_sparse %d\n", acc->is_sparse);
            SDL_Log("\t\t\taccessor extensions_count %zu\n",
                    acc->extensions_count);

            SDL_Log("\t\t\tbuffer view name %s\n", bv->name);
            SDL_Log("\t\t\tbuffer view buffer %p\n", (void *)bv->buffer);
            SDL_Log("\t\t\tbuffer view offset %zu\n", bv->offset);
            SDL_Log("\t\t\tbuffer view size %zu\n", bv->size);
            SDL_Log("\t\t\tbuffer view stride %zu\n", bv->stride);
            SDL_Log("\t\t\tbuffer view type %u\n", bv->type);
            SDL_Log("\t\t\tbuffer view data %p\n", (void *)bv->data);
            SDL_Log("\t\t\tbuffer view has_meshopt_compression %d\n",
                    bv->has_meshopt_compression);
            SDL_Log("\t\t\tbuffer view extensions_count %zu\n",
                    bv->extensions_count);
            SDL_Log("\t\t\tbuffer view extensions %p\n",
                    (void *)bv->extensions);

            // Matrix types aren't supported
            assert(acc->type != cgltf_type_invalid);
            assert(acc->type <= cgltf_type_vec4);

            // Check validity of attribute type
            assert(atr->type != cgltf_attribute_type_invalid);

            // Simply map attribute types to attribute locations
            // TODO implement some kind of robust mapping scheme,
            // for example when multiple texture coordinates are present
            glEnableVertexAttribArray(atr->type - 1);
            glVertexAttribPointer(
                atr->type - 1, acc->type, type_map(acc->component_type),
                acc->normalized, acc->stride, (const GLvoid *)(acc->offset));
        }

        glBindVertexArray(0);
    }

    return mesh_out;
}

void mesh_draw(const mesh_t *mesh, const scene_t *scene,
               const program_t *program) {
    for (size_t i = 0; i < mesh->primitives_count; i++) {
        material_t *material = &scene->materials[mesh->materials[i]];

        // Bind base color texture
        if (scene->images[material->base_color_texture] != 0) {
            glActiveTexture(GL_TEXTURE0);
            glBindTexture(GL_TEXTURE_2D,
                          scene->images[material->base_color_texture]);
            glUniform1i(
                glGetUniformLocation(program->handle, "u_BaseColorSampler"), 0);
        }

        // Bind metallic roughness texture
        if (scene->images[material->metallic_roughness_texture] != 0) {
            glActiveTexture(GL_TEXTURE1);
            glBindTexture(GL_TEXTURE_2D,
                          scene->images[material->metallic_roughness_texture]);
            glUniform1i(glGetUniformLocation(program->handle,
                                             "u_MetallicRoughnessSampler"),
                        1);
        }

        // Bind occlusion texture
        if (scene->images[material->occlusion_texture] != 0) {
            glActiveTexture(GL_TEXTURE2);
            glBindTexture(GL_TEXTURE_2D,
                          scene->images[material->occlusion_texture]);
            glUniform1i(
                glGetUniformLocation(program->handle, "u_OcclusionSampler"), 2);
        }

        glUniform4fv(glGetUniformLocation(program->handle, "u_BaseColorFactor"),
                     1, material->base_color_factor);
        glUniform1f(glGetUniformLocation(program->handle, "u_MetallicFactor"),
                    material->metallic_factor);
        glUniform1f(glGetUniformLocation(program->handle, "u_RoughnessFactor"),
                    material->roughness_factor);

        // TODO:
        // size_t normal_texture;
        // size_t emissive_texture;
        // GLfloat emissive_factor[3];
        // GLint double_sided;
        // GLint unlit;

        glBindVertexArray(mesh->vaos[i]);
        draw_param_t *p = mesh->draw_params + i;
        glDrawElements(p->mode, p->count, p->index_type, 0);
    }
}

static GLuint load_texture(cgltf_image *image) {
    GLuint texture = 0;
    if (image == NULL) {
        return texture;
    }

    if (image->buffer_view) {
        cgltf_buffer_view *bv = image->buffer_view;
        SDL_Log("Image has buffer view\n");
        assert(bv->stride == 0);
        char *payload = bv->data;
        if (payload == NULL) {
            payload = bv->buffer->data;
        } else {
            SDL_Log("Buffer view data overrides buffer data\n");
        }
        SDL_Log("Loading texture from buffer offset %zu, size %zu\n",
                bv->offset, bv->size);
        texture = texture_init(payload + bv->offset, bv->size);
    } else if (image->uri) {
        SDL_Log("Image has URI\n");
        texture = texture_init_file(image->uri);
    }

    return texture;
}

scene_t *scene_init_gltf(const char *path) {
    // Parse glTF file and load buffers to memory
    cgltf_options options = {0};
    cgltf_data *data = NULL;
    cgltf_result result = cgltf_parse_file(&options, path, &data);
    if (result != cgltf_result_success) {
        SDL_Log("Failed to load glTF file %s\n", path);
        return NULL;
    }
    result = cgltf_load_buffers(&options, data, path);
    if (result != cgltf_result_success) {
        SDL_Log("Failed to load buffers for glTF file %s\n", path);
        return NULL;
    }

    scene_t *scene_out = malloc(sizeof(scene_t));

    // Initialize GL buffer object cache (of glTF buffer views)
    SDL_Log("%s has %zu buffer views\n", path, data->buffer_views_count);
    scene_out->buffers.buffers_count = data->buffer_views_count;
    scene_out->buffers.buffers =
        calloc(scene_out->buffers.buffers_count, sizeof(GLuint));

    // Load all texture images + null texture
    SDL_Log("%s has %zu images\n", path, data->images_count);
    scene_out->images_count = data->images_count + 1;
    scene_out->images = calloc(scene_out->images_count, sizeof(GLuint));
    for (size_t i = 0; i < data->images_count; i++) {
        scene_out->images[i] = load_texture(data->images + i);
    }

    // Load all materials + generate default material
    scene_out->materials_count = data->materials_count + 1;
    scene_out->materials =
        calloc(scene_out->materials_count, sizeof(material_t));
    for (size_t i = 0; i < data->materials_count; i++) {
        cgltf_material *material = data->materials + i;
        material_init(&scene_out->materials[i], material, data);
    }
    material_default(&scene_out->materials[data->materials_count]);

    // Load all meshes
    scene_out->meshes_count = data->meshes_count;
    scene_out->meshes = calloc(data->meshes_count, sizeof(mesh_t));
    for (size_t i = 0; i < data->meshes_count; i++) {
        cgltf_mesh *mesh = data->meshes + i;
        mesh_init(&scene_out->meshes[i], mesh, data, &scene_out->buffers);
    }

    // Load all nodes
    SDL_Log("Has %zu nodes\n", data->nodes_count);
    scene_out->nodes_count = data->nodes_count;
    scene_out->nodes = malloc(sizeof(node_t) * data->nodes_count);
    for (size_t i = 0; i < data->nodes_count; i++) {
        cgltf_node *node = data->nodes + i;
        SDL_Log("\tNode %zu (%s)\n", i, node->name);
        if (!node->mesh) {
            SDL_Log("\t\tNode hasn't got reference to mesh, skipping...\n");
            continue;
        }
        scene_out->nodes[i].is_root = node->parent == NULL;
        if (node->parent) {
            scene_out->nodes[i].parent = node->parent - data->nodes;
            SDL_Log("\t\tNode has parent %zu\n", scene_out->nodes[i].parent);
        }
        scene_out->nodes[i].children_count = node->children_count;
        if (node->children_count > 0) {
            SDL_Log("\t\tNode has %zu children\n",
                    scene_out->nodes[i].children_count);
            scene_out->nodes[i].children =
                malloc(sizeof(size_t) * node->children_count);
            for (size_t j = 0; j < node->children_count; j++) {
                scene_out->nodes[i].children[j] =
                    node->children[j] - data->nodes;
            }
        }
        if (node->has_matrix) {
            SDL_Log("\t\tNode has a matrix\n");
            SDL_Log("\t\t%s\n", vec_to_str(node->matrix, 16));
            glm_mat4_make(node->matrix, scene_out->nodes[i].matrix);
            assert(!node->has_scale);
            assert(!node->has_rotation);
            assert(!node->has_translation);
        } else {
            if (node->has_scale) {
                SDL_Log("\t\tNode has a scale vector\n");
                SDL_Log("\t\t%s\n", vec_to_str(node->scale, 3));
                glm_scale_make(scene_out->nodes[i].matrix, node->scale);
            } else {
                glm_mat4_identity(scene_out->nodes[i].matrix);
            }
            if (node->has_rotation) {
                SDL_Log("\t\tNode has a rotation quat\n");
                SDL_Log("\t\t%s\n", vec_to_str(node->rotation, 4));
                mat4 tmp;
                memcpy(tmp, scene_out->nodes[i].matrix, sizeof(mat4));
                versor q;
                glm_quat_init(q, node->rotation[0], node->rotation[1],
                              node->rotation[2], node->rotation[3]);
                glm_quat_rotate(scene_out->nodes[i].matrix, q, tmp);
            }
            if (node->has_translation) {
                SDL_Log("\t\tNode has a translation vector\n");
                SDL_Log("\t\t%s\n", vec_to_str(node->translation, 3));
                glm_translate(scene_out->nodes[i].matrix, node->translation);
            }
        }
        scene_out->nodes[i].mesh = node->mesh - data->meshes;
        SDL_Log("\t\tRefers to mesh %zu\n", scene_out->nodes[i].mesh);
    }

    // Load scenes
    SDL_Log("Has %zu scenes\n", data->scenes_count);
    if (data->scenes_count == 0) {
        return scene_out;
    }

    // Select first scene or default (data->scene) if available
    cgltf_scene *scene = data->scenes;
    if (data->scene) {
        scene = data->scene;
    }

    SDL_Log("Scene %s\n", scene->name);

    SDL_Log("\tHas %zu nodes\n", scene->nodes_count);
    scene_out->parent_nodes_count = scene->nodes_count;
    if (scene->nodes_count == 0) {
        return scene_out;
    }

    scene_out->parent_nodes = malloc(sizeof(size_t) * scene->nodes_count);
    for (size_t i = 0; i < scene->nodes_count; i++) {
        cgltf_node *node = scene->nodes[i];
        scene_out->parent_nodes[i] = node - data->nodes;
        SDL_Log("\tScene's node %zu maps to index %zu\n", i,
                scene_out->parent_nodes[i]);
    }

    // Check buffer_view cache state
    for (size_t i = 0; i < scene_out->buffers.buffers_count; i++) {
        SDL_Log("Buffer view %zu is%s loaded\n", i,
                scene_out->buffers.buffers[i] ? "" : "n't");
    }

    cgltf_free(data);

    return scene_out;
}

static void mesh_deinit(mesh_t *mesh) {
    if (mesh) {
        if (mesh->vaos) {
            glDeleteVertexArrays(mesh->primitives_count, mesh->vaos);
            free(mesh->vaos);
        }
        if (mesh->draw_params) {
            free(mesh->draw_params);
        }
        if (mesh->materials) {
            free(mesh->materials);
        }
    }
}

void scene_deinit(scene_t *scene) {
    if (scene) {
        if (scene->parent_nodes && scene->parent_nodes_count > 0) {
            free(scene->parent_nodes);
        }
        if (scene->nodes) {
            for (size_t i = 0; i < scene->nodes_count; i++) {
                node_t *node = scene->nodes + i;
                if (node->children_count > 0) {
                    free(node->children);
                }
            }
            free(scene->nodes);
        }
        if (scene->meshes) {
            for (size_t i = 0; i < scene->meshes_count; i++) {
                mesh_deinit(&scene->meshes[i]);
            }
            free(scene->meshes);
        }
        if (scene->materials) {
            free(scene->materials);
        }
        if (scene->images) {
            glDeleteTextures(scene->images_count - 1, scene->images);
            free(scene->images);
        }
        if (scene->buffers.buffers) {
            glDeleteBuffers(scene->buffers.buffers_count,
                            scene->buffers.buffers);
            free(scene->buffers.buffers);
        }
        free(scene);
    }
}

void scene_draw(const scene_t *scene, const program_t *program) {
    // TODO traverse graph?
    for (size_t i = 0; i < scene->parent_nodes_count; i++) {
        node_t *node = scene->nodes + scene->parent_nodes[i];
        glUniformMatrix4fv(glGetUniformLocation(program->handle, "u_Model"), 1,
                           GL_FALSE, (float *)node->matrix);
        mesh_draw(&scene->meshes[node->mesh], scene, program);
    }
}
