#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include "mrp.h"

extern POINT dyx[];
extern double sigma_h[], sigma_a[];

uint getbits(FILE *fp, int n)
{
    static int bitpos = 0;
    static uint bitbuf = 0;
    int x = 0;

    if (n <= 0) return (0);
    while (n > bitpos) {
	n -= bitpos;
	x = (x << bitpos) | bitbuf;
	bitbuf = getc(fp) & 0xff;
	bitpos = 8;
    }
    bitpos -= n;
    x = (x << n) | (bitbuf >> bitpos);
    bitbuf &= ((1 << bitpos) - 1);
    return (x);
}

DECODER *init_decoder(FILE *fp)
{
    DECODER *dec;

    dec = (DECODER *)alloc_mem(sizeof(DECODER));
    if (getbits(fp, 16) != MAGIC_NUMBER) {
	fprintf(stderr, "Not a compressed file!\n");
	exit(1);
    }
    dec->version = getbits(fp, 8);
    dec->width = getbits(fp, 16);
    dec->height = getbits(fp, 16);
    dec->maxval = getbits(fp, 16);
    dec->num_comp = getbits(fp, 6);
    dec->num_class = getbits(fp, 6);
    dec->num_group = getbits(fp, 6);
    dec->prd_order = getbits(fp, 6);
    dec->num_pmodel = getbits(fp, 6) + 1;
    dec->coef_precision = getbits(fp, 4) + 1;
    dec->pm_accuracy = getbits(fp, 3) - 1;
    dec->f_huffman = getbits(fp, 1);
    dec->maxprd = dec->maxval << dec->coef_precision;
    dec->predictor = (int **)alloc_2d_array(dec->num_class, dec->prd_order,
					    sizeof(int));
    dec->th = (int **)alloc_2d_array(dec->num_class, dec->num_group - 1,
				     sizeof(int));
    dec->idx_E = (img_t **)alloc_2d_array(dec->height, dec->width,
					  sizeof(img_t));
    dec->class = (uchar **)alloc_2d_array(dec->height, dec->width,
					  sizeof(uchar));
    if (dec->num_pmodel > 1) {
        dec->pm_idx = (int *)alloc_mem(dec->num_group * sizeof(int));
    } else {
        dec->pm_idx = NULL;
    }
    dec->spm.freq = alloc_mem((MAX_SYMBOL * 2 + 1) * sizeof(uint));
    dec->spm.cumfreq = &(dec->spm.freq[MAX_SYMBOL]);
    if (dec->f_huffman == 1) {
	dec->sigma = sigma_h;
    } else {
	dec->sigma = sigma_a;
    }
    dec->mtfbuf = (int *)alloc_mem(dec->num_class * sizeof(int));
    return (dec);
}

int decode_vlc(FILE *fp, VLC *vlc)
{
    int i, k, min, off;
    uint code;

    code = min = off = k = 0;
    for (i = 0; i < vlc->max_len; i++) {
	code = (code << 1) | getbits(fp, 1);
	k = vlc->off[i];
	if (k < 0) {
	    min <<= 1;
	} else {
	    if (code <= vlc->code[vlc->index[k]]) break;
	    min = (vlc->code[vlc->index[k]] + 1) << 1;
	    off = k + 1;
	}
    }
    i = off + code - min;
    return (vlc->index[i]);
}

int decode_golomb(FILE *fp, int m)
{
    int v = 0;
    while (getbits(fp, 1) == 0) {
	v++;
    }
    v = (v << m) | getbits(fp, m);
    return (v);
}

