/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.Graph;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import matrix.Matrix;
import model.FeedForwardLayer;
import model.LstmLayer;
import model.Model;
import nonlinearities.ReLuUnit;

public class MultipathLayer
implements Model {
    private Model[] layers;

    public MultipathLayer(int inputDimension, int outputDimension, boolean recurrent, int paths, double initParamsStdDev, Random rng) {
        this.layers = new Model[paths];
        int i = 0;
        while (i < paths) {
            this.layers[i] = recurrent ? new LstmLayer(inputDimension, outputDimension, initParamsStdDev, rng) : new FeedForwardLayer(inputDimension, outputDimension, new ReLuUnit(), initParamsStdDev, rng);
            ++i;
        }
    }

    @Override
    public Matrix forward(Matrix input, Graph g) throws Exception {
        Matrix[] allOutputs = new Matrix[this.layers.length];
        Thread[] ts = new Thread[this.layers.length];
        ForwarderThread[] fts = new ForwarderThread[this.layers.length];
        int i = 0;
        while (i < this.layers.length) {
            fts[i] = new ForwarderThread(input, g, i);
            ts[i] = new Thread(fts[i]);
            ++i;
        }
        i = 0;
        while (i < this.layers.length) {
            ts[i].start();
            ++i;
        }
        i = 0;
        while (i < this.layers.length) {
            try {
                ts[i].join();
                if (fts[i].e != null) {
                    throw fts[i].e;
                }
                g.addBackprops(fts[i].g.getBackprops());
                allOutputs[i] = fts[i].output;
            }
            catch (Exception e) {
                e.printStackTrace();
            }
            ++i;
        }
        Matrix finalOutput = allOutputs[0];
        int i2 = 1;
        while (i2 < allOutputs.length) {
            finalOutput = g.add(finalOutput, allOutputs[i2]);
            ++i2;
        }
        return finalOutput;
    }

    @Override
    public void resetState() {
        Model[] modelArray = this.layers;
        int n = this.layers.length;
        int n2 = 0;
        while (n2 < n) {
            Model l = modelArray[n2];
            l.resetState();
            ++n2;
        }
    }

    @Override
    public List<Matrix> getParameters() {
        ArrayList<Matrix> toReturn = new ArrayList<Matrix>();
        Model[] modelArray = this.layers;
        int n = this.layers.length;
        int n2 = 0;
        while (n2 < n) {
            Model l = modelArray[n2];
            toReturn.addAll(l.getParameters());
            ++n2;
        }
        return toReturn;
    }

    @Override
    public Model clone() {
        MultipathLayer clone = new MultipathLayer(1, 1, false, 1, 1.0, new Random());
        clone.layers = new Model[this.layers.length];
        int i = 0;
        while (i < this.layers.length) {
            clone.layers[i] = this.layers[i].clone();
            ++i;
        }
        return clone;
    }

    @Override
    public void saveState(DataOutputStream fos) throws Exception {
        Model[] modelArray = this.layers;
        int n = this.layers.length;
        int n2 = 0;
        while (n2 < n) {
            Model l = modelArray[n2];
            l.saveState(fos);
            ++n2;
        }
    }

    @Override
    public void loadState(DataInputStream fis) throws Exception {
        int i = 0;
        while (i < this.layers.length) {
            this.layers[i].loadState(fis);
            ++i;
        }
    }

    private class ForwarderThread
    implements Runnable {
        private Matrix m;
        private Graph g;
        public Matrix output;
        private int indx;
        public Exception e = null;

        public ForwarderThread(Matrix m, Graph g, int indx) throws Exception {
            this.m = m;
            this.g = new Graph(g.applyBackprop());
            this.indx = indx;
            this.e = null;
        }

        @Override
        public void run() {
            try {
                this.output = MultipathLayer.this.layers[this.indx].forward(this.m, this.g);
            }
            catch (Exception e) {
                this.e = e;
            }
        }
    }
}

