#include "main.h"

#include <Winsockx.h>

#include "net.h"

#ifndef _USE_WINSOCK

int g_socketCreated = 0;
int g_listenSocketCreated = 0;
int g_hostPlayer = 1;

void Net_Disconnect(void)
{
}

void Net_Connect(int port, const char *ip)
{
}

void Net_Listen(int port)
{
}

void Net_Init(void)
{
}

void Net_SendReinit(void)
{
}

void Net_SyncSend(void)
{
}

void Net_SendCards(void)
{
}

void Net_SendArrowChoice(void)
{
}

void Net_SendElementGrid(void)
{
}

#else
WSADATA g_wsaData;
bool g_netValid = false;

static CTriadSocket g_tSocket;
int g_socketCreated = 0;

static CTriadSocket g_tListenSocket;
int g_listenSocketCreated = 0;

int g_hostPlayer = 1; //assume host unless told otherwise


//send data with a prepended header
void CTriadSocket::PacketSend(void *data, int type, int len)
{
	packetHeader_t pHeader;
	BYTE *d = (BYTE *)malloc(sizeof(packetHeader_t)+len);

	//memcpy the packet header over the first chunk followed by the data
	pHeader.dataLen = len;
	pHeader.pType = type;
	memcpy(d, &pHeader, sizeof(packetHeader_t));

	if (len > 0)
	{
		memcpy(d+sizeof(packetHeader_t), data, len);
	}

	//send it
	Send(d, sizeof(packetHeader_t)+len);

	free(d);
}

//send notify
void CTriadSocket::OnSend(int nErrorCode)
{
	CAsyncSocket::OnSend(nErrorCode);
}

//accept notify event
void CTriadSocket::OnAccept(int nErrorCode)
{
	CAsyncSocket::OnAccept(nErrorCode);
	Accept(g_tSocket);
	g_socketCreated = 1;
	Close();
	g_listenSocketCreated = 0;

	Game_Print("Connected!\n");
	Game_Init();
}

//close notify event
void CTriadSocket::OnClose(int nErrorCode)
{
	if (this == &g_tListenSocket)
	{
		return;
	}
	CAsyncSocket::OnClose(nErrorCode);
	g_hostPlayer = 1;
	//Close();

	g_socketCreated = 0;
	memset(&g_gameRules, 0, sizeof(g_gameRules));
	Game_Init();
	Game_Print("Disconnected.");
}

