/*
 * Decompiled with CFR 0.152.
 */
package theGhastModding.lstmStuff.gameThingy;

import autodiff.GPUGraph;
import edu.cornell.lassp.houle.RngPack.BetterRandom;
import edu.cornell.lassp.houle.RngPack.RanMT;
import edu.cornell.lassp.houle.RngPack.Ranecu;
import examples.deepRL.AggregationLayer;
import examples.deepRL.DeepRLAgent;
import examples.deepRL.DeepRLExample;
import examples.deepRL.NeuralNetworkDisplay;
import java.awt.Robot;
import java.awt.image.BufferedImage;
import java.io.BufferedWriter;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.time.LocalDateTime;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadPoolExecutor;
import loss.Loss;
import loss.LossSumOfSquares;
import matrix.Matrix;
import matrix.Tensor;
import model.ConvDense;
import model.ConvFlatten;
import model.ConvLayer;
import model.Dropout;
import model.FeedForwardLayer;
import model.NeuralNetwork;
import model.NormalizeLayer;
import nonlinearities.ExponentialLinearUnit;
import nonlinearities.LinearUnit;
import theGhastModding.lstmStuff.gameThingy.Touhou10Wrapper;
import trainer.Adam;
import trainer.Optimizer;
import util.CLUtils;
import util.NNDevice;

public class TouhouAI {
    private DeepRLAgent agent;
    private final int actions = 10;
    private final double discount = 0.75;
    private double explore = 0.1;
    private int trainEpochs;
    private int totalGames = 0;
    private boolean isGameStarted;
    private Touhou10Wrapper wrapper;
    private Robot robot;
    private boolean gameOver = false;
    private int previousLives;
    private double previousScore;
    private int prevAction;
    private Tensor prevState;
    private Exception testE = null;
    private static NeuralNetworkDisplay dp;
    private static NeuralNetworkDisplay dp2;
    private static NeuralNetworkDisplay dp3;

    public TouhouAI(NeuralNetwork model, double learningRate, Optimizer optimizer, Loss lossToUse, int bufferSize, NNDevice ... devices) throws Exception {
        this.agent = new DeepRLAgent(model, 10, learningRate, optimizer, lossToUse, 0.75, 0.1, bufferSize, new BetterRandom(new Ranecu(System.currentTimeMillis() + System.nanoTime())), devices);
        this.isGameStarted = false;
        this.robot = new Robot();
        this.agent.setAlpha(0.75);
        this.agent.setTau(0.002);
    }

    public void save(File f) throws Exception {
        File modelFile = new File(String.valueOf(f.getPath()) + "_model.dat");
        File settingsFile = new File(String.valueOf(f.getPath()) + "_a.dat");
        this.agent.save(modelFile);
        FileOutputStream fos = new FileOutputStream(settingsFile);
        DataOutputStream dos = new DataOutputStream(fos);
        dos.writeInt(this.trainEpochs);
        dos.writeDouble(this.explore);
        dos.writeInt(this.totalGames);
        dos.close();
    }

    public void load(File f) throws Exception {
        File modelFile = new File(String.valueOf(f.getPath()) + "_model.dat");
        File settingsFile = new File(String.valueOf(f.getPath()) + "_a.dat");
        if (!modelFile.exists() || !settingsFile.exists()) {
            System.err.println("Savefile incomplete/missing!");
            return;
        }
        this.agent.load(modelFile);
        FileInputStream fis = new FileInputStream(settingsFile);
        DataInputStream dis = new DataInputStream(fis);
        this.trainEpochs = dis.readInt();
        this.explore = dis.readDouble();
        this.totalGames = dis.readInt();
        dis.close();
    }

