/*
 * Decompiled with CFR 0.152.
 */
package theGhastModding.lstmStuff.main;

import autodiff.Graph;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import loss.Loss;
import loss.LossSoftmax;
import matrix.Matrix;
import model.Dropout;
import model.FeedForwardLayer;
import model.LstmLayer;
import model.NeuralNetwork;
import net.dv8tion.jda.core.AccountType;
import net.dv8tion.jda.core.JDA;
import net.dv8tion.jda.core.JDABuilder;
import net.dv8tion.jda.core.events.message.MessageReceivedEvent;
import net.dv8tion.jda.core.hooks.ListenerAdapter;
import nonlinearities.LinearUnit;
import trainer.AMSGrad;
import trainer.Optimizer;
import util.FileChannelInputStream;
import util.FileIO;
import util.Util;

public class ChatbotA2 {
    private static List<List<String>> conversations = null;
    private static Map<Character, Integer> charToIndex = null;
    private static Map<Integer, Character> indexToChar = null;
    private static final int maxChars = 300;
    private static int dimension = -1;
    private static Random staticRNG = new Random();
    private static NeuralNetwork encoder;
    private static NeuralNetwork decoder;
    private static Graph staticGraph;
    private static JDA jda;

    static {
        staticGraph = new Graph(false);
    }

    public static void main(String[] args) {
        try {
            ChatbotA2.loadData();
            ChatbotA2.loadModels();
            if (args.length != 0 && args[0].equalsIgnoreCase("bot")) {
                ChatbotA2.createDiscordBot();
                return;
            }
            int iters = 15;
            int i = 0;
            while (i < iters) {
                System.out.println(String.valueOf(Integer.toString(i + 1)) + "/" + Integer.toString(iters));
                double loss = ChatbotA2.train(8, 5, new LossSoftmax(), new AMSGrad(0.9, 0.99, 1.0E-8), 5.0E-4);
                System.out.println(Double.toString(loss));
                if ((i + 1) % 5 == 0) {
                    ChatbotA2.saveModels();
                    System.out.println("Saved.");
                }
                ++i;
            }
            ChatbotA2.test(0.75);
            ChatbotA2.saveModels();
        }
        catch (Exception e) {
            System.err.println("Error: ");
            e.printStackTrace();
            System.exit(1);
        }
    }

    private static void createDiscordBot() throws Exception {
        System.out.println("Starting discord bot...");
        jda = new JDABuilder(AccountType.BOT).setToken("MjM0ODQ5NzU3MDg2ODc1NjU4.DswRWw.JDIIbB2xZoQzsV6LS2l97r7yZ3g").buildBlocking();
        jda.addEventListener(new MessageListener());
        System.out.println("Bot created!");
        encoder.resetState();
        decoder.resetState();
    }

    private static String getResponseFor(String input, double temperature) throws Exception {
        Matrix inBuffer = new Matrix(dimension);
        inBuffer.w = new double[dimension];
        inBuffer.w[0] = 1.0;
        encoder.forward(inBuffer, staticGraph);
        int k = 0;
        while (k < input.length()) {
            inBuffer.w = ChatbotA2.charToVector(Character.valueOf(input.charAt(k)));
            encoder.forward(inBuffer, staticGraph);
            ++k;
        }
        inBuffer.w = new double[dimension];
        inBuffer.w[0] = 1.0;
        Matrix finalOut = ChatbotA2.encoder.forward((Matrix)inBuffer, (Graph)ChatbotA2.staticGraph).matrices[0];
        String networkResponse = "";
        Matrix concatIn = new Matrix(300 + dimension);
        System.arraycopy(finalOut.w, 0, concatIn.w, 0, 300);
        concatIn.w[300] = 1.0;
        int k2 = 0;
        while (k2 < 300) {
            Matrix out = ChatbotA2.decoder.forward((Matrix)concatIn, (Graph)ChatbotA2.staticGraph).matrices[0];
            Matrix probs = LossSoftmax.getSoftmaxProbs(out, temperature);
            int indx = Util.pickIndexFromRandomVector(probs, staticRNG);
            if (indx == 0) {
                networkResponse = String.valueOf(networkResponse) + "/null\\";
                break;
            }
            networkResponse = String.valueOf(networkResponse) + Character.toString(indexToChar.get(indx).charValue());
            double[] newCharIn = new double[dimension];
            newCharIn[indx] = 1.0;
            System.arraycopy(newCharIn, 0, concatIn.w, 300, dimension);
            ++k2;
        }
        return networkResponse;
    }

