/*
 * Decompiled with CFR 0.152.
 */
package datasets;

import autodiff.Graph;
import datastructs.DataSequence;
import datastructs.DataSet;
import datastructs.DataStep;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.util.ArrayList;
import java.util.Random;
import java.util.zip.Deflater;
import java.util.zip.Inflater;
import loss.LossSoftmax;
import matrix.Matrix;
import model.Model;
import nonlinearities.LinearUnit;
import nonlinearities.Nonlinearity;
import theGhastModding.console.main.Console;
import util.Util;

public class CompressedSet
extends DataSet {
    public CompressedSet(File[] f) throws Exception {
        System.out.println("Loading compressed set");
        System.out.println("Compressing files");
        byte[][] allCompressedBytes = new byte[f.length][];
        Deflater df = new Deflater();
        int i = 0;
        while (i < f.length) {
            System.out.println(String.valueOf(Integer.toString(i + 1)) + "/" + Integer.toString(f.length));
            File f2 = f[i];
            byte[] fileBytes = new byte[(int)f2.length()];
            FileInputStream fis = new FileInputStream(f2);
            fis.read(fileBytes);
            fis.close();
            df.setLevel(9);
            df.setInput(fileBytes);
            ByteArrayOutputStream out = new ByteArrayOutputStream(fileBytes.length);
            df.finish();
            byte[] buffer = new byte[1024];
            while (!df.finished()) {
                int count = df.deflate(buffer);
                out.write(buffer, 0, count);
                out.flush();
            }
            out.close();
            byte[] compressedBytes = out.toByteArray();
            df.reset();
            fileBytes = null;
            allCompressedBytes[i] = compressedBytes;
            ++i;
        }
        System.gc();
        System.out.println("Preparing dataset");
        this.training = new ArrayList();
        int length = 0;
        double[] startEndVec = new double[257];
        startEndVec[256] = 1.0;
        byte[][] byArrayArray = allCompressedBytes;
        int n = allCompressedBytes.length;
        int n2 = 0;
        while (n2 < n) {
            byte[] b2 = byArrayArray[n2];
            if (b2.length <= 10000) {
                ArrayList<double[]> vecs = new ArrayList<double[]>();
                vecs.add(startEndVec);
                byte[] byArray = b2;
                int n3 = b2.length;
                int n4 = 0;
                while (n4 < n3) {
                    byte b = byArray[n4];
                    double[] vec = new double[257];
                    vec[b & 0xFF] = 1.0;
                    vecs.add(vec);
                    ++n4;
                }
                vecs.add(startEndVec);
                DataSequence sequence = new DataSequence();
                int i2 = 0;
                while (i2 < vecs.size() - 1) {
                    sequence.steps.add(new DataStep((double[])vecs.get(i2), (double[])vecs.get(i2 + 1)));
                    ++length;
                    ++i2;
                }
                this.training.add(sequence);
                b2 = null;
                System.gc();
            }
            ++n2;
        }
        this.lossTraining = new LossSoftmax();
        this.lossReporting = new LossSoftmax();
        this.inputDimension = 257;
        this.outputDimension = 257;
        allCompressedBytes = null;
        System.gc();
        System.out.println("done. " + Integer.toString(length) + " steps in dataset");
    }

    public static byte[] generateOutput(Model model, double temperature, Random rng) throws Exception {
        double[] startVec = new double[257];
        startVec[256] = 1.0;
        Matrix startMatrix = new Matrix(257);
        startMatrix.w = startVec;
        Matrix inOut = startMatrix.clone();
        boolean looping = true;
        Graph g = new Graph(false);
        ArrayList<Byte> byteList = new ArrayList<Byte>();
        boolean b = false;
        while (looping) {
            Matrix res = model.forward(inOut, g);
            Matrix probs = LossSoftmax.getSoftmaxProbs(res, temperature);
            int indx = Util.pickIndexFromRandomVector(probs, rng);
            if (indx == 256 && b) {
                looping = false;
            } else if (indx != 256) {
                if (!b) {
                    b = true;
                }
                byteList.add((byte)indx);
                inOut.w = new double[257];
                inOut.w[indx] = 1.0;
            }
            if (byteList.size() <= 10000) continue;
            throw new Exception("Invalid output length");
        }
        System.gc();
        byte[] byteArray = new byte[byteList.size()];
        int i = 0;
        while (i < byteArray.length) {
            byteArray[i] = (Byte)byteList.get(i);
            ++i;
        }
        if (byteArray.length < 300) {
            throw new Exception("Invalid output length");
        }
        System.out.println("Compressed size is: " + Integer.toString(byteArray.length));
        model.resetState();
        System.gc();
        byte[] uncompressedByteArray = CompressedSet.decompressShit(byteArray);
        System.out.println("Unompressed size is: " + Integer.toString(uncompressedByteArray.length));
        return uncompressedByteArray;
    }

    private static byte[] decompressShit(byte[] toDecompress) throws Exception {
        Inflater in = new Inflater();
        in.setInput(toDecompress);
        ByteArrayOutputStream outputStream = new ByteArrayOutputStream(toDecompress.length);
        byte[] buffer = new byte[1024];
        while (!in.finished()) {
            int count = in.inflate(buffer);
            outputStream.write(buffer, 0, count);
        }
        outputStream.close();
        in.reset();
        return outputStream.toByteArray();
    }

    @Override
    public void DisplayReport(Model model, Random rng, Console c) throws Exception {
        System.out.println("\ncalculating perplexity over entire data set...");
        double perplexity = LossSoftmax.calculateMedianPerplexity(model, this.training);
        System.out.println("\nMedian Perplexity = " + String.format("%.4f", perplexity));
        System.out.println("Yeah. You can't exactly display a more detailed log on this. Sorry.");
    }

    @Override
    public Nonlinearity getModelOutputUnitToUse() {
        return new LinearUnit();
    }
}

