/* API function registration and execution
 *  Copyright 2001,2002 Patrick TJ McPhee. All rights reserved.
 *
 * Distributed under the terms of the Mozilla Public Licence
 * You can obtain a copy of the licence at http://www.mozilla.org/MPL
 * The Original Code is w32utils
 * The Initial Developer is Patrick TJ McPhee
 *
 * $Header: C:/ptjm/rexx/w32funcs/RCS/addfunc.c 1.2 2003/04/06 02:13:24 ptjm Rel $
 */

/* functions defined in this file:
 * w32funcadd(entrypoint, dllname, rexxname [, args])
 * w32funcdrop([rexxname])
 * w32funcquery(rexxname)
 */ 

#include <windows.h>
#include "rxproto.h"
#include <stdlib.h>

/* these routines use in-line assembler with a syntax peculiar to microsoft
 * C. I'm assuming they will not work with other compilers. */
#if defined(_MSC_VER) || defined(MSC_INLINE_ASSEMBLER)

/* we keep a list of loaded libraries and the functions that have been
 * found in them. If you always make the library name identical (i.e.,
 * no case or path variations), we will load the library only once.
 * In contrast to IBM rexx, we keep a count of each time a function
 * has been loaded, so loading it twice and dropping it once means it's
 * still loaded.
 */
typedef struct {
   int ref;
   PRXSTRING name;
   HMODULE lh;
} *t_Libs;

typedef struct {
   int ref;
   PRXSTRING rexxname, realname, args;
   int argc;
   FARPROC fh;
   t_Libs lib;
   char rctype;
} *t_Funcs;

static t_Libs liblist = NULL;
static t_Funcs funclist = NULL;
static int libcount = 0, funccount = 0;
static int liballoc = 0, funcalloc = 0;

#define LOAD_INCREMENT 10

/* compare two rx strings -- by length, then case-sensitively since
 * I believe this to be more efficient  */
static int rxstrcmp(const PRXSTRING l, const PRXSTRING r)
{
   register int d = l->strlength - r->strlength;
                
   return d ? d : memcmp(l->strptr, r->strptr, min(l->strlength, r->strlength));
}

/* compare two t_Funcs by the rexxname */
static int funccmp(const void * l, const void * r)
{
   return rxstrcmp(((t_Funcs)l)->rexxname, ((t_Funcs)r)->rexxname);
}

/* compare an rxstring (left) to a t_Funcs by the rexxname */
static int rxfunccmp(const void * l, const void * r)
{
   return rxstrcmp((PRXSTRING)l, ((t_Funcs)r)->rexxname);
}

static t_Funcs findFunc(const PRXSTRING rexxname)
{
   return (t_Funcs)bsearch((const void *)rexxname, (const void *)funclist,
                           funccount, sizeof(*funclist), rxfunccmp);
}


/* duplicate an rxstring and null terminate it */
static PRXSTRING rxstrdupm(PRXSTRING t)
{
   PRXSTRING s;

   s = malloc(sizeof(*t)+t->strlength+1);
   s->strptr = (unsigned char *)(s+1);

   if (s) {
      s->strlength = t->strlength;
      memcpy(s->strptr, t->strptr, t->strlength);
      s->strptr[s->strlength] = 0;
   }

   return s;
}


/* clean up a library -- ignores the reference count and does not
 * remove it from the list */
static void dropLib(t_Libs thisl)
{
   FreeLibrary(thisl->lh);
   free(thisl->name);
}

/* clean up a function -- ignores the library and reference count
 * and does not remove it from the list */
static void dropFunc(t_Funcs thisf)
{
   RexxDeregisterFunction(thisf->rexxname->strptr);
   free(thisf->realname);
   free(thisf->rexxname);
   free(thisf->args);
}


/* finds or loads a library, sets up the liblist entry and returns the
 * index. It doesn't increment the reference count -- that's done by
 * addFunc. Returns -1 if it fails to load. */
