#include <GL/glew.h>
#include <GL/freeglut.h>
#include <vector>
#include <random>
#include <algorithm>  // Add this for std::remove_if
#include <glm/glm.hpp>
#include <glm/gtc/matrix_transform.hpp>
#include <glm/gtc/type_ptr.hpp>
#include "mp3_player.hpp"  // Add audio analysis library

const float GRID_DURATION_SECONDS = 65.0f;

// Constants for grid dimensions
const int GRID_WIDTH = 320;
const int GRID_HEIGHT = 200;
const float PIXEL_SIZE = 1.0f;
const float GRID_DEPTH = 0.0f;  // Flat grid for now
const float MAX_HEIGHT = 10.0f;  // Maximum height for a pixel
const float HEIGHT_INCREMENT = 0.2f;  // How much to increase height per update
// Add a time scaling factor for development
float EFFECT_TIME_SCALE = 1.0f; // Changed from const, set to 1.0f for normal speed, <1.0f for faster
const float EFFECT_DURATION = 5.0f;  // How long the effect lasts
const float EFFECT_RADIUS = 60.0f;   // Increased radius for a much larger effect region
const float EFFECT_INTENSITY = 0.5f; // Base intensity of the effect
const int MAX_SNOW_PARTICLES = 100;  // Reduced from 1000 to 100
const float SNOW_SPAWN_RADIUS = 200.0f;  // Spawn radius
const float SNOW_MIN_SIZE = 1.0f;  // Increased minimum size
const float SNOW_MAX_SIZE = 5.0f;  // Increased maximum size
const float SNOW_MIN_SPEED = 30.0f;  // Minimum speed
const float SNOW_MAX_SPEED = 60.0f;  // Maximum speed
const float SNOW_SPAWN_DISTANCE = 200.0f;  // Reduced spawn distance
const float PHYSICS_DELAY = 666.0f; // Seconds before physics activates
const float GRAVITY = 9.8f;
const float BOUNCE_FACTOR = 0.6f;
const float EARTHQUAKE_MAGNITUDE = 3.0f; // Significantly increased from 1.0f
const float FLOOR_HEIGHT = -4.0f; // Moved floor down from -2.0f to -4.0f
const float RESTORATION_FORCE = 7.5f; // Force pulling cubes back to original positions
const float RESTORATION_DAMPING = 0.9f; // Damping for restoration movement
bool restorationPhase = false; // Flag for restoration phase
float restorationStartTime = 0.0f; // When restoration phase started
float restorationTimeout = 9.0f; // Maximum time for restoration
const float PONG_PADDLE_SPEED = 5.0f;
const float PONG_BALL_MAX_SPEED = 8.0f;
const float PONG_BALL_MIN_SPEED = 4.0f;

float morphStartTime = 22.0f;
float morphDuration = 12.0f;

// CRT Color Palette (256 colors)
glm::vec3 CRT_PALETTE[256];  // Remove const, we'll initialize it in initPixels()

// Structure for a pixel
struct Pixel {
    glm::vec3 color;
    float brightness;
};

// Structure for layer pixels
struct LayerPixel {
    float height;
    float color;
    float lastUpdate;
};

// Move VertexData struct before global variables
struct VertexData {
    glm::vec3 position;
    glm::vec3 color;
    float brightness;
    float index;
};

// Global variables
std::vector<Pixel> pixels;
GLuint shaderProgram = 0;
GLuint simpleShaderProgram = 0;  // New shader program for static grid
GLuint mazeShaderProgram = 0;
GLuint VAO, VBO;
GLuint snowVAO, snowVBO;  // Vertex array and buffer objects for snow particles
float globalTime = 0.0f;
float animationCompleteTime = 0.0f;  // Time when all pixels are in place
float baseCameraZ = 600.0f;
float cameraZ = baseCameraZ;
float fov = glm::radians(45.0f);
int windowWidth = 800;
int windowHeight = 600;

glm::mat4 projection;  // Add missing projection matrix

// Audio integration variables
AudioAnalysis::MP3Player* audio_player = nullptr;
AudioAnalysis::AudioFeatures current_audio_features;
std::mutex audio_features_mutex;
float audio_bass_intensity = 0.0f;
float audio_mid_intensity = 0.0f;
float audio_treble_intensity = 0.0f;
float audio_energy_level = 0.0f;
bool audio_beat_detected = false;
float audio_tempo = 120.0f;  // Default BPM
std::string current_mp3_file;

// Add after global variables:
// Global RNG for optimization #2
std::random_device rd;
std::mt19937 globalRng(rd());

// Pre-allocated vertex data buffer for snowflakes (optimization #3)
std::vector<VertexData> snowflakeVertexData;

// Layer management variables
std::vector<LayerPixel> layerPixels;  // Current layer state
std::vector<LayerPixel> nextLayerPixels;  // Next layer state
float lastLayerUpdate = 0.0f;  // Time of last update
const float LAYER_UPDATE_INTERVAL = 0.1f;  // Update every 0.1 seconds

// Snowflake effect system
#define MAX_EFFECTS 16
struct SnowflakeEffect {
    glm::vec2 position;
    float radius;
    int effectType;
    float blendFactor;
    float lastUpdate;
};

std::vector<SnowflakeEffect> snowflakeEffects;


// Add after the global variables
std::vector<bool> gameOfLifeState;  // Current state of each cell
std::vector<bool> nextGameOfLifeState;  // Next state for simulation
float lastGameOfLifeUpdate = 0.0f;  // Time of last update
const float GAME_OF_LIFE_UPDATE_INTERVAL = 0.2f;  // Update every 0.2 seconds

// Add after the global variables
struct SnowParticle {
    glm::vec3 position;
    float size;
    float speed;
    float rotation;
    float rotationSpeed;
    float alpha;
    int effectType;  // Add this field
    bool hasLanded;  // Add this field
    float landingTime;  // Add this field
};

std::vector<SnowParticle> snowParticles;
bool pongStarted = false; // Add this at global scope

#include <GL/glew.h>
#include <GL/glut.h>
#include <cstdio>
#include <string>
#include <glm/glm.hpp>
#include <glm/gtc/matrix_transform.hpp>
#include <glm/gtc/type_ptr.hpp>
#include <glm/gtc/quaternion.hpp>
#include <glm/gtx/quaternion.hpp>
#include <vector>
#include <array>
#include <queue>
#include <memory>
#include <ctime>   // For seeding random

// Global camera position
glm::vec3 cameraPos(4.0f, 3.0f, 3.0f);

// Global time for animations
bool physicsActive = false;
float physicsStartTime = 0.0f;


// Pong game variables
bool pongMode = false;
struct PongPaddle {
    glm::vec3 position;
    glm::vec3 size; // width, height, depth
    glm::vec3 color;
    float speed;
};
PongPaddle leftPaddle;
PongPaddle rightPaddle;
glm::vec3 pongBallVelocity;

int leftScore = 0;
int rightScore = 0;
bool showStartScreen = true;
const float PONG_FLOOR_Y = -2.0f; // Higher floor for Pong mode
const float PONG_CEILING_Y = 7.0f;
const float PONG_SIDE_X = 15.0f;
const float PONG_MAX_SCORE = 5; // Game ends when a player reaches this score

// Auto solve/shuffle state machine
enum CubeState {
    IDLE,
    SHUFFLING,
    SOLVING,
    EARTHQUAKE_TRANSITION,
    RESTORING,
    PONG_PLAYING,
    PONG_RESET
};
CubeState cubeState = IDLE;
const int MIN_SHUFFLE_MOVES = 10;
const int MAX_SHUFFLE_MOVES = 10;
int shuffleMovesLeft = 0;
float stateMachineDelay = 0.0f;

// Store all moves for solving later
struct CubeMove {
    int axis;           // 0=X, 1=Y, 2=Z
    int layer;          // -1, 0, 1
    bool clockwise;     // Direction
};
std::vector<CubeMove> shuffleMoves;

// Animation state
struct RotationAnimation {
    int layer;           // Which layer is rotating (0-2 for each axis)
    int axis;           // 0=X, 1=Y, 2=Z
    float angle;        // Current angle
    float targetAngle;  // Target angle to reach
    float speed;        // Speed of rotation
    bool clockwise;     // Rotation direction
};

// Cube face structure
struct CubeFace {
    glm::vec3 position;
    glm::vec3 color;
    float rotation = 0.0f;
    glm::vec3 rotationAxis = glm::vec3(0.0f, 1.0f, 0.0f);
    int layer[3];  // Layer indices for X, Y, Z axes
    bool isAnimating = false;
    bool clockwise = true;  // Direction of current animation
    glm::mat4 animationTransform = glm::mat4(1.0f);
    glm::mat4 baseTransform = glm::mat4(1.0f);
    
    // Physics properties
    glm::vec3 velocity = glm::vec3(0.0f);
    glm::vec3 angularVelocity = glm::vec3(0.0f);
    bool onGround = false;
    float mass = 1.0f;
};

// Global parameters
struct Parameters {
    float rotationSpeed = 90.0f;         // Degrees per second
    float animationSpeed = 180.0f;        // Degrees per second
    int selectedLayer = 0;               // Layer to rotate
    float cubeSize = 0.5f;               // Size of each individual cube
    float spacing = 0.05f;               // Spacing between cubes
    float lightIntensity = 1.0f;         // Light intensity
    float cameraDistance = 5.0f;         // Distance of camera from cube
    float morphFactor = 0.0f;            // 0 = cube, 1 = sphere
} params;

int mainWindow = 0;

GLuint vao = 0;
GLuint vbo = 0;
GLuint ebo = 0;

// Cube data
const int CUBE_COUNT = 27; // 3x3x3 cube
std::vector<CubeFace> cubes;
std::queue<RotationAnimation> animationQueue;
float rotation = 0.0f;
bool isAnimating = false;

// Forward declarations
void updateAnimation(float deltaTime);
void updatePhysics(float deltaTime);
void startRestorationPhase();
bool isRestorationComplete();
void completeRestoration();
void initOriginalPositions();
void startAutomaticShuffle();
void continueAutomaticShuffle();
void startAutomaticSolve();
void continueAutomaticSolve();
void startRandomMove();
void updateStateMachine(float deltaTime);
void initPongGame();
void updatePongGame(float deltaTime);
void drawPongPaddles();
void resetPongBall();
void checkPongCollisions();

// Store original positions for restoration
struct OriginalPosition {
    glm::vec3 position;
    int layer[3];
};
std::vector<OriginalPosition> originalPositions;

// Initialize the original positions for restoration
void initOriginalPositions() {
    originalPositions.clear();
    
    for (const auto& cube : cubes) {
        OriginalPosition orig;
        orig.position = cube.position;
        orig.layer[0] = cube.layer[0];
        orig.layer[1] = cube.layer[1];
        orig.layer[2] = cube.layer[2];
        originalPositions.push_back(orig);
    }
    
    printf("Debug: Stored %zu original cube positions\n", originalPositions.size());
}

// Vertex data for a single face
const std::array<GLfloat, 24> faceVertices = {
    // Front face vertices (x, y, z)
    -0.5f, -0.5f,  0.5f,  // 0
     0.5f, -0.5f,  0.5f,  // 1
     0.5f,  0.5f,  0.5f,  // 2
    -0.5f,  0.5f,  0.5f,  // 3
    -0.5f, -0.5f, -0.5f,  // 4
     0.5f, -0.5f, -0.5f,  // 5
     0.5f,  0.5f, -0.5f,  // 6
    -0.5f,  0.5f, -0.5f   // 7
};

// Add these structures after your other struct definitions
struct MazeCell {
    bool visited = false;
    bool active = false;  // Currently growing
    float birthTime = 0.0f;  // When this cell became active
    bool walls[4] = {true, true, true, true};  // top, right, bottom, left
};

// Add these as global variables
std::vector<MazeCell> mazeGrid;
const int MAZE_WIDTH = 40;  // Adjust for desired density
const int MAZE_HEIGHT = 25;
const float MAZE_GROWTH_SPEED = 2.0f;  // Cells per second
std::vector<glm::ivec2> growthPoints;
GLuint mazeVAO = 0, mazeVBO = 0;
std::vector<VertexData> mazeLines;

// Add this initialization function
void initMazeBackground() {
    mazeGrid.resize(MAZE_WIDTH * MAZE_HEIGHT);
    
    // Start with a few random growth points
    for (int i = 0; i < 5; i++) {
        int x = rand() % MAZE_WIDTH;
        int y = rand() % MAZE_HEIGHT;
        int idx = y * MAZE_WIDTH + x;
        mazeGrid[idx].active = true;
        mazeGrid[idx].birthTime = globalTime;
        growthPoints.push_back(glm::ivec2(x, y));
    }

    // Create VAO/VBO for maze lines
    glGenVertexArrays(1, &mazeVAO);
    glGenBuffers(1, &mazeVBO);
    
    glBindVertexArray(mazeVAO);
    glBindBuffer(GL_ARRAY_BUFFER, mazeVBO);
    
    // We'll dynamically update this later
    glBufferData(GL_ARRAY_BUFFER, sizeof(VertexData) * 10000, nullptr, GL_DYNAMIC_DRAW);
    
    // Set up attributes same as your other VAOs
    glVertexAttribPointer(0, 3, GL_FLOAT, GL_FALSE, sizeof(VertexData), (void*)0);
    glEnableVertexAttribArray(0);
    glVertexAttribPointer(1, 3, GL_FLOAT, GL_FALSE, sizeof(VertexData), (void*)offsetof(VertexData, color));
    glEnableVertexAttribArray(1);
    glVertexAttribPointer(2, 1, GL_FLOAT, GL_FALSE, sizeof(VertexData), (void*)offsetof(VertexData, brightness));
    glEnableVertexAttribArray(2);
    
    glBindBuffer(GL_ARRAY_BUFFER, 0);
    glBindVertexArray(0);
    
    // After VAO/VBO creation
    printf("Maze VAO: %u, VBO: %u\n", mazeVAO, mazeVBO);
    
    // Debug: Check for OpenGL errors
    GLenum err = glGetError();
    if (err != GL_NO_ERROR) {
        printf("OpenGL error after maze buffer init: %d\n", err);
    }
}
enum Direction { UP, RIGHT, DOWN, LEFT };
const int dx[] = {0, 1, 0, -1};
const int dy[] = {-1, 0, 1, 0};

Direction getDirection(std::pair<int,int> from, std::pair<int,int> to) {
    if (to.second < from.second) return UP;
    if (to.first > from.first) return RIGHT;
    if (to.second > from.second) return DOWN;
    return LEFT;
}

bool isTurn(Direction current, Direction next) {
    return (current + 2) % 4 != next && current != next;
}
bool mazeInitialized = false;
// Add this update function
void updateMazeBackground() {
    static std::vector<std::pair<int, int>> stack;  // For backtracking
    static std::mt19937 rng(std::random_device{}());
    
    // Initialize if first time
    if (stack.empty() && !mazeInitialized) {
        // Start from center for more interesting pattern
        int startX = MAZE_WIDTH / 2;
        int startY = MAZE_HEIGHT / 2;
        stack.push_back({startX, startY});
        mazeGrid[startY * MAZE_WIDTH + startX].active = true;
        mazeGrid[startY * MAZE_WIDTH + startX].visited = true;
        mazeGrid[startY * MAZE_WIDTH + startX].birthTime = globalTime;
    }
    
    // Process multiple cells per frame for smoother growth
    for (int i = 0; i < 5 && !stack.empty(); i++) {
        auto [x, y] = stack.back();
        
        // Get unvisited neighbors
        std::vector<std::pair<int, Direction>> neighbors;
        if (y > 0 && !mazeGrid[(y-1) * MAZE_WIDTH + x].visited) neighbors.push_back({0, UP});
        if (x < MAZE_WIDTH-1 && !mazeGrid[y * MAZE_WIDTH + (x+1)].visited) neighbors.push_back({1, RIGHT});
        if (y < MAZE_HEIGHT-1 && !mazeGrid[(y+1) * MAZE_WIDTH + x].visited) neighbors.push_back({2, DOWN});
        if (x > 0 && !mazeGrid[y * MAZE_WIDTH + (x-1)].visited) neighbors.push_back({3, LEFT});
        
        if (!neighbors.empty()) {
            // Bias towards making turns
            Direction lastDir = stack.size() > 1 ? 
                getDirection(stack[stack.size()-2], stack.back()) : 
                static_cast<Direction>(std::uniform_int_distribution<>(0, 3)(rng));
                
            // Sort neighbors by turn preference
            std::sort(neighbors.begin(), neighbors.end(), 
                [lastDir](const auto& a, const auto& b) {
                    return isTurn(lastDir, a.second) > isTurn(lastDir, b.second);
                });
                
            // 70% chance to prefer turns
            std::uniform_real_distribution<> dist(0.0, 1.0);
            int chosen = dist(rng) < 0.7 ? 0 : 
                std::uniform_int_distribution<>(0, neighbors.size()-1)(rng);
                
            auto [idx, dir] = neighbors[chosen];
            
            // Remove walls between cells
            int newX = x + dx[dir];
            int newY = y + dy[dir];
            
            mazeGrid[y * MAZE_WIDTH + x].walls[idx] = false;
            mazeGrid[newY * MAZE_WIDTH + newX].walls[(idx + 2) % 4] = false;
            
            mazeGrid[newY * MAZE_WIDTH + newX].visited = true;
            mazeGrid[newY * MAZE_WIDTH + newX].active = true;
            mazeGrid[newY * MAZE_WIDTH + newX].birthTime = globalTime;
            
            stack.push_back({newX, newY});
        } else {
            stack.pop_back();
        }
    }
}

// Add this drawing function
void drawMazeBackground() {
    if (globalTime < morphStartTime) {
        return;
    }
    
    mazeLines.clear();
    glDisable(GL_DEPTH_TEST);    
    // Debug: Print timing info
    // Convert grid coordinates to screen space
    float cellWidth = 2.0f / MAZE_WIDTH;
    float cellHeight = 2.0f / MAZE_HEIGHT;
    
    for (int y = 0; y < MAZE_HEIGHT; y++) {
        for (int x = 0; x < MAZE_WIDTH; x++) {
            int idx = y * MAZE_WIDTH + x;
            const MazeCell& cell = mazeGrid[idx];
            
            if (!cell.visited && !cell.active) continue;
            
            float sx = x * cellWidth - 1.0f;
            float sy = y * cellHeight - 1.0f;
            //float alpha = std::min(1.0f, (globalTime - cell.birthTime) * 2.0f);
            float alpha = 1.0f;
            
            // Draw active walls
            if (cell.walls[0]) { // top
                mazeLines.push_back({{sx, sy + cellHeight, 0}, {1.0f, 1.0f, 1.0f}, alpha, 1.0f});
                mazeLines.push_back({{sx + cellWidth, sy + cellHeight, 0}, {1.0f, 1.0f, 1.0f}, alpha, 1.0f});
            }
            if (cell.walls[1]) { // right
                mazeLines.push_back({{sx + cellWidth, sy + cellHeight, 0}, {1.0f, 1.0f, 1.0f}, alpha, 1.0f});
                mazeLines.push_back({{sx + cellWidth, sy, 0}, {1.0f, 1.0f, 1.0f}, alpha, 1.0f});
            }
            if (cell.walls[2]) { // bottom
                mazeLines.push_back({{sx + cellWidth, sy, 0}, {1.0f, 1.0f, 1.0f}, alpha, 1.0f});
                mazeLines.push_back({{sx, sy, 0}, {1.0f, 1.0f, 1.0f}, alpha, 1.0f});
            }
            if (cell.walls[3]) { // left
                mazeLines.push_back({{sx, sy, 0}, {1.0f, 1.0f, 1.0f}, alpha, 1.0f});
                mazeLines.push_back({{sx, sy + cellHeight, 0}, {1.0f, 1.0f, 1.0f}, alpha, 1.0f});
            }
        }
    }
    
        // After generating lines
    printf("Generated %zu maze lines\n", mazeLines.size());

    // Draw first, then handle matrices
    if (!mazeLines.empty()) {
        glUseProgram(mazeShaderProgram);

        // Set up orthographic projection for 2D
        glm::mat4 model = glm::mat4(1.0f);
        glm::mat4 view = glm::mat4(1.0f);
        glm::mat4 proj = glm::ortho(-1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f);

        glUniformMatrix4fv(glGetUniformLocation(mazeShaderProgram, "model"), 1, GL_FALSE, glm::value_ptr(model));
        glUniformMatrix4fv(glGetUniformLocation(mazeShaderProgram, "view"), 1, GL_FALSE, glm::value_ptr(view));
        glUniformMatrix4fv(glGetUniformLocation(mazeShaderProgram, "projection"), 1, GL_FALSE, glm::value_ptr(proj));

        glBindVertexArray(mazeVAO);
        glBindBuffer(GL_ARRAY_BUFFER, mazeVBO);
        glBufferSubData(GL_ARRAY_BUFFER, 0, mazeLines.size() * sizeof(VertexData), mazeLines.data());

        glLineWidth(1.0f);
        glDrawArrays(GL_LINES, 0, mazeLines.size());

        glBindBuffer(GL_ARRAY_BUFFER, 0);
        glBindVertexArray(0);

        // Restore depth test for the rest of the scene
        glEnable(GL_DEPTH_TEST);
         glUseProgram(shaderProgram);  // Add this line
    }
    glEnable(GL_DEPTH_TEST);
}

// Indices for drawing a cube using triangles
const std::array<GLuint, 36> cubeIndices = {
    // Front face
    0, 1, 2,
    2, 3, 0,
    // Right face
    1, 5, 6,
    6, 2, 1,
    // Back face
    5, 4, 7,
    7, 6, 5,
    // Left face
    4, 0, 3,
    3, 7, 4,
    // Top face
    3, 2, 6,
    6, 7, 3,
    // Bottom face
    4, 5, 1,
    1, 0, 4
};

// Colors for each face
const std::array<glm::vec3, 6> faceColors = {
    glm::vec3(1.0f, 0.0f, 0.0f), // Front - Red
    glm::vec3(0.0f, 1.0f, 0.0f), // Right - Green
    glm::vec3(0.0f, 0.0f, 1.0f), // Back - Blue
    glm::vec3(1.0f, 1.0f, 0.0f), // Left - Yellow
    glm::vec3(1.0f, 0.5f, 0.0f), // Top - Orange
    glm::vec3(1.0f, 1.0f, 1.0f)  // Bottom - White
};

