/*
 * Decompiled with CFR 0.152.
 */
package examples.deepRL;

import autodiff.Graph;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.ArrayList;
import java.util.List;
import matrix.Matrix;
import matrix.Tensor;
import model.NeuralNetwork;
import model.TensorLayer;

public class AggregationLayer
implements TensorLayer {
    private NeuralNetwork net1;
    private NeuralNetwork a_net;
    private NeuralNetwork v_net;

    public AggregationLayer(NeuralNetwork net1, NeuralNetwork a_net, NeuralNetwork v_net) {
        this.net1 = net1;
        this.a_net = a_net;
        this.v_net = v_net;
    }

    @Override
    public Tensor forward(Tensor input, Graph g) throws Exception {
        Tensor out1 = this.net1.forward(input, g);
        final Matrix outA = this.a_net.forward((Tensor)out1, (Graph)g).matrices[0];
        if (outA.cols != 1) {
            throw new Exception("Invalid advantage dim");
        }
        final Matrix outV = this.v_net.forward((Tensor)out1, (Graph)g).matrices[0];
        if (outV.w.length != 1) {
            throw new Exception("Invalid value dim");
        }
        double avgA = 0.0;
        double[] dArray = outA.w;
        int n = outA.w.length;
        int n2 = 0;
        while (n2 < n) {
            double d = dArray[n2];
            avgA += d;
            ++n2;
        }
        avgA /= (double)outA.w.length;
        final Tensor finalRes = new Tensor(1, outA.w.length, 1);
        Matrix finalResM = finalRes.matrices[0];
        int i = 0;
        while (i < outA.w.length) {
            finalResM.w[i] = outV.w[0] + (outA.w[i] - avgA);
            ++i;
        }
        if (g.isApplyingBackprop()) {
            g.addBackprop(new Runnable(){

                @Override
                public void run() {
                    Matrix resM = finalRes.matrices[0];
                    double[] dArray = resM.dw;
                    int n = resM.dw.length;
                    int n2 = 0;
                    while (n2 < n) {
                        double d = dArray[n2];
                        outV.dw[0] = outV.dw[0] + d;
                        ++n2;
                    }
                    double factor = ((double)outA.w.length - 1.0) / (double)outA.w.length;
                    int i = 0;
                    while (i < resM.w.length) {
                        outA.dw[i] = resM.dw[i] * factor;
                        ++i;
                    }
                }
            });
        }
        return finalRes;
    }

    @Override
    public void resetState() {
        this.net1.resetState();
        this.a_net.resetState();
        this.v_net.resetState();
    }

    @Override
    public List<Matrix> getParameters() {
        ArrayList<Matrix> l = new ArrayList<Matrix>();
        l.addAll(this.net1.getParameters());
        l.addAll(this.a_net.getParameters());
        l.addAll(this.v_net.getParameters());
        return l;
    }

    @Override
    public void saveState(DataOutputStream fos) throws Exception {
        this.net1.saveState(fos);
        this.a_net.saveState(fos);
        this.v_net.saveState(fos);
    }

    @Override
    public void loadState(DataInputStream fis) throws Exception {
        this.net1.loadState(fis);
        this.a_net.loadState(fis);
        this.v_net.loadState(fis);
    }

    @Override
    public TensorLayer clone() {
        return new AggregationLayer((NeuralNetwork)this.net1.clone(), (NeuralNetwork)this.a_net.clone(), (NeuralNetwork)this.v_net.clone());
    }
}

