/*
 * Decompiled with CFR 0.152.
 */
package theGhastModding.lstmStuff.main;

import autodiff.Graph;
import datastructs.DataSequence;
import datastructs.DataSet;
import datastructs.DataStep;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.Image;
import java.awt.image.BufferedImage;
import java.awt.image.RenderedImage;
import java.io.File;
import java.util.ArrayList;
import java.util.Random;
import javax.imageio.ImageIO;
import javax.swing.JButton;
import javax.swing.JDialog;
import loss.LossSumOfSquares;
import matrix.Tensor;
import model.ConvLayer;
import model.ConvUpsample;
import model.NeuralNetwork;
import nonlinearities.Nonlinearity;
import nonlinearities.ReLuUnit;
import nonlinearities.RectifiedLinearUnit;
import nonlinearities.TanhUnit;
import trainer.AMSGrad;
import trainer.Trainer;
import util.CLUtils;
import util.FileIO;
import util.NNDevice;

public class DeconvTest {
    public static void main(String[] args) {
        try {
            FaithfulDataSet dataset = new FaithfulDataSet(new File("minceraft/"), 64);
            Random random = new Random();
            String savefile = "deconvTest.dat";
            NeuralNetwork net = new NeuralNetwork();
            net.addLayer(new ConvUpsample(2));
            net.addLayer(new ConvLayer(32, 32, 3, 5, 5, 64, 1, 2, new RectifiedLinearUnit(0.05), 0.08, random, true, true, 4));
            net.addLayer(new ConvLayer(32, 32, 64, 1, 1, 32, 1, 0, new RectifiedLinearUnit(0.05), 0.08, random, true, true, 4));
            net.addLayer(new ConvLayer(32, 32, 32, 3, 3, 3, 1, 1, new TanhUnit(), 0.08, random, true, true, 4));
            if (new File(savefile).exists()) {
                FileIO.loadNeuralNetwork(savefile, net);
            }
            System.out.println("Neural network loaded");
            if (args.length != 0 && args[0].equals("LOLXD")) {
                dataset.convertAll(new File("C:\\Users\\lucah\\workspace\\LSTMStuff\\minceraft\\test\\blocks\\"), net);
                System.out.println("50%");
                dataset.convertAll(new File("C:\\Users\\lucah\\workspace\\LSTMStuff\\minceraft\\test\\items\\"), net);
                System.out.println("Done.");
                System.exit(0);
            }
            NNDevice dev = CLUtils.findDevice("AMD", "Baffin")[0];
            Trainer trainer = new Trainer(new AMSGrad(0.5, 0.555, 1.0E-8), dev);
            int iterations = 25;
            boolean epochs = true;
            int i = 0;
            while (i < 25) {
                System.err.println("Iteration " + Integer.toString(i) + "/" + Integer.toString(25));
                trainer.train(net, 5.0E-4, 1, dataset, 1, savefile, false, false, random);
                if (i % 25 == 0 && i != 0) {
                    FileIO.saveNeuralNetwork(savefile, net);
                    System.out.println("Model saved");
                }
                ++i;
            }
            System.out.println("Done training");
            FileIO.saveNeuralNetwork(savefile, net);
            System.out.println("Model saved");
            DataSequence randSeq = (DataSequence)dataset.training.get(random.nextInt(dataset.training.size()));
            Tensor randTensor = randSeq.getDataStep((int)random.nextInt((int)randSeq.getSequenceLength())).input;
            BufferedImage img = DeconvTest.asImage(randTensor);
            BufferedImage res = DeconvTest.asImage(dataset.generate(randTensor, net));
            ImageIO.write((RenderedImage)img, "png", new File("deconvTest_original.png"));
            ImageIO.write((RenderedImage)res, "png", new File("deconvTest_output.png"));
            System.exit(0);
        }
        catch (Exception e) {
            System.err.println("Error: ");
            e.printStackTrace();
            System.exit(1);
        }
    }

    private static Tensor asTensor(BufferedImage image) {
        Tensor toReturn = new Tensor(image.getWidth(), image.getHeight(), 3);
        int i = 0;
        while (i < image.getWidth()) {
            int j = 0;
            while (j < image.getHeight()) {
                int argb = image.getRGB(i, j);
                int r = argb >> 16 & 0xFF;
                int g = argb >> 8 & 0xFF;
                int b = argb >> 0 & 0xFF;
                toReturn.matrices[0].setW(j, i, (double)r / 255.0 * 2.0 - 1.0);
                toReturn.matrices[1].setW(j, i, (double)g / 255.0 * 2.0 - 1.0);
                toReturn.matrices[2].setW(j, i, (double)b / 255.0 * 2.0 - 1.0);
                ++j;
            }
            ++i;
        }
        return toReturn;
    }

    private static BufferedImage asImage(Tensor t) {
        BufferedImage toReturn = new BufferedImage(t.width, t.height, t.depth >= 4 ? 2 : 1);
        int j = 0;
        while (j < toReturn.getWidth()) {
            int k = 0;
            while (k < toReturn.getHeight()) {
                int r = (int)((t.matrices[0].getW(k, j) + 1.0) / 2.0 * 255.0);
                int g = (int)((t.matrices[1].getW(k, j) + 1.0) / 2.0 * 255.0);
                int b = (int)((t.matrices[2].getW(k, j) + 1.0) / 2.0 * 255.0);
                if (r > 255) {
                    r = 255;
                }
                if (r < 0) {
                    r = 0;
                }
                if (g > 255) {
                    g = 255;
                }
                if (g < 0) {
                    g = 0;
                }
                if (b > 255) {
                    b = 255;
                }
                if (b < 0) {
                    b = 0;
                }
                int rgb = r;
                rgb = (rgb << 8) + g;
                rgb = (rgb << 8) + b;
                toReturn.setRGB(j, k, rgb);
                ++k;
            }
            ++j;
        }
        return toReturn;
    }

