/*
 * Decompiled with CFR 0.152.
 */
package autodiff;

import java.util.ArrayList;
import java.util.List;
import matrix.Matrix;
import matrix.Tensor;
import nonlinearities.Nonlinearity;

public class Graph {
    boolean applyBackprop;
    List<Runnable> backprop;
    Exception backpropException = null;

    public boolean applyBackprop() {
        return this.applyBackprop;
    }

    public Graph() {
        this.applyBackprop = true;
    }

    public Graph(boolean applyBackprop) {
        this.applyBackprop = applyBackprop;
        if (applyBackprop) {
            this.backprop = new ArrayList<Runnable>();
        }
    }

    public void backward() throws Exception {
        this.backpropException = null;
        int i = this.backprop.size() - 1;
        while (i >= 0) {
            this.backprop.remove(i).run();
            if (this.backpropException != null) {
                throw this.backpropException;
            }
            --i;
        }
        System.gc();
    }

    public List<Runnable> getBackprops() {
        return this.backprop;
    }

    public void cleanUp() {
        this.backprop.clear();
        this.backprop = null;
    }

    public Matrix convolution(Tensor t, Matrix ... m) throws Exception {
        throw new Exception("Unimplemented");
    }

    public Matrix concatVectors(final Matrix m1, final Matrix m2) throws Exception {
        if (m1.cols > 1 || m2.cols > 1) {
            throw new Exception("Expected column vectors");
        }
        final Matrix out = new Matrix(m1.rows + m2.rows);
        int loc = 0;
        int i = 0;
        while (i < m1.w.length) {
            out.w[loc] = m1.w[i];
            out.dw[loc] = m1.dw[i];
            out.stepCache[loc] = m1.stepCache[i];
            ++loc;
            ++i;
        }
        i = 0;
        while (i < m2.w.length) {
            out.w[loc] = m2.w[i];
            out.dw[loc] = m2.dw[i];
            out.stepCache[loc] = m2.stepCache[i];
            ++loc;
            ++i;
        }
        if (this.applyBackprop) {
            Runnable bp = new Runnable(){

                @Override
                public void run() {
                    try {
                        int loc = 0;
                        int i = 0;
                        while (i < m1.w.length) {
                            m1.w[i] = out.w[loc];
                            m1.dw[i] = out.dw[loc];
                            m1.stepCache[i] = out.stepCache[loc];
                            ++loc;
                            ++i;
                        }
                        i = 0;
                        while (i < m2.w.length) {
                            m2.w[i] = out.w[loc];
                            m2.dw[i] = out.dw[loc];
                            m2.stepCache[i] = out.stepCache[loc];
                            ++loc;
                            ++i;
                        }
                    }
                    catch (Exception e) {
                        Graph.this.backpropException = e;
                        return;
                    }
                }
            };
            this.backprop.add(bp);
        }
        return out;
    }

    public Matrix nonlin(final Nonlinearity neuron, final Matrix m) throws Exception {
        final Matrix out = new Matrix(m.rows, m.cols);
        final int n = m.w.length;
        int i = 0;
        while (i < n) {
            out.w[i] = neuron.forward(m.w[i]);
            ++i;
        }
        if (this.applyBackprop) {
            Runnable bp = new Runnable(){

                @Override
                public void run() {
                    try {
                        int i = 0;
                        while (i < n) {
                            int n2 = i;
                            m.dw[n2] = m.dw[n2] + neuron.backward(m.w[i]) * out.dw[i];
                            ++i;
                        }
                    }
                    catch (Exception e) {
                        Graph.this.backpropException = e;
                        return;
                    }
                }
            };
            this.backprop.add(bp);
        }
        return out;
    }

    public void setBackpropException(Exception e) {
        this.backpropException = e;
    }

    public void addBackprop(Runnable r) {
        this.backprop.add(r);
    }

    public void addBackprops(List<Runnable> r) {
        this.backprop.addAll(r);
    }

    public double subDot(Matrix m1, Matrix m2, int startm1row, int startm1col, int startm2row, int startm2col, int rows, int cols) throws Exception {
        double res = 0.0;
        int i = 0;
        while (i < rows) {
            int j = 0;
            while (j < cols) {
                res += m1.getW(startm1row + i, startm1col + j) * m2.getW(startm2row + i, startm2col + j);
                ++j;
            }
            ++i;
        }
        return res;
    }

    public Matrix dot(final Matrix m1, final Matrix m2) throws Exception {
        if (m1.rows != m2.rows || m1.cols != m2.cols) {
            throw new Exception("matrix dimension mismatch");
        }
        if (m1.cols > 1 || m2.cols > 1) {
            throw new Exception("Expected column vectors");
        }
        final Matrix dot = new Matrix(1);
        double dotp = 0.0;
        int i = 0;
        while (i < m1.rows) {
            dotp += m1.getW(i, 0) * m2.getW(i, 0);
            ++i;
        }
        dot.w = new double[]{dotp};
        if (this.applyBackprop) {
            Runnable bp = new Runnable(){

                @Override
                public void run() {
                    double dotp2 = dot.getDW(0, 0);
                    int i = 0;
                    while (i < m1.rows) {
                        m1.setDW(i, 0, m2.getW(i, 0) * dotp2);
                        m2.setDW(i, 0, m1.getW(i, 0) * dotp2);
                        ++i;
                    }
                }
            };
            this.backprop.add(bp);
        }
        return dot;
    }

