/*
 * Decompiled with CFR 0.152.
 */
package autodiff;

import java.util.ArrayList;
import java.util.List;
import matrix.Matrix;
import matrix.Tensor;
import nonlinearities.LinearUnit;
import nonlinearities.Nonlinearity;

public class Graph {
    protected boolean applyBackprop;
    List<Runnable> backprop;
    Exception backpropException = null;

    public boolean isApplyingBackprop() {
        return this.applyBackprop;
    }

    public void setApplyingBackprop(boolean applyBackprop) {
        this.applyBackprop = applyBackprop;
    }

    public Graph() {
        this.applyBackprop = false;
        this.backprop = new ArrayList<Runnable>();
    }

    public Graph(boolean applyBackprop) {
        this.applyBackprop = applyBackprop;
        this.backprop = new ArrayList<Runnable>();
    }

    public void backward() throws Exception {
        if (!this.applyBackprop) {
            throw new Exception("Attempting to do backpropagation on a non-backprop enabled Graph");
        }
        this.backpropException = null;
        int i = this.backprop.size() - 1;
        while (i >= 0) {
            this.backprop.remove(i).run();
            if (this.backpropException != null) {
                throw this.backpropException;
            }
            --i;
        }
    }

    public void cleanUp() {
        if (this.backprop == null) {
            return;
        }
        this.backprop.clear();
    }

    public void resetBackprop() {
        if (this.backprop != null) {
            this.backprop.clear();
        }
    }

    public Tensor convolution(final int pad, final int stride, final int outDepth, final Tensor t, final Matrix ... filters) throws Exception {
        if (filters.length != t.depth * outDepth) {
            throw new Exception("Filter count does not match tensor depth");
        }
        final Tensor out = new Tensor((t.width + pad * 2 - filters[0].cols) / stride + 1, (t.height + pad * 2 - filters[0].rows) / stride + 1, outDepth);
        int i = 0;
        while (i < outDepth) {
            Matrix outm = out.matrices[i];
            int j = 0;
            while (j < t.depth) {
                Matrix inm = t.matrices[j];
                Matrix currFilter = filters[i * t.depth + j];
                int k = 0;
                while (k < outm.rows) {
                    int l = 0;
                    while (l < outm.cols) {
                        int outIndx = k * outm.cols + l;
                        int y = k * stride - pad;
                        int loopEnd1 = y + filters[0].rows - 1 >= t.height ? filters[0].rows + (t.height - (y + filters[0].rows - 1)) - 1 : filters[0].rows;
                        int x = l * stride - pad;
                        int loopEnd2 = x + filters[0].cols - 1 >= t.width ? filters[0].cols + (t.width - (x + filters[0].cols - 1)) - 1 : filters[0].cols;
                        int m = y < 0 ? -y : 0;
                        while (m < loopEnd1) {
                            int inIndx = (y + m) * inm.cols + x;
                            int filterIndx = m * currFilter.cols;
                            int n = x < 0 ? -x : 0;
                            while (n < loopEnd2) {
                                int n2 = outIndx;
                                outm.w[n2] = outm.w[n2] + inm.w[inIndx + n] * currFilter.w[filterIndx + n];
                                ++n;
                            }
                            ++m;
                        }
                        ++l;
                    }
                    ++k;
                }
                ++j;
            }
            ++i;
        }
        if (this.applyBackprop) {
            Runnable bp = new Runnable(){

                @Override
                public void run() {
                    int i = 0;
                    while (i < outDepth) {
                        Matrix outm = out.matrices[i];
                        int j = 0;
                        while (j < t.depth) {
                            Matrix inm = t.matrices[j];
                            Matrix currFilter = filters[i * t.depth + j];
                            int k = 0;
                            while (k < outm.rows) {
                                int l = 0;
                                while (l < outm.cols) {
                                    double grad = outm.dw[k * outm.cols + l];
                                    int y = k * stride - pad;
                                    int loopEnd1 = y + filters[0].rows - 1 >= t.height ? filters[0].rows + (t.height - (y + filters[0].rows - 1)) - 1 : filters[0].rows;
                                    int x = l * stride - pad;
                                    int loopEnd2 = x + filters[0].cols - 1 >= t.width ? filters[0].cols + (t.width - (x + filters[0].cols - 1)) - 1 : filters[0].cols;
                                    int m = y < 0 ? -y : 0;
                                    while (m < loopEnd1) {
                                        int inIndx = (y + m) * inm.cols + x;
                                        int filterIndx = m * currFilter.cols;
                                        int n = x < 0 ? -x : 0;
                                        while (n < loopEnd2) {
                                            int n2 = inIndx + n;
                                            inm.dw[n2] = inm.dw[n2] + currFilter.w[filterIndx + n] * grad;
                                            ++n;
                                        }
                                        n = x < 0 ? -x : 0;
                                        while (n < loopEnd2) {
                                            int n3 = filterIndx + n;
                                            currFilter.dw[n3] = currFilter.dw[n3] + inm.w[inIndx + n] * grad;
                                            ++n;
                                        }
                                        ++m;
                                    }
                                    ++l;
                                }
                                ++k;
                            }
                            ++j;
                        }
                        ++i;
                    }
                }
            };
            this.backprop.add(bp);
        }
        return out;
    }