//receive notify event
void CTriadSocket::OnReceive(int nErrorCode)
{
	packetHeader_t d;

	CAsyncSocket::OnReceive(nErrorCode);
	if (Receive(&d, sizeof(packetHeader_t)) == SOCKET_ERROR)
	{
		Game_Print("Socket read error!");
		return;
	}

	switch(d.pType)
	{
	case PTYPE_REQUESTSTATE:
		PacketSend(&g_gameRules, PTYPE_SENDSTATE, sizeof(g_gameRules));
		break;
	case PTYPE_SENDSTATE:
		if (d.dataLen <= 0)
		{
			Game_Print("Got PTYPE_SENDSTATE with <= 0 datalen");
			return;
		}

		if (d.dataLen != sizeof(g_gameRules))
		{
			Game_Print("Rule state with invalid size.");
			return;
		}

		if (Receive(&g_gameRules, d.dataLen) == SOCKET_ERROR)
		{
			Game_Print("Error receiving rules.");
			return;
		}
		break;
	case PTYPE_SENDCARDS:
		{
			int i = 0;
			sendCard_t s[NUM_CARDS_PER_PLAYER];
			if (d.dataLen <= 0)
			{
				Game_Print("Got PTYPE_SENDCARDS with <= 0 datalen");
				return;
			}
			if (d.dataLen != sizeof(sendCard_t)*NUM_CARDS_PER_PLAYER)
			{
				Game_Print("Cards with invalid size.");
				return;
			}
			
			while (i < NUM_CARDS_PER_PLAYER)
			{
				if (Receive(&s[i], sizeof(sendCard_t)) == SOCKET_ERROR)
				{
					Game_Print("Error receiving cards.");
					return;
				}

				if (g_pl2Cards[i].predictionTake < g_Time)
				{
					g_pl2Cards[i].altTex = s[i].altTex;
				}
				assert(s[i].altTex >= 0);
				g_pl2Cards[i].pCard = &g_playingCards[s[i].pCardIndex];
				assert(s[i].pCardIndex >= 0);
				g_pl2Cards[i].gotoGridSpot = s[i].gotoGrid;
				if (g_pl2Cards[i].gotoGridSpot != g_pl2Cards[i].gridSpot &&
					g_pl2Cards[i].gridSpot != -1 &&
					g_pl2Cards[i].gotoGridSpot > 9) //CPOS_CARDP2_5
				{
					g_pl2Cards[i].gridTravelTime = g_Time + 1000;
					extern waveFile_t *g_sound_menuUse;
					S_PlayRawData(&g_sound_menuUse->data.data, g_sound_menuUse->data.size, BUFFER1);
				}
				g_pl2Cards[i].selected = s[i].selected;
				i++;
			}
			if (!g_gameState.p2ChosenCards &&
				g_gameState.chosenCards)
			{
				if (g_hostPlayer)
				{ //hacktastic
					g_gameState.arrowDeciding = g_Time + 2000;
				}
				else
				{
					g_gameState.arrowDeciding = g_Time + 9999999;
				}
			}
			g_gameState.p2ChosenCards = 1;
			Game_TallyCards();
		}
		break;
	case PTYPE_SENDCHOICE:
		{
			sendChoice_t s;
			if (d.dataLen <= 0)
			{
				Game_Print("Got PTYPE_SENDCHOICE with <= 0 datalen");
				return;
			}
			if (d.dataLen != sizeof(sendChoice_t))
			{
				Game_Print("Sendchoice with invalid size.");
				return;
			}

			if (Receive(&s, d.dataLen) == SOCKET_ERROR)
			{
				Game_Print("Error receiving choice.");
				return;
			}

			g_gameState.arrowDecided = s.decision;
			g_gameState.arrowDeciding = g_Time-1;
		}
		break;
	case PTYPE_SENDELEM:
		if (d.dataLen != sizeof(g_gameState.elementalGrid))
		{
			Game_Print("Got PTYPE_SENDELEM with invalid size.");
			return;
		}
		if (Receive(g_gameState.elementalGrid, d.dataLen) == SOCKET_ERROR)
		{
			Game_Print("Error receiving elem grid.");
			return;
		}
		break;
	case PTYPE_SENDGENERAL:
		{
			sendGeneral_t s;

			if (d.dataLen != sizeof(sendGeneral_t))
			{
				Game_Print("Got PTYPE_SENDGENERAL with invalid size.");
				return;
			}
			if (Receive(&s, d.dataLen) == SOCKET_ERROR)
			{
				Game_Print("Error receiving general state.");
				return;
			}

			g_gameState.p2Total = s.score;
			if (g_gameState.scorePredictionSet < g_Time)
			{
				g_gameState.drawP2Total = s.score;
			}
		}
		break;
	case PTYPE_RESTART:
		Game_Init();
		g_tSocket.PacketSend(NULL, PTYPE_RESTART2, 0);
		break;
	case PTYPE_RESTART2: //terrible terrible hack, I.. just want to play with Kirt!
		Game_Init();
		break;
	default:
		break;
	}
}


//if elemental game then send the grid info
void Net_SendElementGrid(void)
{
	g_tSocket.PacketSend(g_gameState.elementalGrid, PTYPE_SENDELEM, sizeof(g_gameState.elementalGrid));
}


//arrow decided who goes first
void Net_SendArrowChoice(void)
{
	sendChoice_t s;

	s.decision = g_gameState.arrowDecided;

	if (s.decision == 1)
	{ //each player is 1 to himself
		s.decision = 2;
	}
	else if (s.decision == 2)
	{
		s.decision = 1;
	}

	g_tSocket.PacketSend(&s, PTYPE_SENDCHOICE, sizeof(sendChoice_t));
}