// Add geometry shader source
const char* geometryShaderSourceRubik = R"(
    #version 330 core
    layout (triangles) in;
    layout (triangle_strip, max_vertices = 60) out;
    
    in vec3 vFragPos[];
    in vec3 vNormal[];
    in vec3 vColor[];
    
    out vec3 FragPos;
    out vec3 Normal;
    out vec3 Color;
    
    uniform float morphFactor;
    uniform mat4 model;
    
    // Calculate the position on a sphere
    vec3 spherifyPoint(vec3 pos) {
        // Normalize the position to get the direction
        vec3 dir = normalize(pos);
        // Use a consistent radius for the sphere
        float radius = 0.5;
        // Create the sphere position
        return dir * radius;
    }
    
    // Calculate the midpoint between two vertices and project to sphere if morphing
    vec3 getMidPoint(vec3 p1, vec3 p2) {
        vec3 mid = (p1 + p2) * 0.5;
        
        if (morphFactor > 0.0) {
            // Calculate the sphere point for the midpoint
            vec3 spherePoint = spherifyPoint(mid);
            // Blend between cube and sphere
            mid = mix(mid, spherePoint, morphFactor);
        }
        
        return mid;
    }
    
    // Calculate normal for a point
    vec3 calculateNormal(vec3 pos) {
        // For sphere, normal is direction from center
        vec3 sphereNormal = normalize(pos);
        // For cube, normal is based on the largest component
        vec3 cubeNormal = vec3(0.0);
        float maxComp = max(abs(pos.x), max(abs(pos.y), abs(pos.z)));
        
        if (abs(pos.x) == maxComp) cubeNormal.x = sign(pos.x);
        else if (abs(pos.y) == maxComp) cubeNormal.y = sign(pos.y);
        else if (abs(pos.z) == maxComp) cubeNormal.z = sign(pos.z);
        
        return normalize(mix(cubeNormal, sphereNormal, morphFactor));
    }
    
    // Calculate color for a point
    vec3 interpolateColor(vec3 color1, vec3 color2) {
        return (color1 + color2) * 0.5;
    }
    
    void emitVertex(vec3 position, vec3 normal, vec3 color) {
        // Apply model transform
        vec4 worldPos = model * vec4(position, 1.0);
        FragPos = vec3(worldPos);
        
        // Transform normal to world space (simplified)
        Normal = normalize(normal);
        
        Color = color;
        gl_Position = gl_in[0].gl_Position.w * vec4(position, 1.0);
        EmitVertex();
    }
    
    void main() {
        // If not morphing much, just pass through the original triangles
        if (morphFactor < 0.01) {
            for (int i = 0; i < 3; i++) {
                FragPos = vFragPos[i];
                Normal = vNormal[i];
                Color = vColor[i];
                gl_Position = gl_in[i].gl_Position;
                EmitVertex();
            }
            EndPrimitive();
            return;
        }
        
        // Extract original vertices from input
        vec3 v0 = gl_in[0].gl_Position.xyz / gl_in[0].gl_Position.w;
        vec3 v1 = gl_in[1].gl_Position.xyz / gl_in[1].gl_Position.w;
        vec3 v2 = gl_in[2].gl_Position.xyz / gl_in[2].gl_Position.w;
        
        // Apply spherification to original vertices
        vec3 sv0 = mix(v0, spherifyPoint(v0), morphFactor);
        vec3 sv1 = mix(v1, spherifyPoint(v1), morphFactor);
        vec3 sv2 = mix(v2, spherifyPoint(v2), morphFactor);
        
        // Calculate midpoints on morphed shape
        vec3 m01 = getMidPoint(sv0, sv1);
        vec3 m12 = getMidPoint(sv1, sv2);
        vec3 m20 = getMidPoint(sv2, sv0);
        
        // Calculate normals
        vec3 n0 = calculateNormal(sv0);
        vec3 n1 = calculateNormal(sv1);
        vec3 n2 = calculateNormal(sv2);
        vec3 n01 = calculateNormal(m01);
        vec3 n12 = calculateNormal(m12);
        vec3 n20 = calculateNormal(m20);
        
        // Calculate colors
        vec3 c0 = vColor[0];
        vec3 c1 = vColor[1];
        vec3 c2 = vColor[2];
        vec3 c01 = interpolateColor(c0, c1);
        vec3 c12 = interpolateColor(c1, c2);
        vec3 c20 = interpolateColor(c2, c0);
        
        // For higher morph factors, create more subdivisions for smoother sphere
        if (morphFactor > 0.5) {
            // Calculate center of triangle
            vec3 center = (sv0 + sv1 + sv2) / 3.0;
            center = mix(center, spherifyPoint(center), morphFactor);
            vec3 nCenter = calculateNormal(center);
            vec3 cCenter = (c0 + c1 + c2) / 3.0;
            
            // Emit 6 triangles (further subdivision)
            // Triangle 1: v0, m01, center
            emitVertex(sv0, n0, c0);
            emitVertex(m01, n01, c01);
            emitVertex(center, nCenter, cCenter);
            EndPrimitive();
            
            // Triangle 2: m01, v1, center
            emitVertex(m01, n01, c01);
            emitVertex(sv1, n1, c1);
            emitVertex(center, nCenter, cCenter);
            EndPrimitive();
            
            // Triangle 3: v1, m12, center
            emitVertex(sv1, n1, c1);
            emitVertex(m12, n12, c12);
            emitVertex(center, nCenter, cCenter);
            EndPrimitive();
            
            // Triangle 4: m12, v2, center
            emitVertex(m12, n12, c12);
            emitVertex(sv2, n2, c2);
            emitVertex(center, nCenter, cCenter);
            EndPrimitive();
            
            // Triangle 5: v2, m20, center
            emitVertex(sv2, n2, c2);
            emitVertex(m20, n20, c20);
            emitVertex(center, nCenter, cCenter);
            EndPrimitive();
            
            // Triangle 6: m20, v0, center
            emitVertex(m20, n20, c20);
            emitVertex(sv0, n0, c0);
            emitVertex(center, nCenter, cCenter);
            EndPrimitive();
        } else {
            // Emit 4 triangles (basic subdivision)
            // Triangle 1: v0, m01, m20
            emitVertex(sv0, n0, c0);
            emitVertex(m01, n01, c01);
            emitVertex(m20, n20, c20);
            EndPrimitive();
            
            // Triangle 2: m01, v1, m12
            emitVertex(m01, n01, c01);
            emitVertex(sv1, n1, c1);
            emitVertex(m12, n12, c12);
            EndPrimitive();
            
            // Triangle 3: m20, m12, v2
            emitVertex(m20, n20, c20);
            emitVertex(m12, n12, c12);
            emitVertex(sv2, n2, c2);
            EndPrimitive();
            
            // Triangle 4: m01, m12, m20 (middle triangle)
            emitVertex(m01, n01, c01);
            emitVertex(m12, n12, c12);
            emitVertex(m20, n20, c20);
            EndPrimitive();
        }
    }
)";

// Modify vertex shader source
const char* vertexShaderSourceRubik = R"(
    #version 330 core
    layout (location = 0) in vec3 aPos;
    layout (location = 1) in vec3 aNormal;
    layout (location = 2) in vec3 aColor;
    
    // Outputs to geometry shader
    out vec3 vFragPos;
    out vec3 vNormal;
    out vec3 vColor;
    
    uniform mat4 model;
    uniform mat4 view;
    uniform mat4 projection;
    uniform vec3 faceColor;
    uniform float morphFactor;
    uniform float time;
    uniform vec3 rubikCenter;
    uniform float audioIntensity;  // For audio reactivity
    
    // Color palette for rotation
    const vec3 palette[6] = vec3[6](
        vec3(1.0, 0.0, 0.0),  // Red
        vec3(1.0, 0.5, 0.0),  // Orange
        vec3(0.0, 0.0, 1.0),  // Blue
        vec3(0.0, 1.0, 0.0),  // Green
        vec3(1.0, 1.0, 0.0),  // Yellow
        vec3(1.0, 1.0, 1.0)   // White
    );
    
    // Function to rotate colors in the palette
    vec3 rotateColor(vec3 color, float t) {
        float minDist = 1.0;
        int closestIndex = 0;
        
        for(int i = 0; i < 6; i++) {
            float dist = distance(color, palette[i]);
            if(dist < minDist) {
                minDist = dist;
                closestIndex = i;
            }
        }
        
        float cycleTime = 6.0;
        float normalizedTime = mod(t, cycleTime) / cycleTime;
        float nextIndex = mod(float(closestIndex) + normalizedTime * 6.0, 6.0);
        int nextIndex1 = int(nextIndex);
        int nextIndex2 = int(mod(nextIndex + 1.0, 6.0));
        float blend = fract(nextIndex);
        
        return mix(palette[nextIndex1], palette[nextIndex2], blend);
    }
    
    void main() {
        // Pass the original position to the geometry shader
        // (The geometry shader will handle the spherical morphing)
        vFragPos = vec3(model * vec4(aPos, 1.0));
        
        // Pass normal
        vNormal = mat3(transpose(inverse(model))) * aNormal;
        
        // Calculate color with rotation effect
        vColor = rotateColor(faceColor, time);
        
        // Calculate position for the pipeline
        vec3 center = rubikCenter; // or vec3(0.0) if you don't have a uniform

        // Vector from center to this cube
        vec3 fromCenter = vFragPos - center;

        // Pulsate outward based on treble
        float pulseStrength = 0.8; // How much the gap can grow
        vec3 gapPulse = fromCenter * audioIntensity * pulseStrength;

        // Apply the gap pulse to the position
        vec3 pulsedPos = aPos + gapPulse;

        // Use pulsedPos for further calculations
        // For example, if you do: gl_Position = projection * view * model * vec4(vFragPos, 1.0);
        // Change it to:
        gl_Position = projection * view * model * vec4(pulsedPos, 1.0);
       //gl_Position = projection * view * model * vec4(aPos, 1.0);
    }
)";

// Modify fragment shader source
const char* fragmentShaderSourceRubik = R"(
    #version 330 core
    in vec3 FragPos;
    in vec3 Normal;
    in vec3 Color;
    
    out vec4 FragColor;
    
    struct Light {
        vec3 position;
        vec3 color;
        float intensity;
    };
    
    #define MAX_LIGHTS 4
    uniform Light lights[MAX_LIGHTS];
    uniform int numLights;
    uniform vec3 viewPos;
    uniform float time;
    
    // Different pattern functions for each face color
    
    // Red face (Left): Concentric circles pattern
    vec3 redPattern(vec3 baseColor, vec3 position, float time) {
        float dist = length(position.yz * 2.0);
        float rings = sin(dist * 5.0 - time * 1.0) * 0.5 + 0.5;
        float pulse = sin(time * 0.5) * 0.1 + 0.9;
        
        return mix(baseColor * 0.8, baseColor * 1.4, rings) * pulse;
    }
    
    // Orange face (Right): Fire-like effect
    vec3 orangePattern(vec3 baseColor, vec3 position, float time) {
        float noise = sin(position.y * 6.0 + time * 1.5) * cos(position.z * 6.0 + time) * 0.5 + 0.5;
        float flicker = sin(time * 2.0) * 0.1 + 0.9;
        
        // Enhance reds and yellows for fire effect
        vec3 fireColor = mix(baseColor, vec3(1.0, 0.3, 0.0), 0.3);
        return mix(fireColor * 0.8, fireColor * 1.4, noise) * flicker;
    }
    
    // Yellow face (Back): Checkerboard pattern
    vec3 yellowPattern(vec3 baseColor, vec3 position, float time) {
        float checkSize = 3.0;
        float moving = time * 0.3;
        float checker = mod(floor(position.x * checkSize + moving) + 
                          floor(position.y * checkSize), 2.0);
        
        return mix(baseColor * 0.8, baseColor * 1.3, checker);
    }
    
    // Green face (Top): Hexagonal grid
    vec3 greenPattern(vec3 baseColor, vec3 position, float time) {
        vec2 p = position.xz;
        p = p * 3.0; // Scale
        
        // Hexagonal grid
        const float hexSize = 1.0;
        p.x *= 0.57735 * 2.0;
        p.y += mod(floor(p.x), 2.0) * 0.5;
        p = mod(p, hexSize) - 0.5 * hexSize;
        float hex = length(p);
        
        // Pulsating
        float pulse = sin(time * 0.5 + hex * 2.0) * 0.15 + 0.85;
        
        return mix(baseColor * 0.7, baseColor * 1.3, smoothstep(0.3, 0.4, hex)) * pulse;
    }
    
    // Blue face (Bottom): Wave pattern
    vec3 bluePattern(vec3 baseColor, vec3 position, float time) {
        float waves = sin(position.x * 6.0 + time) * 
                    cos(position.z * 6.0 + time * 0.7) * 0.5 + 0.5;
        
        // Add depth to blue
        return mix(baseColor * 0.7, vec3(0.0, 0.6, 1.0), waves * 0.7);
    }
    
    // White face (Front): Sparkle effect
    vec3 whitePattern(vec3 baseColor, vec3 position, float time) {
        float noise1 = sin(position.x * 10.0 + time * 0.3) * 
                      sin(position.y * 10.0 + time * 0.5) * 
                      sin(position.z * 10.0 + time * 0.7);
        float sparkle = smoothstep(0.7, 0.9, noise1);
        
        return mix(baseColor * 0.9, vec3(1.3), sparkle);
    }
    
    // Darker pattern for internal faces
    vec3 darkPattern(vec3 baseColor, vec3 position, float time) {
        float gradient = (sin(position.x * 2.0 + time * 0.5) + 
                        sin(position.y * 2.0 + time * 0.6) + 
                        sin(position.z * 2.0 + time * 0.3)) / 3.0;
        return baseColor * (0.4 + gradient * 0.1);
    }
    
    // Choose pattern based on color
    vec3 applyPattern(vec3 baseColor, vec3 position, float time) {
        // Detect which color we're dealing with based on the RGB values
        float r = baseColor.r;
        float g = baseColor.g;
        float b = baseColor.b;
        
        // Find the two closest palette colors for blending
        int color1Index = -1;
        int color2Index = -1;
        float blendFactor = 0.0;
        
        // For each original palette color
        const vec3 palette[6] = vec3[6](
            vec3(1.0, 0.0, 0.0),   // Red
            vec3(1.0, 0.5, 0.0),   // Orange
            vec3(0.0, 0.0, 1.0),   // Blue
            vec3(0.0, 1.0, 0.0),   // Green
            vec3(1.0, 1.0, 0.0),   // Yellow
            vec3(1.0, 1.0, 1.0)    // White
        );
        
        // Find the closest palette color
        float minDist = 10.0;
        for(int i = 0; i < 6; i++) {
            float dist = distance(baseColor, palette[i]);
            if(dist < minDist) {
                minDist = dist;
                color1Index = i;
            }
        }
        
        // Extract the cycling information from the base color
        // During color rotation, the vertex shader sends us a color that's
        // already a blend between two palette colors
        if(color1Index >= 0) {
            // Calculate the next color in the rotation
            float cycleTime = 6.0;
            float normalizedTime = mod(time, cycleTime) / cycleTime;
            float nextIndex = mod(float(color1Index) + normalizedTime * 6.0, 6.0);
            color1Index = int(nextIndex);
            color2Index = int(mod(float(color1Index) + 1.0, 6.0));
            blendFactor = fract(nextIndex);
        }
        
        // If we couldn't identify the color, use the dark pattern
        if(color1Index < 0) {
            return darkPattern(baseColor, position, time);
        }
        
        // Apply pattern for first color
        vec3 pattern1;
        switch(color1Index) {
            case 0: pattern1 = redPattern(palette[0], position, time); break;
            case 1: pattern1 = orangePattern(palette[1], position, time); break;
            case 2: pattern1 = bluePattern(palette[2], position, time); break;
            case 3: pattern1 = greenPattern(palette[3], position, time); break;
            case 4: pattern1 = yellowPattern(palette[4], position, time); break;
            case 5: pattern1 = whitePattern(palette[5], position, time); break;
            default: pattern1 = darkPattern(baseColor, position, time); break;
        }
        
        // Apply pattern for second color
        vec3 pattern2;
        switch(color2Index) {
            case 0: pattern2 = redPattern(palette[0], position, time); break;
            case 1: pattern2 = orangePattern(palette[1], position, time); break;
            case 2: pattern2 = bluePattern(palette[2], position, time); break;
            case 3: pattern2 = greenPattern(palette[3], position, time); break;
            case 4: pattern2 = yellowPattern(palette[4], position, time); break;
            case 5: pattern2 = whitePattern(palette[5], position, time); break;
            default: pattern2 = darkPattern(baseColor, position, time); break;
        }
        
        // Blend the two patterns
        return mix(pattern1, pattern2, blendFactor);
    }
    
    void main() {
        // Apply color pattern based on the face color
        vec3 texturedColor = applyPattern(Color, FragPos, time);
        
        // Higher ambient for better visibility
        float ambientStrength = 0.5;
        vec3 ambient = vec3(0.0);
        for(int i = 0; i < numLights; i++) {
            ambient += ambientStrength * lights[i].color;
        }
        ambient = (ambient / float(numLights)) * texturedColor;
        
        // Combined lighting from all sources with reduced specular to emphasize patterns
        vec3 diffuse = vec3(0.0);
        vec3 specular = vec3(0.0);
        
        for(int i = 0; i < numLights; i++) {
            // Diffuse - softer
            vec3 lightDir = normalize(lights[i].position - FragPos);
            float diff = max(dot(normalize(Normal), lightDir), 0.0);
            // Smoother falloff
            diff = pow(diff, 0.8); // Make diffuse lighting more distributed
            diffuse += diff * lights[i].intensity * lights[i].color * texturedColor;
            
            // Reduced specular effect to avoid washing out patterns
            float specularStrength = 0.3;
            vec3 viewDir = normalize(viewPos - FragPos);
            vec3 halfwayDir = normalize(lightDir + viewDir);
            float spec = pow(max(dot(normalize(Normal), halfwayDir), 0.0), 16.0);
            specular += specularStrength * spec * lights[i].color;
        }
        
        // Combine all lighting with more weight on the patterns
        vec3 result = ambient + diffuse * 0.7 + specular;
        
        // Gentler cel-shading for better pattern visibility
        float brightness = dot(result, vec3(0.2126, 0.7152, 0.0722));
        
        if (brightness > 0.9) result *= 1.1;
        else if (brightness > 0.6) result *= 0.9;
        else if (brightness > 0.3) result *= 0.7;
        else result *= 0.5;
        
        FragColor = vec4(result, 1.0);
    }
)";

// Add these global variables after the existing ones
GLuint floorVAO = 0;
GLuint floorVBO = 0;
GLuint floorEBO = 0;
GLuint mirrorShaderProgram = 0;

// Add mirror shader source
const char* mirrorVertexShaderSource = R"(
    #version 330 core
    layout (location = 0) in vec3 aPos;
    layout (location = 1) in vec3 aNormal;
    layout (location = 2) in vec2 aTexCoord;
    
    out vec3 FragPos;
    out vec3 Normal;
    out vec2 TexCoord;
    
    uniform mat4 model;
    uniform mat4 view;
    uniform mat4 projection;
    
    void main() {
        FragPos = vec3(model * vec4(aPos, 1.0));
        Normal = mat3(transpose(inverse(model))) * aNormal;
        TexCoord = aTexCoord;
        gl_Position = projection * view * model * vec4(aPos, 1.0);
    }
)";

const char* mirrorFragmentShaderSource = R"(
    #version 330 core
    in vec3 FragPos;
    in vec3 Normal;
    in vec2 TexCoord;
    
    out vec4 FragColor;
    
    struct Light {
        vec3 position;
        vec3 color;
        float intensity;
    };
    
    #define MAX_LIGHTS 4
    uniform Light lights[MAX_LIGHTS];
    uniform int numLights;
    uniform vec3 viewPos;
    uniform float time;
    uniform mat4 view;
    uniform mat4 projection;
    
    void main() {
        // Simplified base color
        vec3 baseColor = vec3(0.4, 0.6, 0.8);
        
        // Add subtle grid pattern
        float gridSize = 1.0;
        float gridLine = 0.01;
        float gridX = mod(TexCoord.x * 5.0, gridSize);
        float gridY = mod(TexCoord.y * 5.0, gridSize);
        
        if (gridX < gridLine || gridY < gridLine) {
            baseColor *= 0.95;
        }
        
        // Simple wave effect
        float wave = sin(TexCoord.x * 4.0 + TexCoord.y * 4.0 + time * 0.2) * 0.02 + 0.98;
        baseColor *= wave;
        
        // Basic lighting
        vec3 normal = normalize(Normal);
        vec3 viewDir = normalize(viewPos - FragPos);
        
        // Calculate Fresnel effect for mirror
        float fresnel = pow(1.0 - max(dot(normal, viewDir), 0.0), 3.0);
        fresnel = 0.2 + 0.6 * fresnel;
        
        // Ambient light
        vec3 ambient = vec3(0.1) * baseColor;
        
        // Directional light calculations
        vec3 result = ambient;
        
        // Process lights
        for(int i = 0; i < numLights; i++) {
            // Diffuse
            vec3 lightDir = normalize(lights[i].position - FragPos);
            float diff = max(dot(normal, lightDir), 0.0);
            vec3 diffuse = diff * lights[i].intensity * lights[i].color * baseColor;
            
            // Specular
            vec3 halfwayDir = normalize(lightDir + viewDir);
            float spec = pow(max(dot(normal, halfwayDir), 0.0), 32.0);
            vec3 specular = spec * lights[i].color * 0.5;
            
            result += diffuse + specular;
        }
        
        // Adjust brightness
        result = result * (0.8 + fresnel * 0.4);
        
        // Make edges fade out
        float distFromCenter = length(TexCoord - vec2(0.5, 0.5)) * 2.0;
        float edgeFade = 1.0 - smoothstep(0.7, 1.0, distFromCenter);
        
        FragColor = vec4(result, 0.7 * edgeFade);
    }
)";

// Initialize mirror shader
bool initMirrorShader() {
    // Create vertex shader
    GLuint vertexShader = glCreateShader(GL_VERTEX_SHADER);
    glShaderSource(vertexShader, 1, &mirrorVertexShaderSource, NULL);
    glCompileShader(vertexShader);
    
    // Check for errors
    GLint success;
    GLchar infoLog[512];
    glGetShaderiv(vertexShader, GL_COMPILE_STATUS, &success);
    if (!success) {
        glGetShaderInfoLog(vertexShader, 512, NULL, infoLog);
        printf("Error: Mirror vertex shader compilation failed: %s\n", infoLog);
        return false;
    }
    
    // Create fragment shader
    GLuint fragmentShader = glCreateShader(GL_FRAGMENT_SHADER);
    glShaderSource(fragmentShader, 1, &mirrorFragmentShaderSource, NULL);
    glCompileShader(fragmentShader);
    
    // Check for errors
    glGetShaderiv(fragmentShader, GL_COMPILE_STATUS, &success);
    if (!success) {
        glGetShaderInfoLog(fragmentShader, 512, NULL, infoLog);
        printf("Error: Mirror fragment shader compilation failed: %s\n", infoLog);
        return false;
    }
    
    // Create shader program
    mirrorShaderProgram = glCreateProgram();
    glAttachShader(mirrorShaderProgram, vertexShader);
    glAttachShader(mirrorShaderProgram, fragmentShader);
    glLinkProgram(mirrorShaderProgram);
    
    // Check for errors
    glGetProgramiv(mirrorShaderProgram, GL_LINK_STATUS, &success);
    if (!success) {
        glGetProgramInfoLog(mirrorShaderProgram, 512, NULL, infoLog);
        printf("Error: Mirror shader program linking failed: %s\n", infoLog);
        return false;
    }
    
    // Clean up shaders
    glDeleteShader(vertexShader);
    glDeleteShader(fragmentShader);
    
    return true;
}

// Initialize mirror floor
bool initMirrorFloor() {
    // Create floor vertices, normals, and texture coordinates
    float floorSize = 12.0f; // Larger floor
    float floorY = 0.0f;    // Position floor further down from -2.0f to -4.0f
    
    float floorVertices[] = {
        // Positions          // Normals         // Texture Coords
        -floorSize, floorY,  floorSize,  0.0f, 1.0f, 0.0f,  0.0f, 0.0f,
         floorSize, floorY,  floorSize,  0.0f, 1.0f, 0.0f,  1.0f, 0.0f,
         floorSize, floorY, -floorSize,  0.0f, 1.0f, 0.0f,  1.0f, 1.0f,
        -floorSize, floorY, -floorSize,  0.0f, 1.0f, 0.0f,  0.0f, 1.0f
    };
    
    unsigned int floorIndices[] = {
        0, 1, 2,
        2, 3, 0
    };
    
    // Generate and bind VAO
    glGenVertexArrays(1, &floorVAO);
    glBindVertexArray(floorVAO);
    
    // Generate and bind VBO
    glGenBuffers(1, &floorVBO);
    glBindBuffer(GL_ARRAY_BUFFER, floorVBO);
    glBufferData(GL_ARRAY_BUFFER, sizeof(floorVertices), floorVertices, GL_STATIC_DRAW);
    
    // Generate and bind EBO
    glGenBuffers(1, &floorEBO);
    glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, floorEBO);
    glBufferData(GL_ELEMENT_ARRAY_BUFFER, sizeof(floorIndices), floorIndices, GL_STATIC_DRAW);
    
    // Position attribute
    glVertexAttribPointer(0, 3, GL_FLOAT, GL_FALSE, 8 * sizeof(float), (void*)0);
    glEnableVertexAttribArray(0);
    
    // Normal attribute
    glVertexAttribPointer(1, 3, GL_FLOAT, GL_FALSE, 8 * sizeof(float), (void*)(3 * sizeof(float)));
    glEnableVertexAttribArray(1);
    
    // Texture coordinate attribute
    glVertexAttribPointer(2, 2, GL_FLOAT, GL_FALSE, 8 * sizeof(float), (void*)(6 * sizeof(float)));
    glEnableVertexAttribArray(2);
    
    // Unbind
    glBindVertexArray(0);
    
    return true;
}

