/*
 * Decompiled with CFR 0.152.
 */
package examples;

import autodiff.Graph;
import datastructs.DataSequence;
import datastructs.DataSet;
import datastructs.DataStep;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.Graphics2D;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.FileInputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Random;
import javax.swing.JFrame;
import loss.LossSoftmax;
import matrix.Tensor;
import model.FeedForwardLayer;
import model.NeuralNetwork;
import nonlinearities.Nonlinearity;
import nonlinearities.ReLuUnit;
import nonlinearities.SigmoidUnit;
import trainer.AMSGrad;
import trainer.Trainer;
import util.FileIO;

public class MnistDemo {
    public static void main(String[] args) {
        try {
            MNISTDataset dataset = new MNISTDataset(new File("mnist-images.dat"), new File("mnist-labels.dat"));
            System.out.println("Dataset loaded\nInitalizing network");
            NeuralNetwork net = new NeuralNetwork();
            net.addLayer(new FeedForwardLayer(784, 25, new ReLuUnit(), 0.035714, new Random()));
            System.out.println("Fully connected; 784 --> 25; Sigmoid;");
            net.addLayer(new FeedForwardLayer(25, 10, new SigmoidUnit(), 0.4, new Random()));
            System.out.println("Fully connected; 25 --> 10; Sigmoid;");
            System.out.println();
            String savefile = "mnist.dat";
            if (new File("mnist.dat").exists()) {
                System.out.println("Loading model from saved state");
                FileIO.loadNeuralNetwork("mnist.dat", net);
                System.out.println("Loaded");
            }
            System.out.println("Using optimizer \"AMSGrad\" and loss \"Softmax\"");
            Trainer trainer = new Trainer(new AMSGrad(0.9, 0.999, 1.0E-8));
            System.out.println("Learning rate is 0.001");
            Thread.sleep(10000L);
            int iters = 50;
            JFrame frame = new JFrame("a");
            frame.getContentPane().setLayout(null);
            frame.getContentPane().setPreferredSize(new Dimension(280, 280));
            frame.setDefaultCloseOperation(0);
            frame.setResizable(false);
            frame.pack();
            frame.setVisible(true);
            Random rng = new Random();
            Graph g = new Graph(false);
            BufferedImage displayImg = new BufferedImage(280, 280, 1);
            Graphics2D gr = (Graphics2D)displayImg.getGraphics();
            int i = 0;
            while (i < 50) {
                System.out.println("Iteration " + Integer.toString(i + 1) + "/" + Integer.toString(50));
                frame.setTitle("Iteration " + Integer.toString(i + 1) + "/" + Integer.toString(50));
                double loss = trainer.train(net, 0.001, 5, dataset, 5, "mnist.dat", false, true, rng);
                if (i != 0 && i % 2 == 0) {
                    dataset.DisplayReport(net, new Random());
                }
                int indx = rng.nextInt(4000);
                gr.drawImage(dataset.asImage(((DataSequence)dataset.training.get((int)0)).getDataStep((int)indx).input.matrices[0].w), 0, 0, 280, 280, frame);
                double[] netPrediction = net.forward((Tensor)((DataSequence)dataset.training.get((int)0)).getDataStep((int)indx).input, (Graph)g).matrices[0].w;
                int argmax = 0;
                double max = Double.NEGATIVE_INFINITY;
                int i2 = 0;
                while (i2 < 10) {
                    if (netPrediction[i2] > max) {
                        max = netPrediction[i2];
                        argmax = i2;
                    }
                    ++i2;
                }
                gr.setColor(Color.WHITE);
                gr.drawString(String.valueOf(Integer.toString(argmax)) + "   -   " + Double.toString(loss), 12, 265);
                frame.getContentPane().getGraphics().drawImage(displayImg, 0, 0, 280, 280, null);
                ++i;
            }
            trainer.dispose();
            Thread.sleep(5000L);
            frame.setVisible(false);
            System.out.println("All done");
            System.exit(0);
        }
        catch (Exception e) {
            System.err.println("Error: ");
            e.printStackTrace();
            System.exit(1);
        }
    }

    private static class MNISTDataset
    extends DataSet {
        private int bytesToInt(byte[] bytes) {
            ByteBuffer buff = ByteBuffer.allocate(4);
            buff.put(bytes);
            buff.flip();
            int x = buff.getInt(0);
            buff.clear();
            return x;
        }

        public MNISTDataset(File images, File labels) throws Exception {
            this.training = new ArrayList();
            this.training.add(new DataSequence());
            this.lossTraining = new LossSoftmax();
            this.lossReporting = new LossSoftmax();
            FileInputStream fisImages = new FileInputStream(images);
            byte[] intBuffer = new byte[4];
            fisImages.read(intBuffer);
            if (this.bytesToInt(intBuffer) != 2051) {
                fisImages.close();
                throw new Exception("Invalid magic no.");
            }
            fisImages.skip(12L);
            FileInputStream fisLabels = new FileInputStream(labels);
            fisLabels.read(intBuffer);
            if (this.bytesToInt(intBuffer) != 2049) {
                fisImages.close();
                fisLabels.close();
                throw new Exception("Invalid magic no.");
            }
            fisLabels.skip(4L);
            int i = 0;
            while (i < 10000) {
                double[] newInputVec = new double[784];
                int j = 0;
                while (j < 784) {
                    int col = fisImages.read() & 0xFF;
                    newInputVec[j] = (double)col / 255.0;
                    ++j;
                }
                int label = fisLabels.read() & 0xFF;
                double[] newOutputVec = new double[10];
                newOutputVec[label] = 1.0;
                ((DataSequence)this.training.get(this.training.size() - 1)).addDataStep(new DataStep(newInputVec, newOutputVec));
                if (i != 0 && i % 5000 == 0) {
                    System.out.println(String.valueOf(Integer.toString(i)) + "/10000");
                    this.training.add(new DataSequence());
                }
                ++i;
            }
            fisImages.close();
            fisLabels.close();
        }

        public BufferedImage asImage(double[] arr) {
            BufferedImage newImage = new BufferedImage(28, 28, 1);
            int j = 0;
            while (j < 784) {
                int col = (int)(arr[j] * 255.0);
                newImage.setRGB(j % 28, j / 28, new Color(col, col, col).getRGB());
                ++j;
            }
            return newImage;
        }

        @Override
        public void DisplayReport(NeuralNetwork model, Random rng) throws Exception {
            System.out.println("Median perplexity: " + LossSoftmax.calculateMedianPerplexity(model, this.training));
        }

        @Override
        public Nonlinearity getModelOutputUnitToUse() {
            return new SigmoidUnit();
        }
    }
}