    private static void test(double temperature) throws Exception {
        encoder.resetState();
        decoder.resetState();
        List<String> randomConv = conversations.get(staticRNG.nextInt(conversations.size()));
        String input = randomConv.get(0);
        String networkResponse = ChatbotA2.getResponseFor(input, temperature);
        String expectedResponse = randomConv.get(1);
        System.out.println("Input: " + input);
        System.out.println("Response: " + networkResponse);
        System.out.println("Expected response: " + expectedResponse);
        encoder.resetState();
        decoder.resetState();
    }

    private static double train(int iterations, int batchSize, Loss lossToUse, Optimizer optimizer, double learningRate) throws Exception {
        double nomLoss = 0.0;
        double denomLoss = 0.0;
        Graph g = new Graph(true);
        Matrix inBuffer = new Matrix(dimension);
        Matrix targetBuffer = new Matrix(dimension);
        int it = 0;
        while (it < iterations) {
            System.out.println("[" + Integer.toString(it + 1) + "/" + Integer.toString(iterations) + "]");
            ArrayList<List<String>> batch = new ArrayList<List<String>>();
            int i = 0;
            while (i < batchSize) {
                batch.add(conversations.get(staticRNG.nextInt(conversations.size())));
                ++i;
            }
            i = 0;
            while (i < batchSize) {
                List conversation = (List)batch.get(i);
                encoder.resetState();
                decoder.resetState();
                int j = 0;
                while (j < conversation.size() / 2) {
                    inBuffer = new Matrix(dimension);
                    String input = (String)conversation.get(j * 2);
                    inBuffer.w = new double[dimension];
                    inBuffer.w[0] = 1.0;
                    encoder.forward(inBuffer, g);
                    int k = 0;
                    while (k < input.length()) {
                        inBuffer.w = ChatbotA2.charToVector(Character.valueOf(input.charAt(k)));
                        encoder.forward(inBuffer, g);
                        ++k;
                    }
                    inBuffer.w = new double[dimension];
                    inBuffer.w[0] = 1.0;
                    final Matrix finalOut = ChatbotA2.encoder.forward((Matrix)inBuffer, (Graph)g).matrices[0];
                    String response = (String)conversation.get(j * 2 + 1);
                    int k2 = -1;
                    while (k2 < response.length()) {
                        final Matrix concatIn = new Matrix(300 + dimension);
                        System.arraycopy(finalOut.w, 0, concatIn.w, 0, 300);
                        double[] charIn = null;
                        if (k2 == -1) {
                            charIn = new double[dimension];
                            charIn[0] = 1.0;
                        } else {
                            charIn = ChatbotA2.charToVector(Character.valueOf(response.charAt(k2)));
                        }
                        System.arraycopy(charIn, 0, concatIn.w, 300, charIn.length);
                        g.addBackprop(new Runnable(){

                            @Override
                            public void run() {
                                int o = 0;
                                while (o < 300) {
                                    int n = o;
                                    finalOut.dw[n] = finalOut.dw[n] + concatIn.dw[o];
                                    ++o;
                                }
                            }
                        });
                        if (k2 == response.length() - 1) {
                            targetBuffer.w = new double[dimension];
                            targetBuffer.w[0] = 1.0;
                        } else {
                            targetBuffer.w = ChatbotA2.charToVector(Character.valueOf(response.charAt(k2 + 1)));
                        }
                        Matrix out = ChatbotA2.decoder.forward((Matrix)concatIn, (Graph)g).matrices[0];
                        nomLoss += lossToUse.measure(out, targetBuffer);
                        denomLoss += 1.0;
                        lossToUse.backward(out, targetBuffer);
                        ++k2;
                    }
                    ++j;
                }
                g.backward();
                optimizer.updateParameters(decoder, learningRate, batchSize);
                optimizer.updateParameters(encoder, learningRate, batchSize);
                ++i;
            }
            ++it;
        }
        encoder.resetState();
        decoder.resetState();
        g.cleanUp();
        return nomLoss / denomLoss;
    }

