/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.Graph;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;
import matrix.Matrix;
import matrix.Tensor;
import model.TensorLayer;

public class ConvLayer
implements TensorLayer {
    private int stride;
    private int pad;
    private int inWidth;
    private int inHeight;
    private int inDepth;
    private int filterWidth;
    private int filterHeight;
    private int filtersPerDepth;
    private int outWidth;
    private int outHeight;
    private Matrix[] filters;
    private Tensor bias;
    private Matrix patchedGradient;
    private boolean multithreading = false;
    private int cores = 1;
    private ThreadPoolExecutor threadPool;
    private FilterThread[] tasks;
    private BackpropThread[] backpropTasks;

    public ConvLayer(int inWidth, int inHeight, int inDepth, int filterWidth, int filterHeight, int numFiltersPerDepth, int stride, int pad, double initParamsStdDev, Random rng, boolean multithreading, int cores) {
        this(inWidth, inHeight, inDepth, filterWidth, filterHeight, numFiltersPerDepth, stride, pad, initParamsStdDev, rng);
        this.multithreading = multithreading;
        this.threadPool = (ThreadPoolExecutor)Executors.newFixedThreadPool(cores);
        this.cores = cores;
    }

    public ConvLayer(int inWidth, int inHeight, int inDepth, int filterWidth, int filterHeight, int numFiltersPerDepth, int stride, int pad, double initParamsStdDev, Random rng) {
        this.stride = stride;
        this.pad = pad;
        this.inWidth = inWidth;
        this.inHeight = inHeight;
        this.inDepth = inDepth;
        this.filterHeight = filterHeight;
        this.filterWidth = filterWidth;
        this.filtersPerDepth = numFiltersPerDepth;
        this.outWidth = (inWidth + pad * 2 - filterWidth) / stride + 1;
        this.outHeight = (inHeight + pad * 2 - filterHeight) / stride + 1;
        this.filters = new Matrix[numFiltersPerDepth * inDepth];
        int i = 0;
        while (i < this.filters.length) {
            this.filters[i] = Matrix.rand(filterHeight, filterWidth, initParamsStdDev, rng);
            ++i;
        }
        this.tasks = new FilterThread[numFiltersPerDepth];
        i = 0;
        while (i < numFiltersPerDepth) {
            this.tasks[i] = new FilterThread(i);
            ++i;
        }
        this.bias = new Tensor(this.outWidth, this.outHeight, this.filtersPerDepth);
        i = 0;
        while (i < this.bias.getDepth()) {
            this.bias.setMatrixAt(i, Matrix.uniform(this.outHeight, this.outWidth, 0.0));
            int j = 0;
            while (j < this.bias.getMatrixAt((int)i).w.length) {
                this.bias.getMatrixAt((int)i).w[j] = rng.nextGaussian() * initParamsStdDev;
                ++j;
            }
            ++i;
        }
        this.backpropTasks = new BackpropThread[numFiltersPerDepth];
        i = 0;
        while (i < numFiltersPerDepth) {
            this.backpropTasks[i] = new BackpropThread(i);
            ++i;
        }
        this.multithreading = false;
    }

    public void setFilter(int indx, Matrix filter) {
        if (filter.rows != this.filterHeight && filter.cols != this.filterWidth) {
            return;
        }
        if (indx >= this.filters.length) {
            return;
        }
        this.filters[indx] = filter;
    }

    public Matrix getFilter(int indx) {
        return this.filters[indx];
    }

    /*
     * WARNING - void declaration
     */
    @Override
    public Tensor forward(final Tensor t, final Graph g) throws Exception {
        final Tensor toReturn = new Tensor(this.outWidth, this.outHeight, this.filtersPerDepth);
        if (t.getDepth() != this.inDepth) {
            throw new Exception("Invalid tensor depth: Is " + Integer.toString(t.getDepth()) + ", expected: " + Integer.toString(this.inDepth));
        }
        if (t.getWidth() != this.inWidth) {
            throw new Exception("Invalid tensor width: Is " + Integer.toString(t.getWidth()) + ", expected: " + Integer.toString(this.inWidth));
        }
        if (t.getHeight() != this.inHeight) {
            throw new Exception("Invalid tensor height: Is " + Integer.toString(t.getHeight()) + ", expected: " + Integer.toString(this.inHeight));
        }
        if (this.multithreading) {
            void var5_11;
            void var5_8;
            ArrayList a = new ArrayList();
            boolean bl = false;
            while (var5_8 < this.filtersPerDepth) {
                this.tasks[var5_8].t = t;
                this.tasks[var5_8].res = toReturn.getMatrixAt((int)var5_8);
                a.add(this.threadPool.submit(this.tasks[var5_8]));
                ++var5_8;
            }
            for (Future future : a) {
                while (future.get() != null) {
                }
            }
            boolean bl2 = false;
            while (var5_11 < this.filtersPerDepth) {
                toReturn.getMatrixAt((int)var5_11).w = ((FilterThread)this.tasks[var5_11]).res.w;
                ++var5_11;
            }
        } else {
            int i = 0;
            while (i < this.filtersPerDepth) {
                int j = 0;
                while (j < this.inDepth) {
                    int filterIndex = i * this.inDepth + j;
                    int k = 0;
                    while (k < this.outHeight) {
                        int l = 0;
                        while (l < this.outWidth) {
                            int m = 0;
                            while (m < this.filterHeight) {
                                int n = 0;
                                while (n < this.filterWidth) {
                                    int y = k * this.stride + m - this.pad;
                                    int x = l * this.stride + n - this.pad;
                                    if (x >= 0 && y >= 0 && x < this.inWidth && y < this.inHeight) {
                                        double w2 = t.getMatrixAt(j).getW(y, x);
                                        toReturn.getMatrixAt(i).setW(k, l, toReturn.getMatrixAt(i).getW(k, l) + (w2 *= this.filters[filterIndex].getW(m, n)));
                                    }
                                    ++n;
                                }
                                ++m;
                            }
                            ++l;
                        }
                        ++k;
                    }
                    ++j;
                }
                ++i;
            }
        }
        if (g.applyBackprop()) {
            g.addBackprop(new Runnable(){

                /*
                 * WARNING - void declaration
                 */
                @Override
                public void run() {
                    try {
                        if (!ConvLayer.this.multithreading) {
                            int numi = 0;
                            while (numi < ConvLayer.this.filtersPerDepth) {
                                ConvLayer.this.patchedGradient = new Matrix(ConvLayer.this.inHeight + ConvLayer.this.pad * 2 - ConvLayer.this.filterHeight + 1, ConvLayer.this.inWidth + ConvLayer.this.pad * 2 - ConvLayer.this.filterWidth + 1);
                                int i = 0;
                                while (i < toReturn.getMatrixAt((int)numi).cols) {
                                    int j = 0;
                                    while (j < toReturn.getMatrixAt((int)numi).rows) {
                                        ConvLayer.this.patchedGradient.setDW(j * ConvLayer.this.stride, i * ConvLayer.this.stride, toReturn.getMatrixAt(numi).getDW(j, i));
                                        ++j;
                                    }
                                    ++i;
                                }
                                int numj = 0;
                                while (numj < ConvLayer.this.inDepth) {
                                    int filterIndx = numi * ConvLayer.this.inDepth + numj;
                                    int k = 0;
                                    while (k < ((ConvLayer)ConvLayer.this).patchedGradient.rows) {
                                        int l = 0;
                                        while (l < ((ConvLayer)ConvLayer.this).patchedGradient.cols) {
                                            double w2 = ConvLayer.this.patchedGradient.getDW(k, l);
                                            int m = 0;
                                            while (m < ConvLayer.this.filterHeight) {
                                                int n = 0;
                                                while (n < ConvLayer.this.filterWidth) {
                                                    int y2 = k + m - ConvLayer.this.pad;
                                                    int x2 = l + n - ConvLayer.this.pad;
                                                    if (y2 >= 0 && x2 >= 0 && y2 < ConvLayer.this.inHeight && x2 < ConvLayer.this.inWidth) {
                                                        ConvLayer.this.filters[filterIndx].setDW(m, n, ConvLayer.this.filters[filterIndx].getDW(m, n) + w2 * t.getMatrixAt(numj).getW(y2, x2));
                                                        t.getMatrixAt(numj).setDW(y2, x2, t.getMatrixAt(numj).getDW(y2, x2) + w2 * ConvLayer.this.filters[filterIndx].getW(m, n));
                                                    }
                                                    ++n;
                                                }
                                                ++m;
                                            }
                                            ++l;
                                        }
                                        ++k;
                                    }
                                    ++numj;
                                }
                                ++numi;
                            }
                        } else {
                            void var2_21;
                            void var2_18;
                            ArrayList a = new ArrayList();
                            boolean bl = false;
                            while (var2_18 < ConvLayer.this.filtersPerDepth) {
                                ConvLayer.this.backpropTasks[var2_18].toReturn = toReturn.getMatrixAt((int)var2_18);
                                ConvLayer.this.backpropTasks[var2_18].t2 = new Tensor(ConvLayer.this.inWidth, ConvLayer.this.inHeight, ConvLayer.this.inDepth);
                                ConvLayer.this.backpropTasks[var2_18].t3 = t;
                                a.add(ConvLayer.this.threadPool.submit(ConvLayer.this.backpropTasks[var2_18]));
                                ++var2_18;
                            }
                            for (Future future : a) {
                                while (future.get() != null) {
                                }
                            }
                            boolean bl2 = false;
                            while (var2_21 < ConvLayer.this.filtersPerDepth) {
                                int j = 0;
                                while (j < ConvLayer.this.inDepth) {
                                    int k = 0;
                                    while (k < ConvLayer.this.inWidth) {
                                        int l = 0;
                                        while (l < ConvLayer.this.inHeight) {
                                            t.getMatrixAt(j).setDW(l, k, t.getMatrixAt(j).getDW(l, k) + ConvLayer.this.backpropTasks[var2_21].t2.getMatrixAt(j).getDW(l, k));
                                            ++l;
                                        }
                                        ++k;
                                    }
                                    ++j;
                                }
                                ++var2_21;
                            }
                        }
                    }
                    catch (Exception e) {
                        g.setBackpropException(e);
                        return;
                    }
                }
            });
        }
        int i = 0;
        while (i < this.filtersPerDepth) {
            toReturn.setMatrixAt(i, g.add(toReturn.getMatrixAt(i), this.bias.getMatrixAt(i)));
            ++i;
        }
        return toReturn;
    }

    public int getOutWidth() {
        return this.outWidth;
    }

    public int getOutHeight() {
        return this.outHeight;
    }

    @Override
    public List<Matrix> getParameters() {
        ArrayList<Matrix> result = new ArrayList<Matrix>();
        Matrix[] matrixArray = this.filters;
        int n = this.filters.length;
        int n2 = 0;
        while (n2 < n) {
            Matrix l = matrixArray[n2];
            result.add(l);
            ++n2;
        }
        matrixArray = this.bias.getMatrices();
        n = matrixArray.length;
        n2 = 0;
        while (n2 < n) {
            Matrix m = matrixArray[n2];
            result.add(m);
            ++n2;
        }
        return result;
    }

    @Override
    public void resetState() {
    }

    public int getCores() {
        return this.multithreading ? this.cores : 1;
    }

    public void setCores(int cores) throws Exception {
        if (this.threadPool != null && this.threadPool.getActiveCount() != 0) {
            throw new Exception("Can't change core count while thread pool is in use");
        }
        if (this.threadPool != null) {
            this.threadPool.shutdown();
        }
        this.threadPool = (ThreadPoolExecutor)Executors.newFixedThreadPool(cores);
        this.cores = cores;
    }

    public boolean isMultithreading() {
        return this.multithreading;
    }

    public void setMultithreading(boolean mt) throws Exception {
        if (mt == this.multithreading) {
            return;
        }
        if (!mt) {
            if (this.threadPool != null && this.threadPool.getActiveCount() != 0) {
                throw new Exception("Can't disable multithreading while thread pool is in use");
            }
            if (this.threadPool != null) {
                this.threadPool.shutdown();
            }
        } else {
            this.threadPool = (ThreadPoolExecutor)Executors.newFixedThreadPool(this.cores);
        }
        this.multithreading = mt;
    }

    @Override
    public TensorLayer clone() {
        ConvLayer clone = new ConvLayer(this.inWidth, this.inHeight, this.inDepth, this.filterWidth, this.filterHeight, this.filtersPerDepth, this.stride, this.pad, 1.0, new Random(), this.multithreading, this.cores);
        clone.bias = this.bias.clone();
        int i = 0;
        while (i < this.filters.length) {
            clone.filters[i] = this.filters[i].clone();
            ++i;
        }
        return clone;
    }

    @Override
    public void saveState(DataOutputStream fos) throws Exception {
        fos.writeInt(this.stride);
        fos.writeInt(this.inWidth);
        fos.writeInt(this.inHeight);
        fos.writeInt(this.inDepth);
        fos.writeInt(this.filterWidth);
        fos.writeInt(this.filterHeight);
        fos.writeInt(this.filtersPerDepth);
        fos.writeInt(this.outWidth);
        fos.writeInt(this.outHeight);
        fos.writeInt(this.pad);
        this.bias.save(fos);
        fos.writeInt(this.filters.length);
        int i = 0;
        while (i < this.filters.length) {
            this.filters[i].save(fos);
            ++i;
        }
    }

    @Override
    public void loadState(DataInputStream fis) throws Exception {
        this.stride = fis.readInt();
        this.inWidth = fis.readInt();
        this.inHeight = fis.readInt();
        this.inDepth = fis.readInt();
        this.filterWidth = fis.readInt();
        this.filterHeight = fis.readInt();
        this.filtersPerDepth = fis.readInt();
        this.outWidth = fis.readInt();
        this.outHeight = fis.readInt();
        this.pad = fis.readInt();
        this.bias.load(fis);
        this.filters = new Matrix[fis.readInt()];
        int i = 0;
        while (i < this.filters.length) {
            this.filters[i] = new Matrix(this.filterHeight, this.filterWidth);
            this.filters[i].load(fis);
            ++i;
        }
        this.tasks = new FilterThread[this.filtersPerDepth];
        i = 0;
        while (i < this.filtersPerDepth) {
            this.tasks[i] = new FilterThread(i);
            ++i;
        }
        this.backpropTasks = new BackpropThread[this.filtersPerDepth];
        i = 0;
        while (i < this.filtersPerDepth) {
            this.backpropTasks[i] = new BackpropThread(i);
            ++i;
        }
    }

    private class BackpropThread
    implements Runnable {
        private int numi;
        private Matrix patchedGrad;
        private Matrix toReturn;
        private Tensor t2;
        private Tensor t3;

        private BackpropThread(int i) {
            this.numi = i;
        }

        @Override
        public void run() {
            this.patchedGrad = new Matrix(ConvLayer.this.inHeight + ConvLayer.this.pad * 2 - ConvLayer.this.filterHeight + 1, ConvLayer.this.inWidth + ConvLayer.this.pad * 2 - ConvLayer.this.filterWidth + 1);
            int i = 0;
            while (i < this.toReturn.cols) {
                int j = 0;
                while (j < this.toReturn.rows) {
                    this.patchedGrad.setDW(j * ConvLayer.this.stride, i * ConvLayer.this.stride, this.toReturn.getDW(j, i));
                    ++j;
                }
                ++i;
            }
            int numj = 0;
            while (numj < ConvLayer.this.inDepth) {
                int filterIndx = this.numi * ConvLayer.this.inDepth + numj;
                int k = 0;
                while (k < this.patchedGrad.rows) {
                    int l = 0;
                    while (l < this.patchedGrad.cols) {
                        double w2 = this.patchedGrad.getDW(k, l);
                        int m = 0;
                        while (m < ConvLayer.this.filterHeight) {
                            int n = 0;
                            while (n < ConvLayer.this.filterWidth) {
                                int y2 = k + m - ConvLayer.this.pad;
                                int x2 = l + n - ConvLayer.this.pad;
                                if (y2 >= 0 && x2 >= 0 && y2 < ConvLayer.this.inHeight && x2 < ConvLayer.this.inWidth) {
                                    ConvLayer.this.filters[filterIndx].setDW(m, n, ConvLayer.this.filters[filterIndx].getDW(m, n) + w2 * this.t3.getMatrixAt(numj).getW(y2, x2));
                                    this.t2.getMatrixAt(numj).setDW(y2, x2, this.t2.getMatrixAt(numj).getDW(y2, x2) + w2 * ConvLayer.this.filters[filterIndx].getW(m, n));
                                }
                                ++n;
                            }
                            ++m;
                        }
                        ++l;
                    }
                    ++k;
                }
                ++numj;
            }
        }
    }

    private class FilterThread
    implements Runnable {
        private int i;
        private Matrix res;
        private Tensor t;

        private FilterThread(int i) {
            this.res = new Matrix(ConvLayer.this.outHeight, ConvLayer.this.outWidth);
            this.i = i;
        }

        @Override
        public void run() {
            int j = 0;
            while (j < ConvLayer.this.inDepth) {
                int filterIndx = this.i * ConvLayer.this.inDepth + j;
                int k = 0;
                while (k < ConvLayer.this.outHeight) {
                    int l = 0;
                    while (l < ConvLayer.this.outWidth) {
                        int m = 0;
                        while (m < ConvLayer.this.filterHeight) {
                            int n = 0;
                            while (n < ConvLayer.this.filterWidth) {
                                int y = k * ConvLayer.this.stride + m - ConvLayer.this.pad;
                                int x = l * ConvLayer.this.stride + n - ConvLayer.this.pad;
                                if (x >= 0 && y >= 0 && x < ConvLayer.this.inWidth && y < ConvLayer.this.inHeight) {
                                    double w2 = this.t.getMatrixAt(j).getW(y, x);
                                    this.res.setW(k, l, this.res.getW(k, l) + (w2 *= ConvLayer.this.filters[filterIndx].getW(m, n)));
                                }
                                ++n;
                            }
                            ++m;
                        }
                        ++l;
                    }
                    ++k;
                }
                ++j;
            }
        }
    }
}

