/********************************************

secondary escape estimation (See)

this version does 3-order weighting by entropy

********************************************/

#include "see.h"
#include "config.h"
#include "context.h"

#include <crblib/intmath.h>
#include <crblib/mempool.h>

//-}{----------- defines ------------------------------------

// the hash is:
// top 5 bits are esc/tot
// next 2 are order
// then 2 for # of chars in parent cntx
// then four pairs of 2 from cntx
// then 5 bits from order0

#define order0_bits (9)		// <> these cutoffs & the hashes could all be tuned
#define order1_bits (16)
#define order2_bits (23)

#define order0_size (1<<order0_bits)
#define order1_size (1<<order1_bits)
#define order2_size (1<<order2_bits)

#define MAX_SEE_ESCC	(3)
#define MAX_SEE_TOTC	(64)

//-}{----------- types & internal protos ------------------

struct SeeState
{
	LinkNode LN;
	SeeState * parent;
	SeeState * child;
	uint hash,seen;
	uint esc,tot;
};

struct See
{
	SeeState order0[order0_size];
	SeeState order1[order1_size];
	MemPool * StatePool;
};

void See_Precondition(See *see);
void SeeState_StatsFromHash(SeeState *ss,uint FiveBits);

//-}{----------- functions ------------------------------------

See *	See_Create(void)
{
See * see;

	see = new(*see);

	see->StatePool = MemPool_Create(sizeof(SeeState),4096,1024);
	
	See_Precondition(see);
	
return see;
}

void	See_Destroy(See * see)
{
	MemPool_Destroy(see->StatePool);
	destroy(see);
}

static void See_GetStats(See * see,SeeState *ss,uint * pEscC,uint * pTotC,uint inEsc,uint inTot)
{
uint e1,e2,e0,t1,t2,t0,h1,h2,h0,s2,s1,s0;
uint tot,esc;

	e2 = ss->esc; t2 = ss->tot; s2 = ss->seen; ss = ss->parent; assert(ss);
	e1 = ss->esc; t1 = ss->tot; s1 = ss->seen; ss = ss->parent; assert(ss);
	e0 = ss->esc; t0 = ss->tot; s0 = ss->seen;

	// spooky : ilog2round (accurate) gives 0.003 bpp over intlog2r (inaccurate for values > 256)

	h2 = (t2<<12)/(t2 * ilog2round(t2) - e2 * ilog2round(e2) - (t2-e2) * ilog2round(t2-e2) + 1);
	h1 = (t1<<12)/(t1 * ilog2round(t1) - e1 * ilog2round(e1) - (t1-e1) * ilog2round(t1-e1) + 1);
	h0 = (t0<<12)/(t0 * ilog2round(t0) - e0 * ilog2round(e0) - (t0-e0) * ilog2round(t0-e0) + 1);

	// give less weight to contexts with only the preconditioned stats
	// this helps a bit; *2,3, or 4 seems the best multiple :
	if ( s0 ) h0 <<= 2;
	if ( s1 ) h1 <<= 2;
	if ( s2 ) h2 <<= 2;

#if 0 // {
	tot = (h0 + h1 + h2);
	esc = (e2*h2/t2 + e1*h1/t1 + e0*h0/t0);
#else // }{
{
	uint ei,ti,hi;

	// also weight in the esc/tot from the context
	//	helps about 0.001 bpc on most files

	ei = inEsc; ti = inTot;
	hi = (ti<<12)/(ti * ilog2round(ti) - ei * ilog2round(ei) - (ti-ei) * ilog2round(ti-ei) + 1);

	tot = (h0 + h1 + h2 + hi);
	esc = (e2*h2/t2 + e1*h1/t1 + e0*h0/t0 + ei*hi/ti);
}
#endif //}

	while ( tot >= 16000 )
	{
		tot >>= 1;
		esc >>= 1;
	}
	if ( esc < 1 ) esc = 1;
	if ( esc >= tot ) tot = esc + 1;

	*pEscC = esc;
	*pTotC = tot;
}

void 	See_EncodeEscape(See *see,arithInfo * ari,SeeState * ss,uint escapeCount,uint totSymCount,bool escape)
{
	if ( ss )
	{
	uint esc,tot;
		See_GetStats(see,ss,&esc,&tot,escapeCount,escapeCount+totSymCount);
		arithEncBit(ari,esc,tot,!escape);
		See_AdjustState(see,ss,escape);
	}
	else
	{
		arithEncBit(ari,totSymCount,escapeCount+totSymCount,escape);
	}
}

