/* $Id: ac.c,v 1.26 1998/03/16 20:35:59 hatrack Exp $ */

/****************************************************************************/
/*   MPEG4 Visual Texture Coding (VTC) Mode Software                        */
/*                                                                          */
/*   This software was developed by                                         */
/*   Sarnoff Coporation                   and    Texas Instruments          */
/*   Iraj Sodagar   (iraj@sarnoff.com)           Jie Liang (liang@ti.com)   */
/*   Hung-Ju Lee    (hjlee@sarnoff.com)                                     */
/*   Paul Hatrack   (hatrack@sarnoff.com)                                   */
/*   Shipeng Li     (shipeng@sarnoff.com)                                   */
/*   Bing-Bing Chai (bchai@sarnoff.com)                                     */
/*                                                                          */
/* In the course of development of the MPEG-4 standard. This software       */
/* module is an implementation of a part of one or more MPEG-4 tools as     */
/* specified by the MPEG-4 standard.                                        */
/*                                                                          */
/* The copyright of this software belongs to ISO/IEC. ISO/IEC gives use     */
/* of the MPEG-4 standard free license to use this  software module or      */
/* modifications thereof for hardware or software products claiming         */
/* conformance to the MPEG-4 standard.                                      */
/*                                                                          */
/* Those intending to use this software module in hardware or software      */
/* products are advised that use may infringe existing  patents. The        */
/* original developers of this software module and their companies, the     */
/* subsequent editors and their companies, and ISO/IEC have no liability    */
/* and ISO/IEC have no liability for use of this software module or         */
/* modification thereof in an implementation.                               */
/*                                                                          */
/* Permission is granted to MPEG memebers to use, copy, modify,             */
/* and distribute the software modules ( or portions thereof )              */
/* for standardization activity within ISO/IEC JTC1/SC29/WG11.              */
/*                                                                          */
/* Copyright (C) 1998  Sarnoff Coporation and Texas Instruments             */ 
/****************************************************************************/

/************************************************************/
/*     Sarnoff Very Low Bit Rate Still Image Coder          */
/*     Copyright 1995, 1996, 1997, 1998 Sarnoff Corporation */
/************************************************************/

/************************************************************/
/*  Filename: ac.c                                          */
/*  Author: Bing-Bing Chai                                  */
/*  Date Modified: January 6, 1998                          */
/*                                                          */
/*  Descriptions:                                           */
/*    This file contains routines for integer arithmetic    */
/*    coding, which is based on the ac.c file from the SOL  */
/*    package. The original ac.c was obtained from public   */
/*    domain.                                               */
/*                                                          */
/*    The following routines are modified or created for    */
/*    the latest VTC package:                               */
/*      static Void output_bit ()                           */
/*      Void ac_encoder_init ()                             */
/*      int ac_encoder_done ()                              */
/*      static int input_bit ()                             */
/*      Void ac_decoder_done ()                             */
/*                                                          */
/************************************************************/

#include <stdio.h>
#include <stdlib.h>
#include "dataStruct.h"
#include "ac.h" 
#include "bitpack.h"
#include "errorHandler.h"
#include "msg.h"


#define Code_value_bits 16
#define Malloc(size) (malloc(size))
#define Calloc(nelm, elsz) (calloc(nelm, elsz))

#define Top_value (((LInt)1<<Code_value_bits)-1)
#define First_qtr (Top_value/4+1)
#define Half	  (2*First_qtr)
#define Third_qtr (3*First_qtr)


/* static function prototypes */
static Void output_bit (ac_encoder *ace, Int bit);
static Void bit_plus_follow (ac_encoder *ace, Int bit);
static Void update_model (ac_model *acm0, ac_model *acm1, Int sym);
static Int input_bit (ac_decoder *acd);

static zero_accum=0;

/************************************************************************/
/*              Error Checking and Handling Macros                      */
/************************************************************************/

#define error(m)                                           \
do  {                                                      \
  fflush (stdout);                                         \
  fprintf (stderr, "%s:%d: error: ", __FILE__, __LINE__);  \
  fprintf (stderr, m);                                     \
  fprintf (stderr, "\n");                                  \
  exit (1);                                                \
}  while (0)

#define check(b,m)                                         \
do  {                                                      \
  if (b)                                                   \
    error (m);                                             \
}  while (0)

/************************************************************************/
/*                           Bit Output                                 */
/************************************************************************/

