/*
 * Decompiled with CFR 0.152.
 */
package theGhastModding.lstmStuff.gameThingy;

import autodiff.GPUGraph;
import edu.cornell.lassp.houle.RngPack.BetterRandom;
import edu.cornell.lassp.houle.RngPack.RanMT;
import edu.cornell.lassp.houle.RngPack.Ranecu;
import examples.deepRL.AggregationLayer;
import examples.deepRL.DeepRLAgent;
import examples.deepRL.DeepRLExample;
import examples.deepRL.NeuralNetworkDisplay;
import java.awt.Robot;
import java.awt.image.BufferedImage;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadPoolExecutor;
import loss.Loss;
import loss.LossSumOfSquares;
import matrix.Matrix;
import matrix.Tensor;
import model.ConvDense;
import model.ConvFlatten;
import model.ConvLayer;
import model.FeedForwardLayer;
import model.NeuralNetwork;
import net.dv8tion.jda.core.entities.TextChannel;
import nonlinearities.ExponentialLinearUnit;
import nonlinearities.LinearUnit;
import theGhastModding.botLogger.main.TheGhastBotLogger;
import theGhastModding.lstmStuff.gameThingy.SuperHexagonWrapper;
import trainer.Adam;
import trainer.Optimizer;
import util.CLUtils;
import util.NNDevice;

public class SuperHexagonAI {
    private DeepRLAgent agent;
    private final int actions = 3;
    private final double discount = 0.75;
    private double explore = 0.1;
    private int trainEpochs;
    private int totalGames = 0;
    private boolean isGameStarted;
    private SuperHexagonWrapper wrapper;
    private Robot robot;
    private Exception testE = null;
    private boolean gameOver = false;
    private int prevAction;
    private Tensor prevState;
    private int frameCntr;
    private static NeuralNetworkDisplay dp;
    private static NeuralNetworkDisplay dp2;
    private static NeuralNetworkDisplay dp3;
    private static TheGhastBotLogger botLogger;

    public SuperHexagonAI(NeuralNetwork model, int trainEpochs, double learningRate, Optimizer optimizer, Loss lossToUse, int bufferSize, NNDevice ... devices) throws Exception {
        this.trainEpochs = trainEpochs;
        this.agent = new DeepRLAgent(model, 3, learningRate, optimizer, lossToUse, 0.75, 0.1, bufferSize, new BetterRandom(new Ranecu(System.currentTimeMillis() + System.nanoTime())), devices);
        this.isGameStarted = false;
        this.robot = new Robot();
        this.agent.setAlpha(0.75);
        this.agent.setTau(0.002);
    }

    public void save(File f) throws Exception {
        File modelFile = new File(String.valueOf(f.getPath()) + "_model.dat");
        File settingsFile = new File(String.valueOf(f.getPath()) + "_a.dat");
        this.agent.save(modelFile);
        FileOutputStream fos = new FileOutputStream(settingsFile);
        DataOutputStream dos = new DataOutputStream(fos);
        dos.writeInt(this.trainEpochs);
        dos.writeDouble(this.explore);
        dos.writeInt(this.totalGames);
        dos.close();
    }

    public void load(File f) throws Exception {
        File modelFile = new File(String.valueOf(f.getPath()) + "_model.dat");
        File settingsFile = new File(String.valueOf(f.getPath()) + "_a.dat");
        if (!modelFile.exists() || !settingsFile.exists()) {
            System.err.println("Savefile incomplete/missing!");
            return;
        }
        this.agent.load(modelFile);
        FileInputStream fis = new FileInputStream(settingsFile);
        DataInputStream dis = new DataInputStream(fis);
        this.trainEpochs = dis.readInt();
        this.explore = dis.readDouble();
        this.totalGames = dis.readInt();
        dis.close();
    }