    public void run(int games, File savefile, boolean autosave) {
        try {
            long frametime = 50000000L;
            BufferedWriter csvWriter = new BufferedWriter(new FileWriter(new File("D:/models/touhou.csv"), true));
            Ranecu ran = new Ranecu(System.currentTimeMillis());
            int i = 0;
            while (i < games) {
                System.out.println("Game " + Integer.toString(i + 1) + " of " + Integer.toString(games));
                Tensor inBuffer = new Tensor(192, 224, 4, false);
                Thread.sleep(2500L);
                this.startGame();
                Thread.sleep(250L);
                int frameCntr = 0;
                ++this.totalGames;
                this.explore = Math.abs(Math.sin((double)this.totalGames / 4.0)) * Math.pow(0.9991, this.totalGames);
                if (this.explore < 0.05) {
                    this.explore = 0.05;
                }
                if (this.explore == 0.05 && ran.coin(0.25)) {
                    this.explore = 0.0;
                }
                this.agent.setExplorationRate(this.explore);
                System.out.println("Exploration rate: " + Double.toString(this.explore));
                long startTime = System.currentTimeMillis();
                System.out.println("Now playing a game.");
                long lastFrametime = System.nanoTime();
                long lastFramelength = System.currentTimeMillis();
                ((GPUGraph)this.agent.graph).forceKeepParameters(true);
                boolean isFocusing = false;
                this.previousLives = 9;
                this.previousScore = 0.0;
                BufferedImage lastGameScreen = this.wrapper.getGameScreen();
                this.robot.keyPress(89);
                this.robot.keyPress(17);
                this.gameOver = false;
                ThreadPoolExecutor test = (ThreadPoolExecutor)Executors.newCachedThreadPool();
                this.testE = null;
                Runnable testr = new Runnable(){

                    @Override
                    public void run() {
                        try {
                            Thread.sleep(17L);
                            TouhouAI.this.robot.keyRelease(37);
                            TouhouAI.this.robot.keyRelease(39);
                            TouhouAI.this.robot.keyRelease(38);
                            TouhouAI.this.robot.keyRelease(40);
                            int extraLives = TouhouAI.this.wrapper.getExtraLives();
                            double score = TouhouAI.this.wrapper.getScore();
                            double feedBack = (score - TouhouAI.this.previousScore) / 20000.0;
                            TouhouAI.this.previousScore = score;
                            if (extraLives - TouhouAI.this.previousLives != 0) {
                                feedBack = extraLives - TouhouAI.this.previousLives > 0 ? 1.0 : -1.0;
                                TouhouAI.this.previousLives = extraLives;
                            }
                            TouhouAI.this.agent.feedback(TouhouAI.this.prevState, TouhouAI.this.prevAction, feedBack, feedBack < 0.5);
                            if (extraLives == 7) {
                                TouhouAI.this.gameOver = true;
                            }
                        }
                        catch (Exception e) {
                            TouhouAI.this.testE = e;
                        }
                    }
                };
                while (true) {
                    if (System.nanoTime() - lastFrametime < 50000000L) {
                        continue;
                    }
                    lastFrametime = System.nanoTime();
                    lastFramelength = System.currentTimeMillis();
                    lastGameScreen = this.wrapper.getGameScreen();
                    Matrix currFrame = TouhouAI.grayscale(lastGameScreen);
                    inBuffer.matrices[3] = inBuffer.matrices[2];
                    inBuffer.matrices[2] = inBuffer.matrices[1];
                    inBuffer.matrices[1] = inBuffer.matrices[0];
                    inBuffer.matrices[0] = currFrame;
                    int act = this.agent.think(inBuffer);
                    while (test.getActiveCount() != 0) {
                    }
                    if (this.testE != null) {
                        csvWriter.close();
                        throw this.testE;
                    }
                    this.prevAction = act;
                    this.prevState = inBuffer.clone();
                    switch (act) {
                        default: {
                            break;
                        }
                        case 0: {
                            this.robot.keyPress(37);
                            break;
                        }
                        case 1: {
                            this.robot.keyPress(39);
                            break;
                        }
                        case 2: {
                            this.robot.keyPress(38);
                            break;
                        }
                        case 3: {
                            this.robot.keyPress(40);
                            break;
                        }
                        case 4: {
                            this.robot.keyPress(38);
                            this.robot.keyPress(37);
                            break;
                        }
                        case 5: {
                            this.robot.keyPress(38);
                            this.robot.keyPress(39);
                            break;
                        }
                        case 6: {
                            this.robot.keyPress(40);
                            this.robot.keyPress(37);
                            break;
                        }
                        case 7: {
                            this.robot.keyPress(40);
                            this.robot.keyPress(39);
                            break;
                        }
                        case 8: {
                            break;
                        }
                        case 9: {
                            if (isFocusing) {
                                this.robot.keyRelease(16);
                                isFocusing = false;
                                break;
                            }
                            this.robot.keyPress(16);
                            isFocusing = true;
                        }
                    }
                    ++frameCntr;
                    if (this.gameOver) break;
                    lastFramelength = System.currentTimeMillis() - lastFramelength;
                    if (lastFramelength < 250L) {
                        test.submit(testr);
                    } else {
                        System.err.println("LAG!!");
                        this.robot.keyRelease(37);
                        this.robot.keyRelease(39);
                        this.robot.keyRelease(38);
                        this.robot.keyRelease(40);
                    }
                    if (System.currentTimeMillis() - startTime >= 140000L) break;
                }
                this.robot.keyRelease(89);
                this.robot.keyRelease(37);
                this.robot.keyRelease(39);
                this.robot.keyRelease(38);
                this.robot.keyRelease(40);
                this.robot.keyRelease(16);
                this.robot.keyRelease(17);
                this.robot.keyPress(27);
                Thread.sleep(60L);
                this.robot.keyRelease(27);
                this.closeGame();
                while (test.getActiveCount() != 0) {
                }
                ((GPUGraph)this.agent.graph).forceKeepParameters(false);
                System.out.println("Game Over.");
                long endTime = System.currentTimeMillis() - startTime;
                System.out.println(String.valueOf(Double.toString(1.0 / ((double)endTime / (double)frameCntr / 1000.0))) + " FPS");
                System.out.println("Now training for " + Integer.toString(frameCntr) + " iterations...");
                dp.setFrozen(true);
                dp2.setFrozen(true);
                dp3.setFrozen(true);
                System.gc();
                double nomLoss = 0.0;
                int j = 0;
                while (j < 4) {
                    System.out.println(String.valueOf(Double.toString((double)j / 4.0 * 100.0)) + "% done...");
                    ((GPUGraph)this.agent.graph).forceKeepParameters(true);
                    nomLoss += this.agent.train(frameCntr / 4);
                    ((GPUGraph)this.agent.graph).forceKeepParameters(false);
                    ++j;
                }
                System.out.println("100% done...");
                System.gc();
                dp.setFrozen(false);
                dp2.setFrozen(false);
                dp3.setFrozen(false);
                System.out.println("Training done, final loss is " + Double.toString(nomLoss / 4.0) + ".");
                if (frameCntr > 8) {
                    csvWriter.write(String.valueOf(String.format("%.8f", nomLoss / 4.0)) + "," + Double.toString((double)endTime / 1000.0) + "," + Double.toString(this.agent.getExplorationRate()) + ",");
                    csvWriter.newLine();
                    csvWriter.flush();
                }
                if (autosave && i % 2 == 0) {
                    this.save(savefile);
                }
                ++i;
            }
            System.gc();
            System.out.println("All games completed successfully!");
            csvWriter.close();
        }
        catch (Exception e) {
            System.err.println("Error running game: ");
            e.printStackTrace();
        }
    }