// Update the initShaders function to also initialize the mirror shader
bool initShadersRubik() {
    printf("Debug: Creating shaders...\n");
    
    // Create regular shaders
    GLuint vertexShader = glCreateShader(GL_VERTEX_SHADER);
    GLuint geometryShader = glCreateShader(GL_GEOMETRY_SHADER);
    GLuint fragmentShader = glCreateShader(GL_FRAGMENT_SHADER);
    
    // Set shader source
    glShaderSource(vertexShader, 1, &vertexShaderSourceRubik, NULL);
    glShaderSource(geometryShader, 1, &geometryShaderSourceRubik, NULL);
    glShaderSource(fragmentShader, 1, &fragmentShaderSourceRubik, NULL);
    
    // Compile vertex shader
    glCompileShader(vertexShader);
    
    // Check for compile errors
    GLint success;
    GLchar infoLog[512];
    glGetShaderiv(vertexShader, GL_COMPILE_STATUS, &success);
    if (!success) {
        glGetShaderInfoLog(vertexShader, 512, NULL, infoLog);
        printf("Error: Vertex shader compilation failed: %s\n", infoLog);
        return false;
    }
    
    // Compile geometry shader
    glCompileShader(geometryShader);
    
    // Check for compile errors
    glGetShaderiv(geometryShader, GL_COMPILE_STATUS, &success);
    if (!success) {
        glGetShaderInfoLog(geometryShader, 512, NULL, infoLog);
        printf("Error: Geometry shader compilation failed: %s\n", infoLog);
        return false;
    }
    
    // Compile fragment shader
    glCompileShader(fragmentShader);
    
    // Check for compile errors
    glGetShaderiv(fragmentShader, GL_COMPILE_STATUS, &success);
    if (!success) {
        glGetShaderInfoLog(fragmentShader, 512, NULL, infoLog);
        printf("Error: Fragment shader compilation failed: %s\n", infoLog);
        return false;
    }
    
    // Create shader program
    shaderProgram = glCreateProgram();
    
    // Attach shaders to program
    glAttachShader(shaderProgram, vertexShader);
    glAttachShader(shaderProgram, geometryShader);
    glAttachShader(shaderProgram, fragmentShader);
    
    // Link program
    glLinkProgram(shaderProgram);
    
    // Check for linking errors
    glGetProgramiv(shaderProgram, GL_LINK_STATUS, &success);
    if (!success) {
        glGetProgramInfoLog(shaderProgram, 512, NULL, infoLog);
        printf("Error: Shader program linking failed: %s\n", infoLog);
        return false;
    }
    
    // Delete shaders as they're linked into our program now and no longer necessary
    glDeleteShader(vertexShader);
    glDeleteShader(geometryShader);
    glDeleteShader(fragmentShader);
    
    // Initialize mirror shader
    if (!initMirrorShader()) {
        printf("Error: Failed to initialize mirror shader\n");
        return false;
    }
    
    printf("Debug: Shaders created successfully\n");
    return true;
}

// Update the initCubes function to prepare for physics
void initCubes() {
    printf("Debug: Initializing cubes...\n");
    cubes.clear();
    
     // Create a 3x3x3 grid of cubes
    for (int x = -1; x <= 1; x++) {
        for (int y = -1; y <= 1; y++) {
            for (int z = -1; z <= 1; z++) {
                if (x == 0 && y == 0 && z == 0) continue;  // Skip center cube
                
                CubeFace cube;
                float spacing = 1.1f + params.spacing;  // Base spacing between cubes
                cube.position = glm::vec3(x * spacing, 
                                        y * spacing,  // Add height offset
                                        z * spacing);
                
                // Set layer indices (keep as -1,0,1)
                cube.layer[0] = x;
                cube.layer[1] = y;
                cube.layer[2] = z;
                
                // Set initial color based on position
                cube.color = glm::vec3(0.8f);  // Default gray
                
                // Assign face colors based on position
                if (x == -1) cube.color = glm::vec3(1.0f, 0.0f, 0.0f);  // Red
                if (x == 1)  cube.color = glm::vec3(1.0f, 0.5f, 0.0f);  // Orange
                if (y == -1) cube.color = glm::vec3(0.0f, 0.0f, 1.0f);  // Blue
                if (y == 1)  cube.color = glm::vec3(0.0f, 1.0f, 0.0f);  // Green
                if (z == -1) cube.color = glm::vec3(1.0f, 1.0f, 0.0f);  // Yellow
                if (z == 1)  cube.color = glm::vec3(1.0f, 1.0f, 1.0f);  // White
                
                // Initialize physics properties
                cube.velocity = glm::vec3(0.0f);
                cube.angularVelocity = glm::vec3(0.0f);
                cube.mass = 1.0f;
                cube.onGround = false;
                cube.animationTransform = glm::mat4(1.0f);
                cube.baseTransform = glm::mat4(1.0f);

                cubes.push_back(cube);
            }
        }
    }
    
    // Store original positions for restoration
    initOriginalPositions();
    
    printf("Debug: Created %zu cubes\n", cubes.size());
}

bool initBuffersRubik() {
    printf("Debug: Creating buffers...\n");
    
    // Create and bind VAO
    glGenVertexArrays(1, &vao);
    glBindVertexArray(vao);
    
    // Create and bind VBO
    glGenBuffers(1, &vbo);
    glBindBuffer(GL_ARRAY_BUFFER, vbo);
    
    // Create vertex data with positions, normals, and colors
    std::vector<GLfloat> vertexData;
    for (size_t i = 0; i < faceVertices.size(); i += 3) {
        // Position
        vertexData.push_back(faceVertices[i]);
        vertexData.push_back(faceVertices[i + 1]);
        vertexData.push_back(faceVertices[i + 2]);
        
        // Normal (calculate based on face)
        int faceIndex = i / 12;  // Each face has 4 vertices
        glm::vec3 normal;
        switch (faceIndex) {
            case 0: normal = glm::vec3(0.0f, 0.0f, 1.0f); break;   // Front
            case 1: normal = glm::vec3(1.0f, 0.0f, 0.0f); break;   // Right
            case 2: normal = glm::vec3(0.0f, 0.0f, -1.0f); break;  // Back
            case 3: normal = glm::vec3(-1.0f, 0.0f, 0.0f); break;  // Left
            case 4: normal = glm::vec3(0.0f, 1.0f, 0.0f); break;   // Top
            case 5: normal = glm::vec3(0.0f, -1.0f, 0.0f); break;  // Bottom
        }
        vertexData.push_back(normal.x);
        vertexData.push_back(normal.y);
        vertexData.push_back(normal.z);
        
        // Color (will be set per face in display)
        vertexData.push_back(1.0f);
        vertexData.push_back(1.0f);
        vertexData.push_back(1.0f);
    }
    
    glBufferData(GL_ARRAY_BUFFER, vertexData.size() * sizeof(GLfloat), vertexData.data(), GL_STATIC_DRAW);
    
    // Create and bind EBO
    glGenBuffers(1, &ebo);
    glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, ebo);
    glBufferData(GL_ELEMENT_ARRAY_BUFFER, cubeIndices.size() * sizeof(GLuint), cubeIndices.data(), GL_STATIC_DRAW);
    
    // Set vertex attributes
    // Position
    glVertexAttribPointer(0, 3, GL_FLOAT, GL_FALSE, 9 * sizeof(float), (void*)0);
    glEnableVertexAttribArray(0);
    
    // Normal
    glVertexAttribPointer(1, 3, GL_FLOAT, GL_FALSE, 9 * sizeof(float), (void*)(3 * sizeof(float)));
    glEnableVertexAttribArray(1);
    
    // Color
    glVertexAttribPointer(2, 3, GL_FLOAT, GL_FALSE, 9 * sizeof(float), (void*)(6 * sizeof(float)));
    glEnableVertexAttribArray(2);
    
    // Initialize cube positions
    initCubes();
    
    // Unbind buffers
    glBindBuffer(GL_ARRAY_BUFFER, 0);
    glBindVertexArray(0);
    
    printf("Debug: Buffers created successfully\n");
    return true;
}

// Draw Pong paddle as a simple colored cuboid
void drawPongPaddle(const PongPaddle& paddle) {
    // Use the regular shader program
    glUseProgram(shaderProgram);
    
    // Set shader uniforms
    GLint modelLoc = glGetUniformLocation(shaderProgram, "model");
    GLint faceColorLoc = glGetUniformLocation(shaderProgram, "faceColor");
    
    // Create model matrix for the paddle
    glm::mat4 model = glm::translate(glm::mat4(1.0f), paddle.position);
    model = glm::scale(model, paddle.size);
    
    // Set model matrix
    glUniformMatrix4fv(modelLoc, 1, GL_FALSE, glm::value_ptr(model));
    
    // Set paddle color
    glUniform3fv(faceColorLoc, 1, glm::value_ptr(paddle.color));
    
    // Draw the paddle using the cube's vertex data
    glBindVertexArray(vao);
    glDrawElements(GL_TRIANGLES, 36, GL_UNSIGNED_INT, 0);
    glBindVertexArray(0);
}

// Draw both paddles for the Pong game
void drawPongPaddles() {
    drawPongPaddle(leftPaddle);
    drawPongPaddle(rightPaddle);
}

// Update the display function to draw paddles only for Pong mode
void displayRubik() {
    // Store current window
    int currentWindow = glutGetWindow();
    
    // Clear buffers
    glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
    
    // Set up lights
    struct Light {
        glm::vec3 position;
        glm::vec3 color;
        float intensity;
    };
    
    // Create multiple lights with more stable positions and brighter intensities
    Light lights[4] = {
        // Main white light - fixed overhead position
        {
            glm::vec3(0.0f, 8.0f, 0.0f),
            glm::vec3(1.0f, 1.0f, 1.0f),
            params.lightIntensity * 1.7f  // Increase intensity
        },
        // Subtle red light - fixed position
        {
            glm::vec3(-6.0f, 2.0f, 0.0f),
            glm::vec3(1.0f, 0.2f, 0.2f),
            params.lightIntensity * 0.6f  // Increase intensity
        },
        // Subtle blue light - fixed position
        {
            glm::vec3(0.0f, 2.0f, -6.0f),
            glm::vec3(0.2f, 0.2f, 1.0f),
            params.lightIntensity * 0.6f  // Increase intensity
        },
        // Subtle green fill light - fixed position
        {
            glm::vec3(6.0f, -2.0f, 6.0f),
            glm::vec3(0.2f, 1.0f, 0.2f),
            params.lightIntensity * 0.5f  // Increase intensity
        }
    };
    
    // For Pong mode, position camera for better view
    if (pongMode) {
        glm::vec3 rubikCenter(0.0f);
        for (const auto& cube : cubes) rubikCenter += cube.position;
        rubikCenter /= (float)cubes.size();
        GLint rubikCenterLoc = glGetUniformLocation(shaderProgram, "rubikCenter");
        glUniform3fv(rubikCenterLoc, 1, glm::value_ptr(rubikCenter));

        // Override camera position to look at the play field from a better angle
        float playfieldMidY = (PONG_CEILING_Y + PONG_FLOOR_Y) * 0.5f;
        float playfieldHeight = PONG_CEILING_Y - PONG_FLOOR_Y;
        float verticalFov = glm::radians(45.0f); // or whatever your FOV is
        // Calculate Z so playfield fits in view
        float zDist = (playfieldHeight * 0.5f) / tan(verticalFov * 0.5f);
        cameraPos = glm::vec3(0.0f, PONG_CEILING_Y, zDist + 12.0f); // +2 for margin
    } 

    // Set up view and projection matrices
    glm::vec3 cameraTarget = glm::vec3(0.0f, 0.0f, 0.0f);
    glm::vec3 cameraDirection = glm::normalize(cameraPos - cameraTarget);
    glm::vec3 cameraRight = glm::normalize(glm::cross(glm::vec3(0.0f, 1.0f, 0.0f), cameraDirection));
    glm::vec3 cameraUp = glm::cross(cameraDirection, cameraRight);
    
    // Use cameraPos directly instead of scaling it
    glm::mat4 view = glm::lookAt(
        cameraPos,  // Use actual camera position instead of scaled version
        cameraTarget,
        cameraUp
    );
    
    int w = glutGet(GLUT_WINDOW_WIDTH);
    int h = glutGet(GLUT_WINDOW_HEIGHT);
    float aspectRatio = (float)w / (float)h;
    projection = glm::perspective(glm::radians(45.0f), aspectRatio, 0.1f, 100.0f);
    //glm::mat4 projection = glm::perspective(glm::radians(45.0f), 1.0f, 0.1f, 100.0f);
    
    // First draw the reflected cubes with simplified approach (no stencil buffer)
    // Create reflection matrix
    float floorY = -0.0f; // This should match the value from initMirrorFloor
    glm::mat4 reflectionMatrix = glm::scale(
        glm::translate(glm::mat4(1.0f), glm::vec3(0.0f, floorY, 0.0f)),
        glm::vec3(1.0f, -1.0f, 1.0f)
    );
    
    // Use regular shader program
    glUseProgram(shaderProgram);
    
    // Set morph factor uniform
    GLint morphFactorLoc = glGetUniformLocation(shaderProgram, "morphFactor");
    glUniform1f(morphFactorLoc, params.morphFactor);
    
    // Set time uniform for animations
    GLint timeLoc = glGetUniformLocation(shaderProgram, "time");
    glUniform1f(timeLoc, globalTime);
    
    // Set light uniforms
    GLint numLightsLoc = glGetUniformLocation(shaderProgram, "numLights");
    glUniform1i(numLightsLoc, 4);
    
    for(int i = 0; i < 4; i++) {
        std::string base = "lights[" + std::to_string(i) + "].";
        GLint posLoc = glGetUniformLocation(shaderProgram, (base + "position").c_str());
        GLint colorLoc = glGetUniformLocation(shaderProgram, (base + "color").c_str());
        GLint intensityLoc = glGetUniformLocation(shaderProgram, (base + "intensity").c_str());
        
        glUniform3fv(posLoc, 1, glm::value_ptr(lights[i].position));
        glUniform3fv(colorLoc, 1, glm::value_ptr(lights[i].color));
        glUniform1f(intensityLoc, lights[i].intensity);
    }
    
    // Set view position for specular calculations
    GLint viewPosLoc = glGetUniformLocation(shaderProgram, "viewPos");
    glUniform3fv(viewPosLoc, 1, glm::value_ptr(cameraPos));
    
    // Set view and projection uniforms
    GLint viewLoc = glGetUniformLocation(shaderProgram, "view");
    GLint projLoc = glGetUniformLocation(shaderProgram, "projection");
    
    glUniformMatrix4fv(viewLoc, 1, GL_FALSE, glm::value_ptr(view));
    glUniformMatrix4fv(projLoc, 1, GL_FALSE, glm::value_ptr(projection));
    
    // Bind VAO for cubes
    glBindVertexArray(vao);
    
    // Set up blending for reflections
    glEnable(GL_BLEND);
    glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA);
    
    // Set up depth test for reflections
    glEnable(GL_DEPTH_TEST);
    glDepthMask(GL_TRUE);
    
    // We need to flip face culling for reflections
    glEnable(GL_CULL_FACE);
    glCullFace(GL_FRONT);
    
    // Draw each cube reflected
        GLint modelLoc = glGetUniformLocation(shaderProgram, "model");
    GLint faceColorLoc = glGetUniformLocation(shaderProgram, "faceColor");
    
    for (const auto& cube : cubes) {
        // Start with translation to cube's base position
        glm::mat4 model = glm::translate(glm::mat4(1.0f), cube.position);

        model = glm::translate(model, glm::vec3(0.0f, 2.0f, 0.0f)); // height above floor        

        // Apply accumulated animation transform
        model = model * cube.animationTransform;
       
        // Apply reflection matrix
        model = reflectionMatrix * model;
        
        glUniformMatrix4fv(modelLoc, 1, GL_FALSE, glm::value_ptr(model));
        
        // Draw each face with its appropriate color but darker for reflections
        for (size_t i = 0; i < 6; i++) {
            glm::vec3 faceColor;
            
            // Determine which face we're drawing and use the appropriate color
            if (i == 0 && cube.layer[2] == 1) faceColor = glm::vec3(1.0f, 1.0f, 1.0f);  // Front face (white)
            else if (i == 1 && cube.layer[0] == 1) faceColor = glm::vec3(1.0f, 0.5f, 0.0f);  // Right face (orange)
            else if (i == 2 && cube.layer[2] == -1) faceColor = glm::vec3(1.0f, 1.0f, 0.0f);  // Back face (yellow)
            else if (i == 3 && cube.layer[0] == -1) faceColor = glm::vec3(1.0f, 0.0f, 0.0f);  // Left face (red)
            else if (i == 4 && cube.layer[1] == 1) faceColor = glm::vec3(0.0f, 1.0f, 0.0f);  // Top face (green)
            else if (i == 5 && cube.layer[1] == -1) faceColor = glm::vec3(0.0f, 0.0f, 1.0f);  // Bottom face (blue)
            else faceColor = glm::vec3(0.2f, 0.2f, 0.2f);  // Default dark gray for internal faces
            
            // Darken and make more transparent for reflection
            faceColor *= 0.7f;
            
            // Set face color uniform
            glUniform3fv(faceColorLoc, 1, glm::value_ptr(faceColor));
            
            // Draw the face
            glDrawElements(GL_TRIANGLES, 6, GL_UNSIGNED_INT, (void*)(i * 6 * sizeof(GLuint)));
        }
    }
    
    // Reset normal face culling
    glCullFace(GL_BACK);
    
    // Draw the mirror floor
    glUseProgram(mirrorShaderProgram);
    
    // Set common mirror shader uniforms
    GLint mirrorViewLoc = glGetUniformLocation(mirrorShaderProgram, "view");
    GLint mirrorProjLoc = glGetUniformLocation(mirrorShaderProgram, "projection");
    GLint mirrorViewPosLoc = glGetUniformLocation(mirrorShaderProgram, "viewPos");
    GLint mirrorTimeLoc = glGetUniformLocation(mirrorShaderProgram, "time");
    GLint mirrorNumLightsLoc = glGetUniformLocation(mirrorShaderProgram, "numLights");
    
    glUniformMatrix4fv(mirrorViewLoc, 1, GL_FALSE, glm::value_ptr(view));
    glUniformMatrix4fv(mirrorProjLoc, 1, GL_FALSE, glm::value_ptr(projection));
    glUniform3fv(mirrorViewPosLoc, 1, glm::value_ptr(cameraPos));
    glUniform1f(mirrorTimeLoc, globalTime);
    glUniform1i(mirrorNumLightsLoc, 4);
    
    // Set light uniforms for mirror shader
    for(int i = 0; i < 4; i++) {
        std::string base = "lights[" + std::to_string(i) + "].";
        GLint posLoc = glGetUniformLocation(mirrorShaderProgram, (base + "position").c_str());
        GLint colorLoc = glGetUniformLocation(mirrorShaderProgram, (base + "color").c_str());
        GLint intensityLoc = glGetUniformLocation(mirrorShaderProgram, (base + "intensity").c_str());
        
        glUniform3fv(posLoc, 1, glm::value_ptr(lights[i].position));
        glUniform3fv(colorLoc, 1, glm::value_ptr(lights[i].color));
        glUniform1f(intensityLoc, lights[i].intensity);
    }
    
    // Set mirror model matrix (identity, as it's already positioned in vertex data)
    GLint mirrorModelLoc = glGetUniformLocation(mirrorShaderProgram, "model");
    glm::mat4 mirrorModel = glm::mat4(1.0f);
    glUniformMatrix4fv(mirrorModelLoc, 1, GL_FALSE, glm::value_ptr(mirrorModel));
    
    // Draw mirror floor
    glBindVertexArray(floorVAO);
    glDrawElements(GL_TRIANGLES, 6, GL_UNSIGNED_INT, 0);
    glBindVertexArray(0);
    
    // Now draw the actual cubes (normal, not reflected)
    glUseProgram(shaderProgram);
    
    // Reset blend function
    glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA);
    
    // Bind VAO for cubes
    glBindVertexArray(vao);
    glUniform1f(glGetUniformLocation(shaderProgram, "audioIntensity"), audio_treble_intensity*0.1);
    // Draw each cube normally
    for (const auto& cube : cubes) {
        // Start with translation to cube's base position
        glm::mat4 model = glm::translate(glm::mat4(1.0f), cube.position);
        model = glm::translate(model, glm::vec3(0.0f, 2.0f, 0.0f)); // height above floor
        model = model * cube.animationTransform;

        glUniformMatrix4fv(modelLoc, 1, GL_FALSE, glm::value_ptr(model));
        
        // Draw each face with its appropriate color based on position
        for (size_t i = 0; i < 6; i++) {
            glm::vec3 faceColor;
            
            // Determine which face we're drawing and use the appropriate color
            if (i == 0 && cube.layer[2] == 1) faceColor = glm::vec3(1.0f, 1.0f, 1.0f);  // Front face (white)
            else if (i == 1 && cube.layer[0] == 1) faceColor = glm::vec3(1.0f, 0.5f, 0.0f);  // Right face (orange)
            else if (i == 2 && cube.layer[2] == -1) faceColor = glm::vec3(1.0f, 1.0f, 0.0f);  // Back face (yellow)
            else if (i == 3 && cube.layer[0] == -1) faceColor = glm::vec3(1.0f, 0.0f, 0.0f);  // Left face (red)
            else if (i == 4 && cube.layer[1] == 1) faceColor = glm::vec3(0.0f, 1.0f, 0.0f);  // Top face (green)
            else if (i == 5 && cube.layer[1] == -1) faceColor = glm::vec3(0.0f, 0.0f, 1.0f);  // Bottom face (blue)
            else faceColor = glm::vec3(0.2f, 0.2f, 0.2f);  // Default dark gray for internal faces
            
            // Set face color uniform
            glUniform3fv(faceColorLoc, 1, glm::value_ptr(faceColor));
            
            // Draw the face
            glDrawElements(GL_TRIANGLES, 6, GL_UNSIGNED_INT, (void*)(i * 6 * sizeof(GLuint)));
        }
    }
    
    // If in Pong mode, draw paddles only (removed scoreboard reference)
    if (pongMode) {
        // Draw the pong paddles
        drawPongPaddles();
    }
    
    // Disable blending
    glDisable(GL_BLEND);
    
    // Unbind VAO and shader
    glBindVertexArray(0);
    glUseProgram(0);
    
    // Swap buffers
    glutSwapBuffers();
    
    // Restore window context
    //if (currentWindow > 0) {
    //    glutSetWindow(currentWindow);
    //}
}

void startLayerRotation(int axis, int layer, bool clockwise) {
    if (isAnimating) {
        printf("Debug: Already animating, queueing rotation\n");
        RotationAnimation anim;
        anim.axis = axis;
        anim.layer = layer;
        anim.angle = 0.0f;
        anim.targetAngle = 90.0f;
        anim.speed = params.animationSpeed;
        anim.clockwise = clockwise;
        animationQueue.push(anim);
        return;
    }
    
    printf("Debug: Starting layer rotation: axis=%d, layer=%d, clockwise=%d\n", 
           axis, layer, clockwise);
    
    isAnimating = true;
    for (auto& cube : cubes) {
        if (cube.layer[axis] == layer) {
            cube.isAnimating = true;
            cube.clockwise = clockwise;
            cube.rotationAxis = glm::vec3(
                axis == 0 ? 1.0f : 0.0f,
                axis == 1 ? 1.0f : 0.0f,
                axis == 2 ? 1.0f : 0.0f
            );
            // Store the current transform as the base for the new rotation
            cube.baseTransform = cube.animationTransform;
        }
    }
}

