/*
 * Decompiled with CFR 0.152.
 */
package theGhastModding.lstmStuff.main;

import autodiff.Graph;
import datasets.CifarDataset;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.io.File;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.ArrayList;
import java.util.Random;
import javax.swing.ImageIcon;
import javax.swing.JButton;
import javax.swing.JDialog;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.UIManager;
import matrix.Matrix;
import matrix.Tensor;
import model.ConvFlatten;
import model.ConvLayer;
import model.ConvNet;
import model.ConvNonlinLayer;
import model.FeedForwardLayer;
import model.Model;
import model.NeuralNetwork;
import model.PoolLayer;
import nonlinearities.LinearUnit;
import nonlinearities.ReLuUnit;
import theGhastModding.console.main.Console;
import trainer.Adam;
import trainer.NewTrainer;
import trainer.RMSProp;
import util.CifarLoader;
import util.ConsoleHelper;
import util.FileIO;

public class ConvTest {
    public static void main(String[] args) {
        try {
            UIManager.setLookAndFeel(UIManager.getSystemLookAndFeelClassName());
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        NewTrainer trainer = new NewTrainer(new RMSProp(0.999, 5.0, 1.0E-8));
        if (args.length > 0 && args[0].equalsIgnoreCase("Adam")) {
            trainer = new NewTrainer(new Adam(0.9, 0.999, 1.0E-4));
        }
        Console c = null;
        try {
            Random random = new Random();
            c = new Console(true, 960, 480, 0, "Console", 138, 130, 0);
            c.showFrame();
            c.print("Loading images...");
            CifarLoader loader = new CifarLoader(new File("cifar/"));
            loader.load(1, 4000);
            c.print("Complete!");
            c.print("Preparing dataset...");
            CifarDataset dataset = new CifarDataset(loader, 1000);
            if (args.length <= 0 || !args[0].equalsIgnoreCase("true")) {
                loader = null;
                System.gc();
            }
            c.print("Complete!");
            c.print("Preparing Convolutional Neural Network...");
            ConvNet convNet = new ConvNet();
            convNet.addLayer(new ConvLayer(2, 32, 32, 3, 2, 2, 16, 0.08, random, true, 8));
            convNet.addLayer(new ConvNonlinLayer(new ReLuUnit()));
            convNet.addLayer(new ConvLayer(1, 16, 16, 16, 5, 5, 20, 0.08, random, true, 8));
            convNet.addLayer(new ConvNonlinLayer(new ReLuUnit()));
            convNet.addLayer(new PoolLayer(2, 12, 12, 20, 2, 2));
            convNet.addLayer(new ConvLayer(1, 6, 6, 20, 6, 6, 20, 0.08, random, true, 8));
            convNet.addLayer(new ConvFlatten(1, 1, 20));
            FeedForwardLayer fl1 = new FeedForwardLayer(20, 10, new LinearUnit(), 0.08, random);
            ArrayList<Model> a = new ArrayList<Model>();
            a.add(fl1);
            convNet.setFullyConnected(new NeuralNetwork(a));
            c.print("Complete!");
            if (args.length > 0 && args[0].equalsIgnoreCase("true")) {
                double largestVal;
                int largestIndx;
                Matrix out;
                Tensor in;
                c.print("Generating test output...");
                c.print("Loading Network from saved state...");
                FileIO.loadNeuralNetwork("cifar_test.ser", convNet);
                c.print("Complete!");
                c.print("Passing through many images...");
                int right = 0;
                int i = 0;
                while (i < 4000) {
                    in = CifarLoader.asTensor(loader.getImage(i));
                    out = convNet.forward(in, new Graph(false)).getMatrixAt(0);
                    largestIndx = 0;
                    largestVal = Double.MIN_VALUE;
                    int j = 0;
                    while (j < out.w.length) {
                        if (out.w[j] > largestVal) {
                            largestVal = out.w[j];
                            largestIndx = j;
                        }
                        ++j;
                    }
                    if (largestIndx == loader.getCategory(i)) {
                        ++right;
                    }
                    ++i;
                }
                c.print("Done. Model accuracy is " + Double.toString((double)right / 4000.0 * 100.0) + "%");
                System.out.println(Double.toString((double)right / 4000.0 * 100.0));
                c.print("Passing through sample image...");
                int indx = random.nextInt(dataset.training.size());
                in = CifarLoader.asTensor(loader.getImage(indx));
                out = convNet.forward(in, new Graph(false)).getMatrixAt(0);
                c.print("Done. Displaying result...");
                largestIndx = 0;
                largestVal = Double.MIN_VALUE;
                int i2 = 0;
                while (i2 < out.w.length) {
                    if (out.w[i2] > largestVal) {
                        largestVal = out.w[i2];
                        largestIndx = i2;
                    }
                    ++i2;
                }
                String str = loader.getCategoryString(largestIndx);
                str = String.valueOf(str) + " (" + Integer.toString(largestIndx) + "), should be " + loader.getCategoryString(loader.getCategory(indx)) + " (" + Integer.toString(loader.getCategory(indx)) + ")";
                final JDialog dialog = new JDialog(null, "Result");
                dialog.setResizable(false);
                dialog.setContentPane(new JPanel());
                dialog.getContentPane().setPreferredSize(new Dimension(225, 150));
                dialog.setDefaultCloseOperation(0);
                dialog.getContentPane().add(new JLabel(new ImageIcon(loader.getImage(indx))));
                dialog.getContentPane().add(new JLabel(str));
                JButton close = new JButton("Close");
                close.addActionListener(new ActionListener(){

                    @Override
                    public void actionPerformed(ActionEvent e) {
                        dialog.setVisible(false);
                    }
                });
                dialog.getContentPane().add(close);
                dialog.setModal(true);
                dialog.pack();
                dialog.setVisible(true);
                c.print("Program done. Exiting.");
                ConsoleHelper.pause(2000L);
                c.hideFrame();
                System.exit(0);
            }
            c.print("Starting training...");
            int iterations = 25;
            double learningRate = 0.004;
            int i = 0;
            while (i < iterations) {
                c.print("Progress info: Iteration " + Integer.toString(i + 1) + "/" + Integer.toString(iterations));
                try {
                    trainer.trainConvNet(convNet, learningRate, 2, dataset, 2, "cifar_test.ser", new File("cifar_test.ser").exists(), true, random, c);
                }
                catch (Exception e) {
                    e.printStackTrace();
                    c.print("Error during training: ");
                    String errorMessage = "";
                    StringWriter w2 = new StringWriter();
                    PrintWriter w = new PrintWriter(w2);
                    e.printStackTrace(w);
                    errorMessage = w2.getBuffer().toString();
                    c.print(errorMessage);
                    c.print("Presuming learning rate is too high. Automatically decreasing learning rate to " + Double.toString(learningRate / 10.0));
                    learningRate /= 10.0;
                    System.exit(1);
                }
                ++i;
            }
            c.print("Complete!");
            c.print("Program done. Exiting.");
            ConsoleHelper.pause(5000L);
            c.hideFrame();
            System.exit(0);
        }
        catch (Exception e) {
            System.err.println("Error: ");
            e.printStackTrace();
            if (c != null) {
                c.clear();
                c.setTextColor(Color.RED);
                c.print("Error: ");
                String errorMessage = "";
                StringWriter w2 = new StringWriter();
                PrintWriter w = new PrintWriter(w2);
                e.printStackTrace(w);
                errorMessage = w2.getBuffer().toString();
                c.print(errorMessage);
                ConsoleHelper.pause(5000L);
                c.hideFrame();
            }
            System.exit(1);
        }
    }
}