/**************************************************/
/*  Added bit stuffing to prevent start code      */
/*  emulation, i.e., add a "1" bit after every 22 */
/*  consecutive "0" bits in the bit stream        */
/*                                                */
/*  Modified to use a fixed buffer and write to   */
/*  file directly after the buffer is full. So the*/
/*  ace->bitstream_len now only has the # of bytes*/
/*  in current buffer. Total bits will indicate   */
/*  the total for arithmetic part.                */
/**************************************************/

static Void output_bit (ac_encoder *ace, Int bit)
{
  ace->buffer <<= 1;
  if (bit)
    ace->buffer |= 0x01;
  
  ace->bits_to_go -= 1;
  ace->total_bits += 1;
  if (ace->bits_to_go==0)  {
    
    if (ace->bitstream) {
      if (ace->bitstream_len >= MAX_BUFFER){ 
        write_to_bitstream(ace->bitstream,MAX_BUFFER<<3);
        ace->bitstream_len=0;
      }
      
      ace->bitstream[ace->bitstream_len++] = ace->buffer;
    }
    
    if (ace->fp) {
      putc (ace->buffer, ace->fp);
    }
    ace->bits_to_go = 8;
  }
  
  if(bit==0){
    zero_accum++;
    if(zero_accum==22){
      output_bit(ace,1);
      zero_accum=0;
    }
  }
  else
    zero_accum=0;

  return;
}

static Void bit_plus_follow (ac_encoder *ace, Int bit)
{
  output_bit (ace, bit);
  while (ace->fbits > 0)  {
    output_bit (ace, !bit);
    ace->fbits -= 1;
  }

  return;
}

/* Void mput_bistream(UChar value, ac_encoder *ace) */
Void mput_bistream(Int value, Int nbits, ac_encoder *ace)
{

  Int extension;
  Int module = 1 << nbits;
  
  if (ace->bitstream) {
    if (ace->bitstream_len >= MAX_BUFFER) 
      if (realloc(ace->bitstream, sizeof(UChar)*
		  (ace->bitstream_len/MAX_BUFFER+1)*MAX_BUFFER+2)==NULL)
      {
	fprintf(stderr,"Couldn't reallocate memory in mput_bistream.\n");
	exit(-1);
      }


    while (value/module > 0) {
      extension = 1;
      ace->bitstream[ace->bitstream_len++] =
          (UChar)( (value%module) | (extension << nbits));
      value = value >> nbits;
    }
    extension = 0;
    ace->bitstream[ace->bitstream_len++] =
          (UChar)( (value%module) | (extension << nbits));
  }
}

/************************************************************************/
/*                             Encoder                                  */
/************************************************************************/

/*****************************************************/
/*  Added zero_accum initlization for stuffing bits  */
/*****************************************************/
Void ac_encoder_init (ac_encoder *ace, C_Char *fn, 
		      UChar *bitstream)
{

  if (fn)  {
    ace->fp = fopen (fn, "wb");
    check (!ace->fp, "arithmetic encoder could not open file");
  }  else  {
    ace->fp = NULL;
  }

  ace->bits_to_go = 8;

  ace->low = 0;
  ace->high = Top_value;
  ace->fbits = 0;
  ace->buffer = 0;

  ace->total_bits = 0;

  /* can't release the memory ! */
  ace->bitstream_len = 0;
  if ((ace->bitstream = bitstream 
	 = (UChar *)malloc((MAX_BUFFER+10)*sizeof(UChar)))==NULL)
      errorHandler("Couldn't reallocate memory for ace->bitstream in " \
		   "output_bit.\n");

  zero_accum=0;

  /* always start arithmetic bitstream with a 1 bit. */
  emit_bits(1,1);

  return;
}



/***************************************************************/
/* Added stuffing bits to prevent start code emulation.        */
/* Reassigned bitbuffer pointer to prevent potential problem.  */
/***************************************************************/
Int ac_encoder_done (ac_encoder *ace,UChar **bitbuffer)
{
  Int bits_to_write=0;
  ace->fbits += 1;
  if (ace->low < First_qtr)
    bit_plus_follow (ace, 0);
  else
    bit_plus_follow (ace, 1);

  if (ace->fp) putc (ace->buffer >> ace->bits_to_go, ace->fp);
  if (ace->fp) fclose (ace->fp);

  if (ace->bits_to_go != 8){
     ace->bitstream[ace->bitstream_len++] = (ace->buffer << ace->bits_to_go);
    if((ace->bitstream[ace->bitstream_len-1]&(1<<ace->bits_to_go)) == 0){
      ace->bitstream[ace->bitstream_len-1] += (1<<ace->bits_to_go)-1;
      ace->total_bits +=1;
    }
  }

  if(ace->bitstream_len>MAX_BUFFER)
    bits_to_write=(MAX_BUFFER<<3);
  bits_to_write +=(ace->total_bits)%(MAX_BUFFER<<3);
  if(bits_to_write ==0 && ace->bitstream_len==MAX_BUFFER)
    bits_to_write=(MAX_BUFFER<<3); 

  write_to_bitstream(ace->bitstream,bits_to_write);

  if(ace->bits_to_go == 8 && (ace->bitstream[ace->bitstream_len-1]&1) ==0){
    /* stuffing bits to prevent start code emulation */
    emit_bits(1,1);
    ace->total_bits +=1;
  }	

  free(ace->bitstream);
  return ace->total_bits;
}

