#include "dmeval.h"
#include <math.h>

#define DM_MAX_BUF     512


/* Operators
 */
const DMEvalOper dmEvalOpers[OP_NOPERS] =
{
    { "-"      , OT_UNARY , FALSE },
    { "~"      , OT_UNARY , TRUE },
    
    { "+"      , OT_LEFT  , TRUE },
    { "-"      , OT_LEFT  , TRUE },
    { "*"      , OT_LEFT  , TRUE },
    { "/"      , OT_LEFT  , TRUE },
    { "%"      , OT_LEFT  , TRUE },


    { "<<"     , OT_LEFT  , TRUE },
    { ">>"     , OT_LEFT  , TRUE },

    { "&"      , OT_LEFT  , TRUE },
    { "|"      , OT_LEFT  , TRUE },
    { "^"      , OT_LEFT  , TRUE },

    { ">="     , OT_LEFT  , TRUE },
    { "<="     , OT_LEFT  , TRUE },
    { ">"      , OT_LEFT  , TRUE },
    { "<"      , OT_LEFT  , TRUE },

    { "FUNC"   , OT_NONE  , FALSE },
    { "VAR"    , OT_NONE  , FALSE },
    { "CONST"  , OT_NONE  , FALSE },

    { "SUBEXPR", OT_NONE  , FALSE },
};


/* Function definitions
 */
static DMValue func_int_clip(DMValue *v)
{
    return (*v < -1.0f) ? -1.0f : ((*v > 1.0f) ? 1.0f : *v);
}


static DMValue func_sin(DMValue *v)
{
    return sin(*v);
}


static DMValue func_cos(DMValue *v)
{
    return cos(*v);
}


static DMValue func_pow(DMValue *v)
{
    return pow(v[0], v[1]);
}


/* Some basic functions
 */
static const DMEvalSymbol dmEvalBasicFuncs[] =
{
    { "sin",  SYM_FUNC , 1, func_sin, NULL, 0 },
    { "cos",  SYM_FUNC , 1, func_cos, NULL, 0 },
    { "clip", SYM_FUNC , 1, func_int_clip, NULL, 0 },
    { "pow",  SYM_FUNC , 2, func_pow, NULL, 0 },

    { "pi",   SYM_CONST, 0, NULL, NULL, DM_PI },
    { "e",    SYM_CONST, 0, NULL, NULL, DM_E },
};

static const int ndmEvalBasicFuncs = sizeof(dmEvalBasicFuncs) / sizeof(dmEvalBasicFuncs[0]);


void dmEvalErrorV(DMEvalContext *ev, const char *fmt, va_list ap)
{
    char *tmp = dm_strdup_vprintf(fmt, ap);
    
    ev->err = TRUE;

    if (ev->errStr != NULL)
    {
        ev->errStr = dm_strdup_printf("%s%s", ev->errStr, tmp);
        dmFree(tmp);
    }
    else
        ev->errStr = tmp;
}


void dmEvalError(DMEvalContext *ev, const char *fmt, ...)
{
    va_list ap;
    va_start(ap, fmt);
    dmEvalErrorV(ev, fmt, ap);
    va_end(ap);
}


DMEvalSymbol *dmEvalContextFindSymbol(DMEvalContext *ev, const char *name)
{
    int i;
    if (ev->symbols == NULL)
        return NULL;

    for (i = 0; i < ev->nsymbols; i++)
    {
        if (strcmp(ev->symbols[i].name, name) == 0)
            return &(ev->symbols[i]);
    }
    
    return NULL;
}


// Add a new symbol to the evaluation context.
// Return pointer to newly allocated symbol struct if successful.
// If the symbol already exists or there was a memory allocation
// error, NULL is returned.
static DMEvalSymbol * dmEvalContextAddSymbol(DMEvalContext *ev, const char *name, const int type)
{
    DMEvalSymbol *symbol = dmEvalContextFindSymbol(ev, name);
    if (symbol != NULL)
        return NULL;

    ev->symbols = dmRealloc(ev->symbols, sizeof(DMEvalSymbol) * (ev->nsymbols + 1));
    if (ev->symbols == NULL)
    {
        dmEvalError(ev,
            "Could not reallocate eval symbols array (#%d). Fatal error.\n",
            ev->nsymbols + 1);
        return NULL;
    }
    
    symbol = &(ev->symbols[ev->nsymbols]);
    ev->nsymbols++;

    memset(symbol, 0, sizeof(DMEvalSymbol));
    symbol->name = dm_strdup(name);
    symbol->type = type;

    return symbol;
}


