/*
 * Decompiled with CFR 0.152.
 */
package examples.deepRL;

import autodiff.GPUGraph;
import autodiff.Graph;
import edu.cornell.lassp.houle.RngPack.RanMT;
import examples.deepRL.AggregationLayer;
import examples.deepRL.DeepRLAgent;
import examples.deepRL.NeuralNetworkDisplay;
import java.awt.image.BufferedImage;
import java.io.BufferedWriter;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.util.List;
import java.util.Random;
import javax.swing.UIManager;
import loss.LossSumOfSquares;
import matrix.Matrix;
import matrix.Tensor;
import model.ConvDense;
import model.ConvFlatten;
import model.ConvLayer;
import model.Dropout;
import model.FeedForwardLayer;
import model.NeuralNetwork;
import model.NormalizeLayer;
import model.TensorLayer;
import nonlinearities.ExponentialLinearUnit;
import nonlinearities.LinearUnit;
import theGhastModding.lstmStuff.gameThingy.BreakoutGame;
import trainer.Adam;
import trainer.Optimizer;
import util.CLUtils;
import util.NNDevice;

public class DeepRLExample {
    private BreakoutGame game;
    private int ts;
    private double epsilon = 0.1;
    private static NeuralNetworkDisplay dp;
    private static NeuralNetworkDisplay dp2;
    private static NeuralNetworkDisplay dp3;
    private DeepRLAgent agent = null;
    private File savePath;
    private File savePath2;

    public DeepRLExample(NeuralNetwork model, String savePath, double tau, double alpha, int bufferSize, Optimizer method) throws Exception {
        this(model, savePath, tau, alpha, bufferSize, method, null);
    }

    public DeepRLExample(NeuralNetwork model, String savePath, double tau, double alpha, int bufferSize, Optimizer method, NNDevice ... dev) throws Exception {
        this.agent = new DeepRLAgent(model, 3, 1.0E-4, method, new LossSumOfSquares(), 0.8, this.epsilon, bufferSize, new Random(), dev);
        this.agent.setAlpha(alpha);
        this.agent.setTau(tau);
        this.savePath = new File(savePath);
        this.savePath2 = new File(String.valueOf(savePath) + "_a.dat");
        this.agent.setRandomActFrames(2);
        this.loadNetwork();
    }

    private void loadNetwork() {
        try {
            if (this.savePath.exists()) {
                this.agent.load(this.savePath);
            }
            if (this.savePath2.exists()) {
                FileInputStream fis = new FileInputStream(this.savePath2);
                DataInputStream dis = new DataInputStream(fis);
                this.ts = dis.readInt();
                this.epsilon = Math.abs(Math.sin((double)this.ts / 4.0)) * Math.pow(0.9999, this.ts);
                if (this.epsilon < 0.05) {
                    this.epsilon += 0.05;
                }
                this.agent.setExplorationRate(this.epsilon);
                dis.close();
            }
        }
        catch (Exception e) {
            System.err.println("Error loading settings (using defaults): ");
            e.printStackTrace();
            this.ts = 0;
            this.epsilon = 0.1;
            this.agent.setExplorationRate(this.epsilon);
            return;
        }
    }

    private void saveNetwork() {
        try {
            this.agent.save(this.savePath);
            FileOutputStream fos = new FileOutputStream(this.savePath2);
            DataOutputStream dos = new DataOutputStream(fos);
            dos.writeInt(this.ts);
            dos.close();
        }
        catch (Exception e) {
            System.err.println("Error saving settings: ");
            e.printStackTrace();
            return;
        }
    }

    public void startGame() {
        this.game = new BreakoutGame();
        this.game.create();
        this.game.init();
        this.game.renderNextFrame();
    }

    public void stopGame() {
        this.game.destroy();
        this.game = null;
    }