void updateAnimation(float deltaTime) {
    if (!isAnimating) return;

    bool allDone = true;
    // Use the animation speed from params
    float rotationAmount = deltaTime * params.animationSpeed;

    glm::vec3 cubeCenter(0.0f);
    for (const auto& c : cubes) cubeCenter += c.position;
    cubeCenter /= (float)cubes.size();

    for (auto& cube : cubes) {
    if (cube.isAnimating) {
        float direction = cube.clockwise ? 1.0f : -1.0f;
        cube.rotation += rotationAmount * direction;
        
        if (std::abs(cube.rotation) >= 90.0f) {
                // Calculate the exact final rotation needed
            float finalAngle = 90.0f * (cube.clockwise ? 1.0f : -1.0f);
            cube.rotation = 0.0f;
            cube.isAnimating = false;
            
                // Calculate final rotation around the center of the cube
            glm::mat4 finalRotation = glm::rotate(
                glm::mat4(1.0f),
                glm::radians(finalAngle),
                cube.rotationAxis
            );
            
                // Rotate around the current cube center
                glm::vec3 relPos = cube.position - cubeCenter;
                glm::vec3 rotatedRelPos = glm::vec3(finalRotation * glm::vec4(relPos, 1.0f));
                cube.position = rotatedRelPos + cubeCenter;
                
                // Store the final rotation, combining with previous transforms
            cube.animationTransform = finalRotation * cube.baseTransform;
            
                // Update layer indices based on rotation
            if (cube.rotationAxis.x > 0.5f) {  // X-axis rotation
                int newY = cube.clockwise ? -cube.layer[2] : cube.layer[2];
                int newZ = cube.clockwise ? cube.layer[1] : -cube.layer[1];
                cube.layer[1] = newY;
                cube.layer[2] = newZ;
            }
            else if (cube.rotationAxis.y > 0.5f) {  // Y-axis rotation
                int newX = cube.clockwise ? cube.layer[2] : -cube.layer[2];
                int newZ = cube.clockwise ? -cube.layer[0] : cube.layer[0];
                cube.layer[0] = newX;
                cube.layer[2] = newZ;
            }
            else if (cube.rotationAxis.z > 0.5f) {  // Z-axis rotation
                int newX = cube.clockwise ? -cube.layer[1] : cube.layer[1];
                int newY = cube.clockwise ? cube.layer[0] : -cube.layer[0];
                cube.layer[0] = newX;
                cube.layer[1] = newY;
            }
        } else {
                // Calculate the current rotation angle
            float currentAngle = cube.rotation;
                
                // Create rotation matrix for the current frame
            glm::mat4 rotation = glm::rotate(
                glm::mat4(1.0f),
                glm::radians(currentAngle),
                cube.rotationAxis
            );
            
                // Calculate rotated position
                //glm::vec4 rotatedPos = rotation * glm::vec4(cube.position, 1.0f);
                glm::vec3 relPos = cube.position - cubeCenter;
                glm::vec3 rotatedRelPos = glm::vec3(rotation * glm::vec4(relPos, 1.0f));
                glm::vec3 rotatedPos = rotatedRelPos + cubeCenter;
                
                // Set the animation transform for this frame, combining with previous transforms
                cube.animationTransform = glm::translate(glm::mat4(1.0f), glm::vec3(rotatedPos) - cube.position) * rotation * cube.baseTransform;
                
                allDone = false;
            }
        }
    }

    if (allDone) {
        isAnimating = false;
        if (!animationQueue.empty()) {
            auto nextAnim = animationQueue.front();
            animationQueue.pop();
            startLayerRotation(nextAnim.axis, nextAnim.layer, nextAnim.clockwise);
        }
    }
}

// Handle keyboard input
void keyboardRubik(unsigned char key, int x, int y) {
    static bool debugPhysics = false;

    switch (key) {
        // Camera controls - now relative to cameraDistance
        case 'w': 
            if (pongMode) {
                // Move left paddle up
                leftPaddle.position.y += leftPaddle.speed * 0.2f;
                // Limit paddle movement to playfield bounds
                if (leftPaddle.position.y + leftPaddle.size.y/2 > PONG_CEILING_Y) {
                    leftPaddle.position.y = PONG_CEILING_Y - leftPaddle.size.y/2;
                }
            } else {
                cameraPos.z -= params.cameraDistance * 0.1f; 
            }
            break;
            
        case 's': 
            if (pongMode) {
                // Move left paddle down
                leftPaddle.position.y -= leftPaddle.speed * 0.2f;
                // Limit paddle movement to playfield bounds
                if (leftPaddle.position.y - leftPaddle.size.y/2 < PONG_FLOOR_Y) {
                    leftPaddle.position.y = PONG_FLOOR_Y + leftPaddle.size.y/2;
                }
            } else {
                cameraPos.z += params.cameraDistance * 0.1f; 
            }
            break;
            
        case 'a': cameraPos.x -= params.cameraDistance * 0.1f; break;
        case 'd': cameraPos.x += params.cameraDistance * 0.1f; break;
        case 'q': cameraPos.y += params.cameraDistance * 0.1f; break;
        case 'e': cameraPos.y -= params.cameraDistance * 0.1f; break;
        
        // Space to start Pong game or start the ball
        case ' ':
            if (pongMode && showStartScreen) {
                showStartScreen = false;
                resetPongBall();
            }
            break;
        
        // X-axis rotations (left/right faces) - disabled during physics & restoration
        case '1': if (!physicsActive && !restorationPhase ) startLayerRotation(0, 1, true); break;
        case '!': if (!physicsActive && !restorationPhase ) startLayerRotation(0, 1, false); break;
        case '2': if (!physicsActive && !restorationPhase ) startLayerRotation(0, 0, true); break;
        case '@': if (!physicsActive && !restorationPhase ) startLayerRotation(0, 0, false); break;
        case '3': if (!physicsActive && !restorationPhase ) startLayerRotation(0, -1, true); break;
        case '#': if (!physicsActive && !restorationPhase ) startLayerRotation(0, -1, false); break;
        
        // Y-axis rotations (top/bottom faces) - disabled during physics & restoration
        case '4': if (!physicsActive && !restorationPhase ) startLayerRotation(1, 1, true); break;
        case '$': if (!physicsActive && !restorationPhase ) startLayerRotation(1, 1, false); break;
        case '5': if (!physicsActive && !restorationPhase ) startLayerRotation(1, 0, true); break;
        case '%': if (!physicsActive && !restorationPhase ) startLayerRotation(1, 0, false); break;
        case '6': if (!physicsActive && !restorationPhase ) startLayerRotation(1, -1, true); break;
        case '^': if (!physicsActive && !restorationPhase ) startLayerRotation(1, -1, false); break;
        
        // Z-axis rotations (front/back faces) - disabled during physics & restoration
        case '7': if (!physicsActive && !restorationPhase ) startLayerRotation(2, 1, true); break;
        case '&': if (!physicsActive && !restorationPhase ) startLayerRotation(2, 1, false); break;
        case '8': if (!physicsActive && !restorationPhase ) startLayerRotation(2, 0, true); break;
        case '*': if (!physicsActive && !restorationPhase ) startLayerRotation(2, 0, false); break;
        case '9': if (!physicsActive && !restorationPhase ) startLayerRotation(2, -1, true); break;
        case '(': if (!physicsActive && !restorationPhase ) startLayerRotation(2, -1, false); break;
        
        // Reset camera
        case 'r':
            cameraPos = glm::vec3(4.0f, 3.0f, 3.0f);
            break;
            
        // Reset cube and physics
        case 'R':
            // Reset physics state
            physicsActive = false;
            restorationPhase = false;
            globalTime = 0.0f;
            cubeState = IDLE;  // Also reset the state machine
            pongMode = false;  // Exit Pong mode
            
            // Reinitialize cubes
            initCubes();
            break;
        
        // Start PONG mode
        case 'P':
            if (!pongMode) {
                initPongGame();
            }
            break;
            
        // Trigger automatic shuffle-solve-earthquake cycle
        case 'A':
            if (!pongMode && cubeState == IDLE && !physicsActive && !restorationPhase && !isAnimating) {
                printf("Starting automatic shuffle-solve-earthquake cycle\n");
                startAutomaticShuffle();
            } else {
                printf("Cannot start automatic cycle - already in progress or active animation\n");
            }
            break;
            
        // Trigger earthquake immediately with higher initial velocities
        case 'E':
            if (!pongMode && !physicsActive && !restorationPhase) {
                printf("Debug: Forcing earthquake to start immediately!\n");
                physicsActive = true;
                restorationPhase = false;
                physicsStartTime = globalTime;
                
                // Initialize physical properties for all cubes
                for (auto& cube : cubes) {
                    // Calculate direction from center for explosion effect
                    glm::vec3 dirFromCenter = glm::normalize(cube.position);
                    float distFromCenter = glm::length(cube.position);
                    
                    // Strong outward velocity (explosion effect)
                    cube.velocity = dirFromCenter * (4.0f + distFromCenter * 1.0f);
                    
                    // Add strong upward velocity
                    cube.velocity.y += 5.0f;
                    
                    // Add random rotation
                    cube.angularVelocity = glm::vec3(
                        (rand() % 200 - 100) / 100.0f * 10.0f,
                        (rand() % 200 - 100) / 100.0f * 10.0f,
                        (rand() % 200 - 100) / 100.0f * 10.0f
                    );
                    
                    cube.onGround = false;
                }
                
                printf("Physics activated! Cubes are now physics-driven.\n");
            }
            break;
            
        // Start restoration phase immediately
        case 'F':
            if (!pongMode && physicsActive) {
                printf("Debug: Forcing restoration phase to start immediately!\n");
                startRestorationPhase();
            }
            break;
        
        // Toggle physics debug output (changed from 'P' to 'B' to avoid conflict)
        case 'B':
            debugPhysics = !debugPhysics;
            printf("Physics debug output %s\n", debugPhysics ? "enabled" : "disabled");
            if (debugPhysics && (physicsActive || restorationPhase)) {
                // Print current state of all cubes
                printf("Current physics state:\n");
                for (size_t i = 0; i < cubes.size(); i++) {
                    const auto& cube = cubes[i];
                    printf("Cube %zu: Pos(%.2f, %.2f, %.2f), Vel(%.2f, %.2f, %.2f), OnGround:%d\n",
                           i, cube.position.x, cube.position.y, cube.position.z,
                           cube.velocity.x, cube.velocity.y, cube.velocity.z,
                           cube.onGround);
                }
            }
            break;
        
        case 27:  // ESC key
            glutDestroyWindow(glutGetWindow());
            exit(0);
            break;
    }
    
    glutPostRedisplay();
}

// Handle special keys (arrow keys)
void specialKeyboard(int key, int x, int y) {
    if (pongMode) {
        switch (key) {
            case GLUT_KEY_UP:
                // Move right paddle up
                rightPaddle.position.y += rightPaddle.speed * 0.2f;
                // Limit paddle movement to playfield bounds
                if (rightPaddle.position.y + rightPaddle.size.y/2 > PONG_CEILING_Y) {
                    rightPaddle.position.y = PONG_CEILING_Y - rightPaddle.size.y/2;
                }
                break;
                
            case GLUT_KEY_DOWN:
                // Move right paddle down
                rightPaddle.position.y -= rightPaddle.speed * 0.2f;
                // Limit paddle movement to playfield bounds
                if (rightPaddle.position.y - rightPaddle.size.y/2 < PONG_FLOOR_Y) {
                    rightPaddle.position.y = PONG_FLOOR_Y + rightPaddle.size.y/2;
                }
                break;
        }
        
        glutPostRedisplay();
    }
}

void updateRubik(int value) {
    // Store current window
    int currentWindow = glutGetWindow();
    
    // Update time
    float deltaTime = 0.016f;  // ~60 FPS
    globalTime += deltaTime;
    
    static float timeBeforeShuffle = 3.0f;  // Wait 3 seconds before shuffling
    static bool shuffleStarted = false;
    
    //if (!shuffleStarted && globalTime >= timeBeforeShuffle) {
    //    shuffleStarted = true;
    //    startAutomaticShuffle();
    //}
    // Always update state machine to allow background effects
    updateStateMachine(deltaTime);
    
    // Update based on current mode
    if (pongMode) {
        // Update Pong game
        updatePongGame(deltaTime);
        if (!physicsActive && !restorationPhase) {
            updateAnimation(deltaTime);
        } else {
            updatePhysics(deltaTime);
        }
    } else {
        // Update cube animation (if not in physics mode)
        if (!physicsActive && !restorationPhase) {
            updateAnimation(deltaTime);
        } else {
            updatePhysics(deltaTime);
        }
    }
    
    static float lastTime = 0.0f;
    float currentTime = glutGet(GLUT_ELAPSED_TIME) / 1000.0f;
        // Update camera angles for both horizontal and vertical rotation
             static float cameraAngle = 0.0f;
    static float cameraAnimationSpeed = 0.1f;
    static float cameraDistance = 8.0f;
    static float cameraHeight = 3.0f;

    cameraAngle += (currentTime - lastTime) * cameraAnimationSpeed * 2.0f * M_PI;

    lastTime = currentTime;


    
      // Create a spiral pattern in 3D space
    float horizontalAngle = cameraAngle;
    float verticalAngle = sin(cameraAngle * 0.5f) * 0.5f * M_PI;
        static float minHeight = 2.0f;  // Minimum height above the floor
    
    // Calculate zoom factor separately
    float zoomFactor = cameraDistance + 
                      sin(cameraAngle * 0.3f) * 2.5f +
                      sin(cameraAngle * 1.2f) * 1.5f;
    
    // Apply zoom factor directly to the normalized direction vector
    cameraPos = glm::vec3(
        cos(horizontalAngle) * zoomFactor,
         std::max(minHeight, (float)(cameraHeight + sin(verticalAngle) * zoomFactor)),
        sin(horizontalAngle) * zoomFactor
    );
    // Request redisplay for both windows
    //glutSetWindow(mainWindow);
    glutPostRedisplay();
       
    // Restore window context
    //if (currentWindow > 0) {
    //    glutSetWindow(currentWindow);
    //}
    
    // Schedule next update
    glutTimerFunc(16, updateRubik, 0);
}

void cleanup() {
    printf("Debug: Cleaning up resources...\n");
    if (vbo != 0) {
        glDeleteBuffers(1, &vbo);
        vbo = 0;
    }
    if (ebo != 0) {
        glDeleteBuffers(1, &ebo);
        ebo = 0;
    }
    if (vao != 0) {
        glDeleteVertexArrays(1, &vao);
        vao = 0;
    }
    if (shaderProgram != 0) {
        glDeleteProgram(shaderProgram);
        shaderProgram = 0;
    }
    
    // Clean up mirror resources
    if (floorVBO != 0) {
        glDeleteBuffers(1, &floorVBO);
        floorVBO = 0;
    }
    if (floorEBO != 0) {
        glDeleteBuffers(1, &floorEBO);
        floorEBO = 0;
    }
    if (floorVAO != 0) {
        glDeleteVertexArrays(1, &floorVAO);
        floorVAO = 0;
    }
    if (mirrorShaderProgram != 0) {
        glDeleteProgram(mirrorShaderProgram);
        mirrorShaderProgram = 0;
    }
    
    cubes.clear();
}

// Add this function before the update function
void updatePhysics(float deltaTime) {
    // If in restoration phase, apply restoration forces
    if (restorationPhase) {
        // Check if restoration is complete
        if (isRestorationComplete()) {
            completeRestoration();
            return;
        }
        
        // Calculate restoration time for smoother transition
        float restorationTime = globalTime - restorationStartTime;
        float forceFactor = std::min(1.0f, restorationTime / 3.0f); // Ramp up force over 3 seconds
        
        // Apply restoration forces to each cube
        for (size_t i = 0; i < cubes.size() && i < originalPositions.size(); ++i) {
            auto& cube = cubes[i];
            const auto& origPos = originalPositions[i];
            
            // Calculate direction to original position
            glm::vec3 toOriginal = origPos.position - cube.position;
            float distance = glm::length(toOriginal);
            
            if (distance > 0.001f) {
                // Calculate force based on distance (spring-like)
                glm::vec3 force = glm::normalize(toOriginal) * distance * RESTORATION_FORCE * forceFactor;
                
                // Apply force to velocity
                cube.velocity += force * deltaTime;
                
                // Apply damping for smoother approach
                cube.velocity *= RESTORATION_DAMPING;
                
                // Update position
                cube.position += cube.velocity * deltaTime;
                
                // Gradually rotate cube back to identity rotation
                if (glm::length(cube.angularVelocity) > 0.01f) {
                    // Slow down rotation
                    cube.angularVelocity *= 0.95f;
                }
                
                // Gradually blend transformation back to identity
                if (cube.animationTransform != glm::mat4(1.0f)) {
                    // Extract rotation
                    glm::quat rotation = glm::quat_cast(cube.animationTransform);
                    glm::quat identity = glm::quat(1.0f, 0.0f, 0.0f, 0.0f);
                    
                    // Blend toward identity
                    rotation = glm::slerp(rotation, identity, deltaTime * 5.0f);
                    
                    // Apply back
                    cube.animationTransform = glm::mat4_cast(rotation);
                }
            } else {
                // Already close, just snap to position
                cube.position = origPos.position;
                cube.velocity = glm::vec3(0.0f);
            }
        }
        
        return;
    }
    
    // If not yet active but countdown reached, activate physics
    if (!physicsActive && globalTime >= PHYSICS_DELAY) {
        printf("\n\n*** EARTHQUAKE STARTING NOW at time %.2f ***\n\n", globalTime);
        physicsActive = true;
        physicsStartTime = globalTime;
        
        // Initialize physics properties with much stronger initial values
        for (auto& cube : cubes) {
            // Stronger initial velocity based on position from center
            glm::vec3 dirFromCenter = glm::normalize(cube.position);
            float distFromCenter = glm::length(cube.position);
            
            // More dramatic explosion effect - 4x stronger
            cube.velocity = dirFromCenter * (4.0f + distFromCenter * 1.0f);
            
            // Add stronger upward initial velocity component
            cube.velocity.y += 5.0f;
            
            // More energetic random rotation
            cube.angularVelocity = glm::vec3(
                (rand() % 200 - 100) / 100.0f * 10.0f,  // Doubled rotation speed
                (rand() % 200 - 100) / 100.0f * 10.0f,
                (rand() % 200 - 100) / 100.0f * 10.0f
            );
            
            // Randomize mass slightly
            cube.mass = 0.8f + (rand() % 40) / 100.0f; // 0.8 to 1.2
            cube.onGround = false;
            
            printf("Cube at position (%.2f, %.2f, %.2f) initialized with velocity (%.2f, %.2f, %.2f)\n", 
                   cube.position.x, cube.position.y, cube.position.z,
                   cube.velocity.x, cube.velocity.y, cube.velocity.z);
        }
        
        // Stop any ongoing cube rotations
        isAnimating = false;
        while (!animationQueue.empty()) {
            animationQueue.pop();
        }
    }
    
    // If physics is not active, don't continue with simulation
    if (!physicsActive) {
        return;
    }
    
    // Calculate earthquake effect - stronger at the beginning
    float earthquakeTime = globalTime - physicsStartTime;
    float earthquakeFactor = std::max(0.0f, 1.0f - earthquakeTime / 5.0f); // Longer earthquake (5s instead of 3s)
    
    // Check if earthquake phase should end and restoration phase should begin
    if (earthquakeTime > 8.0f && !restorationPhase) { // Let cubes settle for a few seconds after main quake
        startRestorationPhase();
        return;
    }
    
    // Debug print to verify earthquake is running - print more frequently
    static float lastPrintTime = 0;
    if (earthquakeTime - lastPrintTime >= 0.5f) { // Print every 0.5 seconds
        printf("Debug: Earthquake active! Time: %.1fs, Strength: %.2f, Cubes: %zu\n", 
               earthquakeTime, earthquakeFactor * EARTHQUAKE_MAGNITUDE, cubes.size());
        
        // Print information about a sample cube
        if (!cubes.empty()) {
            const auto& cube = cubes[0];
            printf("Sample cube: Pos(%.2f, %.2f, %.2f), Vel(%.2f, %.2f, %.2f), OnGround:%d\n",
                   cube.position.x, cube.position.y, cube.position.z,
                   cube.velocity.x, cube.velocity.y, cube.velocity.z,
                   cube.onGround);
        }
        
        lastPrintTime = earthquakeTime;
    }
    
    // Apply physics to each cube
    for (auto& cube : cubes) {
        // Apply gravity - stronger effect
        if (!cube.onGround) {
            cube.velocity.y -= GRAVITY * deltaTime * 1.5f; // Increased gravity effect
        }
        
        // Earthquake effect - more violent shaking
        if (earthquakeFactor > 0.01f) {
            // More chaotic earthquake motion with varied frequencies
            float shakeX = sin(earthquakeTime * 20.0f + cube.position.x * 5.0f) * 
                           EARTHQUAKE_MAGNITUDE * earthquakeFactor;
            float shakeZ = cos(earthquakeTime * 18.0f + cube.position.z * 4.0f) * 
                           EARTHQUAKE_MAGNITUDE * earthquakeFactor;
            
            // Add vertical component to earthquake
            float shakeY = sin(earthquakeTime * 15.0f + cube.position.x * 3.0f + cube.position.z * 4.0f) * 
                           EARTHQUAKE_MAGNITUDE * earthquakeFactor * 0.8f; // Increased vertical component
            
            // Apply shake forces
            cube.velocity.x += shakeX * deltaTime * 20.0f; // Scale shake by deltaTime and increase effect
            cube.velocity.y += shakeY * deltaTime * 20.0f;
            cube.velocity.z += shakeZ * deltaTime * 20.0f;
            
            // More violent shaking for floor collision - higher chance of bounce
            if (cube.onGround && rand() % 100 < 50) { // Increased from 30 to 50
                cube.velocity.y += (rand() % 100) / 100.0f * earthquakeFactor * 8.0f; // Much stronger bounce
                cube.onGround = false;
            }
        }
        
        // Update position based on velocity
        cube.position += cube.velocity * deltaTime;
        
        // Floor collision
        if (cube.position.y - params.cubeSize < FLOOR_HEIGHT && cube.velocity.y < 0) {
            cube.position.y = FLOOR_HEIGHT + params.cubeSize;
            
            // Bounce with energy loss
            cube.velocity.y = -cube.velocity.y * BOUNCE_FACTOR;
            
            // Friction effect on horizontal velocity
            cube.velocity.x *= 0.95f;
            cube.velocity.z *= 0.95f;
            
            // If velocity is very low, consider it stopped
            if (std::abs(cube.velocity.y) < 0.1f) {
                cube.velocity.y = 0;
                cube.onGround = true;
            } else {
                cube.onGround = false;
            }
        }
        
        // Walls - simple boundaries to keep cubes in view
        const float WALL_LIMIT = 10.0f;
        if (std::abs(cube.position.x) > WALL_LIMIT) {
            cube.position.x = glm::sign(cube.position.x) * WALL_LIMIT;
            cube.velocity.x = -cube.velocity.x * 0.8f;
        }
        if (std::abs(cube.position.z) > WALL_LIMIT) {
            cube.position.z = glm::sign(cube.position.z) * WALL_LIMIT;
            cube.velocity.z = -cube.velocity.z * 0.8f;
        }
        
        // Apply rotation based on angular velocity
        if (glm::length(cube.angularVelocity) > 0.01f) {
            glm::quat rotation = glm::quat(1.0f, 0.0f, 0.0f, 0.0f);
            
            // Apply rotation around X axis
            if (std::abs(cube.angularVelocity.x) > 0.01f) {
                rotation = glm::rotate(rotation, cube.angularVelocity.x * deltaTime, glm::vec3(1.0f, 0.0f, 0.0f));
            }
            
            // Apply rotation around Y axis
            if (std::abs(cube.angularVelocity.y) > 0.01f) {
                rotation = glm::rotate(rotation, cube.angularVelocity.y * deltaTime, glm::vec3(0.0f, 1.0f, 0.0f));
            }
            
            // Apply rotation around Z axis
            if (std::abs(cube.angularVelocity.z) > 0.01f) {
                rotation = glm::rotate(rotation, cube.angularVelocity.z * deltaTime, glm::vec3(0.0f, 0.0f, 1.0f));
            }
            
            // Convert quaternion to matrix and apply to the cube
            cube.animationTransform = cube.animationTransform * glm::mat4_cast(rotation);
            
            // Dampen angular velocity over time
            cube.angularVelocity *= 0.99f;
            
            // If on ground, increase angular damping due to friction
            if (cube.onGround) {
                cube.angularVelocity *= 0.95f;
            }
        }
        
        // Air resistance
        cube.velocity *= 0.99f;
    }
}

// Start restoration phase
void startRestorationPhase() {
    if (!restorationPhase) {
        printf("Debug: Starting restoration phase!\n");
        restorationPhase = true;
        physicsActive = false;
        restorationStartTime = globalTime;
        
        // Slow down any cube movements
        for (auto& cube : cubes) {
            cube.velocity *= 0.5f;
            cube.angularVelocity *= 0.5f;
        }
    }
}

// Check if all cubes are close enough to their original positions
bool isRestorationComplete() {
    const float POSITION_THRESHOLD = 0.1f;
    const float VELOCITY_THRESHOLD = 0.05f;
    
    // Check if we've waited the maximum time
    float restorationTime = globalTime - restorationStartTime;
    if (restorationTime > restorationTimeout) {
        printf("Debug: Restoration timeout reached. Forcing completion.\n");
        return true;
    }
    
    // Check if all cubes are close to their target positions and nearly stopped
    bool allInPlace = true;
    float maxDistance = 0.0f;
    float maxVelocity = 0.0f;
    
    for (size_t i = 0; i < cubes.size() && i < originalPositions.size(); ++i) {
        const glm::vec3& targetPos = originalPositions[i].position;
        const glm::vec3& currentPos = cubes[i].position;
        
        float distance = glm::length(targetPos - currentPos);
        float velocity = glm::length(cubes[i].velocity);
        
        maxDistance = std::max(maxDistance, distance);
        maxVelocity = std::max(maxVelocity, velocity);
        
        if (distance > POSITION_THRESHOLD || velocity > VELOCITY_THRESHOLD) {
            allInPlace = false;
        }
    }
    
    // Provide feedback every second
    static float lastFeedbackTime = 0;
    if (restorationTime - lastFeedbackTime >= 1.0f) {
        printf("Debug: Restoration progress - Max distance: %.3f, Max velocity: %.3f\n", 
               maxDistance, maxVelocity);
        lastFeedbackTime = restorationTime;
    }
    
    return allInPlace;
}

