/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.Graph;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import matrix.Matrix;
import model.DenseLayer;
import nonlinearities.Nonlinearity;

public class RnnLayer
implements DenseLayer {
    int inputDimension;
    int outputDimension;
    Matrix W;
    Matrix b;
    Matrix context;
    Nonlinearity f;

    public RnnLayer(int inputDimension, int outputDimension, Nonlinearity hiddenUnit, double initParamsStdDev, Random rng) {
        this.inputDimension = inputDimension;
        this.outputDimension = outputDimension;
        this.context = new Matrix(outputDimension);
        this.f = hiddenUnit;
        this.W = Matrix.rand(outputDimension, inputDimension + outputDimension, initParamsStdDev, rng);
        this.b = Matrix.uniform(outputDimension, 1, 0.0);
    }

    @Override
    public Matrix forward(Matrix input, Graph g) throws Exception {
        Matrix output;
        Matrix concat = g.concatVectors(input, this.context);
        Matrix sum = g.mul(this.W, concat);
        sum = g.add(sum, this.b);
        this.context = output = g.nonlin(this.f, sum);
        return output;
    }

    @Override
    public void resetState() {
        this.context = new Matrix(this.outputDimension);
    }

    public Matrix getContext() {
        return this.context;
    }

    @Override
    public List<Matrix> getParameters() {
        ArrayList<Matrix> result = new ArrayList<Matrix>();
        result.add(this.W);
        result.add(this.b);
        return result;
    }

    @Override
    public DenseLayer clone() {
        RnnLayer clone = new RnnLayer(1, 1, null, 1.0, new Random());
        clone.inputDimension = this.inputDimension;
        clone.outputDimension = this.outputDimension;
        clone.W = this.W.clone();
        clone.b = this.b.clone();
        clone.context = this.context.clone();
        clone.f = this.f;
        return clone;
    }

    @Override
    public void saveState(DataOutputStream fos) throws Exception {
        fos.writeInt(this.inputDimension);
        fos.writeInt(this.outputDimension);
        for (Matrix W : this.getParameters()) {
            W.save(fos);
        }
        this.context.save(fos);
    }

    @Override
    public void loadState(DataInputStream fis) throws Exception {
        this.inputDimension = fis.readInt();
        this.outputDimension = fis.readInt();
        for (Matrix W : this.getParameters()) {
            W.load(fis);
        }
        this.context.load(fis);
    }
}

