/* route.c */
/* This is a bit horrible: the basic rule is that all the complex stuff is done in the server (outside
   the firewall -the machine inside might not be allowed incoming connections). The only reason
   for this is to keep all the admin data in one place.

   The only thing the client does need is the ACL. In future, this might be the first thing sent over
   the TCP connection. 
*/


/* We need three separate hash tables: 

   -> one for the client to do ACL checking.
   -> one for the server to map (a.b.c.d, p1, p2, cname) to (ttl, struct inaddr *, fd) for outgoing UDP.
   -> one for the server to map (svr, port) to (a.b.c.d,p1,p2, cname, ttl) for incoming UDP.
*/

#include "header.h"

acl_table_t route_acl_table, client_acl_table;
svr_conf_table_t route_config_table;
svr_route_table_t routing_table;
svr_route_freelist_t freeList;
client_fd_table_t client_fdtable;
svr_packet_queue_t outgoing_queue, incoming_queue;

/* Should be inline. Inputs in nbo. */
inline int hash1(ipv4_addr_t addr, ipv4_port_t port) {
  int t = ((addr>>24)&0xff)+((addr>>8)&0xff)+(port&0xff);

  return (t<0?-t:t);
}

inline int hash2(ipv4_addr_t addr, ipv4_port_t port) {
  int t = ((addr>>16)&0xff)+((addr)&0xff)+((port>>16)&0xff);

  return (t<0?-t:t);
}

/* Functions to manage the outgoing queue */

/* Return 0 if OK, 1 if we should discard this packet. */
int insert_packet_for_in(server_route_t *e) {
  int i, j;

  if (IQ.occ == IQ.length) {
    return 1;
  }

  for (i=0;i<IQ.occ;i++) {
    if (timeval_before(e->due_in, IQ.entries[i]->due_in)) {
      for (j=IQ.occ;j > i;j--) {
	IQ.entries[j] = IQ.entries[j-1];
      }
      IQ.entries[i] = e;
      IQ.occ++;
      return 0;
    }
  }
  IQ.entries[IQ.occ++] = e;
  return 0;  
}

int insert_packet_for_out(server_route_t *e) {
  int i, j;

  if (OQ.occ == OQ.length) {
    return 1;
  }

  for (i=0;i<OQ.occ;i++) {
    if (timeval_before(e->due_in, OQ.entries[i]->due_in)) {
      for (j=OQ.occ;j > i;j--) {
	OQ.entries[j] = OQ.entries[j-1];
      }
      OQ.entries[i] = e;
      OQ.occ++;
      return 0;
    }
  }
  OQ.entries[OQ.occ++] = e;
  return 0;  
}

void discard_packet_for(server_route_t *e, svr_packet_queue_t *qt) {
  int i, j, k = -1;

  for (i=0;i<qt->occ;i++) {
    if (qt->entries[i] == e) {
      k = i; break;
    }
  }
  if (k > -1) {
    for (j=k;j<qt->occ-1;j++) {
      qt->entries[j] = qt->entries[j+1];
    }
    qt->occ--;
  }
}

inline int getHashBucketIncoming(ipv4_addr_t client, ipv4_port_t cl_port,
				      ipv4_addr_t orig, ipv4_port_t orig_port) {
  int t = client+cl_port+orig+orig_port;

  return (t>0?t:-t)%(routing_table.length);
}

inline int getHashBucketOutgoing(ipv4_addr_t svr, ipv4_port_t svr_port, int fd_destination) {
  int t = svr+svr_port+fd_destination;

  return (t>0?t:-t)%(routing_table.length);
}

inline server_route_t *searchForOutgoingConnection(int fd) {
  server_route_t *cur;
  int i;

  for (i=0;i<routing_table.length;i++) {
    cur = routing_table.fwd[i];
    while (cur !=NULL) {
      if (cur->fd_destination == fd) { 
	return cur;
      }
      cur=cur->next;
    }
  }

  return NULL; 
}

/* Delete an element from its lists and link it back to the free list 
 */
void delete_route(server_route_t *rt) {
  server_route_t **prevn, **prevn2;

  if (rt->prev) {
    prevn = &(rt->prev->next);
  } else {
    prevn = &routing_table.fwd[getHashBucketIncoming(rt->clnt_name,rt->clnt_port,
						    rt->orig_name, rt->orig_port)];
			     
  }
  if (rt->prev2) {
    prevn2 = &(rt->prev2->next2);
  } else {
    prevn2 = &routing_table.bwd[getHashBucketOutgoing(rt->svr_name, rt->svr_port, 
						      rt->fd_destination)];
  }
  
  /* Unlink */
  if (rt->next) { 
    rt->next->prev = rt->prev;
  }
  if (rt->next2) {
    rt->next2->prev2 = rt->prev2;
  }
  *prevn = rt->next;
  *prevn2 = rt->next2;
  
  routing_table.occ--;

  /* And add to the free list */
  putFreeBlock(rt);
}

