/*
 * Decompiled with CFR 0.152.
 */
package theGhastModding.lstmStuff.gpu;

import com.aparapi.Kernel;
import com.aparapi.Range;
import matrix.Matrix;
import matrix.Tensor;

public class GPUUtils {
    private MatrixMultiplier segmem = new MatrixMultiplier();
    private Adam adam = new Adam();
    private RMSProp rmsProp = new RMSProp();
    private MatrixVectorMul mul = new MatrixVectorMul();
    private GrayscaleFilter grayscaleFilter = new GrayscaleFilter();
    private ApplyFilter applyFilter = new ApplyFilter();

    static {
        System.setProperty("com.aparapi.enableShowGeneratedOpenCL", "false");
    }

    public Tensor grayscaleFilter(Tensor in) {
        Range r = this.grayscaleFilter.setArguments(in);
        this.grayscaleFilter.execute(r);
        Tensor toReturn = new Tensor(in.getWidth(), in.getHeight(), 1);
        toReturn.setMatrixAt(0, this.grayscaleFilter.getResult());
        return toReturn;
    }

    public Matrix grayscaleFilter2(Tensor in) {
        Range r = this.grayscaleFilter.setArguments(in);
        this.grayscaleFilter.execute(r);
        return this.grayscaleFilter.getResult();
    }

    public Matrix matrixVectorMul(Matrix matrix, Matrix vector) {
        Range r = this.mul.setArguments(matrix, vector);
        this.mul.execute(r);
        return this.mul.getResult();
    }

    public Matrix applyFilter(Matrix[] in, Matrix[] filters, int stride, int outWidth, int outHeight) throws Exception {
        Range r = this.applyFilter.setParameters(in, filters, stride, outWidth, outHeight);
        this.applyFilter.execute(r);
        Matrix[] res = this.applyFilter.getResults();
        Matrix toReturn = new Matrix(outHeight, outWidth);
        int i = 0;
        while (i < res.length) {
            int j = 0;
            while (j < toReturn.w.length) {
                int n = j;
                toReturn.w[n] = toReturn.w[n] + res[i].w[j];
                ++j;
            }
            ++i;
        }
        return toReturn;
    }

    public Matrix RMSProp(Matrix in, double decay, double learningRate, double epsilon, double gradientClipValue) throws Exception {
        Range range = this.rmsProp.setArguments(in, decay, learningRate, epsilon, gradientClipValue);
        this.rmsProp.execute(range);
        return this.rmsProp.getResult();
    }

    public Matrix[] adam(Matrix[] in, double decayB1, double decayB2, int t, double learningRate, double smoothEpsilon) throws Exception {
        Matrix[] out = new Matrix[in.length];
        int i = 0;
        while (i < in.length) {
            out[i] = this.adam(in[i], decayB1, decayB2, t, learningRate, smoothEpsilon);
            ++i;
        }
        return out;
    }

    public Matrix adam(Matrix in, double decayB1, double decayB2, int t, double learningRate, double smoothEpsilon) throws Exception {
        Range range = this.adam.setArguments(in, decayB1, decayB2, t, learningRate, smoothEpsilon);
        this.adam.execute(range);
        return this.adam.getResult();
    }

    public Matrix mul(Matrix m1, Matrix m2) throws Exception {
        if (m1.cols != m2.rows) {
            throw new Exception("matrix dimension mismatch");
        }
        Range rangeToUse = this.segmem.setArgs(m1, m2);
        this.segmem.execute(rangeToUse);
        return this.segmem.getResultMatrix();
    }

