/*
 * Decompiled with CFR 0.152.
 */
package theGhastModding.lstmStuff.main;

import autodiff.Graph;
import datastructs.TensorDataSequence;
import datastructs.TensorDataSet;
import datastructs.TensorDataStep;
import java.awt.Color;
import java.awt.Graphics;
import java.awt.image.BufferedImage;
import java.awt.image.RenderedImage;
import java.io.File;
import java.util.ArrayList;
import java.util.Random;
import javax.imageio.ImageIO;
import loss.LossSumOfSquares;
import matrix.Matrix;
import matrix.Tensor;
import model.ConvDropout;
import model.ConvLayer;
import model.ConvNet;
import model.DeconvLayer;
import model.TensorLayer;
import nonlinearities.Nonlinearity;
import nonlinearities.ReLuUnit;
import trainer.AMSGrad;
import trainer.NewTrainer;
import util.CifarLoader;
import util.FileIO;

public class DeconvTest {
    public static void main(String[] args) {
        try {
            FaithfulDataSet dataset = new FaithfulDataSet(new File("minceraft/"), 64);
            Random random = new Random();
            String savefile = "deconvTest.dat";
            ArrayList<TensorLayer> layers = new ArrayList<TensorLayer>();
            layers.add(new ConvLayer(16, 16, 3, 4, 4, 5, 2, 1, 0.08, random, true, 4));
            layers.add(new ConvDropout(0.1));
            layers.add(new DeconvLayer(2, 8, 8, 5, 5, 5, 5, 1, 2, 0.08, random, true, 4));
            layers.add(new ConvDropout(0.1));
            layers.add(new DeconvLayer(2, 16, 16, 5, 5, 5, 3, 1, 2, 0.08, random, true, 4));
            ConvNet net = new ConvNet(layers);
            if (new File(savefile).exists()) {
                FileIO.loadNeuralNetwork(savefile, net);
            }
            System.out.println("Neural network loaded");
            if (args.length != 0 && args[0].equals("LOLXD")) {
                dataset.convertAll(new File("C:\\Users\\lucah\\workspace\\LSTMStuff\\minceraft\\test\\blocks\\"), net);
                System.out.println("50%");
                dataset.convertAll(new File("C:\\Users\\lucah\\workspace\\LSTMStuff\\minceraft\\test\\items\\"), net);
                System.out.println("Done.");
                System.exit(0);
            }
            NewTrainer trainer = new NewTrainer(new AMSGrad(0.9, 0.999, 1.0E-4));
            int iterations = 150;
            boolean epochs = true;
            int i = 0;
            while (i < 150) {
                System.err.println("Iteration " + Integer.toString(i) + "/" + Integer.toString(150));
                trainer.trainConvNet(net, 0.001, 1, dataset, 1, savefile, false, false, random);
                if (i % 25 == 0 && i != 0) {
                    FileIO.saveNeuralNetwork(savefile, net);
                    System.out.println("Model saved");
                }
                ++i;
            }
            System.out.println("Done training");
            FileIO.saveNeuralNetwork(savefile, net);
            System.out.println("Model saved");
            TensorDataSequence randSeq = (TensorDataSequence)dataset.training.get(random.nextInt(dataset.training.size()));
            Tensor randTensor = randSeq.getDataStep((int)random.nextInt((int)randSeq.getSequenceLength())).input;
            BufferedImage img = CifarLoader.asImage(randTensor);
            BufferedImage res = CifarLoader.asImage(dataset.generate(randTensor, net));
            ImageIO.write((RenderedImage)img, "png", new File("deconvTest_original.png"));
            ImageIO.write((RenderedImage)res, "png", new File("deconvTest_output.png"));
            int cntr = 0;
            for (Matrix m : net.getParameters()) {
                ImageIO.write((RenderedImage)DeconvTest.aaa(m), "png", new File("weight" + Integer.toString(cntr) + ".png"));
                ++cntr;
            }
            System.exit(0);
        }
        catch (Exception e) {
            System.err.println("Error: ");
            e.printStackTrace();
            System.exit(1);
        }
    }