DMEvalSymbol *dmEvalContextAddVar(DMEvalContext *ev, const char *name, DMValue *var)
{
    DMEvalSymbol *symbol = dmEvalContextAddSymbol(ev, name, SYM_VAR);
    if (symbol == NULL)
        return NULL;
    
    symbol->var = var;
    return symbol;
}


DMEvalSymbol *dmEvalContextAddConst(DMEvalContext *ev, const char *name, DMValue value)
{
    DMEvalSymbol *symbol = dmEvalContextAddSymbol(ev, name, SYM_CONST);
    if (symbol == NULL)
        return NULL;
    
    symbol->cvalue = value;
    return symbol;
}


DMEvalSymbol *dmEvalContextAddFunc(DMEvalContext *ev, const char *name, DMValue (*func)(DMValue *), int nargs)
{
    DMEvalSymbol *symbol = dmEvalContextAddSymbol(ev, name, SYM_VAR);
    if (symbol == NULL)
        return NULL;
    
    symbol->func  = func;
    symbol->nargs = nargs;

    return DMERR_OK;
}


DMEvalContext *dmEvalContextNew(void)
{
    int i;
    DMEvalContext *ev = dmCalloc(1, sizeof(DMEvalContext));

    if (ev == NULL)
        return NULL;

    for (i = 0; i < ndmEvalBasicFuncs; i++)
    {
        const DMEvalSymbol *symbol= &dmEvalBasicFuncs[i];
        DMEvalSymbol *nsymbol = dmEvalContextAddSymbol(ev, symbol->name, symbol->type);
        if (nsymbol != NULL)
        {
            nsymbol->nargs    = symbol->nargs;
            nsymbol->func     = symbol->func;
            nsymbol->var      = symbol->var;
            nsymbol->cvalue   = symbol->cvalue;
        }
    }

    return ev;
}


void dmEvalTreeFree(DMEvalNode *node)
{
    while (node != NULL)
    {
        DMEvalNode *next = node->next;
        int i;

        for (i = 0; i < DM_MAX_ARGS; i++)
        {
            dmEvalTreeFree(node->args[i]);
            node->args[i] = NULL;
        }

        dmEvalTreeFree(node->subexpr);
        node->subexpr = NULL;
        dmFree(node);
        node = next;
    }
}


void dmEvalContextClear(DMEvalContext *ev)
{
    if (ev == NULL)
        return;
    
    dmFree(ev->errStr);
    ev->err = FALSE;
    ev->errStr = NULL;
}


void dmEvalContextClose(DMEvalContext *ev)
{
    int i;

    if (ev == NULL)
        return;

    for (i = 0; i < ev->nsymbols; i++)
        dmFree(ev->symbols[i].name);

    dmFree(ev->symbols);
    dmEvalContextClear(ev);
    dmFree(ev);
}


static DMEvalNode *dmEvalInsertNode(DMEvalNode **list, DMEvalNode *node)
{
    if (*list != NULL)
    {
        node->prev = (*list)->prev;
        (*list)->prev->next = node;
        (*list)->prev = node;
    }
    else
    {
        *list = node;
        node->prev = *list;
    }

    node->next = NULL;
    return node;
}


static DMEvalNode *dmEvalAddNode(DMEvalNode **list, const int op)
{
    DMEvalNode *node = dmCalloc(1, sizeof(DMEvalNode));
    if (node == NULL)
        return NULL;

    node->op = op;

    return dmEvalInsertNode(list, node);
}


enum
{
    PARSE_NONE          = 0x0000,
    PARSE_START         = 0x1000,
    PARSE_END           = 0x2000,
    PARSE_ERROR         = 0x8000,

    PARSE_IDENT         = 0x0001, // Any identifier (variable, function name)
    PARSE_CONST         = 0x0002, // Constant value (n, n.nnn, etc)
    PARSE_OPER          = 0x0004, // All operators
    PARSE_OPER_UNARY    = 0x0008, // Unary operators ~, -
    PARSE_SUBEXPR_START = 0x0010, // ( ...
    PARSE_SUBEXPR_END   = 0x0020, // )
    PARSE_ARGS          = 0x0040, // function args: (xxx[, yyy ...])

    PARSE_NORMAL        = PARSE_CONST | PARSE_IDENT | PARSE_SUBEXPR_START | PARSE_OPER_UNARY,
};

#define DM_CHECK(x) { if (mode & PARSE_ ## x ) { if (str[0]) strcat(str, " or "); strcat(str, # x ); } }