    class Adam
    extends Kernel {
        double[] w;
        double[] dw;
        double decayB1;
        double decayB2;
        double t;
        double[] stepCache;
        double[] outStepCache;
        double[] outW;
        int rows;
        int cols;
        int length;
        double learningRate;
        double smoothEpsilon;

        Adam() {
        }

        @Override
        public void run() {
            double val2_$local$;
            double val_$local$;
            int i = this.getGlobalId(0);
            int j = this.getGlobalId(1);
            int indx_$local$ = this.cols * i + j;
            double decayB1_fixed_$local$ = this.decayB1 * (1.0 - this.pow(this.decayB1, this.t - 1.0)) / (1.0 - this.pow(this.decayB1, this.t));
            double decayB2_fixed_$local$ = this.decayB2 * (1.0 - this.pow(this.decayB2, this.t - 1.0)) / (1.0 - this.pow(this.decayB2, this.t));
            this.outStepCache[indx_$local$] = val_$local$ = decayB1_fixed_$local$ * this.stepCache[indx_$local$] + (1.0 - decayB1_fixed_$local$) * this.dw[indx_$local$];
            this.outStepCache[indx_$local$ + this.length] = val2_$local$ = decayB2_fixed_$local$ * this.stepCache[indx_$local$ + this.length] + (1.0 - decayB2_fixed_$local$) * (this.dw[indx_$local$] * this.dw[indx_$local$]);
            this.outW[indx_$local$] = this.w[indx_$local$] - this.learningRate * val_$local$ / this.sqrt(val2_$local$ + this.smoothEpsilon);
        }

        public Range setArguments(Matrix in, double decayB1, double decayB2, int t, double learningRate, double smoothEpsilon) {
            this.w = in.w;
            this.dw = in.dw;
            this.decayB1 = decayB1;
            this.decayB2 = decayB2;
            this.t = t;
            this.stepCache = in.stepCache;
            this.outStepCache = new double[in.stepCache.length];
            this.outW = new double[in.w.length];
            this.rows = in.rows;
            this.cols = in.cols;
            this.length = in.w.length;
            this.learningRate = learningRate;
            this.smoothEpsilon = smoothEpsilon;
            return Range.create2D(in.rows, in.cols);
        }

        public Matrix getResult() {
            Matrix toReturn = new Matrix(this.rows, this.cols);
            toReturn.w = this.outW;
            toReturn.stepCache = this.outStepCache;
            return toReturn;
        }
    }

    class ApplyFilter
    extends Kernel {
        private double[][] in;
        private double[][] out;
        private double[][] filters;
        int filterHeight;
        int filterWidth;
        int stride;
        int inWidth;
        int outWidth;
        int outHeight;

        ApplyFilter() {
        }

        @Override
        public void run() {
            int j = this.getGlobalId(0);
            int k = this.getGlobalId(1);
            int l = this.getGlobalId(2);
            int m = 0;
            while (m < this.filterHeight) {
                int n = 0;
                while (n < this.filterWidth) {
                    double w2_$local$ = this.in[j][(k * this.stride + m) * this.inWidth + (l * this.stride + n)];
                    double[] dArray = this.out[j];
                    int n2 = k * this.outWidth + l;
                    dArray[n2] = dArray[n2] + (w2_$local$ *= this.filters[j][(this.filterHeight - m - 1) * this.filterWidth + (this.filterWidth - n - 1)]);
                    ++n;
                }
                ++m;
            }
        }

        public Range setParameters(Matrix[] in, Matrix[] filters, int stride, int outWidth, int outHeight) {
            this.out = new double[in.length][outWidth * outHeight];
            this.in = new double[in.length][];
            int i = 0;
            while (i < in.length) {
                this.in[i] = in[i].w;
                ++i;
            }
            this.filters = new double[filters.length][];
            i = 0;
            while (i < filters.length) {
                this.filters[i] = filters[i].w;
                ++i;
            }
            this.filterHeight = filters[0].rows;
            this.filterWidth = filters[0].cols;
            this.stride = stride;
            this.outWidth = outWidth;
            this.outHeight = outHeight;
            this.inWidth = in[0].cols;
            Range r = Range.create3D(in.length, outHeight, outWidth);
            return r;
        }

        public Matrix[] getResults() {
            Matrix[] toReturn = new Matrix[this.out.length];
            int i = 0;
            while (i < toReturn.length) {
                toReturn[i] = new Matrix(this.outHeight, this.outWidth);
                toReturn[i].w = this.out[i];
                ++i;
            }
            return toReturn;
        }
    }

    class GrayscaleFilter
    extends Kernel {
        double[] inRed;
        double[] inGreen;
        double[] inBlue;
        int inWidth;
        int inHeight;

        GrayscaleFilter() {
        }

        @Override
        public void run() {
            int k_$local$ = this.getGlobalId(0) * this.inHeight + this.getGlobalId(1);
            this.inRed[k_$local$] = (float)this.inRed[k_$local$] / 3.0f + (float)this.inGreen[k_$local$] / 3.0f + (float)this.inBlue[k_$local$] / 3.0f;
        }

        public Range setArguments(Tensor in) {
            this.inRed = in.getMatrixAt((int)0).w;
            this.inGreen = in.getMatrixAt((int)1).w;
            this.inBlue = in.getMatrixAt((int)2).w;
            this.inWidth = in.getWidth();
            this.inHeight = in.getHeight();
            Range r = Range.create2D(in.getWidth(), in.getHeight());
            return r;
        }

        public Matrix getResult() {
            Matrix m = new Matrix(this.inHeight, this.inWidth);
            m.w = this.inRed;
            return m;
        }
    }

