/*
 * Decompiled with CFR 0.152.
 */
package theGhastModding.lstmStuff.main;

import WavFile.WavFile;
import autodiff.GPUGraph;
import autodiff.Graph;
import datastructs.DataSequence;
import datastructs.DataSet;
import datastructs.DataStep;
import edu.cornell.lassp.houle.RngPack.RanMT;
import java.awt.Color;
import java.awt.Graphics2D;
import java.awt.image.BufferedImage;
import java.awt.image.RenderedImage;
import java.io.File;
import java.util.ArrayList;
import java.util.Random;
import javax.imageio.ImageIO;
import loss.LossSumOfSquares;
import matrix.Matrix;
import matrix.Tensor;
import model.ConvLayer;
import model.ConvUpsample;
import model.NeuralNetwork;
import model.NormalizeLayer;
import nonlinearities.LinearUnit;
import nonlinearities.Nonlinearity;
import nonlinearities.RectifiedLinearUnit;
import theGhastModding.lstmStuff.main.VideoEncoder;
import trainer.Adam;
import trainer.Trainer;
import util.CLUtils;
import util.CifarLoader;
import util.FileIO;
import util.NNDevice;

public class VideoNN {
    private static NeuralNetwork encoder;
    private static NeuralNetwork decoder;
    private static NeuralNetwork fullCompressor;
    private static final String saveLocation = "vidAI";

    public static void main(String[] args) {
        try {
            NNDevice dev = CLUtils.findDevice("AMD", "Tahiti")[0];
            if (dev == null) {
                throw new NullPointerException();
            }
            VideoNN.loadNets();
            VideoNN.trainCompressor(5.0E-4, 10, 2, 2, dev);
            VideoNN.testCompressor(new File("pony town stuff\\\\video frames\\\\"), dev);
            System.out.println("Done.");
            System.exit(1);
        }
        catch (Exception e) {
            System.err.println("Error: ");
            e.printStackTrace();
            System.exit(1);
        }
    }

    public static void loadNets() throws Exception {
        System.out.println("Loading models...");
        encoder = new NeuralNetwork();
        int cores = 3;
        encoder.addLayer(new ConvLayer(216, 120, 3, 3, 3, 64, 1, 1, new LinearUnit(), 0.00358, new Random(), true, true, cores));
        encoder.addLayer(new ConvLayer(216, 120, 64, 4, 4, 96, 2, 1, new RectifiedLinearUnit(0.05), 0.001098013, new Random(), true, true, cores));
        encoder.addLayer(new NormalizeLayer(1.0E-8, 0.08, new Random()));
        encoder.addLayer(new ConvLayer(108, 60, 96, 3, 3, 96, 1, 1, new LinearUnit(), 0.00155282498, new Random(), true, true, cores));
        encoder.addLayer(new ConvLayer(108, 60, 96, 4, 4, 128, 2, 1, new RectifiedLinearUnit(0.05), 0.00155282498, new Random(), true, true, cores));
        encoder.addLayer(new NormalizeLayer(1.0E-8, 0.08, new Random()));
        encoder.addLayer(new ConvLayer(54, 30, 128, 3, 3, 128, 1, 1, new LinearUnit(), 0.00253575258, new Random(), true, true, cores));
        encoder.addLayer(new ConvLayer(54, 30, 128, 4, 4, 140, 2, 1, new RectifiedLinearUnit(0.05), 0.00253575258, new Random(), true, true, cores));
        encoder.addLayer(new NormalizeLayer(1.0E-8, 0.08, new Random()));
        encoder.addLayer(new ConvLayer(27, 15, 140, 3, 3, 140, 1, 1, new LinearUnit(), 0.00439205, new Random(), true, true, cores));
        encoder.addLayer(new ConvLayer(27, 15, 140, 3, 3, 150, 2, 0, new RectifiedLinearUnit(0.05), 0.00439205, new Random(), true, true, cores));
        encoder.addLayer(new NormalizeLayer(1.0E-8, 0.08, new Random()));
        encoder.addLayer(new ConvLayer(13, 7, 150, 3, 3, 128, 1, 1, new LinearUnit(), 0.009265616, new Random(), true, true, cores));
        decoder = new NeuralNetwork();
        decoder.addLayer(new ConvUpsample(2));
        decoder.addLayer(new ConvLayer(26, 14, 128, 5, 5, 150, 1, 3, new RectifiedLinearUnit(0.2), 0.004632, new Random(), true, true, cores));
        decoder.addLayer(new ConvLayer(28, 16, 150, 3, 3, 150, 1, 1, new RectifiedLinearUnit(0.2), 0.004175, new Random(), true, true, cores));
        decoder.addLayer(new NormalizeLayer(1.0E-8, 0.08, new Random()));
        decoder.addLayer(new ConvUpsample(2));
        decoder.addLayer(new ConvLayer(56, 32, 150, 5, 5, 140, 1, 1, new RectifiedLinearUnit(0.2), 0.00208797, new Random(), true, true, cores));
        decoder.addLayer(new ConvLayer(54, 30, 140, 3, 3, 140, 1, 1, new RectifiedLinearUnit(0.2), 0.00253575, new Random(), true, true, cores));
        decoder.addLayer(new NormalizeLayer(1.0E-8, 0.08, new Random()));
        decoder.addLayer(new ConvUpsample(2));
        decoder.addLayer(new ConvLayer(108, 60, 140, 5, 5, 128, 1, 2, new RectifiedLinearUnit(0.2), 0.00126787, new Random(), true, true, cores));
        decoder.addLayer(new ConvLayer(108, 60, 128, 3, 3, 96, 1, 1, new RectifiedLinearUnit(0.2), 0.00155282, new Random(), true, true, cores));
        decoder.addLayer(new NormalizeLayer(1.0E-8, 0.08, new Random()));
        decoder.addLayer(new ConvUpsample(2));
        decoder.addLayer(new ConvLayer(216, 120, 96, 5, 5, 64, 1, 2, new RectifiedLinearUnit(0.2), 0.001, new Random(), true, true, cores));
        decoder.addLayer(new ConvLayer(216, 120, 64, 3, 3, 3, 1, 1, new LinearUnit(), 0.001, new Random(), true, true, cores));
        if (new File("vidAI_enc.dat").exists()) {
            FileIO.loadNeuralNetwork("vidAI_enc.dat", encoder);
        }
        if (new File("vidAI_dec.dat").exists()) {
            FileIO.loadNeuralNetwork("vidAI_dec.dat", decoder);
        }
        fullCompressor = new NeuralNetwork();
        fullCompressor.addLayer(encoder);
        fullCompressor.addLayer(decoder);
        long totalParams = 0L;
        for (Matrix m : fullCompressor.getParameters()) {
            totalParams += (long)m.w.length;
        }
        System.out.println("Loaded. " + Long.toString(totalParams) + " total parameters.");
    }