    private static void loadModels() throws Exception {
        encoder = new NeuralNetwork();
        encoder.addLayer(new FeedForwardLayer(dimension, 256, new LinearUnit(), 2.0 / Math.sqrt(dimension), staticRNG));
        encoder.addLayer(new Dropout(0.25));
        encoder.addLayer(new LstmLayer(256, 300, 2.0 / Math.sqrt(256.0), staticRNG));
        encoder.addLayer(new Dropout(0.25));
        encoder.addLayer(new LstmLayer(300, 300, 2.0 / Math.sqrt(300.0), staticRNG));
        encoder.addLayer(new FeedForwardLayer(300, 300, new LinearUnit(), 2.0 / Math.sqrt(300.0), staticRNG));
        decoder = new NeuralNetwork();
        decoder.addLayer(new FeedForwardLayer(300 + dimension, 256, new LinearUnit(), 2.0 / Math.sqrt(300 + dimension), staticRNG));
        decoder.addLayer(new Dropout(0.25));
        decoder.addLayer(new LstmLayer(256, 350, 2.0 / Math.sqrt(256.0), staticRNG));
        decoder.addLayer(new Dropout(0.25));
        decoder.addLayer(new LstmLayer(350, 300, 2.0 / Math.sqrt(350.0), staticRNG));
        decoder.addLayer(new FeedForwardLayer(300, dimension, new LinearUnit(), 1.0 / Math.sqrt(300.0), staticRNG));
        if (new File("chatbot_encoder.dat").exists()) {
            FileIO.loadNeuralNetwork("chatbot_encoder.dat", encoder);
        }
        if (new File("chatbot_decoder.dat").exists()) {
            FileIO.loadNeuralNetwork("chatbot_decoder.dat", decoder);
        }
    }

    private static void saveModels() throws Exception {
        FileIO.saveNeuralNetwork("chatbot_encoder.dat", encoder);
        FileIO.saveNeuralNetwork("chatbot_decoder.dat", decoder);
    }

    private static double[] charToVector(Character c) {
        double[] toReturn = new double[dimension];
        Integer x = charToIndex.get(c);
        if (x == null) {
            x = charToIndex.get(Character.valueOf(' '));
        }
        toReturn[x.intValue()] = 1.0;
        return toReturn;
    }

