/*
 * Decompiled with CFR 0.152.
 */
package theGhastModding.lstmStuff.main;

import autodiff.GPUGraph;
import autodiff.Graph;
import java.util.Random;
import matrix.Matrix;
import org.jocl.CL;
import org.jocl.Pointer;
import org.jocl.blast.CLBlast;
import org.jocl.cl_command_queue;
import org.jocl.cl_context;
import org.jocl.cl_context_properties;
import org.jocl.cl_device_id;
import org.jocl.cl_platform_id;
import org.jocl.cl_queue_properties;

public class GPUTester {
    public static void main(String[] args) {
        try {
            String PLATFORM_NAME_HINT = "AMD";
            String DEVICE_NAME_HINT = "Tahiti";
            CL.setExceptionsEnabled(true);
            CLBlast.setExceptionsEnabled(true);
            int[] num = new int[1];
            CL.clGetPlatformIDs(0, null, num);
            cl_platform_id[] platforms = new cl_platform_id[num[0]];
            CL.clGetPlatformIDs(num[0], platforms, null);
            int platformToUse = 0;
            cl_device_id deviceToUse = null;
            int i = 0;
            while (i < platforms.length) {
                if (GPUTester.getString(platforms[i], 2306).toUpperCase().contains(PLATFORM_NAME_HINT.toUpperCase())) {
                    platformToUse = i;
                    Thread.sleep(100L);
                    System.err.println(String.valueOf(Integer.toString(i)) + ":" + GPUTester.getString(platforms[i], 2306));
                    Thread.sleep(100L);
                } else {
                    System.out.println(String.valueOf(Integer.toString(i)) + ":" + GPUTester.getString(platforms[i], 2306));
                }
                num = new int[1];
                CL.clGetDeviceIDs(platforms[i], -1L, 0, null, num);
                cl_device_id[] devices = new cl_device_id[num[0]];
                CL.clGetDeviceIDs(platforms[i], -1L, num[0], devices, null);
                int j = 0;
                while (j < devices.length) {
                    if (GPUTester.getString(devices[j], 4139).toUpperCase().contains(DEVICE_NAME_HINT.toUpperCase()) && GPUTester.getString(platforms[i], 2306).toUpperCase().contains(PLATFORM_NAME_HINT.toUpperCase())) {
                        deviceToUse = devices[j];
                        Thread.sleep(100L);
                        System.err.println("\t" + Integer.toString(j) + ":" + GPUTester.getString(devices[j], 4139));
                        Thread.sleep(100L);
                    } else {
                        System.out.println("\t" + Integer.toString(j) + ":" + GPUTester.getString(devices[j], 4139));
                    }
                    ++j;
                }
                if (i == platforms.length - 1 && deviceToUse == null) {
                    deviceToUse = devices[0];
                    platformToUse = i;
                }
                ++i;
            }
            cl_context_properties properties = new cl_context_properties();
            properties.addProperty(4228L, platforms[platformToUse]);
            cl_context context = CL.clCreateContext(properties, 1, new cl_device_id[]{deviceToUse}, null, null, null);
            cl_queue_properties queueProperties = new cl_queue_properties();
            cl_command_queue command_queue = CL.clCreateCommandQueueWithProperties(context, deviceToUse, queueProperties, null);
            Matrix m1 = new Matrix(4, 2);
            Matrix m2 = new Matrix(2, 4);
            m1.w = new double[]{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
            m2.w = new double[]{8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0};
            Graph g = new Graph(false);
            Matrix res2 = g.mul(m1, m2);
            System.out.println(res2.toString());
            GPUGraph g2 = new GPUGraph(command_queue, context, false);
            Matrix res3 = g2.mul(m1, m2);
            System.out.println(res3.toString());
            Random random = new Random(105L);
            m1 = new Matrix(4, 4);
            m1.w = new double[]{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0};
            m2 = new Matrix(4, 1);
            m2.w = new double[]{4.0, 3.0, 2.0, 1.0};
            res2 = g.mul(m1, m2);
            System.out.println(res2.toString());
            res3 = g2.mul(m1, m2);
            System.out.println(res3.toString());
            long totalOpenCLTime = 0L;
            long totalJavaTime = 0L;
            int testSize = 256;
            Matrix[] m1s = new Matrix[testSize];
            Matrix[] m2s = new Matrix[testSize];
            int i2 = 0;
            while (i2 < testSize) {
                m1s[i2] = Matrix.rand(512, 512, 1.0, random);
                m2s[i2] = Matrix.rand(512, 1, 1.0, random);
                ++i2;
            }
            System.out.println("Running on OpenCL device...");
            long startTime = System.currentTimeMillis();
            int i3 = 0;
            while (i3 < testSize) {
                g2.mul(m1s[i3], m2s[i3]);
                ++i3;
            }
            totalOpenCLTime += System.currentTimeMillis() - startTime;
            System.out.println("Done.\nRunning in Java...");
            startTime = System.currentTimeMillis();
            i3 = 0;
            while (i3 < testSize) {
                g.mul(m1s[i3], m2s[i3]);
                ++i3;
            }
            System.out.println("Done.");
            System.out.println("Total OpenCL processing time: " + Long.toString(totalOpenCLTime) + "ms (" + Double.toString((double)totalOpenCLTime / 1000.0) + "s)");
            System.out.println("Total Java processing time: " + Long.toString(totalJavaTime += System.currentTimeMillis() - startTime) + "ms (" + Double.toString((double)totalJavaTime / 1000.0) + "s)");
            System.out.println(Boolean.toString(totalOpenCLTime < totalJavaTime));
            CL.clReleaseCommandQueue(command_queue);
            CL.clReleaseContext(context);
            CL.clReleaseDevice(deviceToUse);
        }
        catch (Exception e) {
            System.err.println("Error: ");
            e.printStackTrace();
        }
    }

    private static String getString(cl_device_id device, int paramName) {
        long[] size = new long[1];
        CL.clGetDeviceInfo(device, paramName, 0L, null, size);
        byte[] buffer = new byte[(int)size[0]];
        CL.clGetDeviceInfo(device, paramName, buffer.length, Pointer.to(buffer), null);
        return new String(buffer, 0, buffer.length - 1);
    }

    private static String getString(cl_platform_id platform, int paramName) {
        long[] size = new long[1];
        CL.clGetPlatformInfo(platform, paramName, 0L, null, size);
        byte[] buffer = new byte[(int)size[0]];
        CL.clGetPlatformInfo(platform, paramName, buffer.length, Pointer.to(buffer), null);
        return new String(buffer, 0, buffer.length - 1);
    }
}