    public static void saveNets() throws Exception {
        System.out.println("Saving models...");
        FileIO.saveNeuralNetwork("vidAI_enc.dat", encoder);
        FileIO.saveNeuralNetwork("vidAI_dec.dat", decoder);
        System.out.println("Saved.");
    }

    private static void trainCompressor(double learningRate, int iterations, int epochs, int numSequences, NNDevice dev) throws Exception {
        Trainer t = new Trainer(new Adam(0.9, 0.995, 1.0E-8), dev);
        int i = 0;
        while (i < iterations) {
            System.out.println("Iteration " + Integer.toString(i + 1) + "/" + Integer.toString(iterations));
            CompressorDataset dataset = new CompressorDataset(new File("pony town stuff\\video frames\\"), numSequences, 16);
            t.train(fullCompressor, learningRate, epochs, dataset, 1, null, false, false, null);
            VideoNN.saveNets();
            ++i;
        }
    }

    public static void test(File imagesLocation, boolean doAll) throws Exception {
        RanMT rn = new RanMT();
        int fileCount = imagesLocation.list().length;
        int frameLength = doAll ? fileCount : 1800;
        int randomStartIndx = (int)(rn.raw() * (double)(fileCount - frameLength));
        if (doAll) {
            randomStartIndx = 0;
        }
        VideoEncoder enc = new VideoEncoder(new File("test.mp4"), 30, 216, 120, 8, "placebo", 23, true);
        WavFile audio = WavFile.newWavFile(new File("test_audio.wav"), 1, (int)Math.ceil(367.5 * (double)frameLength), 16, 11025L);
        int i = 0;
        while (i < frameLength) {
            File imgFile = new File(String.valueOf(imagesLocation.getPath()) + "/video" + String.format("%06d", randomStartIndx + i + 1) + ".bmp");
            BufferedImage currentImage = ImageIO.read(imgFile);
            enc.encodeFrame(currentImage);
            double[] sampleBuffer = new double[367 + (randomStartIndx + i) % 2];
            int j = 0;
            while (j < sampleBuffer.length) {
                double sample;
                int x = j % (currentImage.getWidth() / 2) * 2;
                int y = j / (currentImage.getWidth() / 2) * 2;
                Color c = new Color(currentImage.getRGB(x, y));
                sampleBuffer[j] = sample = (double)c.getRed() / 255.0 * 2.0 - 1.0;
                ++j;
            }
            audio.writeFrames(sampleBuffer, sampleBuffer.length);
            ++i;
        }
        enc.finishEncode();
        audio.close();
    }

