/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.Graph;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import matrix.Matrix;
import matrix.Tensor;
import model.ConvDense;
import model.DenseLayer;
import model.TensorLayer;

public class NeuralNetwork
implements TensorLayer {
    private List<TensorLayer> layers = new ArrayList<TensorLayer>();
    public int t = 0;

    @Override
    public Tensor forward(Tensor input, Graph g) throws Exception {
        Tensor prev = input;
        for (TensorLayer l : this.layers) {
            prev = l.forward(prev, g);
        }
        return prev;
    }

    public Tensor forward(Matrix input, Graph g) throws Exception {
        return this.forward(new Tensor(input), g);
    }

    public void addLayer(TensorLayer layer) {
        this.layers.add(layer);
    }

    public void addLayer(DenseLayer layer) {
        this.layers.add(new ConvDense(layer));
    }

    public TensorLayer getLayer(int indx) {
        return this.layers.get(indx);
    }

    public int getLayerCount() {
        return this.layers.size();
    }

    public synchronized List<TensorLayer> getLayers() {
        return this.layers;
    }

    @Override
    public void resetState() {
        for (TensorLayer l : this.layers) {
            l.resetState();
        }
    }

    public void resetGradients() {
        for (TensorLayer l : this.layers) {
            List<Matrix> params = l.getParameters();
            if (params == null) continue;
            for (Matrix m : params) {
                Arrays.fill(m.dw, 0.0);
            }
        }
    }

    @Override
    public List<Matrix> getParameters() {
        ArrayList<Matrix> toReturn = new ArrayList<Matrix>();
        for (TensorLayer l : this.layers) {
            List<Matrix> params = l.getParameters();
            if (params == null) continue;
            toReturn.addAll(params);
        }
        return toReturn;
    }

    @Override
    public TensorLayer clone() {
        NeuralNetwork clone = new NeuralNetwork();
        clone.t = this.t;
        int i = 0;
        while (i < this.layers.size()) {
            clone.layers.add(this.layers.get(i).clone());
            ++i;
        }
        return clone;
    }

    @Override
    public void saveState(DataOutputStream fos) throws Exception {
        fos.writeInt(this.t);
        for (TensorLayer tl : this.layers) {
            tl.saveState(fos);
        }
    }

    @Override
    public void loadState(DataInputStream fis) throws Exception {
        this.t = fis.readInt();
        for (TensorLayer tl : this.layers) {
            tl.loadState(fis);
        }
    }
}