Int ac_encode_symbol (ac_encoder *ace, ac_model *acm0, ac_model *acm1, 
		      Int sym)
{
  LInt range;
  Int tmp;

  if (sym<0 || sym >= ((acm0!=NULL) ? acm0->nsym : acm1->nsym)) 
  {
     fprintf(stderr,"symbol out of range: sym=%d and acm->nsym=%d\n",
	     sym, (acm0!=NULL) ? acm0->nsym : acm1->nsym );
     exit(-1);
  }

  tmp = ace->total_bits;
  range = (LInt)(ace->high-ace->low)+1;

  if (acm1 == NULL) /* zeroth order modeling */
  {
    ace->high = ace->low + 
      (range*(Int)acm0->cfreq[sym])/(Int)acm0->cfreq[0]-1;
    ace->low  = ace->low + 
      (range*(Int)acm0->cfreq[sym+1])/(Int)acm0->cfreq[0];
  }
  else if (acm0 != NULL) /* mixture of zeroth and first order modeling */
  {
    ace->high = ace->low + 
      (range*(Int)(acm0->cfreq[sym]+acm1->cfreq[sym]))
      /(Int)(acm0->cfreq[0]+acm1->cfreq[0])-1;
    ace->low  = ace->low + 
      (range*(Int)(acm0->cfreq[sym+1]+acm1->cfreq[sym+1]))
      /(Int)(acm0->cfreq[0]+acm1->cfreq[0]);
  }
  else /* first order modeling */
  {
    ace->high = ace->low + 
      (range*(Int)(acm1->cfreq[sym]))/(Int)(acm1->cfreq[0])-1;
    ace->low  = ace->low + 
      (range*(Int)(acm1->cfreq[sym+1]))/(Int)(acm1->cfreq[0]);
  }

  for (;;)  {
    if (ace->high<Half)  {
      bit_plus_follow (ace, 0);
    }  
    else if (ace->low>=Half)  {
      bit_plus_follow (ace, 1);
      ace->low -= Half;
      ace->high -= Half;
    }  
    else if (ace->low>=First_qtr && ace->high<Third_qtr)  {
      ace->fbits += 1;
      ace->low -= First_qtr;
      ace->high -= First_qtr;
    }  
    else
      break;
    ace->low = 2*ace->low;
    ace->high = 2*ace->high+1;
  }

  if ((acm0!=NULL) ? acm0->adapt : acm1->adapt)
    update_model (acm0, acm1, sym);

  return (ace->total_bits - tmp);
}



/************************************************************************/
/*                            Bit Input                                 */
/************************************************************************/

/*********************************************************/
/* Modified to be consistant with the functions in       */
/* bitpack.c, i.e., using nextinputbit() to get the new  */
/* bits from the bit stream.                             */
/*                                                       */
/* Added remove stuffing bits, refer to output_bit() for */
/* more details.                                         */
/*********************************************************/
static Int input_bit (ac_decoder *acd)
{
  Int t;

  if (acd->bits_to_go==0)  {
     acd->bits_to_go = 8;
  }

  acd->bits_to_go -= 1;
  acd->total_bits ++;

  t=nextinputbit();

  /* remove stuffing bits */
  if(t==0){
    zero_accum++;
    if(zero_accum==22){
      if((input_bit(acd)) !=1)
        errorHandler("Error in decoding stuffing bits");
      zero_accum=0;
    }
  }
  else
    zero_accum=0;    


  return(t);
}

Int mget_bitstream(Int nbits, ac_decoder *acd)
{
   Int count=0;
   Int word=0;
   Int value=0;
   Int module=1<<(nbits);
   
   do {
     word=acd->bitstream[acd->bitstream_ptr++];
     value += (word & (module-1))<<(count*nbits);
     count++;
   } while (word>>nbits);
   return (value);


   /* return (acd->bitstream[acd->bitstream_ptr++]); */
}


