/*
 * Decompiled with CFR 0.152.
 */
package util;

import autodiff.GPUGraph;
import autodiff.MultiGPUGraph;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import org.jocl.CL;
import org.jocl.Pointer;
import org.jocl.cl_command_queue;
import org.jocl.cl_context;
import org.jocl.cl_context_properties;
import org.jocl.cl_device_id;
import org.jocl.cl_platform_id;
import org.jocl.cl_queue_properties;
import util.NNDevice;

public class CLUtils {
    static {
        CL.setExceptionsEnabled(true);
    }

    public static NNDevice[] findDevice(String platformHint, String deviceHint) {
        int[] num = new int[1];
        CL.clGetPlatformIDs(0, null, num);
        cl_platform_id[] platforms = new cl_platform_id[num[0]];
        CL.clGetPlatformIDs(num[0], platforms, null);
        ArrayList<NNDevice> allDevices = new ArrayList<NNDevice>();
        int i = 0;
        while (i < platforms.length) {
            if (CLUtils.getString(platforms[i], 2306).toUpperCase().contains(platformHint.toUpperCase())) {
                num = new int[1];
                CL.clGetDeviceIDs(platforms[i], -1L, 0, null, num);
                cl_device_id[] devices = new cl_device_id[num[0]];
                CL.clGetDeviceIDs(platforms[i], -1L, num[0], devices, null);
                int j = 0;
                while (j < devices.length) {
                    if (CLUtils.getString(devices[j], 4139).toUpperCase().contains(deviceHint.toUpperCase())) {
                        NNDevice dev = new NNDevice(platforms[i], devices[j], CLUtils.getString(platforms[i], 2306), CLUtils.getString(devices[j], 4139), CLUtils.getLongArr(devices[j], 4101, 3), CLUtils.getLong(devices[j], 4100), CLUtils.getLong(devices[j], 4127), CLUtils.getLong(devices[j], 4131));
                        allDevices.add(dev);
                    }
                    ++j;
                }
            }
            ++i;
        }
        NNDevice[] dev = new NNDevice[allDevices.size()];
        int i2 = 0;
        while (i2 < dev.length) {
            dev[i2] = (NNDevice)allDevices.get(i2);
            ++i2;
        }
        return dev;
    }

    public static void printDevices() throws Exception {
        int[] num = new int[1];
        CL.clGetPlatformIDs(0, null, num);
        cl_platform_id[] platforms = new cl_platform_id[num[0]];
        CL.clGetPlatformIDs(num[0], platforms, null);
        int i = 0;
        while (i < platforms.length) {
            System.out.println(String.valueOf(Integer.toString(i)) + ":" + CLUtils.getString(platforms[i], 2306));
            num = new int[1];
            CL.clGetDeviceIDs(platforms[i], -1L, 0, null, num);
            cl_device_id[] devices = new cl_device_id[num[0]];
            CL.clGetDeviceIDs(platforms[i], -1L, num[0], devices, null);
            int j = 0;
            while (j < devices.length) {
                System.out.println("\t" + Integer.toString(j) + ":" + CLUtils.getString(devices[j], 4139));
                long[] workDims = CLUtils.getLongArr(devices[j], 4101, 3);
                System.out.println("\t\tWork dimensions: " + Long.toString(workDims[0]) + "," + Long.toString(workDims[1]) + "," + Long.toString(workDims[2]));
                long workLimit = CLUtils.getLong(devices[j], 4100);
                System.out.println("\t\tMax work size: " + Long.toString(workLimit));
                long memSize = CLUtils.getLong(devices[j], 4127) / 1024L / 1024L;
                System.out.println("\t\tGlobal memory: " + Long.toString(memSize) + "MB");
                long localMemSize = CLUtils.getLong(devices[j], 4131) / 1024L;
                System.out.println("\t\tLocal memory: " + Long.toString(localMemSize) + "KB");
                int speed = CLUtils.getInt(devices[j], 4108);
                System.out.println("\t\tCurrent core clock speed: " + Integer.toString(speed) + "MHz");
                String ver = CLUtils.getString(devices[j], 4143);
                System.out.println("\t\tSupported OpenCL version: " + ver);
                String vendor = CLUtils.getString(devices[j], 4140);
                System.out.println("\t\tVendor: " + vendor);
                System.out.println("\t\tMax image (Tensor) dimensions: " + Long.toString(CLUtils.getLong(devices[j], 4115)) + "x" + Long.toString(CLUtils.getLong(devices[j], 4116)) + "x" + Long.toString(CLUtils.getLong(devices[j], 4117)));
                ++j;
            }
            ++i;
        }
    }