//send special card message
void Net_SendCards(void)
{
	int i = 0;
	int j;
	drawCard_t *c;
	playingCard_t *p;
	sendCard_t s[NUM_CARDS_PER_PLAYER];

	if (!g_gameState.chosenCards)
	{ //don't send them til you're done choosing
		return;
	}

	while (i < NUM_CARDS_PER_PLAYER)
	{
		j = 0;
		c = &g_pl1Cards[i];

		s[i].altTex = !c->altTex;
		assert(s[i].altTex >= 0);
		if (c->gotoGridSpot <= 4) //CPOS_CARDP1_5
		{
			s[i].gotoGrid = c->gotoGridSpot+5;
		}
		else
		{
			s[i].gotoGrid = c->gotoGridSpot;
		}
		s[i].num = i;
		s[i].pCardIndex = 0;
		s[i].selected = c->selected;
		p = c->pCard;
		while (p > &g_playingCards[0])
		{ //the index is the offset count from the base address
			s[i].pCardIndex++;
			p--;
		}
		assert(s[i].pCardIndex >= 0 && s[i].pCardIndex < NUM_CARDS);
		i++;
	}

	g_tSocket.PacketSend(&s[0], PTYPE_SENDCARDS, sizeof(sendCard_t)*NUM_CARDS_PER_PLAYER);
}


//general info to make sure clients stay in sync
void Net_SyncSend(void)
{
	sendGeneral_t s;
	Net_SendCards();

	s.score = g_gameState.p1Total;

	g_tSocket.PacketSend(&s, PTYPE_SENDGENERAL, sizeof(sendGeneral_t));
}


//new round
void Net_SendReinit(void)
{
	g_tSocket.PacketSend(NULL, PTYPE_RESTART, 0);
}


//net init
void Net_Init(void)
{
	if (g_netValid)
	{ //already valid, nothing to do
		return;
	}

	XNetStartupParams xnsp;
	memset(&xnsp, 0, sizeof(xnsp));
	xnsp.cfgSizeOfStruct = sizeof(XNetStartupParams);
	xnsp.cfgFlags = XNET_STARTUP_BYPASS_SECURITY;

	if (XNetStartup(&xnsp))
	{ //failure
		return;
	}

	if (WSAStartup(MAKEWORD(1, 1), &g_wsaData))
	{ //failure
		return;
	}

	g_netValid = true;
}

//create socket
int Net_CreateSocket(CTriadSocket &socket, int port)
{
	if (!socket.Create(port))
	{
		Game_Print("Could not create socket on port %i.", port);
		return 0;
	}
	return 1;
}

//host - listen
void Net_Listen(int port)
{
	Net_Init();

	if (!g_gameRules.set)
	{
		Game_Print("Set up rules before hosting.\n");
		return;
	}

	if (g_socketCreated)
	{
		Game_Print("Already connected.\n");
		return;
	}

	if (g_listenSocketCreated)
	{
		Game_Print("Already listening.\n");
		return;
	}

	if (!Net_CreateSocket(g_tListenSocket, port))
	{
		return;
	}
	g_listenSocketCreated = 1;
	g_hostPlayer = 1;
	g_tListenSocket.Listen();
}

//client - connect
void Net_Connect(int port, const char *ip)
{
	Net_Init();

	if (g_socketCreated)
	{ //already connected
		Game_Print("Already connected.\n");
		return;
	}

	if (!Net_CreateSocket(g_tSocket, port+1))
	{
		return;
	}
	g_socketCreated = 1;

	Sleep(500);

	if (!g_tSocket.Connect(ip, port))
	{
		int err = WSAGetLastError();
		Net_Disconnect();
		Game_Print("Connection to %s on port %i failed. Error %i.", ip, port, err);
		return;
	}

	g_hostPlayer = 1;
	g_tSocket.PacketSend(NULL, PTYPE_REQUESTSTATE, 0);
	Game_Print("Connected!\n");
	Game_Init();
}

//close connection
void Net_Disconnect(void)
{
	if (!g_netValid)
	{ //not valid, nothing to do
		return;
	}

	if (g_socketCreated)
	{
		g_socketCreated = 0;
		g_tSocket.Close();

		memset(&g_gameRules, 0, sizeof(g_gameRules));
		Game_Init();
		Game_Print("Disconnected.");
	}
	if (g_listenSocketCreated)
	{
		g_listenSocketCreated = 0;
		g_tListenSocket.Close();
	}

	XNetCleanup();
	WSACleanup();
	g_netValid = false;
	g_hostPlayer = 1;
}

