/*
 * Decompiled with CFR 0.152.
 */
package theGhastModding.lstmStuff.main;

import java.util.Arrays;
import java.util.Random;
import matrix.Matrix;
import matrix.Tensor;
import theGhastModding.lstmStuff.gpu.GPUUtils;

public class GPUTester {
    public static void main(String[] args) {
        try {
            Random random = new Random();
            Tensor inputs = new Tensor(159, 119, 32);
            int i = 0;
            while (i < inputs.getDepth()) {
                inputs.setMatrixAt(i, Matrix.rand(119, 159, 0.08, random));
                ++i;
            }
            int filterWidth = 5;
            int filterHeight = 5;
            Matrix[] filters = new Matrix[2048];
            int i2 = 0;
            while (i2 < filters.length) {
                filters[i2] = Matrix.rand(5, 5, 0.08, random);
                ++i2;
            }
            int stride = 2;
            GPUUtils gpu = new GPUUtils();
            int outWidth = 78;
            int outHeight = 58;
            Tensor finalResult = new Tensor(outWidth, outHeight, 64);
            gpu.applyFilter(inputs.getMatrices(), filters, stride, outWidth, outHeight);
            long time = System.nanoTime();
            int i3 = 0;
            while (i3 < finalResult.getDepth()) {
                finalResult.setMatrixAt(i3, gpu.applyFilter(inputs.getMatrices(), Arrays.copyOfRange(filters, i3 * inputs.getDepth(), i3 * inputs.getDepth() + inputs.getDepth()), stride, outWidth, outHeight));
                ++i3;
            }
            long totalGpuTime = System.nanoTime() - time;
            time = System.nanoTime();
            int i4 = 0;
            while (i4 < 64) {
                int j = 0;
                while (j < inputs.getDepth()) {
                    int k = 0;
                    while (k < outHeight) {
                        int l = 0;
                        while (l < outWidth) {
                            int m = 0;
                            while (m < filterHeight) {
                                int n = 0;
                                while (n < filterWidth) {
                                    double w2 = inputs.getMatrixAt(j).getW(k * stride + m, l * stride + n);
                                    finalResult.getMatrixAt(i4).setW(k, l, finalResult.getMatrixAt(i4).getW(k, l) + (w2 *= filters[i4 * inputs.getDepth() + j].getW(filterHeight - m - 1, filterWidth - n - 1)));
                                    ++n;
                                }
                                ++m;
                            }
                            ++l;
                        }
                        ++k;
                    }
                    ++j;
                }
                ++i4;
            }
            long totalCpuTime = System.nanoTime() - time;
            System.out.println("Total GPU Time = " + Double.toString((double)totalGpuTime / 1000.0 / 1000.0 / 1000.0) + " seconds");
            System.out.println("Total CPU Time = " + Double.toString((double)totalCpuTime / 1000.0 / 1000.0 / 1000.0) + " seconds");
            System.out.println(totalGpuTime < totalCpuTime);
        }
        catch (Exception e) {
            System.err.println("Error: ");
            e.printStackTrace();
            System.exit(1);
        }
    }
}