static char *dmEvalGetMode(int mode)
{
    char str[128] = "";

    DM_CHECK(START);
    DM_CHECK(END);
    DM_CHECK(IDENT);
    DM_CHECK(CONST);
    DM_CHECK(OPER);
    DM_CHECK(OPER_UNARY);
    DM_CHECK(SUBEXPR_START);
    DM_CHECK(SUBEXPR_END);
    DM_CHECK(ARGS);
            
    return dm_strdup(str);
}


static void dmEvalSetMode(DMEvalContext *ev, const int mode)
{
    if (mode != PARSE_ERROR &&
        mode != PARSE_START &&
        ev->expect != PARSE_NONE &&
        (mode & ev->expect) == 0)
    {
        char *tmp1 = dmEvalGetMode(ev->expect),
             *tmp2 = dmEvalGetMode(mode);

        dmEvalError(ev, "Expected [%s], got %s.\n", tmp1, tmp2);
        dmFree(tmp1);
        dmFree(tmp2);
    }

    ev->prev = ev->mode;
    ev->mode = mode;
}


static BOOL dmEvalTokenizeExpr(DMEvalContext *ev, DMEvalNode **list, char **str, int depth)
{
    char *c = *str;
    char tmpStr[DM_MAX_BUF + 2], *tmp;
    int tmpStrLen = 0, argIndex, op;
    DMEvalNode *node = NULL, *func = NULL;
    BOOL first = FALSE, decimal = FALSE;

    ev->expect = PARSE_NORMAL;
    ev->mode = PARSE_START;
    
    while (ev->mode != PARSE_ERROR && ev->mode != PARSE_END)
    switch (ev->mode)
    {
        case PARSE_START:
            // Start
            if (*c == 0)
                dmEvalSetMode(ev, PARSE_END);

            // Skip whitespace
            else if (isspace(*c))
                c++;

            else if (*c == ')' || *c == ',')
            {
                if (depth > 0)
                    dmEvalSetMode(ev, PARSE_END);
                else
                {
                    dmEvalError(ev, "Invalid nesting near '%s' (depth %d).\n", c, depth);
                    dmEvalSetMode(ev, PARSE_ERROR);
                }
                c++;
            }
            
            else if (*c == '(')
                dmEvalSetMode(ev, func != NULL ? PARSE_ARGS : PARSE_SUBEXPR_START);
            
            else if (*c == '-')
                dmEvalSetMode(ev, (ev->prev == PARSE_START || ev->prev == PARSE_OPER || ev->prev == PARSE_OPER_UNARY) ? PARSE_OPER_UNARY : PARSE_OPER);
            
            else if (*c == '~')
                dmEvalSetMode(ev, PARSE_OPER_UNARY);
            
            else if (strchr("+*/<>%&|!^", *c))
                dmEvalSetMode(ev, PARSE_OPER);

            else if (isdigit(*c) || *c == '.')
                dmEvalSetMode(ev, PARSE_CONST);

            else if (isalpha(*c) || *c == '_')
                dmEvalSetMode(ev, PARSE_IDENT);
            
            else
            {
                dmEvalError(ev, "Syntax error near '%s' (depth %d).\n", c, depth);
                dmEvalSetMode(ev, PARSE_ERROR);
            }

            first = TRUE;
            break;
        
        case PARSE_SUBEXPR_START:
            tmp = c + 1;

            ev->expect = PARSE_NORMAL;

            if ((node = dmEvalAddNode(list, OP_SUBEXPR)) == NULL)
                dmEvalSetMode(ev, PARSE_ERROR);
            else
            if (dmEvalTokenizeExpr(ev, &(node->subexpr), &tmp, depth + 1) != 0)
            {
                dmEvalError(ev, "Subexpression starting at '%s' contained errors.\n", c);
                dmEvalSetMode(ev, PARSE_ERROR);
            }

            if (ev->mode != PARSE_ERROR)
            {
                dmEvalSetMode(ev, PARSE_START);
                ev->expect = PARSE_OPER | PARSE_SUBEXPR_END;
                c = tmp;
            }
            break;

        case PARSE_ARGS:
            tmp = c + 1;
            
            for (argIndex = 0; argIndex < func->symbol->nargs; argIndex++)
            {
                if (dmEvalTokenizeExpr(ev, &(func->args[argIndex]), &tmp, depth + 1) != 0)
                {
                    dmEvalError(ev, "Function argument subexpression starting at '%s' contained errors.\n", c);
                    dmEvalSetMode(ev, PARSE_ERROR);
                }
            }

            func = NULL;

            if (ev->mode != PARSE_ERROR)
            {
                dmEvalSetMode(ev, PARSE_START);
                ev->expect = PARSE_OPER | PARSE_END;
                c = tmp;
            }
            break;

        case PARSE_CONST:
            if (first)
            {
                first = FALSE;
                decimal = FALSE;
                tmpStrLen = 0;
                
                if (isdigit(*c) || *c == '-' || *c == '+' || *c == '.')
                {
                    if (*c == '.')
                        decimal = TRUE;
                    tmpStr[tmpStrLen++] = *c++;
                }
                else
                {
                    dmEvalError(ev, "Invalid constant expression near '%s'.\n", c);
                    dmEvalSetMode(ev, PARSE_ERROR);
                }
            }
            else
            {
                if (isdigit(*c))
                {
                    tmpStr[tmpStrLen++] = *c++;
                }
                else
                if (*c == '.')
                {
                    if (!decimal)
                    {
                        tmpStr[tmpStrLen++] = *c++;
                        decimal = TRUE;
                    }
                    else
                    {
                        dmEvalError(ev, "Invalid constant expression near '%s'.\n", c);
                        dmEvalSetMode(ev, PARSE_ERROR);
                    }
                }
                else
                {
                    tmpStr[tmpStrLen] = 0;

                    if ((node = dmEvalAddNode(list, OP_VALUE)) == NULL)
                    {
                        dmEvalSetMode(ev, PARSE_ERROR);
                    }
                    else
                    {
                        node->val = atof(tmpStr);
                        dmEvalSetMode(ev, PARSE_START);
                        ev->expect = PARSE_OPER | PARSE_END;
                    }
                }
            }
            break;

        case PARSE_OPER_UNARY:
            {
                int op = OP_INVALID;

                switch (*c)
                {
                    case '-': op = OP_SUB_UNARY; c++; break;
                    case '~': op = OP_BIT_COMPLEMENT; c++; break;
                }

                if (op != OP_INVALID)
                {
                    if ((node = dmEvalAddNode(list, op)) != NULL)
                    {
                        ev->expect = PARSE_NORMAL;
                        dmEvalSetMode(ev, PARSE_START);
                    }
                    else
                        dmEvalSetMode(ev, PARSE_ERROR);
                }
            }
            break;

        case PARSE_OPER:
            op = OP_INVALID;

            switch (*c)
            {
                case '+': op = OP_ADD; c++; break;
                case '-': op = OP_SUB; c++; break;
                case '*': op = OP_MUL; c++; break;
                case '/': op = OP_DIV; c++; break;
                case '%': op = OP_MOD; c++; break;
                case '&': op = OP_BIT_AND; c++; break;
                case '^': op = OP_BIT_XOR; c++; break;
                case '|': op = OP_BIT_OR; c++; break;

                case '>':
                    if (c[1] == '>')
                    {
                        c += 2;
                        op = OP_BIT_RSHIFT;
                    }
                    else
                    {
                        op = (c[1] == '=') ? OP_GT_EQ : OP_GT;
                        c++;
                    }
                    break;
                    
                case '<':
                    if (c[1] == '<')
                    {
                        c += 2;
                        op = OP_BIT_LSHIFT;
                    }
                    else
                    {
                        op = (c[1] == '=') ? OP_LT_EQ : OP_LT;
                        c++;
                    }
                    break;

                default:
                    dmEvalError(ev, "Unknown operator '%c' at %s\n", *c, c);
                    dmEvalSetMode(ev, PARSE_ERROR);
            }

            if (op != OP_INVALID)
            {
                if ((node = dmEvalAddNode(list, op)) != NULL)
                {
                    ev->expect = PARSE_NORMAL | PARSE_OPER_UNARY;
                    dmEvalSetMode(ev, PARSE_START);
                }
                else
                    dmEvalSetMode(ev, PARSE_ERROR);
            }
            break;

        case PARSE_IDENT:
            if (isalnum(*c) || *c == '_')
            {
                if (first)
                {
                    tmpStrLen = 0;
                    first = FALSE;
                }

                if (tmpStrLen < DM_MAX_BUF)
                    tmpStr[tmpStrLen++] = *c++;
                else
                {
                    tmpStr[tmpStrLen] = 0;
                    dmEvalError(ev, "Identifier too long! ('%s') near %s\n", tmpStr, c);
                }
            }
            else
            {
                tmpStr[tmpStrLen] = 0;
                DMEvalSymbol *symbol = dmEvalContextFindSymbol(ev, tmpStr);
                if (symbol != NULL)
                {
                    if ((node = dmEvalAddNode(list, symbol->type == SYM_FUNC ? OP_FUNC : OP_VAR)) != NULL)
                    {
                        node->symbol = symbol;
                        if (symbol->type == SYM_FUNC)
                        {
                            func = node;
                            ev->expect = PARSE_ARGS;
                        }
                        else
                            ev->expect = PARSE_END | PARSE_OPER;

                        dmEvalSetMode(ev, PARSE_START);
                    }
                    else
                        dmEvalSetMode(ev, PARSE_ERROR);
                }
                else
                {
                    dmEvalError(ev, "No such identifier '%s'.\n", tmpStr);
                    dmEvalSetMode(ev, PARSE_ERROR);
                }
            }
            break;
    }

    *str = c;

    return (ev->mode == PARSE_ERROR);
}


