/*
 * Decompiled with CFR 0.152.
 */
package theGhastModding.lstmStuff.main;

import datasets.TextGenerationUnbroken;
import datastructs.DataSet;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.InputStream;
import java.net.HttpURLConnection;
import java.net.URL;
import java.time.LocalDateTime;
import java.util.Random;
import model.Dropout;
import model.FeedForwardLayer;
import model.LstmLayer;
import model.NeuralNetwork;
import nonlinearities.ReLuUnit;
import trainer.AMSGrad;
import trainer.Adam;
import trainer.Trainer;
import util.FileIO;

public class TextLSTM {
    private static Trainer trainer;

    public static void main(String[] args) {
        if (args.length > 0 && args[0].equalsIgnoreCase("convert")) {
            try {
                BufferedWriter bw = new BufferedWriter(new FileWriter(new File("voyager.txt")));
                BufferedReader br = new BufferedReader(new FileReader("episodes.txt"));
                while (true) {
                    String link;
                    System.err.println((link = br.readLine()) == null ? "Done." : link);
                    if (link == null) break;
                    URL url = new URL(link);
                    HttpURLConnection con = (HttpURLConnection)url.openConnection();
                    InputStream in = con.getInputStream();
                    byte[] buffer = new byte[1024];
                    String allCode = "";
                    while (in.available() > 0) {
                        in.read(buffer);
                        allCode = String.valueOf(allCode) + new String(buffer);
                        if (in.available() > 1024) continue;
                        Thread.sleep(1000L);
                    }
                    allCode = allCode.replaceAll("(?s)<[^>]*>(\\s*<[^>]*>)*", " ");
                    bw.write(allCode);
                    bw.newLine();
                    bw.write("@@@");
                    bw.newLine();
                }
                br.close();
                bw.close();
            }
            catch (Exception e) {
                System.err.println("Error downloading episode transcripts: ");
                e.printStackTrace();
                System.exit(1);
            }
            System.exit(0);
        }
        trainer = new Trainer(new AMSGrad(0.99, 0.999, 1.0E-7));
        if (args.length > 0 && args[0].equalsIgnoreCase("Adam")) {
            trainer = new Trainer(new Adam(0.95, 0.999, 1.0E-7));
        }
        int counter = 0;
        double learningRate = 0.001;
        double currLoss = Double.MAX_VALUE;
        boolean logging = true;
        BufferedWriter logWriter = null;
        try {
            logWriter = new BufferedWriter(new FileWriter(new File("train log.txt")));
        }
        catch (Exception e) {
            System.err.println("Error creating logger (disabling logging): ");
            e.printStackTrace();
            logging = false;
        }
        try {
            if (args.length > 0 && args[0].equalsIgnoreCase("Generate")) {
                System.out.println(TextLSTM.generate(1500, new File("voyager.txt"), new Random(), 0.8));
                System.exit(0);
            }
        }
        catch (Exception e) {
            System.err.println("Error generating output: ");
            e.printStackTrace();
            System.exit(1);
        }
        int i = 0;
        while (i < 15) {
            System.out.println(String.valueOf(Integer.toString(counter)) + " training iterations have passed (That's " + Integer.toString(counter * 2) + " epochs in total)");
            System.out.println("Time is: " + LocalDateTime.now().toString());
            try {
                currLoss = TextLSTM.train(learningRate, 2, new File("voyager.txt"), new Random());
            }
            catch (Exception e) {
                if (e.getMessage() != null && e.getMessage().equals("WARNING: invalid value for training loss. Try lowering learning rate.")) {
                    learningRate /= 10.0;
                }
                e.printStackTrace();
            }
            if (logging) {
                try {
                    if (counter != 0) {
                        logWriter.newLine();
                    }
                    logWriter.write(Double.toString(currLoss));
                    logWriter.flush();
                }
                catch (Exception e) {
                    System.err.println("Error logging message: ");
                    e.printStackTrace();
                }
            }
            ++counter;
            ++i;
        }
        try {
            logWriter.close();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        try {
            System.out.println(TextLSTM.generate(1000, new File("voyager.txt"), new Random(), 0.64));
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        trainer.dispose();
        System.exit(0);
    }

    public static double train(double learningRate, int epochs, File f, Random random) throws Exception {
        TextGenerationUnbroken data = new TextGenerationUnbroken(f.getPath(), 15, 250, 1000, true, random);
        System.gc();
        String savePath = "voyager.ser";
        TextGenerationUnbroken.reportPerplexity = false;
        TextGenerationUnbroken.reportSequenceLength = 100;
        Random rng = new Random();
        NeuralNetwork lstm = TextLSTM.getModel(data, random);
        int reportEveryNthEpoch = 3;
        double loss = trainer.train(lstm, learningRate, epochs, data, reportEveryNthEpoch, savePath, new File(savePath).exists(), true, rng);
        return loss;
    }

    public static String generate(int stringLength, File f, Random random, double temperature) throws Exception {
        System.out.println(f.getPath());
        System.gc();
        String savePath = "voyager.ser";
        TextGenerationUnbroken data = new TextGenerationUnbroken(f.getPath(), 15, 250, 1000, true, random);
        NeuralNetwork lstm = TextLSTM.getModel(data, random);
        FileIO.loadNeuralNetwork(savePath, lstm);
        return data.generateText(lstm, stringLength, false, temperature, random);
    }

    public static NeuralNetwork getModel(DataSet data, Random random) {
        NeuralNetwork lstm = new NeuralNetwork();
        lstm.addLayer(new FeedForwardLayer(data.inputDimension.getHeight(), 128, new ReLuUnit(), 0.08, random));
        lstm.addLayer(new Dropout(0.1));
        lstm.addLayer(new LstmLayer(128, 350, 0.08, random));
        lstm.addLayer(new Dropout(0.1));
        lstm.addLayer(new LstmLayer(350, 128, 0.08, random));
        lstm.addLayer(new Dropout(0.1));
        lstm.addLayer(new FeedForwardLayer(128, data.outputDimension.getHeight(), data.getModelOutputUnitToUse(), 0.08, random));
        return lstm;
    }
}