// Complete the restoration process
void completeRestoration() {
    printf("Debug: Completing restoration process!\n");
    
    // Snap all cubes to their exact original positions
    for (size_t i = 0; i < cubes.size() && i < originalPositions.size(); ++i) {
        cubes[i].position = originalPositions[i].position;
        cubes[i].layer[0] = originalPositions[i].layer[0];
        cubes[i].layer[1] = originalPositions[i].layer[1];
        cubes[i].layer[2] = originalPositions[i].layer[2];
        cubes[i].velocity = glm::vec3(0.0f);
        cubes[i].angularVelocity = glm::vec3(0.0f);
        cubes[i].animationTransform = glm::mat4(1.0f);
        cubes[i].baseTransform = glm::mat4(1.0f);
    }
    
    // Exit restoration phase
    restorationPhase = false;
    
    printf("Cube has been restored. You can shuffle it again!\n");
}

// Start automatic shuffle sequence
void startAutomaticShuffle() {
    printf("Starting automatic shuffle sequence\n");
    if (cubeState == SHUFFLING) {
        return;
    }
    cubeState = SHUFFLING;
    
    // Clear previous shuffle moves
    shuffleMoves.clear();
    
    // Decide on number of random moves
    shuffleMovesLeft = MIN_SHUFFLE_MOVES + rand() % (MAX_SHUFFLE_MOVES - MIN_SHUFFLE_MOVES + 1);
    printf("Will perform %d random moves\n", shuffleMovesLeft);
    
    // Start the first move
    continueAutomaticShuffle();
}

// Continue automatic shuffle with next move
void continueAutomaticShuffle() {
    if (shuffleMovesLeft <= 0 || isAnimating) {
        // If we're done or still animating, don't start a new move
        if (shuffleMovesLeft <= 0 && !isAnimating) {
            printf("Shuffle complete. Starting solve after delay.\n");
            cubeState = IDLE;
            stateMachineDelay = 2.0f; // Pause before solving
        }
        return;
    }

    // Generate a random move
    CubeMove move;
    move.axis = rand() % 3;                 // Random axis (0, 1, or 2)
    move.layer = (rand() % 3) - 1;          // Random layer (-1, 0, or 1)
    move.clockwise = (rand() % 2) == 0;     // Random direction
    
    // Store the move for later solving
    shuffleMoves.push_back(move);
    
    // Start the move
    printf("Shuffle move %d: Axis %d, Layer %d, Clockwise %d\n", 
           (int)shuffleMoves.size(), move.axis, move.layer, move.clockwise);
    startLayerRotation(move.axis, move.layer, move.clockwise);
    
    // Decrement counter
    shuffleMovesLeft--;
}

// Start automatic solve sequence
void startAutomaticSolve() {
    printf("Starting automatic solve sequence\n");
    cubeState = SOLVING;
    
    // If no moves to undo, go straight to earthquake
    if (shuffleMoves.empty()) {
        printf("No moves to undo. Going directly to earthquake.\n");
        //cubeState = EARTHQUAKE_TRANSITION;
        return;
    }
    
    // Start solving by reversing the shuffle moves
    continueAutomaticSolve();
}

// Continue automatic solve with next move
void continueAutomaticSolve() {
    if (shuffleMoves.empty() || isAnimating) {
        // If we're done or still animating, don't start a new move
        if (shuffleMoves.empty() && !isAnimating) {
            printf("Solve complete. Starting earthquake transition.\n");
            cubeState = EARTHQUAKE_TRANSITION;
            // Trigger earthquake immediately
            physicsActive = true;
            restorationPhase = false;
            physicsStartTime = globalTime;
            
            // Initialize physical properties for all cubes
            for (auto& cube : cubes) {
                // Calculate direction from center for explosion effect
                glm::vec3 dirFromCenter = glm::normalize(cube.position);
                float distFromCenter = glm::length(cube.position);
                
                // Strong outward velocity (explosion effect)
                cube.velocity = dirFromCenter * (4.0f + distFromCenter * 1.0f);
                
                // Add strong upward velocity
                cube.velocity.y += 5.0f;
                
                // Add random rotation
                cube.angularVelocity = glm::vec3(
                    (rand() % 200 - 100) / 100.0f * 10.0f,
                    (rand() % 200 - 100) / 100.0f * 10.0f,
                    (rand() % 200 - 100) / 100.0f * 10.0f
                );
                
                cube.onGround = false;
            }
        }
        return;
    }

    // Get the last move and remove it from the list
    CubeMove move = shuffleMoves.back();
    shuffleMoves.pop_back();
    
    // Invert the move direction to undo it
    move.clockwise = !move.clockwise;
    
    // Start the inverted move
    printf("Solve move: Axis %d, Layer %d, Clockwise %d\n", 
           move.axis, move.layer, move.clockwise);
    startLayerRotation(move.axis, move.layer, move.clockwise);
}

// Update state machine for automatic solving/shuffling
void updateStateMachine(float deltaTime) {
    // If in Pong mode, run background cube effects while playing
    
    // If there's a delay timer running, decrement it
    if (stateMachineDelay > 0) {
        stateMachineDelay -= deltaTime;
        if (stateMachineDelay <= 0) {
            // Timer expired, proceed with state machine
            stateMachineDelay = 0;
            
            // Handle next step based on current state
            if (cubeState == IDLE) {
                // If we were idle, start solving
                startAutomaticSolve();
            } else if (cubeState == EARTHQUAKE_TRANSITION && physicsActive) {
                // If earthquake is already active, do nothing
                // The earthquake transition will handle itself via physics
            } else if (cubeState == RESTORING && restorationPhase) {
                // If restoration is already active, do nothing
                // The restoration will handle itself
            }
        }
        return;
    }
    
    // Handle state transitions
    switch (cubeState) {
        case SHUFFLING:
            if (!isAnimating) {
                continueAutomaticShuffle();
            }
            break;
            
        case SOLVING:
            if (!isAnimating) {
                continueAutomaticSolve();
            }
            break;
            
        case EARTHQUAKE_TRANSITION:
            if (!physicsActive && !restorationPhase) {
                // Earthquake has completed, restoration is done
                printf("Earthquake and restoration complete. Starting new shuffle.\n");
                startAutomaticShuffle();
            } else if (physicsActive) {
                // Check if earthquake has run long enough
                float earthquakeTime = globalTime - physicsStartTime;
                if (earthquakeTime > 7.0f && !restorationPhase) {
                    // Earthquake has run long enough, start restoration
                    printf("Earthquake has run long enough. Starting restoration.\n");
                    startRestorationPhase();
                    cubeState = RESTORING;
                }
            }
            break;
            
        case RESTORING:
            if (!restorationPhase) {
                if (!pongStarted) {
                    pongStarted = true;
                    printf("Restoration complete. Starting Pong game!\n");
                    initPongGame();
                } else {
                // Restoration is complete, start new shuffle
                printf("Restoration complete. Starting new shuffle.\n");
                startAutomaticShuffle();
                }
            }
            break;
            
        case IDLE:
            // In idle state, we're waiting for a timer or external input
            break;
    }
}

// Initialize the Pong game
void initPongGame() {
    printf("Initializing Pong game with Rubik's cube as ball\n");
    pongMode = true;
    
    // Reset scores
    leftScore = 0;
    rightScore = 0;
    
    // Initialize left paddle
    leftPaddle.position = glm::vec3(-PONG_SIDE_X + 1.0f, 0.0f, 0.0f);
    leftPaddle.size = glm::vec3(0.5f, 3.0f, 3.0f);
    leftPaddle.color = glm::vec3(1.0f, 0.2f, 0.2f); // Red
    leftPaddle.speed = PONG_PADDLE_SPEED;
    
    // Initialize right paddle
    rightPaddle.position = glm::vec3(PONG_SIDE_X - 1.0f, 0.0f, 0.0f);
    rightPaddle.size = glm::vec3(0.5f, 3.0f, 3.0f);
    rightPaddle.color = glm::vec3(0.2f, 0.2f, 1.0f); // Blue
    rightPaddle.speed = PONG_PADDLE_SPEED;
    
    // Calculate current cube center
    glm::vec3 cubeCenter(0.0f);
    if (!cubes.empty()) {
        for (const auto& cube : cubes) {
            cubeCenter += cube.position;
        }
        cubeCenter /= (float)cubes.size();
        
        // Move all cube pieces to center (0,0,0)
        glm::vec3 offset = glm::vec3(0.0f, 0.0f, 0.0f) - cubeCenter;
        for (auto& cube : cubes) {
            // Reset transform
            cube.baseTransform = glm::mat4(1.0f);
            cube.animationTransform = glm::mat4(1.0f);
            
            // Apply offset to position cube at center (0,0,0)
            cube.position += offset;
            
            // Clear velocities
            cube.velocity = glm::vec3(0.0f);
            cube.angularVelocity = glm::vec3(0.0f);
        }
    }
    
    // Set game state
    //cubeState = PONG_PLAYING;
    showStartScreen = true;
    
    // Initialize ball velocity in resetPongBall
    resetPongBall();
    
    printf("Pong game initialized - Press SPACE to start\n");
}

// Reset the ball position and give it an initial velocity
void resetPongBall() {
    // Only proceed if there are cubes to work with
    if (cubes.empty()) {
        printf("Warning: No cube pieces available for Pong ball!\n");
        return;
    }
    
    // Calculate cube center based on average of all cube pieces
    glm::vec3 cubeCenter(0.0f);
    for (const auto& cube : cubes) {
        cubeCenter += cube.position;
    }
    cubeCenter /= (float)cubes.size();
    
    // Move all cube pieces to center (0,0,0)
    glm::vec3 offset = glm::vec3(0.0f, 0.0f, 0.0f) - cubeCenter;
    for (auto& cube : cubes) {
        cube.position += offset;
        cube.velocity = glm::vec3(0.0f);
        cube.angularVelocity = glm::vec3(0.0f);
    }
    
    // Random initial velocity direction (left or right)
    float dirX = (rand() % 2 == 0) ? 1.0f : -1.0f;
    
    // Random angle between -30 and 30 degrees
    float angle = glm::radians((rand() % 60) - 30.0f);
    
    // Calculate initial velocity
    float speed = PONG_BALL_MIN_SPEED + 
                 (rand() % 100) / 100.0f * (PONG_BALL_MAX_SPEED - PONG_BALL_MIN_SPEED);
    
    pongBallVelocity = glm::vec3(
        dirX * speed * cos(angle),
        speed * sin(angle),
        0.0f
    );
    
    // Set some random rotation to the cube
    for (auto& cube : cubes) {
        cube.angularVelocity = glm::vec3(
            (rand() % 200 - 100) / 100.0f * 2.0f,
            (rand() % 200 - 100) / 100.0f * 2.0f,
            (rand() % 200 - 100) / 100.0f * 2.0f
        );
    }
    
    //cubeState = PONG_RESET;
    showStartScreen = false;
}

// Update Pong game state
void updatePongGame(float deltaTime) {
    static float resetDelay = 0.0f;
        // Pause Pong if earthquake or restoration is active
    if (physicsActive || restorationPhase) {
        return;
    }
    if (showStartScreen) {
        // Wait for space key press
        return;
    }
    
    if (cubeState == PONG_RESET) {
        // Wait a moment before starting ball movement
        resetDelay += deltaTime;
        if (resetDelay >= 1.0f) {
            resetDelay = 0.0f;
            cubeState = PONG_PLAYING;
        }
        return;
    }
    
    // Check for game over
    if (leftScore >= PONG_MAX_SCORE || rightScore >= PONG_MAX_SCORE) {
        printf("Game over! %s wins with %d points!\n", 
               leftScore >= PONG_MAX_SCORE ? "Left player" : "Right player",
               leftScore >= PONG_MAX_SCORE ? leftScore : rightScore);
        
        // Reset scores and start a new game
        leftScore = 0;
        rightScore = 0;
        
        // Show the start screen again
        showStartScreen = true;
        return;
    }
    
    // Calculate ball center before applying movement
    glm::vec3 ballCenter(0.0f);
    if (!cubes.empty()) {
        for (const auto& cube : cubes) {
            ballCenter += cube.position;
        }
        ballCenter /= (float)cubes.size();
    }
    
    // First apply physics for earthquake/rotation effects if needed
    if (physicsActive || isAnimating) {
        // Update physics simulation for earthquake effects
        // but preserve the ball's overall position
        updatePhysics(deltaTime);
        
        // After physics, recalculate the ball center
        glm::vec3 newCenter(0.0f);
        if (!cubes.empty()) {
            for (const auto& cube : cubes) {
                newCenter += cube.position;
            }
            newCenter /= (float)cubes.size();
        }
        
        // Calculate the offset that physics applied
        glm::vec3 centerOffset = newCenter - ballCenter;
        
        // Adjust cubes to maintain the intended pong ball position
        // by removing the physics-induced offset
        for (auto& cube : cubes) {
            cube.position -= centerOffset;
        }
    }
    
    //Now apply the pong ball movement
    for (auto& cube : cubes) {
        // Apply the global ball velocity to each cube piece
        cube.position += pongBallVelocity * deltaTime;
        
        /* Apply additional rotation if not already from physics
        if (!physicsActive && glm::length(cube.angularVelocity) > 0.01f) {
            glm::quat rotation = glm::quat(1.0f, 0.0f, 0.0f, 0.0f);
            
            // Apply rotation around each axis
            if (std::abs(cube.angularVelocity.x) > 0.01f) {
                rotation = glm::rotate(rotation, cube.angularVelocity.x * deltaTime, glm::vec3(1.0f, 0.0f, 0.0f));
            }
            if (std::abs(cube.angularVelocity.y) > 0.01f) {
                rotation = glm::rotate(rotation, cube.angularVelocity.y * deltaTime, glm::vec3(0.0f, 1.0f, 0.0f));
            }
            if (std::abs(cube.angularVelocity.z) > 0.01f) {
                rotation = glm::rotate(rotation, cube.angularVelocity.z * deltaTime, glm::vec3(0.0f, 0.0f, 1.0f));
            }
            
            // Apply rotation to cube's transform
            cube.animationTransform = cube.animationTransform * glm::mat4_cast(rotation);
            
            // Dampen rotation slightly
            cube.angularVelocity *= 0.99f;
        }*/
    }
    
    static float leftPaddleTargetY = 0.0f;
    static float rightPaddleTargetY = 0.0f;

    // Example: Map bass to left paddle, treble to right paddle
    {
        std::lock_guard<std::mutex> lock(audio_features_mutex);

        // Map bass (0.0-1.0) to paddle Y position (-max to +max)
        float maxPaddleY = PONG_CEILING_Y - leftPaddle.size.y/2;
        float minPaddleY = PONG_FLOOR_Y + leftPaddle.size.y/2;
        leftPaddleTargetY = minPaddleY + 2*audio_bass_intensity * (maxPaddleY - minPaddleY);

        // Map treble (0.0-1.0) to right paddle Y position
        maxPaddleY = PONG_CEILING_Y - rightPaddle.size.y/2;
        minPaddleY = PONG_FLOOR_Y + rightPaddle.size.y/2;
        rightPaddleTargetY = minPaddleY + 2*audio_treble_intensity * (maxPaddleY - minPaddleY);

        // Optionally, on beat, "kick" the paddles
        if (audio_beat_detected) {
            leftPaddleTargetY = PONG_CEILING_Y - leftPaddle.size.y/2; // Move to top on beat
            rightPaddleTargetY = PONG_FLOOR_Y + rightPaddle.size.y/2; // Move to bottom on beat
            // Reset beat flag if you want one kick per beat
            audio_beat_detected = false;
        }
    }

    // Smoothly move paddles toward target
    float paddleLerpSpeed = 0.2f; // Adjust for responsiveness
    leftPaddle.position.y += (leftPaddleTargetY - leftPaddle.position.y) * paddleLerpSpeed;
    rightPaddle.position.y += (rightPaddleTargetY - rightPaddle.position.y) * paddleLerpSpeed;
    
    // Check for collisions
    checkPongCollisions();
    
    // Gradually increase speed
    float currentSpeed = glm::length(pongBallVelocity);
    if (currentSpeed < PONG_BALL_MAX_SPEED) {
        pongBallVelocity *= (1.0f + 0.05f * deltaTime);
    }
}

// Check for and handle collisions in the Pong game
void checkPongCollisions() {
    // Make sure we have cubes to work with
    if (cubes.empty()) {
        printf("Warning: No cube pieces available for Pong collision detection!\n");
        return;
    }
    
    // Calculate cube center and approx size for collision detection
    glm::vec3 cubeCenter(0.0f);
    float cubeSize = 0.0f;
    
    for (const auto& cube : cubes) {
        cubeCenter += cube.position;
    }
    cubeCenter /= (float)cubes.size();
    
    // Find furthest cube piece to estimate size
    for (const auto& cube : cubes) {
        float dist = glm::length(cube.position - cubeCenter);
        cubeSize = std::max(cubeSize, dist);
    }
    
    // Add a bit of margin
    cubeSize += 1.0f;
    
    // Check for floor and ceiling collisions
    if (cubeCenter.y - cubeSize < PONG_FLOOR_Y && pongBallVelocity.y < 0) {
        // Floor collision - bounce up
        pongBallVelocity.y = -pongBallVelocity.y;
        
        // Add some spin based on x velocity
        for (auto& cube : cubes) {
            cube.angularVelocity.x += pongBallVelocity.x * 0.1f;
        }
    }
    
    if (cubeCenter.y + cubeSize > PONG_CEILING_Y && pongBallVelocity.y > 0) {
        // Ceiling collision - bounce down
        pongBallVelocity.y = -pongBallVelocity.y;
        
        // Add some spin based on x velocity
        for (auto& cube : cubes) {
            cube.angularVelocity.x -= pongBallVelocity.x * 0.1f;
        }
    }
    
    // Check for paddle collisions
    
    // Left paddle
    if (pongBallVelocity.x < 0 && // Moving left
        cubeCenter.x - cubeSize < leftPaddle.position.x + leftPaddle.size.x/2 && // Right edge of paddle
        cubeCenter.x - cubeSize > leftPaddle.position.x - leftPaddle.size.x/2 && // Left edge of paddle
        cubeCenter.y + cubeSize > leftPaddle.position.y - leftPaddle.size.y/2 && // Bottom of paddle
        cubeCenter.y - cubeSize < leftPaddle.position.y + leftPaddle.size.y/2) { // Top of paddle
        
        // Bounce off left paddle
        pongBallVelocity.x = -pongBallVelocity.x;
        
        // Adjust y velocity based on where ball hit paddle
        float relativeIntersectY = (leftPaddle.position.y - cubeCenter.y) / (leftPaddle.size.y/2);
        float bounceAngle = relativeIntersectY * glm::radians(60.0f); // Max 60 degree bounce
        
        // Calculate new velocity direction
        float speed = glm::length(pongBallVelocity);
        pongBallVelocity.x = std::abs(speed * cos(bounceAngle));
        pongBallVelocity.y = -speed * sin(bounceAngle);
        
        // Add some spin based on the paddle's influence
        for (auto& cube : cubes) {
            cube.angularVelocity.z += relativeIntersectY * 3.0f;
        }
    }
    
    // Right paddle
    if (pongBallVelocity.x > 0 && // Moving right
        cubeCenter.x + cubeSize > rightPaddle.position.x - rightPaddle.size.x/2 && // Left edge of paddle
        cubeCenter.x + cubeSize < rightPaddle.position.x + rightPaddle.size.x/2 && // Right edge of paddle
        cubeCenter.y + cubeSize > rightPaddle.position.y - rightPaddle.size.y/2 && // Bottom of paddle
        cubeCenter.y - cubeSize < rightPaddle.position.y + rightPaddle.size.y/2) { // Top of paddle
        
        // Bounce off right paddle
        pongBallVelocity.x = -pongBallVelocity.x;
        
        // Adjust y velocity based on where ball hit paddle
        float relativeIntersectY = (rightPaddle.position.y - cubeCenter.y) / (rightPaddle.size.y/2);
        float bounceAngle = relativeIntersectY * glm::radians(60.0f); // Max 60 degree bounce
        
        // Calculate new velocity direction
        float speed = glm::length(pongBallVelocity);
        pongBallVelocity.x = -std::abs(speed * cos(bounceAngle));
        pongBallVelocity.y = -speed * sin(bounceAngle);
        
        // Add some spin based on the paddle's influence
        for (auto& cube : cubes) {
            cube.angularVelocity.z -= relativeIntersectY * 3.0f;
        }
    }

    // After updating pongBallVelocity.x and pongBallVelocity.y:
    if (std::abs(pongBallVelocity.x) < 5.0f) {
        pongBallVelocity.x = (pongBallVelocity.x < 0 ? -1.0f : 1.0f) * 5.0f;
        // Optionally, renormalize to keep the total speed the same:
        float speed = glm::length(pongBallVelocity);
        pongBallVelocity = glm::normalize(pongBallVelocity) * speed;
    }
    
    // Check for scoring (ball past paddles)
    if (cubeCenter.x - cubeSize > PONG_SIDE_X) {
        // Left player scores
        leftScore++;
        printf("Left player scores! Score: %d-%d\n", leftScore, rightScore);
        resetPongBall();
    }
    else if (cubeCenter.x + cubeSize < -PONG_SIDE_X) {
        // Right player scores
        rightScore++;
        printf("Right player scores! Score: %d-%d\n", leftScore, rightScore);
        resetPongBall();
    }
}

// Audio callback functions
void on_audio_event(const AudioAnalysis::AudioFeatures& features, double position) {
    std::lock_guard<std::mutex> lock(audio_features_mutex);
    current_audio_features = features;
    
    // Extract key features for visual effects
    audio_bass_intensity = features.freq_bands.bass;
    audio_mid_intensity = features.freq_bands.mid;
    audio_treble_intensity = features.freq_bands.treble;
    audio_energy_level = features.energy;
    audio_beat_detected = features.beat_detected;
    audio_tempo = features.tempo;
}

void on_beat_detected(const AudioAnalysis::AudioEvent& event) {
    if (event.type == AudioAnalysis::EventType::BEAT) {
        // Create visual effect on beat
        SnowflakeEffect effect;
        effect.position = glm::vec2(
            (rand() % GRID_WIDTH) - GRID_WIDTH/2.0f,
            (rand() % GRID_HEIGHT) - GRID_HEIGHT/2.0f
        );
        effect.radius = EFFECT_RADIUS * (1.0f + event.intensity);
        effect.effectType = rand() % 4;  // Random pattern
        effect.blendFactor = event.intensity;
        effect.lastUpdate = globalTime;
        snowflakeEffects.push_back(effect);
        
        // Calculate grid center for proper spawning
        float gridCenterX = GRID_WIDTH / 2.0f;
        float gridCenterY = GRID_HEIGHT / 2.0f;
        
        // Increase snowflake spawn rate on beats
        for (int i = 0; i < 5; ++i) {
            if (snowParticles.size() < MAX_SNOW_PARTICLES) {
                std::uniform_real_distribution<float> angleDist(0.0f, 2.0f * M_PI);
                std::uniform_real_distribution<float> radiusDist(0.0f, SNOW_SPAWN_RADIUS);
                std::uniform_real_distribution<float> zDist(SNOW_SPAWN_DISTANCE, SNOW_SPAWN_DISTANCE * 1.5f);
                
                float angle = angleDist(globalRng);
                float radius = radiusDist(globalRng);
                float x = gridCenterX + cos(angle) * radius;  // Center around grid center
                float y = gridCenterY + sin(angle) * radius;  // Center around grid center
                float z = zDist(globalRng);
                
                SnowParticle particle;
                particle.position = glm::vec3(x, y, z);
                particle.size = SNOW_MIN_SIZE + (SNOW_MAX_SIZE - SNOW_MIN_SIZE) * event.intensity;
                particle.speed = SNOW_MIN_SPEED + (SNOW_MAX_SPEED - SNOW_MIN_SPEED) * audio_energy_level;
                particle.rotation = angleDist(globalRng);
                particle.rotationSpeed = SNOW_MIN_SPEED * 0.1f;
                particle.alpha = 0.8f + 0.2f * event.intensity;
                particle.effectType = rand() % 4;
                particle.hasLanded = false;
                particle.landingTime = 0.0f;
                snowParticles.push_back(particle);
            }
        }
    }
}

