/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.Graph;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import matrix.Matrix;
import matrix.Tensor;
import model.TensorLayer;

public class NormalizeLayer
implements TensorLayer {
    private Matrix params;
    private double epsilon;

    public NormalizeLayer(double epsilon, double initParamsStdDev, Random rng) {
        this.epsilon = epsilon;
        this.params = Matrix.rand(2, 1, initParamsStdDev, rng);
    }

    @Override
    public Tensor forward(final Tensor input, Graph g) throws Exception {
        final Tensor toReturn = new Tensor(input.width, input.height, input.depth);
        final Tensor x_hats = new Tensor(input.width, input.height, input.depth);
        final int inSize = input.width * input.height;
        final Matrix means = new Matrix(input.depth);
        Matrix variances = new Matrix(input.depth);
        final Matrix sqrtCache = new Matrix(input.depth);
        int i = 0;
        while (i < input.depth) {
            Matrix currMatrix = input.matrices[i];
            Matrix currOutMatrix = toReturn.matrices[i];
            Matrix currX_hat = x_hats.matrices[i];
            double mean = 0.0;
            double[] dArray = currMatrix.w;
            int n = currMatrix.w.length;
            int n2 = 0;
            while (n2 < n) {
                double d = dArray[n2];
                mean += d;
                ++n2;
            }
            means.w[i] = mean /= (double)inSize;
            double variance = 0.0;
            double diff = 0.0;
            double[] dArray2 = currMatrix.w;
            int n3 = currMatrix.w.length;
            int n4 = 0;
            while (n4 < n3) {
                double d = dArray2[n4];
                diff = d - mean;
                variance += diff * diff;
                ++n4;
            }
            variances.w[i] = variance /= (double)inSize;
            sqrtCache.w[i] = Math.sqrt(variance + this.epsilon);
            int j = 0;
            while (j < inSize) {
                currX_hat.w[j] = (currMatrix.w[j] - mean) / sqrtCache.w[i];
                ++j;
            }
            j = 0;
            while (j < inSize) {
                currOutMatrix.w[j] = this.params.w[0] * currX_hat.w[j] + this.params.w[1];
                ++j;
            }
            ++i;
        }
        if (g.isApplyingBackprop()) {
            g.addBackprop(new Runnable(){

                @Override
                public void run() {
                    int i = 0;
                    while (i < input.depth) {
                        Matrix currMatrix = input.matrices[i];
                        Matrix currOutMatrix = toReturn.matrices[i];
                        Matrix currX_hat = x_hats.matrices[i];
                        double varianceGrad = 0.0;
                        int j = 0;
                        while (j < inSize) {
                            NormalizeLayer.this.params.dw[1] = NormalizeLayer.this.params.dw[1] + currOutMatrix.dw[j];
                            NormalizeLayer.this.params.dw[0] = NormalizeLayer.this.params.dw[0] + currOutMatrix.dw[j] * currX_hat.w[j];
                            currX_hat.dw[j] = currOutMatrix.dw[j] * NormalizeLayer.this.params.w[0];
                            varianceGrad += currX_hat.dw[j] * (currMatrix.w[j] - means.w[i]);
                            int n = j++;
                            currX_hat.dw[n] = currX_hat.dw[n] * (1.0 / sqrtCache.w[i]);
                        }
                        varianceGrad *= -1.0 / (sqrtCache.w[i] * sqrtCache.w[i]);
                        varianceGrad = 0.5 * (1.0 / sqrtCache.w[i]) * varianceGrad;
                        Matrix dsq = Matrix.uniform(input.height, input.width, varianceGrad /= (double)inSize);
                        double meanGrad = 0.0;
                        int j2 = 0;
                        while (j2 < inSize) {
                            int n = j2;
                            dsq.w[n] = dsq.w[n] * (2.0 * (currMatrix.w[j2] - means.w[i]));
                            ++j2;
                        }
                        j2 = 0;
                        while (j2 < inSize) {
                            meanGrad += dsq.w[j2] + currX_hat.dw[j2];
                            ++j2;
                        }
                        j2 = 0;
                        while (j2 < inSize) {
                            int n = j2;
                            dsq.w[n] = dsq.w[n] + currX_hat.dw[j2];
                            ++j2;
                        }
                        meanGrad = -meanGrad;
                        Matrix dx2 = Matrix.uniform(input.height, input.width, meanGrad /= (double)inSize);
                        int j3 = 0;
                        while (j3 < inSize) {
                            int n = j3;
                            currMatrix.dw[n] = currMatrix.dw[n] + (dx2.w[j3] + dsq.w[j3]);
                            ++j3;
                        }
                        ++i;
                    }
                }
            });
        }
        return toReturn;
    }

    @Override
    public void resetState() {
    }

    @Override
    public List<Matrix> getParameters() {
        ArrayList<Matrix> l = new ArrayList<Matrix>();
        l.add(this.params);
        return l;
    }

    @Override
    public void saveState(DataOutputStream fos) throws Exception {
        this.params.save(fos);
        fos.writeDouble(this.epsilon);
    }

    @Override
    public void loadState(DataInputStream fis) throws Exception {
        this.params.load(fis);
        this.epsilon = fis.readDouble();
    }

    @Override
    public TensorLayer clone() {
        NormalizeLayer clone = new NormalizeLayer(this.epsilon, 1.0, new Random());
        clone.params = this.params.clone();
        return clone;
    }
}

