/*
 * Decompiled with CFR 0.152.
 */
package theGhastModding.lstmStuff.main;

import autodiff.Graph;
import java.awt.image.BufferedImage;
import java.awt.image.RenderedImage;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.util.ArrayList;
import java.util.Random;
import javax.imageio.ImageIO;
import javax.swing.UIManager;
import loss.Loss;
import matrix.Matrix;
import matrix.Tensor;
import model.ConvFlatten;
import model.ConvLayer;
import model.ConvNet;
import model.ConvNonlinLayer;
import model.FeedForwardLayer;
import model.Model;
import model.NeuralNetwork;
import nonlinearities.LinearUnit;
import nonlinearities.ReLuUnit;
import theGhastModding.lstmStuff.gameThingy.BreakoutFrame;
import theGhastModding.lstmStuff.gui.ResourceDisplay;
import theGhastModding.utils.math.ByteConverters;
import trainer.Adam;
import trainer.TrainingMethod;
import util.FileIO;

public class BreakoutAI
implements Runnable {
    private BreakoutFrame game;
    private ConvNet model;
    private String savePath;
    private File replaySavePath;
    private double epsilon = -1.0;
    private boolean training = true;
    private TrainingMethod method;
    private int ts = 0;
    private final double epsilonBase = 0.97;
    private int maxPlays = 5;
    private Loss lossToUse;

    public BreakoutAI(ConvNet model, String savePath, TrainingMethod method) throws Exception {
        this.model = model;
        this.savePath = savePath;
        this.method = method;
        this.lossToUse = new Loss(){

            @Override
            public void backward(Matrix actualOutput, Matrix targetOutput) throws Exception {
                int i = 0;
                while (i < targetOutput.w.length) {
                    if (targetOutput.w[i] != Double.NEGATIVE_INFINITY) {
                        double errDelta = actualOutput.w[i] - targetOutput.w[i];
                        int n = i;
                        actualOutput.dw[n] = actualOutput.dw[n] + errDelta;
                    }
                    ++i;
                }
            }

            @Override
            public double measure(Matrix actualOutput, Matrix targetOutput) throws Exception {
                double sum = 0.0;
                int i = 0;
                while (i < targetOutput.w.length) {
                    if (targetOutput.w[i] != Double.NEGATIVE_INFINITY) {
                        double errDelta = actualOutput.w[i] - targetOutput.w[i];
                        sum += 0.5 * (errDelta * errDelta);
                    }
                    ++i;
                }
                return sum;
            }
        };
        if (new File(savePath).exists()) {
            FileIO.loadNeuralNetwork(savePath, model);
        }
        this.replaySavePath = new File("D:/BreakoutAI/");
        if (!this.replaySavePath.exists()) {
            this.replaySavePath.mkdirs();
        }
        if (new File(String.valueOf(savePath) + ".a").exists()) {
            this.loadSettings();
        }
    }

    private void loadSettings() {
        try {
            FileInputStream fis = new FileInputStream(new File(String.valueOf(this.savePath) + ".a"));
            byte[] intBuffer = new byte[4];
            fis.read(intBuffer);
            this.ts = ByteConverters.bytesToInt(intBuffer);
            this.epsilon = Math.pow(0.97, this.ts);
            fis.close();
        }
        catch (Exception e) {
            System.err.println("Error loading settings (using defaults): ");
            e.printStackTrace();
            this.ts = 0;
            this.epsilon = Math.pow(0.97, this.ts);
            return;
        }
    }

    private void saveSettings() {
        try {
            FileOutputStream fos = new FileOutputStream(new File(String.valueOf(this.savePath) + ".a"));
            fos.write(ByteConverters.intToBytes(this.ts));
            fos.close();
        }
        catch (Exception e) {
            System.err.println("Error saving settings: ");
            e.printStackTrace();
            return;
        }
    }

    public void startGame() {
        this.game = new BreakoutFrame();
        this.game.create();
        this.game.game.init();
        this.game.game.renderNextFrame();
    }

    public void stopGame() {
        this.game.destroy();
        this.game = null;
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    @Override
    public void run() {
        this.model.resetState();
        this.startGame();
        BufferedImage t = null;
        Tensor inBuffer = new Tensor(84, 84, 4);
        Matrix outBuffer = null;
        int previousScore = 0;
        int fc = this.replaySavePath.listFiles().length / 2;
        Graph g = new Graph(false);
        Random random = new Random();
        int cntr2 = 0;
        double nomLoss = 0.0;
        double denomLoss = 0.0;
        try {
            int x = 0;
            block2: while (true) {
                if (x >= this.maxPlays) {
                    return;
                }
                while (true) {
                    if (this.game.game.isGameOver()) {
                        this.game.game.init();
                        this.model.resetState();
                        if (this.training) {
                            this.epsilon = Math.pow(0.97, this.ts);
                            ++this.ts;
                            System.out.println("Epsilon: " + this.epsilon);
                            System.out.println("[DEBUG]: Current loss is: " + Double.toString(nomLoss / denomLoss));
                            nomLoss = 0.0;
                            denomLoss = 0.0;
                        }
                        if (x % 25 == 0 && this.training) {
                            FileIO.saveNeuralNetwork(this.savePath, this.model);
                            this.saveSettings();
                        }
                        ++x;
                        continue block2;
                    }
                    Thread.sleep(32L);
                    double reward = (double)this.game.game.getScore() - (double)previousScore / 10.0;
                    previousScore = this.game.game.getScore();
                    t = this.game.game.getCurrentFrame();
                    inBuffer.setMatrixAt(0, inBuffer.getMatrixAt(1));
                    inBuffer.setMatrixAt(1, inBuffer.getMatrixAt(2));
                    inBuffer.setMatrixAt(2, inBuffer.getMatrixAt(3));
                    inBuffer.setMatrixAt(3, this.grayscale(t));
                    outBuffer = this.model.forward(inBuffer, g).getMatrixAt(0);
                    int action = this.evaluateOutput(outBuffer);
                    if (random.nextDouble() < this.epsilon) {
                        action = random.nextInt(3);
                    }
                    this.game.game.setRight(false);
                    this.game.game.setLeft(false);
                    if (action == 0) {
                        this.game.game.setRight(false);
                        this.game.game.setLeft(false);
                    }
                    if (action == 1) {
                        this.game.game.setRight(true);
                        this.game.game.setLeft(false);
                    }
                    if (action == 2) {
                        this.game.game.setRight(false);
                        this.game.game.setLeft(true);
                    }
                    if (this.training) {
                        ReplayFrame rf = new ReplayFrame(t, this.game.game.getScore(), reward, action);
                        rf.save(new File(String.valueOf(this.replaySavePath.getPath()) + "/" + Integer.toString(fc) + ".dat"), new File(String.valueOf(this.replaySavePath.getPath()) + "/" + Integer.toString(fc) + ".png"));
                        ++fc;
                    }
                    if (this.training && x > 1) {
                        double loss = this.train(random, 0.01);
                        if (Double.isInfinite(loss) || Double.isNaN(loss)) {
                            System.err.println("[ERR]: Invalid loss of " + Double.toString(loss) + "!");
                            return;
                        }
                        nomLoss += loss;
                        denomLoss += 1.0;
                        --this.model.t;
                        if (++cntr2 >= 32) {
                            ++this.model.t;
                            cntr2 = 0;
                        }
                    }
                    this.game.game.renderNextFrame();
                }
                break;
            }
        }
        catch (Exception e) {
            System.err.println("Error while playing game: ");
            e.printStackTrace();
            return;
        }
    }

    private int evaluateOutput(Matrix m) {
        double largest = Double.NEGATIVE_INFINITY;
        int largestIndex = 0;
        int i = 0;
        while (i < m.w.length) {
            if (m.w[i] > largest) {
                largestIndex = i;
                largest = m.w[i];
            }
            ++i;
        }
        return largestIndex;
    }

    private Matrix grayscale(BufferedImage in) throws Exception {
        Matrix toReturn = new Matrix(in.getHeight(), in.getWidth());
        int i = 0;
        while (i < in.getWidth()) {
            int j = 0;
            while (j < in.getHeight()) {
                int argb = in.getRGB(i, j);
                int r = argb >> 16 & 0xFF;
                int g = argb >> 8 & 0xFF;
                int b = argb >> 0 & 0xFF;
                double gray = (double)r / 3.0 + (double)g / 3.0 + (double)b / 3.0;
                toReturn.setW(j, i, gray / 256.0);
                ++j;
            }
            ++i;
        }
        return toReturn;
    }

    public double getEpsilon() {
        return this.epsilon;
    }

    public void setEpsilon(double epsilon) {
        this.epsilon = epsilon;
    }

    public boolean isTraining() {
        return this.training;
    }

    public void setIsTraining(boolean training) {
        this.training = training;
    }

    public int getMaxPlays() {
        return this.maxPlays;
    }

    public void setMaxPlays(int maxPlays) {
        this.maxPlays = maxPlays;
    }

    private double train(Random random, double learningRate) throws Exception {
        if (this.model.t == 0) {
            this.model.t = 1;
        }
        File[] replays = this.replaySavePath.listFiles();
        int randomFile = random.nextInt(replays.length / 2 - 54) + 4;
        int maxForward = 48;
        ReplayFrame[] replayFrames = new ReplayFrame[maxForward + 4];
        Tensor netInput = new Tensor(84, 84, 4);
        int j = 0;
        while (j < replayFrames.length) {
            replayFrames[j] = new ReplayFrame(null, 0, 0.0, 0);
            replayFrames[j].load(new File(String.valueOf(this.replaySavePath.getPath()) + "/" + Integer.toString(randomFile - 3 + j) + ".dat"), new File(String.valueOf(this.replaySavePath.getPath()) + "/" + Integer.toString(randomFile - 3 + j) + ".png"));
            ++j;
        }
        netInput.setMatrixAt(3, this.grayscale(replayFrames[3].getFrame()));
        netInput.setMatrixAt(2, this.grayscale(replayFrames[2].getFrame()));
        netInput.setMatrixAt(1, this.grayscale(replayFrames[1].getFrame()));
        netInput.setMatrixAt(0, this.grayscale(replayFrames[0].getFrame()));
        Matrix targetOutput = new Matrix(3);
        targetOutput.w[replayFrames[3].getAction()] = replayFrames[4].getReward() / 5.0;
        int j2 = 5;
        while (j2 < replayFrames.length) {
            int n = replayFrames[3].getAction();
            targetOutput.w[n] = targetOutput.w[n] + Math.pow(0.91, j2 - 4) * (replayFrames[j2].getReward() / 5.0);
            ++j2;
        }
        int x = 0;
        while (x < 3) {
            if (x != replayFrames[3].getAction()) {
                targetOutput.w[x] = Double.NEGATIVE_INFINITY;
            }
            ++x;
        }
        Graph g = new Graph(true);
        Matrix out = this.model.forward(netInput, g).getMatrixAt(0);
        double loss = this.lossToUse.measure(out, targetOutput);
        this.lossToUse.backward(out, targetOutput);
        g.backward();
        this.method.updateParameters(this.model, learningRate, 1, null);
        g.cleanUp();
        return loss;
    }

    public static void main(String[] args) {
        try {
            UIManager.setLookAndFeel(UIManager.getSystemLookAndFeelClassName());
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        Random random = new Random();
        ConvNet model = new ConvNet();
        model.addLayer(new ConvLayer(4, 84, 84, 4, 8, 8, 32, 0.08, random, true, 8));
        model.addLayer(new ConvNonlinLayer(new ReLuUnit()));
        model.addLayer(new ConvLayer(2, 20, 20, 32, 4, 4, 64, 0.08, random, true, 8));
        model.addLayer(new ConvNonlinLayer(new ReLuUnit()));
        model.addLayer(new ConvLayer(1, 9, 9, 64, 3, 3, 64, 0.08, random, true, 8));
        model.addLayer(new ConvNonlinLayer(new ReLuUnit()));
        model.addLayer(new ConvFlatten(7, 7, 64));
        ArrayList<Model> fcLayers = new ArrayList<Model>();
        fcLayers.add(new FeedForwardLayer(3136, 512, new LinearUnit(), 0.08, random));
        fcLayers.add(new FeedForwardLayer(512, 3, new LinearUnit(), 0.08, random));
        NeuralNetwork fc = new NeuralNetwork(fcLayers);
        model.setFullyConnected(fc);
        ResourceDisplay rd = null;
        try {
            rd = new ResourceDisplay(null, "Intel(R) Xeon(R) CPU E5-1620 v4", "DDR4 SDRAM PC4-17000, ECC, registered");
            rd.setLocation(650, 16);
            rd.showFrame();
        }
        catch (Exception e) {
            System.err.println("Error creating resource display: ");
            e.printStackTrace();
            rd = null;
        }
        BreakoutAI ai = null;
        try {
            ai = new BreakoutAI(model, "D:/BreakoutAI.ser", new Adam(0.9, 0.999, 1.0E-4));
        }
        catch (Exception e) {
            System.err.println("Error creating AI: ");
            e.printStackTrace();
            System.exit(1);
        }
        boolean playOnly = args.length > 0 && args[0].equalsIgnoreCase("play");
        try {
            ai.setEpsilon(1.0);
            if (playOnly) {
                ai.setEpsilon(-1.0);
                ai.setIsTraining(false);
                ai.setMaxPlays(3);
            } else {
                ai.setMaxPlays(1024);
            }
        }
        catch (Exception e) {
            System.err.println("Error starting game: ");
            e.printStackTrace();
            System.exit(1);
        }
        Thread t = new Thread(ai);
        t.start();
        try {
            t.join();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        try {
            Thread.sleep(1000L);
            ai.stopGame();
        }
        catch (Exception e) {
            System.err.println("Error closing game: ");
            e.printStackTrace();
            System.exit(1);
        }
        if (rd != null) {
            rd.closeFrame();
        }
        System.exit(0);
    }

    private class ReplayFrame {
        private BufferedImage frame;
        private int gameScore;
        private double aiScore;
        private int action;

        private ReplayFrame(BufferedImage frame, int gameScore, double aiScore, int action) {
            this.frame = frame;
            this.gameScore = gameScore;
            this.aiScore = aiScore;
            this.action = action;
        }

        private void save(File data, File image) throws Exception {
            FileOutputStream fos = new FileOutputStream(data);
            fos.write(ByteConverters.intToBytes(this.gameScore));
            fos.write(ByteConverters.doubleToBytes(this.aiScore));
            fos.write(ByteConverters.intToBytes(this.action));
            fos.flush();
            fos.close();
            ImageIO.write((RenderedImage)this.frame, "png", image);
        }

        private void load(File data, File image) throws Exception {
            FileInputStream fis = new FileInputStream(data);
            byte[] buffer1 = new byte[4];
            byte[] buffer2 = new byte[8];
            fis.read(buffer1);
            this.gameScore = ByteConverters.bytesToInt(buffer1);
            fis.read(buffer2);
            this.aiScore = ByteConverters.bytesToDouble(buffer2);
            fis.read(buffer1);
            this.action = ByteConverters.bytesToInt(buffer1);
            fis.close();
            this.frame = ImageIO.read(image);
        }

        public BufferedImage getFrame() {
            return this.frame;
        }

        public double getReward() {
            return this.aiScore;
        }

        public int getAction() {
            return this.action;
        }
    }
}