int dmEvalParseExpr(DMEvalContext *ev, char *expr, DMEvalNode **result)
{
    int ret;

    if (ev == NULL || result == NULL)
        return DMERR_NULLPTR;

    ev->prev = PARSE_START;
    ret = dmEvalTokenizeExpr(ev, result, &expr, 0);
    
    return ret;
}




BOOL dmEvalTreeExecute(DMEvalContext *ev, DMEvalNode *node, DMValue *presult)
{
    DMValue val1, val2;

    if (node == NULL)
        return FALSE;

    switch (node->op)
    {
        case OP_VAR:
            switch (node->symbol->type)
            {
                case SYM_CONST: *presult = node->symbol->cvalue; return TRUE;
                case SYM_VAR  : *presult = *(node->symbol->var); return TRUE;
            }
            return FALSE;

        case OP_VALUE:
            *presult = node->val;
            return TRUE;
        
        case OP_FUNC:
            
            return TRUE;
        
        case OP_SUBEXPR:
            return dmEvalTreeExecute(ev, node->subexpr, presult);
    
        // Binary operators
        case OP_BIT_LSHIFT:
        case OP_BIT_RSHIFT:
        
        case OP_BIT_AND:
        case OP_BIT_XOR:
        case OP_BIT_OR:

        case OP_ADD:
        case OP_SUB:
        case OP_MUL:
        case OP_DIV:
        case OP_MOD:
            if (!dmEvalTreeExecute(ev, node->left, &val1) ||
                !dmEvalTreeExecute(ev, node->right, &val2))
                return FALSE;

            switch (node->op)
            {
                case OP_DIV:
                    if (val2 == 0)
                    {
                        dmEvalError(ev, "Division by zero.\n");
                        return FALSE;
                    }
                    *presult = val1 / val2;
                    break;


                case OP_MOD:
                    if (val2 == 0)
                    {
                        dmEvalError(ev, "Division by zero.\n");
                        return FALSE;
                    }
                    *presult = DMCONVTYPE val1 % DMCONVTYPE val2;
                    break;

                case OP_BIT_LSHIFT:
                    if (val2 > 31)
                        dmEvalError(ev, "Left shift count >= width of type (%d << %d)\n", val1, val2);
                    *presult = DMCONVTYPE val1 << DMCONVTYPE val2; break;

                case OP_BIT_RSHIFT:
                    if (val2 > 31)
                        dmEvalError(ev, "Right shift count >= width of type (%d >> %d)\n", val1, val2);
                    *presult = DMCONVTYPE val1 >> DMCONVTYPE val2; break;

                case OP_MUL     : *presult = val1 * val2; break;
                case OP_ADD     : *presult = val1 + val2; break;
                case OP_SUB     : *presult = val1 - val2; break;
                case OP_BIT_AND : *presult = DMCONVTYPE val1 & DMCONVTYPE val2; break;
                case OP_BIT_OR  : *presult = DMCONVTYPE val1 | DMCONVTYPE val2; break;
                case OP_BIT_XOR : *presult = DMCONVTYPE val1 ^ DMCONVTYPE val2; break;
            }
            return TRUE;

        // Unary operators
        case OP_SUB_UNARY:
        case OP_BIT_COMPLEMENT:
/*
            switch (node->op)
            {
                case OP_SUB_UNARY: *presult -= tmp; break;
                case OP_BIT_COMPLEMENT: *presult = DMCONVTYPE ~(DMCONVTYPE tmp); break;
            }
*/
            return TRUE;

        default:
            dmEvalError(ev, "Invalid opcode %d in node %p.\n", node->op, node);
            return FALSE;
    }
}