    public static void loadData() throws Exception {
        int n;
        String[] data;
        HashMap<Integer, String> movieLines = new HashMap<Integer, String>();
        FileInputStream fis = new FileInputStream(new File("movie lines/movie_lines.txt"));
        FileChannelInputStream in = new FileChannelInputStream(fis.getChannel());
        BufferedReader reader = new BufferedReader(new InputStreamReader(in));
        String line = reader.readLine();
        HashMap<Character, Integer> charCounts = new HashMap<Character, Integer>();
        ArrayList<Character> unusedChars = new ArrayList<Character>();
        while (line != null) {
            data = line.split(" \\+\\+\\+\\$\\+\\+\\+ ");
            if (data.length != 5) {
                line = reader.readLine();
                continue;
            }
            int lineNum = Integer.parseInt(data[0].substring(1));
            String text = data[4];
            if (text.length() > 300) {
                line = reader.readLine();
                continue;
            }
            char[] cArray = text.toCharArray();
            n = cArray.length;
            int n2 = 0;
            while (n2 < n) {
                char c = cArray[n2];
                if (charCounts.containsKey(Character.valueOf(c))) {
                    charCounts.put(Character.valueOf(c), (Integer)charCounts.get(Character.valueOf(c)) + 1);
                } else {
                    charCounts.put(Character.valueOf(c), 1);
                }
                ++n2;
            }
            movieLines.put(lineNum, text);
            line = reader.readLine();
            if (movieLines.size() % 10000 != 0) continue;
            System.out.println(String.valueOf(Integer.toString(movieLines.size())) + "/304713 lines parsed");
        }
        reader.close();
        in.close();
        fis.close();
        Iterator lineNum = charCounts.keySet().iterator();
        while (lineNum.hasNext()) {
            char c = ((Character)lineNum.next()).charValue();
            if ((Integer)charCounts.get(Character.valueOf(c)) >= 705) continue;
            unusedChars.add(Character.valueOf(c));
        }
        System.out.println(String.valueOf(Integer.toString(movieLines.size())) + " valid lines read");
        conversations = new ArrayList<List<String>>();
        fis = new FileInputStream(new File("movie lines/movie_conversations.txt"));
        in = new FileChannelInputStream(fis.getChannel());
        reader = new BufferedReader(new InputStreamReader(in));
        line = reader.readLine();
        while (line != null) {
            data = line.split(" \\+\\+\\+\\$\\+\\+\\+ ");
            String linedata = data[3];
            linedata = linedata.replace("[", "").replace("]", "").replace("'", "").replace(",", "").replace(" ", "");
            String[] lineNumbers = linedata.split("L");
            ArrayList<String> conversation = new ArrayList<String>();
            String[] stringArray = lineNumbers;
            int n3 = lineNumbers.length;
            n = 0;
            while (n < n3) {
                int lineNumber;
                String lineNumberString = stringArray[n];
                if (!lineNumberString.isEmpty() && movieLines.get(lineNumber = Integer.parseInt(lineNumberString)) != null) {
                    String toAdd = (String)movieLines.get(lineNumber);
                    Iterator iterator = unusedChars.iterator();
                    while (iterator.hasNext()) {
                        char c = ((Character)iterator.next()).charValue();
                        toAdd = toAdd.replace(c, ' ');
                    }
                    conversation.add((String)movieLines.get(lineNumber));
                }
                ++n;
            }
            if (conversation.size() > 3) {
                conversations.add(conversation);
            }
            line = reader.readLine();
        }
        reader.close();
        in.close();
        fis.close();
        System.out.println(String.valueOf(Integer.toString(conversations.size())) + " conversations mapped");
        movieLines.clear();
        movieLines = null;
        System.gc();
        charToIndex = new HashMap<Character, Integer>();
        indexToChar = new HashMap<Integer, Character>();
        ArrayList<Character> chars = new ArrayList<Character>();
        int id = 1;
        charToIndex.put(Character.valueOf('\u0000'), 0);
        indexToChar.put(0, Character.valueOf('\u0000'));
        System.out.println("Characters:");
        System.out.print("\t");
        for (List<String> c : conversations) {
            for (String s : c) {
                s = s.replaceAll("\u201d", "\"");
                s = s.replaceAll("\u2019", "'");
                int i = 0;
                while (i < s.length()) {
                    char ch = s.charAt(i);
                    if (!chars.contains(Character.valueOf(ch)) && !unusedChars.contains(Character.valueOf(ch))) {
                        System.out.print(Character.toString(ch));
                        chars.add(Character.valueOf(ch));
                        charToIndex.put(Character.valueOf(ch), id);
                        indexToChar.put(id, Character.valueOf(ch));
                        ++id;
                    }
                    ++i;
                }
            }
        }
        dimension = charToIndex.size();
        System.out.println("\nTotal unique chars = " + chars.size());
        chars.clear();
        chars = null;
        System.gc();
        System.out.println(String.valueOf(Double.toString((double)(Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory()) / 1024.0 / 1024.0)) + "MB");
    }

    private static class MessageListener
    extends ListenerAdapter {
        private MessageListener() {
        }

        @Override
        public void onMessageReceived(MessageReceivedEvent event) {
            if (event.getMessage().getContentRaw().startsWith("!")) {
                if (event.getMessage().getContentRaw().equals("!stopBot")) {
                    event.getChannel().sendMessage("Bot shut down!").queue();
                    jda.shutdown();
                    System.out.println("Bot shut down!");
                    return;
                }
                String input = event.getMessage().getContentRaw().substring(1);
                if (input.length() > 300) {
                    event.getChannel().sendMessage("Error: Input too long. Must be " + Integer.toString(300) + " characters at most (Your message was " + Integer.toString(input.length()) + "characters long).").queue();
                    return;
                }
                if (input.length() < 2) {
                    event.getChannel().sendMessage("Error: Input too short. Must be at least 2 characters.").queue();
                    return;
                }
                System.out.println(input);
                try {
                    String output = ChatbotA2.getResponseFor(input, 0.75);
                    event.getChannel().sendMessage(output).queue();
                    System.out.println(String.valueOf(output) + "\n");
                }
                catch (Exception e) {
                    event.getChannel().sendMessage("Error processing message: ").queue();
                    System.err.println("Error processing message: ");
                    e.printStackTrace();
                    String errorMessage = "";
                    StringWriter w2 = new StringWriter();
                    PrintWriter w = new PrintWriter(w2);
                    e.printStackTrace(w);
                    errorMessage = w2.getBuffer().toString();
                    event.getChannel().sendMessage(errorMessage).queue();
                    return;
                }
            }
        }
    }
}