void decode_predictor(FILE *fp, DECODER *dec)
{
    int k, m, cl, coef, sgn;

    if (dec->f_huffman == 1) {
	for (k = 0; k < dec->prd_order; k++) {
	    m = getbits(fp, 4);
	    for (cl = 0; cl < dec->num_class; cl++) {
		coef = decode_golomb(fp, m);
		if (coef > 0) {
		    sgn = getbits(fp, 1);
		    if (sgn) {
			coef = -coef;
		    }
		}
		dec->predictor[cl][k] = coef;
	    }
	}
    } else {
	PMODEL *pm;

	pm = &dec->spm;
	pm->size = MAX_COEF + 18;
	pm->cumfreq[MAX_COEF + 2] = 0;
	for(k = MAX_COEF + 2; k < pm->size; k++) {
	    pm->freq[k] = 1;
	    pm->cumfreq[k + 1] = pm->cumfreq[k] + pm->freq[k];
	}
	for (k = 0; k < dec->prd_order; k++) {
	    m = rc_decode(fp, dec->rc, pm, MAX_COEF + 2, MAX_COEF + 18)
		- (MAX_COEF + 2);
	    set_spmodel(pm, MAX_COEF + 1, m);
	    for (cl = 0; cl < dec->num_class; cl++) {
		coef = rc_decode(fp, dec->rc, pm, 0, MAX_COEF + 1);
		if (coef > 0) {
		    sgn = rc_decode(fp, dec->rc, pm, MAX_COEF+2, MAX_COEF+4)
			    - (MAX_COEF + 2);
		    if (sgn) {
			coef = -coef;
		    }
		}
		dec->predictor[cl][k] = coef;
	    }
	}
    }
    return;
}

void decode_threshold(FILE *fp, DECODER *dec)
{
    int cl, gr, m, k;

    if (dec->f_huffman == 1) {
	m = getbits(fp, 4);
	for (cl = 0; cl < dec->num_class; cl++) {
	    k = 0;
	    for (gr = 1; gr < dec->num_group; gr++) {
		if (k <= MAX_UPARA) {
		    if (getbits(fp, 1)) k += decode_golomb(fp, m) + 1;
		}
		dec->th[cl][gr - 1] = k;
	    }
	}
	if (dec->num_pmodel > 1) {
	    for (k = 1; (1 << k) < dec->num_pmodel; k++);
	    for (gr = 0; gr < dec->num_group; gr++) {
		dec->pm_idx[gr] = getbits(fp, k);
	    }
	}
    } else {
	PMODEL *pm;

	pm = &dec->spm;
	pm->size = 16;
	pm->cumfreq[0] = 0;
	for (k = 0; k < pm->size; k++) {
	    pm->freq[k] = 1;
	    pm->cumfreq[k + 1] = pm->cumfreq[k] + pm->freq[k];
	}
	m = rc_decode(fp, dec->rc, pm, 0, pm->size);
	set_spmodel(pm, MAX_UPARA + 2, m);
	for (cl = 0; cl < dec->num_class; cl++) {
	    k = 0;
	    for (gr = 1; gr < dec->num_group; gr++) {
		if (k <= MAX_UPARA) {
		    k += rc_decode(fp, dec->rc, pm, 0, pm->size - k);
		}
		dec->th[cl][gr - 1] = k;
	    }
	}

	if (dec->num_pmodel > 1) {
	    pm->size = dec->num_pmodel;
	    pm->freq[0] = 0;
	    for (k = 0; k < pm->size; k++) {
		pm->freq[k] = 1;
		pm->cumfreq[k + 1] = pm->cumfreq[k] + pm->freq[k];
	    }
	    for (gr = 0; gr < dec->num_group; gr++) {
		dec->pm_idx[gr] = rc_decode(fp, dec->rc, pm, 0, pm->size);
	    }
	}
    }
    return;
}

