/*begin_replacements
inWidth;info[0]
inHeight;info[1]
inDepth;info[2]
outWidth;info[3]
outHeight;info[4]
filterHeight;info[5]
filterWidth;info[6]
stride;info[7]
numi;get_global_id(0)
end_replacements*/
//The above code causes certain keywords in the source code to be replaced by others
//This makes the code easier to read while keeping it efficient by not having to
//create variables in each function for each keyword -> less memory usage, more direct accesses to constants, not having to constantly look up which index in info stands for which parameter
//e.g any occourance of inWidth will be replaced by info[0]
//the semicolon is the seperator between to-be-replaced word and replacement

//Base function
float conv(__constant int *info, __global const float *in, __constant float *filters, __global const float *bias, int k, int l) {
    
    float res = 0;
    
    int indx2 = 0;
    int indx3 = 0;
    
    int indx4 = inWidth * (k * stride) + l * stride;
    
    int start = 0;
    
    int rest = filterWidth & 3; // = filterWidth % 4
    
    int c1 = filterWidth > 4 ? 1 : 0;
    
    float4 vfin4 = (float4)(0,0,0,0);
    float4 vffilter4 = (float4)(0,0,0,0);
    
    int matrixIndx = 0;
    int filterIndx = (numi * inDepth) * (filterWidth * filterHeight);
    for(int numj = 0; numj < inDepth; numj++){
        indx3 = filterIndx;
        indx2 = matrixIndx + indx4;
        for(int m = 0; m < filterHeight; m++) {
            /*for(int n = 0; n < filterWidth; n++) {
                res += in[indx2 + n] * filters[indx3 + n];
            }*/
            start = 0;
            if(c1){
	        	for(int n = 0; n < filterWidth / 4; n++){
	            	vfin4 = (float4)(in[indx2 + start], in[indx2 + start + 1], in[indx2 + start + 2], in[indx2 + start + 3]);
		            vffilter4 = (float4)(filters[indx3 + start], filters[indx3 + start + 1], filters[indx3 + start + 2], filters[indx3 + start + 3]);
		            res += dot(vfin4, vffilter4);
		            start += 4;
	        	}
        	}
        	if(rest == 3){
            	float3 vfin3 = (float3)(in[indx2 + start], in[indx2 + start + 1], in[indx2 + start + 2]);
	            float3 vffilter3 = (float3)(filters[indx3 + start], filters[indx3 + start + 1], filters[indx3 + start + 2]);
	            res += dot(vfin3, vffilter3);
        	}else if(rest == 2){
            	float2 vfin2 = (float2)(in[indx2 + start], in[indx2 + start + 1]);
	            float2 vffilter2 = (float2)(filters[indx3 + start], filters[indx3 + start + 1]);
	            res += dot(vfin2, vffilter2);
        	}else if(rest == 1){
        		res += in[indx2 + start] * filters[indx3 + start];
        	}else if(c1 == 0){
        		vfin4 = (float4)(in[indx2 + start], in[indx2 + start + 1], in[indx2 + start + 2], in[indx2 + start + 3]);
	            vffilter4 = (float4)(filters[indx3 + start], filters[indx3 + start + 1], filters[indx3 + start + 2], filters[indx3 + start + 3]);
	            res += dot(vfin4, vffilter4);
        	}
            
            
            indx2 += inWidth;
            indx3 += filterWidth;
        }
        matrixIndx += inWidth * inHeight;
        filterIndx += filterWidth * filterHeight;
    }
    //out[outputIndx] = (double)res + (float)bias[numi];
    return res + bias[numi];
}

__kernel void conv_linear(__constant int *info, __global const float *in, __constant float *filters __attribute__((max_constant_size(4096))), __global const float *bias, __global float *out){
    
    int k = get_global_id(1) / outWidth;
    int l = get_global_id(1) % outWidth;
    
    int outputIndx = numi * (outWidth * outHeight) + outWidth * k + l;
    
    float res = conv(info, in, filters, bias, k, l);
    
    out[outputIndx] = res;
}

__kernel void conv_tanh(__constant int *info, __global const float *in, __constant float *filters __attribute__((max_constant_size(4096))), __global const float *bias, __global float *out, __global float *gradMuls){
    
    int k = get_global_id(1) / outWidth;
    int l = get_global_id(1) % outWidth;
    
    int outputIndx = numi * (outWidth * outHeight) + outWidth * k + l;
    
    float res = conv(info, in, filters, bias, k, l);
    
    float t = tanh(res);
    out[outputIndx] = t;
    
    gradMuls[outputIndx] = (1.0 - (t * t));
}

__kernel void conv_ReLU(__constant int *info, __global const float *in, __constant float *filters __attribute__((max_constant_size(4096))), __global const float *bias, __global float *out, __global float *gradMuls, __global const float *slope){
	
	int k = get_global_id(1) / outWidth;
    int l = get_global_id(1) % outWidth;
	
	int outputIndx = numi * (outWidth * outHeight) + outWidth * k + l;
    
    float res = conv(info, in, filters, bias, k, l);
	
	gradMuls[outputIndx] = 1.0;
	if(res < 0){
		res *= slope[0];
		gradMuls[outputIndx] = slope[0];
	}
	
	out[outputIndx] = res;
}

__kernel void conv_sigmoid(__constant int *info, __global const float *in, __constant float *filters __attribute__((max_constant_size(4096))), __global const float *bias, __global float *out, __global float *gradMuls){
    
    int k = get_global_id(1) / outWidth;
    int l = get_global_id(1) % outWidth;
    
    int outputIndx = numi * (outWidth * outHeight) + outWidth * k + l;
    
    float res = conv(info, in, filters, bias, k, l);
    
    float t = 1.0 / (1.0 + native_exp(-res));
    out[outputIndx] = t;
    
    gradMuls[outputIndx] = (t * (1.0 - t));
}

__kernel void conv_sine(__constant int *info, __global const float *in, __constant float *filters __attribute__((max_constant_size(4096))), __global const float *bias, __global float *out, __global float *gradMuls){
    
    int k = get_global_id(1) / outWidth;
    int l = get_global_id(1) % outWidth;
    
    int outputIndx = numi * (outWidth * outHeight) + outWidth * k + l;
    
    float res = conv(info, in, filters, bias, k, l);
    
    out[outputIndx] = native_sin(res);
    
    gradMuls[outputIndx] = native_cos(res);
}

__kernel void conv_ELU(__constant int *info, __global const float *in, __constant float *filters __attribute__((max_constant_size(4096))), __global const float *bias, __global float *out, __global float *gradMuls, __global const float *alpha){
	
	int k = get_global_id(1) / outWidth;
    int l = get_global_id(1) % outWidth;
	
	int outputIndx = numi * (outWidth * outHeight) + outWidth * k + l;
    
    float res = conv(info, in, filters, bias, k, l);
	
	gradMuls[outputIndx] = 1.0;
	if(res < 0){
		res = alpha[0] * (native_exp(res) - 1.0);
		gradMuls[outputIndx] = (res + alpha[0]);
	}
	
	out[outputIndx] = res;
}