/*
 * Decompiled with CFR 0.152.
 */
package theGhastModding.lstmStuff.main;

import autodiff.Graph;
import datastructs.DataSequence;
import datastructs.DataSet;
import datastructs.DataStep;
import edu.cornell.lassp.houle.RngPack.RanMT;
import java.awt.Color;
import java.awt.geom.AffineTransform;
import java.awt.image.AffineTransformOp;
import java.awt.image.BufferedImage;
import java.awt.image.RenderedImage;
import java.io.File;
import java.io.FileInputStream;
import java.util.ArrayList;
import java.util.Random;
import javax.imageio.ImageIO;
import javax.swing.UIManager;
import matrix.Matrix;
import matrix.Tensor;
import model.ConvDense;
import model.ConvExpand;
import model.ConvFlatten;
import model.ConvLayer;
import model.ConvUpsample;
import model.FeedForwardLayer;
import model.NeuralNetwork;
import nonlinearities.ExponentialLinearUnit;
import nonlinearities.LinearUnit;
import nonlinearities.Nonlinearity;
import nonlinearities.TanhUnit;
import theGhastModding.lstmStuff.gui.ProgressBarWindow;
import trainer.AMSGrad;
import trainer.Adam;
import trainer.BasicSGD;
import trainer.GANTrainer;
import trainer.Optimizer;
import trainer.RMSProp;
import util.CLUtils;
import util.FileChannelInputStream;
import util.FileIO;
import util.NNDevice;

public class GANTest {
    private static NeuralNetwork generator;
    private static NeuralNetwork discriminator;
    private static int cores;
    private static final int seedLength = 768;
    private static final String saveBasePath = "./Pony_GAN";
    private static NNDevice[] devices;
    private static Graph staticGraph;

    static {
        cores = 0;
        devices = null;
        staticGraph = null;
    }