    private static class FaithfulDataSet
    extends DataSet {
        public FaithfulDataSet(File folder, int maxImagesPerSequence) throws Exception {
            this.inputDimension = new DataSet.TensorDimensions(16, 16, 3);
            this.outputDimension = new DataSet.TensorDimensions(32, 32, 3);
            this.lossTraining = new LossSumOfSquares();
            this.lossReporting = new LossSumOfSquares();
            this.training = new ArrayList();
            this.training.add(new DataSequence());
            ArrayList<BufferedImage> originals = new ArrayList<BufferedImage>();
            ArrayList<BufferedImage> faithful = new ArrayList<BufferedImage>();
            File inputs = new File(String.valueOf(folder.getPath()) + "/inputs");
            File outputs = new File(String.valueOf(folder.getPath()) + "/outputs");
            int cntr = 0;
            File[] fileArray = inputs.listFiles();
            int n = fileArray.length;
            int n2 = 0;
            while (n2 < n) {
                BufferedImage biOut;
                File outfile;
                BufferedImage bi;
                File f = fileArray[n2];
                if (f.getName().endsWith(".png") && (bi = ImageIO.read(f)).getWidth() == 16 && bi.getHeight() == 16 && (outfile = new File(String.valueOf(outputs.getPath()) + "/" + f.getName())).exists() && (biOut = ImageIO.read(outfile)).getWidth() == 32 && biOut.getHeight() == 32) {
                    originals.add(bi);
                    faithful.add(biOut);
                    ++cntr;
                }
                ++n2;
            }
            System.out.println("Loaded " + Integer.toString(cntr) + " images");
            JDialog frame = new JDialog(null, "test");
            frame.setResizable(false);
            frame.setDefaultCloseOperation(0);
            frame.setPreferredSize(new Dimension(128, 128));
            frame.setLayout(null);
            JButton b = new JButton("OK");
            b.setBounds(0, 64, 128, 32);
            b.setEnabled(false);
            frame.getContentPane().add(b);
            frame.pack();
            frame.setVisible(true);
            int i = 0;
            while (i < originals.size()) {
                BufferedImage img = (BufferedImage)originals.get(i);
                DataStep step = new DataStep(DeconvTest.asTensor(img), DeconvTest.asTensor((BufferedImage)faithful.get(i)));
                ((DataSequence)this.training.get(this.training.size() - 1)).addDataStep(step);
                frame.setTitle("test " + Integer.toString(i) + "/" + Integer.toString(originals.size()));
                frame.getContentPane().getGraphics().setColor(Color.BLACK);
                frame.getContentPane().getGraphics().fillRect(0, 0, 128, 64);
                frame.getContentPane().getGraphics().drawImage(img, 0, 0, 64, 64, null);
                frame.getContentPane().getGraphics().drawImage((Image)faithful.get(i), 64, 0, 64, 64, null);
                Thread.sleep(64L);
                if (i % maxImagesPerSequence == 0) {
                    this.training.add(new DataSequence());
                }
                ++i;
            }
            frame.setVisible(false);
            System.out.println("Dataset loaded\r\n" + Integer.toString(this.training.size()) + " sequences");
        }

        private void convertAll(File folder, NeuralNetwork n) throws Exception {
            File outputFolder = new File(String.valueOf(folder.getPath()) + "_out/");
            outputFolder.mkdir();
            File[] fileArray = folder.listFiles();
            int n2 = fileArray.length;
            int n3 = 0;
            while (n3 < n2) {
                BufferedImage bi;
                File f = fileArray[n3];
                if (f.getName().endsWith(".png") && (bi = ImageIO.read(f)).getWidth() == 16 && bi.getHeight() == 16) {
                    File outfile = new File(String.valueOf(outputFolder.getPath()) + "/" + f.getName());
                    BufferedImage out = DeconvTest.asImage(this.generate(DeconvTest.asTensor(bi), n));
                    ImageIO.write((RenderedImage)this.postprocess(out), "png", outfile);
                }
                ++n3;
            }
        }

        private BufferedImage postprocess(BufferedImage in) {
            BufferedImage toReturn = new BufferedImage(32, 32, 2);
            int i = 0;
            while (i < 32) {
                int j = 0;
                while (j < 32) {
                    Color c = new Color(in.getRGB(i, j));
                    if (c.getRed() < 1 && c.getGreen() < 1 && c.getBlue() < 1) {
                        c = new Color(c.getRed(), c.getGreen(), c.getBlue(), 0);
                    }
                    toReturn.setRGB(i, j, c.getRGB());
                    ++j;
                }
                ++i;
            }
            return toReturn;
        }

        public Tensor generate(Tensor t, NeuralNetwork net) throws Exception {
            net.resetState();
            Graph g = new Graph(false);
            Tensor toReturn = net.forward(t, g);
            net.resetState();
            return toReturn;
        }

        @Override
        public void DisplayReport(NeuralNetwork model, Random rng) throws Exception {
            System.err.println("a");
        }

        @Override
        public Nonlinearity getModelOutputUnitToUse() {
            return new ReLuUnit();
        }
    }
}