bool 	See_DecodeEscape(See *see,arithInfo * ari,SeeState * ss,uint escapeCount,uint totSymCount)
{
	if ( ss )
	{
	bool escape;
	uint esc,tot;
		See_GetStats(see,ss,&esc,&tot,escapeCount,escapeCount+totSymCount);
		escape = arithDecBit(ari,esc,tot);
		escape = ! escape;
		See_AdjustState(see,ss,escape);
		return escape;
	}
	else
	{
		return arithDecBit(ari,totSymCount,escapeCount+totSymCount);
	}
}

uint	See_GetEscapeP(See *see,SeeState *ss,uint escapeCount,uint totSymCount)
{
	if ( ss )
	{
	uint esc,tot;
		See_GetStats(see,ss,&esc,&tot,escapeCount,escapeCount+totSymCount);
		return (esc << PPMZ2_IntProb_Shift)/ tot;
	}
	else
	{
		return (escapeCount << PPMZ2_IntProb_Shift)/(escapeCount + totSymCount);
	}
}

void	See_AdjustState(See *see,SeeState *ss,bool escape)
{
	while(ss)
	{
		ss->seen ++;

		if ( escape )
		{
			ss->esc += See_Inc;
			ss->tot += See_Inc + See_EscTot_ExtraInc;
		}
		else
		{
			// forget escapes very fast
			if ( ss->esc >= See_Esc_ScaleDown )
			{
				ss->esc = (ss->esc >> 1) + 1;
				ss->tot = (ss->tot >> 1) + 2;
			}
			
			ss->tot += See_Inc;
		}

		if ( ss->tot >= See_ScaleDown )
		{
			ss->esc = (ss->esc >> 1) + 1;
			ss->tot = (ss->tot >> 1) + 2;
			assert( ss->tot < See_ScaleDown );
		}

		ss = ss->parent;
	}
}

//-}{----------- See_GetState ------------------------------------

SeeState * See_GetState(See * see,uint escapeCount,uint totSymCount,ulong cntx,int order,int numSyms,const Context * context)
{
uint h,o1;
uint escC,totC;
uint z;
SeeState *ss,*o1ss;

	// do the hash;
	//	order
	//	escC,
	//	totC,
	//	index
	// use MPS count ?

	escC = escapeCount;
	totC = totSymCount;

	if ( totC == 0 )
		return NULL;

	assert( numSyms >= 1 );
	assert( escC >= 1 && totC >= escC );
	totC -= escC;
	escC --;

	if ( escC > MAX_SEE_ESCC || totC >= MAX_SEE_TOTC )
		return NULL;

	/***
	
	<> tune the see hash!
	
	we could make a much fancier hash; right now there are lots
	of redundant bits; 
	
	eg. when order = 0 we dont use any bits for index, so we could
			use more for esc & tot
		conversely when esc is large, we should use fewer bits for order & cntx

	this all makes it much messier to do the precondition..

	****/

	// fill up 15 bits of hash
	h = 0;

	// 2 bits for the escC <= 3
	assert( escC <= 3 );
	h <<= 2;
	h |= escC;
	
	// 3 bits for totC
	h <<= 3;
		 if ( totC <= 0 ) h |= 0;
	else if ( totC <= 1 ) h |= 1;
	else if ( totC <= 2 ) h |= 2;
	else if ( totC <= 4 ) h |= 3;
	else if ( totC <= 6 ) h |= 4;
	else if ( totC <= 9 ) h |= 5;
	else if ( totC <= 13) h |= 6;
	else				  h |= 7;
	
	// 2 bits for the order :
	h <<= 2;
	if ( escC >= 1 )
	{
		if ( order >= 3 )
			h |= 1;
	}
	else
	{
		z = order >> 1;
		if ( z > 3 ) z = 3;
		h |= z;
	}

	// 2 bits for num chars in parent
	// this was Malcolm's idea, and it helps *huge* (meaning about 0.02 bpc)
	//	paper2 -> 2.196 and trans -> 1.229 !!!
	// maybe I should use the actual-coded-parent by LOE instead of the direct parent?
	//	there is a problem there : the LOE decision depends on this!
	h <<= 2;
	if ( context && context->parent )
	{
		// @@ use full or upex for numsyms?
		z = context->parent->upex.numSyms;
		if ( z > 3 ) z = 3;
		h |= z;
	}

	// isdet bool ?
	//  helps a tiny bit (0.001) on files with lots of dets (trans,bib)
	//	doesn't affect others
	h <<= 1;
	if ( numSyms == 1 )
		h |= 1;

	// 8 bits from index : 2 bits from each the last 4 bytes :
	if ( order > 0 ) { h <<= 2; h |= ((cntx>>5)&0x3); }
	if ( order > 1 ) { h <<= 2; h |= ((cntx>>13)&0x3); }
	if ( escC <= 1 )
	{
		if ( order > 2 ) { h <<= 2; h |= ((cntx>>21)&0x3); }
		if ( order > 3 ) { h <<= 2; h |= ((cntx>>29)&0x3); }
	}

	// the bottom 5 bits of index[0] :
	h <<= 5;
	h |= cntx & 31;

	assert( h < order2_size );

	o1 = (h >> (order2_bits - order1_bits));
	assert( o1 < order1_size );
	o1ss = &(see->order1[o1]);
	ss = o1ss->child;
	if ( ss )
	{
		for(;;)
		{
			if ( ss->hash == h )
			{
				// MTF
				o1ss->child = ss;
				return ss;
			}
			ss = LN_Next(ss);
			if ( ss == o1ss->child )
				break;
		}
	}

	ss = MemPool_GetHunk(see->StatePool);
	ss->parent = o1ss;
	ss->hash = h;
	LN_Null(ss);
	if ( o1ss->child )
		LN_AddHead(o1ss->child,ss);
	o1ss->child = ss;

	#if 0 //{
	// inheret

	ss->esc = o1ss->esc;
	ss->tot = o1ss->tot;

	#else //}{
	// a fresh init seems better

	SeeState_StatsFromHash(ss,h>>(order2_bits - 5));

	#endif //}

return ss;
}