    public void run(int games) {
        this.startGame();
        this.game.renderNextFrame();
        BufferedImage t = null;
        Tensor inBuffer = new Tensor(84, 84, 4);
        double previousScore = 0.0;
        RanMT random = new RanMT();
        try {
            this.game.renderNextFrame();
            this.game.setPaddleSpeed(3);
            BufferedWriter bw = new BufferedWriter(new FileWriter(new File("scores.csv"), true));
            int x = 0;
            while (x < games) {
                int stillCount = 0;
                int rightCount = 0;
                int leftCount = 0;
                int stepCounter = 0;
                int i = 0;
                while (i < 3) {
                    t = this.game.getCurrentFrame();
                    this.game.renderNextFrame();
                    Matrix t_m = this.grayscale(t);
                    int j = 0;
                    while (j < t_m.w.length) {
                        int n = j++;
                        t_m.w[n] = t_m.w[n] + random.raw() * 0.05;
                    }
                    inBuffer.matrices[i + 1] = t_m;
                    ++i;
                }
                t = this.game.getCurrentFrame();
                System.out.println(String.valueOf(Integer.toString(x + 1)) + "," + Integer.toString(games));
                inBuffer = new Tensor(84, 84, 4);
                previousScore = 0.0;
                if (this.agent.graph instanceof GPUGraph) {
                    ((GPUGraph)this.agent.graph).forceKeepParameters(true);
                }
                while (!this.game.isGameOver()) {
                    ++stepCounter;
                    inBuffer.matrices[0] = inBuffer.matrices[1];
                    inBuffer.matrices[1] = inBuffer.matrices[2];
                    inBuffer.matrices[2] = inBuffer.matrices[3];
                    Matrix t_m = this.grayscale(t);
                    int i2 = 0;
                    while (i2 < t_m.w.length) {
                        int n = i2++;
                        t_m.w[n] = t_m.w[n] + random.raw() * 0.01;
                    }
                    inBuffer.matrices[3] = t_m;
                    int action = this.agent.think(inBuffer);
                    if (action == 0) {
                        ++stillCount;
                    }
                    if (action == 1) {
                        ++rightCount;
                    }
                    if (action == 2) {
                        ++leftCount;
                    }
                    if (action == 0) {
                        this.game.setRight(false);
                        this.game.setLeft(false);
                    } else if (action == 1) {
                        this.game.setRight(true);
                        this.game.setLeft(false);
                    } else if (action == 2) {
                        this.game.setRight(false);
                        this.game.setLeft(true);
                    } else {
                        this.game.setRight(false);
                        this.game.setLeft(false);
                    }
                    this.game.renderNextFrame();
                    double reward = this.game.getScore() - previousScore;
                    previousScore = this.game.getScore();
                    this.agent.feedback(inBuffer, action, reward, reward < 0.0);
                    t = this.game.getCurrentFrame();
                }
                System.out.println("[DEBUG]: Final score was " + Double.toString(this.game.getScore()));
                System.out.println("[DEBUG]: Now training...");
                double loss = this.agent.train(stepCounter);
                if (Double.isInfinite(loss) || Double.isNaN(loss)) {
                    System.err.println("[ERR]: Invalid loss of " + Double.toString(loss) + "!");
                    bw.close();
                    return;
                }
                System.out.println("[DEBUG]: Training done.");
                bw.write(String.valueOf(Double.toString(this.game.getScore())) + "," + Double.toString(loss) + "," + Double.toString((double)rightCount / (double)stepCounter * 100.0) + "," + Double.toString((double)leftCount / (double)stepCounter * 100.0) + "," + Double.toString((double)stillCount / (double)stepCounter * 100.0) + ",");
                bw.newLine();
                bw.flush();
                this.game.init();
                System.gc();
                ++this.ts;
                this.epsilon = Math.abs(Math.sin((double)this.ts / 4.0)) * Math.pow(0.9995, this.ts);
                if (this.epsilon < 0.05) {
                    this.epsilon += 0.05;
                }
                this.agent.setExplorationRate(this.epsilon);
                System.out.println("[DEBUG]: Epsilon: " + this.epsilon);
                if (x % 16 == 0) {
                    System.out.println("[DEBUG]: Autosaving...");
                    this.saveNetwork();
                }
                ++x;
            }
            bw.close();
            System.out.println("[DEBUG]: Autosaving...");
            this.saveNetwork();
            this.agent.cleanUpGraphs();
        }
        catch (Exception e) {
            System.err.println("Error while playing game: ");
            e.printStackTrace();
            return;
        }
    }

    private Matrix grayscale(BufferedImage in) throws Exception {
        Matrix toReturn = new Matrix(in.getHeight(), in.getWidth());
        int i = 0;
        while (i < in.getWidth()) {
            int j = 0;
            while (j < in.getHeight()) {
                int argb = in.getRGB(i, j);
                int r = argb >> 16 & 0xFF;
                int g = argb >> 8 & 0xFF;
                int b = argb & 0xFF;
                double gray = (double)r / 3.0 + (double)g / 3.0 + (double)b / 3.0;
                toReturn.setW(j, i, gray / 255.0);
                ++j;
            }
            ++i;
        }
        return toReturn;
    }

