#include "context.h"
#include "config.h"
#include <crblib/mempool.h>

/*************

Context : just track seen symbols, counts & escapes; no coding here

the only fudgy bit here is the Update_See : we actually do this for
	the *2nd* time on contexts we code from, since EncodeEscape does an additional Adjust for us
	however, the states we update here may not be the same ones we used, since the actual coding & LOE
		use excluded states..

*************/

//-----------------------------------------
// static pools
static MemPool * ContextPool = NULL;
static int ContextPoolCount = 0;
static MemPool * ContextNodePool = NULL;
static int ContextNodePoolCount = 0;
//-----------------------------------------

Context * Context_Create(Context * parent,uint index)
{
Context * cntx;
int order;

	cntx = AutoPool_GetHunk(&ContextPool,&ContextPoolCount,sizeof(Context));
	assert(cntx);

	// all is zero

	assert(cntx->upex.syms == NULL);
	LN_Null( cntx );
	LN_Null(&(cntx->LRU));

	cntx->parent = parent;
	cntx->index = index;

	if ( parent )
		order = parent->order + 1;
	else
		order = 0;

	cntx->order = order;

	if ( parent )
	{
		if ( parent->child )
		{
			LN_AddHead(parent->child,cntx);
		}
		else
		{
			parent->child = cntx;
		}
	}	

return cntx;
}

Context * Context_Find(Context *parent,ulong index)
{
Context *base;
	
	base = parent->child;
	if ( base )
	{
	Context *c;
		c = base;
		for(;;)
		{
			if ( c->index == index )
			{
				// found it, do MTF
				// this is really easy, cuz we have a circular list,
				//	we just change the start point!
				parent->child = c;
				return c;
			}
			c = LN_Next(c);
			if ( c == base )
				break;
		}
	}

return NULL;
}

void Context_DestroyAllContexts(void)
{
	AutoPool_Destroy(&ContextPool,&ContextPoolCount);
	AutoPool_Destroy(&ContextNodePool,&ContextNodePoolCount);
}

void Context_Destroy(Context * cntx)
{
ContextNode *syms,*next;

	if ( ! cntx ) return;

	LN_Cut( cntx );
	LN_Cut(&(cntx->LRU));

	syms = cntx->full.syms;
	while(syms)
	{
		next = syms->next;
		AutoPool_FreeHunk(&ContextNodePool,&ContextNodePoolCount,syms);
		syms = next;
	}
	syms = cntx->upex.syms;
	while(syms)
	{
		next = syms->next;
		AutoPool_FreeHunk(&ContextNodePool,&ContextNodePoolCount,syms);
		syms = next;
	}
	
	AutoPool_FreeHunk(&ContextPool,&ContextPoolCount,cntx);
}

void ContextData_Update(ContextData * cntx ,int sym, int order, ulong index, See *see,Context *context)
{
ContextNode *node;
bool escape;

	// <> could track the 'if I had coded' entropy here
	
	if ( cntx->totSymCount >= Context_CharCountScaleDown )
	{
		ContextData_Halve(cntx);
	}

	cntx->lastSym = sym;

	node = cntx->syms;
	while(node)
	{
		if ( node->sym == sym )
		{
			if ( node->count <= Context_SymIncNovel )
			{
				cntx->escapeCount -= Context_EscpInc;
				node->count += Context_SymInc - Context_SymIncNovel;
				cntx->totSymCount += Context_SymInc - Context_SymIncNovel;
				if ( cntx->escapeCount <= 0 )
					cntx->escapeCount = 1;
			}
			node->count += Context_SymInc;
			cntx->totSymCount += Context_SymInc;

			escape = false;
	
			goto gotNode;
		}
		node = node->next;
	}

	escape = true;

	node = AutoPool_GetHunk(&ContextNodePool,&ContextNodePoolCount,sizeof(ContextNode));
	assert(node);

	node->next = cntx->syms;
	cntx->syms = node;

	node->sym = sym;
	node->count = Context_SymIncNovel;
	cntx->totSymCount += Context_SymIncNovel;	

	if ( cntx->escapeCount < Context_Escape_Max )
	{
		cntx->escapeCount += Context_EscpInc;
	}

	cntx->numSyms++;

	gotNode: //---------------

	cntx->largestCount = max(cntx->largestCount,node->count);

	if ( see )
	{
		// note that this may or may not be the same state that we coded from, because
		//	of exclusions & such
		See_AdjustState(see,cntx->seeState,escape);
		cntx->seeState = See_GetState(see,cntx->escapeCount,cntx->totSymCount,index,order,cntx->numSyms,context);
	}
	else
	{
		cntx->seeState = NULL;
	}
}

