/*
 * Decompiled with CFR 0.152.
 */
package autodiff;

import autodiff.Graph;
import matrix.Matrix;
import org.jocl.CL;
import org.jocl.Pointer;
import org.jocl.blast.CLBlast;
import org.jocl.cl_command_queue;
import org.jocl.cl_context;
import org.jocl.cl_event;
import org.jocl.cl_mem;

public class GPUGraph
extends Graph {
    private cl_command_queue command_queue;
    private cl_context context;

    public GPUGraph(cl_command_queue command_queue, cl_context context, boolean applyBackprop) {
        super(applyBackprop);
        this.command_queue = command_queue;
        this.context = context;
    }

    @Override
    public double subDot(Matrix m1, Matrix m2, int startm1row, int startm1col, int startm2row, int startm2col, int rows, int cols) throws Exception {
        return super.subDot(m1, m2, startm1row, startm1col, startm2row, startm2col, rows, cols);
    }

    public Matrix dgemv(final Matrix m, final Matrix v) throws Exception {
        if (v.cols != 1) {
            throw new Exception("Expected column vector");
        }
        final Matrix out = new Matrix(m.rows, 1);
        cl_mem memA = CL.clCreateBuffer(this.context, 4L, m.rows * m.cols * 8, null, null);
        cl_mem memB = CL.clCreateBuffer(this.context, 4L, v.rows * 8, null, null);
        cl_mem memC = CL.clCreateBuffer(this.context, 1L, out.rows * 8, null, null);
        CL.clEnqueueWriteBuffer(this.command_queue, memA, true, 0L, m.rows * m.cols * 8, Pointer.to(m.w), 0, null, null);
        CL.clEnqueueWriteBuffer(this.command_queue, memB, true, 0L, v.rows * 8, Pointer.to(v.w), 0, null, null);
        CL.clEnqueueWriteBuffer(this.command_queue, memC, true, 0L, out.rows * 8, Pointer.to(out.w), 0, null, null);
        cl_event event = new cl_event();
        int LDA = m.cols;
        int LDB = 1;
        int LDC = 1;
        CLBlast.CLBlastDgemv(101, 111, m.rows, m.cols, 1.0, memA, 0L, LDA, memB, 0L, LDB, 0.0, memC, 0L, LDC, this.command_queue, event);
        CL.clWaitForEvents(1, new cl_event[]{event});
        CL.clEnqueueReadBuffer(this.command_queue, memC, true, 0L, out.rows * 8, Pointer.to(out.w), 0, null, null);
        CL.clReleaseMemObject(memA);
        CL.clReleaseMemObject(memB);
        CL.clReleaseMemObject(memC);
        if (this.applyBackprop) {
            Runnable bp = new Runnable(){

                @Override
                public void run() {
                    try {
                        int i = 0;
                        while (i < m.rows) {
                            double b = out.dw[i];
                            int k = 0;
                            while (k < m.cols) {
                                int n = m.cols * i + k;
                                m.dw[n] = m.dw[n] + v.w[k] * b;
                                int n2 = k;
                                v.dw[n2] = v.dw[n2] + m.w[m.cols * i + k] * b;
                                ++k;
                            }
                            ++i;
                        }
                    }
                    catch (Exception e) {
                        GPUGraph.this.backpropException = e;
                        return;
                    }
                }
            };
            this.backprop.add(bp);
        }
        return out;
    }

    @Override
    public Matrix mul(final Matrix m1, final Matrix m2) throws Exception {
        if (m1.cols != m2.rows) {
            throw new Exception("matrix dimension mismatch");
        }
        if (m2.cols == 1) {
            return this.dgemv(m1, m2);
        }
        int m1rows = m1.rows;
        final int m1cols = m1.cols;
        final int m2cols = m2.cols;
        final int outcols = m2.cols;
        final Matrix out = new Matrix(m1rows, m2cols);
        cl_mem memA = CL.clCreateBuffer(this.context, 4L, m1.rows * m1cols * 8, null, null);
        cl_mem memB = CL.clCreateBuffer(this.context, 4L, m2.rows * m2.cols * 8, null, null);
        cl_mem memC = CL.clCreateBuffer(this.context, 1L, out.rows * out.cols * 8, null, null);
        CL.clEnqueueWriteBuffer(this.command_queue, memA, true, 0L, m1.rows * m1cols * 8, Pointer.to(m1.w), 0, null, null);
        CL.clEnqueueWriteBuffer(this.command_queue, memB, true, 0L, m2.rows * m2.cols * 8, Pointer.to(m2.w), 0, null, null);
        CL.clEnqueueWriteBuffer(this.command_queue, memC, true, 0L, out.rows * out.cols * 8, Pointer.to(out.w), 0, null, null);
        cl_event event = new cl_event();
        int LDA = m1.cols;
        int LDB = out.cols;
        int LDC = out.cols;
        CLBlast.CLBlastDgemm(101, 111, 111, out.rows, out.cols, m1.cols, 1.0, memA, 0L, LDA, memB, 0L, LDB, 0.0, memC, 0L, LDC, this.command_queue, event);
        CL.clWaitForEvents(1, new cl_event[]{event});
        CL.clEnqueueReadBuffer(this.command_queue, memC, true, 0L, out.rows * out.cols * 8, Pointer.to(out.w), 0, null, null);
        CL.clReleaseMemObject(memA);
        CL.clReleaseMemObject(memB);
        CL.clReleaseMemObject(memC);
        if (this.applyBackprop) {
            Runnable bp = new Runnable(){

                @Override
                public void run() {
                    try {
                        int i = 0;
                        while (i < m1.rows) {
                            int outcol = outcols * i;
                            int j = 0;
                            while (j < m2.cols) {
                                double b = out.dw[outcol + j];
                                int k = 0;
                                while (k < m1.cols) {
                                    int n = m1cols * i + k;
                                    m1.dw[n] = m1.dw[n] + m2.w[m2cols * k + j] * b;
                                    int n2 = m2cols * k + j;
                                    m2.dw[n2] = m2.dw[n2] + m1.w[m1cols * i + k] * b;
                                    ++k;
                                }
                                ++j;
                            }
                            ++i;
                        }
                    }
                    catch (Exception e) {
                        GPUGraph.this.backpropException = e;
                        return;
                    }
                }
            };
            this.backprop.add(bp);
        }
        return out;
    }
}