/* Find a connection, if you can. */
inline server_route_t *getIncomingConnection(ipv4_addr_t client, ipv4_port_t cl_port, 
				      ipv4_addr_t orig, ipv4_port_t orig_port) {
  server_route_t *cur;

  cur = routing_table.fwd[getHashBucketIncoming(client, cl_port, orig, orig_port)];
  while (cur != NULL) {
    if (cur->clnt_name == client && cur->clnt_port == cl_port && 
	cur->orig_name == orig && cur->orig_port == orig_port) {
      cur->ttl = TTL_USED;
      return cur;
    }
    cur = cur->next;
  }
  return NULL;
}

server_route_t *getOutgoingConnection(ipv4_addr_t svr, ipv4_port_t svr_port,
				      int fd_destination) {
  server_route_t *cur;

  cur = routing_table.bwd[getHashBucketOutgoing(svr, svr_port, fd_destination)];
  while (cur != NULL) {
    if (cur->svr_name == svr && cur->svr_port == svr_port &&
	cur->fd_destination == fd_destination) {
      cur->ttl = TTL_USED;
      return cur;
    }
    cur = cur->next2;
  }

  return NULL;
}
/* Create a new connection. 
 */
void add_route(server_route_t *rt) {
  int b0 = getHashBucketIncoming(rt->clnt_name, rt->clnt_port, rt->orig_name, rt->orig_port);
  int b1 = getHashBucketOutgoing(rt->svr_name, rt->svr_port, rt->fd_destination);
  
  rt->next = routing_table.fwd[b0]; 
  if (rt->next) {
    rt->next->prev = rt;
  }
  rt->prev = NULL;
  routing_table.fwd[b0] = rt;

  rt->next2 = routing_table.bwd[b1];
  if (rt->next2) {
    rt->next2->prev2 = rt;
  }
  rt->prev2 = NULL;
  routing_table.bwd[b1] = rt;
  routing_table.occ++;

}


/* Manage the free list. If there are fewer than MIN_FREE_LIST_BLOCKS, 
 * allocate some, and if there are more than MAX_FREE_LIST_BLOCKS, free them.
*/
void manage_freelist() {
  int i;

  if (freeList.howmany < MIN_FREE_LIST_BLOCKS) {
    for (i=0;i<BLOCKS_AT_A_TIME;i++) {
      server_route_t *t = malloc(sizeof(server_route_t) + 2*SERVER_PACKET_BUFFER_SZ);
      
      if (t==NULL) {
	warn(0, "Can't allocate memory for routes.");
	return;
      }
      /* This should improve locality slightly : */
      t->pckt_in = &(((char *)t)[sizeof(server_route_t)]);
      t->pckt_out = &(((char *)t)[sizeof(server_route_t) + SERVER_PACKET_BUFFER_SZ]);
      putFreeBlock(t);
    }
    return;
  }

  if (freeList.howmany > MAX_FREE_LIST_BLOCKS) {
    for (i=0;i<BLOCKS_AT_A_TIME;i++) {
      server_route_t *t;

      t = getFreeBlock();
      if (t) { free(t); }
    }
  }
}

/* Add a block to the free list */
void putFreeBlock(server_route_t *t) {
  if (freeList.lst) {
    freeList.lst->prev = t;
    t->next = freeList.lst;
    freeList.lst = t;
    freeList.howmany++;
  } else {
    freeList.lst = t; t->prev = NULL; t->next = NULL;
    freeList.howmany=1;
  }
  if ((freeList.howmany%4)==0) { 
    manage_freelist();
  }
}

/* Get the next block off the free list */
server_route_t *getFreeBlock(void) {
  server_route_t *nxt;
  
  if (freeList.howmany < 1) {
    manage_freelist();
    if (freeList.howmany < 1) { return NULL; }
  }
  
  if (!freeList.lst) {
    warn(0, "Free list is empty, and says it isn't: resetting it.");
    freeList.howmany = 0;
    return getFreeBlock();
  }
  
  nxt = freeList.lst;
  if (freeList.lst->next != NULL) {
    freeList.lst->next->prev = NULL;
  }
  freeList.lst = freeList.lst->next;
  nxt->prev = NULL; nxt->next = NULL;
  return nxt;
}