void decode_class(FILE *fp, DECODER *dec)
{
    int i, j, k, blk, x, y, num_block;
    uchar *index;

    num_block = dec->height * dec->width / (BLOCK_SIZE * BLOCK_SIZE);
    index = (uchar *)alloc_mem(num_block * sizeof(uchar));
    if (dec->f_huffman == 1) {
	VLC *vlc;
	vlc = (VLC *)alloc_mem(sizeof(VLC));
	vlc->size = dec->num_class;
	vlc->max_len = 16;
	vlc->len = (int *)alloc_mem(vlc->size * sizeof(int));
	for (i = 0; i < vlc->size; i++) {
	    vlc->len[i] = getbits(fp, 4) + 1;
	}
	gen_huffcode(vlc);
	for (blk = 0; blk < num_block; blk++) {
	    index[blk] = decode_vlc(fp, vlc);
	}
	free_vlc(vlc);
    } else {
	double p;
	int l;
	PMODEL *pm;
	range_t freqoff;

	pm = &dec->spm;
	pm->size = dec->num_class + PMCLASS_LEVEL;
	pm->cumfreq[dec->num_class] = 0;
	for (i = dec->num_class; i < pm->size; i++) {
	    pm->freq[i] = 1;
	    pm->cumfreq[i + 1] = pm->cumfreq[i] + pm->freq[i];
	}
	freqoff = pm->cumfreq[dec->num_class];
	for (i = 0; i < dec->num_class; i++) {
	    l = rc_decode(fp, dec->rc, pm, dec->num_class, pm->size);
	    p = exp(-log(2.0) * ((double)(l - dec->num_class)+0.5)
		    * PMCLASS_MAX/PMCLASS_LEVEL);
	    if ((pm->cumfreq[pm->size] - freqoff) < (MAX_TOTFREQ << 1)) {
		for (; l < pm->size; l++) {
		    pm->freq[l] *= 2;
		    pm->cumfreq[l + 1] = pm->cumfreq[l] + pm->freq[l];
		}
	    }
	    pm->freq[i] = p * (1 << 16);
	    if (pm->freq[i] == 0) pm->freq[i]++;
	}
	pm->cumfreq[0] = 0;
	for (i = 0; i < dec->num_class; i++) {
	    pm->cumfreq[i + 1] = pm->cumfreq[i] + pm->freq[i];
	}
	for (blk = 0; blk < num_block; blk++) {
	    index[blk] = rc_decode(fp, dec->rc, pm, 0, dec->num_class);
	}
    }

    for (i = 0; i < dec->num_class; i++) {
	dec->mtfbuf[i] = i;
    }
    blk = 0;
    for (y = 0; y < dec->height; y += BLOCK_SIZE) {
	for (x = 0; x < dec->width; x += BLOCK_SIZE) {
	    mtf_classlabel(x, y, dec->class, dec->mtfbuf,
			   dec->width, dec->num_class);
	    i = index[blk++];
	    /* conversion */
	    for (k = 0; k < dec->num_class; k++) {
		if (dec->mtfbuf[k] == i) break;
	    }
	    for (i = 0; i < BLOCK_SIZE; i++) {
		for (j = 0; j < BLOCK_SIZE; j++) {
		    dec->class[y + i][x + j] = k;
		}
	    }
	}
    }
    free(index);
    return;
}

int calc_prd(IMAGE *img, DECODER *dec, int cl, int y, int x)
{
    int k, prd, prd_order, rx, ry, *coef_p;

    prd_order = dec->prd_order;
    prd = 0;
    coef_p = dec->predictor[cl];
    if (y == 0) {
	if (x == 0) {
	    for (k = 0; k < prd_order; k++) {
		prd += *coef_p++;
	    }
	    prd *= ((img->maxval + 1) >> 1);
	} else {
	    ry = 0;
	    for (k = 0; k < prd_order; k++) {
		rx = x + dyx[k].x;
		if (rx < 0) rx = 0;
		else if (rx >= x) rx = x - 1;
		prd += (*coef_p++) * img->val[ry][rx];
	    }
	}
    } else {
	if (x == 0) {
	    for (k = 0; k < prd_order; k++) {
		ry = y + dyx[k].y;
		if (ry < 0) ry = 0;
		else if (ry >= y) ry = y - 1;
		rx = x + dyx[k].x;
		if (rx < 0) rx = 0;
		prd += (*coef_p++) * img->val[ry][rx];
	    }
	} else {
	    for (k = 0; k < prd_order; k++) {
		ry = y + dyx[k].y;
		if (ry < 0) ry = 0;
		rx = x + dyx[k].x;
		if (rx < 0) rx = 0;
		else if (rx >= img->width) rx = img->width - 1;
		prd += (*coef_p++) * img->val[ry][rx];
	    }
	}
    }
    if (prd < 0) prd = 0;
    else if (prd > dec->maxprd) prd = dec->maxprd;
    return (prd);
}

