/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.Graph;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import matrix.Matrix;
import model.DenseLayer;
import nonlinearities.Nonlinearity;
import nonlinearities.SigmoidUnit;
import nonlinearities.TanhUnit;

public class LstmLayer
implements DenseLayer {
    int inputDimension;
    int outputDimension;
    Matrix Wix;
    Matrix Wih;
    Matrix bi;
    Matrix Wfx;
    Matrix Wfh;
    Matrix bf;
    Matrix Wox;
    Matrix Woh;
    Matrix bo;
    Matrix Wcx;
    Matrix Wch;
    Matrix bc;
    Matrix hiddenContext;
    Matrix cellContext;
    Nonlinearity fInputGate = new SigmoidUnit();
    Nonlinearity fForgetGate = new SigmoidUnit();
    Nonlinearity fOutputGate = new SigmoidUnit();
    Nonlinearity fCellInput = new TanhUnit();
    Nonlinearity fCellOutput = new TanhUnit();

    public LstmLayer(int inputDimension, int outputDimension, double initParamsStdDev, Random rng) {
        this.inputDimension = inputDimension;
        this.outputDimension = outputDimension;
        this.Wix = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
        this.Wih = Matrix.rand(outputDimension, outputDimension, initParamsStdDev, rng);
        this.bi = Matrix.uniform(outputDimension, 1, 0.0);
        this.Wfx = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
        this.Wfh = Matrix.rand(outputDimension, outputDimension, initParamsStdDev, rng);
        this.bf = Matrix.ones(outputDimension, 1);
        this.Wox = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
        this.Woh = Matrix.rand(outputDimension, outputDimension, initParamsStdDev, rng);
        this.bo = Matrix.uniform(outputDimension, 1, 0.0);
        this.Wcx = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
        this.Wch = Matrix.rand(outputDimension, outputDimension, initParamsStdDev, rng);
        this.bc = Matrix.uniform(outputDimension, 1, 0.0);
        this.hiddenContext = new Matrix(outputDimension);
        this.cellContext = new Matrix(outputDimension);
    }

    @Override
    public Matrix forward(Matrix input, Graph g) throws Exception {
        Matrix output;
        Matrix sum0 = g.mul(this.Wix, input);
        Matrix sum1 = g.mul(this.Wih, this.hiddenContext);
        Matrix inputGate = g.nonlin(this.fInputGate, g.add(g.add(sum0, sum1), this.bi));
        Matrix sum2 = g.mul(this.Wfx, input);
        Matrix sum3 = g.mul(this.Wfh, this.hiddenContext);
        Matrix forgetGate = g.nonlin(this.fForgetGate, g.add(g.add(sum2, sum3), this.bf));
        Matrix sum4 = g.mul(this.Wox, input);
        Matrix sum5 = g.mul(this.Woh, this.hiddenContext);
        Matrix outputGate = g.nonlin(this.fOutputGate, g.add(g.add(sum4, sum5), this.bo));
        Matrix sum6 = g.mul(this.Wcx, input);
        Matrix sum7 = g.mul(this.Wch, this.hiddenContext);
        Matrix cellInput = g.nonlin(this.fCellInput, g.add(g.add(sum6, sum7), this.bc));
        Matrix retainCell = g.elmul(forgetGate, this.cellContext);
        Matrix writeCell = g.elmul(inputGate, cellInput);
        Matrix cellAct = g.add(retainCell, writeCell);
        this.hiddenContext = output = g.elmul(outputGate, g.nonlin(this.fCellOutput, cellAct));
        this.cellContext = cellAct;
        return output;
    }

    @Override
    public void resetState() {
        this.hiddenContext = new Matrix(this.outputDimension);
        this.cellContext = new Matrix(this.outputDimension);
    }

    @Override
    public List<Matrix> getParameters() {
        ArrayList<Matrix> result = new ArrayList<Matrix>();
        result.add(this.Wix);
        result.add(this.Wih);
        result.add(this.bi);
        result.add(this.Wfx);
        result.add(this.Wfh);
        result.add(this.bf);
        result.add(this.Wox);
        result.add(this.Woh);
        result.add(this.bo);
        result.add(this.Wcx);
        result.add(this.Wch);
        result.add(this.bc);
        return result;
    }

    @Override
    public DenseLayer clone() {
        LstmLayer clone = new LstmLayer(1, 1, 1.0, new Random());
        clone.inputDimension = this.inputDimension;
        clone.outputDimension = this.outputDimension;
        clone.Wix = this.Wix.clone();
        clone.Wih = this.Wih.clone();
        clone.bi = this.bi.clone();
        clone.Wfx = this.Wfx.clone();
        clone.Wfh = this.Wfh.clone();
        clone.bf = this.bf.clone();
        clone.Wox = this.Wox.clone();
        clone.Woh = this.Woh.clone();
        clone.bo = this.bo.clone();
        clone.Wcx = this.Wcx.clone();
        clone.Wch = this.Wch.clone();
        clone.bc = this.bc.clone();
        clone.hiddenContext = this.hiddenContext;
        clone.cellContext = this.cellContext;
        return clone;
    }

    public Matrix getHiddenContext() {
        return this.hiddenContext;
    }

    public Matrix getCellContext() {
        return this.cellContext;
    }

    @Override
    public void saveState(DataOutputStream fos) throws Exception {
        fos.writeInt(this.inputDimension);
        fos.writeInt(this.outputDimension);
        for (Matrix W : this.getParameters()) {
            W.save(fos);
        }
        this.hiddenContext.save(fos);
        this.cellContext.save(fos);
    }

    @Override
    public void loadState(DataInputStream fis) throws Exception {
        this.inputDimension = fis.readInt();
        this.outputDimension = fis.readInt();
        for (Matrix W : this.getParameters()) {
            W.load(fis);
        }
        this.hiddenContext.load(fis);
        this.cellContext.load(fis);
    }
}

