/*////////////////////////////////////////////////////////////////////////
Copyright (c) 1996 Electrotechnical Laboratry (ETL), AIST, MITI

Permission to use, copy, modify, and distribute this material for any
purpose and without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies, and
that the name of ETL not be used in advertising or publicity pertaining
to this material without the specific, prior written permission of an
authorized representative of ETL.
ETL MAKES NO REPRESENTATIONS ABOUT THE ACCURACY OR SUITABILITY OF THIS
MATERIAL FOR ANY PURPOSE.  IT IS PROVIDED "AS IS", WITHOUT ANY EXPRESS
OR IMPLIED WARRANTIES.
/////////////////////////////////////////////////////////////////////////
Content-Type:	program/C; charset=US-ASCII
Program:	socks.c
Author:		Yutaka Sato <ysato@etl.go.jp>
Description:
History:
	960219	created
//////////////////////////////////////////////////////////////////////#*/
#include <stdio.h>
#include "delegate.h"
extern char *gethostbyAddr();
extern char *gethostaddr();

#define SOCKS_VERSION	4
#define SOCKS_CONNECT	1
#define SOCKS_BIND	2
#define SOCKS_ACCEPT	8	/* pseudo command */
#define SOCKS_RESULT	90
#define SOCKS_FAIL	91
#define SOCKS_NO_IDENTD	92
#define SOCKS_BAD_ID	93

#if defined(SOCKS_DEFAULT_NS)
#define DNS	SOCKS_DEFAULT_NS
#else
#define DNS 	NULL
#endif

#if defined(SOCKS_DEFAULT_DNAME)
#define DNAME	SOCKS_DEFAULT_DNAME
#else
#define DNAME	NULL
#endif

#ifndef SOCKS_DEFAULT_SERVER
#define SOCKS_DEFAULT_SERVER NULL
#endif

#ifndef SOCKS_DEFAULT_PORT
#define SOCKS_DEFAULT_PORT 0
#endif

static char *scommand(command)
{
	switch(command){
		case SOCKS_CONNECT:	return "CONNECT";
		case SOCKS_BIND:	return "BIND";
		case SOCKS_ACCEPT:	return "ACCEPT";
		default:		return "BAD-COM";
	}
}

isSocksConnect(pack,leng,ver,addr,port,user)
	unsigned char *pack;
	int *ver;
	char *addr;
	int *port;
	unsigned char **user;
{
	if( pack[0] == 4 )
	if( pack[1] == SOCKS_CONNECT ){
		*ver = 4;
		*port = pack[2] << 8 | pack[3];
		sprintf(addr,"%d.%d.%d.%d",pack[4],pack[5],pack[6],pack[7]);
		*user = &pack[8];
		return 8 + strlen(*user) + 1;
	}
	return 0;
}

static makePacket(obuf,version,command,host,port,user)
	char *obuf,*host,*user;
{	char *addr;
	int av[4][1];
	int pc;
	
	addr = gethostaddr(host);
	if( addr == NULL ){
		sv1log("Don't try Socks for unknown host\n");
		return -1;
	}

	if( command == SOCKS_CONNECT && strcmp(addr,"0.0.0.0") == 0 )
		return -1;
	if( sscanf(addr,"%d.%d.%d.%d",av[0],av[1],av[2],av[3]) != 4 )
		return -1;

	obuf[0] = version;
	obuf[1] = command;

	obuf[2] = port >> 8;
	obuf[3] = port;

	obuf[4] = av[0][0];
	obuf[5] = av[1][0];
	obuf[6] = av[2][0];
	obuf[7] = av[3][0];

	strcpy(obuf+8,user);
	pc = 8 + strlen(user) + 1;

	return pc;
}

/*
 *	SOCKS SERVER
 */
