/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.GPUGraph;
import autodiff.Graph;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.UUID;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;
import matrix.Matrix;
import matrix.Tensor;
import model.TensorLayer;
import nonlinearities.Nonlinearity;

public class ConvLayer
implements TensorLayer {
    private int stride;
    private int pad;
    private int inWidth;
    private int inHeight;
    private int inDepth;
    private int filterWidth;
    private int filterHeight;
    private int filtersPerDepth;
    private int outWidth;
    private int outHeight;
    private Matrix[] filters;
    private Matrix biases;
    private boolean gpu = false;
    private boolean multithreading = false;
    private int cores = 1;
    private ThreadPoolExecutor threadPool;
    private FilterThread[] tasks;
    private Nonlinearity nonlin;
    private final UUID parameterUUID = UUID.randomUUID();
    private Tensor multicoreT;

    public ConvLayer(int inWidth, int inHeight, int inDepth, int filterWidth, int filterHeight, int numFiltersPerDepth, int stride, int pad, Nonlinearity nonlin, double initParamsStdDev, Random rng, boolean gpu, boolean multithreading, int cores) {
        this(inWidth, inHeight, inDepth, filterWidth, filterHeight, numFiltersPerDepth, stride, pad, nonlin, initParamsStdDev, rng);
        if (gpu && !multithreading) {
            cores = 1;
            multithreading = true;
        }
        this.gpu = gpu;
        this.multithreading = multithreading;
        this.threadPool = (ThreadPoolExecutor)Executors.newFixedThreadPool(cores);
        this.cores = cores;
    }

    public ConvLayer(int inWidth, int inHeight, int inDepth, int filterWidth, int filterHeight, int numFiltersPerDepth, int stride, int pad, Nonlinearity nonlin, double initParamsStdDev, Random rng) {
        this.stride = stride;
        this.pad = pad;
        this.nonlin = nonlin;
        this.inWidth = inWidth;
        this.inHeight = inHeight;
        this.inDepth = inDepth;
        this.filterHeight = filterHeight;
        this.filterWidth = filterWidth;
        this.filtersPerDepth = numFiltersPerDepth;
        this.outWidth = (inWidth + pad * 2 - filterWidth) / stride + 1;
        this.outHeight = (inHeight + pad * 2 - filterHeight) / stride + 1;
        this.filters = new Matrix[numFiltersPerDepth * inDepth];
        int i = 0;
        while (i < this.filters.length) {
            this.filters[i] = Matrix.rand(filterHeight, filterWidth, initParamsStdDev, rng);
            ++i;
        }
        this.initTasks();
        this.biases = Matrix.uniform(numFiltersPerDepth, 1, 0.0);
        this.multithreading = false;
        this.gpu = false;
    }

    private void initTasks() {
        this.multicoreT = new Tensor(this.inWidth, this.inHeight, this.inDepth);
        int numTasks = this.filtersPerDepth;
        this.tasks = new FilterThread[numTasks];
        int i = 0;
        while (i < numTasks) {
            this.tasks[i] = new FilterThread(i);
            ++i;
        }
    }

    public void setFilter(int indx, Matrix filter) {
        if (filter.rows != this.filterHeight && filter.cols != this.filterWidth) {
            return;
        }
        if (indx >= this.filters.length) {
            return;
        }
        this.filters[indx] = filter;
    }

    public Matrix getFilter(int indx) {
        return this.filters[indx];
    }

    private synchronized void updateFilterGrad(int filterIndx, double[] dwBuffer) {
        Matrix filter = this.filters[filterIndx];
        int i = 0;
        while (i < filter.dw.length) {
            int n = i;
            filter.dw[n] = filter.dw[n] + dwBuffer[i];
            ++i;
        }
    }

    private synchronized void updateMulticoreTGrad(int matrixIndx, double[] dwBuffer) {
        Matrix m = this.multicoreT.matrices[matrixIndx];
        int i = 0;
        while (i < m.dw.length) {
            m.dw[i] = dwBuffer[i];
            ++i;
        }
    }

    /*
     * WARNING - void declaration
     */
    @Override
    public Tensor forward(final Tensor t, final Graph g) throws Exception {
        Tensor toReturn;
        Tensor res;
        if (t.depth != this.inDepth) {
            throw new Exception("Invalid tensor depth: Is " + Integer.toString(t.depth) + ", expected: " + Integer.toString(this.inDepth));
        }
        if (t.width != this.inWidth) {
            throw new Exception("Invalid tensor width: Is " + Integer.toString(t.width) + ", expected: " + Integer.toString(this.inWidth));
        }
        if (t.height != this.inHeight) {
            throw new Exception("Invalid tensor height: Is " + Integer.toString(t.height) + ", expected: " + Integer.toString(this.inHeight));
        }
        final Runnable[] rs = new Runnable[this.tasks.length];
        if (this.gpu && g instanceof GPUGraph) {
            toReturn = res = ((GPUGraph)g).fullConv(g.isApplyingBackprop(), this.pad, this.stride, this.filtersPerDepth, t, this.biases, this.filters, this.nonlin, this.parameterUUID);
        } else if (this.multithreading) {
            void var6_15;
            void var6_13;
            void var6_10;
            void var6_8;
            toReturn = new Tensor(this.outWidth, this.outHeight, this.filtersPerDepth);
            ArrayList a = new ArrayList();
            boolean bl = false;
            while (var6_8 < this.inDepth) {
                System.arraycopy(t.matrices[var6_8].w, 0, this.multicoreT.matrices[var6_8].w, 0, this.inWidth * this.inHeight);
                ++var6_8;
            }
            boolean bl2 = false;
            while (var6_10 < this.filtersPerDepth) {
                this.tasks[var6_10].res.w = new double[this.outWidth * this.outHeight];
                this.tasks[var6_10].doBackprop = g.isApplyingBackprop();
                a.add(this.threadPool.submit(this.tasks[var6_10]));
                ++var6_10;
            }
            for (Future future : a) {
                while (future.get() != null) {
                }
            }
            boolean bl3 = false;
            while (var6_13 < this.filtersPerDepth) {
                System.arraycopy(this.tasks[var6_13].res.w, 0, toReturn.matrices[var6_13].w, 0, this.outWidth * this.outHeight);
                ++var6_13;
            }
            boolean bl4 = false;
            while (var6_15 < this.tasks.length) {
                rs[var6_15] = this.tasks[var6_15].r;
                ++var6_15;
            }
        } else {
            void var6_17;
            toReturn = new Tensor(this.outWidth, this.outHeight, this.filtersPerDepth);
            res = g.convolution(this.pad, this.stride, this.filtersPerDepth, t, this.filters);
            boolean bl = false;
            while (var6_17 < res.depth) {
                Matrix m = res.matrices[var6_17];
                int j = 0;
                while (j < m.w.length) {
                    int n = j++;
                    m.w[n] = m.w[n] + this.biases.w[var6_17];
                }
                toReturn.matrices[var6_17] = m;
                ++var6_17;
            }
        }
        if (!(!g.isApplyingBackprop() || this.gpu && g instanceof GPUGraph)) {
            g.addBackprop(new Runnable(){

                /*
                 * WARNING - void declaration
                 */
                @Override
                public void run() {
                    try {
                        int i = 0;
                        while (i < ConvLayer.this.filtersPerDepth) {
                            Matrix matrix = toReturn.matrices[i];
                            int j = 0;
                            while (j < matrix.w.length) {
                                int n = i;
                                ConvLayer.this.biases.dw[n] = ConvLayer.this.biases.dw[n] + matrix.dw[j];
                                ++j;
                            }
                            ++i;
                        }
                        if (ConvLayer.this.multithreading) {
                            void var2_11;
                            void var2_8;
                            void var2_6;
                            ArrayList a = new ArrayList();
                            boolean bl = false;
                            while (var2_6 < ConvLayer.this.inDepth) {
                                System.arraycopy(t.matrices[var2_6].w, 0, ConvLayer.this.multicoreT.matrices[var2_6].w, 0, ConvLayer.this.inWidth * ConvLayer.this.inHeight);
                                ++var2_6;
                            }
                            boolean bl2 = false;
                            while (var2_8 < ConvLayer.this.filtersPerDepth) {
                                System.arraycopy(toReturn.matrices[var2_8].dw, 0, ConvLayer.this.tasks[var2_8].res.dw, 0, ConvLayer.this.filterWidth * ConvLayer.this.filterHeight);
                                a.add(ConvLayer.this.threadPool.submit(rs[var2_8]));
                                ++var2_8;
                            }
                            for (Future future : a) {
                                while (future.get() != null) {
                                }
                            }
                            boolean bl3 = false;
                            while (var2_11 < ConvLayer.this.inDepth) {
                                Matrix inm = t.matrices[var2_11];
                                Matrix taskm = ConvLayer.this.multicoreT.matrices[var2_11];
                                int k = 0;
                                while (k < ConvLayer.this.inWidth * ConvLayer.this.inHeight) {
                                    int n = k;
                                    inm.dw[n] = inm.dw[n] + taskm.dw[k];
                                    ++k;
                                }
                                k = 0;
                                while (k < ConvLayer.this.inWidth * ConvLayer.this.inHeight) {
                                    taskm.dw[k] = 0.0;
                                    ++k;
                                }
                                ++var2_11;
                            }
                        }
                    }
                    catch (Exception e) {
                        g.setBackpropException(e);
                        return;
                    }
                }
            });
        }
        if (!this.gpu || !(g instanceof GPUGraph)) {
            int i = 0;
            while (i < this.filtersPerDepth) {
                Matrix matrix = toReturn.matrices[i];
                toReturn.matrices[i] = g.nonlin(this.nonlin, matrix);
                ++i;
            }
        }
        return toReturn;
    }

    public int getOutWidth() {
        return this.outWidth;
    }

    public int getOutHeight() {
        return this.outHeight;
    }

    public int getFilterCount() {
        return this.filters.length;
    }

    public int getFilterWidth() {
        return this.filterWidth;
    }

    public int getFilterHeight() {
        return this.filterHeight;
    }

    @Override
    public List<Matrix> getParameters() {
        ArrayList<Matrix> result = new ArrayList<Matrix>();
        Matrix[] matrixArray = this.filters;
        int n = this.filters.length;
        int n2 = 0;
        while (n2 < n) {
            Matrix l = matrixArray[n2];
            result.add(l);
            ++n2;
        }
        result.add(this.biases);
        return result;
    }

    @Override
    public void resetState() {
    }

    public int getCores() {
        return this.multithreading ? this.cores : 1;
    }

    public void setCores(int cores) throws Exception {
        if (this.threadPool != null && this.threadPool.getActiveCount() != 0) {
            throw new Exception("Can't change core count while thread pool is in use");
        }
        if (this.threadPool != null) {
            this.threadPool.shutdown();
        }
        this.threadPool = (ThreadPoolExecutor)Executors.newFixedThreadPool(cores);
        this.cores = cores;
    }

    public boolean isMultithreading() {
        return this.multithreading;
    }

    public void setMultithreading(boolean mt) throws Exception {
        if (mt == this.multithreading) {
            return;
        }
        if (!mt) {
            if (this.threadPool != null && this.threadPool.getActiveCount() != 0) {
                throw new Exception("Can't disable multithreading while thread pool is in use");
            }
            if (this.threadPool != null) {
                this.threadPool.shutdown();
            }
        } else {
            this.threadPool = (ThreadPoolExecutor)Executors.newFixedThreadPool(this.cores);
        }
        this.multithreading = mt;
    }

    public boolean isGPUAccelerating() {
        return this.gpu;
    }

    public void setGPUAccelerating(boolean gpu) {
        this.gpu = gpu;
    }

    @Override
    public TensorLayer clone() {
        ConvLayer clone = new ConvLayer(this.inWidth, this.inHeight, this.inDepth, this.filterWidth, this.filterHeight, this.filtersPerDepth, this.stride, this.pad, this.nonlin, 1.0, new Random(), this.gpu, this.multithreading, this.cores);
        clone.biases = this.biases.clone();
        int i = 0;
        while (i < this.filters.length) {
            clone.filters[i] = this.filters[i].clone();
            ++i;
        }
        return clone;
    }

    @Override
    public void saveState(DataOutputStream fos) throws Exception {
        fos.writeInt(this.stride);
        fos.writeInt(this.inWidth);
        fos.writeInt(this.inHeight);
        fos.writeInt(this.inDepth);
        fos.writeInt(this.filterWidth);
        fos.writeInt(this.filterHeight);
        fos.writeInt(this.filtersPerDepth);
        fos.writeInt(this.outWidth);
        fos.writeInt(this.outHeight);
        fos.writeInt(this.pad);
        this.biases.save(fos);
        fos.writeInt(this.filters.length);
        int i = 0;
        while (i < this.filters.length) {
            this.filters[i].save(fos);
            ++i;
        }
    }

    @Override
    public void loadState(DataInputStream fis) throws Exception {
        this.stride = fis.readInt();
        this.inWidth = fis.readInt();
        this.inHeight = fis.readInt();
        this.inDepth = fis.readInt();
        this.filterWidth = fis.readInt();
        this.filterHeight = fis.readInt();
        this.filtersPerDepth = fis.readInt();
        this.outWidth = fis.readInt();
        this.outHeight = fis.readInt();
        this.pad = fis.readInt();
        this.biases.load(fis);
        this.filters = new Matrix[fis.readInt()];
        int i = 0;
        while (i < this.filters.length) {
            this.filters[i] = new Matrix(this.filterHeight, this.filterWidth);
            this.filters[i].load(fis);
            ++i;
        }
        this.initTasks();
    }

    private class FilterThread
    implements Runnable {
        private int i;
        private Matrix res;
        private Tensor t;
        private Runnable r;
        private boolean doBackprop;

        private FilterThread(int i) {
            this.res = new Matrix(ConvLayer.this.outHeight, ConvLayer.this.outWidth, false);
            this.i = i;
            this.t = ConvLayer.this.multicoreT;
        }

        @Override
        public void run() {
            int j = 0;
            while (j < this.t.depth) {
                Matrix inm = this.t.matrices[j];
                Matrix currFilter = ConvLayer.this.filters[this.i * this.t.depth + j];
                int k = 0;
                while (k < this.res.rows) {
                    int l = 0;
                    while (l < this.res.cols) {
                        int outIndx = k * this.res.cols + l;
                        int y = k * ConvLayer.this.stride - ConvLayer.this.pad;
                        int loopEnd1 = y + ConvLayer.this.filters[0].rows - 1 >= this.t.height ? ConvLayer.this.filters[0].rows + (this.t.height - (y + ConvLayer.this.filters[0].rows - 1)) - 1 : ConvLayer.this.filters[0].rows;
                        int x = l * ConvLayer.this.stride - ConvLayer.this.pad;
                        int loopEnd2 = x + ConvLayer.this.filters[0].cols - 1 >= this.t.width ? ConvLayer.this.filters[0].cols + (this.t.width - (x + ConvLayer.this.filters[0].cols - 1)) - 1 : ConvLayer.this.filters[0].cols;
                        int m = y < 0 ? -y : 0;
                        while (m < loopEnd1) {
                            int inIndx = (y + m) * inm.cols + x;
                            int filterIndx = m * currFilter.cols;
                            int n = x < 0 ? -x : 0;
                            while (n < loopEnd2) {
                                int n2 = outIndx;
                                this.res.w[n2] = this.res.w[n2] + inm.w[inIndx + n] * currFilter.w[filterIndx + n];
                                ++n;
                            }
                            ++m;
                        }
                        ++l;
                    }
                    ++k;
                }
                ++j;
            }
            if (this.doBackprop) {
                this.r = new Runnable(){

                    @Override
                    public void run() {
                        int j = 0;
                        while (j < FilterThread.this.t.depth) {
                            Matrix inm = FilterThread.this.t.matrices[j];
                            Matrix currFilter = ((FilterThread)FilterThread.this).ConvLayer.this.filters[FilterThread.this.i * FilterThread.this.t.depth + j];
                            double[] filterDW = new double[currFilter.dw.length];
                            double[] tDW = new double[((FilterThread)FilterThread.this).ConvLayer.this.inWidth * ((FilterThread)FilterThread.this).ConvLayer.this.inHeight];
                            int k = 0;
                            while (k < FilterThread.this.res.rows) {
                                int l = 0;
                                while (l < FilterThread.this.res.cols) {
                                    double grad = FilterThread.this.res.dw[k * FilterThread.this.res.cols + l];
                                    int y = k * ((FilterThread)FilterThread.this).ConvLayer.this.stride - ((FilterThread)FilterThread.this).ConvLayer.this.pad;
                                    int loopEnd1 = y + ((FilterThread)FilterThread.this).ConvLayer.this.filters[0].rows - 1 >= FilterThread.this.t.height ? ((FilterThread)FilterThread.this).ConvLayer.this.filters[0].rows + (FilterThread.this.t.height - (y + ((FilterThread)FilterThread.this).ConvLayer.this.filters[0].rows - 1)) - 1 : ((FilterThread)FilterThread.this).ConvLayer.this.filters[0].rows;
                                    int x = l * ((FilterThread)FilterThread.this).ConvLayer.this.stride - ((FilterThread)FilterThread.this).ConvLayer.this.pad;
                                    int loopEnd2 = x + ((FilterThread)FilterThread.this).ConvLayer.this.filters[0].cols - 1 >= FilterThread.this.t.width ? ((FilterThread)FilterThread.this).ConvLayer.this.filters[0].cols + (FilterThread.this.t.width - (x + ((FilterThread)FilterThread.this).ConvLayer.this.filters[0].cols - 1)) - 1 : ((FilterThread)FilterThread.this).ConvLayer.this.filters[0].cols;
                                    int m = y < 0 ? -y : 0;
                                    while (m < loopEnd1) {
                                        int inIndx = (y + m) * inm.cols + x;
                                        int filterIndx = m * currFilter.cols;
                                        int n = x < 0 ? -x : 0;
                                        while (n < loopEnd2) {
                                            int n2 = inIndx + n;
                                            tDW[n2] = tDW[n2] + currFilter.w[filterIndx + n] * grad;
                                            ++n;
                                        }
                                        n = x < 0 ? -x : 0;
                                        while (n < loopEnd2) {
                                            int n3 = filterIndx + n;
                                            filterDW[n3] = filterDW[n3] + inm.w[inIndx + n] * grad;
                                            ++n;
                                        }
                                        ++m;
                                    }
                                    ++l;
                                }
                                ++k;
                            }
                            ConvLayer.this.updateFilterGrad(FilterThread.this.i * FilterThread.this.t.depth + j, filterDW);
                            ConvLayer.this.updateMulticoreTGrad(j, tDW);
                            ++j;
                        }
                    }
                };
            }
        }
    }
}