//XBOX
//i have created some wrapper functionality to try to emulate the
//behaviour of the windows functionality gltriad was using
DWORD WINAPI SocketThread(VOID *pParameter)
{
	CAsyncSocket *self = (CAsyncSocket *)pParameter;
	while (1)
	{
		if (!self->SocketLogic())
		{
			return 0;
		}
		Sleep(10);
	}
}

bool CAsyncSocket::SocketLogic(void)
{
	if (m_listening)
	{
		//SOCKET s = accept(m_socket, (sockaddr *)m_incoming, &m_inLen);
		SOCKET s = accept(m_socket, NULL, NULL);
		if (s != INVALID_SOCKET)
		{ //try accepting and stop listening
			m_acceptSocket = s;
			OnAccept(0);
			m_listening = false;
			m_transferActive = false;
			return false;
		}
	}
	if (m_transferActive)
	{
		while (1)
		{
			if (m_sendBufferSize > 0)
			{
				if (send(m_socket, m_sendBuffer, m_sendBufferSize, 0) != -1)
				{
					m_sendBufferSize = 0;
				}
			}

			//FIXME it would be good to always read a packet header
			//however this would break compatibility with standard
			//gltriad, so for now, hax.
			int r = recv(m_socket, (char *)&m_testBuffer, sizeof(m_testBuffer), 0);
			if (r == SOCKET_ERROR)
			{ //it's all over!
				Net_Disconnect();
				return false;
			}
			if (r == sizeof(m_testBuffer))
			{
				m_hasTestBuffer = true;
				OnReceive(0);
			}
			
			//else
			{
				break;
			}
		}
	}

	return true;
}

CAsyncSocket::CAsyncSocket()
{
	m_socket = INVALID_SOCKET;
	m_netThread = 0;
	m_listening = false;
	m_transferActive = false;

	//m_testBuffer = 0;
	m_hasTestBuffer = false;

	m_sendBufferSize = 0;

	m_acceptSocket = INVALID_SOCKET;
}

CAsyncSocket::~CAsyncSocket()
{
	Close();
}

void CAsyncSocket::OnSend(int nErrorCode)
{
}

void CAsyncSocket::OnAccept(int nErrorCode)
{
}

void CAsyncSocket::OnReceive(int nErrorCode)
{
}

void CAsyncSocket::OnClose(int nErrorCode)
{
}

BOOL CAsyncSocket::Accept(CAsyncSocket& rConnectedSocket,
	SOCKADDR *lpSockAddr, int *lpSockAddrLen)
{
	if (lpSockAddr)
	{
		memcpy(lpSockAddr, m_incoming, m_inLen);
	}
	if (lpSockAddrLen)
	{
		*lpSockAddrLen = m_inLen;
	}
	rConnectedSocket.SetSocket(m_acceptSocket);

	HANDLE thread = CreateThread(NULL, 0, SocketThread, (void *)&rConnectedSocket, 0, NULL);
	if (!thread)
	{
		return FALSE;
	}
	rConnectedSocket.SetNetThread(thread);
	rConnectedSocket.SetActive(true);

	m_transferActive = true;
	return TRUE;
}

void CAsyncSocket::Close()
{
	OnClose(0);
	if (m_socket != INVALID_SOCKET)
	{
		closesocket(m_socket);
		m_socket = INVALID_SOCKET;
	}
	if (m_netThread)
	{
		CloseHandle(m_netThread);
		m_netThread = 0;
	}
}