void Context_Update(Context * cntx ,int sym, ulong index, See *see,int codedOrder)
{

	assert( ! cntx->parent || cntx->parent->order == (cntx->order - 1) );
	assert( cntx->order == PPMZ2_Order || ! cntx->child  || cntx->child->order  == (cntx->order + 1) );
	assert( ((Context *)LN_Next(cntx))->order == cntx->order );

	// fulls don't do See Updates
	ContextData_Update(&(cntx->full),sym,cntx->order,index,NULL,cntx);

	if ( cntx->order >= codedOrder )
		ContextData_Update(&(cntx->upex),sym,cntx->order,index,see,cntx);

	// not necessarily true cuz of halvings
	assert( cntx->full.totSymCount >= cntx->upex.totSymCount || cntx->full.totSymCount >= ((Context_CharCountScaleDown>>1)-256) );
	assert( cntx->full.numSyms >= cntx->upex.numSyms || cntx->full.totSymCount >= ((Context_CharCountScaleDown>>1)-256));
}

void Context_Halve(Context * cntx)
{
	ContextData_Halve(&(cntx->full));
	ContextData_Halve(&(cntx->upex));
}

void ContextData_Halve(ContextData * cntx)
{
ContextNode *node,**nodePtr;

	cntx->totSymCount = 0;
	cntx->numSyms = 0;
	cntx->largestCount = 0;

	nodePtr = &(cntx->syms);
	while( (node = *nodePtr) != NULL )
	{
		node->count = (node->count)>>1;

		if ( node->count == 0 )
		{
			*nodePtr = node->next;
			AutoPool_FreeHunk(&ContextNodePool,&ContextNodePoolCount,node);
		}
		else
		{
			if ( node->count <= Context_SymIncNovel )
				node->count = Context_SymIncNovel + 1;

			cntx->totSymCount += node->count;

			cntx->numSyms++;

			cntx->largestCount = max(cntx->largestCount,node->count);

			nodePtr = &(node->next);
		}
	}
	
	cntx->escapeCount = ((cntx->escapeCount)>>1) + 1;
}

ContextNode * ContextNode_Create(void)
{
return AutoPool_GetHunk(&ContextNodePool,&ContextNodePoolCount,sizeof(ContextNode));
}

void ContextNode_Destroy(ContextNode * n)
{
	AutoPool_FreeHunk(&ContextNodePool,&ContextNodePoolCount,n);
}

void ContextData_GetExcludedInfo(ContextData *cd,Exclude *exc,int *pTotCount,int *pLargestCount,int *pEscapeCount)
{
int largestCount,totCount,escapeCount;

	if ( Exclude_IsEmpty(exc) )
	{
		totCount = cd->totSymCount;
		largestCount = cd->largestCount;
		escapeCount = cd->escapeCount;
	}
	else
	{
	ContextNode * n;

		// escape from un-excluded counts
		//	also count the excluded escape syms, but not as hard
		// rig up the counding so that 1 excluded -> 1 final count

		// helped paper2 2.193 -> 2.188 bpc !!
		// you can get 0.001 bpc by tweaking all these constants :

		// these counts make 1 exc -> 1, 2 -> 2, and 3 -> 2 , etc.
		//	so for low-escape contexts, we get the same counts, and for low-orders we get
		//	much lower escape counts

		largestCount = 0;
		totCount = 0;
		escapeCount = Context_Excluded_Escape_Init;

		for(n = cd->syms;n;n=n->next)
		{
			if ( ! isExcluded(exc,n->sym) )
			{
				totCount += n->count;
				if ( n->count > largestCount )
					largestCount = n->count;

				if ( n->count <= Context_SymIncNovel )
					escapeCount += Context_Excluded_Escape_Inc;
			}
			else
			{
				if ( n->count <= Context_SymIncNovel )
					escapeCount += Context_Excluded_Escape_ExcludedInc;
			}
		}

		escapeCount >>= Context_Excluded_Escape_Shift;
	}

	*pTotCount = totCount;
	*pLargestCount = largestCount;
	*pEscapeCount = escapeCount;
}