    private static BufferedImage aaa(Matrix m) {
        BufferedImage bi = new BufferedImage(m.cols, m.rows, 1);
        int i = 0;
        while (i < bi.getHeight()) {
            int j = 0;
            while (j < bi.getWidth()) {
                double a = m.getW(i, j);
                if (a < 0.0) {
                    a = 0.0;
                }
                if (a > 1.0) {
                    a = 1.0;
                }
                bi.setRGB(j, i, new Color((int)(a * 255.0), (int)(a * 255.0), (int)(a * 255.0)).getRGB());
                ++j;
            }
            ++i;
        }
        return bi;
    }

    private static class FaithfulDataSet
    extends TensorDataSet {
        public FaithfulDataSet(File folder, int maxImagesPerSequence) throws Exception {
            this.inputDimension = new TensorDataSet.TensorDimensions(16, 16, 3);
            this.outputDimension = new TensorDataSet.TensorDimensions(32, 32, 3);
            this.lossTraining = new LossSumOfSquares();
            this.lossReporting = new LossSumOfSquares();
            this.training = new ArrayList();
            this.training.add(new TensorDataSequence());
            ArrayList<BufferedImage> originals = new ArrayList<BufferedImage>();
            ArrayList<BufferedImage> faithful = new ArrayList<BufferedImage>();
            File inputs = new File(String.valueOf(folder.getPath()) + "/inputs");
            File outputs = new File(String.valueOf(folder.getPath()) + "/outputs");
            int cntr = 0;
            File[] fileArray = inputs.listFiles();
            int n = fileArray.length;
            int n2 = 0;
            while (n2 < n) {
                BufferedImage biOut;
                File outfile;
                BufferedImage bi;
                File f = fileArray[n2];
                if (f.getName().endsWith(".png") && (bi = ImageIO.read(f)).getWidth() == 16 && bi.getHeight() == 16 && (outfile = new File(String.valueOf(outputs.getPath()) + "/" + f.getName())).exists() && (biOut = ImageIO.read(outfile)).getWidth() == 32 && biOut.getHeight() == 32) {
                    originals.add(this.fixTransparency(bi));
                    faithful.add(this.fixTransparency(biOut));
                    ++cntr;
                }
                ++n2;
            }
            System.out.println("Loaded " + Integer.toString(cntr) + " images");
            int i = 0;
            while (i < originals.size()) {
                TensorDataStep step = new TensorDataStep(CifarLoader.asTensor((BufferedImage)originals.get(i)), CifarLoader.asTensor((BufferedImage)faithful.get(i)));
                ((TensorDataSequence)this.training.get(this.training.size() - 1)).addDataStep(step);
                if (i % maxImagesPerSequence == 0) {
                    this.training.add(new TensorDataSequence());
                }
                ++i;
            }
            System.out.println("Dataset loaded\r\n" + Integer.toString(this.training.size()) + " sequences");
        }

        private void convertAll(File folder, ConvNet n) throws Exception {
            File outputFolder = new File(String.valueOf(folder.getPath()) + "_out/");
            outputFolder.mkdir();
            File[] fileArray = folder.listFiles();
            int n2 = fileArray.length;
            int n3 = 0;
            while (n3 < n2) {
                BufferedImage bi;
                File f = fileArray[n3];
                if (f.getName().endsWith(".png") && (bi = ImageIO.read(f)).getWidth() == 16 && bi.getHeight() == 16) {
                    File outfile = new File(String.valueOf(outputFolder.getPath()) + "/" + f.getName());
                    BufferedImage out = CifarLoader.asImage(this.generate(CifarLoader.asTensor(this.fixTransparency(bi)), n));
                    ImageIO.write((RenderedImage)out, "png", outfile);
                }
                ++n3;
            }
        }

        private BufferedImage fixTransparency(BufferedImage in) {
            BufferedImage toReturn = new BufferedImage(in.getWidth(), in.getHeight(), 1);
            Graphics g = toReturn.getGraphics();
            g.drawImage(in, 0, 0, toReturn.getWidth(), toReturn.getHeight(), Color.BLACK, null);
            g.dispose();
            return toReturn;
        }

        public Tensor generate(Tensor t, ConvNet net) throws Exception {
            net.resetState();
            Graph g = new Graph(false);
            Tensor toReturn = net.forward(t, g);
            net.resetState();
            return toReturn;
        }

        @Override
        public void DisplayReport(ConvNet model, Random rng) throws Exception {
            System.err.println("a");
        }

        @Override
        public Nonlinearity getModelOutputUnitToUse() {
            return new ReLuUnit();
        }
    }
}

