//Base functions
__kernel void convBackwardFilters(__constant int *info, __global const float *in, __constant float *filters, __global const float *grads, __global float *filterGrads){
    
    int glob0 = get_global_id(0) % numFilters;
    int glob1 = get_global_id(1) % inDepth;
    int glob2 = get_global_id(2) % wrap3;
    
    int matrixIndx = glob1 * (inWidth * inHeight);
    int filterIndx = (glob0 * inDepth + glob1) * (filterWidth * filterHeight);
    
    int m = glob2 / filterWidth;
    int n = glob2 % filterWidth;
    
    int gradIndx = filterIndx + filterWidth * m + n;
    
    float res = 0;
    int indx1 = 0;
    
    int rest = outWidth & 3; // = outWidth % 4
    int c1 = outWidth > 4 ? 1 : 0;
    
    int outputIndx = glob0 * (outWidth * outHeight);
    for(int k = 0; k < outHeight; k++) {
    	//indx1 = k * stride + m;
    	indx1 = matrixIndx + inWidth * (k * stride + m) + n;
        /*for(int l = 0; l < outWidth; l++) {
	        res += in[indx1] * grads[outputIndx];
	        indx1 += stride;
	        outputIndx++;
        }*/
        if(c1){
	        for(int l = 0; l < outWidth / 4; l++){
	        	float4 vf = (float4)(in[indx1], in[indx1 + stride], in[indx1 + stride * 2], in[indx1 + stride * 3]);
	        	float4 vf2 = (float4)(grads[outputIndx], grads[outputIndx + 1], grads[outputIndx + 2], grads[outputIndx + 3]);
	        	res += dot(vf, vf2);
	        	indx1 += 4 * stride;
	        	outputIndx += 4;
	        }
        }
        if(rest == 3){
            float4 vf = (float4)(in[indx1], in[indx1 + stride], in[indx1 + stride * 2], 0);
        	float4 vf2 = (float4)(grads[outputIndx], grads[outputIndx + 1], grads[outputIndx + 2], 0);
        	res += dot(vf, vf2);
        }else if(rest == 2){
        	float2 vf = (float2)(in[indx1], in[indx1 + stride]);
        	float2 vf2 = (float2)(grads[outputIndx], grads[outputIndx + 1]);
        	res += dot(vf, vf2);
        }else if(rest == 1){
        	res += in[indx1] * grads[outputIndx];
        }else if(c1 == 0){
        	float4 vf = (float4)(in[indx1], in[indx1 + stride], in[indx1 + stride * 2], in[indx1 + stride * 3]);
        	float4 vf2 = (float4)(grads[outputIndx], grads[outputIndx + 1], grads[outputIndx + 2], grads[outputIndx + 3]);
        	res += dot(vf, vf2);
        	outputIndx += 4;
        }
        outputIndx += rest;
    }
    filterGrads[gradIndx] = res;
}