static int addLibrary(PRXSTRING name)
{
   register int i;

   /* we don't bother sorting liblist because it's unlikely to help */
   for (i = 0; i < libcount && rxstrcmp(name, liblist[i].name); i++)
      ;

   if (i == libcount) {
      if (name->strlength == 0) {
         SetLastError(ERROR_BAD_ARGUMENTS);
         return -1;
      }

      if (libcount >= liballoc) {
         liblist = realloc(liblist, (liballoc += LOAD_INCREMENT)*sizeof(*liblist));
         if (!liblist)
            return -1;

      }

      liblist[i].ref = 0;

      /* allocate the name -- do this before loading the library because we
       * need to null terminate */
      liblist[i].name = rxstrdupm(name);

      if (!liblist[i].name)
         return -1;

      liblist[i].lh = LoadLibrary(liblist[i].name->strptr);

      if (liblist[i].lh == NULL) {
         free(liblist[i].name);
         return -1;
      }

      libcount++;
   }

   return i;
}


/* here is the routine which actually calls the DLL function. This is the
 * only bit which is compiler-dependent */
static rxfunc(privatecallfunction)
{
   RXSTRING funcname = { strlen(fname), fname };
   t_Funcs thisf = findFunc(&funcname);
   union {
      int intval;
      short shortval;
      double dblval;
      float flval;
      char * strval;
   } * args;
   char * type;
   struct {
      void * ptr;
      int j, len;
      char type;
   } *rcargs;
   int intasm;
   short shortasm;
   int dblasm[2];
   float flasm;
   char * strasm;
   register int i, j;
   int len, rc, rccount = 0, ac;
   char curtype;
   SHVBLOCK shv;
   char buf[21];
   FARPROC fh;

   /* this shouldn't happen -- how did we get called if we don't exist? */
   if (!thisf)
      return 43;

   /* allocate an argument list now, then push them all when we know what they are */
   if (thisf->argc) {
      args = alloca(sizeof(*args)*thisf->argc);
      type = alloca(sizeof(*type)*thisf->argc);
      rcargs = alloca(sizeof(*rcargs)*thisf->argc);
      ac = thisf->argc;

      for (i = j = 0; i < thisf->args->strlength && j < thisf->argc; i++, j++) {

            type[j] = thisf->args->strptr[i];

            if (argv[j].strptr && type[j] != 'p'
                && type[j] != 'b')
               rxstrdup(args[j].strval, argv[j]);
            else
               args[j].strval = NULL;

            switch(type[j]) {
               /* string (or pointer to passed-in memory). We simply pass the
                * null-terminated argument we just created */
               default:
               case 's':
                  break;

               case 'i':
                  args[j].intval = atoi(args[j].strval);
                  break;

               case 'h':
                  args[j].shortval = atoi(args[j].strval);
                  break;

               case 'r':
                  args[j].strval = (char *)(argv+j);
                  break;

               case 'd':
                  args[j].dblval = atof(args[j].strval);
                  break;

               case 'f':
                  args[j].flval = atof(args[j].strval);
                  break;

               /* for return values: a buffer. The size of the buffer can be put
                * in brackets. The rexx variable passed in this position is
                * evaluated and passed to the function, then set on return. */
               case 'b':
                  if (thisf->args->strptr[i+1] == '[') {
                     len = atoi(thisf->args->strptr+1);
                     i += 2 + strcspn(thisf->args->strptr+1, "]");
                  }
                  else {
                     len = DEFAULTSTRINGSIZE;
                  }

                  /* allocate either that amount or the size of *argv[j] */
                  shv.shvnext = NULL;
                  shv.shvname = argv[j];
                  shv.shvvalue.strptr = malloc(len);
                  shv.shvvalue.strlength = shv.shvvaluelen = len;
                  shv.shvcode = RXSHV_SYFET;
                  shv.shvret = 0;
                  RexxVariablePool(&shv);
                  if (shv.shvret & RXSHV_TRUNC) {
                     len = shv.shvvaluelen;
                     shv.shvvalue.strptr = realloc(shv.shvvalue.strptr, len);
                     shv.shvret = 0;
                     RexxVariablePool(&shv);
                  }

                  rcargs[rccount].j = j;
                  rcargs[rccount].type = 'b';
                  rcargs[rccount].len = len;
                  rcargs[rccount++].ptr =
                  args[j].strval = shv.shvvalue.strptr;
                  break;

               /* Another return value. For most of these, we allocate the
                * appropriate kind of thing, assign the value of the corresponding
                * argument variable, and pass a pointer to it. For ps, the actual
                * rexx variable must be a pointer value, which we pass directly
                * to the API. Thus, malloc would have args 'i' and return value ps,
                * while free() would have args 'ps'. To get at the actual contents,
                * you need to call something like memcpy(b, ps, i) for some value
                * of i. pb and pr are not allowed currently.
                */
               case 'p':
                  shv.shvnext = NULL;
                  shv.shvname = argv[j];
                  shv.shvcode = RXSHV_SYFET;
                  shv.shvret = 0;

                  if (thisf->args->strptr[i+1] != 's' &&
                      thisf->args->strptr[i+1] != 'b' &&
                      thisf->args->strptr[i+1] != 'r') {
                     /* take advantage of all types fitting in a pointer */

                     rcargs[rccount].j = j;
                     rcargs[rccount].type = thisf->args->strptr[i+1];
                     rcargs[rccount].len = sizeof(char *);
                     rcargs[rccount++].ptr =
                     args[j].strval = malloc(sizeof(char *));


                     /* none of these types can have more than 20 bytes */
                     shv.shvvalue.strptr = buf;
                     shv.shvvalue.strlength = shv.shvvaluelen = sizeof(buf);

                     RexxVariablePool(&shv);
                     buf[shv.shvvalue.strlength] = 0;

                     switch (thisf->args->strptr[i+1]) {
                        case 'i':
                           *(int *)args[j].strval = atoi(buf);
                           break;

                        case 'h':
                           *(short *)args[j].strval = atoi(buf);
                           break;

                        case 'd':
                           *(double *)args[j].strval = atof(buf);
                           break;

                        case 'f':
                           *(float *)args[j].strval = atof(buf);
                           break;
                     }

                  }
                  else {
                     shv.shvvalue.strptr = args[j].strval;
                     shv.shvvalue.strlength = shv.shvvaluelen = sizeof(args[j].strval);
                     RexxVariablePool(&shv);
                     if (shv.shvret & RXSHV_NEWV)
                        args[j].strval = NULL;
                  }

                  i++;
                  break;
            }
      }
   }

   /* all-strings short-cut */
   else if (argc) {
      args = alloca(sizeof(*args)*argc);
      type = alloca(sizeof(*type)*argc);
      ac = argc;

      for (i = 0; i < argc; i++) {
         type[i] = 's';
         if (argv[i].strptr)
            rxstrdup(args[i].strval, argv[i]);
         else
            args[i].strval = NULL;
      }
   }

   /* now push all that onto the stack in reverse order */
   for (j = ac - 1; j >= 0; j--) {
      switch (type[j]) {
         case 's':
         case 'b':
         case 'r':
         case 'p':
            strasm = args[j].strval;
            __asm push strasm;
            break;

      case'i':
         intasm = args[j].intval;
         __asm push intasm;
         break;

      case'h':
         shortasm = args[j].shortval;
         __asm push shortasm;
         break;

      case'd':
         memcpy(dblasm, &args[j].dblval, sizeof(dblasm));
         __asm {
           push dblasm[1];
           push dblasm[0];
         }
         break;

      case'f':
         flasm = args[j].flval;
         __asm push flasm;
         break;

      }
   }

   fh = thisf->fh;

   /* call it */
   __asm call fh;

   /* and get the return code */

   __asm mov rc, eax;

   /* now put the return code in result and fill in the pointers and buffer
    * values */
   switch (thisf->rctype) {
      /* null-terminated string (always) */
      case 's': {
         len = strlen((char *)rc);
         rxresize(result, len);
         memcpy(result->strptr, (char *)rc, len);
         result->strlength = len;
      }
      break;

      /* other return types are not currently supported */
      default:
      case 'i':
         result->strlength = sprintf(result->strptr, "%d", rc);
         break;

      case 'p':
         memcpy(result->strptr, &rc, sizeof(rc));
         result->strlength = sizeof(rc);
         break;
   }

   /* finally, set the filled-in values of any pointers */
   shv.shvnext = NULL;
   shv.shvcode = RXSHV_SYSET;
   for (i = 0; i < rccount; i++) {
      shv.shvname = argv[rcargs[i].j];

      switch (rcargs[i].type) {
         case 'b':
         case 's':
            shv.shvvalue.strptr = rcargs[i].ptr;
            shv.shvvalue.strlength = rcargs[i].len;
            break;

         case 'i':
            shv.shvvalue.strptr = buf;
            shv.shvvalue.strlength = sprintf(buf, "%d", *(int *)rcargs[i].ptr);
            break;

         case 'h':
            shv.shvvalue.strptr = buf;
            shv.shvvalue.strlength = sprintf(buf, "%hd", *(short *)rcargs[i].ptr);
            break;

         case 'f':
            shv.shvvalue.strptr = buf;
            shv.shvvalue.strlength = sprintf(buf, "%f", *(float *)rcargs[i].ptr);
            break;

         case 'd':
            shv.shvvalue.strptr = buf;
            shv.shvvalue.strlength = sprintf(buf, "%f", *(double *)rcargs[i].ptr);
            break;
      }

      shv.shvret = 0;
      RexxVariablePool(&shv);
      free(rcargs[i].ptr);
   }

   return 0;
}

