/*
 * Decompiled with CFR 0.152.
 */
package examples;

import autodiff.Graph;
import datasets.CifarDataset;
import java.awt.Dimension;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.io.File;
import java.util.Random;
import javax.swing.ImageIcon;
import javax.swing.JButton;
import javax.swing.JDialog;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.UIManager;
import matrix.Matrix;
import matrix.Tensor;
import model.ConvDense;
import model.ConvFlatten;
import model.ConvLayer;
import model.ConvPool;
import model.Dropout;
import model.FeedForwardLayer;
import model.NeuralNetwork;
import model.NormalizeLayer;
import nonlinearities.ExponentialLinearUnit;
import nonlinearities.SigmoidUnit;
import trainer.AMSGrad;
import trainer.Trainer;
import util.CLUtils;
import util.CifarLoader;
import util.FileIO;
import util.NNDevice;

public class Cifar10 {
    public static void main(String[] args) {
        try {
            UIManager.setLookAndFeel(UIManager.getSystemLookAndFeelClassName());
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        NNDevice dev = null;
        try {
            Random random = new Random();
            System.out.println("Loading images...");
            CifarLoader loader = new CifarLoader(new File("cifar/"));
            loader.load(args.length > 0 && args[0].equalsIgnoreCase("true") ? 5 : 2, 10000);
            System.out.println("Complete!");
            System.out.println("Preparing dataset...");
            CifarDataset dataset = new CifarDataset(loader, 64, 10, random);
            System.gc();
            System.out.println("Complete!");
            System.out.println("Preparing Convolutional Neural Network...");
            NeuralNetwork convNet = new NeuralNetwork();
            int cores = 4;
            convNet.addLayer(new ConvLayer(32, 32, 3, 5, 5, 64, 1, 0, new ExponentialLinearUnit(1.0), 0.018, random, true, true, cores));
            convNet.addLayer(new ConvPool(28, 28, 64, 2, 2, 2));
            convNet.addLayer(new NormalizeLayer(1.0E-8, 0.08, random));
            convNet.addLayer(new ConvLayer(14, 14, 64, 3, 3, 128, 1, 1, new ExponentialLinearUnit(1.0), 0.0089, random, true, true, cores));
            convNet.addLayer(new ConvPool(14, 14, 128, 2, 2, 2));
            convNet.addLayer(new NormalizeLayer(1.0E-8, 0.08, random));
            convNet.addLayer(new ConvLayer(7, 7, 128, 2, 2, 128, 1, 0, new ExponentialLinearUnit(1.0), 0.012, random, true, true, cores));
            convNet.addLayer(new ConvPool(6, 6, 128, 2, 2, 2));
            convNet.addLayer(new ConvFlatten(3, 3, 128));
            convNet.addLayer(new Dropout(0.25));
            convNet.addLayer(new ConvDense(new FeedForwardLayer(1152, 512, new ExponentialLinearUnit(1.0), 0.029, random)));
            convNet.addLayer(new ConvDense(new FeedForwardLayer(512, 10, new SigmoidUnit(), 0.044, random)));
            System.out.println("Complete!");
            if (args.length > 0 && args[0].equalsIgnoreCase("true")) {
                double largestVal;
                int largestIndx;
                Matrix out;
                Tensor in;
                System.out.println("Generating test output...");
                System.out.println("Loading Network from saved state...");
                FileIO.loadNeuralNetwork("cifar_test.dat", convNet);
                System.out.println("Complete!");
                System.out.println("Passing through many images...");
                int right = 0;
                Graph g = new Graph(false);
                if (dev != null) {
                    g = CLUtils.createGraph(dev, false);
                }
                int i = 0;
                while (i < 10000) {
                    in = CifarLoader.asTensor(loader.getImage(i));
                    out = convNet.forward((Tensor)in, (Graph)g).matrices[0];
                    largestIndx = 0;
                    largestVal = Double.MIN_VALUE;
                    int j = 0;
                    while (j < out.w.length) {
                        if (out.w[j] > largestVal) {
                            largestVal = out.w[j];
                            largestIndx = j;
                        }
                        ++j;
                    }
                    if (largestIndx == loader.getCategory(i)) {
                        ++right;
                    }
                    ++i;
                }
                System.out.println("Done. Model accuracy is " + Double.toString((double)right / 10000.0 * 100.0) + "%");
                System.out.println("Passing through sample image...");
                int indx = random.nextInt(10000);
                in = CifarLoader.asTensor(loader.getImage(indx));
                out = convNet.forward((Tensor)in, (Graph)new Graph((boolean)false)).matrices[0];
                System.out.println("Done. Displaying result...");
                largestIndx = 0;
                largestVal = Double.MIN_VALUE;
                int i2 = 0;
                while (i2 < out.w.length) {
                    if (out.w[i2] > largestVal) {
                        largestVal = out.w[i2];
                        largestIndx = i2;
                    }
                    ++i2;
                }
                String str = loader.getCategoryString(largestIndx);
                str = String.valueOf(str) + " (" + Integer.toString(largestIndx) + "), should be " + loader.getCategoryString(loader.getCategory(indx)) + " (" + Integer.toString(loader.getCategory(indx)) + ")";
                final JDialog dialog = new JDialog(null, "Result");
                dialog.setResizable(false);
                dialog.setContentPane(new JPanel());
                dialog.getContentPane().setPreferredSize(new Dimension(225, 150));
                dialog.setDefaultCloseOperation(0);
                dialog.getContentPane().add(new JLabel(new ImageIcon(loader.getImage(indx))));
                dialog.getContentPane().add(new JLabel(str));
                JButton close = new JButton("Close");
                close.addActionListener(new ActionListener(){

                    @Override
                    public void actionPerformed(ActionEvent e) {
                        dialog.setVisible(false);
                    }
                });
                dialog.getContentPane().add(close);
                dialog.setModal(true);
                dialog.pack();
                dialog.setVisible(true);
                System.out.println("Program done. Exiting.");
                System.exit(0);
            }
            System.out.println("Starting training...");
            Trainer trainer = new Trainer(new AMSGrad(0.9, 0.999, 1.0E-8), dev);
            int iterations = 50;
            double learningRate = 0.001;
            long startTime = System.currentTimeMillis();
            int i = 0;
            while (i < iterations) {
                System.out.println("Progress info: Iteration " + Integer.toString(i + 1) + "/" + Integer.toString(iterations));
                try {
                    trainer.train(convNet, learningRate, 2, dataset, 2, "cifar_test.dat", new File("cifar_test.dat").exists(), true, random);
                    System.out.println("Loading new images...");
                    loader.load(random.nextInt(4) + 1, 10000);
                    System.out.println("Complete!");
                    System.out.println("Updating dataset...");
                    dataset = new CifarDataset(loader, 64, 10, random);
                    System.out.println("Complete!");
                    System.gc();
                }
                catch (Exception e) {
                    e.printStackTrace();
                    System.out.println("Error during training: ");
                    System.exit(1);
                }
                ++i;
            }
            trainer.dispose();
            System.out.println("Complete!");
            System.out.println("Program done. Exiting.");
            System.err.println(System.currentTimeMillis() - startTime);
            System.exit(0);
        }
        catch (Exception e) {
            System.err.println("Error: ");
            e.printStackTrace();
            System.exit(1);
        }
    }
}

