// ==========================================================================
// Author:  Yee Hsu
// Date:    6/8/2006
//
// Desc:    LZARI Compressor. This is a compression library that compresses
//          any file by up to 90%. Very efficient and handy if you want to
//          save storage space by compressing files into smaller ones.
// ==========================================================================

#include "stdafx.h"
#include "LZARI_Compress.h"

// init some static vars
unsigned int LZARI_Compress::getbuffer    = 0;
unsigned int LZARI_Compress::putbuffer    = 0;
unsigned int LZARI_Compress::getmask    = 0;
unsigned int LZARI_Compress::putmask    = 128;

LZARI_Compress::LZARI_Compress()
{
    this->ClearAll();
}

LZARI_Compress::~LZARI_Compress()
{
}

void LZARI_Compress::Error(char *message)
{
    //printf("\n%s\n", message);
    exit(EXIT_FAILURE);
}

void LZARI_Compress::PutBit(int bit)  /* Output one bit (bit = 0,1) */
{
    //static unsigned int  buffer = 0, mask = 128;

    if (bit) putbuffer |= putmask;
    if ((putmask >>= 1) == 0) {
        if (putc(putbuffer, outfile) == EOF) Error("Write Error");
        putbuffer = 0;  putmask = 128;  codesize++;
    }
}

void LZARI_Compress::FlushBitBuffer(void)  /* Send remaining bits */
{
    int  i;

    for (i = 0; i < 7; i++) PutBit(0);
}

int LZARI_Compress::GetBit(void)  /* Get one bit (0 or 1) */
{
    //static unsigned int  buffer, mask = 0;

    if ((getmask >>= 1) == 0) {
        getbuffer = getc(infile);  getmask = 128;
    }
    return ((getbuffer & getmask) != 0);
}

void LZARI_Compress::InitTree(void)  /* Initialize trees */
{
    int  i;

    /* For i = 0 to N - 1, rson[i] and lson[i] will be the right and
    left children of node i.  These nodes need not be initialized.
    Also, dad[i] is the parent of node i.  These are initialized to
    NIL (= N), which stands for 'not used.'
    For i = 0 to 255, rson[N + i + 1] is the root of the tree
    for strings that begin with character i.  These are initialized
    to NIL.  Note there are 256 trees. */

    for (i = N + 1; i <= N + 256; i++) rson[i] = NIL;    /* root */
    for (i = 0; i < N; i++) dad[i] = NIL;    /* node */
}

void LZARI_Compress::InsertNode(int r)
/* Inserts string of length F, text_buf[r..r+F-1], into one of the
trees (text_buf[r]'th tree) and returns the longest-match position
and length via the global variables match_position and match_length.
If match_length = F, then removes the old node in favor of the new
one, because the old one will be deleted sooner.
Note r plays double role, as tree node and position in buffer. */
{
    int  i, p, cmp, temp;
    unsigned char  *key;

    cmp = 1;  key = &text_buf[r];  p = N + 1 + key[0];
    rson[r] = lson[r] = NIL;  match_length = 0;
    for ( ; ; ) {
        if (cmp >= 0) {
            if (rson[p] != NIL) p = rson[p];
            else {  rson[p] = r;  dad[r] = p;  return;  }
        } else {
            if (lson[p] != NIL) p = lson[p];
            else {  lson[p] = r;  dad[r] = p;  return;  }
        }
        for (i = 1; i < F; i++)
            if ((cmp = key[i] - text_buf[p + i]) != 0)  break;
        if (i > THRESHOLD) {
            if (i > match_length) {
                match_position = (r - p) & (N - 1);
                if ((match_length = i) >= F) break;
            } else if (i == match_length) {
                if ((temp = (r - p) & (N - 1)) < match_position)
                    match_position = temp;
            }
        }
    }
    dad[r] = dad[p];  lson[r] = lson[p];  rson[r] = rson[p];
    dad[lson[p]] = r;  dad[rson[p]] = r;
    if (rson[dad[p]] == p) rson[dad[p]] = r;
    else                   lson[dad[p]] = r;
    dad[p] = NIL;  /* remove p */
}

void LZARI_Compress::DeleteNode(int p)  /* Delete node p from tree */
{
    int  q;

    if (dad[p] == NIL) return;  /* not in tree */
    if (rson[p] == NIL) q = lson[p];
    else if (lson[p] == NIL) q = rson[p];
    else {
        q = lson[p];
        if (rson[q] != NIL) {
            do {  q = rson[q];  } while (rson[q] != NIL);
            rson[dad[q]] = lson[q];  dad[lson[q]] = dad[q];
            lson[q] = lson[p];  dad[lson[p]] = q;
        }
        rson[q] = rson[p];  dad[rson[p]] = q;
    }
    dad[q] = dad[p];
    if (rson[dad[p]] == p) rson[dad[p]] = q;
    else                   lson[dad[p]] = q;
    dad[p] = NIL;
}