/* Return 1 iff access allowed, 0 if denied, <0 if error. Inputs in nbo. */
inline int access_allowed(ipv4_addr_t addr, ipv4_port_t port) {
  ipv4_addr_t had = ntohl(addr);
  ipv4_port_t mport = ntohs(port);
  int i;

  for (i=0;i<route_acl_table.occ;i++) {
    ipv4_addr_t h2 = ntohl(route_acl_table.entries[i].allowed);
    ipv4_addr_t h3 = ntohl(route_acl_table.entries[i].mask);
    int p1, p2;

    if ((had&h3) == (h2&h3)) {
      p1 = ntohl(route_acl_table.entries[i].pstart);
      p2 = ntohl(route_acl_table.entries[i].pend);
      if (mport >= p1 && mport <= p2) { return 1; }
    }
  }
  return 0;
}

/* Find the server and port for a given client port */
inline server_conf_elem_t *getRedirectInfo(ipv4_addr_t client, ipv4_port_t client_port) {
  int i;

  for (i=0;i<route_config_table.occ;i++) {
    if (route_config_table.entries[i].client_name == client &&
	route_config_table.entries[i].client_port == client_port) {
      return &route_config_table.entries[i];
    }
  }
  return NULL;
}


/* Initialise routing tables */
void route_init(void) {
  int bytes, i;

  bytes = sizeof(client_fd_entry_t)*TABLE_CLNT_FD_SZ;
  client_fdtable.entries = malloc(bytes);
  if (!client_fdtable.entries) {
    fatal(0, "Can't get %d bytes for client FD table.\n", bytes);
  }
  client_fdtable.length = TABLE_CLNT_FD_SZ;
  client_fdtable.occ = 0;

  for (i=0;i<client_fdtable.length;i++) {
    client_fdtable.entries[i].used = 0;
  }
 
  bytes = sizeof(acl_ht_entry_t)*TABLE_CLNT_ACL_SZ;
  route_acl_table.entries = malloc(bytes);
  route_acl_table.length = TABLE_CLNT_ACL_SZ;
  if (!route_acl_table.entries) { 
    fatal(0, "Can't get %d bytes for acl table.\n", bytes);
  }

  route_acl_table.occ = 0;

  bytes = sizeof(acl_ht_entry_t)*TABLE_CONN_ACL_SZ;
  client_acl_table.entries = malloc(bytes);
  client_acl_table.length = TABLE_CONN_ACL_SZ;
  if (!client_acl_table.entries) { 
    fatal(0, "Can't get %d bytes for acl table.\n", bytes);
  }

  client_acl_table.occ = 0;
  
  bytes = sizeof(server_route_t *)*TABLE_QUEUE_SZ;
  IQ.entries = malloc(bytes);
  IQ.occ = 0; IQ.length = TABLE_QUEUE_SZ;
  if (!IQ.entries) {
    fatal(0, "Can't get %d bytes for incoming queue.\n", bytes);
  }

  OQ.entries = malloc(bytes);
  OQ.occ = 0; OQ.length = TABLE_QUEUE_SZ;
  if (!OQ.entries) {
    fatal(0, "Can't get %d bytes for outgoing queue.\n", bytes);
  }

  bytes = sizeof(server_conf_elem_t)*TABLE_CLNT_INDIRECT_SZ;
  route_config_table.entries = malloc(bytes);
  route_config_table.length = TABLE_CLNT_INDIRECT_SZ;
  if (!route_config_table.entries) {
    fatal(0, "Can't get %d bytes for configuration table.\n", bytes);
  }

  route_config_table.occ = 0;

  bytes = sizeof(server_route_t *)*ROUTING_TABLE_SZ;
  routing_table.length = ROUTING_TABLE_SZ;
  routing_table.fwd = malloc(bytes); routing_table.bwd = malloc(bytes);
  if (!routing_table.fwd || !routing_table.bwd) {
    fatal(0, "Can't allocate %d + %d bytes for routing table.", bytes, bytes);
  }
  for (i=0;i<ROUTING_TABLE_SZ;i++) {
    routing_table.fwd[i] = NULL; routing_table.bwd[i] = NULL;
  }
  routing_table.occ = 0;
  
  freeList.howmany = 0; freeList.lst = NULL;
}

void close_all_routes(void) {
  server_route_t *cur;
  int i;

  for (i=0;i<routing_table.length;i++) {
    cur = routing_table.fwd[i];
    while (cur != NULL) {
      close(cur->fd_destination);
      cur = cur->next;
    }
  }
  
}

/* End */

