/*
 * Decompiled with CFR 0.152.
 */
package datasets;

import autodiff.Graph;
import datastructs.DataSequence;
import datastructs.DataSet;
import datastructs.DataStep;
import java.io.File;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import loss.LossSoftmax;
import matrix.Matrix;
import model.NeuralNetwork;
import nonlinearities.LinearUnit;
import nonlinearities.Nonlinearity;
import util.Util;

public class TextGeneration
extends DataSet {
    public static int reportSequenceLength = 100;
    private boolean singleWordAutocorrect = false;
    public static boolean reportPerplexity = true;
    private static Map<Character, Integer> charToIndex = new HashMap<Character, Integer>();
    private static Map<Integer, Character> indexToChar = new HashMap<Integer, Character>();
    private static int dimension;
    private static double[] vecStartEnd;
    private static final int START_END_TOKEN_INDEX = 0;
    private static Set<String> words;

    static {
        words = new HashSet<String>();
    }

    public List<String> generateText(NeuralNetwork model, int steps, boolean argmax, double temperature, Random rng) throws Exception {
        ArrayList<String> lines = new ArrayList<String>();
        Matrix start = new Matrix(dimension);
        start.w[0] = 1.0;
        model.resetState();
        Graph g = new Graph(false);
        Matrix input = start.clone();
        String line = "";
        int s = 0;
        while (s < steps) {
            int i;
            Matrix logprobs = model.forward((Matrix)input, (Graph)g).matrices[0];
            Matrix probs = LossSoftmax.getSoftmaxProbs(logprobs, temperature);
            if (this.singleWordAutocorrect) {
                Matrix possible = Matrix.ones(dimension, 1);
                try {
                    possible = TextGeneration.singleWordAutocorrect(line);
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
                double tot = 0.0;
                i = 0;
                while (i < probs.w.length) {
                    int n = i;
                    probs.w[n] = probs.w[n] * possible.w[i];
                    tot += probs.w[i];
                    ++i;
                }
                i = 0;
                while (i < probs.w.length) {
                    int n = i++;
                    probs.w[n] = probs.w[n] / tot;
                }
                i = 0;
                while (i < probs.w.length) {
                    if (probs.w[i] > 0.0 && possible.w[i] == 0.0) {
                        throw new Exception("Illegal transition");
                    }
                    ++i;
                }
            }
            int indxChosen = -1;
            if (argmax) {
                double high = Double.NEGATIVE_INFINITY;
                i = 0;
                while (i < probs.w.length) {
                    if (probs.w[i] > high) {
                        high = probs.w[i];
                        indxChosen = i;
                    }
                    ++i;
                }
            } else {
                indxChosen = Util.pickIndexFromRandomVector(probs, rng);
            }
            if (indxChosen == 0) {
                lines.add(line);
                line = "";
                input = start.clone();
                g = new Graph(false);
                model.resetState();
                input = start.clone();
            } else {
                Character ch = indexToChar.get(indxChosen);
                line = String.valueOf(line) + ch;
                int i2 = 0;
                while (i2 < input.w.length) {
                    input.w[i2] = 0.0;
                    ++i2;
                }
                input.w[indxChosen] = 1.0;
            }
            ++s;
        }
        if (!line.equals("")) {
            lines.add(line);
        }
        return lines;
    }

    private static Matrix singleWordAutocorrect(String sequence) throws Exception {
        if ((sequence = sequence.replace("\"\n\"", " ")).equals("") || sequence.endsWith(" ")) {
            return Matrix.ones(dimension, 1);
        }
        String[] parts = sequence.split(" ");
        String lastPartialWord = parts[parts.length - 1].trim();
        if (lastPartialWord.equals(" ") || lastPartialWord.contains(" ")) {
            throw new Exception("unexpected");
        }
        ArrayList<String> matches = new ArrayList<String>();
        for (String word : words) {
            if (!word.startsWith(lastPartialWord)) continue;
            matches.add(word);
        }
        if (matches.size() == 0) {
            throw new Exception("unexpected, no matches for '" + lastPartialWord + "'");
        }
        Matrix result = new Matrix(dimension);
        boolean hit = false;
        for (String match : matches) {
            if (match.length() < lastPartialWord.length()) {
                throw new Exception("How is match shorter than partial word?");
            }
            if (lastPartialWord.equals(match)) {
                result.w[TextGeneration.charToIndex.get((Object)Character.valueOf((char)' ')).intValue()] = 1.0;
                result.w[0] = 1.0;
                continue;
            }
            char nextChar = match.charAt(lastPartialWord.length());
            result.w[TextGeneration.charToIndex.get((Object)Character.valueOf((char)nextChar)).intValue()] = 1.0;
            hit = true;
        }
        if (!hit) {
            result.w[TextGeneration.charToIndex.get((Object)Character.valueOf((char)' ')).intValue()] = 1.0;
            result.w[0] = 1.0;
        }
        return result;
    }

    public static String sequenceToSentence(DataSequence sequence) {
        String result = "\"";
        int s = 0;
        while (s < sequence.getSequenceLength() - 1) {
            DataStep step = sequence.getDataStep(s);
            int index = -1;
            int i = 0;
            while (i < step.targetOutput.matrices[0].w.length) {
                if (step.targetOutput.matrices[0].w[i] == 1.0) {
                    index = i;
                    break;
                }
                ++i;
            }
            Character ch = indexToChar.get(index);
            result = String.valueOf(result) + ch;
            ++s;
        }
        result = String.valueOf(result) + "\"\n";
        return result;
    }

    public TextGeneration(String path, boolean autocorrect, boolean removeUnused, int maxSequences) throws Exception {
        this.singleWordAutocorrect = autocorrect;
        System.out.println("Text generation task");
        System.out.println("loading " + path + "...");
        File file = new File(path);
        List<String> lines = Files.readAllLines(file.toPath(), Charset.defaultCharset());
        HashSet<Character> chars = new HashSet<Character>();
        int id = 0;
        charToIndex.put(Character.valueOf('\u0000'), id);
        indexToChar.put(id, Character.valueOf('\u0000'));
        ++id;
        ArrayList<Character> leastUsedChars = new ArrayList<Character>();
        if (removeUnused) {
            HashMap<Character, Integer> occurenceCounter = new HashMap<Character, Integer>();
            for (String line : lines) {
                int i = 0;
                while (i < line.length()) {
                    char c = line.charAt(i);
                    if (occurenceCounter.containsKey(Character.valueOf(c))) {
                        int count = (Integer)occurenceCounter.get(Character.valueOf(c));
                        occurenceCounter.put(Character.valueOf(c), count + 1);
                    } else {
                        occurenceCounter.put(Character.valueOf(c), 1);
                    }
                    ++i;
                }
            }
            for (Character c : occurenceCounter.keySet()) {
                if ((Integer)occurenceCounter.get(c) >= 100) continue;
                leastUsedChars.add(c);
            }
        }
        System.out.println("Characters:");
        System.out.print("\t");
        ArrayList<String> linesToUse = new ArrayList<String>();
        Random random = new Random();
        int i = 0;
        while (i < maxSequences) {
            if (lines.isEmpty()) break;
            String lineToAdd = lines.get(random.nextInt(lines.size()));
            lineToAdd = lineToAdd.replaceAll("\u201d", "\"");
            for (Character c : leastUsedChars) {
                String cs = Character.toString(c.charValue());
                if (!lineToAdd.contains(cs)) continue;
                lineToAdd = lineToAdd.replace(cs, "");
            }
            linesToUse.add(lineToAdd);
            ++i;
        }
        for (String line : lines) {
            if (autocorrect) {
                String[] parts = line.split(" ");
                String[] stringArray = parts;
                int n = parts.length;
                int n2 = 0;
                while (n2 < n) {
                    String part = stringArray[n2];
                    words.add(part.trim());
                    ++n2;
                }
            }
            int i2 = 0;
            while (i2 < line.length()) {
                char ch = line.charAt(i2);
                if (!chars.contains(Character.valueOf(ch)) && !leastUsedChars.contains(Character.valueOf(ch))) {
                    System.out.print(ch);
                    chars.add(Character.valueOf(ch));
                    charToIndex.put(Character.valueOf(ch), id);
                    indexToChar.put(id, Character.valueOf(ch));
                    ++id;
                }
                ++i2;
            }
        }
        dimension = chars.size() + 1;
        vecStartEnd = new double[dimension];
        TextGeneration.vecStartEnd[0] = 1.0;
        ArrayList<DataSequence> sequences = new ArrayList<DataSequence>();
        int size = 0;
        int x = 0;
        long startTime = System.currentTimeMillis();
        for (String line : linesToUse) {
            if (System.currentTimeMillis() - startTime >= 1000L) {
                startTime = System.currentTimeMillis();
                System.out.println(String.valueOf(Double.toString((double)x / (double)lines.size() * 100.0)) + "% done");
            }
            ++x;
            ArrayList<double[]> vecs = new ArrayList<double[]>();
            vecs.add(vecStartEnd);
            int i3 = 0;
            while (i3 < line.length()) {
                char ch = line.charAt(i3);
                int index = charToIndex.get(Character.valueOf(ch));
                double[] vec = new double[dimension];
                vec[index] = 1.0;
                vecs.add(vec);
                ++i3;
            }
            vecs.add(vecStartEnd);
            DataSequence sequence = new DataSequence();
            int i4 = 0;
            while (i4 < vecs.size() - 1) {
                sequence.addDataStep(new DataStep((double[])vecs.get(i4), (double[])vecs.get(i4 + 1)));
                ++size;
                ++i4;
            }
            sequences.add(sequence);
        }
        System.out.println("\nTotal unique chars = " + chars.size());
        System.out.println(String.valueOf(size) + " steps in training set.");
        this.training = sequences;
        this.lossTraining = new LossSoftmax();
        this.lossReporting = new LossSoftmax();
        this.inputDimension = new DataSet.TensorDimensions(1, ((DataSequence)sequences.get((int)0)).getDataStep((int)0).input.matrices[0].w.length, 1);
        int loc = 0;
        while (((DataSequence)sequences.get((int)0)).getDataStep((int)loc).targetOutput == null) {
            ++loc;
        }
        this.outputDimension = new DataSet.TensorDimensions(1, ((DataSequence)sequences.get((int)0)).getDataStep((int)loc).targetOutput.matrices[0].w.length, 1);
    }

    @Override
    public void DisplayReport(NeuralNetwork model, Random rng) throws Exception {
        double[] temperatures;
        System.out.println("========================================");
        System.out.println("REPORT:");
        if (reportPerplexity) {
            System.out.println("\ncalculating perplexity over entire data set...");
            double perplexity = LossSoftmax.calculateMedianPerplexity(model, this.training);
            System.out.println("\nMedian Perplexity = " + String.format("%.4f", perplexity));
        }
        double[] dArray = temperatures = new double[]{1.0, 0.75, 0.5, 0.25, 0.1};
        int n = temperatures.length;
        int n2 = 0;
        while (n2 < n) {
            double temperature = dArray[n2];
            if (this.singleWordAutocorrect) {
                System.out.println("\nTemperature " + temperature + " prediction (with single word autocorrect):");
            } else {
                System.out.println("\nTemperature " + temperature + " prediction:");
            }
            List<String> guess = this.generateText(model, reportSequenceLength, false, temperature, rng);
            int i = 0;
            while (i < guess.size()) {
                if (!guess.get(i).isEmpty()) {
                    if (i == guess.size() - 1) {
                        System.out.println("\t\"" + guess.get(i) + "...\"");
                    } else {
                        System.out.println("\t\"" + guess.get(i) + "\"");
                    }
                }
                ++i;
            }
            ++n2;
        }
        if (this.singleWordAutocorrect) {
            System.out.println("\nArgmax prediction (with single word autocorrect):");
        } else {
            System.out.println("\nArgmax prediction:");
        }
        List<String> guess = this.generateText(model, reportSequenceLength, true, 1.0, rng);
        int i = 0;
        while (i < guess.size()) {
            if (!guess.get(i).isEmpty()) {
                if (i == guess.size() - 1) {
                    System.out.println("\t\"" + guess.get(i) + "...\"");
                } else {
                    System.out.println("\t\"" + guess.get(i) + "\"");
                }
            }
            ++i;
        }
        System.out.println("========================================");
    }

    @Override
    public Nonlinearity getModelOutputUnitToUse() {
        return new LinearUnit();
    }
}

