/*
 * Decompiled with CFR 0.152.
 */
package theGhastModding.lstmStuff.main;

import datasets.TextGenerationUnbroken;
import datastructs.DataSet;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.nio.file.Files;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Random;
import model.FeedForwardLayer;
import model.GruLayer;
import model.LstmLayer;
import model.Model;
import model.MultipathLayer;
import model.NeuralNetwork;
import trainer.Adam;
import trainer.NewTrainer;
import trainer.RMSProp;
import util.FileIO;

public class TextLSTM {
    private static NewTrainer trainer;

    public static void main(String[] args) {
        trainer = new NewTrainer(new RMSProp(0.999, 3.0, 0.1));
        if (args.length > 0 && args[0].equalsIgnoreCase("Adam")) {
            trainer = new NewTrainer(new Adam(0.9, 0.999, 0.1));
        }
        int counter = 0;
        double learningRate = 0.001;
        int currentHour = LocalDateTime.now().getHour();
        double lastLoss = Double.MAX_VALUE;
        double currLoss = Double.MAX_VALUE;
        boolean logging = true;
        BufferedWriter logWriter = null;
        try {
            logWriter = new BufferedWriter(new FileWriter(new File("train log.txt")));
        }
        catch (Exception e) {
            System.err.println("Error creating logger (disabling logging): ");
            e.printStackTrace();
            logging = false;
        }
        long seed = -1701074527L;
        try {
            if (args.length > 0 && args[0].equalsIgnoreCase("Generate")) {
                System.out.println(TextLSTM.generate(15000, new File("FiO/allOfThem.txt"), new Random(), 0.64));
                System.exit(0);
            }
        }
        catch (Exception e) {
            System.err.println("Error generating output: ");
            e.printStackTrace();
            System.exit(1);
        }
        int i = 0;
        while (i < 100) {
            System.out.println(String.valueOf(Integer.toString(counter)) + " training iterations have passed (That's " + Integer.toString(counter * 2) + " epochs in total)");
            System.out.println("Time is: " + LocalDateTime.now().toString());
            if (LocalDateTime.now().getHour() - currentHour >= 1) {
                seed += (long)currentHour * 1000L;
                currentHour = LocalDateTime.now().getHour();
                try {
                    Files.copy(new File("FiO.ser").toPath(), new FileOutputStream(new File("FiO_backup_" + Integer.toString(LocalDateTime.now().getHour())) + ".ser"));
                    FileOutputStream fos = new FileOutputStream(new File("FiO_backup_" + Integer.toString(LocalDateTime.now().getHour()) + "_info.txt"));
                    String str = "Last reported Loss of this backup is: " + Double.toString(lastLoss);
                    fos.write(str.getBytes());
                    fos.flush();
                    fos.close();
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
            }
            try {
                lastLoss = currLoss;
                currLoss = TextLSTM.train(learningRate, 2, new File("FiO/allOfThem.txt"), new Random(seed));
            }
            catch (Exception e) {
                if (e.getMessage() != null && e.getMessage().equals("WARNING: invalid value for training loss. Try lowering learning rate.")) {
                    learningRate /= 10.0;
                }
                e.printStackTrace();
            }
            if (logging) {
                try {
                    if (counter != 0) {
                        logWriter.newLine();
                    }
                    logWriter.write(Double.toString(currLoss));
                    logWriter.flush();
                }
                catch (Exception e) {
                    System.err.println("Error logging message: ");
                    e.printStackTrace();
                }
            }
            ++counter;
            ++i;
        }
        try {
            logWriter.close();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        System.exit(0);
    }

    public static double train(double learningRate, int epochs, File f, Random random) throws Exception {
        TextGenerationUnbroken data = new TextGenerationUnbroken(f.getPath(), 100, 300, 500, random);
        System.gc();
        String savePath = "FiO.ser";
        TextGenerationUnbroken.reportPerplexity = false;
        TextGenerationUnbroken.reportSequenceLength = 100;
        Random rng = new Random();
        ArrayList<Model> layers = new ArrayList<Model>();
        layers.add(new GruLayer(data.inputDimension, 70, 0.08, rng));
        layers.add(new MultipathLayer(70, 300, 3, 0.008, rng));
        layers.add(new LstmLayer(300, 200, 0.08, rng));
        layers.add(new FeedForwardLayer(200, data.outputDimension, ((DataSet)data).getModelOutputUnitToUse(), 0.08, rng));
        NeuralNetwork lstm = new NeuralNetwork(layers);
        int reportEveryNthEpoch = 3;
        double loss = trainer.train(lstm, learningRate, epochs, data, reportEveryNthEpoch, savePath, new File(savePath).exists(), true, rng, null);
        return loss;
    }

    public static String generate(int stringLength, File f, Random random, double temperature) throws Exception {
        System.out.println(f.getPath());
        System.gc();
        String savePath = "FiO.ser";
        TextGenerationUnbroken.reportPerplexity = false;
        TextGenerationUnbroken.reportSequenceLength = stringLength;
        TextGenerationUnbroken data = new TextGenerationUnbroken(f.getPath(), 100, 250, 500, random);
        ArrayList<Model> layers = new ArrayList<Model>();
        layers.add(new GruLayer(data.inputDimension, 70, 0.08, random));
        layers.add(new MultipathLayer(70, 300, 3, 0.008, random));
        layers.add(new LstmLayer(300, 200, 0.08, random));
        layers.add(new FeedForwardLayer(200, data.outputDimension, data.getModelOutputUnitToUse(), 0.08, random));
        NeuralNetwork lstm = new NeuralNetwork(layers);
        FileIO.loadNeuralNetwork(savePath, lstm);
        return data.generateText(lstm, stringLength, false, temperature, random);
    }
}

