/*
 * Decompiled with CFR 0.152.
 */
package theGhastModding.lstmStuff.main;

import autodiff.Graph;
import java.io.BufferedReader;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FilenameFilter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import loss.Loss;
import loss.LossArgMax;
import loss.LossSumOfSquares;
import matrix.Matrix;
import model.FeedForwardLayer;
import model.LstmLayer;
import model.NeuralNetwork;
import nonlinearities.ReLuUnit;
import nonlinearities.TanhUnit;
import trainer.Adam;
import trainer.GradientNoise;
import trainer.TrainingMethod;
import util.FileIO;

public class CryptoAI {
    private List<Cryptocurrency> allCryptos = new ArrayList<Cryptocurrency>();
    private Map<String, Double> datedGoldPrices = null;
    private NeuralNetwork model;
    private File savefile;
    private File replayMemoryFolder;
    private double epsilon = 0.25;
    private int tradingTime = 30;
    private double startingCaptial = 1000.0;
    private static final Loss lossToUse = new LossSumOfSquares();
    private TrainingMethod method;
    static final FilenameFilter csvFilter = new FilenameFilter(){

        @Override
        public boolean accept(File dir, String name) {
            String lowercaseName = name.toLowerCase();
            return lowercaseName.endsWith(".csv");
        }
    };

