/*
 * Decompiled with CFR 0.152.
 */
package examples;

import datasets.TextGeneration;
import datasets.TextGenerationUnbroken;
import datastructs.DataSet;
import java.io.File;
import java.util.List;
import java.util.Random;
import model.Dropout;
import model.FeedForwardLayer;
import model.LstmLayer;
import model.NeuralNetwork;
import nonlinearities.RectifiedLinearUnit;
import trainer.Adam;
import trainer.Trainer;
import util.FileIO;

public class TextGenExample {
    public static void main(String[] args) {
        Trainer trainer = new Trainer(new Adam(0.9, 0.999, 1.0E-8));
        double learningRate = 5.0E-4;
        int iterations = 5;
        double samplingTemperature = 0.74;
        int epochs = 2;
        int sequenceCount = 32;
        int hiddenUnits = 321;
        int inputBottleneck = 128;
        String textfile = "text.txt";
        boolean breakLines = true;
        Random rng = new Random();
        DataSet data = null;
        try {
            data = breakLines ? new TextGeneration(textfile, false, true, sequenceCount) : new TextGenerationUnbroken(textfile, sequenceCount, 2048, 8192, true, rng);
        }
        catch (Exception e) {
            System.err.println("Error loading dataset: ");
            e.printStackTrace();
            System.exit(1);
        }
        NeuralNetwork lstm = new NeuralNetwork();
        lstm.addLayer(new FeedForwardLayer(data.inputDimension.getHeight(), inputBottleneck, new RectifiedLinearUnit(0.1), Math.sqrt(1.0 / (double)data.inputDimension.getHeight()), rng));
        lstm.addLayer(new Dropout(0.25));
        lstm.addLayer(new LstmLayer(inputBottleneck, hiddenUnits, Math.sqrt(1.0 / (double)inputBottleneck), rng));
        lstm.addLayer(new Dropout(0.25));
        lstm.addLayer(new LstmLayer(hiddenUnits, hiddenUnits - 64, Math.sqrt(1.0 / (double)hiddenUnits), rng));
        lstm.addLayer(new FeedForwardLayer(hiddenUnits - 64, data.outputDimension.getHeight(), data.getModelOutputUnitToUse(), Math.sqrt(1.0 / ((double)hiddenUnits - 64.0)), rng));
        try {
            if (new File("textGenExample.dat").exists()) {
                FileIO.loadNeuralNetwork("textGenExample.dat", lstm);
            }
            int i = 0;
            while (i < iterations) {
                System.out.println("Iteration " + Integer.toString(i + 1) + "/" + Integer.toString(iterations));
                trainer.train(lstm, learningRate, epochs, data, 5, "textGenExample.dat", true, true, rng);
                FileIO.saveNeuralNetwork("textGenExample.dat", lstm);
                try {
                    data = breakLines ? new TextGeneration(textfile, false, true, sequenceCount) : new TextGenerationUnbroken(textfile, sequenceCount, 128, 512, true, rng);
                }
                catch (Exception e) {
                    System.err.println("Error loading dataset: ");
                    e.printStackTrace();
                    System.exit(1);
                }
                ++i;
            }
        }
        catch (Exception e) {
            System.err.println("Error training: ");
            e.printStackTrace();
            System.exit(1);
        }
        try {
            if (breakLines) {
                List<String> res = ((TextGeneration)data).generateText(lstm, 255, false, samplingTemperature, rng);
                for (String s : res) {
                    System.out.println(s);
                }
            } else {
                String s = ((TextGenerationUnbroken)data).generateText(lstm, 1024, false, samplingTemperature, rng);
                System.out.println(s);
            }
        }
        catch (Exception e) {
            System.err.println("Error generating text: ");
            e.printStackTrace();
            System.exit(1);
        }
        trainer.dispose();
    }
}

