/*
 * Decompiled with CFR 0.152.
 */
package datasets;

import datastructs.DataSequence;
import datastructs.DataSet;
import datastructs.DataStep;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import loss.LossMultiDimensionalBinary;
import loss.LossSumOfSquares;
import model.Model;
import nonlinearities.Nonlinearity;
import nonlinearities.SigmoidUnit;
import theGhastModding.console.main.Console;

public class EmbeddedReberGrammar
extends DataSet {
    public EmbeddedReberGrammar(Random r) throws Exception {
        int total_sequences = 1000;
        this.inputDimension = 7;
        this.outputDimension = 7;
        this.lossTraining = new LossSumOfSquares();
        this.lossReporting = new LossMultiDimensionalBinary();
        this.training = EmbeddedReberGrammar.generateSequences(r, total_sequences);
        this.validation = EmbeddedReberGrammar.generateSequences(r, total_sequences);
        this.testing = EmbeddedReberGrammar.generateSequences(r, total_sequences);
    }

    public static List<DataSequence> generateSequences(Random r, int sequences) {
        ArrayList<DataSequence> result = new ArrayList<DataSequence>();
        boolean B = false;
        boolean T = true;
        int P = 2;
        int S = 3;
        int X = 4;
        int V = 5;
        int E = 6;
        State[] states = new State[]{new State(new Transition[]{new Transition(1, 0)}), new State(new Transition[]{new Transition(2, 1), new Transition(11, 2)}), new State(new Transition[]{new Transition(3, 0)}), new State(new Transition[]{new Transition(4, 1), new Transition(9, 2)}), new State(new Transition[]{new Transition(4, 3), new Transition(5, 4)}), new State(new Transition[]{new Transition(6, 3), new Transition(9, 4)}), new State(new Transition[]{new Transition(7, 6)}), new State(new Transition[]{new Transition(8, 1)}), new State(new Transition[]{new Transition(0, 6)}), new State(new Transition[]{new Transition(9, 1), new Transition(10, 5)}), new State(new Transition[]{new Transition(5, 2), new Transition(6, 5)}), new State(new Transition[]{new Transition(12, 0)}), new State(new Transition[]{new Transition(13, 1), new Transition(17, 2)}), new State(new Transition[]{new Transition(13, 3), new Transition(14, 4)}), new State(new Transition[]{new Transition(15, 3), new Transition(17, 4)}), new State(new Transition[]{new Transition(16, 6)}), new State(new Transition[]{new Transition(8, 2)}), new State(new Transition[]{new Transition(17, 1), new Transition(18, 5)}), new State(new Transition[]{new Transition(14, 2), new Transition(15, 5)})};
        int sequence = 0;
        while (sequence < sequences) {
            ArrayList<DataStep> steps = new ArrayList<DataStep>();
            int state_id = 0;
            while (true) {
                int transition = -1;
                if (states[state_id].transitions.length == 1) {
                    transition = 0;
                } else if (states[state_id].transitions.length == 2) {
                    transition = r.nextInt(2);
                }
                double[] observation = null;
                observation = new double[7];
                observation[states[state_id].transitions[transition].token] = 1.0;
                state_id = states[state_id].transitions[transition].next_state_id;
                if (state_id == 0) break;
                double[] target_output = new double[7];
                int i = 0;
                while (i < states[state_id].transitions.length) {
                    target_output[states[state_id].transitions[i].token] = 1.0;
                    ++i;
                }
                steps.add(new DataStep(observation, target_output));
            }
            result.add(new DataSequence(steps));
            ++sequence;
        }
        return result;
    }

    @Override
    public void DisplayReport(Model model, Random rng, Console c) throws Exception {
    }

    @Override
    public Nonlinearity getModelOutputUnitToUse() {
        return new SigmoidUnit();
    }

    public static class State {
        public Transition[] transitions;

        public State(Transition[] transitions) {
            this.transitions = transitions;
        }
    }

    public static class Transition {
        public int next_state_id;
        public int token;

        public Transition(int next_state_id, int token) {
            this.next_state_id = next_state_id;
            this.token = token;
        }
    }
}