BOOL CAsyncSocket::Create(UINT nSocketPort, int nSocketType, long lEvent,
	LPCTSTR lpszSocketAddress)
{
	if (m_socket != INVALID_SOCKET)
	{ //already exists
		return FALSE;
	}

	if (m_netThread)
	{ //kill it first if it exists
		CloseHandle(m_netThread);
	}

	m_socket = INVALID_SOCKET;
	m_netThread = 0;
	m_listening = false;
	m_transferActive = false;
	//m_testBuffer = 0;
	m_hasTestBuffer = false;
	m_sendBufferSize = 0;
	m_acceptSocket = INVALID_SOCKET;

	m_socket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
	if (m_socket == INVALID_SOCKET)
	{
		return FALSE;
	}

	m_netThread = CreateThread(NULL, 0, SocketThread, this, 0, NULL);
	if (!m_netThread)
	{
		return FALSE;
	}

	sockaddr_in sockad;

	memset(&sockad, 0, sizeof(sockad));
	sockad.sin_family = AF_INET;
	sockad.sin_port = htons(nSocketPort);

	if (bind(m_socket, (sockaddr *)&sockad, sizeof(sockaddr)))
	{ //error binding..
		return FALSE;
	}
	return TRUE;
}

BOOL CAsyncSocket::Connect(LPCTSTR lpszHostAddress, UINT nHostPort)
{
	sockaddr_in sockad;

	memset(&sockad, 0, sizeof(sockad));
	sockad.sin_addr.s_addr = inet_addr(lpszHostAddress);//inet_addr("192.168.1.101");
	sockad.sin_family = AF_INET;
	sockad.sin_port = htons(nHostPort);

	if (connect(m_socket, (sockaddr *)&sockad, sizeof(sockaddr)))
	{ //error connecting..
		return FALSE;
	}

	m_transferActive = true;

	return TRUE;
}

BOOL CAsyncSocket::Listen(int nConnectionBacklog)
{
	if (listen(m_socket, nConnectionBacklog))
	{
		return FALSE;
	}
	m_listening = true;
	return TRUE;
}

int CAsyncSocket::Receive(void *lpBuf, int nBufLen, int nFlags)
{
	char *buf = (char *)lpBuf;

	if (m_hasTestBuffer)
	{
		m_hasTestBuffer = false;
		byte *testBuffer = (byte *)&m_testBuffer;
		int j = 0;
		while (nBufLen > 0 && j < sizeof(m_testBuffer))
		{
			nBufLen -= 1;
			*(buf+j) = *testBuffer;
			testBuffer++;
			j++;
		}
		if (nBufLen > 0)
		{
			return recv(m_socket, buf+j, nBufLen, nFlags);
		}
		return sizeof(m_testBuffer);
	}

	return recv(m_socket, buf, nBufLen, nFlags);
}

int CAsyncSocket::Send(const void *lpBuf, int nBufLen, int nFlags)
{
	//this would solve the problem of packets that are too big.
	//but it breaks compatibility with gltriad. =|
	/*
	int sent = 0;
	const char *buf = (const char *)lpBuf;

	basePacketHeader_t packets;

	packets.totalSize = nBufLen;
	packets.sequences = 0;
	while (sent < nBufLen)
	{
		packets.sequences++;
		sent += MAX_PACKET_SIZE;
	}

	if (send(m_socket, &packets, sizeof(basePacketHeader_t), 0) == -1)
	{
		return -1;
	}

	sent = 0;
	while (sent < nBufLen)
	{
		int size = MAX_PACKET_SIZE;
		if ((sent+size) > nBufLen)
		{
			size = nBufLen-sent;
		}
		if (send(m_socket, buf+sent, size, nFlags) == -1)
		{
			return -1;
		}

		sent += MAX_PACKET_SIZE;
	}
	*/

	/*
	int s = send(m_socket, (const char *)lpBuf, nBufLen, nFlags);
	int err;
	while (s == SOCKET_ERROR)
	{
		err = WSAGetLastError();
	}
	return s;
	*/
#if 1
	if (m_sendBufferSize+nBufLen > MAX_SEND_BUFFER_SIZE)
	{
		return -1;
	}

	memcpy(&m_sendBuffer[m_sendBufferSize], lpBuf, nBufLen);
	m_sendBufferSize += nBufLen;
#else
	while (send(m_socket, (const char *)lpBuf, nBufLen, nFlags) == -1)
	{
		Sleep(2);
	}
#endif

	return nBufLen;
}
#endif //_USE_WINSOCK
