/*
 * Decompiled with CFR 0.152.
 */
package examples.deepRL;

import autodiff.Graph;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.util.List;
import java.util.Random;
import loss.Loss;
import matrix.Matrix;
import matrix.Tensor;
import model.NeuralNetwork;
import trainer.Optimizer;
import util.CLUtils;
import util.FileChannelInputStream;
import util.FileChannelOutputStream;
import util.NNDevice;

public class DeepRLAgent {
    private NeuralNetwork model;
    private NeuralNetwork targetModel;
    private int actionDim;
    private double learningRate;
    private Optimizer optimizer;
    private Loss lossToUse;
    private double discount;
    private double explore;
    private Random random;
    public Graph graph;
    private double lastErr = 0.0;
    private Replay[] replayBuffer;
    private int replayPos = 0;
    private int replaySize = 0;
    private double prioritySum = 0.0;
    private double alpha = 0.7;
    private double tau = 0.01;
    private int replayCntr = 0;
    private int randomActFrames = 1;
    private int randomActCntr = 0;
    private int randomAct;

    public DeepRLAgent(NeuralNetwork model, int actionDim, double learningRate, Optimizer optimizer, Loss lossToUse, double discount, double explore, int replayBufferSize, Random random, NNDevice[] devices) throws Exception {
        this.model = model;
        this.targetModel = (NeuralNetwork)model.clone();
        this.actionDim = actionDim;
        this.learningRate = learningRate;
        this.optimizer = optimizer;
        this.lossToUse = lossToUse;
        this.discount = discount;
        this.explore = explore;
        this.replayBuffer = new Replay[replayBufferSize];
        this.random = random;
        this.graph = devices == null || devices.length == 0 ? new Graph(true) : (devices.length == 1 ? CLUtils.createGraph(devices[0], true) : CLUtils.createGraph(devices[0], devices[1], true));
        this.teleport();
    }

    public void teleport() {
        this.model.resetState();
        this.targetModel.resetState();
    }

    public int think(Tensor observations) throws Exception {
        this.graph.setApplyingBackprop(false);
        int action = this.argmax(this.model.forward((Tensor)observations, (Graph)this.graph).matrices[0].w);
        if (this.randomActCntr < this.randomActFrames) {
            ++this.randomActCntr;
            return this.randomAct;
        }
        if (this.explore != 0.0 && this.random.nextDouble() <= this.explore) {
            this.randomAct = action = this.random.nextInt(this.actionDim);
            this.randomActCntr = 1;
        }
        return action;
    }

    public void feedback(Tensor observations, int action, double reward, boolean gameOver) throws Exception {
        double priority = Math.pow(Math.abs(reward) + 0.01, this.alpha);
        if (this.replayBuffer[this.replayPos] != null) {
            this.prioritySum -= this.replayBuffer[this.replayPos].getPriority();
        }
        this.replayBuffer[this.replayPos] = new Replay(observations.clone(), action, reward, gameOver, priority, this.replayCntr);
        ++this.replayCntr;
        this.prioritySum += priority;
        ++this.replayPos;
        this.replaySize = Math.max(this.replaySize, this.replayPos);
        if (this.replayPos >= this.replayBuffer.length) {
            this.replayPos = 0;
        }
        if (this.replayBuffer[this.replayPos] != null && this.random.nextDouble() < this.replayBuffer[this.replayPos].getPriority() / this.prioritySum) {
            ++this.replayPos;
        }
    }

    public double train(int epochs) throws Exception {
        if (epochs <= 0) {
            return -1.0;
        }
        this.lastErr = 0.0;
        if (this.model.t == 0) {
            this.model.t = 2;
        }
        double nomLoss = 0.0;
        double denomLoss = 0.0;
        int l = 0;
        while (l < epochs) {
            Replay r_t = null;
            Replay r_t1 = null;
            if (this.replaySize <= 15) {
                return 0.1;
            }
            double a = this.random.nextDouble() * this.prioritySum;
            double sum = 0.0;
            int i = 0;
            while (i < this.replaySize) {
                if ((sum += this.replayBuffer[i].getPriority()) >= a) {
                    r_t = this.replayBuffer[i];
                    if (i + 1 >= this.replaySize || this.replayBuffer[i + 1].getNum() != r_t.getNum() + 1) {
                        r_t1 = null;
                        break;
                    }
                    r_t1 = this.replayBuffer[i + 1];
                    break;
                }
                ++i;
            }
            this.graph.setApplyingBackprop(false);
            double target = this.getTargetFor(r_t, r_t1);
            this.graph.setApplyingBackprop(true);
            Matrix out = this.model.forward((Tensor)r_t.getObservation(), (Graph)this.graph).matrices[0];
            Matrix targetOutput = out.clone();
            targetOutput.w[r_t.getAction()] = target;
            double error = this.lossToUse.measure(out, targetOutput);
            nomLoss += error;
            denomLoss += 1.0;
            this.prioritySum -= r_t.getPriority();
            r_t.updatePriority(Math.pow(error, this.alpha));
            this.prioritySum += r_t.getPriority();
            this.lossToUse.backward(out, targetOutput);
            ++l;
        }
        this.graph.backward();
        this.optimizer.updateParameters(this.model, this.learningRate, epochs);
        double betterTau = this.tau * (double)epochs;
        List<Matrix> targetParams = this.targetModel.getParameters();
        List<Matrix> mainParams = this.model.getParameters();
        int i = 0;
        while (i < targetParams.size()) {
            Matrix targetM = targetParams.get(i);
            Matrix mainM = mainParams.get(i);
            int j = 0;
            while (j < targetM.w.length) {
                targetM.w[j] = betterTau * mainM.w[j] + (1.0 - betterTau) * targetM.w[j];
                ++j;
            }
            ++i;
        }
        this.graph.setApplyingBackprop(false);
        this.lastErr = nomLoss / denomLoss;
        return nomLoss / denomLoss;
    }

