import ImageProcessor.*;
import java.io.*;
import java.util.*;
import java.lang.*;

class FreqTable {
	public Symbol[] table;
	
	public FreqTable() {
		table = new Symbol[256];
		for (int i = 0; i < 256; i++)
			table[i] = new Symbol(i);
	}
	
	public int length() {
		return table.length;
	}
	
	public int getSymb(int i) {
		return table[i].getSymbol();
	}
	
	public int getFreq(int i) {
		return table[i].getFrequency();
	}
}

class Symbol {
	public int symbol;
	public int frequency;
	
	public Symbol (int symb) {
		symbol = symb;
		frequency = 1;
	}
	
	public int getSymbol() {
		return symbol;
	}
	
	public int getFrequency() {
		return frequency;
	}
}

class Node {
	public int symbol;
	public int frequency;
	public int parent;
	public int left;
	public int right;
	
	public Node(int symb) {
		symbol = symb;
		frequency = 1;
		parent = -1;
		left = -1;
		right = -1;
	}
	
	public void setSymbol(int symb) {
		symbol = symb;
	}
	
	public int getSymbol() {
		return symbol;
	}
	
	public void setFrequency(int freq) {
		frequency = freq;
	}
	
	public int getFrequency() {
		return frequency;
	}
	
	public void setParent(int p) {
		parent = p;
	}
	
	public int getParent() {
		return parent;
	}
	
	public void setLeft(int l) {
		left = l;
	}
	
	public int getLeft() {
		return left;
	}
	
	public void setRight(int r) {
		right = r;
	}
	
	public int getRight() {
		return right;
	}
}

public class HuffTree {
	private FreqTable table;
	private int nodeAmount;
	private Node[] tree;
	private BitOutputStream bOut;
	private int simboliai;
	
	private HuffTree() {
		table = new FreqTable();
		nodeAmount = 2 * 256 - 1;
		
		tree = new Node[nodeAmount];
		createHuffTree();		
	}
	
	private void createHuffTree() {
		emptyHuffTree();
		
		int treeNode = 0;
		
		for(int i = 0; i < table.length(); i++) { 
			tree[i] = new Node(table.getSymb(i));
			tree[i].setFrequency(table.getFreq(i));
			treeNode++;
		}
		
		int sumFreq = 0;
		int in = 0;

		while (in < nodeAmount - 1) { 
			sumFreq = tree[in].getFrequency() + tree[in + 1].getFrequency();
			tree[treeNode] = new Node(0);
			tree[treeNode].setFrequency(sumFreq);
			tree[treeNode].setLeft(in);
			tree[treeNode].setRight(in + 1);
			tree[in].setParent(treeNode);
			tree[in + 1].setParent(treeNode);
				
			treeNode++;

			in += 2;
		}
	}	
	
	private void emptyHuffTree() {
		for (int i = 0; i < nodeAmount; i++)
			tree[i] = null;
	}
	
	private String getSymbolCode(int symb) {
		boolean found = false;
		int i = 0;
		
		while (!found && (i < nodeAmount)) {
			if (tree[i].getSymbol() == symb)
				found = true;
			i++;
		}
		
		int node = i - 1;
		
		//kodavimas
		String code = "";
		if (found) {
			i--;
			while ((i < nodeAmount - 1)) {
				if (tree[tree[i].getParent()].getLeft() == i)
					code = "0" + code;
				else
					code = "1" + code;
				i = tree[i].getParent();
			}
			reconstructTree(node);
		}
		return code;
	}
	
	private void reconstructTree(int node) {
	
		int max = highest(tree[node].getFrequency(), node);
		int naujas = max;
		
		// jei ne tas pats mazgas - keiciam vietom, kol pasiekiam root
		if (node != max) {
			swap(node, max);			
	
			node = tree[max].getParent();
			int root = nodeAmount - 1;

			while (node < root){
        		max = highest(tree[node].getFrequency(), node);
        		if (node != max)
            		swap(node, max);            	
            	node = tree[max].getParent();   
	   		}
		}
	    //padidinam daznius daznius	
		while (naujas != -1){
			tree[naujas].setFrequency((tree[naujas].getFrequency() + 1));
			naujas = tree[naujas].getParent();
		}		
	}
	
	private void swap(int i, int j) {
		int l = tree[i].getLeft();
		int r = tree[i].getRight();
		int symb = tree[i].getSymbol();
		tree[i].setLeft(tree[j].getLeft());
		tree[i].setRight(tree[j].getRight());
		tree[i].setSymbol(tree[j].getSymbol());
		tree[j].setLeft(l);
		tree[j].setRight(r);
		tree[j].setSymbol(symb);
		
		if (l != -1) {
			tree[l].setParent(j);
			tree[r].setParent(j);
		}
		
		l = tree[i].getLeft();
		
		if (l != -1) {
			r = tree[i].getRight();
			tree[l].setParent(i);
			tree[r].setParent(i);
		}
	}
	
	private int highest(int freq, int node) {
		int highest = node;
		int i = nodeAmount;
		
		for (i = nodeAmount - 1; i > node; i--) {
			if (tree[i].getFrequency() == freq)
				return i;
		}
		return highest;
	}
	