    public void openGame() throws Exception {
        this.wrapper = new Touhou10Wrapper(new File("./th.lnk"));
        this.wrapper.start();
        Robot robot = new Robot();
        Thread.sleep(1500L);
        robot.keyPress(10);
        Thread.sleep(20L);
        robot.keyRelease(10);
        Thread.sleep(1000L);
        robot.keyPress(40);
        Thread.sleep(20L);
        robot.keyRelease(40);
        Thread.sleep(1000L);
        robot.keyPress(10);
        Thread.sleep(20L);
        robot.keyRelease(10);
        Thread.sleep(1000L);
        robot.keyPress(10);
        Thread.sleep(20L);
        robot.keyRelease(10);
        Thread.sleep(1000L);
        robot.keyPress(10);
        Thread.sleep(20L);
        robot.keyRelease(10);
        Thread.sleep(1000L);
        robot.keyPress(10);
        Thread.sleep(20L);
        robot.keyRelease(10);
        Thread.sleep(1000L);
        robot.keyPress(40);
        Thread.sleep(20L);
        robot.keyRelease(40);
        Thread.sleep(1000L);
        robot.keyPress(10);
        Thread.sleep(20L);
        robot.keyRelease(10);
        Thread.sleep(3500L);
        this.isGameStarted = true;
    }