//-----------------------------------------------------------------------
//	Precondition ; also sets up parent pointers

void SeeState_StatsFromHash(SeeState *ss,uint FiveBits)
{
uint escC,totCH;
uint escapeCount,totSymCount;
uint seedEsc,seedTot,totC;

	escC = FiveBits >> 3;
	totCH = FiveBits & 7;

	switch(totCH)
	{
		case 0:
		case 1:
		case 2:
			totC = totCH;
			break;
		case 3:
			totC = 3;
			break;
		case 4:
			totC = 5;
			break;
		case 5:
			totC = 8;
			break;
		case 6:
			totC = 11;
			break;
		case 7:
			totC = 20;
			break;
	}

	escapeCount = escC + 1;
	totSymCount = totC + escapeCount;

	seedEsc = escapeCount * See_Init_Scale + See_Init_Esc;
	seedTot = (escapeCount + totSymCount) * See_Init_Scale + See_Init_Tot;
	
	ss->esc = seedEsc;
	ss->tot = seedTot;
}

void See_Precondition(See *see)
{
uint escC,totCH;
uint escapeCount,totSymCount;

	fastRand_Seed(0);

	for(escC = 0; escC <= 3 ; escC ++)
	{
		escapeCount = escC + 1;
		for(totCH = 0; totCH <= 7; totCH ++)
		{
		uint h,h_hi,h_lo;
		uint seedEsc,seedTot,totC;
		uint shift;

			switch(totCH)
			{
				case 0:
				case 1:
				case 2:
					totC = totCH;
					break;
				case 3:
					totC = 3;
					break;
				case 4:
					totC = 5;
					break;
				case 5:
					totC = 8;
					break;
				case 6:
					totC = 11;
					break;
				case 7:
					totC = 20;
					break;
			}

			totSymCount = totC + escapeCount;

			// the 5 bit esc/totC
			h = (escC<<3) + totCH;
			
			h_hi = h;

			seedEsc = escapeCount * See_Init_Scale + See_Init_Esc;
			seedTot = (escapeCount + totSymCount) * See_Init_Scale + See_Init_Tot;

			shift = order1_bits - 5;

			for(h_lo=0;h_lo<(1UL<<shift);h_lo++)
			{
			SeeState * ss;

				h = (h_hi<<shift) | h_lo;

				assert( h < order1_size );

				ss = &(see->order1[h]);

				ss->esc = seedEsc;
				ss->tot = seedTot;
				ss->hash = h;
				
				ss->parent = &(see->order0[ (h>>(order1_bits - order0_bits)) ]);
				ss = ss->parent;
				ss->hash = (h>>(order1_bits - order0_bits));
				ss->esc = seedEsc;
				ss->tot = seedTot;
			}
		}
	}
}