	private void writeBits(String s) throws IOException {
		for (int i = 0; i < s.length(); i++ ) {
        	char ch = s.charAt(i);
        	int bit = Character.getNumericValue(ch);
        	bOut.writeBit(bit);
		}
    }
    
    private String fullByte(String str, int length) {
		for (int i = str.length(); i < length; i++) {
            str = "0" + str;
		}
		return str;
    }
	
	private void coding(String fName) throws IOException {
		FileInputStream fIn = new FileInputStream(fName); 
        int c;
        FileOutputStream fOut = new FileOutputStream("coded.txt"); 
        bOut = new BitOutputStream(fOut); 
        String code = "";
        int ilgis = 0;
        int fullIlgis = 0;
        while ((c = fIn.read()) != -1) { 
        	code = "";
            code = getSymbolCode(c); 
            ilgis = code.length();
 			fullIlgis += ilgis; 
            writeBits(code);
        } 
        int number = fullIlgis % 8;
        int num = 0;
        if (number !=  0)
            num = 8 - number;
        String sNum = Integer.toBinaryString(num);
        sNum = fullByte(sNum, 8);
        for (int i = 0; i < num; i++) {
            writeBits("0");
        }
        System.out.println("Uzkoduoto failo ilgis: " + ((fullIlgis + num)/8 + 1));
        writeBits(sNum);
        fIn.close();  
        bOut.flushToFile();
		bOut.close();
        fOut.close();
	}
	
	private String getSymbols(String code, List symbols) {
		int node = nodeAmount - 1;
		int i = 0; 
		String left = new String(code);
		
		while ((tree[node].getRight() != -1) && (code.length() > 0)) {
			if (code.substring(0, 1).equals("0"))
				node = tree[node].getLeft();
			else
				node = tree[node].getRight(); 
			code = code.substring(1);
		}

		if (tree[node].getRight() == -1) {
			int symb = tree[node].getSymbol();
			symbols.add(new Integer(symb));
			
			reconstructTree(node);

			left = getSymbols(code, symbols);
		}
		return left;
	}
	
	private void decoding(String inFile) throws IOException {		
		FileInputStream fIn = new FileInputStream(inFile); 
        int clength = 0;
        int c;

        while ((c = fIn.read()) != -1) { 
            clength++; 
        } 
        fIn.close();
        int ilgis = clength;

        fIn = new FileInputStream(inFile);
        for (int i = 1; i < ilgis; i++) { 
        	c = fIn.read();
        } 
        int galas = fIn.read();
        fIn.close();
      //  System.out.println("galas: " + galas);
		
		byte[] buf;  

        File oFile = new File("decoded.txt");
		BufferedWriter writer = new BufferedWriter(new FileWriter(oFile));
                
        System.out.println("\nIsvesties failas: " + oFile);
		
		File fInput = new File(inFile);
		long len = fInput.length();
       	FileInputStream iFile = new FileInputStream(fInput);

		buf = new byte[(int)fInput.length()];
		iFile.read(buf);
		
        BitInputStream bIn = new BitInputStream(buf);
		
        String code = "";
        
        // failo atkodavimas
        long nIlgis = len*8 - galas - 8;
      //  System.out.println("nIlgis: " + nIlgis);
        String liekana = "";
		List symbols = new ArrayList();		
		
        for (int i = 1; i <= nIlgis; i++){
            int bit = bIn.readBit();
            code = code + String.valueOf(bit);
            int tmp = code.length();
            if (tmp % 8 == 0){
            	liekana = liekana + code;
            	liekana = getSymbols(liekana, symbols);
            	           	
    			char[] s; // dekoduoti simboliai
    			s = new char[symbols.size()];
				for(int j  = 0; j < symbols.size(); j++){
					String str = String.valueOf(symbols.get(j));
					int iStr = Integer.parseInt(str);
					char ch = (char)iStr;
					writer.write(ch);
				}
                code = "";
                symbols.clear();
            }
        }
      	String sLiek = liekana + code;
        liekana = getSymbols(sLiek, symbols);
            	           	
    	char[] s; // dekoduoti simboliai
    	s = new char[symbols.size()];
		for(int j  = 0; j < symbols.size(); j++){
			String str = String.valueOf(symbols.get(j));
			int iStr = Integer.parseInt(str);
			char ch = (char)iStr;
			writer.write(ch);
		}
        iFile.close();
		bIn.close();
		writer.close();
	}
		
	public static void main(String[] args)  throws IOException {;
		HuffTree t = new HuffTree();
        if (args.length < 2) {
            System.out.println("Naudojimas: java HuffmanTree c/d filename");
            System.exit(0);
        }
        else {
            if (args[0].equalsIgnoreCase("c")) {
            	try {
                	String inFile = args[1];
                	t.coding(inFile);
                } catch( FileNotFoundException e ) {
					System.err.println( "Failas " + args[1] + " nerastas!" );
					System.exit(0);
				}   
            }
            if (args[0].equalsIgnoreCase("d")) {
            	try {
                	String inFile = args[1];
                	t.decoding(inFile);
                } catch( FileNotFoundException e ) {
					System.err.println( "Failas " + args[1] + " nerastas!" );
					System.exit(0);
				} 
            }
        }
    }
}