    public Matrix concatVectors(final Matrix m1, final Matrix m2) throws Exception {
        if (m1.cols > 1 || m2.cols > 1) {
            throw new Exception("Expected column vectors");
        }
        final Matrix out = new Matrix(m1.rows + m2.rows);
        System.arraycopy(m1.w, 0, out.w, 0, m1.w.length);
        System.arraycopy(m2.w, 0, out.w, m1.w.length, m2.w.length);
        if (this.applyBackprop) {
            Runnable bp = new Runnable(){

                @Override
                public void run() {
                    try {
                        System.arraycopy(out.dw, 0, m1.dw, 0, m1.dw.length);
                        System.arraycopy(out.dw, m1.dw.length, m2.dw, 0, m2.dw.length);
                    }
                    catch (Exception e) {
                        Graph.this.backpropException = e;
                        return;
                    }
                }
            };
            this.backprop.add(bp);
        }
        return out;
    }

    public Matrix nonlin(final Nonlinearity neuron, final Matrix m) throws Exception {
        final Matrix out = new Matrix(m.rows, m.cols);
        final int n = m.w.length;
        if (neuron instanceof LinearUnit) {
            System.arraycopy(m.w, 0, out.w, 0, m.w.length);
        } else {
            int i = 0;
            while (i < n) {
                out.w[i] = neuron.forward(m.w[i]);
                ++i;
            }
        }
        if (this.applyBackprop) {
            Runnable bp = new Runnable(){

                @Override
                public void run() {
                    try {
                        if (neuron instanceof LinearUnit) {
                            int i = 0;
                            while (i < n) {
                                int n2 = i;
                                m.dw[n2] = m.dw[n2] + out.dw[i];
                                ++i;
                            }
                        } else {
                            int i = 0;
                            while (i < n) {
                                int n3 = i;
                                m.dw[n3] = m.dw[n3] + neuron.backward(m.w[i]) * out.dw[i];
                                ++i;
                            }
                        }
                    }
                    catch (Exception e) {
                        Graph.this.backpropException = e;
                        return;
                    }
                }
            };
            this.backprop.add(bp);
        }
        return out;
    }

    public void setBackpropException(Exception e) {
        this.backpropException = e;
    }

    public void addBackprop(Runnable r) {
        this.backprop.add(r);
    }

    public Matrix dot(final Matrix m1, final Matrix m2) throws Exception {
        if (m1.rows != m2.rows || m1.cols != m2.cols) {
            throw new Exception("matrix dimension mismatch");
        }
        if (m1.cols > 1 || m2.cols > 1) {
            throw new Exception("Expected column vectors");
        }
        final Matrix dot = new Matrix(1);
        double dotp = 0.0;
        int i = 0;
        while (i < m1.rows) {
            dotp += m1.getW(i, 0) * m2.getW(i, 0);
            ++i;
        }
        dot.w = new double[]{dotp};
        if (this.applyBackprop) {
            Runnable bp = new Runnable(){

                @Override
                public void run() {
                    double dotp2 = dot.getDW(0, 0);
                    int i = 0;
                    while (i < m1.rows) {
                        m1.setDW(i, 0, m2.getW(i, 0) * dotp2);
                        m2.setDW(i, 0, m1.getW(i, 0) * dotp2);
                        ++i;
                    }
                }
            };
            this.backprop.add(bp);
        }
        return dot;
    }

