/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.Graph;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.ArrayList;
import java.util.List;
import matrix.Matrix;
import matrix.Tensor;
import model.NeuralNetwork;
import model.TensorLayer;

public class ConvNet
implements TensorLayer {
    private List<TensorLayer> layers = new ArrayList<TensorLayer>();
    private NeuralNetwork fullyConnected = null;
    public int t = 0;

    public ConvNet() {
    }

    public ConvNet(List<TensorLayer> layers) {
        this.layers = layers;
    }

    @Override
    public Tensor forward(Tensor input, Graph g) throws Exception {
        Tensor prev = input;
        for (TensorLayer l : this.layers) {
            prev = l.forward(prev, g);
        }
        if (this.fullyConnected != null && prev.getDepth() == 1) {
            Matrix out = this.fullyConnected.forward(prev.getMatrixAt(0), g);
            Tensor toReturn = new Tensor(out.cols, out.rows, 1);
            toReturn.setMatrixAt(0, out);
            return toReturn;
        }
        return prev;
    }

    public void addLayer(TensorLayer layer) {
        this.layers.add(layer);
    }

    public TensorLayer getLayer(int indx) {
        return this.layers.get(indx);
    }

    public int getLayerCount() {
        return this.layers.size();
    }

    public synchronized List<TensorLayer> getLayers() {
        return this.layers;
    }

    public void setFullyConnected(NeuralNetwork nn) {
        this.fullyConnected = nn;
    }

    public synchronized NeuralNetwork getFullyConnected() {
        return this.fullyConnected;
    }

    @Override
    public void resetState() {
        for (TensorLayer l : this.layers) {
            l.resetState();
        }
        if (this.fullyConnected != null) {
            this.fullyConnected.resetState();
        }
    }

    @Override
    public List<Matrix> getParameters() {
        ArrayList<Matrix> toReturn = new ArrayList<Matrix>();
        for (TensorLayer l : this.layers) {
            List<Matrix> params = l.getParameters();
            if (params == null) continue;
            toReturn.addAll(params);
        }
        if (this.fullyConnected != null) {
            toReturn.addAll(this.fullyConnected.getParameters());
        }
        return toReturn;
    }

    @Override
    public TensorLayer clone() {
        ConvNet clone = new ConvNet();
        clone.t = this.t;
        clone.fullyConnected = (NeuralNetwork)this.fullyConnected.clone();
        int i = 0;
        while (i < this.layers.size()) {
            clone.layers.add(this.layers.get(i).clone());
            ++i;
        }
        return clone;
    }

    @Override
    public void saveState(DataOutputStream fos) throws Exception {
        fos.writeInt(this.t);
        for (TensorLayer tl : this.layers) {
            tl.saveState(fos);
        }
        fos.writeInt(this.fullyConnected != null ? 1 : 0);
        if (this.fullyConnected != null) {
            this.fullyConnected.saveState(fos);
        }
    }

    @Override
    public void loadState(DataInputStream fis) throws Exception {
        this.t = fis.readInt();
        for (TensorLayer tl : this.layers) {
            tl.loadState(fis);
        }
        if (fis.readInt() == 1) {
            this.fullyConnected.loadState(fis);
        }
    }
}

