/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.Graph;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import matrix.Matrix;
import model.DenseLayer;
import nonlinearities.Nonlinearity;
import nonlinearities.SigmoidUnit;
import nonlinearities.TanhUnit;

public class GruLayer
implements DenseLayer {
    int inputDimension;
    int outputDimension;
    Matrix IHmix;
    Matrix HHmix;
    Matrix Bmix;
    Matrix IHnew;
    Matrix HHnew;
    Matrix Bnew;
    Matrix IHreset;
    Matrix HHreset;
    Matrix Breset;
    Matrix context;
    Nonlinearity fMix = new SigmoidUnit();
    Nonlinearity fReset = new SigmoidUnit();
    Nonlinearity fNew = new TanhUnit();

    public GruLayer(int inputDimension, int outputDimension, double initParamsStdDev, Random rng) {
        this.inputDimension = inputDimension;
        this.outputDimension = outputDimension;
        this.IHmix = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
        this.HHmix = Matrix.rand(outputDimension, outputDimension, initParamsStdDev, rng);
        this.Bmix = Matrix.uniform(outputDimension, 1, 0.0);
        this.IHnew = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
        this.HHnew = Matrix.rand(outputDimension, outputDimension, initParamsStdDev, rng);
        this.Bnew = Matrix.uniform(outputDimension, 1, 0.0);
        this.IHreset = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
        this.HHreset = Matrix.rand(outputDimension, outputDimension, initParamsStdDev, rng);
        this.Breset = new Matrix(outputDimension);
        this.context = Matrix.uniform(outputDimension, 1, 0.0);
    }

    @Override
    public Matrix forward(Matrix input, Graph g) throws Exception {
        Matrix output;
        Matrix sum0 = g.mul(this.IHmix, input);
        Matrix sum1 = g.mul(this.HHmix, this.context);
        Matrix actMix = g.nonlin(this.fMix, g.add(g.add(sum0, sum1), this.Bmix));
        Matrix sum2 = g.mul(this.IHreset, input);
        Matrix sum3 = g.mul(this.HHreset, this.context);
        Matrix actReset = g.nonlin(this.fReset, g.add(g.add(sum2, sum3), this.Breset));
        Matrix sum4 = g.mul(this.IHnew, input);
        Matrix gatedContext = g.elmul(actReset, this.context);
        Matrix sum5 = g.mul(this.HHnew, gatedContext);
        Matrix actNewPlusGatedContext = g.nonlin(this.fNew, g.add(g.add(sum4, sum5), this.Bnew));
        Matrix memvals = g.elmul(actMix, this.context);
        Matrix newvals = g.elmul(g.oneMinus(actMix), actNewPlusGatedContext);
        this.context = output = g.add(memvals, newvals);
        return output;
    }

    @Override
    public void resetState() {
        this.context = new Matrix(this.outputDimension);
    }

    public Matrix getContext() {
        return this.context;
    }

    @Override
    public List<Matrix> getParameters() {
        ArrayList<Matrix> result = new ArrayList<Matrix>();
        result.add(this.IHmix);
        result.add(this.HHmix);
        result.add(this.Bmix);
        result.add(this.IHnew);
        result.add(this.HHnew);
        result.add(this.Bnew);
        result.add(this.IHreset);
        result.add(this.HHreset);
        result.add(this.Breset);
        return result;
    }

    @Override
    public DenseLayer clone() {
        GruLayer clone = new GruLayer(1, 1, 1.0, new Random());
        clone.inputDimension = this.inputDimension;
        clone.outputDimension = this.outputDimension;
        clone.IHmix = this.IHmix.clone();
        clone.HHmix = this.HHmix.clone();
        clone.Bmix = this.Bmix.clone();
        clone.IHnew = this.IHnew.clone();
        clone.HHnew = this.HHnew.clone();
        clone.Bnew = this.Bnew.clone();
        clone.IHreset = this.IHreset.clone();
        clone.HHreset = this.HHreset.clone();
        clone.Breset = this.Breset.clone();
        clone.context = this.context.clone();
        clone.fMix = this.fMix;
        clone.fReset = this.fReset;
        clone.fNew = this.fNew;
        return clone;
    }

    @Override
    public void saveState(DataOutputStream fos) throws Exception {
        fos.writeInt(this.inputDimension);
        fos.writeInt(this.outputDimension);
        for (Matrix W : this.getParameters()) {
            W.save(fos);
        }
        this.context.save(fos);
    }

    @Override
    public void loadState(DataInputStream fis) throws Exception {
        this.inputDimension = fis.readInt();
        this.outputDimension = fis.readInt();
        for (Matrix W : this.getParameters()) {
            W.load(fis);
        }
        this.context.load(fis);
    }
}