    public void startGame() throws Exception {
        if (!this.isGameStarted) {
            this.openGame();
        } else {
            this.robot.keyPress(40);
            Thread.sleep(20L);
            this.robot.keyRelease(40);
            Thread.sleep(1000L);
            this.robot.keyPress(40);
            Thread.sleep(20L);
            this.robot.keyRelease(40);
            Thread.sleep(1000L);
            this.robot.keyPress(10);
            Thread.sleep(20L);
            this.robot.keyRelease(10);
            Thread.sleep(1000L);
            this.robot.keyPress(38);
            Thread.sleep(20L);
            this.robot.keyRelease(38);
            Thread.sleep(1000L);
            this.robot.keyPress(10);
            Thread.sleep(20L);
            this.robot.keyRelease(10);
            Thread.sleep(2500L);
        }
    }

    public void closeGame() {
        try {
            this.wrapper.stop();
            this.robot.keyPress(18);
            this.robot.keyPress(115);
            Thread.sleep(100L);
            this.robot.keyRelease(18);
            this.robot.keyRelease(115);
            Thread.sleep(2500L);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        this.isGameStarted = false;
    }

    public static void main(String[] args) {
        System.out.println(TouhouAI.thing(15, false));
        System.out.println(TouhouAI.thing(3, true));
        try {
            dp = new NeuralNetworkDisplay(60, 60, 4);
            dp2 = new NeuralNetworkDisplay(14, 14, 16);
            dp3 = new NeuralNetworkDisplay(7, 7, 31);
            dp.setTitle(0, 0, 1, 2);
            dp.setLocation(700, 425);
            dp2.setTitle(1, 0, 1, 2);
            dp2.setLocation(800, 425);
            dp3.setTitle(2, 6, 7, 8);
            dp3.setLocation(900, 425);
            BetterRandom random = new BetterRandom(new RanMT());
            NeuralNetwork model = new NeuralNetwork();
            int cores = 4;
            model.addLayer(new ConvLayer(192, 224, 4, 8, 8, 32, 4, 0, new ExponentialLinearUnit(1.0), 0.01, random, true, true, cores));
            model.addLayer(new DeepRLExample.ViewerLayer(dp, false, 0, 1, 2));
            model.addLayer(new ConvLayer(47, 55, 32, 7, 7, 72, 2, 0, new ExponentialLinearUnit(1.0), 0.01, random, true, true, cores));
            model.addLayer(new DeepRLExample.ViewerLayer(dp2, false, 0, 1, 2));
            model.addLayer(new NormalizeLayer(1.0E-8, 0.08, random));
            model.addLayer(new ConvLayer(21, 25, 72, 3, 3, 72, 2, 0, new ExponentialLinearUnit(1.0), 0.01, random, true, true, cores));
            model.addLayer(new ConvLayer(10, 12, 72, 3, 3, 72, 1, 0, new ExponentialLinearUnit(1.0), 0.0107, random, true, true, cores));
            model.addLayer(new DeepRLExample.ViewerLayer(dp3, false, 6, 7, 8));
            model.addLayer(new ConvLayer(8, 10, 72, 3, 3, 96, 1, 0, new ExponentialLinearUnit(1.0), 0.012, random, true, true, cores));
            model.addLayer(new ConvFlatten(6, 8, 96));
            model.addLayer(new ConvDense(new FeedForwardLayer(4608, 768, new ExponentialLinearUnit(1.0), 0.0147, random)));
            NeuralNetwork a_net = new NeuralNetwork();
            a_net.addLayer(new Dropout(0.1));
            a_net.addLayer(new ConvDense(new FeedForwardLayer(768, 10, new LinearUnit(), 0.036, random)));
            NeuralNetwork v_net = new NeuralNetwork();
            v_net.addLayer(new Dropout(0.1));
            v_net.addLayer(new ConvDense(new FeedForwardLayer(768, 1, new LinearUnit(), 0.036, random)));
            NeuralNetwork finalNet = new NeuralNetwork();
            finalNet.addLayer(new AggregationLayer(model, a_net, v_net));
            dp.setVisible(true);
            dp2.setVisible(true);
            dp3.setVisible(true);
            NNDevice[] devices = CLUtils.findDevice("AMD", "Tahiti");
            TouhouAI ai = new TouhouAI(finalNet, 1.0E-4, new Adam(0.9, 0.999, 1.0E-8), new LossSumOfSquares(), 3024, devices);
            ai.load(new File("D:/models/touhou_ai"));
            int actualIters = 0;
            int i = 0;
            while (i < 4) {
                if (LocalDateTime.now().getHour() < 14) {
                    --i;
                }
                System.err.println(Integer.toString(++actualIters));
                ai.run(4, new File("D:/models/touhou_ai"), false);
                ai.save(new File("D:/models/touhou_ai"));
                ++i;
            }
            ai.closeGame();
            ai.save(new File("D:/models/superhexagon_ai"));
            dp.setVisible(false);
            dp2.setVisible(false);
            dp3.setVisible(false);
        }
        catch (Exception e) {
            System.err.println("Error: ");
            e.printStackTrace();
            System.exit(1);
        }
        System.exit(0);
    }

    private static Matrix grayscale(BufferedImage in) {
        Matrix res = new Matrix(in.getHeight() / 2, in.getWidth() / 2, false);
        int[] rgbArray = new int[in.getWidth() * in.getHeight()];
        in.getRGB(0, 0, in.getWidth(), in.getHeight(), rgbArray, 0, in.getWidth());
        int cntr = 0;
        int cntr2 = -1;
        int i = 0;
        while (i < in.getHeight()) {
            int j = 0;
            while (j < in.getWidth()) {
                ++cntr2;
                if (i % 2 == 0 && j % 2 == 0) {
                    int rgb = rgbArray[cntr2];
                    res.w[cntr] = ((float)(rgb & 0xFF) + (float)(rgb >> 8 & 0xFF) + (float)(rgb >> 16 & 0xFF)) / 765.0f;
                    ++cntr;
                }
                ++j;
            }
            ++i;
        }
        return res;
    }

    public static String thing(int n, boolean vertical) {
        String result = "";
        if (vertical) {
            int i = 0;
            while (i < n) {
                result = String.valueOf(result) + (i % 2 == 0 ? " * \r\n" : "***\r\n");
                ++i;
            }
        } else {
            int i = 0;
            while (i < 3) {
                int j = 0;
                while (j < n) {
                    result = String.valueOf(result) + (i == 1 ? (j == 0 ? "\r\n" : "") : (i == 0 ? " * " : "***"));
                    ++j;
                }
                ++i;
            }
        }
        return result;
    }
}