    public static void main(String[] args) {
        try {
            UIManager.setLookAndFeel(UIManager.getSystemLookAndFeelClassName());
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        dp = new NeuralNetworkDisplay(20, 20, 12);
        dp2 = new NeuralNetworkDisplay(9, 9, 27);
        dp3 = new NeuralNetworkDisplay(7, 7, 35);
        dp.setTitle(0, 0, 1, 2);
        dp.setLocation(700, 150);
        dp2.setTitle(1, 0, 1, 2);
        dp2.setLocation(700, 300);
        dp3.setTitle(2, 0, 1, 2);
        dp3.setLocation(700, 450);
        Random random = new Random();
        int cores = 4;
        NeuralNetwork model = new NeuralNetwork();
        model.addLayer(new ConvLayer(84, 84, 4, 8, 8, 32, 4, 0, new ExponentialLinearUnit(1.0), 0.01, random, true, true, cores));
        model.addLayer(new NormalizeLayer(1.0E-8, 0.08, random));
        model.addLayer(new ViewerLayer(dp, false, 0, 1, 2));
        model.addLayer(new ConvLayer(20, 20, 32, 4, 4, 64, 2, 0, new ExponentialLinearUnit(1.0), 0.01, random, true, true, cores));
        model.addLayer(new NormalizeLayer(1.0E-8, 0.08, random));
        model.addLayer(new ViewerLayer(dp2, false, 0, 1, 2));
        model.addLayer(new ConvLayer(9, 9, 64, 3, 3, 64, 1, 0, new ExponentialLinearUnit(1.0), 0.013888888888888888, random, true, true, cores));
        model.addLayer(new ViewerLayer(dp3, false, 0, 1, 2));
        model.addLayer(new ConvFlatten(7, 7, 64));
        NeuralNetwork a_net = new NeuralNetwork();
        a_net.addLayer(new Dropout(0.1));
        a_net.addLayer(new ConvDense(new FeedForwardLayer(3136, 3, new LinearUnit(), Math.sqrt(0.0033333333333333335), random)));
        NeuralNetwork v_net = new NeuralNetwork();
        v_net.addLayer(new Dropout(0.1));
        v_net.addLayer(new ConvDense(new FeedForwardLayer(3136, 1, new LinearUnit(), 0.0625, random)));
        NeuralNetwork finalModel = new NeuralNetwork();
        finalModel.addLayer(new AggregationLayer(model, a_net, v_net));
        long total = 0L;
        for (Matrix m : model.getParameters()) {
            total += (long)m.w.length;
        }
        System.out.println(String.valueOf(Long.toString(total)) + " total parameters.");
        dp.setVisible(true);
        dp2.setVisible(true);
        dp3.setVisible(true);
        NNDevice[] dev = null;
        dev = CLUtils.findDevice("AMD", "Tahiti");
        DeepRLExample ai = null;
        try {
            ai = new DeepRLExample(finalModel, "D:\\models\\CatchAI.dat", 0.001, 0.75, 16384, new Adam(0.9, 0.999, 1.0E-8), dev);
        }
        catch (Exception e) {
            System.err.println("Error creating AI: ");
            e.printStackTrace();
            System.exit(1);
        }
        ai.run(10240);
        try {
            Thread.sleep(1000L);
            ai.stopGame();
        }
        catch (Exception e) {
            System.err.println("Error closing game: ");
            e.printStackTrace();
            System.exit(1);
        }
        if (dp != null) {
            dp.setVisible(false);
        }
        if (dp2 != null) {
            dp2.setVisible(false);
        }
        if (dp3 != null) {
            dp3.setVisible(false);
        }
        System.exit(0);
    }

    public static class ViewerLayer
    implements TensorLayer {
        private NeuralNetworkDisplay dp;
        private int[] o;
        private boolean showGradients;

        public ViewerLayer(NeuralNetworkDisplay dp, boolean showGradients, int ... a) {
            this.dp = dp;
            this.showGradients = showGradients;
            this.dp.setShowGradients(showGradients);
            this.o = new int[a.length];
            int i = 0;
            while (i < a.length) {
                this.o[i] = a[i];
                ++i;
            }
        }

        @Override
        public Tensor forward(final Tensor input, Graph g) throws Exception {
            if (!this.showGradients) {
                if (this.o.length >= 3) {
                    this.dp.showFeatures(input.matrices[this.o[0]], input.matrices[this.o[1]], input.matrices[this.o[2]]);
                }
                if (this.o.length == 2) {
                    this.dp.showFeatures(input.matrices[this.o[0]], input.matrices[this.o[1]]);
                }
                if (this.o.length == 1) {
                    this.dp.showFeatures(input.matrices[this.o[0]]);
                }
            }
            if (this.showGradients && g.isApplyingBackprop()) {
                g.addBackprop(new Runnable(){

                    @Override
                    public void run() {
                        if (o.length >= 3) {
                            dp.showFeatures(input.matrices[o[0]], input.matrices[o[1]], input.matrices[o[2]]);
                        }
                        if (o.length == 2) {
                            dp.showFeatures(input.matrices[o[0]], input.matrices[o[1]]);
                        }
                        if (o.length == 1) {
                            dp.showFeatures(input.matrices[o[0]]);
                        }
                    }
                });
            }
            return input;
        }

        public boolean isShowGradients() {
            return this.showGradients;
        }

        public void setShowGradients(boolean showGradients) {
            this.showGradients = showGradients;
            this.dp.setShowGradients(showGradients);
        }

        @Override
        public void resetState() {
        }

        @Override
        public List<Matrix> getParameters() {
            return null;
        }

        @Override
        public TensorLayer clone() {
            return new ViewerLayer(this.dp, this.showGradients, this.o);
        }

        @Override
        public void saveState(DataOutputStream fos) throws Exception {
        }

        @Override
        public void loadState(DataInputStream fis) throws Exception {
        }
    }
}