    public Matrix mul(final Matrix m1, final Matrix m2) throws Exception {
        if (m1.cols != m2.rows) {
            throw new Exception("matrix dimension mismatch");
        }
        int m1rows = m1.rows;
        final int m1cols = m1.cols;
        final int m2cols = m2.cols;
        final Matrix out = new Matrix(m1rows, m2cols);
        final int outcols = m2cols;
        int i = 0;
        while (i < m1rows) {
            int m1col = m1cols * i;
            int j = 0;
            while (j < m2cols) {
                double dot = 0.0;
                int k = 0;
                while (k < m1cols) {
                    dot += m1.w[m1col + k] * m2.w[m2cols * k + j];
                    ++k;
                }
                out.w[outcols * i + j] = dot;
                ++j;
            }
            ++i;
        }
        if (this.applyBackprop) {
            Runnable bp = new Runnable(){

                @Override
                public void run() {
                    try {
                        int i = 0;
                        while (i < m1.rows) {
                            int outcol = outcols * i;
                            int j = 0;
                            while (j < m2.cols) {
                                double b = out.dw[outcol + j];
                                int k = 0;
                                while (k < m1.cols) {
                                    int n = m1cols * i + k;
                                    m1.dw[n] = m1.dw[n] + m2.w[m2cols * k + j] * b;
                                    int n2 = m2cols * k + j;
                                    m2.dw[n2] = m2.dw[n2] + m1.w[m1cols * i + k] * b;
                                    ++k;
                                }
                                ++j;
                            }
                            ++i;
                        }
                    }
                    catch (Exception e) {
                        Graph.this.backpropException = e;
                        return;
                    }
                }
            };
            this.backprop.add(bp);
        }
        return out;
    }

    public Matrix add(final Matrix m1, final Matrix m2) throws Exception {
        if (m1.rows != m2.rows || m1.cols != m2.cols) {
            throw new Exception("matrix dimension mismatch");
        }
        final Matrix out = new Matrix(m1.rows, m1.cols);
        int i = 0;
        while (i < m1.w.length) {
            out.w[i] = m1.w[i] + m2.w[i];
            ++i;
        }
        if (this.applyBackprop) {
            Runnable bp = new Runnable(){

                @Override
                public void run() {
                    try {
                        int i = 0;
                        while (i < m1.w.length) {
                            int n = i;
                            m1.dw[n] = m1.dw[n] + out.dw[i];
                            int n2 = i;
                            m2.dw[n2] = m2.dw[n2] + out.dw[i];
                            ++i;
                        }
                    }
                    catch (Exception e) {
                        Graph.this.backpropException = e;
                        return;
                    }
                }
            };
            this.backprop.add(bp);
        }
        return out;
    }

    public Matrix oneMinus(Matrix m) throws Exception {
        Matrix ones = Matrix.ones(m.rows, m.cols);
        Matrix out = this.sub(ones, m);
        return out;
    }

    public Matrix sub(Matrix m1, Matrix m2) throws Exception {
        Matrix out = this.add(m1, this.neg(m2));
        return out;
    }

    public Matrix smul(Matrix m, double s) throws Exception {
        Matrix m2 = Matrix.uniform(m.rows, m.cols, s);
        Matrix out = this.elmul(m, m2);
        return out;
    }

    public Matrix smul(double s, Matrix m) throws Exception {
        Matrix out = this.smul(m, s);
        return out;
    }

    public Matrix neg(Matrix m) throws Exception {
        Matrix negones = Matrix.negones(m.rows, m.cols);
        Matrix out = this.elmul(negones, m);
        return out;
    }

    public Matrix elmul(final Matrix m1, final Matrix m2) throws Exception {
        if (m1.rows != m2.rows || m1.cols != m2.cols) {
            throw new Exception("matrix dimension mismatch");
        }
        final Matrix out = new Matrix(m1.rows, m1.cols);
        int i = 0;
        while (i < m1.w.length) {
            out.w[i] = m1.w[i] * m2.w[i];
            ++i;
        }
        if (this.applyBackprop) {
            Runnable bp = new Runnable(){

                @Override
                public void run() {
                    try {
                        int i = 0;
                        while (i < m1.w.length) {
                            int n = i;
                            m1.dw[n] = m1.dw[n] + m2.w[i] * out.dw[i];
                            int n2 = i;
                            m2.dw[n2] = m2.dw[n2] + m1.w[i] * out.dw[i];
                            ++i;
                        }
                    }
                    catch (Exception e) {
                        Graph.this.backpropException = e;
                        return;
                    }
                }
            };
            this.backprop.add(bp);
        }
        return out;
    }
}