    public static void main(String[] args) {
        try {
            UIManager.setLookAndFeel(UIManager.getSystemLookAndFeelClassName());
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        try {
            block29: {
                if (args.length != 0 && args[0].equalsIgnoreCase("generate")) {
                    GANTest.loadNets();
                    File outFolder = new File("GAN/");
                    outFolder.mkdirs();
                    ProgressBarWindow pgrs = new ProgressBarWindow(200);
                    int k = 0;
                    while (k < 200) {
                        BufferedImage img = GANTest.generateImage();
                        ImageIO.write((RenderedImage)img, "png", new File("GAN/GAN_output" + Integer.toString(k + 1) + ".png"));
                        pgrs.setProgress(k + 1);
                        pgrs.setLoss(k + 1);
                        ++k;
                    }
                    pgrs.close();
                    System.out.println("Done.");
                    staticGraph.cleanUp();
                    System.exit(0);
                }
                if (args.length != 6 && args.length != 10 && args.length != 7) {
                    if (args.length != 0 && args[0].equalsIgnoreCase("help")) {
                        System.out.println("Usage: [path to images] [iterations] [epochs] [batch size] [learning rate] [optimizer] {CPU cores} {opencl platform name hint} {opencl device name hint} {device no.}\n");
                        System.out.println("Optimizer can be either one of these:\n\tSGD\n\tRMSProp\n\tAdam (Recommended)\n\tAMSGrad (Experimental)");
                        System.out.println("\nRun with 'clList' as only argument to list OpenCL platforms & devices");
                        System.out.println("Platform name hint argument and device name argument have to be used together, CPU cores argument can be used alone");
                        System.exit(0);
                    }
                    if (args.length != 0 && args[0].equalsIgnoreCase("clList")) {
                        CLUtils.printDevices();
                        System.exit(0);
                    }
                    System.err.println("Usage: [path to images] [iterations] [epochs] [batch size] [learning rate] [optimizer] {CPU cores} {opencl platform name hint} {opencl device name hint} {device no.}");
                    System.exit(1);
                }
                cores = args.length == 7 || args.length == 10 ? Integer.parseInt(args[6]) : 1;
                if (args.length == 10) {
                    try {
                        devices = CLUtils.findDevice(args[7], args[8]);
                        if (devices.length == 0) {
                            throw new Exception("Device was not found");
                        }
                        if (Integer.parseInt(args[9]) == -1) {
                            staticGraph = devices.length > 1 ? CLUtils.createGraph(devices[0], devices[1], false) : CLUtils.createGraph(devices[0], false);
                            break block29;
                        }
                        int num = Integer.parseInt(args[9]);
                        if (num >= devices.length) {
                            throw new Exception("Device was not found");
                        }
                        staticGraph = CLUtils.createGraph(devices[num], false);
                        devices = new NNDevice[]{devices[num]};
                    }
                    catch (Exception e) {
                        System.err.println("Error getting OpenCL device: ");
                        e.printStackTrace();
                        System.exit(1);
                    }
                } else {
                    devices = null;
                    staticGraph = new Graph(false);
                }
            }
            File imagesFolder = new File(args[0]);
            if (!imagesFolder.exists()) {
                System.err.println(String.valueOf(imagesFolder.getPath()) + ": folder not found");
                System.exit(1);
            }
            if (imagesFolder.listFiles().length < 128) {
                System.err.println("Images folder does not contain enough images (128 images minimum)");
                System.exit(1);
            }
            GANTest.loadNets();
            int epochs = Integer.parseInt(args[2]);
            int iterations = Integer.parseInt(args[1]);
            int batchSize = Integer.parseInt(args[3]);
            double learningRate = Double.parseDouble(args[4]);
            String optimizerName = args[5];
            Optimizer optimizer = null;
            if (optimizerName.equalsIgnoreCase("SGD")) {
                optimizer = new BasicSGD(0.5);
            }
            if (optimizerName.equalsIgnoreCase("RMSProp")) {
                optimizer = new RMSProp(0.5, 5.0, 1.0E-9);
            }
            if (optimizerName.equalsIgnoreCase("Adam")) {
                optimizer = new Adam(0.5, 0.99, 1.0E-9);
            }
            if (optimizerName.equalsIgnoreCase("AMSGrad")) {
                System.err.println(String.valueOf(optimizerName) + ": WARNING: it is not recommended to use this optimizer with GANs");
                System.err.println(String.valueOf(optimizerName) + ": WARNING: untested optimizer");
                optimizer = new AMSGrad(0.5, 0.99, 1.0E-9);
            }
            if (optimizer == null) {
                System.err.println(String.valueOf(args[5]) + ": Invalid optimizer name");
                System.exit(0);
            }
            GANTrainer trainer = new GANTrainer(optimizer, 0.5, devices);
            File f = new File("stop.txt");
            if (f.exists()) {
                f.delete();
            }
            int i = 0;
            while (i < GANTest.generator.t / 2000) {
                learningRate *= 0.95;
                ++i;
            }
            int lastT = GANTest.generator.t / 2000;
            int i2 = 0;
            while (i2 < iterations) {
                block30: {
                    ImageDataset dataset;
                    System.out.println(String.valueOf(Integer.toString(i2 + 1)) + "/" + Integer.toString(iterations));
                    if (f.exists()) break;
                    try {
                        dataset = new ImageDataset(imagesFolder, 64, batchSize * epochs + batchSize * epochs + 16);
                    }
                    catch (Exception e) {
                        e.printStackTrace();
                        break block30;
                    }
                    System.out.println("Iteration " + Integer.toString(i2 + 1) + "/" + Integer.toString(iterations));
                    double loss = trainer.train(generator, discriminator, learningRate, epochs, batchSize, dataset, epochs, null, null, false, false, new Random());
                    if (GANTest.generator.t / 2000 > lastT) {
                        learningRate *= 0.95;
                        lastT = GANTest.generator.t / 2000;
                    }
                    GANTest.saveNetworks();
                    System.out.println(Double.toString(loss));
                }
                ++i2;
            }
            trainer.dispose();
            GANTest.saveNetworks();
            System.out.println("Done.");
            staticGraph.cleanUp();
            System.exit(0);
        }
        catch (Exception e) {
            System.err.println("Error: ");
            e.printStackTrace();
            System.exit(1);
        }
    }

    private static void loadNets() throws Exception {
        generator = new NeuralNetwork();
        discriminator = new NeuralNetwork();
        Random rng = new Random();
        generator.addLayer(new ConvDense(new FeedForwardLayer(768, 8192, new LinearUnit(), 0.044, rng)));
        generator.addLayer(new ConvExpand(8192, 8, 8, 128));
        generator.addLayer(new ConvLayer(8, 8, 128, 5, 5, 135, 1, 2, new ExponentialLinearUnit(1.0), 0.011408, rng, true, true, cores));
        generator.addLayer(new ConvLayer(8, 8, 135, 3, 3, 135, 1, 1, new ExponentialLinearUnit(1.0), 0.011408, rng, true, true, cores));
        generator.addLayer(new ConvUpsample(2));
        generator.addLayer(new ConvLayer(16, 16, 135, 3, 3, 135, 1, 1, new ExponentialLinearUnit(1.0), 0.01, rng, true, true, cores));
        generator.addLayer(new ConvLayer(16, 16, 135, 3, 3, 135, 1, 1, new ExponentialLinearUnit(1.0), 0.01, rng, true, true, cores));
        generator.addLayer(new ConvUpsample(2));
        generator.addLayer(new ConvLayer(32, 32, 135, 3, 3, 120, 1, 1, new ExponentialLinearUnit(1.0), 0.01, rng, true, true, cores));
        generator.addLayer(new ConvLayer(32, 32, 120, 3, 3, 120, 1, 1, new ExponentialLinearUnit(1.0), 0.01, rng, true, true, cores));
        generator.addLayer(new ConvUpsample(2));
        generator.addLayer(new ConvLayer(64, 64, 120, 3, 3, 64, 1, 1, new ExponentialLinearUnit(1.0), 0.01, rng, true, true, cores));
        generator.addLayer(new ConvLayer(64, 64, 64, 3, 3, 3, 1, 1, new LinearUnit(), 0.01, rng, true, true, cores));
        discriminator.addLayer(new ConvLayer(64, 64, 3, 3, 3, 120, 1, 1, new ExponentialLinearUnit(1.0), 0.01, rng, true, true, cores));
        discriminator.addLayer(new ConvLayer(64, 64, 120, 3, 3, 120, 1, 1, new ExponentialLinearUnit(1.0), 0.01, rng, true, true, cores));
        discriminator.addLayer(new ConvLayer(64, 64, 120, 4, 4, 240, 2, 1, new ExponentialLinearUnit(1.0), 0.01, rng, true, true, cores));
        discriminator.addLayer(new ConvLayer(32, 32, 240, 3, 3, 240, 1, 1, new ExponentialLinearUnit(1.0), 0.01, rng, true, true, cores));
        discriminator.addLayer(new ConvLayer(32, 32, 240, 4, 4, 360, 2, 1, new ExponentialLinearUnit(1.0), 0.01, rng, true, true, cores));
        discriminator.addLayer(new ConvLayer(16, 16, 360, 3, 3, 360, 1, 1, new ExponentialLinearUnit(1.0), 0.01, rng, true, true, cores));
        discriminator.addLayer(new ConvLayer(16, 16, 360, 4, 4, 360, 2, 1, new ExponentialLinearUnit(1.0), 0.01, rng, true, true, cores));
        discriminator.addLayer(new ConvLayer(8, 8, 360, 3, 3, 480, 1, 1, new ExponentialLinearUnit(1.0), 0.01, rng, true, true, cores));
        discriminator.addLayer(new ConvLayer(8, 8, 480, 4, 4, 480, 2, 1, new ExponentialLinearUnit(1.0), 0.01, rng, true, true, cores));
        discriminator.addLayer(new ConvFlatten(4, 4, 480));
        discriminator.addLayer(new ConvDense(new FeedForwardLayer(7680, 768, new LinearUnit(), 0.0110485, rng)));
        File generatorFile = new File("./Pony_GAN_generator.dat");
        File discriminatorFile = new File("./Pony_GAN_discriminator.dat");
        if (generatorFile.exists() && discriminatorFile.exists()) {
            FileIO.loadNeuralNetwork(generatorFile.getPath(), generator);
            FileIO.loadNeuralNetwork(discriminatorFile.getPath(), discriminator);
        } else {
            System.err.println("WAIT!");
        }
        long totalParams = 0L;
        for (Matrix m : generator.getParameters()) {
            totalParams += (long)m.w.length;
        }
        for (Matrix m : discriminator.getParameters()) {
            totalParams += (long)m.w.length;
        }
        System.out.println("Loaded. " + Long.toString(totalParams) + " total parameters.");
    }

    private static void saveNetworks() throws Exception {
        FileIO.saveNeuralNetwork("./Pony_GAN_generator.dat", generator);
        FileIO.saveNeuralNetwork("./Pony_GAN_discriminator.dat", discriminator);
        System.out.println("Saved.");
    }

    private static Tensor seedTensor() throws Exception {
        return Tensor.rand(1, 768, 1, 1.0, new Random());
    }

    private static BufferedImage generateImage() throws Exception {
        generator.resetState();
        Tensor genOut = generator.forward(GANTest.seedTensor(), staticGraph);
        generator.resetState();
        return GANTest.asImage(genOut);
    }

    public static BufferedImage asImage(Tensor t) throws Exception {
        BufferedImage img = new BufferedImage(t.width, t.height, 1);
        int i = 0;
        while (i < t.width) {
            int j = 0;
            while (j < t.height) {
                if (t.depth == 3) {
                    int r = (int)((t.matrices[0].getW(j, i) + 1.0) / 2.0 * 255.0);
                    int g = (int)((t.matrices[1].getW(j, i) + 1.0) / 2.0 * 255.0);
                    int b = (int)((t.matrices[2].getW(j, i) + 1.0) / 2.0 * 255.0);
                    if (r < 0) {
                        r = 0;
                    }
                    if (r > 255) {
                        r = 255;
                    }
                    if (g < 0) {
                        g = 0;
                    }
                    if (g > 255) {
                        g = 255;
                    }
                    if (b < 0) {
                        b = 0;
                    }
                    if (b > 255) {
                        b = 255;
                    }
                    img.setRGB(i, j, new Color(r, g, b).getRGB());
                } else {
                    int col = (int)((t.matrices[0].getW(j, i) + 1.0) / 2.0 * 255.0);
                    if (col < 0) {
                        col = 0;
                    }
                    if (col > 255) {
                        col = 255;
                    }
                    img.setRGB(i, j, new Color(col, col, col).getRGB());
                }
                ++j;
            }
            ++i;
        }
        return img;
    }

    public static Tensor asTensor(BufferedImage img, boolean color) {
        Tensor toReturn = new Tensor(img.getWidth(), img.getHeight(), color ? 3 : 1);
        int i = 0;
        while (i < img.getWidth()) {
            int j = 0;
            while (j < img.getHeight()) {
                int rgb = img.getRGB(i, j);
                int red = rgb >> 16 & 0xFF;
                int green = rgb >> 8 & 0xFF;
                int blue = rgb & 0xFF;
                if (!color) {
                    double col = red + green + blue;
                    col /= 3.0;
                    col = col / 255.0 * 2.0 - 1.0;
                    toReturn.matrices[0].setW(j, i, col);
                } else {
                    double r = (double)red / 255.0 * 2.0 - 1.0;
                    double g = (double)green / 255.0 * 2.0 - 1.0;
                    double b = (double)blue / 255.0 * 2.0 - 1.0;
                    toReturn.matrices[0].setW(j, i, r);
                    toReturn.matrices[1].setW(j, i, g);
                    toReturn.matrices[2].setW(j, i, b);
                }
                ++j;
            }
            ++i;
        }
        return toReturn;
    }

    private static class ImageDataset
    extends DataSet {
        private boolean color = true;

        /*
         * Unable to fully structure code
         */
        public ImageDataset(File framesFolder, int targetSize, int numImages) throws Exception {
            super();
            this.training = new ArrayList<E>();
            this.inputDimension = new DataSet.TensorDimensions(1, 768, 1);
            this.outputDimension = new DataSet.TensorDimensions(targetSize, targetSize, this.color != false ? 3 : 1);
            rn = new RanMT();
            files = framesFolder.listFiles();
            imageCount = files.length;
            seq = new DataSequence();
            System.out.println("Loading images...");
            prevPercent = 0;
            System.out.print("|");
            i = 0;
            while (i < 98) {
                System.out.print("-");
                ++i;
            }
            System.out.println("|");
            j = 0;
            ** GOTO lbl76
            {
                System.out.print(">");
                ++prevPercent;
                do {
                    block14: {
                        if ((int)((double)j / (double)numImages * 100.0) > prevPercent) continue block3;
                        randomImageIndx = 0;
                        if (numImages >= imageCount) {
                            randomImageIndx = j;
                            if (j >= imageCount) {
                                break block3;
                            }
                        } else {
                            randomImageIndx = (int)(rn.raw() * (double)(imageCount - 3)) + 1;
                        }
                        fis = new FileInputStream(files[randomImageIndx]);
                        in = new FileChannelInputStream(fis.getChannel());
                        currImage = null;
                        try {
                            currImage = ImageIO.read(in);
                        }
                        catch (Exception e) {
                            System.err.println("Error reading image (probably not a valid image file): ");
                            e.printStackTrace();
                            break block14;
                        }
                        in.close();
                        fis.close();
                        if (currImage != null) {
                            if (currImage.getWidth() != currImage.getHeight()) {
                                width = currImage.getWidth();
                                if (width < (height = currImage.getHeight())) {
                                    width = (int)((double)targetSize / (double)height * (double)width);
                                    height = targetSize;
                                } else {
                                    height = (int)((double)targetSize / (double)width * (double)height);
                                    width = targetSize;
                                }
                                startx = (targetSize - width) / 2;
                                starty = (targetSize - height) / 2;
                                resImage = new BufferedImage(targetSize, targetSize, 1);
                                resImage.getGraphics().drawImage(currImage, startx, starty, width, height, null);
                                resImage.getGraphics().dispose();
                                t = GANTest.asTensor(resImage, this.color);
                                seq.addDataStep(new DataStep(t.clone(), null));
                            } else {
                                resImage = new BufferedImage(targetSize, targetSize, 1);
                                resImage.getGraphics().drawImage(currImage, 0, 0, targetSize, targetSize, null);
                                resImage.getGraphics().dispose();
                                t = GANTest.asTensor(resImage, this.color);
                                seq.addDataStep(new DataStep(t.clone(), null));
                                tx = AffineTransform.getScaleInstance(1.0, -1.0);
                                tx.translate(0.0, -resImage.getHeight(null));
                                op = new AffineTransformOp(tx, 1);
                                resImage = op.filter(resImage, null);
                                t = GANTest.asTensor(resImage, this.color);
                                seq.addDataStep(new DataStep(t, null));
                            }
                        }
                    }
                    ++j;
lbl76:
                    // 2 sources

                } while (j < numImages);
            }
            i = 0;
            while (i < 100 - prevPercent) {
                System.out.print(">");
                ++i;
            }
            System.out.println();
            this.training.add(seq);
        }

        @Override
        public void DisplayReport(NeuralNetwork model, Random rng) throws Exception {
            try {
                long time = System.currentTimeMillis();
                ImageIO.write((RenderedImage)GANTest.generateImage(), "png", new File("GAN/GAN_output" + Long.toString(time) + ".png"));
                ImageIO.write((RenderedImage)GANTest.asImage(((DataSequence)this.training.get((int)0)).getDataStep((int)rng.nextInt((int)((DataSequence)this.training.get((int)0)).getSequenceLength())).input), "png", new File("GAN/GAN_output" + Long.toString(time + 1L) + ".png"));
            }
            catch (Exception e) {
                System.err.println("Error generating sample image: ");
                e.printStackTrace();
            }
            System.err.println("a");
        }

        @Override
        public Nonlinearity getModelOutputUnitToUse() {
            return new TanhUnit();
        }
    }
}

