/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.Graph;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.ArrayList;
import java.util.List;
import matrix.Matrix;
import model.Model;

public class NeuralNetwork
implements Model {
    private List<Model> layers = new ArrayList<Model>();
    public int t = 0;

    public NeuralNetwork(List<Model> layers) {
        this.layers = layers;
    }

    public NeuralNetwork() {
    }

    public NeuralNetwork(NeuralNetwork original) {
        int i = 0;
        while (i < original.layers.size()) {
            this.layers.add(original.layers.get(i));
            ++i;
        }
        this.t = original.t;
    }

    public void addLayer(Model layer) {
        this.layers.add(layer);
    }

    @Override
    public Matrix forward(Matrix input, Graph g) throws Exception {
        Matrix prev = input;
        for (Model layer : this.layers) {
            prev = layer.forward(prev, g);
        }
        return prev;
    }

    @Override
    public void resetState() {
        for (Model layer : this.layers) {
            layer.resetState();
        }
    }

    @Override
    public List<Matrix> getParameters() {
        ArrayList<Matrix> result = new ArrayList<Matrix>();
        for (Model layer : this.layers) {
            List<Matrix> p = layer.getParameters();
            if (p == null) continue;
            result.addAll(p);
        }
        return result;
    }

    @Override
    public Model clone() {
        NeuralNetwork clone = new NeuralNetwork();
        clone.t = this.t;
        int i = 0;
        while (i < this.layers.size()) {
            clone.addLayer(this.layers.get(i).clone());
            ++i;
        }
        return clone;
    }

    @Override
    public void saveState(DataOutputStream fos) throws Exception {
        fos.writeInt(this.t);
        for (Model layer : this.layers) {
            layer.saveState(fos);
        }
    }

    @Override
    public void loadState(DataInputStream fis) throws Exception {
        this.t = fis.readInt();
        for (Model layer : this.layers) {
            layer.loadState(fis);
        }
    }
}