// Initialize audio system
bool initAudio(const std::string& mp3_file = "") {
    if (!mp3_file.empty() && !mp3_file.find(".mp3")) {
        printf("Warning: File might not be an MP3: %s\n", mp3_file.c_str());
    }
    
    try {
        audio_player = new AudioAnalysis::MP3Player();
        
        // Set analysis parameters for enhanced responsiveness
        AudioAnalysis::AnalysisParams params;
        params.beat_sensitivity = 0.4f;  // Medium sensitivity
        params.window_size = 1024;
        params.hop_size = 512;
        params.min_bpm = 60.0f;
        params.max_bpm = 200.0f;
        params.enable_freq_bands = true;
        params.enable_mfcc = true;
        params.enable_harmonic_analysis = true;
        params.energy_spike_threshold = 0.6f;
        params.continuous_update_rate_ms = 30;  // Smooth updates
        params.enable_continuous_monitoring = true;
        
        audio_player->set_analysis_params(params);
        
        // Register callbacks
        audio_player->add_continuous_callback(on_audio_event);
        audio_player->add_specific_event_callback(on_beat_detected);
        
        // Load MP3 file if provided
        if (!mp3_file.empty()) {
            if (audio_player->load(mp3_file)) {
                current_mp3_file = mp3_file;
                printf("Loaded MP3: %s (Duration: %.2fs)\n", 
                       mp3_file.c_str(), audio_player->get_duration());
                return true;
            } else {
                printf("Failed to load MP3 file: %s\n", mp3_file.c_str());
                return false;
            }
        }
        
        printf("Audio system initialized (no file loaded)\n");
        return true;
    } catch (const std::exception& e) {
        printf("Audio initialization failed: %s\n", e.what());
        return false;
    }
}

// Cleanup audio resources
void cleanupAudio() {
    if (audio_player) {
        audio_player->stop();
        delete audio_player;
        audio_player = nullptr;
    }
}

// Add after global variables:

// Add after initGameOfLife()
void initLayers() {
    layerPixels.resize(GRID_WIDTH * GRID_HEIGHT);
    nextLayerPixels.resize(GRID_WIDTH * GRID_HEIGHT);
    
    // Initialize all pixels with exactly zero height
    for (int i = 0; i < GRID_WIDTH * GRID_HEIGHT; ++i) {
        layerPixels[i] = {0.0f, 0.0f, 0.0f};
        nextLayerPixels[i] = {0.0f, 0.0f, 0.0f};
    }
}

// Add after updateGameOfLife()
void updateLayers() {
    if (globalTime - lastLayerUpdate < LAYER_UPDATE_INTERVAL) {
        return;
    }
    lastLayerUpdate = globalTime;
    
    // Update each pixel's layer
    for (int y = 0; y < GRID_HEIGHT; ++y) {
        for (int x = 0; x < GRID_WIDTH; ++x) {
            int idx = y * GRID_WIDTH + x;
            
            // Get the current height
            float currentHeight = layerPixels[idx].height;
            
            // Increment height for continuous rain effect
            float targetHeight = currentHeight + HEIGHT_INCREMENT;
            
            // Ensure height never goes below 0
            targetHeight = std::max(0.0f, targetHeight);
            
            // Smoothly interpolate towards target height
            float newHeight = glm::mix(currentHeight, targetHeight, 0.1f);
            
            // Update the next layer state
            nextLayerPixels[idx].height = newHeight;
            nextLayerPixels[idx].lastUpdate = globalTime;
        }
    }
    
    // Update state
    layerPixels.swap(nextLayerPixels);
}

// Add after initLayers()
void initSnowParticles() {
    snowParticles.clear();
    std::uniform_real_distribution<float> sizeDist(SNOW_MIN_SIZE, SNOW_MAX_SIZE);
    std::uniform_real_distribution<float> speedDist(SNOW_MIN_SPEED, SNOW_MAX_SPEED);
    std::uniform_real_distribution<float> angleDist(0.0f, 2.0f * M_PI);
    std::uniform_real_distribution<float> radiusDist(0.0f, SNOW_SPAWN_RADIUS);
    std::uniform_real_distribution<float> zDist(SNOW_SPAWN_DISTANCE, SNOW_SPAWN_DISTANCE * 1.5f);
    std::uniform_int_distribution<int> effectDist(0, 3);
    
    // Calculate grid center for proper spawning
    float gridCenterX = GRID_WIDTH / 2.0f;
    float gridCenterY = GRID_HEIGHT / 2.0f;
    
    for (int i = 0; i < MAX_SNOW_PARTICLES; ++i) {
        float angle = angleDist(globalRng);
        float radius = radiusDist(globalRng);
        float x = gridCenterX + cos(angle) * radius;  // Center around grid center
        float y = gridCenterY + sin(angle) * radius;  // Center around grid center
        float z = zDist(globalRng);
        
        SnowParticle particle;
        particle.position = glm::vec3(x, y, z);
        particle.size = sizeDist(globalRng);
        particle.speed = speedDist(globalRng);
        particle.rotation = angleDist(globalRng);
        particle.rotationSpeed = speedDist(globalRng) * 0.1f;
        particle.alpha = 0.8f;
        particle.effectType = effectDist(globalRng);
        particle.hasLanded = false;
        particle.landingTime = 0.0f;
        snowParticles.push_back(particle);
    }
}

// Add after updateLayers()
void updateSnowParticles() {
    std::uniform_real_distribution<float> angleDist(0.0f, 2.0f * M_PI);
    std::uniform_real_distribution<float> radiusDist(0.0f, SNOW_SPAWN_RADIUS);
    std::uniform_real_distribution<float> zDist(SNOW_SPAWN_DISTANCE, SNOW_SPAWN_DISTANCE * 1.5f);
    std::uniform_int_distribution<int> effectDist(0, 3);
    
    // Calculate grid center for proper respawning
    float gridCenterX = GRID_WIDTH / 2.0f;
    float gridCenterY = GRID_HEIGHT / 2.0f;
    
    for (auto& particle : snowParticles) {
        if (!particle.hasLanded) {
            // Move particle towards the grid
            particle.position.z -= particle.speed * 0.016f;
            
            // Add some swaying motion
            particle.position.x += sin(globalTime * 0.5f + particle.rotation) * 0.2f;
            particle.position.y += cos(globalTime * 0.5f + particle.rotation) * 0.2f;
            
            // Update rotation
            particle.rotation += particle.rotationSpeed * 0.016f;
            
            // Check if particle has landed
            if (particle.position.z < 0.0f) {
                particle.hasLanded = true;
                particle.landingTime = globalTime;
                
                // Create landing effect
                SnowflakeEffect effect;
                float gridCenterX = GRID_WIDTH / 2.0f;
                float gridCenterY = GRID_HEIGHT / 2.0f;
                effect.position = glm::vec2(particle.position.x - gridCenterX, particle.position.y - gridCenterY);
                //effect.position = glm::vec2(particle.position.x, particle.position.y);
                effect.radius = EFFECT_RADIUS;
                effect.effectType = particle.effectType;
                effect.blendFactor = 1.0f;
                effect.lastUpdate = globalTime;
                snowflakeEffects.push_back(effect);
            }
        } else {
            // Reset particle if effect has expired
            if (globalTime - particle.landingTime > EFFECT_DURATION) {
                float angle = angleDist(globalRng);
                float radius = radiusDist(globalRng);
                float x = gridCenterX + cos(angle) * radius;  // Center around grid center
                float y = gridCenterY + sin(angle) * radius;  // Center around grid center
                float z = zDist(globalRng);
                
                particle.position = glm::vec3(x, y, z);
                particle.rotation = angleDist(globalRng);
                particle.hasLanded = false;
                particle.effectType = effectDist(globalRng);
            }
        }
    }
}

// Add new function to handle snowflake landing effects
void updateSnowflakeEffects() {
    // Remove expired effects
    snowflakeEffects.erase(
        std::remove_if(snowflakeEffects.begin(), snowflakeEffects.end(),
            [](const SnowflakeEffect& effect) {
                return globalTime - effect.lastUpdate > EFFECT_DURATION;
            }
        ),
        snowflakeEffects.end()
    );
    
    // Update existing effects
    for (auto& effect : snowflakeEffects) {
        float age = globalTime - effect.lastUpdate;
        float progress = age / EFFECT_DURATION;
        effect.blendFactor = 1.0f - progress;  // Fade out over time
        effect.radius = EFFECT_RADIUS * (1.0f - progress * 0.5f);  // Shrink slightly
    }
}

// Vertex shader for rain-in effect from negative Z
const char* vertexShaderSource = R"(
    #version 330 core
    #define MAX_EFFECTS 16
    struct SnowflakeEffect {
        vec2 position;
        float radius;
        int effectType;
        float blend;
    };
    uniform SnowflakeEffect effects[MAX_EFFECTS];
    uniform int numEffects;
    layout (location = 0) in vec3 aPos;
    layout (location = 1) in vec3 aColor;
    layout (location = 2) in float aBrightness;
    layout (location = 3) in float aIndex;
    
    out vec3 Color;
    out float Brightness;
    out vec2 ScreenPos;
    out float CompletionTime;
    out float PatternValue;
    out float EffectBlend;  // Add this for snowflake effects
    
    uniform mat4 model;
    uniform mat4 view;
    uniform mat4 projection;
    uniform float time;
    uniform bool isAlive;
    uniform float layerHeight;
    uniform int effectType;  // Add this for snowflake effects
    uniform float morphStartTime; // Add this uniform
    uniform float morphProgress;
    // Add these uniforms
uniform float pulseStrength;  // For audio reactivity
uniform float maxGridRadius;  // For grid size calculations
uniform float beatIntensity;
    
    // Pattern functions
    float plasmaPattern(vec2 pos, float t) {
        float v1 = sin(pos.x * 0.01 + t * 0.2 + pos.y * 0.01);
        float v2 = sin(pos.x * 0.02 - t * 0.15 - pos.y * 0.02);
        float v3 = sin(sqrt(pos.x * pos.x + pos.y * pos.y) * 0.01 + t * 0.1);
        float angle = atan(pos.y, pos.x);
        float v4 = sin(angle * 4.0 + t * 0.3);
        return (v1 + v2 + v3 + v4) * 0.25;
    }
    
    float spiralPattern(vec2 pos, float t) {
        float angle = atan(pos.y, pos.x);
        float radius = length(pos);
        return sin(angle * 8.0 + radius * 0.1 - t * 0.5);
    }
    
    float checkerPattern(vec2 pos, float t) {
        float angle = t * 0.2;
        float cos_a = cos(angle);
        float sin_a = sin(angle);
        vec2 rotated = vec2(
            pos.x * cos_a - pos.y * sin_a,
            pos.x * sin_a + pos.y * cos_a
        );
        return mod(floor(rotated.x * 0.1) + floor(rotated.y * 0.1), 2.0);
    }
    
    float wavePattern(vec2 pos, float t) {
        return sin(pos.x * 0.1 + t * 0.5) * sin(pos.y * 0.1 + t * 0.3);
    }
    
    void main() {
        float z = aPos.z;
        float maxDuration = 10.0;
        
        // Generate random values based on pixel index
        float rand1 = fract(sin(aIndex * 78.233) * 43758.5453);
        float rand2 = fract(sin((aIndex + 1.0) * 78.233) * 43758.5453);
        float rand3 = fract(sin((aIndex + 3.0) * 78.233) * 43758.5453);
        
        // Calculate starting distance based on duration
        float rainDuration = mix(5.0, maxDuration, rand2);
        float delay = rand1 * 2.0;
        float timeAvailable = maxDuration - delay;
        float startZ = -400.0 * (rainDuration / timeAvailable);
        
        // Calculate local time with random duration and delay
        float localT = clamp((time - delay) / rainDuration, 0.0, 1.0);
        float easedT = localT * localT;  // Quadratic easing
        float accelerationFactor = mix(0.8, 1.2, rand3);
        easedT = mix(localT, easedT, accelerationFactor);
        
        // Calculate whirling effect
        float whirlStrength = 1.0 - easedT;  // Stronger at start, fades to 0
        float whirlSpeed = 2.0;  // Speed of rotation
        float whirlRadius = 50.0;  // Maximum radius of whirl
        
        // Calculate center of grid
        vec2 gridCenter = vec2(160.0, 100.0);
        vec2 pixelPos = aPos.xy - gridCenter;
        
        // Calculate whirl offset
        float angle = time * whirlSpeed + length(pixelPos) * 0.1;
        float whirlX = sin(angle) * whirlRadius * whirlStrength;
        float whirlY = cos(angle) * whirlRadius * whirlStrength;
        
        // Apply whirl offset to position
        vec2 whirlOffset = vec2(whirlX, whirlY) * whirlStrength;
        vec2 finalPos = aPos.xy + whirlOffset;
        
        // Interpolate Z position with acceleration
        float interpZ = mix(startZ, z, easedT);
        
        // Calculate pattern in world space
        vec2 worldPos = finalPos - gridCenter;
        float effectTime = max(0.0, time - 12.0);
        float effectDuration = 10.0;
        float transitionDuration = 2.0;
        float cycleDuration = effectDuration + transitionDuration;
        float effectIndex = floor(effectTime / cycleDuration);
        float effectProgress = fract(effectTime / cycleDuration);
        
        float currentPattern = 0.0;
        float nextPattern = 0.0;
        
        if (effectIndex == 0.0) {
            currentPattern = plasmaPattern(worldPos, effectTime);
            nextPattern = spiralPattern(worldPos, effectTime);
        } else if (effectIndex == 1.0) {
            currentPattern = spiralPattern(worldPos, effectTime);
            nextPattern = checkerPattern(worldPos, effectTime);
        } else if (effectIndex == 2.0) {
            currentPattern = checkerPattern(worldPos, effectTime);
            nextPattern = wavePattern(worldPos, effectTime);
        } else {
            currentPattern = wavePattern(worldPos, effectTime);
            nextPattern = plasmaPattern(worldPos, effectTime);
        }
        
        float transitionProgress = smoothstep(0.0, 1.0, 
            clamp((effectProgress * cycleDuration - effectDuration) / transitionDuration, 0.0, 1.0));
        
        PatternValue = mix(currentPattern, nextPattern, transitionProgress);
        
        // For snowflakes, override PatternValue with the effect
        if (effectType >= 0) {
            switch(effectType) {
                case 0: PatternValue = plasmaPattern(worldPos, time); break;
                case 1: PatternValue = spiralPattern(worldPos, time); break;
                case 2: PatternValue = checkerPattern(worldPos, time); break;
                case 3: PatternValue = wavePattern(worldPos, time); break;
            }
        } else {
            // For grid pixels, blend in any active effects
            EffectBlend = 0.0;
            float totalWeight = 0.0;
            float accumulatedPattern = 0.0;

            for (int i = 0; i < numEffects; ++i) {
                float dist = length(worldPos - effects[i].position);
                if (dist < effects[i].radius) {
                    float localBlend = (1.0 - dist / effects[i].radius) * effects[i].blend;
                    float effectValue = 0.0;
                    
                    switch(effects[i].effectType) {
                        case 0: effectValue = plasmaPattern(worldPos, time); break;
                        case 1: effectValue = spiralPattern(worldPos, time); break;
                        case 2: effectValue = checkerPattern(worldPos, time); break;
                        case 3: effectValue = wavePattern(worldPos, time); break;
                    }
                    
                    // Accumulate weighted effects
                    accumulatedPattern += effectValue * localBlend;
                    totalWeight += localBlend;
                    EffectBlend = max(EffectBlend, localBlend);
                }
            }

            // Normalize the accumulated pattern if we have any effects
            if (totalWeight > 0.0) {
                PatternValue = accumulatedPattern / totalWeight;
            } else {
                // Keep the default pattern if no effects are active
                PatternValue = currentPattern;
            }
        }
        
        // Apply pulsating height based on pattern value only after spiral pattern
        float patternHeight = 0.0;
        if (effectIndex >= 1.0) {
            float pulse = sin(effectTime * 2.0) * 0.5 + 0.5;
            patternHeight = PatternValue * 5.0 * pulse;
        }
        
        // Apply both Game of Life and layer height
        if (isAlive) {
            interpZ += 2.0;
        }
        interpZ += layerHeight;
        interpZ += patternHeight;
        
        // Replace the origami folding section with this improved bird version

// In the main shader, replace the folding section:
// Replace the bird morphing section with Möbius strip transformation
// In the vertex shader, replace all the transformation code with:
    float a = 100.0;  // Scale of lemniscate
    float b = 30.0;   // Height of twist
    const float MORPH_DURATION = 16.0;  // Time for morph to complete
        vec3 morphedPos;
        if (time < morphStartTime) {
    morphedPos = vec3(finalPos, interpZ);
    gl_Position = projection * view * model * vec4(morphedPos, 1.0);
        } else {
    float u = (finalPos.x - (gridCenter.x - 160)) / (2.0 * 160);
    float v = (finalPos.y - gridCenter.y) / maxGridRadius;
    
    float progress = morphProgress;
    float t = u * 2.0 * 3.14159;
    
    vec3 moebius;
    float denom = 1.0 + sin(t) * sin(t);
    moebius.x = gridCenter.x + a * cos(t) / denom;
    moebius.y = gridCenter.y + a * sin(t) * cos(t) / denom;
    
    float crossingAngle = atan(1.0, cos(t + 3.14159 / 4.));
    float zOffset = 30.0 * smoothstep(0.0, 1.0, cos(t + 3.14159 / 4.)); 
    moebius.z = b * (sin(t * 0.5) * v + cos(crossingAngle) * v) + zOffset;
    moebius.x += b * sin(crossingAngle) * v * cos(t);
    moebius.y += b * sin(crossingAngle) * v * sin(t);
    
    morphedPos = mix(vec3(finalPos, interpZ), moebius, smoothstep(0.0, 0.3, progress));
    
// Only start camera movement after morph is complete
if (time > morphStartTime + MORPH_DURATION) {
float v_cam = 0.0; // Camera slightly above the surface, not at the edge
float blend = 0.8; // Strongly blend look-at toward the center

// Camera position on the strip at t = cameraT, v = v_cam
float cameraT = (time - (morphStartTime + MORPH_DURATION)) * 0.5;
float camDenom = 1.0 + sin(cameraT) * sin(cameraT);
float crossingAngle = atan(1.0, cos(cameraT + 3.14159 / 4.));
float zOffset = 30.0 * smoothstep(0.0, 1.0, cos(cameraT + 3.14159 / 4.));
vec3 cameraPos;
cameraPos.x = gridCenter.x + a * cos(cameraT) / camDenom;
cameraPos.y = gridCenter.y + a * sin(cameraT) * cos(cameraT) / camDenom;
cameraPos.z = b * (sin(cameraT * 0.5) * v_cam + cos(crossingAngle) * v_cam) + zOffset;
cameraPos.x += b * sin(crossingAngle) * v_cam * cos(cameraT);
cameraPos.y += b * sin(crossingAngle) * v_cam * sin(cameraT);

// Look-ahead point at t = lookT, v = 0 (center of strip)
float lookT = cameraT + 0.6;
float lookDenom = 1.0 + sin(lookT) * sin(lookT);
float look_crossingAngle = atan(1.0, cos(lookT + 3.14159 / 4.));
float look_zOffset = 30.0 * smoothstep(0.0, 1.0, cos(lookT + 3.14159 / 4.));
vec3 lookAtPos;
lookAtPos.x = gridCenter.x + a * cos(lookT) / lookDenom;
lookAtPos.y = gridCenter.y + a * sin(lookT) * cos(lookT) / lookDenom;
lookAtPos.z = look_zOffset; // v = 0, so only zOffset

// Blend look-at point between tangent and center
vec3 blendedLookAt = mix(lookAtPos, vec3(gridCenter.x, gridCenter.y, look_zOffset), blend);

// Compute tangent (forward) at cameraT, v = v_cam
float dt = 0.01;
float t1 = cameraT;
float t2 = cameraT + dt;
float denom1 = 1.0 + sin(t1) * sin(t1);
float denom2 = 1.0 + sin(t2) * sin(t2);
float ca1 = atan(1.0, cos(t1 + 3.14159 / 4.));
float zo1 = 30.0 * smoothstep(0.0, 1.0, cos(t1 + 3.14159 / 4.));
float ca2 = atan(1.0, cos(t2 + 3.14159 / 4.));
float zo2 = 30.0 * smoothstep(0.0, 1.0, cos(t2 + 3.14159 / 4.));
vec3 p1, p2;
p1.x = gridCenter.x + a * cos(t1) / denom1;
p1.y = gridCenter.y + a * sin(t1) * cos(t1) / denom1;
p1.z = b * (sin(t1 * 0.5) * v_cam + cos(ca1) * v_cam) + zo1;
p1.x += b * sin(ca1) * v_cam * cos(t1);
p1.y += b * sin(ca1) * v_cam * sin(t1);
p2.x = gridCenter.x + a * cos(t2) / denom2;
p2.y = gridCenter.y + a * sin(t2) * cos(t2) / denom2;
p2.z = b * (sin(t2 * 0.5) * v_cam + cos(ca2) * v_cam) + zo2;
p2.x += b * sin(ca2) * v_cam * cos(t2);
p2.y += b * sin(ca2) * v_cam * sin(t2);
vec3 tangent = normalize(p2 - p1);

// Normal as before, but with larger v_side for stability
float v_side = 0.3;
vec3 p_side1 = p1, p_side2 = p1;
p_side1.z = b * (sin(t1 * 0.5) * v_side + cos(ca1) * v_side) + zo1;
p_side1.x += b * sin(ca1) * v_side * cos(t1);
p_side1.y += b * sin(ca1) * v_side * sin(t1);
p_side2.z = b * (sin(t1 * 0.5) * -v_side + cos(ca1) * -v_side) + zo1;
p_side2.x += b * sin(ca1) * -v_side * cos(t1);
p_side2.y += b * sin(ca1) * -v_side * sin(t1);
vec3 normal = normalize(p_side1 - p_side2);

// Camera basis
vec3 forward = normalize(blendedLookAt - cameraPos);
vec3 up = normal;

// If up and forward are nearly parallel, use world up
if (abs(dot(forward, up)) > 0.95) {
    up = vec3(0, 0, 1);
}

// Re-orthogonalize
vec3 right = normalize(cross(forward, up));
up = normalize(cross(right, forward));

// Build view matrix (lookAt)
mat4 cameraView = mat4(
    vec4(right, 0.0),
    vec4(up, 0.0),
    vec4(-forward, 0.0),
    vec4(0.0, 0.0, 0.0, 1.0)
);
cameraView = cameraView * mat4(
    1.0, 0.0, 0.0, 0.0,
    0.0, 1.0, 0.0, 0.0,
    0.0, 0.0, 1.0, 0.0,
    -cameraPos.x, -cameraPos.y, -cameraPos.z, 1.0
);

gl_Position = projection * cameraView * vec4(morphedPos, 1.0);
} else {
    gl_Position = projection * view * model * vec4(morphedPos, 1.0);
}
}

// Single transformation at the end


        Color = aColor;
        Brightness = aBrightness;
        ScreenPos = gl_Position.xy / gl_Position.w;
        CompletionTime = delay + rainDuration;
    }
)";