/* finds or loads a function, sets up the funclist entry and returns the
 * index. Returns -1 if it fails to find the library or the name is
 * duplicated. Sorts the list each time, which slows this down,
 * but speeds up execution. */
static int addFunc(PRXSTRING realname, t_Libs lib, PRXSTRING rexxname, PRXSTRING args,
                   PRXSTRING rctype)
{
   t_Funcs thisf = findFunc(rexxname);
   PRXSTRING thisname;

   /* if this is the function, name sure it's the same as the one we're
    * asking for, otherwise bail */
   if (thisf) {
      if (lib != thisf->lib || rxstrcmp(realname, thisf->realname)) {
         SetLastError(ERROR_DUP_NAME);
         return -1;
      }
   }
   else {
      if (rexxname->strlength == 0 || realname->strlength == 0) {
         SetLastError(ERROR_BAD_ARGUMENTS);
         return -1;
      }

      if (funccount >= funcalloc) {
         funclist = realloc(funclist, (funcalloc += LOAD_INCREMENT)*sizeof(*funclist));
         if (!funclist)
            return -1;
      }

      thisf = funclist + funccount;
      thisf->ref = 0;

      /* allocate the name -- do this before finding the function because we
       * need to null terminate */
      thisf->realname = rxstrdupm(realname);

      if (!thisf->realname)
         return -1;

      thisf->fh = GetProcAddress(lib->lh, thisf->realname->strptr);

      if (thisf->fh == NULL) {
         free(thisf->rexxname);

         /* get rid of this library if it isn't being used */
         if (!lib->ref) {
            dropLib(lib);
            libcount--;
            if ((lib - liblist) < libcount) {
               memmove(lib, lib+1, sizeof(*lib)*(libcount - (lib - liblist)));
            }
         }
         return -1;
      }

      thisf->lib = lib;

      /* need to use thisname so we can locate the function later -- the
       * problem is that we convert it to upper case later */
      thisname = thisf->rexxname = rxstrdupm(rexxname);
      if (!thisf->rexxname) {
         free(thisf->realname);
         return -1;
      }

      _strupr(thisf->rexxname->strptr);

      if (args) {
         register int i;

         thisf->args = rxstrdupm(args);
         if (!thisf->args) {
            free(thisf->realname);
            free(thisf->rexxname);
            return -1;
         }

         _strlwr(args->strptr);

         for (i = 0; i < args->strlength; i++) {
            switch(args->strptr[i]) {
               case 's':
               case 'i':
               case 'h':
               case 'r':
               case 'd':
               case 'f':
                  thisf->argc++;
                  break;

               /* for return values: a buffer. The size of the buffer can be put
                * in brackets. The rexx variable passed in this position is
                * evaluated and passed to the function, then set on return. */
               case 'b':
                  if (thisf->args->strptr[i+1] == '[') {
                     i += 2 + strcspn(thisf->args->strptr+1, "]");
                  }

                  thisf->argc++;
                  break;

               /* finally, you can have a pointer to one of the base types. We use
                * the type itself to determine the arg count */
               case 'p':
                  break;

               /* we were doing so well, but we screwed up the arg list.
                * Rather than failing, I assume we meant `string' here */
               default:
                  thisf->argc++;
                  thisf->args->strptr[i] = 's';
            }
         }

      }
      else {
         thisf->args = NULL;
         thisf->argc = 0;
      }

      if (rctype && rctype->strlength) {
         thisf->rctype = tolower(rctype->strptr[0]);
      }


      RexxRegisterFunctionExe(thisf->rexxname->strptr, privatecallfunction);
      funccount++;

      qsort(funclist, funccount, sizeof(*funclist), funccmp);

      thisf = findFunc(thisname);

      if (!thisf) {
         return -1; /* this can't happen */
      }
   }


   thisf->ref++;
   return thisf - funclist;
}


