/*
 * Decompiled with CFR 0.152.
 */
package datasets;

import autodiff.Graph;
import datastructs.DataSequence;
import datastructs.DataSet;
import datastructs.DataStep;
import java.io.File;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import loss.LossSoftmax;
import matrix.Matrix;
import model.Model;
import nonlinearities.LinearUnit;
import nonlinearities.Nonlinearity;
import util.Util;

public class TextGenerationMultiline
extends DataSet {
    public int reportSequenceLength = 100;
    public boolean reportPerplexity = true;
    private Map<String, Integer> charToIndex = new HashMap<String, Integer>();
    private Map<Integer, String> indexToChar = new HashMap<Integer, String>();
    private int dimension;
    private double[] vecStartEnd;
    private final int START_END_TOKEN_INDEX = 0;

    public String generateText(Model model, int steps, boolean argmax, double temperature, Random rng) throws Exception {
        Matrix start = new Matrix(this.dimension);
        start.w[0] = 1.0;
        model.resetState();
        Graph g = new Graph(false);
        model.forward(start, g);
        start.w[0] = 0.0;
        start.w[this.charToIndex.get((Object)"R").intValue()] = 1.0;
        Matrix input = start.clone();
        String result = "R";
        int s = 0;
        while (s < steps) {
            Matrix logprobs = model.forward(input, g);
            Matrix probs = LossSoftmax.getSoftmaxProbs(logprobs, temperature);
            int indxChosen = -1;
            if (argmax) {
                double high = Double.NEGATIVE_INFINITY;
                int i = 0;
                while (i < probs.w.length) {
                    if (probs.w[i] > high) {
                        high = probs.w[i];
                        indxChosen = i;
                    }
                    ++i;
                }
            } else {
                indxChosen = Util.pickIndexFromRandomVector(probs, rng);
            }
            if (indxChosen == 0) {
                return result;
            }
            String ch = this.indexToChar.get(indxChosen);
            result = String.valueOf(result) + ch;
            int i = 0;
            while (i < input.w.length) {
                input.w[i] = 0.0;
                ++i;
            }
            input.w[indxChosen] = 1.0;
            ++s;
        }
        return result;
    }

    public String sequenceToSentence(DataSequence sequence) {
        String result = "\"";
        int s = 0;
        while (s < sequence.steps.size() - 1) {
            DataStep step = sequence.steps.get(s);
            int index = -1;
            int i = 0;
            while (i < step.targetOutput.w.length) {
                if (step.targetOutput.w[i] == 1.0) {
                    index = i;
                    break;
                }
                ++i;
            }
            String ch = this.indexToChar.get(index);
            result = String.valueOf(result) + ch;
            ++s;
        }
        result = String.valueOf(result) + "\"\n";
        return result;
    }

    public TextGenerationMultiline(String path, String splitChar) throws Exception {
        System.out.println("Text generation task");
        System.out.println("loading " + path + "...");
        File file = new File(path);
        List<String> lines_ = Files.readAllLines(file.toPath(), Charset.defaultCharset());
        String text = "";
        int i = 0;
        while (i < lines_.size()) {
            if (i % 1000 == 0) {
                System.out.println(i / 1000);
                System.gc();
            }
            text = String.valueOf(text) + lines_.get(i) + "\n";
            ++i;
        }
        String[] lines = text.split(splitChar);
        text = null;
        lines_.clear();
        lines_ = null;
        System.gc();
        HashSet<String> chars = new HashSet<String>();
        int id = 0;
        this.charToIndex.put("[START/END]", id);
        this.indexToChar.put(id, "[START/END]");
        ++id;
        System.out.println("Characters:");
        System.out.print("\t");
        String[] stringArray = lines;
        int n = lines.length;
        int n2 = 0;
        while (n2 < n) {
            String line = stringArray[n2];
            int i2 = 0;
            while (i2 < line.length()) {
                String ch = String.valueOf(line.charAt(i2));
                if (!chars.contains(ch)) {
                    System.out.print(ch);
                    chars.add(ch);
                    this.charToIndex.put(ch, id);
                    this.indexToChar.put(id, ch);
                    ++id;
                }
                ++i2;
            }
            ++n2;
        }
        this.dimension = chars.size() + 1;
        this.vecStartEnd = new double[this.dimension];
        this.vecStartEnd[0] = 1.0;
        ArrayList<DataSequence> sequences = new ArrayList<DataSequence>();
        int size = 0;
        String[] stringArray2 = lines;
        int n3 = lines.length;
        int n4 = 0;
        while (n4 < n3) {
            String line = stringArray2[n4];
            ArrayList<double[]> vecs = new ArrayList<double[]>();
            vecs.add(this.vecStartEnd);
            int i3 = 0;
            while (i3 < line.length()) {
                String ch = String.valueOf(line.charAt(i3));
                int index = this.charToIndex.get(ch);
                double[] vec = new double[this.dimension];
                vec[index] = 1.0;
                vecs.add(vec);
                ++i3;
            }
            vecs.add(this.vecStartEnd);
            DataSequence sequence = new DataSequence();
            int i4 = 0;
            while (i4 < vecs.size() - 1) {
                sequence.steps.add(new DataStep((double[])vecs.get(i4), (double[])vecs.get(i4 + 1)));
                ++size;
                ++i4;
            }
            sequences.add(sequence);
            ++n4;
        }
        System.out.println("Total unique chars = " + chars.size());
        System.out.println(String.valueOf(size) + " steps in training set.");
        this.training = sequences;
        this.lossTraining = new LossSoftmax();
        this.lossReporting = new LossSoftmax();
        this.inputDimension = ((DataSequence)sequences.get((int)0)).steps.get((int)0).input.w.length;
        int loc = 0;
        while (((DataSequence)sequences.get((int)0)).steps.get((int)loc).targetOutput == null) {
            ++loc;
        }
        this.outputDimension = ((DataSequence)sequences.get((int)0)).steps.get((int)loc).targetOutput.w.length;
    }

    @Override
    public void DisplayReport(Model model, Random rng) throws Exception {
        double[] temperatures;
        System.out.println("========================================");
        System.out.println("REPORT:");
        if (this.reportPerplexity) {
            System.out.println("\ncalculating perplexity over entire data set...");
            double perplexity = LossSoftmax.calculateMedianPerplexity(model, this.training);
            System.out.println("\nMedian Perplexity = " + String.format("%.4f", perplexity));
        }
        double[] dArray = temperatures = new double[]{1.0, 0.75, 0.5, 0.25, 0.1};
        int n = temperatures.length;
        int n2 = 0;
        while (n2 < n) {
            double temperature = dArray[n2];
            System.out.println("\nTemperature " + temperature + " prediction:");
            String guess = this.generateText(model, this.reportSequenceLength, false, temperature, rng);
            System.out.println(guess);
            ++n2;
        }
        System.out.println("\nArgmax prediction:");
        String guess = this.generateText(model, this.reportSequenceLength, true, 1.0, rng);
        System.out.println(guess);
        System.out.println("========================================");
    }

    @Override
    public Nonlinearity getModelOutputUnitToUse() {
        return new LinearUnit();
    }
}