/********** Arithmetic Compression **********/

void LZARI_Compress::StartModel(void)  /* Initialize model */
{
    int ch, sym, i;

    sym_cum[N_CHAR] = 0;
    for (sym = N_CHAR; sym >= 1; sym--) {
        ch = sym - 1;
        char_to_sym[ch] = sym;  sym_to_char[sym] = ch;
        sym_freq[sym] = 1;
        sym_cum[sym - 1] = sym_cum[sym] + sym_freq[sym];
    }
    sym_freq[0] = 0;  /* sentinel (!= sym_freq[1]) */
    position_cum[N] = 0;
    for (i = N; i >= 1; i--)
        position_cum[i - 1] = position_cum[i] + 10000 / (i + 200);
    /* empirical distribution function (quite tentative) */
    /* Please devise a better mechanism! */
}

void LZARI_Compress::UpdateModel(int sym)
{
    int i, c, ch_i, ch_sym;

    if (sym_cum[0] >= MAX_CUM) {
        c = 0;
        for (i = N_CHAR; i > 0; i--) {
            sym_cum[i] = c;
            c += (sym_freq[i] = (sym_freq[i] + 1) >> 1);
        }
        sym_cum[0] = c;
    }
    for (i = sym; sym_freq[i] == sym_freq[i - 1]; i--) ;
    if (i < sym) {
        ch_i = sym_to_char[i];    ch_sym = sym_to_char[sym];
        sym_to_char[i] = ch_sym;  sym_to_char[sym] = ch_i;
        char_to_sym[ch_i] = sym;  char_to_sym[ch_sym] = i;
    }
    sym_freq[i]++;
    while (--i >= 0) sym_cum[i]++;
}

void LZARI_Compress::Output(int bit)  /* Output 1 bit, followed by its complements */
{
    PutBit(bit);
    for ( ; shifts > 0; shifts--) PutBit(! bit);
}

void LZARI_Compress::EncodeChar(int ch)
{
    int  sym;
    unsigned long int  range;

    sym = char_to_sym[ch];
    range = high - low;
    high = low + (range * sym_cum[sym - 1]) / sym_cum[0];
    low +=       (range * sym_cum[sym    ]) / sym_cum[0];
    for ( ; ; ) {
        if (high <= Q2) Output(0);
        else if (low >= Q2) {
            Output(1);  low -= Q2;  high -= Q2;
        } else if (low >= Q1 && high <= Q3) {
            shifts++;  low -= Q1;  high -= Q1;
        } else break;
        low += low;  high += high;
    }
    UpdateModel(sym);
}

void LZARI_Compress::EncodePosition(int position)
{
    unsigned long int  range;

    range = high - low;
    high = low + (range * position_cum[position    ]) / position_cum[0];
    low +=       (range * position_cum[position + 1]) / position_cum[0];
    for ( ; ; ) {
        if (high <= Q2) Output(0);
        else if (low >= Q2) {
            Output(1);  low -= Q2;  high -= Q2;
        } else if (low >= Q1 && high <= Q3) {
            shifts++;  low -= Q1;  high -= Q1;
        } else break;
        low += low;  high += high;
    }
}

void LZARI_Compress::EncodeEnd(void)
{
    shifts++;
    if (low < Q1) Output(0);  else Output(1);
    FlushBitBuffer();  /* flush bits remaining in buffer */
}

int LZARI_Compress::BinarySearchSym(unsigned int x)
/* 1      if x >= sym_cum[1],
N_CHAR if sym_cum[N_CHAR] > x,
i such that sym_cum[i - 1] > x >= sym_cum[i] otherwise */
{
    int i, j, k;

    i = 1;  j = N_CHAR;
    while (i < j) {
        k = (i + j) / 2;
        if (sym_cum[k] > x) i = k + 1;  else j = k;
    }
    return i;
}

int LZARI_Compress::BinarySearchPos(unsigned int x)
/* 0 if x >= position_cum[1],
N - 1 if position_cum[N] > x,
i such that position_cum[i] > x >= position_cum[i + 1] otherwise */
{
    int i, j, k;

    i = 1;  j = N;
    while (i < j) {
        k = (i + j) / 2;
        if (position_cum[k] > x) i = k + 1;  else j = k;
    }
    return i - 1;
}

void LZARI_Compress::StartDecode(void)
{
    int i;

    for (i = 0; i < M + 2; i++)
        value = 2 * value + GetBit();
}