    public CryptoAI(File folder, File savefile, File replayMemory, TrainingMethod method) throws Exception {
        File goldFile;
        this.replayMemoryFolder = replayMemory;
        try {
            if (!replayMemory.exists()) {
                replayMemory.mkdir();
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        if (!(goldFile = new File(String.valueOf(folder.getPath()) + "/Gold.csv")).exists()) {
            throw new Exception("CSV with historical Gold prices not found!");
        }
        this.loadGold(goldFile);
        File[] fileArray = folder.listFiles(csvFilter);
        int n = fileArray.length;
        int n2 = 0;
        while (n2 < n) {
            File f = fileArray[n2];
            if (!f.equals(goldFile)) {
                this.load(f);
            }
            ++n2;
        }
        this.savefile = savefile;
        this.method = method;
    }

    private void play(boolean silent, boolean training) throws Exception {
        double money = this.startingCaptial;
        double cryptoOwnings = 0.0;
        this.loadModel(this.savefile);
        this.model.resetState();
        Random rng = new Random();
        Cryptocurrency crypto = this.allCryptos.get(rng.nextInt(this.allCryptos.size()));
        int startPos = 0;
        int pos = startPos = rng.nextInt(crypto.numEntries() - this.tradingTime - 1);
        Matrix inBuffer = new Matrix(5);
        Matrix outBuffer = new Matrix(3);
        ArrayList<Transition> replayMemory = new ArrayList<Transition>();
        Graph g = new Graph(false);
        System.out.println("Starting forward pass using " + crypto.getName() + "...");
        double lastReward = 0.0;
        while (true) {
            int selectedAction;
            double currValue = crypto.getPrice(pos);
            double currCap = crypto.getMarketCap(pos);
            double currGoldValue = crypto.getGoldPrice(pos);
            inBuffer.w[0] = money / 1000.0;
            inBuffer.w[1] = cryptoOwnings / 100.0;
            inBuffer.w[2] = currValue / 10000.0;
            inBuffer.w[3] = currCap;
            inBuffer.w[4] = currGoldValue / 2000.0;
            outBuffer = this.model.forward(inBuffer, g);
            int action = selectedAction = this.argmax(outBuffer);
            if (rng.nextDouble() < this.epsilon) {
                action = rng.nextInt(3);
            }
            double reward = 0.0;
            double moneyPrev = money;
            double owningsPrev = cryptoOwnings;
            if (action == 0) {
                if (cryptoOwnings > 0.0) {
                    if (!silent) {
                        System.out.println("Action: sell");
                    }
                    money = cryptoOwnings * currValue;
                    cryptoOwnings = 0.0;
                }
            } else if (action == 1) {
                if (money > 0.0) {
                    if (!silent) {
                        System.out.println("Action: buy");
                    }
                    cryptoOwnings = money / currValue;
                    money = 0.0;
                }
            } else if (!silent) {
                System.out.println("Action: hold");
            }
            double f = money + cryptoOwnings * currValue;
            reward = (f - this.startingCaptial) / 1000.0;
            lastReward += (reward -= lastReward);
            if (!silent) {
                System.out.println("Reward=" + Double.toString(reward));
            }
            if (!silent) {
                System.out.println(f);
            }
            replayMemory.add(new Transition(currValue, moneyPrev, currCap, owningsPrev, currGoldValue, action, selectedAction, reward));
            if (++pos - startPos >= this.tradingTime) break;
            if (silent) continue;
            Thread.sleep(100L);
        }
        double f = money + cryptoOwnings * crypto.getPrice(crypto.numEntries() - 1);
        System.out.println("Finished.\nFinal balances: " + Double.toString(money) + "USD and " + Double.toString(cryptoOwnings) + " " + crypto.getName() + " (" + Double.toString(cryptoOwnings * crypto.getPrice(crypto.numEntries() - 1)) + "). Total: " + Double.toString(f) + "USD, " + (f >= 1000.0 ? "+" : "") + Double.toString(f / this.startingCaptial * 100.0 - 100.0));
        if (training) {
            this.saveTransitions(replayMemory, new File(String.valueOf(this.replayMemoryFolder.getPath()) + "/" + Long.toString(System.currentTimeMillis()) + ".dat"));
        }
        this.model.resetState();
    }

    private int argmax(Matrix m) {
        double largest = -1000.0;
        int largestIndx = 0;
        int i = 0;
        while (i < m.w.length) {
            if (m.w[i] > largest) {
                largest = m.w[i];
                largestIndx = i;
            }
            ++i;
        }
        return largestIndx;
    }

    private double max(Matrix m) {
        double largest = -1000.0;
        int i = 0;
        while (i < m.w.length) {
            if (m.w[i] > largest) {
                largest = m.w[i];
            }
            ++i;
        }
        return largest;
    }

    private double train(int iterations, double learningRate) throws Exception {
        LossArgMax testLoss = new LossArgMax();
        this.loadModel(this.savefile);
        double numLoss = 0.0;
        double denomLoss = 0.0;
        File[] replays = this.replayMemoryFolder.listFiles();
        Random rng = new Random();
        Matrix inBuffer = new Matrix(5);
        Matrix outBuffer = new Matrix(3);
        int i = 0;
        while (i < iterations) {
            File replay = replays[rng.nextInt(replays.length)];
            List<Transition> transitions = this.loadTransitions(replay);
            Graph g = new Graph(true);
            this.model.resetState();
            inBuffer.w[0] = transitions.get(0).getCurrMoney() / 1000.0;
            inBuffer.w[1] = transitions.get(0).getCurrCryptoOwnings() / 100.0;
            inBuffer.w[2] = transitions.get(0).getCurrValue() / 10000.0;
            inBuffer.w[3] = transitions.get(0).getCurrMarketCap();
            inBuffer.w[4] = transitions.get(0).getCurrGoldValue() / 2000.0;
            outBuffer = this.model.forward(inBuffer, g);
            int j = 0;
            while (j < transitions.size()) {
                Matrix targetOutput = new Matrix(3);
                int k = 0;
                while (k < targetOutput.w.length) {
                    targetOutput.w[k] = outBuffer.w[k];
                    ++k;
                }
                targetOutput.w[transitions.get((int)j).getUsedAction()] = transitions.get(j).getReward();
                System.out.println(outBuffer);
                Matrix newOutBuffer = null;
                if (j != transitions.size() - 1) {
                    inBuffer.w[0] = transitions.get(j + 1).getCurrMoney() / 1000.0;
                    inBuffer.w[1] = transitions.get(j + 1).getCurrCryptoOwnings() / 100.0;
                    inBuffer.w[2] = transitions.get(j + 1).getCurrValue() / 10000.0;
                    inBuffer.w[3] = transitions.get(j + 1).getCurrMarketCap();
                    inBuffer.w[4] = transitions.get(j + 1).getCurrGoldValue() / 2000.0;
                    newOutBuffer = this.model.forward(inBuffer, g);
                    int n = transitions.get(j).getUsedAction();
                    targetOutput.w[n] = targetOutput.w[n] + 0.5 * this.max(newOutBuffer);
                }
                numLoss += testLoss.measure(outBuffer, targetOutput);
                denomLoss += 1.0;
                lossToUse.backward(outBuffer, targetOutput);
                outBuffer = newOutBuffer;
                ++j;
            }
            ++this.model.t;
            g.backward();
            this.method.updateParameters(this.model, learningRate, transitions.size());
            ++i;
        }
        this.model.resetState();
        this.saveModel(this.savefile);
        return numLoss / denomLoss;
    }

    public double autoTrain(double learningRate, int plays, int epochs) throws Exception {
        double loss = 0.0;
        int i = 0;
        while (i < plays) {
            System.out.println(String.valueOf(Integer.toString(i + 1)) + "/" + Integer.toString(plays));
            this.play(true, true);
            loss = this.train(epochs, learningRate);
            if (Double.isNaN(loss)) {
                throw new Exception("NaN loss");
            }
            System.out.println("Loss: " + Double.toString(loss));
            ++i;
        }
        return loss;
    }

    private void loadModel(File f) throws Exception {
        Random rng = new Random();
        this.model = new NeuralNetwork();
        this.model.addLayer(new FeedForwardLayer(5, 64, new ReLuUnit(), 0.08, rng));
        this.model.addLayer(new LstmLayer(64, 32, 0.08, rng));
        this.model.addLayer(new LstmLayer(32, 16, 0.08, rng));
        this.model.addLayer(new FeedForwardLayer(16, 3, new TanhUnit(), 0.08, rng));
        if (f.exists()) {
            FileIO.loadNeuralNetwork(f.getPath(), this.model);
        }
    }

    private void saveModel(File f) throws Exception {
        FileIO.saveNeuralNetwork(f.getPath(), this.model);
    }

    private void loadGold(File csv) throws Exception {
        this.datedGoldPrices = new HashMap<String, Double>();
        BufferedReader reader = new BufferedReader(new FileReader(csv));
        String l = reader.readLine();
        l = reader.readLine();
        while (l != null && !l.isEmpty()) {
            String[] ls = l.split(",");
            String date = ls[0];
            double value = 0.0;
            value = !ls[2].isEmpty() ? Double.parseDouble(ls[2]) : Double.parseDouble(ls[1]);
            this.datedGoldPrices.put(date, value);
            l = reader.readLine();
        }
        reader.close();
    }

    private void load(File csv) throws Exception {
        Cryptocurrency crypto = new Cryptocurrency(csv.getName().split("\\.")[0]);
        BufferedReader reader = new BufferedReader(new FileReader(csv));
        String l = reader.readLine();
        l = reader.readLine();
        double largestCap = 0.0;
        double previousGoldPrice = 0.0;
        while (l != null && !l.isEmpty()) {
            String date;
            int year;
            String[] ls = l.split(",");
            double value = Double.parseDouble(ls[4]);
            double marketCap = Double.parseDouble(ls[3]);
            if (marketCap > largestCap) {
                largestCap = marketCap;
            }
            if ((year = Integer.parseInt((date = ls[0]).split("-")[0])) >= 2017) {
                double goldPrice = 0.0;
                goldPrice = this.datedGoldPrices.get(date) == null ? previousGoldPrice : this.datedGoldPrices.get(date);
                crypto.addData(value, marketCap, goldPrice);
                previousGoldPrice = goldPrice;
            }
            l = reader.readLine();
        }
        reader.close();
        int i = 0;
        while (i < crypto.numEntries()) {
            crypto.marketCaps.set(i, (Double)crypto.marketCaps.get(i) / largestCap);
            ++i;
        }
        this.allCryptos.add(crypto);
    }

    private void saveTransitions(List<Transition> transitions, File f) throws Exception {
        DataOutputStream dos = new DataOutputStream(new FileOutputStream(f));
        dos.writeInt(transitions.size());
        for (Transition t : transitions) {
            dos.writeDouble(t.getCurrValue());
            dos.writeDouble(t.getCurrMoney());
            dos.writeDouble(t.getCurrMarketCap());
            dos.writeDouble(t.getCurrCryptoOwnings());
            dos.writeDouble(t.getCurrGoldValue());
            dos.writeInt(t.getUsedAction());
            dos.writeInt(t.getSelectedAction());
            dos.writeDouble(t.getReward());
        }
        dos.close();
    }

    private List<Transition> loadTransitions(File f) throws Exception {
        ArrayList<Transition> transitions = new ArrayList<Transition>();
        DataInputStream dis = new DataInputStream(new FileInputStream(f));
        int l = dis.readInt();
        int i = 0;
        while (i < l) {
            double currValue = dis.readDouble();
            double currMoney = dis.readDouble();
            double currMarketCap = dis.readDouble();
            double currCryptoOwnings = dis.readDouble();
            double currGoldValue = dis.readDouble();
            int usedAction = dis.readInt();
            int selectedAction = dis.readInt();
            double reward = dis.readDouble();
            transitions.add(new Transition(currValue, currMoney, currMarketCap, currCryptoOwnings, currGoldValue, usedAction, selectedAction, reward));
            ++i;
        }
        dis.close();
        return transitions;
    }

    public static void main(String[] args) {
        CryptoAI ai = null;
        try {
            ai = new CryptoAI(new File("cryptos/"), new File("D:/crypto.ser"), new File("D:/crypto/"), new GradientNoise(new Adam(0.9, 0.999, 1.0E-8)));
        }
        catch (Exception e) {
            System.err.println("Error loading csv files: ");
            e.printStackTrace();
            System.exit(1);
        }
        for (Cryptocurrency c : ai.allCryptos) {
            System.out.println(c.getName());
        }
        double loss = 0.0;
        double bestLoss = Double.POSITIVE_INFINITY;
        try {
            ai.epsilon = 0.7;
            int totalPlays = 0;
            long startTime = System.currentTimeMillis();
            while (System.currentTimeMillis() - startTime <= 300000L) {
                loss = ai.autoTrain(0.001, 5, 15);
                totalPlays += 5;
                if (!(loss < bestLoss)) continue;
                bestLoss = loss;
            }
            System.out.println("Total plays: " + Integer.toString(totalPlays));
            System.out.println(Integer.toString(totalPlays * 15));
        }
        catch (Exception e) {
            System.err.println("Error training AI: ");
            e.printStackTrace();
            System.exit(1);
        }
        try {
            ai.epsilon = 0.0;
            ai.play(false, false);
        }
        catch (Exception e) {
            System.err.println("Error testing AI: ");
            e.printStackTrace();
            System.exit(1);
        }
        System.out.println("Last recorded loss was: " + Double.toString(loss));
        System.out.println("Best loss was: " + Double.toString(bestLoss));
    }

    private class Cryptocurrency {
        private List<Double> prices = new ArrayList<Double>();
        private List<Double> marketCaps = new ArrayList<Double>();
        private List<Double> goldPrices = new ArrayList<Double>();
        private String name;

        public Cryptocurrency(String name) {
            this.name = name;
        }

        private void addData(double price, double marketCap, double goldPrice) {
            this.prices.add(price);
            this.marketCaps.add(marketCap);
            this.goldPrices.add(goldPrice);
        }

        public String getName() {
            return this.name;
        }

        public int numEntries() {
            return this.prices.size();
        }

        public double getPrice(int index) {
            return this.prices.get(index);
        }

        public double getMarketCap(int index) {
            return this.marketCaps.get(index);
        }

        public double getGoldPrice(int index) {
            return this.goldPrices.get(index);
        }
    }

    private class Transition {
        private double currValue;
        private double currMoney;
        private double currMarketCap;
        private double currCryptoOwnings;
        private double currGoldValue;
        private int usedAction;
        private int selectedAction;
        private double reward;

        public Transition(double currValue, double currMoney, double currMarketCap, double currCryptoOwnings, double currGoldValue, int usedAction, int selectedAction, double reward) {
            this.currValue = currValue;
            this.currMoney = currMoney;
            this.currMarketCap = currMarketCap;
            this.currCryptoOwnings = currCryptoOwnings;
            this.currGoldValue = currGoldValue;
            this.usedAction = usedAction;
            this.selectedAction = selectedAction;
            this.reward = reward;
        }

        public double getCurrValue() {
            return this.currValue;
        }

        public double getCurrMoney() {
            return this.currMoney;
        }

        public double getCurrMarketCap() {
            return this.currMarketCap;
        }

        public double getCurrCryptoOwnings() {
            return this.currCryptoOwnings;
        }

        public double getCurrGoldValue() {
            return this.currGoldValue;
        }

        public int getUsedAction() {
            return this.usedAction;
        }

        public int getSelectedAction() {
            return this.selectedAction;
        }

        public double getReward() {
            return this.reward;
        }
    }
}