/* w32funcadd loads and registers the function for use later */
rxfunc(w32funcadd)
{
   int libi, funci;

   checkparam(3, 5);
   
   libi = addLibrary(argv+1);
   if (libi == -1) {
      result_one();
   }
   else {
      funci = addFunc(argv, liblist+libi, argv+2, argc > 3 ? argv+3 : NULL, argc > 4 ? argv+4 : NULL);
      if (funci == -1) {
         result_one();
      }
      else {
         result_zero();
      }
   }

   return 0;
}

/* w32funcdrop unregisters the function and possibly unloads
 * the library. If called without arguments, it unregisters
 * and unloads everything. */
rxfunc(w32funcdrop)
{
   register int i;
   t_Funcs thisf;
   t_Libs thisl;

   checkparam(0, 1);

   if (!argc) {
      for (i = 0; i < funccount; i++) {
         dropFunc(funclist+i);
      }
      for (i = 0; i < libcount; i++) {
         dropLib(liblist+i);
      }
      libcount = funccount = 0;
      result_zero();
   }
   else {
      thisf = findFunc(argv);
      if (!thisf) {
         result_one();
      }
      else {
         /* get rid of the function and its associated library if the
          * reference counts get down to 0. */

         if (--thisf->ref <= 0) {
            thisl = thisf->lib;

            dropFunc(thisf);

            funccount--;
            if ((thisf - funclist) < funccount) {
               memmove(thisf, thisf+1, sizeof(*thisf)*(funccount - (thisf - funclist)));
            }

            if (--thisl->ref <= 0) {
               dropLib(thisl);

               libcount--;
               if ((thisl - liblist) < libcount) {
                  memmove(thisl, thisl+1, sizeof(*thisl)*(libcount - (thisl - liblist)));
               }
            }
         }
         result_zero();
      }
   }

   return 0;
}

/* w32funcquery returns 0 if the function is loaded, 1 otherwise */
rxfunc(w32funcquery)
{
   t_Funcs thisf;

   checkparam(1, 1);

   thisf = findFunc(argv);

   if (!thisf) {
      result_one();
   }
   else
      result_zero();
   
}

#else

/* catch-all for unsupported compilers (my intent is to support MinGW and
 * ignore the others) */

/* call SetLastError so that w32geterror() will return a meaningful message */
rxfunc(w32funcadd)
{
   SetLastError(ERROR_NOT_SUPPORTED);
   result_one();
   return 0;
}

rxfunc(w32funcdrop)
{
   SetLastError(ERROR_NOT_SUPPORTED);
   result_one();
   return 0;
}

/* this is supported -- it's just that the function isn't loaded ... */
rxfunc(w32funcquery)
{
   result_one();
   return 0;
}
#endif
