/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.Graph;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.List;
import java.util.Random;
import matrix.Matrix;
import matrix.Tensor;
import model.ConvLayer;
import model.TensorLayer;

public class ConvSpectralNorm
implements TensorLayer {
    private ConvLayer conv;
    private Matrix[] u;
    private Matrix test;
    private double epsilon = 1.0E-12;
    private Graph g = new Graph(false);
    private Matrix[] clonedFilters;
    private Matrix[] iteredFilters;

    public ConvSpectralNorm(ConvLayer conv, Random rng) {
        this.conv = conv;
        this.test = new Matrix(1);
        this.test.dw[0] = 1.0;
        int filterWidth = conv.getFilterWidth();
        int filterHeight = conv.getFilterHeight();
        if (filterWidth != filterHeight) {
            System.err.println("Filters must be square");
        }
        this.u = new Matrix[conv.getFilterCount()];
        int i = 0;
        while (i < conv.getFilterCount()) {
            this.u[i] = Matrix.rand(1, filterWidth, 1.0, rng);
            ++i;
        }
        this.clonedFilters = new Matrix[conv.getFilterCount()];
        this.iteredFilters = new Matrix[conv.getFilterCount()];
    }

    @Override
    public Tensor forward(Tensor input, Graph g) throws Exception {
        int i;
        if (this.test.w[0] != 0.0 && g.isApplyingBackprop() || this.iteredFilters[0] == null) {
            this.test.w[0] = 0.0;
            this.test.dw[0] = 1.0;
            this.test.stepCache[0] = 0.0;
            this.test.stepCache[1] = 0.0;
            i = 0;
            while (i < this.conv.getFilterCount()) {
                this.clonedFilters[i] = this.conv.getFilter(i).clone();
                this.powerIter(i);
                this.iteredFilters[i] = this.conv.getFilter(i).clone();
                ++i;
            }
        } else {
            i = 0;
            while (i < this.conv.getFilterCount()) {
                this.conv.setFilter(i, this.iteredFilters[i]);
                ++i;
            }
        }
        if (g.isApplyingBackprop()) {
            g.addBackprop(new Runnable(){

                @Override
                public void run() {
                    int i = 0;
                    while (i < ConvSpectralNorm.this.conv.getFilterCount()) {
                        ConvSpectralNorm.this.conv.setFilter(i, ConvSpectralNorm.this.clonedFilters[i]);
                        ++i;
                    }
                }
            });
        }
        Tensor t = this.conv.forward(input, g);
        if (!g.isApplyingBackprop()) {
            int i2 = 0;
            while (i2 < this.conv.getFilterCount()) {
                this.conv.setFilter(i2, this.clonedFilters[i2]);
                ++i2;
            }
        }
        return t;
    }

    private void powerIter(int indx) throws Exception {
        Matrix W = this.conv.getFilter(indx);
        Matrix newv = this.g.mul(this.u[indx], W);
        this.L2Normalize(newv);
        Matrix newu = this.g.mul(newv, Matrix.transpose(W));
        this.L2Normalize(newu);
        Matrix W_sn = this.g.elmul(this.g.mul(newu, W), newv);
        double sum = 0.0;
        double[] dArray = W_sn.w;
        int n = W_sn.w.length;
        int n2 = 0;
        while (n2 < n) {
            double d = dArray[n2];
            sum += d;
            ++n2;
        }
        int i = 0;
        while (i < W.w.length) {
            int n3 = i++;
            W.w[n3] = W.w[n3] / sum;
        }
        this.u[indx] = newu;
    }

    private void L2Normalize(Matrix m) throws Exception {
        double sum = 0.0;
        double[] dArray = m.w;
        int n = m.w.length;
        int n2 = 0;
        while (n2 < n) {
            double d = dArray[n2];
            sum += d * d;
            ++n2;
        }
        sum = Math.pow(sum, 0.5);
        sum += this.epsilon;
        int i = 0;
        while (i < m.w.length) {
            int n3 = i++;
            m.w[n3] = m.w[n3] / sum;
        }
    }

    @Override
    public void resetState() {
        this.conv.resetState();
    }

    @Override
    public List<Matrix> getParameters() {
        List<Matrix> a = this.conv.getParameters();
        a.add(this.test);
        return a;
    }

    @Override
    public void saveState(DataOutputStream fos) throws Exception {
        this.conv.saveState(fos);
        int i = 0;
        while (i < this.conv.getFilterCount()) {
            this.u[i].save(fos);
            this.clonedFilters[i].save(fos);
            this.iteredFilters[i].save(fos);
            ++i;
        }
        this.test.save(fos);
        fos.writeDouble(this.epsilon);
    }

    @Override
    public void loadState(DataInputStream fis) throws Exception {
        this.conv.loadState(fis);
        this.u = new Matrix[this.conv.getFilterCount()];
        int i = 0;
        while (i < this.conv.getFilterCount()) {
            this.u[i] = new Matrix(1);
            this.u[i].load(fis);
            this.clonedFilters[i] = new Matrix(1);
            this.clonedFilters[i].load(fis);
            this.iteredFilters[i] = new Matrix(1);
            this.iteredFilters[i].load(fis);
            ++i;
        }
        this.test.load(fis);
        this.epsilon = fis.readDouble();
    }

    @Override
    public TensorLayer clone() {
        return new ConvSpectralNorm((ConvLayer)this.conv.clone(), new Random());
    }
}

