/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.Graph;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import matrix.Matrix;
import model.DenseLayer;

public class LinearLayer
implements DenseLayer {
    Matrix W;

    public LinearLayer(int inputDimension, int outputDimension, double initParamsStdDev, Random rng) {
        this.W = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
    }

    @Override
    public Matrix forward(Matrix input, Graph g) throws Exception {
        Matrix out = g.mul(this.W, input);
        return out;
    }

    @Override
    public void resetState() {
    }

    @Override
    public List<Matrix> getParameters() {
        ArrayList<Matrix> result = new ArrayList<Matrix>();
        result.add(this.W);
        return result;
    }

    @Override
    public DenseLayer clone() {
        LinearLayer clone = new LinearLayer(1, 1, 1.0, new Random());
        clone.W = this.W.clone();
        return clone;
    }

    @Override
    public void saveState(DataOutputStream fos) throws Exception {
        this.W.save(fos);
    }

    @Override
    public void loadState(DataInputStream fis) throws Exception {
        this.W.load(fis);
    }
}

