/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.Graph;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;
import matrix.Matrix;
import matrix.Tensor;
import model.TensorLayer;
import theGhastModding.utils.math.ByteConverters;

public class ConvLayer
implements TensorLayer {
    private int stride;
    private int inWidth;
    private int inHeight;
    private int inDepth;
    private int filterWidth;
    private int filterHeight;
    private int filtersPerDepth;
    private int outWidth;
    private int outHeight;
    private Matrix[] filters;
    private Tensor bias;
    private Matrix patchedGradient;
    private boolean multithreading = false;
    private int cores = 1;
    private ThreadPoolExecutor threadPool;
    private FilterThread[] tasks;
    private BackpropThread[] backpropTasks;

    public ConvLayer(int stride, int inWidth, int inHeight, int inDepth, int filterWidth, int filterHeight, int numFiltersPerDepth, double initParamsStdDev, Random rng, boolean multithreading, int cores) {
        this(stride, inWidth, inHeight, inDepth, filterWidth, filterHeight, numFiltersPerDepth, initParamsStdDev, rng);
        this.multithreading = multithreading;
        this.threadPool = (ThreadPoolExecutor)Executors.newFixedThreadPool(cores);
        this.cores = cores;
    }

    public ConvLayer(int stride, int inWidth, int inHeight, int inDepth, int filterWidth, int filterHeight, int numFiltersPerDepth, double initParamsStdDev, Random rng) {
        this.stride = stride;
        this.inWidth = inWidth;
        this.inHeight = inHeight;
        this.inDepth = inDepth;
        this.filterHeight = filterHeight;
        this.filterWidth = filterWidth;
        this.filtersPerDepth = numFiltersPerDepth;
        this.outWidth = (inWidth - filterWidth) / stride + 1;
        this.outHeight = (inHeight - filterHeight) / stride + 1;
        this.filters = new Matrix[numFiltersPerDepth * inDepth];
        int i = 0;
        while (i < this.filters.length) {
            this.filters[i] = Matrix.rand(filterHeight, filterWidth, initParamsStdDev, rng);
            ++i;
        }
        this.bias = new Tensor(this.outWidth, this.outHeight, this.filtersPerDepth);
        i = 0;
        while (i < this.bias.getDepth()) {
            this.bias.setMatrixAt(i, Matrix.rand(this.outHeight, this.outWidth, initParamsStdDev, rng));
            ++i;
        }
        this.patchedGradient = new Matrix(inHeight - filterHeight + 1, inWidth - filterWidth + 1);
        this.tasks = new FilterThread[numFiltersPerDepth];
        i = 0;
        while (i < numFiltersPerDepth) {
            this.tasks[i] = new FilterThread(i);
            ++i;
        }
        this.backpropTasks = new BackpropThread[numFiltersPerDepth];
        i = 0;
        while (i < numFiltersPerDepth) {
            this.backpropTasks[i] = new BackpropThread(i);
            ++i;
        }
        this.multithreading = false;
    }

    public void setFilter(int indx, Matrix filter) {
        if (filter.rows != this.filterHeight && filter.cols != this.filterWidth) {
            return;
        }
        if (indx >= this.filters.length) {
            return;
        }
        this.filters[indx] = filter;
    }

    public Matrix getFilter(int indx) {
        return this.filters[indx];
    }

    /*
     * WARNING - void declaration
     */
    @Override
    public Tensor forward(final Tensor t, final Graph g) throws Exception {
        final Tensor toReturn = new Tensor(this.outWidth, this.outHeight, this.filtersPerDepth);
        if (t.getDepth() != this.inDepth) {
            throw new Exception("Invalid tensor depth: Is " + Integer.toString(t.getDepth()) + ", expected: " + Integer.toString(this.inDepth));
        }
        if (t.getWidth() != this.inWidth) {
            throw new Exception("Invalid tensor width: Is " + Integer.toString(t.getWidth()) + ", expected: " + Integer.toString(this.inWidth));
        }
        if (t.getHeight() != this.inHeight) {
            throw new Exception("Invalid tensor height: Is " + Integer.toString(t.getHeight()) + ", expected: " + Integer.toString(this.inHeight));
        }
        if (this.multithreading) {
            void var5_11;
            void var5_8;
            ArrayList a = new ArrayList();
            boolean bl = false;
            while (var5_8 < this.filtersPerDepth) {
                this.tasks[var5_8].t = t;
                this.tasks[var5_8].res = toReturn.getMatrixAt((int)var5_8);
                a.add(this.threadPool.submit(this.tasks[var5_8]));
                ++var5_8;
            }
            for (Future future : a) {
                while (future.get() != null) {
                }
            }
            boolean bl2 = false;
            while (var5_11 < this.filtersPerDepth) {
                toReturn.setMatrixAt((int)var5_11, this.tasks[var5_11].res);
                ++var5_11;
            }
        } else {
            int i = 0;
            while (i < this.filtersPerDepth) {
                int j = 0;
                while (j < this.inDepth) {
                    int k = 0;
                    while (k < this.outHeight) {
                        int l = 0;
                        while (l < this.outWidth) {
                            int m = 0;
                            while (m < this.filterHeight) {
                                int n = 0;
                                while (n < this.filterWidth) {
                                    double w2 = t.getMatrixAt(j).getW(k * this.stride + m, l * this.stride + n);
                                    toReturn.getMatrixAt(i).setW(k, l, toReturn.getMatrixAt(i).getW(k, l) + (w2 *= this.filters[i * this.inDepth + j].getW(this.filterHeight - m - 1, this.filterWidth - n - 1)));
                                    ++n;
                                }
                                ++m;
                            }
                            ++l;
                        }
                        ++k;
                    }
                    ++j;
                }
                ++i;
            }
        }
        if (g.applyBackprop()) {
            g.addBackprop(new Runnable(){

                /*
                 * WARNING - void declaration
                 */
                @Override
                public void run() {
                    try {
                        if (!ConvLayer.this.multithreading) {
                            int numi = 0;
                            while (numi < ConvLayer.this.filtersPerDepth) {
                                int i = 0;
                                while (i < ((ConvLayer)ConvLayer.this).patchedGradient.cols) {
                                    int j = 0;
                                    while (j < ((ConvLayer)ConvLayer.this).patchedGradient.rows) {
                                        if (i % ConvLayer.this.stride == 0 && j % ConvLayer.this.stride == 0) {
                                            ConvLayer.this.patchedGradient.setDW(j, i, toReturn.getMatrixAt(numi).getDW(j / ConvLayer.this.stride, i / ConvLayer.this.stride));
                                        } else {
                                            ConvLayer.this.patchedGradient.setDW(j, i, 0.0);
                                        }
                                        ++j;
                                    }
                                    ++i;
                                }
                                int numj = 0;
                                while (numj < ConvLayer.this.inDepth) {
                                    int k = -ConvLayer.this.filterHeight + 1;
                                    while (k < ((ConvLayer)ConvLayer.this).patchedGradient.rows) {
                                        int l = -ConvLayer.this.filterWidth + 1;
                                        while (l < ((ConvLayer)ConvLayer.this).patchedGradient.cols) {
                                            if (k % ConvLayer.this.stride == 0 && l % ConvLayer.this.stride == 0) {
                                                int m = 0;
                                                while (m < ConvLayer.this.filterHeight) {
                                                    int n = 0;
                                                    while (n < ConvLayer.this.filterWidth) {
                                                        double w2 = k + m < 0 || l + n < 0 || k + m >= ((ConvLayer)ConvLayer.this).patchedGradient.rows || l + n >= ((ConvLayer)ConvLayer.this).patchedGradient.cols ? 0.0 : ConvLayer.this.patchedGradient.getDW(k + m, l + n);
                                                        ConvLayer.this.filters[numi * ConvLayer.this.inDepth + numj].setDW(m, n, ConvLayer.this.filters[numi * ConvLayer.this.inDepth + numj].getDW(m, n) + w2 * t.getMatrixAt(numj).getW(k + ConvLayer.this.filterHeight - 1, l + ConvLayer.this.filterWidth - 1));
                                                        t.getMatrixAt(numj).setDW(k + ConvLayer.this.filterHeight - 1, l + ConvLayer.this.filterWidth - 1, t.getMatrixAt(numj).getDW(k + ConvLayer.this.filterHeight - 1, l + ConvLayer.this.filterWidth - 1) + w2 * ConvLayer.this.filters[numi * ConvLayer.this.inDepth + numj].getW(m, n));
                                                        ++n;
                                                    }
                                                    ++m;
                                                }
                                            }
                                            ++l;
                                        }
                                        ++k;
                                    }
                                    ++numj;
                                }
                                ++numi;
                            }
                        } else {
                            void var2_18;
                            void var2_15;
                            ArrayList a = new ArrayList();
                            boolean bl = false;
                            while (var2_15 < ConvLayer.this.filtersPerDepth) {
                                ConvLayer.this.backpropTasks[var2_15].toReturn = toReturn.getMatrixAt((int)var2_15);
                                ConvLayer.this.backpropTasks[var2_15].t2 = new Tensor(ConvLayer.this.inWidth, ConvLayer.this.inHeight, ConvLayer.this.inDepth);
                                a.add(ConvLayer.this.threadPool.submit(ConvLayer.this.backpropTasks[var2_15]));
                                ++var2_15;
                            }
                            for (Future future : a) {
                                while (future.get() != null) {
                                }
                            }
                            boolean bl2 = false;
                            while (var2_18 < ConvLayer.this.filtersPerDepth) {
                                int j = 0;
                                while (j < ConvLayer.this.inDepth) {
                                    int k = 0;
                                    while (k < ConvLayer.this.inWidth) {
                                        int l = 0;
                                        while (l < ConvLayer.this.inHeight) {
                                            t.getMatrixAt(j).setDW(l, k, t.getMatrixAt(j).getDW(l, k) + ConvLayer.this.backpropTasks[var2_18].t2.getMatrixAt(j).getDW(l, k));
                                            ++l;
                                        }
                                        ++k;
                                    }
                                    ++j;
                                }
                                ++var2_18;
                            }
                        }
                    }
                    catch (Exception e) {
                        g.setBackpropException(e);
                        return;
                    }
                }
            });
        }
        int i = 0;
        while (i < this.filtersPerDepth) {
            toReturn.setMatrixAt(i, g.add(toReturn.getMatrixAt(i), this.bias.getMatrixAt(i)));
            ++i;
        }
        return toReturn;
    }

    public int getOutWidth() {
        return this.outWidth;
    }

    public int getOutHeight() {
        return this.outHeight;
    }

    @Override
    public List<Matrix> getParameters() {
        ArrayList<Matrix> result = new ArrayList<Matrix>();
        Matrix[] matrixArray = this.filters;
        int n = this.filters.length;
        int n2 = 0;
        while (n2 < n) {
            Matrix l = matrixArray[n2];
            result.add(l);
            ++n2;
        }
        matrixArray = this.bias.getMatrices();
        n = matrixArray.length;
        n2 = 0;
        while (n2 < n) {
            Matrix m = matrixArray[n2];
            result.add(m);
            ++n2;
        }
        return result;
    }

    @Override
    public void resetState() {
    }

    public int getCores() {
        return this.multithreading ? this.cores : 1;
    }

    public void setCores(int cores) throws Exception {
        if (this.threadPool != null && this.threadPool.getActiveCount() != 0) {
            throw new Exception("Can't change core count while thread pool is in use");
        }
        if (this.threadPool != null) {
            this.threadPool.shutdown();
        }
        this.threadPool = (ThreadPoolExecutor)Executors.newFixedThreadPool(cores);
        this.cores = cores;
    }

    public boolean isMultithreading() {
        return this.multithreading;
    }

    public void setMultithreading(boolean mt) throws Exception {
        if (mt == this.multithreading) {
            return;
        }
        if (!mt) {
            if (this.threadPool != null && this.threadPool.getActiveCount() != 0) {
                throw new Exception("Can't disable multithreading while thread pool is in use");
            }
            if (this.threadPool != null) {
                this.threadPool.shutdown();
            }
        } else {
            this.threadPool = (ThreadPoolExecutor)Executors.newFixedThreadPool(this.cores);
        }
        this.multithreading = mt;
    }

    @Override
    public void saveState(FileOutputStream fos) throws Exception {
        fos.write(ByteConverters.intToBytes(this.stride));
        fos.write(ByteConverters.intToBytes(this.inWidth));
        fos.write(ByteConverters.intToBytes(this.inHeight));
        fos.write(ByteConverters.intToBytes(this.inDepth));
        fos.write(ByteConverters.intToBytes(this.filterWidth));
        fos.write(ByteConverters.intToBytes(this.filterHeight));
        fos.write(ByteConverters.intToBytes(this.filtersPerDepth));
        fos.write(ByteConverters.intToBytes(this.outWidth));
        fos.write(ByteConverters.intToBytes(this.outHeight));
        this.bias.save(fos);
        fos.write(ByteConverters.intToBytes(this.filters.length));
        Matrix[] matrixArray = this.filters;
        int n = this.filters.length;
        int n2 = 0;
        while (n2 < n) {
            Matrix m = matrixArray[n2];
            m.save(fos);
            ++n2;
        }
    }

    @Override
    public void loadState(FileInputStream fis) throws Exception {
        byte[] intBuffer = new byte[4];
        fis.read(intBuffer);
        this.stride = ByteConverters.bytesToInt(intBuffer);
        fis.read(intBuffer);
        this.inWidth = ByteConverters.bytesToInt(intBuffer);
        fis.read(intBuffer);
        this.inHeight = ByteConverters.bytesToInt(intBuffer);
        fis.read(intBuffer);
        this.inDepth = ByteConverters.bytesToInt(intBuffer);
        fis.read(intBuffer);
        this.filterWidth = ByteConverters.bytesToInt(intBuffer);
        fis.read(intBuffer);
        this.filterHeight = ByteConverters.bytesToInt(intBuffer);
        fis.read(intBuffer);
        this.filtersPerDepth = ByteConverters.bytesToInt(intBuffer);
        fis.read(intBuffer);
        this.outWidth = ByteConverters.bytesToInt(intBuffer);
        fis.read(intBuffer);
        this.outHeight = ByteConverters.bytesToInt(intBuffer);
        this.bias.load(fis);
        fis.read(intBuffer);
        this.filters = new Matrix[ByteConverters.bytesToInt(intBuffer)];
        int i = 0;
        while (i < this.filters.length) {
            this.filters[i] = new Matrix(this.filterHeight, this.filterWidth);
            this.filters[i].load(fis);
            ++i;
        }
        this.tasks = new FilterThread[this.filtersPerDepth];
        i = 0;
        while (i < this.filtersPerDepth) {
            this.tasks[i] = new FilterThread(i);
            ++i;
        }
        this.backpropTasks = new BackpropThread[this.filtersPerDepth];
        i = 0;
        while (i < this.filtersPerDepth) {
            this.backpropTasks[i] = new BackpropThread(i);
            ++i;
        }
    }

    private class BackpropThread
    implements Runnable {
        private int numi;
        private Matrix patchedGrad;
        private Matrix toReturn;
        private Tensor t2;

        private BackpropThread(int i) {
            this.patchedGrad = new Matrix(ConvLayer.this.inHeight - ConvLayer.this.filterHeight + 1, ConvLayer.this.inWidth - ConvLayer.this.filterWidth + 1);
            this.numi = i;
        }

        @Override
        public void run() {
            int i = 0;
            while (i < this.patchedGrad.cols) {
                int j = 0;
                while (j < this.patchedGrad.rows) {
                    if (i % ConvLayer.this.stride == 0 && j % ConvLayer.this.stride == 0) {
                        this.patchedGrad.setDW(j, i, this.toReturn.getDW(j / ConvLayer.this.stride, i / ConvLayer.this.stride));
                    } else {
                        this.patchedGrad.setDW(j, i, 0.0);
                    }
                    ++j;
                }
                ++i;
            }
            int numj = 0;
            while (numj < ConvLayer.this.inDepth) {
                int k = -ConvLayer.this.filterHeight + 1;
                while (k < this.patchedGrad.rows) {
                    int l = -ConvLayer.this.filterWidth + 1;
                    while (l < this.patchedGrad.cols) {
                        if (k % ConvLayer.this.stride == 0 && l % ConvLayer.this.stride == 0) {
                            int m = 0;
                            while (m < ConvLayer.this.filterHeight) {
                                int n = 0;
                                while (n < ConvLayer.this.filterWidth) {
                                    double w2 = k + m < 0 || l + n < 0 || k + m >= this.patchedGrad.rows || l + n >= this.patchedGrad.cols ? 0.0 : this.patchedGrad.getDW(k + m, l + n);
                                    ConvLayer.this.filters[this.numi * ConvLayer.this.inDepth + numj].setDW(m, n, ConvLayer.this.filters[this.numi * ConvLayer.this.inDepth + numj].getDW(m, n) + w2 * this.t2.getMatrixAt(numj).getW(k + ConvLayer.this.filterHeight - 1, l + ConvLayer.this.filterWidth - 1));
                                    this.t2.getMatrixAt(numj).setDW(k + ConvLayer.this.filterHeight - 1, l + ConvLayer.this.filterWidth - 1, this.t2.getMatrixAt(numj).getDW(k + ConvLayer.this.filterHeight - 1, l + ConvLayer.this.filterWidth - 1) + w2 * ConvLayer.this.filters[this.numi * ConvLayer.this.inDepth + numj].getW(m, n));
                                    ++n;
                                }
                                ++m;
                            }
                        }
                        ++l;
                    }
                    ++k;
                }
                ++numj;
            }
        }
    }

    private class FilterThread
    implements Runnable {
        private int i;
        private Matrix res;
        private Tensor t;

        private FilterThread(int i) {
            this.res = new Matrix(ConvLayer.this.outHeight, ConvLayer.this.outWidth);
            this.i = i;
        }

        @Override
        public void run() {
            int j = 0;
            while (j < ConvLayer.this.inDepth) {
                int k = 0;
                while (k < ConvLayer.this.outHeight) {
                    int l = 0;
                    while (l < ConvLayer.this.outWidth) {
                        int m = 0;
                        while (m < ConvLayer.this.filterHeight) {
                            int n = 0;
                            while (n < ConvLayer.this.filterWidth) {
                                double w2 = this.t.getMatrixAt(j).getW(k * ConvLayer.this.stride + m, l * ConvLayer.this.stride + n);
                                this.res.setW(k, l, this.res.getW(k, l) + (w2 *= ConvLayer.this.filters[this.i * ConvLayer.this.inDepth + j].getW(ConvLayer.this.filterHeight - m - 1, ConvLayer.this.filterWidth - n - 1)));
                                ++n;
                            }
                            ++m;
                        }
                        ++l;
                    }
                    ++k;
                }
                ++j;
            }
        }
    }
}

