/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.Graph;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import matrix.Matrix;
import model.Model;
import nonlinearities.Nonlinearity;
import nonlinearities.SigmoidUnit;
import nonlinearities.TanhUnit;
import theGhastModding.utils.math.ByteConverters;

public class LstmLayer
implements Model {
    int inputDimension;
    int outputDimension;
    Matrix Wix;
    Matrix Wih;
    Matrix bi;
    Matrix Wfx;
    Matrix Wfh;
    Matrix bf;
    Matrix Wox;
    Matrix Woh;
    Matrix bo;
    Matrix Wcx;
    Matrix Wch;
    Matrix bc;
    Matrix hiddenContext;
    Matrix cellContext;
    Nonlinearity fInputGate = new SigmoidUnit();
    Nonlinearity fForgetGate = new SigmoidUnit();
    Nonlinearity fOutputGate = new SigmoidUnit();
    Nonlinearity fCellInput = new TanhUnit();
    Nonlinearity fCellOutput = new TanhUnit();

    public LstmLayer(int inputDimension, int outputDimension, double initParamsStdDev, Random rng) {
        this.inputDimension = inputDimension;
        this.outputDimension = outputDimension;
        this.Wix = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
        this.Wih = Matrix.rand(outputDimension, outputDimension, initParamsStdDev, rng);
        this.bi = new Matrix(outputDimension);
        this.Wfx = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
        this.Wfh = Matrix.rand(outputDimension, outputDimension, initParamsStdDev, rng);
        this.bf = Matrix.ones(outputDimension, 1);
        this.Wox = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
        this.Woh = Matrix.rand(outputDimension, outputDimension, initParamsStdDev, rng);
        this.bo = new Matrix(outputDimension);
        this.Wcx = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
        this.Wch = Matrix.rand(outputDimension, outputDimension, initParamsStdDev, rng);
        this.bc = new Matrix(outputDimension);
    }

    @Override
    public Matrix forward(Matrix input, Graph g) throws Exception {
        Matrix output;
        Matrix sum0 = g.mul(this.Wix, input);
        Matrix sum1 = g.mul(this.Wih, this.hiddenContext);
        Matrix inputGate = g.nonlin(this.fInputGate, g.add(g.add(sum0, sum1), this.bi));
        Matrix sum2 = g.mul(this.Wfx, input);
        Matrix sum3 = g.mul(this.Wfh, this.hiddenContext);
        Matrix forgetGate = g.nonlin(this.fForgetGate, g.add(g.add(sum2, sum3), this.bf));
        Matrix sum4 = g.mul(this.Wox, input);
        Matrix sum5 = g.mul(this.Woh, this.hiddenContext);
        Matrix outputGate = g.nonlin(this.fOutputGate, g.add(g.add(sum4, sum5), this.bo));
        Matrix sum6 = g.mul(this.Wcx, input);
        Matrix sum7 = g.mul(this.Wch, this.hiddenContext);
        Matrix cellInput = g.nonlin(this.fCellInput, g.add(g.add(sum6, sum7), this.bc));
        Matrix retainCell = g.elmul(forgetGate, this.cellContext);
        Matrix writeCell = g.elmul(inputGate, cellInput);
        Matrix cellAct = g.add(retainCell, writeCell);
        this.hiddenContext = output = g.elmul(outputGate, g.nonlin(this.fCellOutput, cellAct));
        this.cellContext = cellAct;
        return output;
    }

    @Override
    public void resetState() {
        this.hiddenContext = new Matrix(this.outputDimension);
        this.cellContext = new Matrix(this.outputDimension);
    }

    @Override
    public List<Matrix> getParameters() {
        ArrayList<Matrix> result = new ArrayList<Matrix>();
        result.add(this.Wix);
        result.add(this.Wih);
        result.add(this.bi);
        result.add(this.Wfx);
        result.add(this.Wfh);
        result.add(this.bf);
        result.add(this.Wox);
        result.add(this.Woh);
        result.add(this.bo);
        result.add(this.Wcx);
        result.add(this.Wch);
        result.add(this.bc);
        return result;
    }

    @Override
    public void saveState(FileOutputStream fos) throws Exception {
        fos.write(ByteConverters.intToBytes(this.inputDimension));
        fos.write(ByteConverters.intToBytes(this.outputDimension));
        for (Matrix W : this.getParameters()) {
            W.save(fos);
        }
        this.hiddenContext.save(fos);
        this.cellContext.save(fos);
    }

    @Override
    public void loadState(FileInputStream fis) throws Exception {
        byte[] inBuffer = new byte[4];
        fis.read(inBuffer);
        this.inputDimension = ByteConverters.bytesToInt(inBuffer);
        fis.read(inBuffer);
        this.outputDimension = ByteConverters.bytesToInt(inBuffer);
        for (Matrix W : this.getParameters()) {
            W.load(fis);
        }
        this.hiddenContext.load(fis);
        this.hiddenContext.load(fis);
    }
}