    public static void testCompressor(File imagesLocation, NNDevice dev) throws Exception {
        RanMT rn = new RanMT();
        File outputFolder = new File("compressor test/");
        outputFolder.mkdir();
        int fileCount = imagesLocation.list().length;
        int frameLength = 450;
        int randomStartIndx = (int)(rn.raw() * (double)(fileCount - 450));
        System.err.println(randomStartIndx);
        fullCompressor.resetState();
        GPUGraph g = CLUtils.createGraph(dev, false);
        int i = 0;
        while (i < 450) {
            File imgFile = new File(String.valueOf(imagesLocation.getPath()) + "/video" + String.format("%06d", randomStartIndx + i + 1) + ".bmp");
            BufferedImage currentImage = ImageIO.read(imgFile);
            BufferedImage out = CifarLoader.asImage(fullCompressor.forward(CifarLoader.asTensor(currentImage), (Graph)g));
            imgFile = new File(String.valueOf(outputFolder.getPath()) + "/video" + String.format("%06d", i + 1) + ".bmp");
            ImageIO.write((RenderedImage)out, "png", imgFile);
            ++i;
        }
        fullCompressor.resetState();
        ((Graph)g).cleanUp();
        VideoNN.test(outputFolder, true);
        File[] fileArray = outputFolder.listFiles();
        int n = fileArray.length;
        int n2 = 0;
        while (n2 < n) {
            File f = fileArray[n2];
            f.delete();
            ++n2;
        }
        outputFolder.delete();
    }

    public static void prepImages(File imagesLocation, File audioFile) throws Exception {
        WavFile audio = WavFile.openWavFile(audioFile);
        if (audio.getSampleRate() != 11025L) {
            throw new Exception("Invalid sample rate, has to be 11025Hz");
        }
        int fileCount = imagesLocation.list().length;
        long startTime = System.currentTimeMillis();
        int i = 1;
        while (i < fileCount + 1) {
            if (System.currentTimeMillis() - startTime >= 1000L) {
                startTime = System.currentTimeMillis();
                System.out.println(String.valueOf(String.format("%05d", i)) + "/" + String.format("%05d", fileCount));
            }
            File imgFile = new File(String.valueOf(imagesLocation.getPath()) + "/video" + String.format("%06d", i) + ".bmp");
            BufferedImage currentImage = ImageIO.read(imgFile);
            BufferedImage outputImage = new BufferedImage(216, 120, 1);
            Graphics2D g = (Graphics2D)outputImage.getGraphics();
            g.drawImage(currentImage, 0, 0, 216, 120, null);
            double[] sampleBuffer = new double[367 + (i + 1) % 2];
            audio.readFrames(sampleBuffer, sampleBuffer.length);
            int j = 0;
            while (j < sampleBuffer.length) {
                int sampleCol = (int)((sampleBuffer[j] + 1.0) / 2.0 * 255.0);
                if (sampleCol < 0) {
                    sampleCol = 0;
                }
                if (sampleCol > 255) {
                    sampleCol = 255;
                }
                Color c = new Color(sampleCol, sampleCol, sampleCol);
                int x = j % (outputImage.getWidth() / 2) * 2;
                int y = j / (outputImage.getWidth() / 2) * 2;
                sampleCol = c.getRGB();
                outputImage.setRGB(x, y, sampleCol);
                outputImage.setRGB(x + 1, y, sampleCol);
                outputImage.setRGB(x, y + 1, sampleCol);
                outputImage.setRGB(x + 1, y + 1, sampleCol);
                ++j;
            }
            ImageIO.write((RenderedImage)outputImage, "png", imgFile);
            ++i;
        }
    }

    private static class CompressorDataset
    extends DataSet {
        public CompressorDataset(File imagesFolder, int numSequences, int sequenceLength) throws Exception {
            this.training = new ArrayList();
            this.inputDimension = new DataSet.TensorDimensions(216, 120, 3);
            this.outputDimension = new DataSet.TensorDimensions(216, 120, 3);
            this.lossTraining = new LossSumOfSquares();
            this.lossReporting = new LossSumOfSquares();
            RanMT rn = new RanMT();
            int imageCount = imagesFolder.list().length;
            int i = 0;
            while (i < numSequences) {
                DataSequence seq = new DataSequence();
                int j = 0;
                while (j < sequenceLength) {
                    int randomImageIndx = (int)(rn.raw() * (double)(imageCount - 1));
                    BufferedImage currImage = ImageIO.read(new File(String.valueOf(imagesFolder.getPath()) + "/video" + String.format("%06d", randomImageIndx + i + 1) + ".bmp"));
                    Tensor t = CifarLoader.asTensor(currImage);
                    seq.addDataStep(new DataStep(t, t.clone()));
                    ++j;
                }
                this.training.add(seq);
                ++i;
            }
            System.out.println(String.valueOf(Long.toString((long)numSequences * (long)sequenceLength)) + " images in dataset");
        }

        @Override
        public void DisplayReport(NeuralNetwork model, Random rng) throws Exception {
            System.out.println("a");
        }

        @Override
        public Nonlinearity getModelOutputUnitToUse() {
            return new LinearUnit();
        }
    }
}

