/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.Graph;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.ArrayList;
import java.util.List;
import matrix.Matrix;
import model.DenseLayer;

public class ConcatLayer
implements DenseLayer {
    private DenseLayer n1;
    private DenseLayer n2;

    public ConcatLayer(DenseLayer n1, DenseLayer n2) {
        this.n1 = n1;
        this.n2 = n2;
    }

    @Override
    public Matrix forward(Matrix input, Graph g) throws Exception {
        Matrix m1 = this.n1.forward(input, g);
        Matrix m2 = this.n2.forward(input, g);
        return g.concatVectors(m1, m2);
    }

    @Override
    public void resetState() {
        this.n1.resetState();
        this.n2.resetState();
    }

    @Override
    public List<Matrix> getParameters() {
        ArrayList<Matrix> params = new ArrayList<Matrix>();
        params.addAll(this.n1.getParameters());
        params.addAll(this.n2.getParameters());
        return params;
    }

    @Override
    public void saveState(DataOutputStream fos) throws Exception {
        this.n1.saveState(fos);
        this.n2.saveState(fos);
    }

    @Override
    public void loadState(DataInputStream fis) throws Exception {
        this.n1.loadState(fis);
        this.n2.loadState(fis);
    }

    @Override
    public DenseLayer clone() {
        return new ConcatLayer(this.n1.clone(), this.n2.clone());
    }
}