/************************************************************************/
/*                             Decoder                                  */
/************************************************************************/

Void ac_decoder_open (ac_decoder *acd, C_Char *fn)
{
  if (fn != NULL) {
     acd->fp = fopen (fn, "rb");
     check (!acd->fp, "arithmetic decoder could not open file");
  }

  return;
}

Void ac_decoder_init (ac_decoder *acd, C_Char *fn, 
		      UChar *bitstream)
{
  Int i;
 
  /* remove first stuffing bit */
  if(!get_X_bits(1))
    errorHandler("Error in extracting the stuffing bit at the beginning of\n"\
		 "arithmetic decoding.");

  zero_accum=0;
  acd->bits_to_go = 0;
  /*acd->bitstream = bitstream;*/

  acd->value = 0;
  for (i=1; i<=Code_value_bits; i++)  {
    acd->value = 2*acd->value + input_bit(acd); 
  }
  acd->low = 0;
  acd->high = Top_value;

  acd->total_bits = 0;


  return;
}


/*******************************************************/
/* Added restore_arithmetic_offset() called to recover */
/* the extra bits read in by decoder. This routine is  */
/* defined in bitpack.c                                */
/*******************************************************/
Void ac_decoder_done (ac_decoder *acd)
{
  if (acd->fp) fclose (acd->fp);

  restore_arithmetic_offset(acd->bits_to_go);
  acd->total_bits +=acd->bits_to_go;

  if((acd->total_bits) %8 !=0)
    errorHandler("Did not get alignment in arithmetic decoding");
}

Int ac_decode_symbol (ac_decoder *acd, ac_model *acm0, ac_model *acm1)
{
  register high,low,value;
  LInt range;
  Int cum;
  Int sym;

  high=acd->high; low=acd->low; value=acd->value;
  range = (LInt)(high-low)+1;

  if (acm1 == NULL) /* zeroth order modeling */
  {
    cum = (((LInt)(value-low)+1)*(Int)(acm0->cfreq[0])-1)/range;
      
    for (sym = 0; (Int)(acm0->cfreq[sym+1])>cum; sym++)
      /* do nothing */ ;
    
    high = low + (range*(Int)(acm0->cfreq[sym]))/(Int)(acm0->cfreq[0])-1;
    low  = low +  (range*(Int)(acm0->cfreq[sym+1]))/(Int)(acm0->cfreq[0]);
  }
  else if (acm0 != NULL) /* mixture of zeroth and first order modeling */
  {
    cum = (((LInt)(value-low)+1)*
	   (Int)(acm0->cfreq[0]+acm1->cfreq[0])-1)/range;
    
    for (sym = 0; (Int)(acm0->cfreq[sym+1]+acm1->cfreq[sym+1])>cum; sym++)
      /* do nothing */ ;
    
    high = low + (range*(Int)(acm0->cfreq[sym]+acm1->cfreq[sym]))
      /(Int)(acm0->cfreq[0]+acm1->cfreq[0])-1;
    low  = low +  (range*(Int)(acm0->cfreq[sym+1]+acm1->cfreq[sym+1]))
      /(Int)(acm0->cfreq[0]+acm1->cfreq[0]);
  }
  else /* first order modeling */
  {
    cum = (((LInt)(value-low)+1)*(Int)(acm1->cfreq[0])-1)/range;
      
    for (sym = 0; (Int)(acm1->cfreq[sym+1])>cum; sym++)
      /* do nothing */ ;
    
    high = low + (range*(Int)(acm1->cfreq[sym]))/(Int)(acm1->cfreq[0])-1;
    low  = low +  (range*(Int)(acm1->cfreq[sym+1]))/(Int)(acm1->cfreq[0]);
  }
  
  for (;;)  {
    if (high<Half)  {
      /* do nothing */
    }  else if (low>=Half)  {
      value -= Half;
      low -= Half;
      high -= Half;
    }  else if (low>=First_qtr && high<Third_qtr)  {
      value -= First_qtr;
      low -= First_qtr;
      high -= First_qtr;
    }  else
      break;
    low = low<<1;
    high = (high<<1)+1;
    value = (value<<1) + input_bit(acd);
  }
  acd->high=high; acd->low=low; acd->value=value;

  if ((acm0!=NULL) ? acm0->adapt : acm1->adapt)
    update_model (acm0, acm1, sym);

  return sym;
}


/************************************************************************/
/*                       Probability Model                              */
/************************************************************************/

