/*
 * Decompiled with CFR 0.152.
 */
package autodiff;

import autodiff.Graph;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import matrix.Matrix;
import matrix.Tensor;
import nonlinearities.ExponentialLinearUnit;
import nonlinearities.LinearUnit;
import nonlinearities.Nonlinearity;
import nonlinearities.ReLuUnit;
import nonlinearities.RectifiedLinearUnit;
import nonlinearities.SigmoidUnit;
import nonlinearities.SineUnit;
import nonlinearities.TanhUnit;
import org.jocl.CL;
import org.jocl.NativePointerObject;
import org.jocl.Pointer;
import org.jocl.Sizeof;
import org.jocl.cl_command_queue;
import org.jocl.cl_context;
import org.jocl.cl_device_id;
import org.jocl.cl_event;
import org.jocl.cl_kernel;
import org.jocl.cl_mem;
import org.jocl.cl_program;

public class GPUGraph
extends Graph {
    cl_command_queue command_queue;
    cl_context context;
    cl_device_id device;
    private static String fullConvProgramCode;
    private cl_program fullConvProgram;
    private cl_kernel fullConvKernel;
    private cl_kernel fullConvReLUKernel;
    private cl_kernel fullConvSigmoidKernel;
    private cl_kernel fullConvSineKernel;
    private cl_kernel fullConvTanhKernel;
    private cl_kernel fullConvELUKernel;
    private static String fullConvBackwardProgramCode;
    private cl_program fullConvBackwardProgram;
    private cl_kernel convBackwardFiltersKernel;
    private cl_kernel convBackwardInputsKernel;
    private final float[] zerosF = new float[]{0.0f};
    private Map<UUID, BufferInfo[]> paramBuffer = new HashMap<UUID, BufferInfo[]>(16);
    private boolean forceKeepParameters = false;
    private List<BufferInfo> unusedMemBuffer = new ArrayList<BufferInfo>(16);
    private Map<BufferInfo, Integer> unusedMemCntr = new HashMap<BufferInfo, Integer>();
    public static int MAX_UNUSED_COUNT;
    public static boolean DISABLE_CONCURRENT;
    public static boolean ONLY_ACCELL_SLOW_NONLIN;
    public long cntr = 0L;
    public long time1 = 0L;
    public long time2 = 0L;

    static {
        CL.setExceptionsEnabled(true);
        fullConvProgramCode = null;
        MAX_UNUSED_COUNT = 5;
        DISABLE_CONCURRENT = false;
        ONLY_ACCELL_SLOW_NONLIN = true;
    }

    public GPUGraph(cl_command_queue command_queue, cl_context context, cl_device_id device, boolean applyBackprop) throws Exception {
        super(applyBackprop);
        this.command_queue = command_queue;
        this.context = context;
        this.device = device;
        if (fullConvProgramCode == null) {
            fullConvProgramCode = this.readProgram("/cl/conv.cl");
        }
        this.fullConvProgram = this.getProgramForDevice(device, fullConvProgramCode);
        this.fullConvKernel = CL.clCreateKernel(this.fullConvProgram, "conv_linear", null);
        this.fullConvReLUKernel = CL.clCreateKernel(this.fullConvProgram, "conv_ReLU", null);
        this.fullConvSigmoidKernel = CL.clCreateKernel(this.fullConvProgram, "conv_sigmoid", null);
        this.fullConvSineKernel = CL.clCreateKernel(this.fullConvProgram, "conv_sine", null);
        this.fullConvTanhKernel = CL.clCreateKernel(this.fullConvProgram, "conv_tanh", null);
        this.fullConvELUKernel = CL.clCreateKernel(this.fullConvProgram, "conv_ELU", null);
        if (fullConvBackwardProgramCode == null) {
            fullConvBackwardProgramCode = this.readProgram("/cl/convBackward.cl");
        }
        this.fullConvBackwardProgram = this.getProgramForDevice(device, fullConvBackwardProgramCode);
        this.convBackwardFiltersKernel = CL.clCreateKernel(this.fullConvBackwardProgram, "convBackwardFilters", null);
        this.convBackwardInputsKernel = CL.clCreateKernel(this.fullConvBackwardProgram, "convBackwardInput", null);
    }

    private cl_program getProgramForDevice(cl_device_id device, String code) throws Exception {
        cl_program prgrm = CL.clCreateProgramWithSource(this.context, 1, new String[]{code}, null, null);
        CL.clBuildProgram(prgrm, 1, new cl_device_id[]{device}, "-cl-single-precision-constant -cl-mad-enable -cl-no-signed-zeros", null, null);
        return prgrm;
    }

    private synchronized void internalBackpropFullConv(int pad, int stride, int numFiltersPerDepth, BufferInfo memT, Tensor t, Tensor out, BufferInfo memFs, Matrix[] filters, boolean doReleaseMem, Tensor gradientMultipliers) throws Exception {
        int j;
        Matrix currMatrix;
        if (gradientMultipliers != null) {
            int i = 0;
            while (i < out.depth) {
                Matrix m1 = out.matrices[i];
                Matrix m2 = gradientMultipliers.matrices[i];
                int j2 = 0;
                while (j2 < m1.dw.length) {
                    int n = j2;
                    m1.dw[n] = m1.dw[n] * m2.w[j2];
                    ++j2;
                }
                ++i;
            }
        }
        BufferInfo memDW = this.putGradientMatricesInBuffer(true, out.matrices);
        BufferInfo memOutDW = this.tryReuseOrCreate(false, t.width * t.height * t.depth * 4);
        BufferInfo memFsDW = this.tryReuseOrCreate(false, filters[0].cols * filters[0].rows * filters.length * 4);
        CL.clEnqueueFillBuffer(this.command_queue, memOutDW.getBuffer(), Pointer.to(this.zerosF), 1L, 0L, t.width * t.height * t.depth * 4, 0, null, null);
        CL.clEnqueueFillBuffer(this.command_queue, memFsDW.getBuffer(), Pointer.to(this.zerosF), 1L, 0L, filters[0].cols * filters[0].rows * filters.length * 4, 0, null, null);
        int x = t.width + pad * 2 - filters[0].rows + 1;
        int y = t.height + pad * 2 - filters[0].cols + 1;
        int[] info = new int[]{t.width, t.height, t.depth, out.width, out.height, filters[0].rows, filters[0].cols, x, y, pad, stride, numFiltersPerDepth, t.width + 2 * pad, t.height + 2 * pad};
        cl_mem memInfo = CL.clCreateBuffer(this.context, 4L, info.length * 4, null, null);
        CL.clEnqueueWriteBuffer(this.command_queue, memInfo, true, 0L, info.length * 4, Pointer.to(info), 0, null, null);
        CL.clSetKernelArg(this.convBackwardInputsKernel, 0, Sizeof.cl_mem, Pointer.to((NativePointerObject)memInfo));
        CL.clSetKernelArg(this.convBackwardFiltersKernel, 0, Sizeof.cl_mem, Pointer.to((NativePointerObject)memInfo));
        CL.clSetKernelArg(this.convBackwardFiltersKernel, 1, Sizeof.cl_mem, Pointer.to((NativePointerObject)memT.getBuffer()));
        CL.clSetKernelArg(this.convBackwardInputsKernel, 1, Sizeof.cl_mem, Pointer.to((NativePointerObject)memFs.getBuffer()));
        CL.clSetKernelArg(this.convBackwardFiltersKernel, 2, Sizeof.cl_mem, Pointer.to((NativePointerObject)memFs.getBuffer()));
        CL.clSetKernelArg(this.convBackwardInputsKernel, 2, Sizeof.cl_mem, Pointer.to((NativePointerObject)memDW.getBuffer()));
        CL.clSetKernelArg(this.convBackwardFiltersKernel, 3, Sizeof.cl_mem, Pointer.to((NativePointerObject)memDW.getBuffer()));
        CL.clSetKernelArg(this.convBackwardInputsKernel, 3, Sizeof.cl_mem, Pointer.to((NativePointerObject)memOutDW.getBuffer()));
        CL.clSetKernelArg(this.convBackwardFiltersKernel, 4, Sizeof.cl_mem, Pointer.to((NativePointerObject)memFsDW.getBuffer()));
        CL.clFinish(this.command_queue);
        long[] globalSizesFilters = new long[]{numFiltersPerDepth, t.depth, filters[0].rows * filters[0].cols};
        long[] globalSizesInputs = new long[]{t.depth, (x + filters[0].rows) * (y + filters[0].cols)};
        CL.clEnqueueNDRangeKernel(this.command_queue, this.convBackwardInputsKernel, globalSizesInputs.length, null, globalSizesInputs, null, 0, null, null);
        if (DISABLE_CONCURRENT) {
            CL.clFinish(this.command_queue);
        }
        CL.clEnqueueNDRangeKernel(this.command_queue, this.convBackwardFiltersKernel, globalSizesFilters.length, null, globalSizesFilters, null, 0, null, null);
        CL.clFinish(this.command_queue);
        float[] arr = new float[t.width * t.height * t.depth];
        CL.clEnqueueReadBuffer(this.command_queue, memOutDW.getBuffer(), true, 0L, arr.length * 4, Pointer.to(arr), 0, null, null);
        int pos = 0;
        int i = 0;
        while (i < t.depth) {
            currMatrix = t.matrices[i];
            j = 0;
            while (j < t.width * t.height) {
                int n = j++;
                currMatrix.dw[n] = currMatrix.dw[n] + (double)arr[pos];
                ++pos;
            }
            ++i;
        }
        if (filters[0].cols * filters[0].rows * filters.length > arr.length) {
            arr = new float[filters[0].cols * filters[0].rows * filters.length];
        }
        CL.clEnqueueReadBuffer(this.command_queue, memFsDW.getBuffer(), true, 0L, filters[0].cols * filters[0].rows * filters.length * 4, Pointer.to(arr), 0, null, null);
        pos = 0;
        i = 0;
        while (i < filters.length) {
            currMatrix = filters[i];
            j = 0;
            while (j < filters[0].cols * filters[0].rows) {
                int n = j++;
                currMatrix.dw[n] = currMatrix.dw[n] + (double)arr[pos];
                ++pos;
            }
            ++i;
        }
        CL.clFinish(this.command_queue);
        CL.clReleaseMemObject(memInfo);
        this.unloadBuffer(memT);
        if (doReleaseMem) {
            this.unloadBuffer(memFs);
        }
        this.unloadBuffer(memOutDW);
        this.unloadBuffer(memDW);
        this.unloadBuffer(memFsDW);
    }

    public boolean isSupportedNonlinearity(Nonlinearity n) {
        if ((n instanceof ReLuUnit || n instanceof RectifiedLinearUnit) && ONLY_ACCELL_SLOW_NONLIN) {
            return false;
        }
        return n instanceof LinearUnit || n instanceof ReLuUnit || n instanceof RectifiedLinearUnit || n instanceof SigmoidUnit || n instanceof SineUnit || n instanceof TanhUnit || n instanceof ExponentialLinearUnit;
    }

    public synchronized Tensor fullConv(boolean doBackprop, final int pad, final int stride, final int outDepth, final Tensor t, final Matrix bias, final Matrix[] filters, Nonlinearity nonlin, UUID parameterUUID) throws Exception {
        boolean hasUUID;
        BufferInfo memFs;
        BufferInfo[] mems;
        Tensor gradMuls;
        if (filters.length != t.depth * outDepth) {
            throw new Exception("Filter count does not match input depth times out depth");
        }
        final Tensor out = new Tensor((t.width + pad * 2 - filters[0].cols) / stride + 1, (t.height + pad * 2 - filters[0].rows) / stride + 1, outDepth);
        boolean doNonlin = false;
        if (this.isSupportedNonlinearity(nonlin) && !(nonlin instanceof LinearUnit)) {
            doNonlin = true;
            gradMuls = new Tensor(out.width, out.height, out.depth, false);
        } else {
            gradMuls = null;
        }
        final BufferInfo memT = this.putMatricesInBuffer(true, pad, t.matrices);
        BufferInfo memB = null;
        if (parameterUUID != null && this.paramBuffer.get(parameterUUID) != null) {
            mems = this.paramBuffer.get(parameterUUID);
            memFs = mems[0];
            memB = mems[1];
            hasUUID = true;
        } else {
            memFs = this.putMatricesInBuffer(true, filters);
            memB = this.putMatricesInBuffer(true, bias);
            if (parameterUUID != null && (doBackprop && this.isApplyingBackprop() || this.forceKeepParameters)) {
                mems = new BufferInfo[]{memFs, memB};
                this.paramBuffer.put(parameterUUID, mems);
                hasUUID = true;
            } else {
                hasUUID = false;
            }
        }
        BufferInfo memOut = this.tryReuseOrCreate(false, out.width * out.height * out.depth * 4);
        BufferInfo memMuls = null;
        cl_mem memNonlinArg = null;
        if (doNonlin) {
            memMuls = this.tryReuseOrCreate(false, out.width * out.height * out.depth * 4);
            if (nonlin instanceof RectifiedLinearUnit || nonlin instanceof ExponentialLinearUnit || nonlin instanceof ReLuUnit) {
                memNonlinArg = CL.clCreateBuffer(this.context, 4L, 4L, null, null);
                float toWrite = 0.0f;
                if (nonlin instanceof RectifiedLinearUnit) {
                    toWrite = (float)((RectifiedLinearUnit)nonlin).slope;
                }
                if (nonlin instanceof ExponentialLinearUnit) {
                    toWrite = (float)((ExponentialLinearUnit)nonlin).slope;
                }
                if (nonlin instanceof ReLuUnit) {
                    toWrite = 0.0f;
                }
                CL.clEnqueueWriteBuffer(this.command_queue, memNonlinArg, true, 0L, 4L, Pointer.to(new float[]{toWrite}), 0, null, null);
            }
        }
        int[] info = new int[]{t.width + 2 * pad, t.height + 2 * pad, t.depth, out.width, out.height, filters[0].rows, filters[0].cols, stride};
        cl_mem memInfo = CL.clCreateBuffer(this.context, 4L, info.length * 4, null, null);
        CL.clEnqueueWriteBuffer(this.command_queue, memInfo, true, 0L, info.length * 4, Pointer.to(info), 0, null, null);
        cl_kernel kernelToUse = this.fullConvKernel;
        if (doNonlin) {
            if (nonlin instanceof RectifiedLinearUnit || nonlin instanceof ReLuUnit) {
                kernelToUse = this.fullConvReLUKernel;
            }
            if (nonlin instanceof SigmoidUnit) {
                kernelToUse = this.fullConvSigmoidKernel;
            }
            if (nonlin instanceof SineUnit) {
                kernelToUse = this.fullConvSineKernel;
            }
            if (nonlin instanceof TanhUnit) {
                kernelToUse = this.fullConvTanhKernel;
            }
            if (nonlin instanceof ExponentialLinearUnit) {
                kernelToUse = this.fullConvELUKernel;
            }
        }
        CL.clSetKernelArg(kernelToUse, 0, Sizeof.cl_mem, Pointer.to((NativePointerObject)memInfo));
        CL.clSetKernelArg(kernelToUse, 1, Sizeof.cl_mem, Pointer.to((NativePointerObject)memT.getBuffer()));
        CL.clSetKernelArg(kernelToUse, 2, Sizeof.cl_mem, Pointer.to((NativePointerObject)memFs.getBuffer()));
        CL.clSetKernelArg(kernelToUse, 3, Sizeof.cl_mem, Pointer.to((NativePointerObject)memB.getBuffer()));
        CL.clSetKernelArg(kernelToUse, 4, Sizeof.cl_mem, Pointer.to((NativePointerObject)memOut.getBuffer()));
        if (doNonlin) {
            CL.clSetKernelArg(kernelToUse, 5, Sizeof.cl_mem, Pointer.to((NativePointerObject)memMuls.getBuffer()));
            if (memNonlinArg != null) {
                CL.clSetKernelArg(kernelToUse, 6, Sizeof.cl_mem, Pointer.to((NativePointerObject)memNonlinArg));
            }
        }
        CL.clFinish(this.command_queue);
        long[] globalSizes = new long[]{outDepth, out.height * out.width};
        cl_event event = new cl_event();
        CL.clEnqueueNDRangeKernel(this.command_queue, kernelToUse, globalSizes.length, null, globalSizes, null, 0, null, event);
        CL.clWaitForEvents(1, new cl_event[]{event});
        float[] arr = new float[out.width * out.height * out.depth];
        float[] mulArr = new float[out.width * out.height * out.depth];
        CL.clEnqueueReadBuffer(this.command_queue, memOut.getBuffer(), true, 0L, arr.length * 4, Pointer.to(arr), 0, null, null);
        if (doNonlin && doBackprop && this.applyBackprop) {
            CL.clEnqueueReadBuffer(this.command_queue, memMuls.getBuffer(), true, 0L, mulArr.length * 4, Pointer.to(mulArr), 0, null, null);
        }
        int i = 0;
        while (i < out.depth) {
            double[] mw = out.matrices[i].w;
            int startIndx = i * out.width * out.height;
            int j = 0;
            while (j < mw.length) {
                mw[j] = arr[startIndx + j];
                ++j;
            }
            if (doNonlin && doBackprop && this.applyBackprop) {
                mw = gradMuls.matrices[i].w;
                j = 0;
                while (j < mw.length) {
                    mw[j] = mulArr[startIndx + j];
                    ++j;
                }
            }
            ++i;
        }
        CL.clFinish(this.command_queue);
        if (doBackprop && this.applyBackprop) {
            CL.clReleaseMemObject(memInfo);
            if (memNonlinArg != null) {
                CL.clReleaseMemObject(memNonlinArg);
            }
            if (doNonlin) {
                this.unloadBuffer(memMuls);
            }
            if (!hasUUID) {
                this.unloadBuffer(memB);
            }
            this.unloadBuffer(memOut);
            this.addBackprop(new Runnable(){

                @Override
                public void run() {
                    int i = 0;
                    while (i < outDepth) {
                        Matrix currMatrix = out.matrices[i];
                        int j = 0;
                        while (j < currMatrix.w.length) {
                            int n = i;
                            bias.dw[n] = bias.dw[n] + currMatrix.dw[j];
                            ++j;
                        }
                        ++i;
                    }
                    try {
                        GPUGraph.this.internalBackpropFullConv(pad, stride, outDepth, memT, t, out, memFs, filters, !hasUUID, gradMuls);
                    }
                    catch (Exception e) {
                        GPUGraph.this.backpropException = e;
                    }
                }
            });
        } else {
            CL.clReleaseMemObject(memInfo);
            if (memNonlinArg != null) {
                CL.clReleaseMemObject(memNonlinArg);
            }
            if (doNonlin) {
                this.unloadBuffer(memMuls);
            }
            this.unloadBuffer(memT);
            if (!hasUUID) {
                this.unloadBuffer(memFs);
                this.unloadBuffer(memB);
            }
            this.unloadBuffer(memOut);
        }
        if (!doNonlin && !(nonlin instanceof LinearUnit)) {
            Tensor finalOut = new Tensor(out.width, out.height, out.depth);
            int i2 = 0;
            while (i2 < out.depth) {
                finalOut.matrices[i2] = super.nonlin(nonlin, out.matrices[i2]);
                ++i2;
            }
            return finalOut;
        }
        return out;
    }

    private BufferInfo createBuffer(boolean readOnly, long size) {
        cl_mem mem = CL.clCreateBuffer(this.context, readOnly ? 4L : 1L, size, null, null);
        return new BufferInfo(mem, readOnly, size);
    }

    private BufferInfo tryReuseOrCreate(boolean readOnly, long size) {
        BufferInfo inf = this.tryReuseBuffer(size, readOnly);
        if (inf == null) {
            return this.createBuffer(readOnly, size);
        }
        return inf;
    }

    private BufferInfo putMatricesInBuffer(boolean readOnly, Matrix ... ms) throws Exception {
        int matrixSize = ms[0].rows * ms[0].cols;
        BufferInfo memT = this.tryReuseOrCreate(readOnly, matrixSize * ms.length * 4);
        float[] arr = new float[matrixSize * ms.length];
        int i = 0;
        while (i < ms.length) {
            int startIndx = i * matrixSize;
            double[] mw = ms[i].w;
            int j = 0;
            while (j < matrixSize) {
                arr[startIndx + j] = (float)mw[j];
                ++j;
            }
            ++i;
        }
        CL.clEnqueueWriteBuffer(this.command_queue, memT.getBuffer(), true, 0L, arr.length * 4, Pointer.to(arr), 0, null, null);
        return memT;
    }

    private BufferInfo putGradientMatricesInBuffer(boolean readOnly, Matrix ... ms) throws Exception {
        int matrixSize = ms[0].rows * ms[0].cols;
        BufferInfo memT = this.tryReuseOrCreate(readOnly, matrixSize * ms.length * 4);
        float[] arr = new float[matrixSize * ms.length];
        int i = 0;
        while (i < ms.length) {
            double[] mdw = ms[i].dw;
            int startIndx = i * matrixSize;
            int j = 0;
            while (j < matrixSize) {
                arr[startIndx + j] = (float)mdw[j];
                ++j;
            }
            ++i;
        }
        CL.clEnqueueWriteBuffer(this.command_queue, memT.getBuffer(), true, 0L, arr.length * 4, Pointer.to(arr), 0, null, null);
        return memT;
    }

    private BufferInfo putMatricesInBuffer(boolean readOnly, int pad, Matrix ... ms) throws Exception {
        if (pad == 0) {
            return this.putMatricesInBuffer(readOnly, ms);
        }
        int matrixSize = (ms[0].rows + 2 * pad) * (ms[0].cols + 2 * pad);
        BufferInfo memT = this.tryReuseOrCreate(readOnly, matrixSize * ms.length * 4);
        float[] arr = new float[matrixSize * ms.length];
        int i = 0;
        while (i < ms.length) {
            Matrix m = ms[i];
            double[] mw = m.w;
            int j = 0;
            while (j < ms[i].rows) {
                int startIndx = i * matrixSize + (j + pad) * (ms[i].cols + 2 * pad) + pad;
                int srcStartIndx = j * m.cols;
                int k = 0;
                while (k < ms[i].cols) {
                    arr[startIndx + k] = (float)mw[srcStartIndx + k];
                    ++k;
                }
                ++j;
            }
            ++i;
        }
        CL.clEnqueueWriteBuffer(this.command_queue, memT.getBuffer(), true, 0L, arr.length * 4, Pointer.to(arr), 0, null, null);
        return memT;
    }

    private void unloadBuffer(BufferInfo buffer) {
        this.unusedMemBuffer.add(buffer);
        this.unusedMemCntr.put(buffer, 0);
    }

    private BufferInfo tryReuseBuffer(long bufferSize, boolean readOnly) {
        int indx = -1;
        int i = 0;
        while (i < this.unusedMemBuffer.size()) {
            if (this.unusedMemBuffer.get(i).isReadOnly() == readOnly && this.unusedMemBuffer.get(i).getSize() == bufferSize) {
                indx = i;
                break;
            }
            ++i;
        }
        if (indx == -1) {
            return null;
        }
        BufferInfo res = this.unusedMemBuffer.remove(indx);
        this.unusedMemCntr.remove(res);
        return res;
    }

    public void cleanMemory() {
        for (BufferInfo inf : this.unusedMemBuffer) {
            CL.clReleaseMemObject(inf.getBuffer());
        }
        this.unusedMemBuffer.clear();
    }

    @Override
    public void cleanUp() {
        super.cleanUp();
        CL.clReleaseKernel(this.fullConvKernel);
        CL.clReleaseKernel(this.fullConvReLUKernel);
        CL.clReleaseKernel(this.fullConvSigmoidKernel);
        CL.clReleaseKernel(this.fullConvSineKernel);
        CL.clReleaseKernel(this.fullConvTanhKernel);
        CL.clReleaseKernel(this.fullConvELUKernel);
        CL.clReleaseKernel(this.convBackwardFiltersKernel);
        CL.clReleaseKernel(this.convBackwardInputsKernel);
        CL.clReleaseProgram(this.fullConvProgram);
        CL.clReleaseProgram(this.fullConvBackwardProgram);
        CL.clReleaseCommandQueue(this.command_queue);
        CL.clReleaseContext(this.context);
        for (UUID id : this.paramBuffer.keySet()) {
            BufferInfo[] bufferInfoArray = this.paramBuffer.get(id);
            int n = bufferInfoArray.length;
            int n2 = 0;
            while (n2 < n) {
                BufferInfo mem = bufferInfoArray[n2];
                CL.clReleaseMemObject(mem.getBuffer());
                ++n2;
            }
        }
        this.paramBuffer.clear();
        this.cleanMemory();
    }

    @Override
    public void resetBackprop() {
        super.resetBackprop();
        if (!this.forceKeepParameters) {
            for (UUID id : this.paramBuffer.keySet()) {
                BufferInfo[] bufferInfoArray = this.paramBuffer.get(id);
                int n = bufferInfoArray.length;
                int n2 = 0;
                while (n2 < n) {
                    BufferInfo mem = bufferInfoArray[n2];
                    this.unloadBuffer(mem);
                    ++n2;
                }
            }
            this.paramBuffer.clear();
        }
    }

    public void forceKeepParameters(boolean force) {
        this.forceKeepParameters = force;
        if (!force) {
            for (UUID id : this.paramBuffer.keySet()) {
                BufferInfo[] bufferInfoArray = this.paramBuffer.get(id);
                int n = bufferInfoArray.length;
                int n2 = 0;
                while (n2 < n) {
                    BufferInfo mem = bufferInfoArray[n2];
                    this.unloadBuffer(mem);
                    ++n2;
                }
            }
            this.paramBuffer.clear();
        }
    }

    @Override
    public void backward() throws Exception {
        Exception ee = null;
        try {
            super.backward();
        }
        catch (Exception e) {
            ee = e;
        }
        if (!this.forceKeepParameters) {
            for (UUID id : this.paramBuffer.keySet()) {
                BufferInfo[] bufferInfoArray = this.paramBuffer.get(id);
                int n = bufferInfoArray.length;
                int n2 = 0;
                while (n2 < n) {
                    BufferInfo mem = bufferInfoArray[n2];
                    this.unloadBuffer(mem);
                    ++n2;
                }
            }
            this.paramBuffer.clear();
        }
        int i = 0;
        while (i < this.unusedMemBuffer.size()) {
            BufferInfo mem = this.unusedMemBuffer.get(i);
            int count = this.unusedMemCntr.get(mem);
            if (count > MAX_UNUSED_COUNT) {
                CL.clReleaseMemObject(mem.getBuffer());
                this.unusedMemBuffer.remove(i);
                --i;
            } else {
                this.unusedMemCntr.put(mem, ++count);
            }
            ++i;
        }
        if (ee != null) {
            throw ee;
        }
    }

    private String readProgram(String loc) throws Exception {
        String line;
        BufferedReader reader = new BufferedReader(new InputStreamReader(this.getClass().getResourceAsStream(loc)));
        HashMap<String, String> replacementMap = new HashMap<String, String>();
        boolean b = false;
        String fullProgram = "";
        while ((line = reader.readLine()) != null) {
            if (b) {
                if (line.contains("end_replacements")) {
                    b = false;
                    continue;
                }
                Object s = line.split(";");
                if (((String[])s).length != 2) {
                    throw new Exception("Invalid syntax for replacement: " + line);
                }
                replacementMap.put(s[0], s[1]);
                continue;
            }
            if (line.contains("begin_replacements")) {
                b = true;
                continue;
            }
            for (Object s : replacementMap.keySet()) {
                line = line.replace((CharSequence)s, (CharSequence)replacementMap.get(s));
            }
            fullProgram = String.valueOf(fullProgram) + line + "\n";
        }
        reader.close();
        if (fullProgram.trim().isEmpty()) {
            throw new Exception("Empty program");
        }
        return fullProgram.trim();
    }

    private class BufferInfo {
        private cl_mem buffer;
        private boolean readOnly;
        private long size;

        public BufferInfo(cl_mem buffer, boolean readOnly, long size) {
            this.buffer = buffer;
            this.readOnly = readOnly;
            this.size = size;
        }

        public cl_mem getBuffer() {
            return this.buffer;
        }

        public boolean isReadOnly() {
            return this.readOnly;
        }

        public long getSize() {
            return this.size;
        }

        public boolean equals(Object o) {
            if (!(o instanceof BufferInfo)) {
                return false;
            }
            BufferInfo a = (BufferInfo)o;
            if (a.getSize() != this.size) {
                return false;
            }
            if (a.isReadOnly() != this.readOnly) {
                return false;
            }
            return a.getBuffer().equals(this.buffer);
        }
    }
}