int LZARI_Compress::DecodeChar(void)
{
    int     sym, ch;
    unsigned long int  range;

    range = high - low;
    sym = BinarySearchSym((unsigned int)
        (((value - low + 1) * sym_cum[0] - 1) / range));
    high = low + (range * sym_cum[sym - 1]) / sym_cum[0];
    low +=       (range * sym_cum[sym    ]) / sym_cum[0];
    for ( ; ; ) {
        if (low >= Q2) {
            value -= Q2;  low -= Q2;  high -= Q2;
        } else if (low >= Q1 && high <= Q3) {
            value -= Q1;  low -= Q1;  high -= Q1;
        } else if (high > Q2) break;
        low += low;  high += high;
        value = 2 * value + GetBit();
    }
    ch = sym_to_char[sym];
    UpdateModel(sym);
    return ch;
}

int LZARI_Compress::DecodePosition(void)
{
    int position;
    unsigned long int  range;

    range = high - low;
    position = BinarySearchPos((unsigned int)
        (((value - low + 1) * position_cum[0] - 1) / range));
    high = low + (range * position_cum[position    ]) / position_cum[0];
    low +=       (range * position_cum[position + 1]) / position_cum[0];
    for ( ; ; ) {
        if (low >= Q2) {
            value -= Q2;  low -= Q2;  high -= Q2;
        } else if (low >= Q1 && high <= Q3) {
            value -= Q1;  low -= Q1;  high -= Q1;
        } else if (high > Q2) break;
        low += low;  high += high;
        value = 2 * value + GetBit();
    }
    return position;
}

/********** Encode and Decode **********/

void LZARI_Compress::Encode(void)
{
    int  i, c, len, r, s, last_match_length;

    fseek(infile, 0L, SEEK_END);
    textsize = ftell(infile);
    if (fwrite(&textsize, sizeof textsize, 1, outfile) < 1)
        Error("Write Error");  /* output size of text */
    codesize += sizeof textsize;
    if (textsize == 0) return;
    rewind(infile);  textsize = 0;
    StartModel();  InitTree();
    s = 0;  r = N - F;
    for (i = s; i < r; i++) text_buf[i] = ' ';
    for (len = 0; len < F && (c = getc(infile)) != EOF; len++)
        text_buf[r + len] = c;
    textsize = len;
    for (i = 1; i <= F; i++) InsertNode(r - i);
    InsertNode(r);
    do {
        if (match_length > len) match_length = len;
        if (match_length <= THRESHOLD) {
            match_length = 1;  EncodeChar(text_buf[r]);
        } else {
            EncodeChar(255 - THRESHOLD + match_length);
            EncodePosition(match_position - 1);
        }
        last_match_length = match_length;
        for (i = 0; i < last_match_length &&
            (c = getc(infile)) != EOF; i++) {
                DeleteNode(s);  text_buf[s] = c;
                if (s < F - 1) text_buf[s + N] = c;
                s = (s + 1) & (N - 1);
                r = (r + 1) & (N - 1);
                InsertNode(r);
            }
            if ((textsize += i) > printcount) {
                //printf("%12ld\r", textsize); 
                printcount += 1024;
            }
            while (i++ < last_match_length) {
                DeleteNode(s);
                s = (s + 1) & (N - 1);
                r = (r + 1) & (N - 1);
                if (--len) InsertNode(r);
            }
    } while (len > 0);
    EncodeEnd();
    //printf("In : %lu bytes\n", textsize);
    //printf("Out: %lu bytes\n", codesize);
    //printf("Out/In: %.3f\n", (double)codesize / textsize);
}

void LZARI_Compress::Decode(void)
{
    int  i, j, k, r, c;
    unsigned long int  count;

    if (fread(&textsize, sizeof textsize, 1, infile) < 1)
        Error("Read Error");  /* read size of text */
    if (textsize == 0) return;
    StartDecode();  StartModel();
    for (i = 0; i < N - F; i++) text_buf[i] = ' ';
    r = N - F;
    for (count = 0; count < textsize; ) {
        c = DecodeChar();
        if (c < 256) {
            putc(c, outfile);  text_buf[r++] = c;
            r &= (N - 1);  count++;
        } else {
            i = (r - DecodePosition() - 1) & (N - 1);
            j = c - 255 + THRESHOLD;
            for (k = 0; k < j; k++) {
                c = text_buf[(i + k) & (N - 1)];
                putc(c, outfile);  text_buf[r++] = c;
                r &= (N - 1);  count++;
            }
        }
        if (count > printcount) {
            //printf("%12lu\r", count);  
            printcount += 1024;
        }
    }
    //printf("%12lu\n", count);
}


/*******************   DONT MODIFY ANYTHING UP THERE   ********************************* */