IMAGE *decode_image(FILE *fp, DECODER *dec)
{
    int x, y, cl, gr, prd, u, e, E, p;
    int *th_p;
    IMAGE *img;

    img = alloc_image(dec->width, dec->height, dec->maxval);
    if (dec->f_huffman == 1) {
	VLC *vlc;
	dec->vlcs = init_vlcs(dec->pmodels, dec->num_group, 1);
	for (y = 0; y < dec->height; y++) {
	    for (x = 0; x < dec->width; x++) {
		cl = dec->class[y][x];
		u = calc_u(dec->idx_E, dec->width, y, x);
		th_p = dec->th[cl];
		for (gr = 0; gr < dec->num_group - 1; gr++) {
		    if (u < *th_p++) break;
		}
		prd = calc_prd(img, dec, cl, y, x);
		prd >>= (dec->coef_precision - 1);
		p = (prd + 1) >> 1;
		vlc = &dec->vlcs[gr][0];
		dec->idx_E[y][x] = E = decode_vlc(fp, vlc);
		e = E2e(E, p, prd & 1, dec->maxval);
		img->val[y][x] = p + e;
	    }
	}
    } else {
	PMODEL *pm;
	if (dec->pm_accuracy < 0) {
	    for (y = 0; y < dec->height; y++) {
		for (x = 0; x < dec->width; x++) {
		    cl = dec->class[y][x];
		    u = calc_u(dec->idx_E, dec->width, y, x);
		    th_p = dec->th[cl];
		    for (gr = 0; gr < dec->num_group - 1; gr++) {
			if (u < *th_p++) break;
		    }
		    prd = calc_prd(img, dec, cl, y, x);
		    prd >>= (dec->coef_precision - 1);
		    p = (prd + 1) >> 1;
		    pm = dec->pmodels[gr][0];
		    dec->idx_E[y][x] = E = 
			rc_decode(fp, dec->rc, pm, 0, pm->size);
		    e = E2e(E, p, prd & 1, dec->maxval);
		    img->val[y][x] = p + e;
		}
	    }
	} else {
	    int mask, shift, base;
	    mask = (1 << dec->pm_accuracy) - 1;
	    shift = dec->coef_precision - dec->pm_accuracy;
	    for (y = 0; y < dec->height; y++) {
		for (x = 0; x < dec->width; x++) {
		    cl = dec->class[y][x];
		    u = calc_u(dec->idx_E, dec->width, y, x);
		    th_p = dec->th[cl];
		    for (gr = 0; gr < dec->num_group - 1; gr++) {
			if (u < *th_p++) break;
		    }
		    prd = calc_prd(img, dec, cl, y, x);
		    base = (dec->maxprd - prd + (1 << shift) / 2) >> shift;
		    pm = dec->pmodels[gr][0] + (base & mask);
		    base >>= dec->pm_accuracy;
		    prd >>= (dec->coef_precision - 1);
		    p = rc_decode(fp, dec->rc, pm, base, base+dec->maxval+1)
			- base;
		    img->val[y][x] = p;
		    e = (p << 1) - prd;
		    if (e < 0) e = -(e + 1);
		    if (e > dec->maxval) e = dec->maxval;
		    dec->idx_E[y][x] = e;
		}
	    }
	}
    }
    return (img);
}

void write_pgm(IMAGE *img, char *filename)
{
    int i, j;
    FILE *fp;
    fp = fileopen(filename, "wb");
    fprintf(fp, "P5\n%d %d\n%d\n", img->width, img->height, img->maxval);
    for (i = 0; i < img->height; i++) {
	for (j = 0; j < img->width; j++) {
	    putc(img->val[i][j], fp);
	}
    }
    fclose(fp);
    return;
}

int main(int argc, char **argv)
{
    int i;
    IMAGE *img;
    DECODER *dec;
    char *infile, *outfile;
    FILE *fp;

    cpu_time();
    setbuf(stdout, 0);
    infile = outfile = NULL;
    for (i = 1; i < argc; i++) {
	if (infile == NULL) {
	    infile = argv[i];
	} else {
	    outfile = argv[i];
	}
    }
    if (infile == NULL || outfile == NULL) {
        printf(BANNER"\n", 0.1 * VERSION);
	printf("usage: decmrp infile outfile\n");
	printf("infile:     Input file\n");
	printf("outfile:    Output file\n");
        exit(0);
    }
    fp = fileopen(infile, "rb");
    dec = init_decoder(fp);
    if (dec->f_huffman == 0) {
	dec->rc = rc_init();
	rc_startdec(fp, dec->rc);
    }
    decode_class(fp, dec);
    decode_predictor(fp, dec);
    decode_threshold(fp, dec);
    dec->pmodels = init_pmodels(dec->num_group, dec->num_pmodel,
				dec->pm_accuracy, dec->pm_idx, dec->sigma,
				dec->maxval + 1);
    img = decode_image(fp, dec);
    fclose(fp);
    write_pgm(img, outfile);
    printf("cpu time :%.2f sec.\n", cpu_time());
    return(0);
}