Void ac_model_init (ac_model *acm, Int nsym, UShort *ifreq, 
		    Int adapt, Int inc)
{
  Int i;
  
  acm->inc = inc;
  acm->nsym = nsym; 
  acm->adapt = adapt;

  acm->freq = (UShort *) Malloc (nsym*sizeof (UShort));
  check (!acm->freq, "arithmetic coder model allocation failure");
  acm->cfreq = (UShort *) Calloc (nsym+1, sizeof (UShort));
  check (!acm->cfreq, "arithmetic coder model allocation failure");
  
  if (ifreq)  
  {
    acm->cfreq[acm->nsym] = 0;
    for (i=acm->nsym-1; i>=0; i--)  {
      acm->freq[i] = ifreq[i];
      acm->cfreq[i] = acm->cfreq[i+1] + acm->freq[i];
    }
    
    /* NOTE: This check won't always work for mixture of models */
    if (acm->cfreq[0] > acm->Max_frequency)  {
      Int cum = 0;
      acm->cfreq[acm->nsym] = 0;
      for (i = acm->nsym-1; i>=0; i--)  {
	acm->freq[i] = ((Int)acm->freq[i] + 1) / 2;
	cum += acm->freq[i];
	acm->cfreq[i] = cum;
      }
    }
    
    if (acm->cfreq[0] > acm->Max_frequency)
      error ("arithmetic coder model max frequency exceeded");
  }  
  else  {
    for (i=0; i<acm->nsym; i++) {
      acm->freq[i] = 1;
      acm->cfreq[i] = acm->nsym - i;
    }
    acm->cfreq[acm->nsym] = 0;
  }

}

Void ac_model_done (ac_model *acm)
{
  acm->nsym = 0;
  
  free (acm->freq);
  acm->freq = NULL;
  
  free (acm->cfreq);
  acm->cfreq = NULL;
}

static Void update_model (ac_model *acm0, ac_model *acm1, Int sym)
{
  register UShort *freq0,*cfreq0;
  register UShort *freq1,*cfreq1;
  register i;

  if (acm1 == NULL) /* zeroth order modeling */
  {
    freq0= acm0->freq; 
    cfreq0= acm0->cfreq;

    /* scale freq count down */
    if (cfreq0[0]==acm0->Max_frequency)  {
      register cum=0;
      cfreq0[acm0->nsym] = 0;
      for (i = acm0->nsym-1; i>=0; i--)  {
	freq0[i] = ((Int)freq0[i] + 1) /2;
	cum += freq0[i];
	cfreq0[i] = cum;
      }
    }
    
    freq0[sym] += acm0->inc;
    for (i=sym; i>=0; i--)
      cfreq0[i] += acm0->inc;
  }
  else if (acm0!=NULL) /* mixture of zeroth and first order modeling */
  {
    freq0= acm0->freq; 
    cfreq0= acm0->cfreq;
    freq1= acm1->freq; 
    cfreq1= acm1->cfreq;
    
    /* scale freq count down */
    if ((Int)(cfreq0[0]+cfreq1[0]) >= (Int)acm0->Max_frequency-acm1->inc)  {
      register cum0=0;
      register cum1=0;
      cfreq0[acm0->nsym] = cfreq1[acm1->nsym] = 0;
      for (i = acm0->nsym-1; i>=0; i--)  {
	freq0[i]= ((Int)freq0[i] + 1) /2;
	freq1[i]= ((Int)freq1[i] + 1) /2;
	cum0 += freq0[i];
	cum1 += freq1[i];
	cfreq0[i] = cum0;
	cfreq1[i] = cum1;
      }
    }
    
    freq0[sym] += acm0->inc;
    for (i=sym; i>=0; i--)
      cfreq0[i] += acm0->inc;

    freq1[sym] += acm1->inc;
    for (i=sym; i>=0; i--)
      cfreq1[i] += acm1->inc;
  }
  else /* first order modeling */
  {
    freq1= acm1->freq; 
    cfreq1= acm1->cfreq;
    
    /* scale freq count down */
    if (cfreq1[0]==acm1->Max_frequency)  {
      register cum=0;
      cfreq1[acm1->nsym] = 0;
      for (i = acm1->nsym-1; i>=0; i--)  {
	freq1[i] = ((Int)freq1[i] + 1) /2;
	cum += freq1[i];
	cfreq1[i] = cum;
      }
    }
    
    freq1[sym] += acm1->inc;
    for (i=sym; i>=0; i--)
      cfreq1[i] += acm1->inc;
    
  }

  return;
}