void LZARI_Compress::ClearAll()
{
    // reset statics vars
    this->putmask = 128;
    this->getbuffer    = this->putbuffer = this->getmask = 0;

    // reset values
    this->textsize = this->codesize = this->printcount = 0;
    this->match_length  = this->match_position = 0;
    this->shifts = this->low = this->value = 0;
    this->infile = this->outfile = NULL;
    this->high = Q4;

    // zero all memories
    memset(this->text_buf, 0, sizeof(this->text_buf));
    memset(this->lson, 0, sizeof(this->lson));
    memset(this->rson, 0, sizeof(this->rson));
    memset(this->dad, 0, sizeof(this->dad));
    memset(this->char_to_sym, 0, sizeof(this->char_to_sym));
    memset(this->sym_to_char, 0, sizeof(this->sym_to_char));
    memset(this->sym_freq, 0, sizeof(this->sym_freq));
    memset(this->sym_cum, 0, sizeof(this->sym_cum));
    memset(this->position_cum, 0, sizeof(this->position_cum));
}

bool LZARI_Compress::ReadFileContents(const std::string sFile, std::string& sOut)
{
    FILE* iFile = NULL;

    if ((iFile = fopen(sFile.c_str(), "rb")) != NULL)
    {
        fseek(iFile, 0L, SEEK_END);
        long nSize = ftell(iFile);
        char *pBuffer = new char[nSize+1];
        fseek(iFile, 0L, SEEK_SET);
        fread(pBuffer, 1, nSize, iFile);
        pBuffer[nSize] = 0;

        sOut.clear();
        sOut.assign(pBuffer, pBuffer + nSize);        // may contain multibyte binary data

        delete pBuffer;
        fclose(iFile);

        return true;
    }
    return false;
}

bool LZARI_Compress::Compress_From_File_To_File(const int METHOD_FLAG, const std::string sIn, std::string sOut)
{
    this->ClearAll();

    this->infile    = fopen(sIn.c_str(), "rb");
    this->outfile    = fopen(sOut.c_str(), "wb");

    if (this->infile && this->outfile)
    {
        METHOD_FLAG == METHOD_DEFLATE ? this->Encode() : this->Decode();
        fclose(this->infile);
        fclose(this->outfile);        

        return true;
    }
    return false;
}

bool LZARI_Compress::Compress_From_File_To_Data(const int METHOD_FLAG, const std::string sIn, std::string& sOut)
{
    this->ClearAll();

    this->infile    = fopen(sIn.c_str(), "rb");
    this->outfile    = fopen(TEMP_FILENAME, "wb");

    if (this->infile && this->outfile)
    {
        METHOD_FLAG == METHOD_DEFLATE ? this->Encode() : this->Decode();
        fclose(this->infile);
        fclose(this->outfile);

        if (this->ReadFileContents(TEMP_FILENAME, sOut))
        {
            remove(TEMP_FILENAME);
            return true;
        }
    }
    return false;
}

bool LZARI_Compress::Compress_From_Data_To_File(const int METHOD_FLAG, const std::string sIn, std::string sOut)
{
    this->ClearAll();

    this->infile    = fopen(TEMP_FILENAME, "wb+");
    this->outfile    = fopen(sOut.c_str(), "wb");

    if (this->infile && this->outfile)
    {
        fwrite(sIn.c_str(), 1, sIn.length(), this->infile);
        fseek(this->infile, 0L, SEEK_SET);

        METHOD_FLAG == METHOD_DEFLATE ? this->Encode() : this->Decode();
        fclose(this->infile);
        fclose(this->outfile);
        remove(TEMP_FILENAME);

        return true;
    }
    return false;
}

bool LZARI_Compress::Compress_From_Data_To_Data(const int METHOD_FLAG, const std::string sIn, std::string& sOut)
{
    this->ClearAll();

    const std::string sTempIn  = (std::string) "_I_" + TEMP_FILENAME;
    const std::string sTempOut = (std::string) "_O_" + TEMP_FILENAME;

    this->infile    = fopen(sTempIn.c_str() , "wb+");
    this->outfile    = fopen(sTempOut.c_str(), "wb+");

    if (this->infile && this->outfile)
    {
        fwrite(sIn.c_str(), 1, sIn.length(), this->infile);
        fseek(this->infile, 0L, SEEK_SET);

        METHOD_FLAG == METHOD_DEFLATE ? this->Encode() : this->Decode();
        fclose(this->infile);
        fclose(this->outfile);

        if (this->ReadFileContents(sTempOut, sOut))
        {
            remove(sTempIn.c_str());
            remove(sTempOut.c_str());

            return true;
        }
    }
    return false;
}