// Fragment shader
const char* fragmentShaderSource = R"(
    #version 330 core
    in vec3 Color;
    in float Brightness;
    in vec2 ScreenPos;
    in float CompletionTime;
    in float PatternValue;
    in float EffectBlend;
    
    out vec4 FragColor;
    
    uniform float time;
    
    // Plasma color palette
    vec3 plasmaColor(float t) {
        vec3 c1 = vec3(0.0, 0.0, 0.5);  // Deep blue
        vec3 c2 = vec3(0.0, 0.0, 1.0);  // Bright blue
        vec3 c3 = vec3(0.0, 1.0, 1.0);  // Cyan
        vec3 c4 = vec3(0.0, 1.0, 0.0);  // Green
        vec3 c5 = vec3(1.0, 1.0, 0.0);  // Yellow
        vec3 c6 = vec3(1.0, 0.5, 0.0);  // Orange
        vec3 c7 = vec3(1.0, 0.0, 0.0);  // Red
        vec3 c8 = vec3(0.5, 0.0, 0.0);  // Dark red
        
        float wobble = sin(t * 0.1) * 0.2;
        float p = fract(t * 0.5 + wobble);
        float s = p * 8.0;
        int i = int(s);
        float f = fract(s);
        
        if (i == 7) return mix(c8, c1, f);
        if (i == 0) return mix(c1, c2, f);
        if (i == 1) return mix(c2, c3, f);
        if (i == 2) return mix(c3, c4, f);
        if (i == 3) return mix(c4, c5, f);
        if (i == 4) return mix(c5, c6, f);
        if (i == 5) return mix(c6, c7, f);
        return mix(c7, c8, f);
    }
    
    void main() {
        float scanline = sin(ScreenPos.y * 1000.0) * 0.5 + 0.5;
        float scanlineIntensity = 0.1;
        float glow = sin(time * 2.0) * 0.1 + 0.9;
        
        float rotationTime = max(0.0, time - 12.0);
        vec3 finalColor = plasmaColor(PatternValue + rotationTime * 0.2) * Brightness;
        
        // Blend with snowflake effect if present
        if (EffectBlend != 0.0) {
            vec3 effectColor = plasmaColor(EffectBlend + rotationTime * 0.2);
            finalColor = mix(finalColor, effectColor, EffectBlend * 0.5);
        }
        
        finalColor *= (1.0 - scanline * scanlineIntensity) * glow;
        finalColor += finalColor * 0.1 * sin(time * 3.0);
        
        FragColor = vec4(finalColor, 1.0);
    }
)";

// Simple vertex shader for static grid
const char* simpleVertexShaderSource = R"(
    #version 330 core
    layout (location = 0) in vec3 aPos;
    layout (location = 1) in vec3 aColor;
    layout (location = 2) in float aBrightness;
    
    out vec3 Color;
    out float Brightness;
    out vec2 ScreenPos;
    out float PatternValue;
    
    uniform mat4 model;
    uniform mat4 view;
    uniform mat4 projection;
    uniform float time;
    uniform bool isAlive;
    uniform float layerHeight;  // Add uniform for layer height
    
    // Pattern functions
    float plasmaPattern(vec2 pos, float t) {
        float v1 = sin(pos.x * 0.01 + t * 0.2 + pos.y * 0.01);
        float v2 = sin(pos.x * 0.02 - t * 0.15 - pos.y * 0.02);
        float v3 = sin(sqrt(pos.x * pos.x + pos.y * pos.y) * 0.01 + t * 0.1);
        float angle = atan(pos.y, pos.x);
        float v4 = sin(angle * 4.0 + t * 0.3);
        return (v1 + v2 + v3 + v4) * 0.25;
    }
    
    float spiralPattern(vec2 pos, float t) {
        float angle = atan(pos.y, pos.x);
        float radius = length(pos);
        return sin(angle * 8.0 + radius * 0.1 - t * 0.5);
    }
    
    float checkerPattern(vec2 pos, float t) {
        float angle = t * 0.2;
        float cos_a = cos(angle);
        float sin_a = sin(angle);
        vec2 rotated = vec2(
            pos.x * cos_a - pos.y * sin_a,
            pos.x * sin_a + pos.y * cos_a
        );
        return mod(floor(rotated.x * 0.1) + floor(rotated.y * 0.1), 2.0);
    }
    
    float wavePattern(vec2 pos, float t) {
        return sin(pos.x * 0.1 + t * 0.5) * sin(pos.y * 0.1 + t * 0.3);
    }
    
    void main() {
        float z = aPos.z;
        
        // Calculate pattern in world space, centered on the grid
        vec2 worldPos = aPos.xy - vec2(160.0, 100.0);
        float effectTime = max(0.0, time - 12.0);
        float effectDuration = 10.0;
        float transitionDuration = 2.0;
        float cycleDuration = effectDuration + transitionDuration;
        float effectIndex = floor(effectTime / cycleDuration);
        float effectProgress = fract(effectTime / cycleDuration);
        
        float currentPattern = 0.0;
        float nextPattern = 0.0;
        
        if (effectIndex == 0.0) {
            currentPattern = plasmaPattern(worldPos, effectTime);
            nextPattern = spiralPattern(worldPos, effectTime);
        } else if (effectIndex == 1.0) {
            currentPattern = spiralPattern(worldPos, effectTime);
            nextPattern = checkerPattern(worldPos, effectTime);
        } else if (effectIndex == 2.0) {
            currentPattern = checkerPattern(worldPos, effectTime);
            nextPattern = wavePattern(worldPos, effectTime);
        } else {
            currentPattern = wavePattern(worldPos, effectTime);
            nextPattern = plasmaPattern(worldPos, effectTime);
        }
        
        float transitionProgress = smoothstep(0.0, 1.0, 
            clamp((effectProgress * cycleDuration - effectDuration) / transitionDuration, 0.0, 1.0));
        
        PatternValue = mix(currentPattern, nextPattern, transitionProgress);
        
        // Apply pulsating height based on pattern value only after spiral pattern
        float patternHeight = 0.0;
        if (effectIndex >= 1.0) {  // Start after spiral pattern
            float pulse = sin(effectTime * 2.0) * 0.5 + 0.5;  // 0 to 1 pulsation
            patternHeight = PatternValue * 5.0 * pulse;  // Scale pattern value to height range with pulsation
        }
        
        // Apply both Game of Life and layer height
        if (isAlive) {
            z += 2.0;  // Game of Life promotion
        }
        z += layerHeight;  // Add layer height
        z += patternHeight;  // Add pattern-based height
        
        vec4 pos = vec4(aPos.xy, z, 1.0);
        gl_Position = projection * view * model * pos;
        Color = aColor;
        Brightness = aBrightness;
        ScreenPos = gl_Position.xy / gl_Position.w;
    }
)";

// Add after the global variables
const int GOL_BLOCK_SIZE = 4;  // Size of each Game of Life block
const int GOL_WIDTH = GRID_WIDTH / GOL_BLOCK_SIZE;
const int GOL_HEIGHT = GRID_HEIGHT / GOL_BLOCK_SIZE;

// Add after the global variables
std::vector<int> golBlockIndices;  // Lookup table for GOL block indices

// Modify initGameOfLife()
void initGameOfLife() {
    gameOfLifeState.resize(GOL_WIDTH * GOL_HEIGHT, false);
    nextGameOfLifeState.resize(GOL_WIDTH * GOL_HEIGHT, false);
    
    // Initialize with a random pattern
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_real_distribution<float> dist(0.0f, 1.0f);
    
    for (int i = 0; i < GOL_WIDTH * GOL_HEIGHT; ++i) {
        gameOfLifeState[i] = dist(gen) < 0.3f;  // 30% chance of being alive
    }

    // Add to your initialization code
    initMazeBackground(); // nothing to do with GoL but have to put somewhere
}

// Modify updateGameOfLife()
void updateGameOfLife() {
    if (globalTime - lastGameOfLifeUpdate < GAME_OF_LIFE_UPDATE_INTERVAL) {
        return;
    }
    lastGameOfLifeUpdate = globalTime;
    
    // Count neighbors for each cell
    for (int y = 0; y < GOL_HEIGHT; ++y) {
        for (int x = 0; x < GOL_WIDTH; ++x) {
            int neighbors = 0;
            
            // Check all 8 neighbors
            for (int dy = -1; dy <= 1; ++dy) {
                for (int dx = -1; dx <= 1; ++dx) {
                    if (dx == 0 && dy == 0) continue;
                    
                    int nx = (x + dx + GOL_WIDTH) % GOL_WIDTH;
                    int ny = (y + dy + GOL_HEIGHT) % GOL_HEIGHT;
                    if (gameOfLifeState[ny * GOL_WIDTH + nx]) {
                        neighbors++;
                    }
                }
            }
            
            int idx = y * GOL_WIDTH + x;
            bool currentState = gameOfLifeState[idx];
            
            // Apply Game of Life rules
            if (currentState) {
                // Any live cell with fewer than two live neighbors dies (underpopulation)
                // Any live cell with more than three live neighbors dies (overpopulation)
                nextGameOfLifeState[idx] = (neighbors == 2 || neighbors == 3);
            } else {
                // Any dead cell with exactly three live neighbors becomes alive (reproduction)
                nextGameOfLifeState[idx] = (neighbors == 3);
            }
        }
    }
    
    // Update state
    gameOfLifeState.swap(nextGameOfLifeState);
}

// Add after initGameOfLife()
void initGolBlockIndices() {
    golBlockIndices.resize(GRID_WIDTH * GRID_HEIGHT);
    for (int y = 0; y < GRID_HEIGHT; ++y) {
        for (int x = 0; x < GRID_WIDTH; ++x) {
            int golX = x / GOL_BLOCK_SIZE;
            int golY = y / GOL_BLOCK_SIZE;
            int golIdx = golY * GOL_WIDTH + golX;
            golBlockIndices[y * GRID_WIDTH + x] = golIdx;
        }
    }
}

// Initialize OpenGL buffers
void initBuffers() {
    // Create and bind VAO
    glGenVertexArrays(1, &VAO);
    glBindVertexArray(VAO);

    // Create and bind VBO
    glGenBuffers(1, &VBO);
    glBindBuffer(GL_ARRAY_BUFFER, VBO);

    // Allocate buffer for vertex data
    std::vector<VertexData> vertexData;
    int idx = 0;
    for (int y = 0; y < GRID_HEIGHT; ++y) {
        for (int x = 0; x < GRID_WIDTH; ++x) {
            const Pixel& pixel = pixels[y * GRID_WIDTH + x];
            float px = (float)x;
            float py = (float)y;
            float index = (float)idx;
            // Top-left
            vertexData.push_back({ glm::vec3(px, py + 1.0f, GRID_DEPTH), pixel.color, pixel.brightness, index });
            // Top-right
            vertexData.push_back({ glm::vec3(px + 1.0f, py + 1.0f, GRID_DEPTH), pixel.color, pixel.brightness, index });
            // Bottom-right
            vertexData.push_back({ glm::vec3(px + 1.0f, py, GRID_DEPTH), pixel.color, pixel.brightness, index });
            // Bottom-left
            vertexData.push_back({ glm::vec3(px, py, GRID_DEPTH), pixel.color, pixel.brightness, index });
            ++idx;
        }
    }

    // Upload vertex data
    glBufferData(GL_ARRAY_BUFFER, vertexData.size() * sizeof(VertexData), vertexData.data(), GL_STATIC_DRAW);

    // Set up vertex attributes
    // Position attribute
    glVertexAttribPointer(0, 3, GL_FLOAT, GL_FALSE, sizeof(VertexData), (void*)0);
    glEnableVertexAttribArray(0);
    // Color attribute
    glVertexAttribPointer(1, 3, GL_FLOAT, GL_FALSE, sizeof(VertexData), (void*)offsetof(VertexData, color));
    glEnableVertexAttribArray(1);
    // Brightness attribute
    glVertexAttribPointer(2, 1, GL_FLOAT, GL_FALSE, sizeof(VertexData), (void*)offsetof(VertexData, brightness));
    glEnableVertexAttribArray(2);
    // Index attribute
    glVertexAttribPointer(3, 1, GL_FLOAT, GL_FALSE, sizeof(VertexData), (void*)offsetof(VertexData, index));
    glEnableVertexAttribArray(3);

    // Unbind VAO and VBO
    glBindBuffer(GL_ARRAY_BUFFER, 0);
    glBindVertexArray(0);
}

// Add to ShaderUniforms struct:
struct ShaderUniforms {
    GLint model, view, projection;
    GLint time, isAlive, layerHeight, effectType;
    GLint morphStartTime, morphProgress;
    GLint snowflakeEffectType;  // Add cached location for snowflake effect type
    // Add more as needed
};
ShaderUniforms uniforms;  // Add back the global instance

// Add precomputed unit corners for snowflakes (optimization #3):
const glm::vec3 unitCorners[4] = {
    glm::vec3(-0.5f, -0.5f, 0.0f),
    glm::vec3(0.5f, -0.5f, 0.0f),
    glm::vec3(0.5f, 0.5f, 0.0f),
    glm::vec3(-0.5f, 0.5f, 0.0f)
};

// Maze vertex shader
const char* mazeVertexShaderSource = R"(
    #version 330 core
    layout (location = 0) in vec3 aPos;
    layout (location = 1) in vec3 aColor;
    uniform mat4 model;
    uniform mat4 view;
    uniform mat4 projection;
    out vec3 Color;
    void main() {
        gl_Position = projection * view * model * vec4(aPos, 1.0);
        Color = aColor;
    }
)";

// Maze fragment shader
const char* mazeFragmentShaderSource = R"(
    #version 330 core
    in vec3 Color;
    out vec4 FragColor;
    void main() {
        FragColor = vec4(Color, 1.0);
    }
)";

bool initMazeShader() {
    GLuint vertexShader = glCreateShader(GL_VERTEX_SHADER);
    glShaderSource(vertexShader, 1, &mazeVertexShaderSource, NULL);
    glCompileShader(vertexShader);

    GLint success;
    glGetShaderiv(vertexShader, GL_COMPILE_STATUS, &success);
    if (!success) {
        char infoLog[512];
        glGetShaderInfoLog(vertexShader, 512, NULL, infoLog);
        printf("Maze vertex shader compilation failed: %s\n", infoLog);
        return false;
    }

    GLuint fragmentShader = glCreateShader(GL_FRAGMENT_SHADER);
    glShaderSource(fragmentShader, 1, &mazeFragmentShaderSource, NULL);
    glCompileShader(fragmentShader);

    glGetShaderiv(fragmentShader, GL_COMPILE_STATUS, &success);
    if (!success) {
        char infoLog[512];
        glGetShaderInfoLog(fragmentShader, 512, NULL, infoLog);
        printf("Maze fragment shader compilation failed: %s\n", infoLog);
        return false;
    }

    mazeShaderProgram = glCreateProgram();
    glAttachShader(mazeShaderProgram, vertexShader);
    glAttachShader(mazeShaderProgram, fragmentShader);
    glLinkProgram(mazeShaderProgram);

    glGetProgramiv(mazeShaderProgram, GL_LINK_STATUS, &success);
    if (!success) {
        char infoLog[512];
        glGetProgramInfoLog(mazeShaderProgram, 512, NULL, infoLog);
        printf("Maze shader program linking failed: %s\n", infoLog);
        return false;
    }

    glDeleteShader(vertexShader);
    glDeleteShader(fragmentShader);

    return true;
}
// Initialize shaders
bool initShaders() {


    // Create and compile rain-in vertex shader
    GLuint vertexShader = glCreateShader(GL_VERTEX_SHADER);
    if (vertexShader == 0) {
        printf("Failed to create rain-in vertex shader\n");
        return false;
    }
    glShaderSource(vertexShader, 1, &vertexShaderSource, NULL);
    glCompileShader(vertexShader);

    // Check vertex shader compilation
    GLint success;
    GLchar infoLog[512];
    glGetShaderiv(vertexShader, GL_COMPILE_STATUS, &success);
    if (!success) {
        glGetShaderInfoLog(vertexShader, 512, NULL, infoLog);
        printf("Rain-in vertex shader compilation failed: %s\n", infoLog);
        return false;
    }

    // Create and compile simple vertex shader
    GLuint simpleVertexShader = glCreateShader(GL_VERTEX_SHADER);
    if (simpleVertexShader == 0) {
        printf("Failed to create simple vertex shader\n");
        return false;
    }
    glShaderSource(simpleVertexShader, 1, &simpleVertexShaderSource, NULL);
    glCompileShader(simpleVertexShader);

    glGetShaderiv(simpleVertexShader, GL_COMPILE_STATUS, &success);
    if (!success) {
        glGetShaderInfoLog(simpleVertexShader, 512, NULL, infoLog);
        printf("Simple vertex shader compilation failed: %s\n", infoLog);
        return false;
    }

    // Create and compile fragment shader
    GLuint fragmentShader = glCreateShader(GL_FRAGMENT_SHADER);
    if (fragmentShader == 0) {
        printf("Failed to create fragment shader\n");
        return false;
    }
    glShaderSource(fragmentShader, 1, &fragmentShaderSource, NULL);
    glCompileShader(fragmentShader);

    glGetShaderiv(fragmentShader, GL_COMPILE_STATUS, &success);
    if (!success) {
        glGetShaderInfoLog(fragmentShader, 512, NULL, infoLog);
        printf("Fragment shader compilation failed: %s\n", infoLog);
        return false;
    }

    // Create rain-in shader program
    shaderProgram = glCreateProgram();
    if (shaderProgram == 0) {
        printf("Failed to create rain-in shader program\n");
        return false;
    }
    glAttachShader(shaderProgram, vertexShader);
    glAttachShader(shaderProgram, fragmentShader);
    glLinkProgram(shaderProgram);

    glGetProgramiv(shaderProgram, GL_LINK_STATUS, &success);
    if (!success) {
        glGetProgramInfoLog(shaderProgram, 512, NULL, infoLog);
        printf("Rain-in shader program linking failed: %s\n", infoLog);
        return false;
    }

    // Create simple shader program
    simpleShaderProgram = glCreateProgram();
    if (simpleShaderProgram == 0) {
        printf("Failed to create simple shader program\n");
        return false;
    }
    glAttachShader(simpleShaderProgram, simpleVertexShader);
    glAttachShader(simpleShaderProgram, fragmentShader);
    glLinkProgram(simpleShaderProgram);

    glGetProgramiv(simpleShaderProgram, GL_LINK_STATUS, &success);
    if (!success) {
        glGetProgramInfoLog(simpleShaderProgram, 512, NULL, infoLog);
        printf("Simple shader program linking failed: %s\n", infoLog);
        return false;
    }

    // Get uniform locations for rain-in shader program
    glUseProgram(shaderProgram);
    uniforms.model = glGetUniformLocation(shaderProgram, "model");
    uniforms.view = glGetUniformLocation(shaderProgram, "view");
    uniforms.projection = glGetUniformLocation(shaderProgram, "projection");
    uniforms.time = glGetUniformLocation(shaderProgram, "time");
    uniforms.isAlive = glGetUniformLocation(shaderProgram, "isAlive");
    uniforms.layerHeight = glGetUniformLocation(shaderProgram, "layerHeight");
    uniforms.effectType = glGetUniformLocation(shaderProgram, "effectType");
    uniforms.morphStartTime = glGetUniformLocation(shaderProgram, "morphStartTime");
    uniforms.morphProgress = glGetUniformLocation(shaderProgram, "morphProgress");
    
    if (uniforms.model == -1 || uniforms.view == -1 || uniforms.projection == -1 || uniforms.time == -1 || 
        uniforms.isAlive == -1 || uniforms.layerHeight == -1 || uniforms.effectType == -1 || 
        uniforms.morphStartTime == -1 || uniforms.morphProgress == -1) {
        printf("Error: Failed to get uniform locations for rain-in shader program\n");
        printf("Model: %d, View: %d, Projection: %d, Time: %d, isAlive: %d, layerHeight: %d, effectType: %d, morphStartTime: %d, morphProgress: %d\n", 
               uniforms.model, uniforms.view, uniforms.projection, uniforms.time, uniforms.isAlive, uniforms.layerHeight, uniforms.effectType, uniforms.morphStartTime, uniforms.morphProgress);
        return false;
    }

    // Get uniform locations for simple shader program
    glUseProgram(simpleShaderProgram);
    uniforms.model = glGetUniformLocation(simpleShaderProgram, "model");
    uniforms.view = glGetUniformLocation(simpleShaderProgram, "view");
    uniforms.projection = glGetUniformLocation(simpleShaderProgram, "projection");
    uniforms.time = glGetUniformLocation(simpleShaderProgram, "time");
    
    if (uniforms.model == -1 || uniforms.view == -1 || uniforms.projection == -1 || uniforms.time == -1) {
        printf("Error: Failed to get uniform locations for simple shader program\n");
        printf("Model: %d, View: %d, Projection: %d, Time: %d\n", uniforms.model, uniforms.view, uniforms.projection, uniforms.time);
        return false;
    }

    // Clean up
    glDeleteShader(vertexShader);
    glDeleteShader(simpleVertexShader);
    glDeleteShader(fragmentShader);

    // Set initial shader program
    glUseProgram(shaderProgram);

    printf("[DEBUG] Shader compile/link status: %d\n", success);

    // Add caching for snowflake effect type
    uniforms.snowflakeEffectType = glGetUniformLocation(shaderProgram, "effectType");
    if (uniforms.snowflakeEffectType == -1) {
        printf("Warning: Failed to get snowflake effectType uniform location\n");
    }

    if (!initMazeShader()) {
    printf("Maze shader failed to initialize!\n");
    exit(1);
    }
    return true;
}

// Update projection matrix to use perspective
void updateProjection() {
    float aspectRatio = (float)windowWidth / windowHeight;
    // Compute baseCameraZ so that the grid fits exactly in the frustum at Z=0
    // Vertical fit:
    baseCameraZ = (GRID_HEIGHT / 2.0f) / tan(fov / 2.0f);
    cameraZ = baseCameraZ;  // Initialize cameraZ
    projection = glm::perspective(fov, aspectRatio, 0.1f, 2000.0f);
}

// GLUT reshape callback
void reshape(int w, int h) {
    windowWidth = w;
    windowHeight = h;
    glViewport(0, 0, w, h);
    updateProjection();
}

// Utility for OpenGL error checking
#define GL_CHECK_ERROR(label) { \
    GLenum err = glGetError(); \
    if (err != GL_NO_ERROR) { \
        printf("OpenGL error at %s: %d\n", label, err); \
    } \
}

