/*
 * Decompiled with CFR 0.152.
 */
package theGhastModding.lstmStuff.gameThingy;

import autodiff.Graph;
import edu.cornell.lassp.houle.RngPack.RanMT;
import java.awt.image.BufferedImage;
import java.io.BufferedWriter;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import javax.swing.UIManager;
import loss.Loss;
import loss.LossSumOfSquares;
import matrix.Matrix;
import matrix.Tensor;
import model.ConvDropout;
import model.ConvFlatten;
import model.ConvLayer;
import model.ConvNet;
import model.ConvNonlinLayer;
import model.Dropout;
import model.FeedForwardLayer;
import model.Model;
import model.NeuralNetwork;
import model.TensorLayer;
import nonlinearities.LinearUnit;
import nonlinearities.ReLuUnit;
import theGhastModding.lstmStuff.gameThingy.CatchGame;
import theGhastModding.lstmStuff.gui.NeuralNetworkDisplay;
import theGhastModding.lstmStuff.gui.ResourceDisplay;
import theGhastModding.lstmStuff.main.VideoEncoder;
import trainer.AMSGrad;
import trainer.GradientNoise;
import trainer.TrainingMethod;
import util.FileIO;

public class BreakoutAI
implements Runnable {
    private CatchGame game;
    private ConvNet model;
    private ConvNet targetModel;
    private String savePath;
    private File replaySavePath;
    private double epsilon = 1.0;
    private boolean training = true;
    private TrainingMethod method;
    private int ts = 0;
    private int maxPlays = 5;
    private Loss lossToUse;
    private final int batchSize = 10;
    private static NeuralNetworkDisplay dp;
    private static NeuralNetworkDisplay dp2;
    private static NeuralNetworkDisplay dp3;
    private final int bufferSize = 400;
    private ReplayFrame[] replayBuffer;
    double a;
    double b;
    double c = 0.0;
    int d;
    int e;
    int f = 0;

    public BreakoutAI(ConvNet model, String savePath, String replaySavePath, double epsilon, TrainingMethod method) throws Exception {
        this.model = model;
        this.savePath = savePath;
        this.epsilon = epsilon;
        this.method = method;
        this.lossToUse = new LossSumOfSquares();
        if (new File(savePath).exists()) {
            FileIO.loadNeuralNetwork(savePath, model);
        }
        this.replaySavePath = new File(replaySavePath);
        if (!this.replaySavePath.exists()) {
            this.replaySavePath.mkdirs();
        }
        if (new File(String.valueOf(savePath) + ".a").exists()) {
            this.loadSettings();
        }
    }

    private void loadSettings() {
        try {
            DataInputStream dis = new DataInputStream(new FileInputStream(new File(String.valueOf(this.savePath) + ".a")));
            this.ts = dis.readInt();
            dis.close();
        }
        catch (Exception e) {
            System.err.println("Error loading settings (using defaults): ");
            e.printStackTrace();
            this.ts = 0;
            return;
        }
    }

    private void saveSettings() {
        try {
            DataOutputStream dos = new DataOutputStream(new FileOutputStream(new File(String.valueOf(this.savePath) + ".a")));
            dos.writeInt(this.ts);
            dos.close();
        }
        catch (Exception e) {
            System.err.println("Error saving settings: ");
            e.printStackTrace();
            return;
        }
    }

    public void startGame() {
        this.game = new CatchGame();
        this.game.create();
        this.game.init();
        this.game.renderNextFrame();
    }

    public void stopGame() {
        this.game.destroy();
        this.game = null;
    }

    @Override
    public void run() {
        this.model.resetState();
        this.targetModel = (ConvNet)this.model.clone();
        this.targetModel.resetState();
        this.startGame();
        this.game.renderNextFrame();
        BufferedImage t = null;
        Tensor inBuffer = new Tensor(84, 84, 4);
        Matrix outBuffer = null;
        double previousScore = 0.0;
        this.replayBuffer = new ReplayFrame[400];
        Graph g = new Graph(false);
        RanMT random = new RanMT();
        int cntr2 = 0;
        double nomLoss = 0.0;
        double denomLoss = 0.0;
        this.epsilon = 1.0;
        try {
            VideoEncoder enc = new VideoEncoder(new File("training.mp4"), 30, 84, 84, 4, "medium", 18, false);
            this.game.renderNextFrame();
            BufferedWriter bw = new BufferedWriter(new FileWriter(new File("scores.csv"), true));
            int x = 0;
            while (x < this.maxPlays) {
                int stillCount = 0;
                int rightCount = 0;
                int leftCount = 0;
                int stepCounter = 0;
                int i = 0;
                while (i < 3) {
                    t = this.game.getCurrentFrame();
                    this.game.renderNextFrame();
                    Matrix t_m = this.grayscale(t);
                    int j = 0;
                    while (j < t_m.w.length) {
                        int n = j++;
                        t_m.w[n] = t_m.w[n] + random.raw() * 0.05;
                    }
                    inBuffer.setMatrixAt(i + 1, t_m);
                    ++i;
                }
                t = this.game.getCurrentFrame();
                enc.encodeFrame(t);
                System.out.println(String.valueOf(Integer.toString(x + 1)) + "," + Integer.toString(this.maxPlays));
                inBuffer = new Tensor(84, 84, 4);
                previousScore = 0.0;
                while (!this.game.isGameOver()) {
                    int netAction;
                    if (!this.training) {
                        Thread.sleep(32L);
                    }
                    ++stepCounter;
                    inBuffer.setMatrixAt(0, inBuffer.getMatrixAt(1));
                    inBuffer.setMatrixAt(1, inBuffer.getMatrixAt(2));
                    inBuffer.setMatrixAt(2, inBuffer.getMatrixAt(3));
                    Matrix t_m = this.grayscale(t);
                    int i2 = 0;
                    while (i2 < t_m.w.length) {
                        int n = i2++;
                        t_m.w[n] = t_m.w[n] + random.raw() * 0.05;
                    }
                    inBuffer.setMatrixAt(3, t_m);
                    outBuffer = this.model.forward(inBuffer, g).getMatrixAt(0);
                    int usedAction = netAction = this.argmax(outBuffer);
                    if (random.raw() < this.epsilon) {
                        usedAction = random.choose(0, 3);
                    }
                    if (netAction == 0) {
                        ++stillCount;
                    }
                    if (netAction == 1) {
                        ++rightCount;
                    }
                    if (netAction == 2) {
                        ++leftCount;
                    }
                    if (usedAction == 0) {
                        this.game.setRight(false);
                        this.game.setLeft(false);
                    } else if (usedAction == 1) {
                        this.game.setRight(true);
                        this.game.setLeft(false);
                    } else if (usedAction == 2) {
                        this.game.setRight(false);
                        this.game.setLeft(true);
                    } else {
                        this.game.setRight(false);
                        this.game.setLeft(false);
                    }
                    this.game.renderNextFrame();
                    double reward = this.game.getScore() - previousScore;
                    previousScore = this.game.getScore();
                    if (this.training) {
                        ReplayFrame rf = new ReplayFrame(inBuffer.clone(), reward, netAction, usedAction, this.game.isGameOver());
                        int i3 = 0;
                        while (i3 < 399) {
                            this.replayBuffer[i3] = this.replayBuffer[i3 + 1];
                            ++i3;
                        }
                        this.replayBuffer[this.replayBuffer.length - 1] = rf;
                    }
                    if (this.training && x > 0) {
                        dp.setFrozen(true);
                        dp2.setFrozen(true);
                        dp3.setFrozen(true);
                        double loss = this.train(random, 1.0E-4, 10);
                        dp.setFrozen(false);
                        dp2.setFrozen(false);
                        dp3.setFrozen(false);
                        if (Double.isInfinite(loss) || Double.isNaN(loss)) {
                            System.err.println("[ERR]: Invalid loss of " + Double.toString(loss) + "!");
                            bw.close();
                            return;
                        }
                        nomLoss += loss;
                        denomLoss += 1.0;
                        if (++cntr2 >= 64) {
                            ++this.model.t;
                            cntr2 = 0;
                        }
                    }
                    t = this.game.getCurrentFrame();
                    enc.encodeFrame(t);
                }
                System.out.println("[DEBUG]: Final score was " + Double.toString(this.game.getScore()));
                bw.write(String.valueOf(Double.toString(this.game.getScore())) + "," + Double.toString(nomLoss / denomLoss) + "," + Double.toString((double)rightCount / (double)stepCounter * 100.0) + "," + Double.toString((double)leftCount / (double)stepCounter * 100.0) + "," + Double.toString((double)stillCount / (double)stepCounter * 100.0) + "," + Double.toString(this.epsilon) + ",");
                bw.newLine();
                bw.flush();
                this.game.init();
                this.model.resetState();
                this.targetModel = (ConvNet)this.model.clone();
                this.targetModel.resetState();
                if (this.training) {
                    ++this.ts;
                    this.epsilon = Math.abs(Math.sin(this.ts));
                    System.out.println("[DEBUG]: Epsilon: " + this.epsilon);
                    System.out.println("[DEBUG]: Current loss is: " + Double.toString(nomLoss / denomLoss));
                    nomLoss = 0.0;
                    denomLoss = 0.0;
                }
                if (x % 8 == 0 && this.training) {
                    System.out.println("[DEBUG]: Autosaving...");
                    FileIO.saveNeuralNetwork(this.savePath, this.model);
                    this.saveSettings();
                }
                ++x;
            }
            enc.finishEncode();
            bw.close();
            System.out.println("[DEBUG]: Autosaving...");
            if (this.training) {
                FileIO.saveNeuralNetwork(this.savePath, this.model);
                this.saveSettings();
            }
        }
        catch (Exception e) {
            System.err.println("Error while playing game: ");
            e.printStackTrace();
            return;
        }
    }

    private int argmax(Matrix m) {
        double largest = Double.NEGATIVE_INFINITY;
        int largestIndex = 0;
        int i = 0;
        while (i < m.w.length) {
            if (m.w[i] > largest) {
                largestIndex = i;
                largest = m.w[i];
            }
            ++i;
        }
        return largestIndex;
    }

    private double max(Matrix m) {
        double largest = Double.NEGATIVE_INFINITY;
        int i = 0;
        while (i < m.w.length) {
            if (m.w[i] > largest) {
                largest = m.w[i];
            }
            ++i;
        }
        return largest;
    }

    private Matrix grayscale(BufferedImage in) throws Exception {
        Matrix toReturn = new Matrix(in.getHeight(), in.getWidth());
        int i = 0;
        while (i < in.getWidth()) {
            int j = 0;
            while (j < in.getHeight()) {
                int argb = in.getRGB(i, j);
                int r = argb >> 16 & 0xFF;
                int g = argb >> 8 & 0xFF;
                int b = argb & 0xFF;
                double gray = (double)r / 3.0 + (double)g / 3.0 + (double)b / 3.0;
                toReturn.setW(j, i, gray / 255.0);
                ++j;
            }
            ++i;
        }
        return toReturn;
    }

    public double getEpsilon() {
        return this.epsilon;
    }

    public void setEpsilon(double epsilon) {
        this.epsilon = epsilon;
    }

    public boolean isTraining() {
        return this.training;
    }

    public void setIsTraining(boolean training) {
        this.training = training;
    }

    public int getMaxPlays() {
        return this.maxPlays;
    }

    public void setMaxPlays(int maxPlays) {
        this.maxPlays = maxPlays;
    }

    private double train(RanMT random, double learningRate, int epochs) throws Exception {
        if (this.model.t == 0) {
            this.model.t = 1;
        }
        double numLoss = 0.0;
        double denomLoss = 0.0;
        Graph g = new Graph(true);
        Graph g2 = new Graph(false);
        Matrix out = null;
        int l = 0;
        while (l < epochs) {
            int randomReplay;
            this.model.resetState();
            ReplayFrame[] replayFrames = new ReplayFrame[2];
            if (this.replayBuffer[395] == null) {
                return 1.0;
            }
            while (this.replayBuffer[randomReplay = random.choose(0, 398)] == null || this.replayBuffer[randomReplay + 1] == null) {
            }
            replayFrames[0] = this.replayBuffer[randomReplay];
            replayFrames[1] = this.replayBuffer[randomReplay + 1];
            Tensor netInput = replayFrames[0].getFrame();
            out = this.model.forward(netInput, g).getMatrixAt(0);
            Matrix targetOutput = out.clone();
            double target = 0.0;
            target = replayFrames[0].getReward();
            if (!replayFrames[0].isGameOver()) {
                netInput = replayFrames[1].getFrame();
                Matrix res = this.targetModel.forward(netInput, g2).getMatrixAt(0);
                double max = this.max(res);
                target += 0.85 * max;
            }
            targetOutput.w[replayFrames[0].getUsedAction()] = target;
            numLoss += this.lossToUse.measure(out, targetOutput);
            denomLoss += 1.0;
            this.lossToUse.backward(out, targetOutput);
            ++l;
        }
        g.backward();
        this.method.updateParameters(this.model, learningRate, 10);
        g.cleanUp();
        this.model.resetState();
        return numLoss / denomLoss;
    }

    public static void main(String[] args) {
        System.out.println(String.valueOf(BreakoutAI.thing(5, false)) + "\n\n" + BreakoutAI.thing(5, true));
        try {
            UIManager.setLookAndFeel(UIManager.getSystemLookAndFeelClassName());
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        dp = new NeuralNetworkDisplay(20, 20, 12);
        dp2 = new NeuralNetworkDisplay(9, 9, 27);
        dp3 = new NeuralNetworkDisplay(7, 7, 35);
        dp.setTitle(0, 0, 1, 2);
        dp.setLocation(700, 150);
        dp2.setTitle(1, 0, 1, 2);
        dp2.setLocation(700, 300);
        dp3.setTitle(2, 0, 1, 2);
        dp3.setLocation(700, 450);
        Random random = new Random();
        int cores = 8;
        ConvNet model = new ConvNet();
        model.addLayer(new ConvLayer(84, 84, 4, 8, 8, 32, 4, 0, 0.08, random, true, cores));
        model.addLayer(new ConvNonlinLayer(new ReLuUnit()));
        model.addLayer(new ViewerLayer(dp, false, 0, 1, 2));
        model.addLayer(new ConvDropout(0.1));
        model.addLayer(new ConvLayer(20, 20, 32, 4, 4, 64, 2, 0, 0.08, random, true, cores));
        model.addLayer(new ConvNonlinLayer(new ReLuUnit()));
        model.addLayer(new ViewerLayer(dp2, false, 0, 1, 2));
        model.addLayer(new ConvDropout(0.1));
        model.addLayer(new ConvLayer(9, 9, 64, 3, 3, 64, 1, 0, 0.08, random, true, cores));
        model.addLayer(new ConvNonlinLayer(new ReLuUnit()));
        model.addLayer(new ViewerLayer(dp3, false, 0, 1, 2));
        model.addLayer(new ConvDropout(0.1));
        model.addLayer(new ConvFlatten(7, 7, 64));
        ArrayList<Model> fcLayers = new ArrayList<Model>();
        fcLayers.add(new FeedForwardLayer(3136, 512, new ReLuUnit(), 0.08, random));
        fcLayers.add(new Dropout(0.1));
        fcLayers.add(new FeedForwardLayer(512, 3, new LinearUnit(), 0.08, random));
        NeuralNetwork fc = new NeuralNetwork(fcLayers);
        model.setFullyConnected(fc);
        long total = 0L;
        for (Matrix m : model.getParameters()) {
            total += (long)m.w.length;
        }
        System.out.println(String.valueOf(Long.toString(total)) + " total parameters.");
        ResourceDisplay rd = null;
        try {
            rd = new ResourceDisplay(null, "Intel(R) Xeon(R) CPU E5-1620 v4", "DDR4 SDRAM PC4-17000, ECC, registered");
            rd.setLocation(650, 16);
            rd.showFrame();
        }
        catch (Exception e) {
            System.err.println("Error creating resource display: ");
            e.printStackTrace();
            rd = null;
        }
        dp.setVisible(true);
        dp2.setVisible(true);
        dp3.setVisible(true);
        BreakoutAI ai = null;
        try {
            ai = new BreakoutAI(model, "BreakoutAI.ser", "BreakoutAI/", 0.15, new GradientNoise(new AMSGrad(0.99, 0.999, 1.0E-5)));
        }
        catch (Exception e) {
            System.err.println("Error creating AI: ");
            e.printStackTrace();
            System.exit(1);
        }
        boolean playOnly = args.length > 0 && args[0].equalsIgnoreCase("play");
        try {
            if (playOnly) {
                ai.setEpsilon(0.0);
                ai.setIsTraining(false);
                ai.setMaxPlays(2);
            } else {
                ai.setMaxPlays(64);
            }
        }
        catch (Exception e) {
            System.err.println("Error starting game: ");
            e.printStackTrace();
            System.exit(1);
        }
        int x = random.nextInt(20);
        int y = random.nextInt(20);
        boolean b = x > 5;
        b = b || y < 10;
        b = b && x != y;
        if (Boolean.toString(!b).startsWith("f") == Boolean.parseBoolean("false")) {
            System.out.println("aaaaaa");
        }
        Thread t = new Thread(ai);
        t.start();
        try {
            t.join();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        try {
            Thread.sleep(1000L);
            ai.stopGame();
        }
        catch (Exception e) {
            System.err.println("Error closing game: ");
            e.printStackTrace();
            System.exit(1);
        }
        if (rd != null) {
            rd.closeFrame();
        }
        if (dp != null) {
            dp.setVisible(false);
        }
        if (dp2 != null) {
            dp2.setVisible(false);
        }
        if (dp3 != null) {
            dp3.setVisible(false);
        }
        System.exit(0);
    }

    public static String thing(int n, boolean vertical) {
        String result = "";
        if (vertical) {
            int i = 0;
            while (i < n) {
                result = String.valueOf(result) + (i % 2 == 0 ? " * \n" : "***\n");
                ++i;
            }
        } else {
            int i = 0;
            while (i < 3) {
                int j = 0;
                while (j < n) {
                    result = String.valueOf(result) + (i == 1 ? (j == 0 ? "\n" : "") : (i == 0 ? " * " : "***"));
                    ++j;
                }
                ++i;
            }
        }
        return result;
    }

    private class ReplayFrame {
        private Tensor frame;
        private double reward;
        private int netAction;
        private int usedAction;
        private boolean gameOver;

        private ReplayFrame(Tensor frame, double reward, int netAction, int usedAction, boolean gameOver) {
            this.frame = frame;
            this.reward = reward;
            this.netAction = netAction;
            this.usedAction = usedAction;
            this.gameOver = gameOver;
        }

        public Tensor getFrame() {
            return this.frame;
        }

        public double getReward() {
            return this.reward;
        }

        public int getNetAction() {
            return this.netAction;
        }

        public int getUsedAction() {
            return this.usedAction;
        }

        public boolean isGameOver() {
            return this.gameOver;
        }
    }

    public static class ViewerLayer
    implements TensorLayer {
        private NeuralNetworkDisplay dp;
        private int[] o;
        private boolean showGradients;

        public ViewerLayer(NeuralNetworkDisplay dp, boolean showGradients, int ... a) {
            this.dp = dp;
            this.showGradients = showGradients;
            this.dp.setShowGradients(showGradients);
            this.o = new int[a.length];
            int i = 0;
            while (i < a.length) {
                this.o[i] = a[i];
                ++i;
            }
        }

        @Override
        public Tensor forward(final Tensor input, Graph g) throws Exception {
            if (!this.showGradients) {
                if (this.o.length >= 3) {
                    this.dp.showFeatures(input.getMatrixAt(this.o[0]), input.getMatrixAt(this.o[1]), input.getMatrixAt(this.o[2]));
                }
                if (this.o.length == 2) {
                    this.dp.showFeatures(input.getMatrixAt(this.o[0]), input.getMatrixAt(this.o[1]));
                }
                if (this.o.length == 1) {
                    this.dp.showFeatures(input.getMatrixAt(this.o[0]));
                }
            }
            if (this.showGradients && g.applyBackprop()) {
                g.addBackprop(new Runnable(){

                    @Override
                    public void run() {
                        if (o.length >= 3) {
                            dp.showFeatures(input.getMatrixAt(o[0]), input.getMatrixAt(o[1]), input.getMatrixAt(o[2]));
                        }
                        if (o.length == 2) {
                            dp.showFeatures(input.getMatrixAt(o[0]), input.getMatrixAt(o[1]));
                        }
                        if (o.length == 1) {
                            dp.showFeatures(input.getMatrixAt(o[0]));
                        }
                    }
                });
            }
            return input;
        }

        public boolean isShowGradients() {
            return this.showGradients;
        }

        public void setShowGradients(boolean showGradients) {
            this.showGradients = showGradients;
            this.dp.setShowGradients(showGradients);
        }

        @Override
        public void resetState() {
        }

        @Override
        public List<Matrix> getParameters() {
            return null;
        }

        @Override
        public TensorLayer clone() {
            return new ViewerLayer(this.dp, this.showGradients, this.o);
        }

        @Override
        public void saveState(DataOutputStream fos) throws Exception {
        }

        @Override
        public void loadState(DataInputStream fis) throws Exception {
        }
    }
}