    public void run(int games, File savefile, boolean autosave) {
        try {
            long frametime = 66666666L;
            BufferedWriter csvWriter = new BufferedWriter(new FileWriter(new File("superhexagon.csv"), true));
            Ranecu ran = new Ranecu(System.currentTimeMillis());
            int i = 0;
            while (i < games) {
                System.out.println("Game " + Integer.toString(i + 1) + " of " + Integer.toString(games));
                if (botLogger != null) {
                    botLogger.log("Game " + Integer.toString(i + 1) + " of " + Integer.toString(games));
                }
                Tensor inBuffer = new Tensor(240, 240, 4);
                Thread.sleep(5000L);
                this.startGame();
                Thread.sleep(250L);
                this.frameCntr = 0;
                ++this.totalGames;
                this.explore = Math.abs(Math.sin((double)this.totalGames / 4.0)) * Math.pow(0.9991, this.totalGames);
                if (this.explore < 0.05) {
                    this.explore = 0.05;
                }
                if (this.explore == 0.05 && ran.coin(0.25)) {
                    this.explore = 0.0;
                }
                this.agent.setExplorationRate(this.explore);
                System.out.println("Maximum Exploration rate: " + Double.toString(this.explore));
                long startTime = System.currentTimeMillis();
                System.out.println("Now playing a game.");
                long lastFrametime = System.nanoTime();
                ((GPUGraph)this.agent.graph).forceKeepParameters(true);
                final int rewardFrame = 60;
                ThreadPoolExecutor test = (ThreadPoolExecutor)Executors.newCachedThreadPool();
                this.testE = null;
                this.gameOver = false;
                Runnable testr = new Runnable(){

                    @Override
                    public void run() {
                        try {
                            Thread.sleep(17L);
                            SuperHexagonAI.this.robot.keyRelease(37);
                            SuperHexagonAI.this.robot.keyRelease(39);
                            if (!SuperHexagonAI.this.wrapper.isInStage()) {
                                SuperHexagonAI.this.agent.feedback(SuperHexagonAI.this.prevState, SuperHexagonAI.this.prevAction, -1.0, true);
                                SuperHexagonAI.this.gameOver = true;
                            } else if (SuperHexagonAI.this.frameCntr % rewardFrame == 0) {
                                SuperHexagonAI.this.agent.feedback(SuperHexagonAI.this.prevState, SuperHexagonAI.this.prevAction, 1.0, false);
                            } else {
                                SuperHexagonAI.this.agent.feedback(SuperHexagonAI.this.prevState, SuperHexagonAI.this.prevAction, 0.0, false);
                            }
                        }
                        catch (Exception e) {
                            SuperHexagonAI.this.testE = e;
                        }
                    }
                };
                while (true) {
                    if (System.nanoTime() - lastFrametime < 66666666L) {
                        continue;
                    }
                    lastFrametime = System.nanoTime();
                    long lastFramelength = System.currentTimeMillis();
                    Matrix currFrame = SuperHexagonAI.grayscale(this.wrapper.getGameScreen());
                    inBuffer.matrices[3] = inBuffer.matrices[2];
                    inBuffer.matrices[2] = inBuffer.matrices[1];
                    inBuffer.matrices[1] = inBuffer.matrices[0];
                    inBuffer.matrices[0] = currFrame;
                    int act = this.agent.think(inBuffer);
                    while (test.getActiveCount() != 0) {
                    }
                    if (this.testE != null) {
                        csvWriter.close();
                        throw this.testE;
                    }
                    this.prevAction = act;
                    this.prevState = inBuffer.clone();
                    switch (act) {
                        default: {
                            break;
                        }
                        case 0: {
                            this.robot.keyPress(37);
                            Thread.sleep(17L);
                            break;
                        }
                        case 1: {
                            this.robot.keyPress(39);
                            Thread.sleep(17L);
                            break;
                        }
                        case 2: {
                            Thread.sleep(8L);
                        }
                    }
                    ++this.frameCntr;
                    if (this.gameOver) break;
                    lastFramelength = System.currentTimeMillis() - lastFramelength;
                    if (lastFramelength < 250L) {
                        test.submit(testr);
                        continue;
                    }
                    System.err.println("LAG!!");
                    this.robot.keyRelease(39);
                    this.robot.keyRelease(37);
                }
                ((GPUGraph)this.agent.graph).forceKeepParameters(false);
                System.out.println("Game Over.");
                this.robot.keyRelease(37);
                this.robot.keyRelease(39);
                if (botLogger != null) {
                    botLogger.log("Game Over.");
                }
                long endTime = System.currentTimeMillis() - startTime;
                System.out.println(String.valueOf(Double.toString(1.0 / ((double)endTime / (double)this.frameCntr / 1000.0))) + " FPS");
                if (botLogger != null) {
                    botLogger.log(String.valueOf(Double.toString(1.0 / ((double)endTime / (double)this.frameCntr / 1000.0))) + " FPS");
                }
                System.out.println("Now training for " + Integer.toString(this.frameCntr) + " iterations...");
                if (botLogger != null) {
                    botLogger.log("Now training for " + Integer.toString(this.frameCntr) + " iterations...");
                }
                startTime = System.currentTimeMillis();
                dp.setFrozen(true);
                dp2.setFrozen(true);
                dp3.setFrozen(true);
                System.gc();
                double nomLoss = 0.0;
                int j = 0;
                while (j < 4) {
                    System.out.println(String.valueOf(Double.toString((double)j / 4.0 * 100.0)) + "% done...");
                    ((GPUGraph)this.agent.graph).forceKeepParameters(true);
                    nomLoss += this.agent.train(this.frameCntr / 4);
                    ((GPUGraph)this.agent.graph).forceKeepParameters(false);
                    ++j;
                }
                System.out.println("100% done...");
                System.gc();
                dp.setFrozen(false);
                dp2.setFrozen(false);
                dp3.setFrozen(false);
                System.out.println("Training done, final loss is " + Double.toString(nomLoss / 4.0) + ".");
                if (botLogger != null) {
                    botLogger.log("Training done, final loss is " + Double.toString(nomLoss / 4.0) + ".");
                }
                if (this.frameCntr > 8) {
                    csvWriter.write(String.valueOf(String.format("%.8f", nomLoss / 4.0)) + "," + Double.toString((double)endTime / 1000.0) + "," + Double.toString(this.agent.getExplorationRate()) + ",");
                    csvWriter.newLine();
                    csvWriter.flush();
                }
                if (autosave && i % 2 == 0) {
                    this.save(savefile);
                }
                ++i;
            }
            System.out.println("All games completed successfully!");
            csvWriter.close();
        }
        catch (Exception e) {
            System.err.println("Error running game: ");
            e.printStackTrace();
        }
    }