    private double getTargetFor(Replay r_t, Replay r_t1) throws Exception {
        double target = r_t.getReward();
        if (!r_t.isGameOver() && r_t1 != null) {
            Matrix res = this.targetModel.forward((Tensor)r_t1.getObservation(), (Graph)this.graph).matrices[0];
            int action = this.argmax(this.model.forward((Tensor)r_t1.getObservation(), (Graph)this.graph).matrices[0].w);
            target += this.discount * res.w[action];
        }
        return target;
    }

    private int argmax(double[] arr) {
        int indx = 0;
        double max = -1.0E7;
        int i = 0;
        while (i < arr.length) {
            if (arr[i] > max) {
                max = arr[i];
                indx = i;
            }
            ++i;
        }
        return indx;
    }

    public void save(File f) throws Exception {
        FileOutputStream fos = new FileOutputStream(f);
        FileChannelOutputStream out = new FileChannelOutputStream(fos.getChannel());
        DataOutputStream dos = new DataOutputStream(out);
        this.model.saveState(dos);
        this.targetModel.saveState(dos);
        dos.writeInt(this.actionDim);
        dos.writeDouble(this.learningRate);
        dos.writeDouble(this.discount);
        dos.writeDouble(this.explore);
        dos.writeDouble(this.prioritySum);
        dos.writeDouble(this.alpha);
        dos.writeDouble(this.tau);
        dos.writeInt(this.replayPos);
        dos.writeInt(this.replaySize);
        int i = 0;
        while (i < this.replaySize) {
            this.replayBuffer[i].save(dos);
            ++i;
        }
        dos.writeInt(this.replayCntr);
        dos.close();
        fos.close();
    }

    public void load(File f) throws Exception {
        FileInputStream fis = new FileInputStream(f);
        FileChannelInputStream in = new FileChannelInputStream(fis.getChannel());
        DataInputStream dis = new DataInputStream(in);
        this.model.loadState(dis);
        this.targetModel.loadState(dis);
        this.actionDim = dis.readInt();
        this.learningRate = dis.readDouble();
        this.discount = dis.readDouble();
        this.explore = dis.readDouble();
        this.prioritySum = dis.readDouble();
        this.alpha = dis.readDouble();
        this.tau = dis.readDouble();
        this.replayPos = dis.readInt();
        this.replaySize = dis.readInt();
        if (this.replaySize > this.replayBuffer.length) {
            this.replayBuffer = new Replay[this.replaySize];
        }
        int i = 0;
        while (i < this.replaySize) {
            this.replayBuffer[i] = new Replay(null, 0, 0.0, false, 0.0, 0);
            this.replayBuffer[i].load(dis);
            ++i;
        }
        this.replayCntr = dis.readInt();
        dis.close();
        fis.close();
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setLearningRate(double newLR) {
        this.learningRate = newLR;
    }

    public double getExplorationRate() {
        return this.explore;
    }

    public void setExplorationRate(double newExplore) {
        this.explore = newExplore;
    }

    public double getDiscount() {
        return this.discount;
    }

    public void setDiscount(double newDiscount) {
        this.discount = newDiscount;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    public double getTau() {
        return this.tau;
    }

    public void setTau(double tau) {
        this.tau = tau;
    }

    public double getLastErr() {
        return this.lastErr;
    }

    public void cleanUpGraphs() {
        this.graph.cleanUp();
    }

    public int getRandomActFrames() {
        return this.randomActFrames;
    }

    public void setRandomActFrames(int randomActFrames) {
        this.randomActFrames = randomActFrames;
    }

    private class Replay {
        private Tensor observations;
        private int action;
        private double reward;
        private boolean gameOver;
        private double priority;
        private int num;

        private Replay(Tensor observations, int action, double reward, boolean gameOver, double priority, int num) {
            this.observations = observations;
            this.action = action;
            this.reward = reward;
            this.gameOver = gameOver;
            this.priority = priority;
            this.num = num;
        }

        public Tensor getObservation() {
            return this.observations;
        }

        public int getAction() {
            return this.action;
        }

        public double getReward() {
            return this.reward;
        }

        public boolean isGameOver() {
            return this.gameOver;
        }

        public double getPriority() {
            return this.priority;
        }

        public int getNum() {
            return this.num;
        }

        public void updatePriority(double newPriority) {
            this.priority = newPriority;
        }

        public void save(DataOutputStream dos) throws Exception {
            this.observations.save(dos);
            dos.writeInt(this.action);
            dos.writeDouble(this.reward);
            dos.writeBoolean(this.gameOver);
            dos.writeDouble(this.priority);
            dos.writeInt(this.num);
        }

        public void load(DataInputStream dis) throws Exception {
            if (this.observations == null) {
                this.observations = new Tensor(1, 1, 1);
            }
            this.observations.load(dis);
            this.action = dis.readInt();
            this.reward = dis.readDouble();
            this.gameOver = dis.readBoolean();
            this.priority = dis.readDouble();
            this.num = dis.readInt();
        }
    }
}