    public static GPUGraph createGraph(NNDevice dev, boolean applyBackprop) throws Exception {
        cl_context_properties properties = new cl_context_properties();
        properties.addProperty(4228L, dev.getPlatform());
        cl_context context = CL.clCreateContext(properties, 1, new cl_device_id[]{dev.getDevice()}, null, null, null);
        cl_queue_properties queueProperties = new cl_queue_properties();
        queueProperties.addProperty(4243L, 1L);
        cl_command_queue command_queue = CL.clCreateCommandQueueWithProperties(context, dev.getDevice(), queueProperties, null);
        return new GPUGraph(command_queue, context, dev.getDevice(), applyBackprop);
    }

    public static MultiGPUGraph createGraph(NNDevice dev1, NNDevice dev2, boolean applyBackprop) throws Exception {
        return new MultiGPUGraph(CLUtils.createGraph(dev1, applyBackprop), CLUtils.createGraph(dev2, applyBackprop), applyBackprop);
    }

    public static MultiGPUGraph createGraph(NNDevice dev1, double perf1, NNDevice dev2, double perf2, boolean applyBackprop) throws Exception {
        return new MultiGPUGraph(CLUtils.createGraph(dev1, applyBackprop), perf1, CLUtils.createGraph(dev2, applyBackprop), perf2, applyBackprop);
    }

    public static String getString(cl_device_id device, int paramName) {
        long[] size = new long[1];
        CL.clGetDeviceInfo(device, paramName, 0L, null, size);
        byte[] buffer = new byte[(int)size[0]];
        CL.clGetDeviceInfo(device, paramName, buffer.length, Pointer.to(buffer), null);
        return new String(buffer, 0, buffer.length - 1);
    }

    public static String getString(cl_platform_id platform, int paramName) {
        long[] size = new long[1];
        CL.clGetPlatformInfo(platform, paramName, 0L, null, size);
        byte[] buffer = new byte[(int)size[0]];
        CL.clGetPlatformInfo(platform, paramName, buffer.length, Pointer.to(buffer), null);
        return new String(buffer, 0, buffer.length - 1);
    }

    public static long[] getLongArr(cl_device_id device, int paramName, int size) {
        byte[] buffer = new byte[8 * size];
        CL.clGetDeviceInfo(device, paramName, buffer.length, Pointer.to(buffer), null);
        ByteBuffer byteBuffer = ByteBuffer.allocate(buffer.length);
        byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
        byteBuffer.put(buffer);
        byteBuffer.flip();
        long[] toReturn = new long[size];
        int i = 0;
        while (i < size) {
            toReturn[i] = byteBuffer.getLong(i * 8);
            ++i;
        }
        byteBuffer.clear();
        return toReturn;
    }

    public static long getLong(cl_device_id device, int paramName) {
        byte[] buffer = new byte[8];
        CL.clGetDeviceInfo(device, paramName, buffer.length, Pointer.to(buffer), null);
        ByteBuffer byteBuffer = ByteBuffer.allocate(buffer.length);
        byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
        byteBuffer.put(buffer);
        byteBuffer.flip();
        return byteBuffer.getLong();
    }

    public static boolean getBool(cl_device_id device, int paramName) {
        byte[] buffer = new byte[1];
        CL.clGetDeviceInfo(device, paramName, buffer.length, Pointer.to(buffer), null);
        return buffer[0] != 0;
    }

    public static int getInt(cl_device_id device, int paramName) {
        byte[] buffer = new byte[4];
        CL.clGetDeviceInfo(device, paramName, buffer.length, Pointer.to(buffer), null);
        ByteBuffer byteBuffer = ByteBuffer.allocate(buffer.length);
        byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
        byteBuffer.put(buffer);
        byteBuffer.flip();
        return byteBuffer.getInt();
    }
}