bool Context_ChooseFull(Context *c,Exclude *exc,See *see,ulong cntx)
{
int fLargestCount,fTotCount,fEscapeCount,fRating;
int uLargestCount,uTotCount,uEscapeCount,uRating;
SeeState * ss;
	
	return false; // 'text100k' : full helps 28975 to 28967

	// don't use non-det fulls
//	if ( fLargestCount != fTotCount )
	if ( c->full.numSyms > 1 )
		return false;
	if ( c->full.totSymCount < 16 )
		return false;

	ContextData_GetExcludedInfo(&(c->full),exc,&fTotCount,&fLargestCount,&fEscapeCount);
	if ( fTotCount == 0 )
		return false;

	ContextData_GetExcludedInfo(&(c->upex),exc,&uTotCount,&uLargestCount,&uEscapeCount);
	if ( uTotCount == 0 )
		return true;

	if ( fTotCount < fEscapeCount ) ss = NULL;
	else							ss = See_GetState(see,fEscapeCount,fTotCount,cntx,c->order,c->full.numSyms,c);

	fRating = ((PPMZ2_IntProb_One - See_GetEscapeP(see,ss,fEscapeCount,fTotCount))
				* fLargestCount ) / fTotCount;
				
	if ( uTotCount < uEscapeCount ) ss = NULL;
	else							ss = See_GetState(see,uEscapeCount,uTotCount,cntx,c->order,c->upex.numSyms,c);

	uRating = ((PPMZ2_IntProb_One - See_GetEscapeP(see,ss,uEscapeCount,uTotCount))
				* uLargestCount ) / uTotCount;

	// and even some more :
//	fRating >>= 1;

	if ( uRating >= fRating )
		return false;

return true;
}

void ContextData_AddCounts(ContextData *cntx,uint sym,uint count)
{
ContextNode * cn;

	if ( count == 0 )
		return;

	for(cn=cntx->syms;cn;cn=cn->next)
	{
		if ( cn->sym == sym )
		{
			cn->count += count;
			cntx->totSymCount += count;
			cntx->largestCount = max(cntx->largestCount,cn->count);
			cntx->lastSym = sym;
			return;
		}
	}
	
	cn = AutoPool_GetHunk(&ContextNodePool,&ContextNodePoolCount,sizeof(ContextNode));
	assert(cn);

	cn->next = cntx->syms;
	cntx->syms = cn;

	cn->sym = sym;
	cn->count = count;
	cntx->totSymCount += count;
	cntx->largestCount = max(cntx->largestCount,cn->count);
	cntx->lastSym = sym;
	cntx->numSyms ++;

return;
}

void ContextData_Normalize(ContextData *cntx)
{
	while ( cntx->totSymCount >= Context_CharCountScaleDown )
	{
		ContextData_Halve(cntx);
	}
}

#include <crblib/intmath.h>

uint ContextData_GetCodeLen(ContextData *cd,uint sym,int order,ulong index,See * see,
										Context *cntx,Exclude *exc)
{
int len;

	len = 0;

	Exclude_Clear(exc);

	for(;;)
	{
	int low,high,tot,escTot;
	int largest,escape;
	ContextNode * n;
	SeeState *ss;
	int iescP;
	float escP;

		// cd != &(cntx->upex) on the first time through
		//  for lower orders, they are equal

		assert( ! isExcluded(exc,sym) );

		if ( cd->totSymCount == 0 )
			goto escaped;

		ContextData_GetExcludedInfo(cd,exc,&tot,&largest,&escape);

		if ( tot == 0 ) // no chars unexcluded
			goto escaped;

		ContextData_GetExcludedInfo(&(cntx->upex),exc,&escTot,&largest,&escape);

		low = high = 0;

		for(n = cd->syms;n;n=n->next)
		{
			assert( n->count > 0 );
			if ( ! isExcluded(exc,n->sym) )
			{
				if ( n->sym == sym )
				{
					// found it :
					high = low + n->count;
					assert( high > 0 );
				}
				else if ( high == 0 )
				{
					low += n->count;
				}

				setExcluded(exc,n->sym);
			}
		}

		if ( escape > escTot ) ss = NULL;
		else ss = See_GetState((See *)see,escape,escTot,index,order,cntx->upex.numSyms,cntx);

		iescP = See_GetEscapeP((See *)see,ss,escape,escTot);
		escP = (float) iescP / PPMZ2_IntProb_One;

		if ( high )
		{
		float symP;
			// found it
			len -= flog2x16( 1.0f - escP );
			symP = (high - low)/(float)tot;
			len -= flog2x16( symP );
			return len;
		}
		else
		{
			len -= flog2x16( escP );
			goto escaped;
		}

	escaped:

		order --;
		cntx = cntx->parent;

		if ( order == -1 )
		{
		float symP;
		int i;

			// count # of syms not excluded :
			tot = 0;
			for(i=0;i<256;i++)
			{
				if ( ! isExcluded(exc,i) )
					tot ++;
			}

			assert( tot > 0 );
			// found it
			symP = 1/(float)tot;
			len -= flog2x16( symP );

			return len;
		}

		assert( order >= 0 );
		assert( cntx );

		cd = &(cntx->upex);
	}
}