//TODO: Optimize the hecking heck heck out of this
__kernel void convBackwardInput(__constant int *info, __constant float *filters __attribute__((max_constant_size(4096))), __global const float *grads, __global float *outGrad){
	
	int numj = get_global_id(0) % inDepth;
	int glob1 = get_global_id(1) % wrap2;
    
    int y2 = glob1 / (ax2 + filterWidth);
    int x2 = glob1 % (ax2 + filterWidth);
    
    if(y2 - pad < 0 || x2 - pad < 0 || y2 - pad >= inHeight2 || x2 - pad >= inWidth2) return;
    
    int matrixIndx = numj * (inWidth2 * inHeight2);
    
    int loopStartN = x2 % stride;
    while((x2 - loopStartN) / stride >= outWidth) loopStartN += stride;
    int loopStartM = y2 % stride;
    while((y2 - loopStartM) / stride >= outHeight) loopStartM += stride;
    
    int loopEndN = min(filterWidth, x2 + 1);
    int loopEndM = min(filterHeight, y2 + 1);
    
    float res = 0;
    
    //float buff1[4] = {0,0,0,0};
    //float buff2[4] = {0,0,0,0};
    //int buffPos = 0;
    float4 fv1;
    float4 fv2;
    
    int loopCnt = (loopEndN - loopStartN) / stride;
    if((loopEndN - loopStartN) % stride != 0) loopCnt++;
    int c1 = loopCnt > 4 ? 1 : 0;
    int rest = loopCnt & 3; // = loopCnt % 4
    
    for(int numi = 0; numi < numFilters; numi++) {
        int filterIndx = (numi * inDepth + numj) * (filterWidth * filterHeight);
        int gradsIndxBaseBase = numi * (outWidth * outHeight) + ((x2 - loopStartN) / stride);
        for(int m = loopStartM; m < loopEndM; m+=stride) {
        	
        	int k = y2 - m;
        	//if(k < 0) break;
        	//if((k / stride) >= outHeight || k % stride != 0) continue;
        	
        	int gradsIndxBase = gradsIndxBaseBase + outWidth * (k / stride);
        	int filtersIndxBase = filterIndx + filterWidth * m + loopStartN;
        	
        	if(c1){
        		filtersIndxBase -= stride;
        		for(int n = 0; n < loopCnt / 4; n++){
        			fv1 = (float4)(grads[gradsIndxBase--], grads[gradsIndxBase--], grads[gradsIndxBase--], grads[gradsIndxBase--]);
        			fv2 = (float4)(filters[filtersIndxBase += stride], filters[filtersIndxBase += stride], filters[filtersIndxBase += stride], filters[filtersIndxBase += stride]);
        			res += dot(fv1, fv2);
        		}
        		filtersIndxBase += stride;
        	}
        	if(rest == 3){
        		fv1 = (float4)(grads[gradsIndxBase--], grads[gradsIndxBase--], grads[gradsIndxBase--], 0);
        		fv2 = (float4)(filters[filtersIndxBase], filters[filtersIndxBase += stride], filters[filtersIndxBase += stride], 0);
        		res += dot(fv1, fv2);
        	}else if(rest == 2){
        		fv1 = (float4)(grads[gradsIndxBase--], grads[gradsIndxBase--], 0, 0);
        		fv2 = (float4)(filters[filtersIndxBase], filters[filtersIndxBase += stride], 0, 0);
        		res += dot(fv1, fv2);
        	}else if(rest == 1){
        		res += grads[gradsIndxBase--] * filters[filtersIndxBase];
        	}else if(c1 == 0){
        		fv1 = (float4)(grads[gradsIndxBase--], grads[gradsIndxBase--], grads[gradsIndxBase--], grads[gradsIndxBase--]);
        		fv2 = (float4)(filters[filtersIndxBase], filters[filtersIndxBase += stride], filters[filtersIndxBase += stride], filters[filtersIndxBase += stride]);
        		res += dot(fv1, fv2);
        	}
    		/*for(int n = loopStartN; n < loopEndN; n+=stride) {
        		
        		//Just put this four times in a row and stuff
                res += grads[gradsIndxBase--] * filters[filtersIndxBase + n];
                
                /*buff1[buffPos] = grads[gradsIndxBase];
                buff2[buffPos] = filters[filtersIndxBase + n];
                gradsIndxBase--;
                buffPos++;
                if(buffPos == 4){
                	buffPos = 0;
                	fv1 = (float4)(buff1[0], buff1[1], buff1[2], buff1[3]);
                	fv2 = (float4)(buff2[0], buff2[1], buff2[2], buff2[3]);
                	res += dot(fv1, fv2);
                }*/
            	
            //}
        	
        }
    }
    
    /*if(buffPos == 3){
        fv1 = (float4)(buff1[0], buff1[1], buff1[2], 0);
    	fv2 = (float4)(buff2[0], buff2[1], buff2[2], 0);
    	res += dot(fv1, fv2);
    }else if(buffPos == 2){
        fv1 = (float4)(buff1[0], buff1[1], 0, 0);
    	fv2 = (float4)(buff2[0], buff2[1], 0, 0);
    	res += dot(fv1, fv2);
    }else if(buffPos == 1){
    	res += buff1[0] * buff2[0];
    }*/
    
    outGrad[matrixIndx + inWidth2 * (y2 - pad) + x2 - pad] = res;
}