/*
 * Decompiled with CFR 0.152.
 */
package theGhastModding.lstmStuff.main;

import autodiff.Graph;
import datasets.CifarDataset;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.image.BufferedImage;
import java.awt.image.RenderedImage;
import java.io.File;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.Random;
import javax.imageio.ImageIO;
import javax.swing.ImageIcon;
import javax.swing.JButton;
import javax.swing.JDialog;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.UIManager;
import matrix.Matrix;
import matrix.Tensor;
import model.ConvDropout;
import model.ConvFlatten;
import model.ConvLayer;
import model.ConvNet;
import model.ConvNonlinLayer;
import model.PoolLayer;
import nonlinearities.ReLuUnit;
import trainer.AMSGrad;
import trainer.GradientNoise;
import trainer.NewTrainer;
import trainer.RMSProp;
import util.CifarLoader;
import util.FileIO;

public class ConvTest {
    public static void main(String[] args) {
        try {
            UIManager.setLookAndFeel(UIManager.getSystemLookAndFeelClassName());
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        NewTrainer trainer = new NewTrainer(new RMSProp(0.999, 5.0, 1.0E-4));
        if (args.length > 0 && args[0].equalsIgnoreCase("AMS")) {
            trainer = new NewTrainer(new GradientNoise(new AMSGrad(0.9, 0.999, 1.0E-5)));
        }
        try {
            Random random = new Random();
            System.out.println("Loading images...");
            CifarLoader loader = new CifarLoader(new File("cifar/"));
            loader.load(args.length > 0 && args[0].equalsIgnoreCase("true") ? 3 : 2, 10000);
            System.out.println("Complete!");
            System.out.println("Preparing dataset...");
            CifarDataset dataset = new CifarDataset(loader, 2000);
            if (args.length <= 0 || !args[0].equalsIgnoreCase("true")) {
                loader = null;
                System.gc();
            }
            System.out.println("Complete!");
            System.out.println("Preparing Convolutional Neural Network...");
            ConvNet convNet = new ConvNet();
            int cores = 4;
            convNet.addLayer(new ConvLayer(32, 32, 3, 5, 5, 16, 1, 0, 0.08, random, true, cores));
            convNet.addLayer(new ConvNonlinLayer(new ReLuUnit()));
            convNet.addLayer(new ConvDropout(0.1));
            convNet.addLayer(new PoolLayer(28, 28, 16, 2, 2, 2));
            convNet.addLayer(new ConvLayer(14, 14, 16, 5, 5, 20, 1, 0, 0.08, random, true, cores));
            convNet.addLayer(new ConvNonlinLayer(new ReLuUnit()));
            convNet.addLayer(new ConvDropout(0.1));
            convNet.addLayer(new PoolLayer(10, 10, 20, 2, 2, 2));
            convNet.addLayer(new ConvLayer(5, 5, 20, 5, 5, 10, 1, 0, 0.08, random, true, cores));
            convNet.addLayer(new ConvNonlinLayer(new ReLuUnit()));
            convNet.addLayer(new ConvFlatten(1, 1, 10));
            System.out.println("Complete!");
            if (args.length > 0 && args[0].equalsIgnoreCase("true")) {
                System.out.println("Generating test output...");
                System.out.println("Loading Network from saved state...");
                FileIO.loadNeuralNetwork("cifar_test.ser", convNet);
                System.out.println("Complete!");
                System.out.println("Passing through many images...");
                int right = 0;
                int i = 0;
                while (i < 4000) {
                    Tensor in = CifarLoader.asTensor(loader.getImage(i));
                    Matrix out = convNet.forward(in, new Graph(false)).getMatrixAt(0);
                    int largestIndx = 0;
                    double largestVal = Double.MIN_VALUE;
                    int j = 0;
                    while (j < out.w.length) {
                        if (out.w[j] > largestVal) {
                            largestVal = out.w[j];
                            largestIndx = j;
                        }
                        ++j;
                    }
                    if (largestIndx == loader.getCategory(i)) {
                        ++right;
                    }
                    ++i;
                }
                int cntr = 0;
                for (Matrix m : convNet.getParameters()) {
                    ImageIO.write((RenderedImage)ConvTest.aaa(m), "png", new File("weight" + Integer.toString(cntr) + ".png"));
                    ++cntr;
                }
                System.out.println("Done. Model accuracy is " + Double.toString((double)right / 4000.0 * 100.0) + "%");
                System.out.println(Double.toString((double)right / 4000.0 * 100.0));
                System.out.println("Passing through sample image...");
                int indx = random.nextInt(4000);
                Tensor in = CifarLoader.asTensor(loader.getImage(indx));
                Matrix out = convNet.forward(in, new Graph(false)).getMatrixAt(0);
                System.out.println("Done. Displaying result...");
                int largestIndx = 0;
                double largestVal = Double.MIN_VALUE;
                int i2 = 0;
                while (i2 < out.w.length) {
                    if (out.w[i2] > largestVal) {
                        largestVal = out.w[i2];
                        largestIndx = i2;
                    }
                    ++i2;
                }
                String str = loader.getCategoryString(largestIndx);
                str = String.valueOf(str) + " (" + Integer.toString(largestIndx) + "), should be " + loader.getCategoryString(loader.getCategory(indx)) + " (" + Integer.toString(loader.getCategory(indx)) + ")";
                final JDialog dialog = new JDialog(null, "Result");
                dialog.setResizable(false);
                dialog.setContentPane(new JPanel());
                dialog.getContentPane().setPreferredSize(new Dimension(225, 150));
                dialog.setDefaultCloseOperation(0);
                dialog.getContentPane().add(new JLabel(new ImageIcon(loader.getImage(indx))));
                dialog.getContentPane().add(new JLabel(str));
                JButton close = new JButton("Close");
                close.addActionListener(new ActionListener(){

                    @Override
                    public void actionPerformed(ActionEvent e) {
                        dialog.setVisible(false);
                    }
                });
                dialog.getContentPane().add(close);
                dialog.setModal(true);
                dialog.pack();
                dialog.setVisible(true);
                cntr = 0;
                for (Matrix m : convNet.getParameters()) {
                    ImageIO.write((RenderedImage)ConvTest.aaa(m), "png", new File(String.valueOf(Integer.toString(cntr)) + ".png"));
                    ++cntr;
                }
                System.err.println(Integer.toString(cntr));
                System.out.println("Program done. Exiting.");
                System.exit(0);
            }
            System.out.println("Starting training...");
            int iterations = 150;
            double learningRate = 0.01;
            int i = 0;
            while (i < iterations) {
                System.out.println("Progress info: Iteration " + Integer.toString(i + 1) + "/" + Integer.toString(iterations));
                try {
                    trainer.trainConvNet(convNet, learningRate, 2, dataset, 2, "cifar_test.ser", new File("cifar_test.ser").exists(), true, random);
                }
                catch (Exception e) {
                    e.printStackTrace();
                    System.out.println("Error during training: ");
                    String errorMessage = "";
                    StringWriter w2 = new StringWriter();
                    PrintWriter w = new PrintWriter(w2);
                    e.printStackTrace(w);
                    errorMessage = w2.getBuffer().toString();
                    System.out.println(errorMessage);
                    System.out.println("Presuming learning rate is too high. Automatically decreasing learning rate to " + Double.toString(learningRate / 10.0));
                    learningRate /= 10.0;
                    System.exit(1);
                }
                ++i;
            }
            System.out.println("Complete!");
            System.out.println("Program done. Exiting.");
            System.exit(0);
        }
        catch (Exception e) {
            System.err.println("Error: ");
            e.printStackTrace();
            System.exit(1);
        }
    }

    private static BufferedImage aaa(Matrix m) {
        BufferedImage bi = new BufferedImage(m.cols, m.rows, 1);
        int i = 0;
        while (i < bi.getHeight()) {
            int j = 0;
            while (j < bi.getWidth()) {
                double g = m.getW(i, j);
                if (g > 1.0) {
                    g = 1.0;
                }
                if (g < 0.0) {
                    g = 0.0;
                }
                bi.setRGB(j, i, new Color((int)(g * 255.0), (int)(g * 255.0), (int)(g * 255.0)).getRGB());
                ++j;
            }
            ++i;
        }
        return bi;
    }
}