// Display function
void display() {
    glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
    glDisable(GL_CULL_FACE);

    if (globalTime > morphStartTime) {
    // Continuous progress without modulo
        updateMazeBackground();
        drawMazeBackground();
    }

    // Update Game of Life, layers, and snow if grid is in place
    if (globalTime >= 12.0f) {
        updateGameOfLife();  // Re-enabled
        float effectTime = (globalTime - 12.0f) * EFFECT_TIME_SCALE;
        float effectDuration = 10.0f * EFFECT_TIME_SCALE;
        float transitionDuration = 2.0f * EFFECT_TIME_SCALE;
        float cycleDuration = effectDuration + transitionDuration;
        float effectIndex = floor(effectTime / cycleDuration);
        
        // Only update layers after spiral pattern
        if (effectIndex >= 1.0) {
            updateLayers();
        }
        updateSnowParticles();
        updateSnowflakeEffects();
    }
    
    // Calculate if all pixels have completed their animation
    bool allPixelsComplete = true;
    if (globalTime < 12.0f) {  // Allow extra time for animation to complete
        // Check if any pixel would still be animating
        float maxDelay = 2.0f;  // Maximum delay from shader
        float maxDuration = 10.0f;  // Maximum duration from shader
        allPixelsComplete = (globalTime >= maxDelay + maxDuration);
        
        // Record the time when animation completes
        if (allPixelsComplete && animationCompleteTime == 0.0f) {
            animationCompleteTime = globalTime;
        }
    }
    
    // Choose shader program based on completion
    GLuint currentProgram = shaderProgram;
    
    if (currentProgram == 0) {
        printf("Error: Shader program not initialized\n");
        return;
    }
    
    glUseProgram(currentProgram);
    //printf("[DEBUG] Using shader program: %d\n", currentProgram);
    GL_CHECK_ERROR("glUseProgram");
    
    // Initialize view matrix
    glm::mat4 view;
    
    // Calculate camera animation for transitions
    float effectTime = (globalTime - 12.0f) * EFFECT_TIME_SCALE;
    float effectDuration = 10.0f * EFFECT_TIME_SCALE;
    float transitionDuration = 2.0f * EFFECT_TIME_SCALE;
    float cycleDuration = effectDuration + transitionDuration;
    
    // Calculate which effect we're in and the transition progress
    float effectIndex = std::floor(effectTime / cycleDuration);
    float effectProgress = effectTime / cycleDuration - effectIndex;
    
    // Calculate camera movement during transitions
    float transitionProgress = 0.0f;
    if (effectProgress > effectDuration / cycleDuration) {
        // We're in a transition period
        transitionProgress = (effectProgress - effectDuration / cycleDuration) / (transitionDuration / cycleDuration);
        
        // Create chaotic camera movement
        float chaosX = std::sin(transitionProgress * 30.0f) * 0.5f + std::cos(transitionProgress * 20.0f) * 0.3f;
        float chaosY = std::cos(transitionProgress * 25.0f) * 0.4f + std::sin(transitionProgress * 15.0f) * 0.2f;
        float chaosZ = std::sin(transitionProgress * 35.0f) * 0.3f + std::cos(transitionProgress * 40.0f) * 0.4f;
        
        // Smooth step function implementation
        auto smoothStep = [](float edge0, float edge1, float x) {
            float t = glm::clamp((x - edge0) / (edge1 - edge0), 0.0f, 1.0f);
            return t * t * (3.0f - 2.0f * t);
        };
        
        float zoomIn = smoothStep(0.0f, 0.5f, transitionProgress);
        float zoomOut = smoothStep(0.5f, 1.0f, transitionProgress);
        
        // Calculate camera position
        float centerX = GRID_WIDTH/2.0f;
        float centerY = GRID_HEIGHT/2.0f;
        float minZ = baseCameraZ * 0.1f;
        float maxZ = baseCameraZ * 1.5f;
        
        // Apply chaotic movement
        if (transitionProgress < 0.5f) {
            float chaosScale = zoomIn * 50.0f;
            float cameraX = centerX + chaosX * chaosScale;
            float cameraY = centerY + chaosY * chaosScale;
            float cameraZ = glm::mix(baseCameraZ, minZ, zoomIn) + chaosZ * chaosScale;
            
            glm::vec3 cameraPos = glm::vec3(cameraX, cameraY, cameraZ);
            glm::vec3 targetPos = glm::vec3(centerX, centerY, 0.0f);
            glm::vec3 up = glm::vec3(0.0f, 1.0f, 0.0f);
            view = glm::lookAt(cameraPos, targetPos, up);
        } else {
            float returnScale = (1.0f - zoomOut) * 50.0f;
            float cameraX = centerX + chaosX * returnScale;
            float cameraY = centerY + chaosY * returnScale;
            float cameraZ = glm::mix(minZ, baseCameraZ, zoomOut) + chaosZ * returnScale;
            
            glm::vec3 cameraPos = glm::vec3(cameraX, cameraY, cameraZ);
            glm::vec3 targetPos = glm::vec3(centerX, centerY, 0.0f);
            glm::vec3 up = glm::vec3(0.0f, 1.0f, 0.0f);
            view = glm::lookAt(cameraPos, targetPos, up);
        }
    } else {
        // Add camera rotation during rain-in
        float rainProgress = glm::clamp(globalTime / 12.0f, 0.0f, 1.0f);
        float rotationAngle = glm::sin(rainProgress * M_PI) * 0.2f;  // Max 0.2 radians rotation
        
        // Calculate camera position with rotation
        glm::vec3 cameraPos = glm::vec3(GRID_WIDTH/2.0f, GRID_HEIGHT/2.0f, baseCameraZ);
        glm::vec3 targetPos = glm::vec3(GRID_WIDTH/2.0f, GRID_HEIGHT/2.0f, 0.0f);
        
        // Create rotation matrix around Z axis
        glm::mat4 rotation = glm::rotate(glm::mat4(1.0f), rotationAngle, glm::vec3(0.0f, 0.0f, 1.0f));
        
        // Apply rotation to camera position
        glm::vec4 rotatedPos = rotation * glm::vec4(cameraPos - targetPos, 1.0f);
        cameraPos = targetPos + glm::vec3(rotatedPos);
        
        glm::vec3 up = glm::vec3(0.0f, 1.0f, 0.0f);
        view = glm::lookAt(cameraPos, targetPos, up);
    }
    
    // Get uniform locations for the current shader program
    GLint modelLoc = glGetUniformLocation(currentProgram, "model");
    GLint viewLoc = glGetUniformLocation(currentProgram, "view");
    GLint projLoc = glGetUniformLocation(currentProgram, "projection");
    GLint timeLoc = glGetUniformLocation(currentProgram, "time");
    GLint isAliveLoc = glGetUniformLocation(currentProgram, "isAlive");
    GLint layerHeightLoc = glGetUniformLocation(currentProgram, "layerHeight");
    
    // Check if we got all required uniform locations
    if (modelLoc == -1 || viewLoc == -1 || projLoc == -1 || timeLoc == -1 || 
        isAliveLoc == -1 || layerHeightLoc == -1) {
        printf("Error: Failed to get uniform locations for shader program %d\n", currentProgram);
        return;
    }
    
    // Set uniforms
    glm::mat4 model = glm::mat4(1.0f);
    glUniformMatrix4fv(modelLoc, 1, GL_FALSE, glm::value_ptr(model));
    glUniformMatrix4fv(viewLoc, 1, GL_FALSE, glm::value_ptr(view));
    glUniformMatrix4fv(projLoc, 1, GL_FALSE, glm::value_ptr(projection));
    glUniform1f(timeLoc, globalTime);
    
    //printf("[DEBUG] Uniform locations: model=%d, view=%d, proj=%d, time=%d, isAlive=%d, layerHeight=%d\n", modelLoc, viewLoc, projLoc, timeLoc, isAliveLoc, layerHeightLoc);
    GL_CHECK_ERROR("getUniformLocations");
    
    // Before drawing the grid, upload active snowflake effects to the shader
    GLint numEffectsLoc = glGetUniformLocation(shaderProgram, "numEffects");
    glUniform1i(numEffectsLoc, std::min((int)snowflakeEffects.size(), 16));
    for (int i = 0; i < std::min((int)snowflakeEffects.size(), 16); ++i) {
        char name[64];
        snprintf(name, sizeof(name), "effects[%d].position", i);
        glUniform2f(glGetUniformLocation(shaderProgram, name), snowflakeEffects[i].position.x, snowflakeEffects[i].position.y);
        snprintf(name, sizeof(name), "effects[%d].radius", i);
        glUniform1f(glGetUniformLocation(shaderProgram, name), snowflakeEffects[i].radius);
        snprintf(name, sizeof(name), "effects[%d].effectType", i);
        glUniform1i(glGetUniformLocation(shaderProgram, name), snowflakeEffects[i].effectType);
        snprintf(name, sizeof(name), "effects[%d].blend", i);
        glUniform1f(glGetUniformLocation(shaderProgram, name), snowflakeEffects[i].blendFactor);
    }
    
    // Draw pixels
    glBindVertexArray(VAO);
    GL_CHECK_ERROR("glBindVertexArray grid");
    for (int y = 0; y < GRID_HEIGHT; ++y) {
        for (int x = 0; x < GRID_WIDTH; ++x) {
            int pixelIdx = y * GRID_WIDTH + x;
            int golIdx = golBlockIndices[pixelIdx];
            
            glUniform1i(isAliveLoc, gameOfLifeState[golIdx] ? 1 : 0);
            glUniform1f(layerHeightLoc, layerPixels[pixelIdx].height);
            glUniform1i(glGetUniformLocation(currentProgram, "effectType"), -1);
            glDrawArrays(GL_TRIANGLE_FAN, pixelIdx * 4, 4);
        }
    }
    GL_CHECK_ERROR("glDrawArrays grid");
    glBindVertexArray(0);
    
    // Draw snow particles
    if (globalTime >= 12.0f) {
        // Set up blending for snow particles
        glEnable(GL_BLEND);
        glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA);
        
        // Set uniforms for snow particles
        glUniformMatrix4fv(modelLoc, 1, GL_FALSE, glm::value_ptr(model));
        glUniformMatrix4fv(viewLoc, 1, GL_FALSE, glm::value_ptr(view));
        glUniformMatrix4fv(projLoc, 1, GL_FALSE, glm::value_ptr(projection));
        glUniform1f(timeLoc, globalTime);
        glUniform1i(isAliveLoc, 0);  // Snowflakes are not alive cells
        glUniform1f(layerHeightLoc, 0.0f);  // No layer height for snowflakes
        
        // Bind snow VAO
        glBindVertexArray(snowVAO);
        glBindBuffer(GL_ARRAY_BUFFER, snowVBO);
        
        GL_CHECK_ERROR("before snowflake draw");
        
        // Batch all snowflake vertex data (optimization #3)
        snowflakeVertexData.clear();
        snowflakeVertexData.reserve(snowParticles.size() * 4);
        
        // Replace the quad-based vertex generation with this snowflake pattern
        // In the vertex generation loop
        // In the vertex generation loop, replace the current snowflake generation with:
        for (size_t particleIdx = 0; particleIdx < snowParticles.size(); ++particleIdx) {
            const auto& particle = snowParticles[particleIdx];
            
            // Center point
            glm::vec3 center = particle.position;
            
            // Create 6 arms of the snowflake
            for (int arm = 0; arm < 6; ++arm) {
                float angle = arm * (M_PI / 3.0f); // 60 degree spacing
                float rotatedAngle = angle + particle.rotation;
                
                // Main arm
                glm::vec3 tip = center + glm::vec3(
                    cos(rotatedAngle) * particle.size,
                    sin(rotatedAngle) * particle.size,
                    0.0f
                );
                
                // Add the triangle for this arm
                snowflakeVertexData.push_back({
                    center,
                    glm::vec3(0.8f, 0.8f, 1.0f),
                    particle.alpha,
                    static_cast<float>(particleIdx)
                });
                
                // Add two vertices to make a triangle
                float halfWidth = particle.size * 0.1f; // Width of the arm
                float perpAngle = rotatedAngle + M_PI/2.0f;
                glm::vec3 perpOffset = glm::vec3(
                    cos(perpAngle) * halfWidth,
                    sin(perpAngle) * halfWidth,
                    0.0f
                );
                
                snowflakeVertexData.push_back({
                    tip + perpOffset,
                    glm::vec3(0.8f, 0.8f, 1.0f),
                    particle.alpha,
                    static_cast<float>(particleIdx)
                });
                
                snowflakeVertexData.push_back({
                    tip - perpOffset,
                    glm::vec3(0.8f, 0.8f, 1.0f),
                    particle.alpha,
                    static_cast<float>(particleIdx)
                });
            }
        }

       
        // Single buffer update for all snowflakes (optimization #3)
        if (!snowflakeVertexData.empty()) {
            glBufferSubData(GL_ARRAY_BUFFER, 0, snowflakeVertexData.size() * sizeof(VertexData), snowflakeVertexData.data());
            
            // Draw all snowflakes at once
            for (size_t particleIdx = 0; particleIdx < snowParticles.size(); ++particleIdx) {
                const auto& particle = snowParticles[particleIdx];
                glUniform1i(uniforms.snowflakeEffectType, particle.effectType);
                
                // Draw all 6 arms of this snowflake
                // Each arm is 3 vertices, 6 arms total = 18 vertices
                int vertexOffset = particleIdx * 18;
                glDrawArrays(GL_TRIANGLES, vertexOffset, 18);
            }
        }
        
        // Unbind VAO and VBO
        glBindBuffer(GL_ARRAY_BUFFER, 0);
        glBindVertexArray(0);
        
        glDisable(GL_BLEND);
    }
    
    // Calculate morph timing for continuous back-and-forth motion
// Calculate morph timing for continuous progression

    float morphProgress = 0.0f;
    if (globalTime > morphStartTime) {
    // Continuous progress without modulo
        morphProgress = (globalTime - morphStartTime) / morphDuration;
    }
    glUniform1f(uniforms.morphStartTime, morphStartTime);
    glUniform1f(uniforms.morphProgress, morphProgress);
    glm::vec2 gridCenter(160.0f, 100.0f); // Assuming these are your grid dimensions/2
    float maxGridRadius = glm::length(gridCenter); // Calculate based on your grid size
glUniform1f(glGetUniformLocation(shaderProgram, "maxGridRadius"), maxGridRadius);
glUniform1f(glGetUniformLocation(shaderProgram, "pulseStrength"), 0.5f); 
glUniform1f(glGetUniformLocation(shaderProgram, "beatIntensity"), audio_bass_intensity);
    
    glutSwapBuffers();
}

// Update function
void update(int value) {
    globalTime += 0.016f;  // ~60 FPS

    static float lastTime = 0.0f;
    float currentTime = glutGet(GLUT_ELAPSED_TIME) / 1000.0f;
    float deltaTime = currentTime - lastTime;
    
    if (currentTime >= GRID_DURATION_SECONDS) {  // Only need this constant
        printf("Transitioning to Rubik's Cube demo\n");
        
        // Clean up pixel grid resources but keep the window and context
        glDeleteProgram(shaderProgram);
        glDeleteProgram(simpleShaderProgram);
        glDeleteVertexArrays(1, &VAO);
        glDeleteBuffers(1, &VBO);
        glDeleteVertexArrays(1, &VAO);
        glDeleteVertexArrays(1, &snowVAO);

        glDisable(GL_BLEND);
        glBindVertexArray(0);
        glBindBuffer(GL_ARRAY_BUFFER, 0);
        glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, 0);

        globalTime = 0.0f;
        animationCompleteTime = 0.0f;
        
        // Initialize Rubik's cube components
        if (!initShadersRubik() || !initBuffersRubik() || !initMirrorFloor()) {
            printf("Failed to initialize Rubik's cube components\n");
            return;
        }
        
        // Set up Rubik's cube callbacks and state
        glutDisplayFunc(displayRubik);
        glutKeyboardFunc(keyboardRubik);
        glutSpecialFunc(specialKeyboard);
        glutTimerFunc(0, updateRubik, 0);
        
        // Update OpenGL state
        glClearColor(0.1f, 0.1f, 0.15f, 1.0f);
        glEnable(GL_DEPTH_TEST);
        glEnable(GL_CULL_FACE);
       //initPongGame();
        startAutomaticShuffle();
        return;  // Stop the pixel grid update cycle
    }
    
    lastTime = currentTime;
    // ... rest of existing update code ...


    // Update audio-responsive effects
    if (audio_player && audio_player->get_playback_state()) {
        std::lock_guard<std::mutex> lock(audio_features_mutex);
        
        // Modify effect parameters based on audio
        EFFECT_TIME_SCALE = 0.2f + audio_energy_level * 0.8f;  // Speed up effects with energy
        
        // Bass-driven snowflake effects
        if (audio_bass_intensity > 0.7f) {
            // Create bass-driven effects
            for (int i = 0; i < 3; ++i) {
                SnowflakeEffect effect;
                effect.position = glm::vec2(
                    (rand() % GRID_WIDTH) - GRID_WIDTH/2.0f,
                    (rand() % GRID_HEIGHT) - GRID_HEIGHT/2.0f
                );
                effect.radius = EFFECT_RADIUS * (0.5f + audio_bass_intensity);
                effect.effectType = 0;  // Plasma pattern for bass
                effect.blendFactor = audio_bass_intensity;
                effect.lastUpdate = globalTime;
                snowflakeEffects.push_back(effect);
            }
        }
        
        // Treble-driven particle effects
        if (audio_treble_intensity > 0.6f) {
            // Increase particle spawn rate and size based on treble
            for (auto& particle : snowParticles) {
                if (!particle.hasLanded) {
                    particle.size = std::min(SNOW_MAX_SIZE, 
                                           particle.size * (1.0f + audio_treble_intensity * 0.5f));
                    particle.speed = std::min(SNOW_MAX_SPEED, 
                                            particle.speed * (1.0f + audio_treble_intensity * 0.3f));
                }
            }
        }
        
        // Mid-frequency driven layer effects
        if (audio_mid_intensity > 0.5f) {
            for (int y = 0; y < GRID_HEIGHT; y += 4) {
                for (int x = 0; x < GRID_WIDTH; x += 4) {
                    int idx = y * GRID_WIDTH + x;
                    if (idx < layerPixels.size()) {
                        layerPixels[idx].height += audio_mid_intensity * 0.5f;
                        layerPixels[idx].height = std::min(layerPixels[idx].height, MAX_HEIGHT);
                    }
                }
            }
        }
    }
    
    glutPostRedisplay();
    glutTimerFunc(16, update, 0);
}

// Add after initBuffers()
void initSnowBuffers() {
    // Create and bind VAO
    glGenVertexArrays(1, &snowVAO);
    glBindVertexArray(snowVAO);

    // Create and bind VBO
    glGenBuffers(1, &snowVBO);
    glBindBuffer(GL_ARRAY_BUFFER, snowVBO);

    // Pre-allocate vertex data buffer (optimization #3)
    snowflakeVertexData.reserve(MAX_SNOW_PARTICLES * 4);

    // Allocate buffer for vertex data (4 vertices per particle)
    // In initSnowBuffers()
    // Update buffer size: 6 arms * 3 vertices per arm = 18 vertices per snowflake
    glBufferData(GL_ARRAY_BUFFER, MAX_SNOW_PARTICLES * 18 * sizeof(VertexData), nullptr, GL_DYNAMIC_DRAW);

    // Set up vertex attributes (optimization #10 - do this once)
    // Position attribute
    glVertexAttribPointer(0, 3, GL_FLOAT, GL_FALSE, sizeof(VertexData), (void*)0);
    glEnableVertexAttribArray(0);
    // Color attribute
    glVertexAttribPointer(1, 3, GL_FLOAT, GL_FALSE, sizeof(VertexData), (void*)offsetof(VertexData, color));
    glEnableVertexAttribArray(1);
    // Brightness attribute
    glVertexAttribPointer(2, 1, GL_FLOAT, GL_FALSE, sizeof(VertexData), (void*)offsetof(VertexData, brightness));
    glEnableVertexAttribArray(2);
    // Index attribute
    glVertexAttribPointer(3, 1, GL_FLOAT, GL_FALSE, sizeof(VertexData), (void*)offsetof(VertexData, index));
    glEnableVertexAttribArray(3);

    // Unbind VAO and VBO
    glBindBuffer(GL_ARRAY_BUFFER, 0);
    glBindVertexArray(0);
}

// Initialize pixels with plasma pattern  
void initPixels() {
    std::uniform_real_distribution<float> brightnessDist(0.7f, 1.0f);
    
    pixels.resize(GRID_WIDTH * GRID_HEIGHT);
    
    // Generate initial plasma pattern
    for (int y = 0; y < GRID_HEIGHT; ++y) {
        for (int x = 0; x < GRID_WIDTH; ++x) {
            int idx = y * GRID_WIDTH + x;
            
            // Normalized coordinates
            float nx = (float)x / GRID_WIDTH;
            float ny = (float)y / GRID_HEIGHT;
            
            // Create initial plasma pattern
            float v1 = sin(nx * 10.0f + ny * 10.0f);
            float v2 = sin(nx * 20.0f - ny * 20.0f);
            float v3 = sin(sqrt(nx * nx + ny * ny) * 30.0f);
            
            // Combine patterns
            float plasma = (v1 + v2 + v3) / 3.0f;
            plasma = (plasma + 1.0f) / 2.0f;  // Normalize to 0-1
            
            // Store plasma value in color (we'll use it as a palette index)
            pixels[idx].color = glm::vec3(plasma, 0.0f, 0.0f);
            pixels[idx].brightness = brightnessDist(globalRng);
        }
    }
}

// Add keyboard callback function before main()
void keyboard(unsigned char key, int x, int y) {
    switch (key) {
        case 27:  // ESC key
            printf("ESC pressed - exiting program\n");
            cleanupAudio();
            exit(0);
            break;
        case ' ':  // Spacebar - play/pause
            if (audio_player) {
                if (audio_player->get_playback_state()) {
                    audio_player->pause();
                    printf("Audio paused\n");
                } else {
                    audio_player->play();
                    printf("Audio playing\n");
                }
            }
            break;
        case 's':  // Stop
            if (audio_player) {
                audio_player->stop();
                printf("Audio stopped\n");
            }
            break;
        case 'r':  // Restart
            if (audio_player) {
                audio_player->seek(0.0);
                audio_player->play();
                printf("Audio restarted\n");
            }
            break;
        case 'i':  // Info
            if (audio_player) {
                printf("Audio Info:\n");
                printf("  File: %s\n", current_mp3_file.c_str());
                printf("  Position: %.2f / %.2f seconds\n", 
                       audio_player->get_position(), audio_player->get_duration());
                printf("  Tempo: %.1f BPM\n", audio_tempo);
                printf("  Energy: %.2f\n", audio_energy_level);
                printf("  Bass: %.2f, Mid: %.2f, Treble: %.2f\n", 
                       audio_bass_intensity, audio_mid_intensity, audio_treble_intensity);
            }
            break;
        case 'h':  // Help
            printf("Audio Visualizer Controls:\n");
            printf("  SPACE - Play/Pause\n");
            printf("  S     - Stop\n");
            printf("  R     - Restart\n");
            printf("  I     - Show audio info\n");
            printf("  H     - Show this help\n");
            printf("  ESC   - Exit\n");
            break;
        default:
            break;
    }
}

// Main function
int main(int argc, char** argv) {
    printf("🎵 Audio-Reactive Pixel Grid Visualizer\n");
    printf("=======================================\n");
    
    // Check for MP3 file argument
    std::string mp3_file;
    if (argc > 1) {
        mp3_file = std::string(argv[1]);
        printf("MP3 file specified: %s\n", mp3_file.c_str());
    } else {
        // Look for default MP3 file in current directory
        mp3_file = "atklojo25-deph.mp3";
        printf("No MP3 file specified, trying default: %s\n", mp3_file.c_str());
    }
    
    glutInit(&argc, argv);
    // Add after glutInit
    glutSetOption(GLUT_ACTION_ON_WINDOW_CLOSE, GLUT_ACTION_GLUTMAINLOOP_RETURNS);
    glutInitDisplayMode(GLUT_DOUBLE | GLUT_RGB | GLUT_DEPTH);
    
    // Set window size maintaining aspect ratio
    windowWidth = 800;
    windowHeight = windowWidth * GRID_HEIGHT / GRID_WIDTH;
    glutInitWindowSize(windowWidth, windowHeight);
    glutCreateWindow("Audio-Reactive Pixel Grid");
    
    // Make the application fullscreen
    glutFullScreen();
    
    // Initialize GLEW
    GLenum err = glewInit();
    if (err != GLEW_OK) {
        printf("GLEW init failed: %s\n", glewGetErrorString(err));
        return 1;
    }
    
    if (!initShaders()) {
        printf("Shader initialization failed\n");
        return 1;
    }
    
    // Initialize audio system
    printf("Initializing audio system...\n");
    if (!initAudio(mp3_file)) {
        printf("Audio initialization failed, continuing without audio\n");
    } else {
        // Start playback automatically if file was loaded
        if (!current_mp3_file.empty() && audio_player) {
            printf("Starting audio playback...\n");
            audio_player->play();
        }
    }
    
    initPixels();
    initGameOfLife();  // Re-enabled
    initLayers();
    initSnowParticles();
    initBuffers();
    initSnowBuffers();
    initGolBlockIndices();
    updateProjection();
    
    glClearColor(0.0f, 0.0f, 0.0f, 1.0f);
    glEnable(GL_DEPTH_TEST);
    
    printf("\nControls:\n");
    printf("  SPACE - Play/Pause audio\n");
    printf("  S     - Stop audio\n");
    printf("  R     - Restart audio\n");
    printf("  I     - Show audio info\n");
    printf("  H     - Show help\n");
    printf("  ESC   - Exit\n\n");
    
    glutDisplayFunc(display);
    glutReshapeFunc(reshape);
    glutKeyboardFunc(keyboard);  // Register keyboard callback
    glutTimerFunc(0, update, 0);
    
    // Register cleanup function
    atexit(cleanupAudio);
    
    glutMainLoop();
    // Grid demo exits

    printf("Rubik's Cube Program start\n");
    
    // Seed random number generator
    srand(time(NULL));
    
    // Initialize GLUT with standard display mode
    glutInit(&argc, argv);
    glutInitDisplayMode(GLUT_DOUBLE | GLUT_RGB | GLUT_DEPTH);
 
    // Create main window
    //glutInitWindowSize(800, 800);
    //glutInitWindowPosition(100, 100);
    //mainWindow = glutCreateWindow("Rubik's Cube");
    //printf("Debug: Main window created: %d\n", mainWindow);
    
    // Initialize GLEW in main window context
    //err = glewInit();
    //if (err != GLEW_OK) {
    //    printf("GLEW init failed: %s\n", glewGetErrorString(err));
    //    return 1;
    //}
    
    // Initialize shaders, buffers, and mirror floor
    if (!initShadersRubik() || !initBuffersRubik() || !initMirrorFloor()) {
        cleanup();
        return 1;
    }
    
    // Basic OpenGL setup for main window


    glutDisplayFunc(displayRubik);
        glutSpecialFunc(specialKeyboard); // Register special key handler for arrow keys
    glutKeyboardFunc(keyboardRubik);
	glutSetCursor(GLUT_CURSOR_NONE);
     glutTimerFunc(0, updateRubik, 0);
    glClearColor(0.1f, 0.1f, 0.15f, 1.0f); // Darker blue background
    glEnable(GL_DEPTH_TEST);
    glEnable(GL_CULL_FACE);
    glutMainLoop();

    printf("Back in main\n");
    return 0;
} 