service_socks(Conn)
	Connection *Conn;
{	FILE *fc;
	unsigned char ibuf[16],obuf[16];
	int pc,ver,com,ch;
	char addr[64],host[1024],user[512];
	int ux,port;

	fc = fdopen(FromC,"r");
	setbuf(fc,NULL);
	pc = freadTIMEOUT(ibuf,8,1,fc);
	if( pc <= 0 )
		return;
	ver = ibuf[0];
	com = ibuf[1];
	port = (ibuf[2] << 8) | ibuf[3];
	sprintf(addr,"%d.%d.%d.%d",ibuf[4],ibuf[5],ibuf[6],ibuf[7]);
	gethostbyAddr(addr,host);

	if( com != SOCKS_CONNECT && com != SOCKS_BIND ){
		sv1log("#### ERROR: NON SOCKS CLIENT ?[%d]\n",com);
		return;
	}

	ux = 0;
	for( ux = 0; ; ux++ ){
		if( sizeof(user)-1 <= ux )
			return;
		ch = fgetcTIMEOUT(fc);
		if( ch == EOF )
			return;
		if( ch == 0 )
			break;
		user[ux] = ch;
	}
	user[ux] = 0;

sv1log("[SOCKS-serv] %d ver[%d] com[%d/%s] port[%d] host[%s][%s] user[%s]\n",
		pc,ver,com,scommand(com),port,addr,host,user);

	set_realserver(Conn,"tcprelay",addr,port);
	if( !service_permitted(Conn,"tcprelay") )
		goto failed;

	if( com == SOCKS_BIND ){
		int bsock,asock,bport,nready,socks[2],readyv[2],aport;
		char bhost[32],ahost[32];
		extern int ACC_TIMEOUT;

		hostIFfor(addr,bhost);
		bsock = server_open("SOCKS",bhost,0,1);
		if( bsock < 0 )
			goto failed;
		bport = gethostAddr(bsock,bhost);
		gethostNAME(bsock,bhost,NULL,&bport);

		makePacket(obuf,SOCKS_VERSION,SOCKS_RESULT,bhost,bport,"");
		write(ToC,obuf,8);

		socks[0] = FromC;
		socks[1] = bsock;
		nready = PollIns(ACC_TIMEOUT*1000,2,socks,readyv);

		if( 0 < nready && 0 < readyv[1] ){
			asock = ACCEPT(bsock,0,-1,1);
			close(bsock);

			gethostNAME(asock,ahost,NULL,&aport);
			makePacket(obuf,SOCKS_VERSION,SOCKS_RESULT,
				ahost,aport,"");
			write(ToC,obuf,8);
			relay_svcl(Conn,FromC,ToC,asock,asock);
			close(asock);
		}else{
			close(bsock);
		}
	}else{
		char bhost[32];
		int bport;

		if( connect_to_serv(Conn,FromC,ToC,0) < 0 )
			goto failed;
		gethostNAME(FromS,bhost,NULL,&bport);
		makePacket(obuf,SOCKS_VERSION,SOCKS_RESULT,bhost,bport,"");
		write(ToC,obuf,8);
		relay_svcl(Conn,FromC,ToC,FromS,ToS);
		close(ToS);
		close(FromS);
	}
	return;

failed:
	makePacket(obuf,SOCKS_VERSION,SOCKS_FAIL,addr,port,"");
	write(ToC,obuf,8);
}

/*
 *	SOCKS CLIENT
 */
static char username[64];
extern int myownSOCKS;

typedef struct {
	char	*s_host;
	int	 s_port;
	int	 s_ver;
} SocksServer;
SocksServer sockservs[8] = {{SOCKS_DEFAULT_SERVER,SOCKS_DEFAULT_PORT}};
static int socks_inits;

socks_init(name,hostport,ns,dom)
	char *name,*hostport,*ns,*dom;
{	int sx;
	char host[256],*host1;
	int port;
	int ver = 4;

	host1 = sockservs[0].s_host;
	if( host1 && host1[0] )
		myownSOCKS = 1;

	if( ns == NULL )
		ns = DNS;
	if( ns != NULL ){
		if( strcmp(ns,"-5") == 0 )
			ver = 5;
		else	RES_ns(ns);
	}

	if( dom == NULL )
		dom = DNAME;
	if( dom != NULL )
		RES_domain(dom);

	if( hostport == 0 || hostport[0] == 0 )
		return;

	port = SOCKS_DEFAULT_PORT;
	if( sscanf(hostport,"%[^:]:%d",host,&port) == 0 )
		return;

	socks_inits++;
	if( socks_inits == 1 )
		sx = 0;
	else
	for( sx = 0; host1 = sockservs[sx].s_host; sx++ ){
		if( host1[0] == 0 )
			break;
		if( strcmp(sockservs[sx].s_host,host) == 0 )
		if( sockservs[sx].s_port == port )
			return;
	}

	sockservs[sx].s_host = strdup(host);
	sockservs[sx].s_port = port;
	sockservs[sx].s_ver = ver;
}
socks_addservers(){
	int sx;
	char *host;

	for( sx = 0; host = sockservs[sx].s_host; sx++ )
	if( *host )
		SOCKS_addserv("*",0,gethostaddr(host),sockservs[sx].s_port);
	return sx;
}

extern int CON_TIMEOUT;
static getResponse(sock,command,rhost,rport)
	unsigned char *rhost;
	int *rport;
{	FILE *fs;
	unsigned char ibuf[128];
	int pc,ai;

	if( PollIn(sock,CON_TIMEOUT*1000) <= 0 ){
		sv1log("SOCKS response TIMEOUT (%d)\n",CON_TIMEOUT);
		return SOCKS_FAIL;
	}
	pc = readsTO(sock,ibuf,8,CON_TIMEOUT*1000);

	sv1log("[SOCKS-clnt] %s %d ver[%d] stat[%d] host[%d.%d.%d.%d]:%d\n",
		scommand(command),
		pc,
		ibuf[0],ibuf[1],
		ibuf[4],ibuf[5],ibuf[6],ibuf[7],
		ibuf[2]<<8|ibuf[3]);

	if( rhost != NULL )
		for( ai = 0; ai < 4; ai++ )
			rhost[ai] = ibuf[4+ai];

	if( rhost != NULL && rhost[0] == 0 )
	if( command == SOCKS_BIND || command == SOCKS_ACCEPT ){
		int iaddr;
		iaddr = peerHostport(sock,NULL);
		for( ai = 0; ai < 4; ai++ )
			rhost[ai] = (iaddr >> (3-ai)*8) & 0xFF;
	}

	if( rport != NULL )
		*rport = (ibuf[2] << 8) | ibuf[3];

	return ibuf[1];
}