    public Matrix mul(final Matrix m1, final Matrix m2) throws Exception {
        if (m1.cols != m2.rows) {
            throw new Exception("matrix dimension mismatch");
        }
        int m1rows = m1.rows;
        final int m1cols = m1.cols;
        final int m2cols = m2.cols;
        final Matrix out = new Matrix(m1rows, m2cols);
        final int outcols = m2cols;
        int i = 0;
        while (i < m1rows) {
            int m1col = m1cols * i;
            int j = 0;
            while (j < m2cols) {
                double dot = 0.0;
                int k = 0;
                while (k < m1cols) {
                    dot += m1.w[m1col + k] * m2.w[m2cols * k + j];
                    ++k;
                }
                out.w[outcols * i + j] = dot;
                ++j;
            }
            ++i;
        }
        if (this.applyBackprop) {
            Runnable bp = new Runnable(){

                @Override
                public void run() {
                    try {
                        int i = 0;
                        while (i < m1.rows) {
                            int outcol = outcols * i;
                            int j = 0;
                            while (j < m2.cols) {
                                double b = out.dw[outcol + j];
                                int k = 0;
                                while (k < m1.cols) {
                                    int n = m1cols * i + k;
                                    m1.dw[n] = m1.dw[n] + m2.w[m2cols * k + j] * b;
                                    ++k;
                                }
                                ++j;
                            }
                            int k = 0;
                            while (k < m1.cols) {
                                int j2 = 0;
                                while (j2 < m2.cols) {
                                    int n = m2cols * k + j2;
                                    m2.dw[n] = m2.dw[n] + m1.w[m1cols * i + k] * out.dw[outcol + j2];
                                    ++j2;
                                }
                                ++k;
                            }
                            ++i;
                        }
                    }
                    catch (Exception e) {
                        Graph.this.backpropException = e;
                        return;
                    }
                }
            };
            this.backprop.add(bp);
        }
        return out;
    }

    public Matrix add(final Matrix m1, final Matrix m2) throws Exception {
        if (m1.rows != m2.rows || m1.cols != m2.cols) {
            throw new Exception("matrix dimension mismatch");
        }
        final Matrix out = new Matrix(m1.rows, m1.cols);
        int i = 0;
        while (i < m1.w.length) {
            out.w[i] = m1.w[i] + m2.w[i];
            ++i;
        }
        if (this.applyBackprop) {
            Runnable bp = new Runnable(){

                @Override
                public void run() {
                    try {
                        int i = 0;
                        while (i < m1.w.length) {
                            int n = i;
                            m1.dw[n] = m1.dw[n] + out.dw[i];
                            int n2 = i;
                            m2.dw[n2] = m2.dw[n2] + out.dw[i];
                            ++i;
                        }
                    }
                    catch (Exception e) {
                        Graph.this.backpropException = e;
                        return;
                    }
                }
            };
            this.backprop.add(bp);
        }
        return out;
    }

    public Matrix oneMinus(Matrix m) throws Exception {
        Matrix ones = Matrix.ones(m.rows, m.cols);
        Matrix out = this.sub(ones, m);
        return out;
    }

    public Matrix sub(Matrix m1, Matrix m2) throws Exception {
        Matrix out = this.add(m1, this.neg(m2));
        return out;
    }

    public Matrix scalMul(final Matrix m, final Matrix s) throws Exception {
        final Matrix out = new Matrix(m.rows, m.cols);
        double scalar = s.w[0];
        int i = 0;
        while (i < m.w.length) {
            out.w[i] = m.w[i] * scalar;
            ++i;
        }
        if (this.applyBackprop) {
            Runnable bp = new Runnable(){

                @Override
                public void run() {
                    double scalar = s.w[0];
                    int i = 0;
                    while (i < m.w.length) {
                        m.dw[i] = out.dw[i] * scalar;
                        s.dw[0] = out.dw[i] * m.w[i];
                        ++i;
                    }
                }
            };
            this.backprop.add(bp);
        }
        return out;
    }

    @Deprecated
    public Matrix smul(Matrix m, double s) throws Exception {
        Matrix m2 = Matrix.uniform(m.rows, m.cols, s);
        Matrix out = this.elmul(m, m2);
        return out;
    }

    @Deprecated
    public Matrix smul(double s, Matrix m) throws Exception {
        Matrix out = this.smul(m, s);
        return out;
    }

    @Deprecated
    public Matrix neg(Matrix m) throws Exception {
        Matrix negones = Matrix.negones(m.rows, m.cols);
        Matrix out = this.elmul(negones, m);
        return out;
    }

    public Matrix elmul(final Matrix m1, final Matrix m2) throws Exception {
        if (m1.rows != m2.rows || m1.cols != m2.cols) {
            throw new Exception("matrix dimension mismatch");
        }
        final Matrix out = new Matrix(m1.rows, m1.cols);
        int i = 0;
        while (i < m1.w.length) {
            out.w[i] = m1.w[i] * m2.w[i];
            ++i;
        }
        if (this.applyBackprop) {
            Runnable bp = new Runnable(){

                @Override
                public void run() {
                    try {
                        int i = 0;
                        while (i < m1.w.length) {
                            int n = i;
                            m1.dw[n] = m1.dw[n] + m2.w[i] * out.dw[i];
                            int n2 = i;
                            m2.dw[n2] = m2.dw[n2] + m1.w[i] * out.dw[i];
                            ++i;
                        }
                    }
                    catch (Exception e) {
                        Graph.this.backpropException = e;
                        return;
                    }
                }
            };
            this.backprop.add(bp);
        }
        return out;
    }
}