    class MatrixMultiplier
    extends Kernel {
        double[] mat1;
        double[] mat2;
        double[] res;
        int rows;
        int cols1;
        int cols2;

        MatrixMultiplier() {
        }

        @Override
        public void run() {
            int i = this.getGlobalId(0);
            int j = this.getGlobalId(1);
            double dot_$local$ = 0.0;
            int k = 0;
            while (k < this.cols1) {
                dot_$local$ += this.mat1[this.cols1 * i + k] * this.mat2[this.cols2 * k + j];
                ++k;
            }
            this.res[this.cols2 * i + j] = dot_$local$;
        }

        public Range setArgs(Matrix m1, Matrix m2) {
            this.mat1 = m1.w;
            this.mat2 = m2.w;
            this.res = new double[m1.rows * m2.cols];
            this.rows = m1.rows;
            this.cols1 = m1.cols;
            this.cols2 = m2.cols;
            Range range = Range.create2D(m1.rows, m2.cols);
            return range;
        }

        public Matrix getResultMatrix() {
            Matrix result = new Matrix(this.rows, this.cols2);
            result.w = this.res;
            return result;
        }
    }

    class MatrixVectorMul
    extends Kernel {
        double[] matrix;
        double[] vector;
        double[] out;
        int inWidth;
        int inHeight;

        MatrixVectorMul() {
        }

        @Override
        public void run() {
            int k_$local$ = this.getGlobalId(0);
            double dot_$local$ = 0.0;
            int i = 0;
            while (i < this.inHeight) {
                dot_$local$ += this.matrix[this.inWidth * i + k_$local$];
                ++i;
            }
            this.out[k_$local$] = dot_$local$;
        }

        public Range setArguments(Matrix m, Matrix v) {
            this.matrix = m.w;
            this.vector = v.w;
            this.out = new double[m.rows];
            Range r = Range.create(m.rows);
            return r;
        }

        public Matrix getResult() {
            Matrix toReturn = new Matrix(this.vector.length);
            toReturn.w = this.vector;
            return toReturn;
        }
    }

    class RMSProp
    extends Kernel {
        double[] w;
        double[] dw;
        double decay;
        double[] stepCache;
        double[] outStepCache;
        double[] outW;
        int rows;
        int cols;
        int length;
        double learningRate;
        double epsilon;
        double gradientClipValue;
        final double regularization = 1.0E-6;

        RMSProp() {
        }

        @Override
        public void run() {
            double val_$local$;
            int i = this.getGlobalId(0);
            int j = this.getGlobalId(1);
            int indx_$local$ = this.cols * i + j;
            double mdwi_$local$ = this.dw[indx_$local$];
            this.outStepCache[indx_$local$] = val_$local$ = this.stepCache[indx_$local$] * this.decay + (1.0 - this.decay) * mdwi_$local$ * mdwi_$local$;
            if (mdwi_$local$ > this.gradientClipValue) {
                mdwi_$local$ = this.gradientClipValue;
            }
            if (mdwi_$local$ < -this.gradientClipValue) {
                mdwi_$local$ = -this.gradientClipValue;
            }
            this.outW[indx_$local$] = this.w[indx_$local$] + (-this.learningRate * mdwi_$local$ / this.sqrt(val_$local$ + this.epsilon) - 1.0E-6 * this.w[indx_$local$]);
        }

        public Range setArguments(Matrix in, double decay, double learningRate, double epsilon, double gradientClipValue) {
            this.w = in.w;
            this.dw = in.dw;
            this.decay = decay;
            this.stepCache = in.stepCache;
            this.outStepCache = new double[in.stepCache.length];
            this.outW = new double[in.w.length];
            this.rows = in.rows;
            this.cols = in.cols;
            this.length = in.w.length;
            this.learningRate = learningRate;
            this.epsilon = epsilon;
            this.gradientClipValue = gradientClipValue;
            return Range.create2D(in.rows, in.cols);
        }

        public Matrix getResult() {
            Matrix toReturn = new Matrix(this.rows, this.cols);
            toReturn.w = this.outW;
            toReturn.stepCache = this.outStepCache;
            return toReturn;
        }
    }
}