static socks_start(sock,ver,command,host,port,user,rhost,rport)
	char *host,*user,*rhost;
	int *rport;
{	unsigned char obuf[128],i,*ap;
	int pc,wcc,rep;

	if( ver == 5 )
		return SOCKS_startV5(sock,command,host,port,user,rhost,rport);

	pc = makePacket(obuf,SOCKS_VERSION,command,host,port,user);
	if( pc < 0 )
		return -1;

	wcc = write(sock,obuf,pc);
	rep = getResponse(sock,command,rhost,rport);

	if( rep == SOCKS_RESULT )
		return 0;
	else	return -1;
}

typedef struct {
	int	v_addr;
	int	v_port;
} ViaSocks;

#define NVIAS 64
static ViaSocks viaSocks[NVIAS];

setViaSocks(host,port)
	char *host;
{	ViaSocks *vs;
	int vsx;
	unsigned int addr;

	addr = gethostintMin(host);
	for( vsx = 0; vsx < NVIAS; vsx++ ){
		vs = &viaSocks[vsx];
		if( vs->v_addr == 0 ){
			vs->v_addr = addr;
			vs->v_port = port;
			break;
		}
		if( vs->v_addr == addr )
			break;
	}
}
getViaSocks(host,port)
	char *host;
{	ViaSocks *vs;
	int vsx;
	unsigned int addr;

	addr = gethostintMin(host);
	for( vsx = 0; vsx < NVIAS; vsx++ ){
		vs = &viaSocks[vsx];
		if( vs->v_addr == 0 )
			break;
		if( vs->v_addr == addr )
			return 1;
	}
	return 0;
}

socks_connect(sock,ver,host,port,user)
	char *host,*user;
{
	if( host == NULL || port == 0 )
		return -1;
	else{
		sock = socks_start(sock,ver,SOCKS_CONNECT,host,port,user,NULL,NULL);
		if( 0 <= sock )
			setViaSocks(host,port);
		return sock;
	}
}

static int bind_ver;
socks_bind(sock,ver,host,port,user,rhost,rport)
	char *host,*user,*rhost;
	int *rport;
{
	bind_ver = ver;
	return socks_start(sock,ver,SOCKS_BIND,host,port,user,rhost,rport);
}

bindViaSocks(dsthost,dstport,rhost,rport)
	char *dsthost,*rhost;
	int *rport;
{	int sx;
	char *shost;
	int sport;
	int ver;
	int sock;

	if( myownSOCKS == 0 )
		return -1;

	getUsernameCached(getuid(),username);
	sv1log("bindViaSocks(%s:%d)[%s]\n",dsthost,dstport,username);

	for( sx = 0; sockservs[sx].s_host; sx++ ){
		shost = sockservs[sx].s_host;
		sport = sockservs[sx].s_port;
		ver = sockservs[sx].s_ver;
		if( shost[0] == 0 || sport == 0 )
			break;
		sock = OpenServer("BindViaSocks","socks",shost,sport);
		if( 0 <= sock )
		if( socks_bind(sock,ver,dsthost,dstport,username,rhost,rport) == 0 )
			return sock;
		close(sock);
	}
	return -1;
}

acceptViaSocks(sock,rhost,rport)
	char *rhost;
	int *rport;
{	int rep;

	if( bind_ver == 5 )
		return SOCKS_recvResponseV5(sock,SOCKS_ACCEPT,rhost,rport);

	rep = getResponse(sock,SOCKS_ACCEPT,rhost,rport);
	if( rep == SOCKS_RESULT )
		return 0;
	else	return -1;
}

connectViaSocks(dsthost,dstport)
	char *dsthost;
{	int sx;
	char *host;
	int port;
	int ver;
	int sock;

	if( myownSOCKS == 0 )
		return -1;

	getUsernameCached(getuid(),username);
	for( sx = 0; sockservs[sx].s_host; sx++ ){
		host = sockservs[sx].s_host;
		port = sockservs[sx].s_port;
		ver = sockservs[sx].s_ver;
		if( host[0] == 0 || port == 0 )
			break;

		sock = OpenServer("ConnectViaSocks","socks",host,port);
		if( 0 <= sock )
		if( socks_connect(sock,ver,dsthost,dstport,username) == 0 )
			return sock;

		close(sock);
		continue;
	}
	return -1;
}

ConnectViaSocks(Conn,relay_input)
	Connection *Conn;
{	int sock;

	if( 0 <= (sock = connectViaSocks(DST_HOST,DST_PORT)) )
		initConnected(Conn,sock,relay_input);

	return sock;
}