    public void openGame() throws Exception {
        this.wrapper = new SuperHexagonWrapper(new File("./AI.bat.lnk"));
        this.wrapper.start();
        Robot robot = new Robot();
        Thread.sleep(500L);
        robot.keyPress(39);
        Thread.sleep(60L);
        robot.keyRelease(39);
        Thread.sleep(500L);
        robot.keyPress(10);
        Thread.sleep(60L);
        robot.keyRelease(10);
        Thread.sleep(1000L);
        this.isGameStarted = true;
        this.wrapper.reset();
    }

    public void startGame() throws Exception {
        if (!this.isGameStarted) {
            this.openGame();
            return;
        }
        this.robot.keyPress(82);
        Thread.sleep(60L);
        this.robot.keyRelease(82);
        Thread.sleep(1000L);
        this.wrapper.reset();
    }

    public void closeGame() {
        try {
            this.wrapper.stop();
            this.robot.keyPress(18);
            this.robot.keyPress(115);
            Thread.sleep(100L);
            this.robot.keyRelease(18);
            this.robot.keyRelease(115);
            Thread.sleep(2500L);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        this.isGameStarted = false;
    }

    public static void main(String[] args) {
        try {
            dp = new NeuralNetworkDisplay(60, 60, 4);
            dp2 = new NeuralNetworkDisplay(14, 14, 17);
            dp3 = new NeuralNetworkDisplay(7, 7, 34);
            dp.setTitle(0, 0, 1, 2);
            dp.setLocation(700, 450);
            dp2.setTitle(1, 0, 1, 2);
            dp2.setLocation(1100, 450);
            dp3.setTitle(2, 0, 1, 2);
            dp3.setLocation(1500, 450);
            int cores = 4;
            BetterRandom rng = new BetterRandom(new RanMT());
            NeuralNetwork network = new NeuralNetwork();
            network.addLayer(new ConvLayer(240, 240, 4, 8, 8, 32, 4, 2, new ExponentialLinearUnit(1.0), 0.01, rng, true, true, 4));
            network.addLayer(new DeepRLExample.ViewerLayer(dp, false, 0, 1, 2));
            network.addLayer(new ConvLayer(60, 60, 32, 8, 8, 72, 4, 0, new ExponentialLinearUnit(1.0), 0.01, rng, true, true, 4));
            network.addLayer(new DeepRLExample.ViewerLayer(dp2, false, 0, 1, 2));
            network.addLayer(new ConvLayer(14, 14, 72, 4, 4, 128, 2, 1, new ExponentialLinearUnit(1.0), 0.01, rng, true, true, 4));
            network.addLayer(new DeepRLExample.ViewerLayer(dp3, false, 0, 1, 2));
            network.addLayer(new ConvLayer(7, 7, 128, 3, 3, 96, 1, 0, new ExponentialLinearUnit(1.0), 0.0126, rng, true, true, 4));
            network.addLayer(new ConvFlatten(5, 5, 96));
            network.addLayer(new ConvDense(new FeedForwardLayer(2400, 768, new ExponentialLinearUnit(1.0), Math.sqrt(4.166666666666667E-4), rng)));
            NeuralNetwork a_net = new NeuralNetwork();
            a_net.addLayer(new ConvDense(new FeedForwardLayer(768, 3, new LinearUnit(), Math.sqrt(0.001953125), rng)));
            NeuralNetwork v_net = new NeuralNetwork();
            v_net.addLayer(new ConvDense(new FeedForwardLayer(768, 1, new LinearUnit(), Math.sqrt(0.001953125), rng)));
            NeuralNetwork finalNet = new NeuralNetwork();
            finalNet.addLayer(new AggregationLayer(network, a_net, v_net));
            dp.setVisible(true);
            dp2.setVisible(true);
            dp3.setVisible(true);
            if (args.length != 0 && args[0].equalsIgnoreCase("true")) {
                SuperHexagonAI.setUpLogger(true, "TheGhastBotServer", "bot-test-channel");
            }
            NNDevice[] devices = CLUtils.findDevice("AMD", "Tahiti");
            SuperHexagonAI ai = new SuperHexagonAI(finalNet, 3, 1.0E-4, new Adam(0.9, 0.999, 1.0E-8), new LossSumOfSquares(), 2048, devices);
            ai.load(new File("D:/models/superhexagon_ai"));
            int actualIters = 0;
            int i = 0;
            while (i < 2) {
                if (LocalDateTime.now().getHour() != 14) {
                    --i;
                }
                System.err.println(Integer.toString(++actualIters + 1));
                if (botLogger != null) {
                    botLogger.log("Iteration " + Integer.toString(i + 1) + " of 32");
                }
                ai.run(32, new File("D:/models/superhexagon_ai"), false);
                ai.save(new File("D:/models/superhexagon_ai"));
                ++i;
            }
            ai.closeGame();
            ai.save(new File("D:/models/superhexagon_ai"));
            dp.setVisible(false);
            dp2.setVisible(false);
            dp3.setVisible(false);
            SuperHexagonAI.movingAverageCalculator(new File("superhexagon.csv"), 1, 3);
            SuperHexagonAI.stopLogger();
        }
        catch (Exception e) {
            System.err.println("Error: ");
            e.printStackTrace();
            System.exit(1);
        }
        System.exit(0);
    }

    private static void movingAverageCalculator(File csvFile, int valuesCol, int outputCol) {
        try {
            String s;
            ArrayList<String> finalTable = new ArrayList<String>();
            BufferedReader br = new BufferedReader(new FileReader(csvFile));
            finalTable.add(String.valueOf(br.readLine()) + ",moving average");
            double[] lastFewValues = new double[32];
            int pos = 0;
            while ((s = br.readLine()) != null) {
                double val;
                if (s.isEmpty()) {
                    finalTable.add(s);
                    continue;
                }
                String[] elements = s.split(",");
                lastFewValues[pos] = val = Double.parseDouble(elements[valuesCol]);
                ++pos;
                pos %= lastFewValues.length;
                double sum = 0.0;
                Object[] objectArray = lastFewValues;
                int n = lastFewValues.length;
                int n2 = 0;
                while (n2 < n) {
                    double d = objectArray[n2];
                    sum += d;
                    ++n2;
                }
                sum /= (double)lastFewValues.length;
                if (elements.length >= outputCol + 1) {
                    elements[outputCol] = Double.toString(sum);
                }
                String res = "";
                objectArray = elements;
                n = elements.length;
                n2 = 0;
                while (n2 < n) {
                    double e = objectArray[n2];
                    res = String.valueOf(res) + (String)e + ",";
                    ++n2;
                }
                if (elements.length < outputCol + 1) {
                    res = String.valueOf(res) + Double.toString(sum) + ",";
                }
                finalTable.add(res);
            }
            br.close();
            BufferedWriter bw = new BufferedWriter(new FileWriter(csvFile));
            for (String s2 : finalTable) {
                bw.write(s2);
                bw.newLine();
            }
            bw.close();
        }
        catch (Exception e) {
            System.err.println("aaaaaaaaaaaaa: ");
            e.printStackTrace();
            return;
        }
    }

    public static void setUpLogger(boolean doWait, String serverName, String channelName) {
        if (botLogger != null) {
            return;
        }
        try {
            botLogger = TheGhastBotLogger.start(doWait);
        }
        catch (Exception e) {
            e.printStackTrace();
            botLogger = null;
        }
        TextChannel a = null;
        for (TextChannel tc : botLogger.getAvailableChannels()) {
            if (!tc.getName().equalsIgnoreCase(channelName) || !tc.getManager().getGuild().getName().equalsIgnoreCase(serverName)) continue;
            a = tc;
            break;
        }
        if (a == null) {
            System.err.println("[ERROR] Channel and/or Server for logger not found");
            SuperHexagonAI.stopLogger();
            return;
        }
        botLogger.setLogOutputChannel(a);
    }

    public static void stopLogger() {
        if (botLogger == null) {
            return;
        }
        try {
            botLogger.stop();
            Thread.sleep(1000L);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        botLogger = null;
    }

    private static Matrix grayscale(BufferedImage in) {
        Matrix res = new Matrix(in.getHeight() / 2, in.getWidth() / 2);
        int[] rgbArray = new int[in.getWidth() * in.getHeight()];
        in.getRGB(0, 0, in.getWidth(), in.getHeight(), rgbArray, 0, in.getWidth());
        int cntr = 0;
        int cntr2 = -1;
        int i = 0;
        while (i < in.getWidth()) {
            int j = 0;
            while (j < in.getHeight()) {
                ++cntr2;
                if (i % 2 == 0 && j % 2 == 0) {
                    int rgb = rgbArray[cntr2];
                    res.w[cntr] = ((float)(rgb & 0xFF) + (float)(rgb >> 8 & 0xFF) + (float)(rgb >> 16 & 0xFF)) / 765.0f;
                    ++cntr;
                }
                ++j;
            }
            ++i;
        }
        return res;
    }
}

