#include "/lib/core.glsl"
#include "/lib/config.glsl"

/* Deferred Light Index Apply */

layout(local_size_x = 16, local_size_y = 16, local_size_z = 1) in;

readonly
#include "/buf/indirect.glsl"

uniform float day, far;
uniform vec3 cameraPositionFract;
uniform mat4 gbufferModelViewInverse, gbufferProjectionInverse;
uniform sampler2D depthtex0;

uniform sampler2D colortex1, colortex2;
uniform usampler2D colortex3;
uniform layout(r11f_g11f_b10f) restrict writeonly image2D deferredLight0;

#include "/lib/view.glsl"
#include "/lib/luminance.glsl"
#include "/lib/octa_normal.glsl"
#include "/lib/srgb.glsl"
#include "/lib/brdf.glsl"

#ifdef LIGHT_LEVELS
	#include "/lib/llv.glsl"
#elif INDEXED_BLOCK_LIGHT
	readonly
	#include "/buf/index.glsl"

	const uint local_index_size = uint(float(index.data.length()) * LDS_RATIO);

	shared ivec3 sh_bb_pe_min;
	shared ivec3 sh_bb_pe_max;
	shared ivec3 sh_bb_view_min;
	shared uint sh_index_len;
	shared ivec3 sh_bb_view_max;
	shared uint[local_index_size] sh_index_data;
	shared uint16_t[local_index_size] sh_index_color;

	#include "/lib/mhtn_dist.glsl"
#endif

void main() {
	#ifdef REAL_INT16
		immut i16vec2 texel = indirect.coords[gl_WorkGroupID.x] + i16vec2(gl_LocalInvocationID.xy);
	#else
		immut uint tile = indirect.coords[gl_WorkGroupID.x];
		immut ivec2 texel = ivec2(uvec2(bitfieldExtract(tile, 0, 16), bitfieldExtract(tile, 16, 16)) + gl_LocalInvocationID.xy);
	#endif

	immut float depth = texelFetch(depthtex0, texel, 0).r;
	immut vec2 texel_size = 1.0 / vec2(view());
	immut vec2 coord = fma(vec2(texel), texel_size, 0.5 * texel_size);
	immut vec4 view_undiv = gbufferProjectionInverse * vec4(fma(vec3(coord, depth), vec3(2.0), vec3(-1.0)), 1.0);
	immut vec3 view = view_undiv.xyz / view_undiv.w;
	immut vec3 pe = mat3(gbufferModelViewInverse) * view;

	immut uint gbuffer_data = texelFetch(colortex3, texel, 0).r;

	#if INDEXED_BLOCK_LIGHT && !defined LIGHT_LEVELS
		if (gl_LocalInvocationIndex == 0u) {
			sh_index_len = 0u;

			const ivec3 i32_max = ivec3(0x7fffffff);
			const ivec3 i32_min = ivec3(0x80000000);

			sh_bb_pe_min = i32_max;
			sh_bb_pe_max = i32_min;
			sh_bb_view_min = i32_max;
			sh_bb_view_max = i32_min;
		}

		immut f16vec3 abs_pe = abs(f16vec3(pe));
		immut float16_t chebyshev_dist = max3(abs_pe.x, abs_pe.y, abs_pe.z);

		// todo!() fix whatever is causing 260 to be the limit instead of 0
		immut bool lit = bitfieldExtract(gbuffer_data, 0, 13) > 260u && chebyshev_dist < float16_t(INDEX_DIST);

		barrier();

		if (lit) {
			immut ivec3 i_pe = ivec3(fma(sign(pe), vec3(0.5), pe));

			atomicMin(sh_bb_pe_min.x, i_pe.x); atomicMax(sh_bb_pe_max.x, i_pe.x);
			atomicMin(sh_bb_pe_min.y, i_pe.y); atomicMax(sh_bb_pe_max.y, i_pe.y);
			atomicMin(sh_bb_pe_min.z, i_pe.z); atomicMax(sh_bb_pe_max.z, i_pe.z);

			immut ivec3 i_view = ivec3(fma(sign(view), vec3(0.5), view));

			atomicMin(sh_bb_view_min.x, i_view.x); atomicMax(sh_bb_view_max.x, i_view.x);
			atomicMin(sh_bb_view_min.y, i_view.y); atomicMax(sh_bb_view_max.y, i_view.y);
			atomicMin(sh_bb_view_min.z, i_view.z); atomicMax(sh_bb_view_max.z, i_view.z);
		}

		/*
			if (subgroupAny(lit)) {
				immut vec3 sg_pe_min = subgroupMin(lit ? pe : vec3(1.0/0.0));
				immut vec3 sg_pe_max = subgroupMax(lit ? pe : vec3(-1.0/0.0));

				immut vec3 sg_view_min = subgroupMin(lit ? view : vec3(1.0/0.0));
				immut vec3 sg_view_max = subgroupMax(lit ? view : vec3(-1.0/0.0));

				if (subgroupElect()) {
					immut ivec3 i_sg_pe_min = ivec3(fma(sign(sg_pe_min), vec3(0.5), sg_pe_min));
					immut ivec3 i_sg_bb_max = ivec3(fma(sign(sg_pe_max), vec3(0.5), sg_pe_max));

					atomicMin(sh_bb_pe_min.x, i_sg_pe_min.x); atomicMax(sh_bb_pe_max.x, i_sg_bb_max.x);
					atomicMin(sh_bb_pe_min.y, i_sg_pe_min.y); atomicMax(sh_bb_pe_max.y, i_sg_bb_max.y);
					atomicMin(sh_bb_pe_min.z, i_sg_pe_min.z); atomicMax(sh_bb_pe_max.z, i_sg_bb_max.z);

					immut ivec3 i_sg_view_min = ivec3(fma(sign(sg_view_min), vec3(0.5), sg_view_min));
					immut ivec3 i_sg_view_max = ivec3(fma(sign(sg_view_max), vec3(0.5), sg_view_max));

					atomicMin(sh_bb_view_min.x, i_sg_view_min.x); atomicMax(sh_bb_view_max.x, i_sg_view_max.x);
					atomicMin(sh_bb_view_min.y, i_sg_view_min.y); atomicMax(sh_bb_view_max.y, i_sg_view_max.y);
					atomicMin(sh_bb_view_min.z, i_sg_view_min.z); atomicMax(sh_bb_view_max.z, i_sg_view_max.z);
				}
			}
		*/

		barrier();

		immut f16vec3 bb_pe_min = f16vec3(sh_bb_pe_min);
		immut f16vec3 bb_pe_max = f16vec3(sh_bb_pe_max);

		immut f16vec3 index_offset = f16vec3(-255.5 - cameraPositionFract - gbufferModelViewInverse[3].xyz + index.offset);

		if (all(greaterThanEqual(bb_pe_max, bb_pe_min))) { // make sure this tile isn't fully unlit, out of range or sky
			immut f16vec3 bb_view_min = f16vec3(sh_bb_view_min);
			immut f16vec3 bb_view_max = f16vec3(sh_bb_view_max);

			immut uint16_t global_len = index.len;
			for (uint16_t i = uint16_t(gl_LocalInvocationIndex); i < global_len; i += uint16_t(gl_WorkGroupSize.x * gl_WorkGroupSize.y)) {
				immut uint light_data = index.data[i];

				immut f16vec3 pe_light = f16vec3(
					bitfieldExtract(light_data, 0, 9),
					bitfieldExtract(light_data, 9, 9),
					bitfieldExtract(light_data, 18, 9)
				) + index_offset;

				immut float16_t intensity = float16_t(bitfieldExtract(light_data.x, 27, 4)) + float16_t(1.0); // not sure why this +1 is needed here

				// distance between light and closest point on bounding box
				// in world-aligned space (player-eye) we can use Manhattan distance
				immut bool pe_visible = mhtn_dist(pe_light, clamp(pe_light, bb_pe_min, bb_pe_max)) <= intensity;

				immut f16vec3 v_light = f16vec3(pe_light * mat3(gbufferModelViewInverse));
				immut bool view_visible = distance(v_light, clamp(v_light, bb_view_min, bb_view_max)) <= intensity;

				if (pe_visible && view_visible) {
					immut uint j = atomicAdd(sh_index_len, 1u);

					sh_index_data[j] = light_data;
					sh_index_color[j] = index.color[i];
				}
			}
		}

		barrier();
	#endif

	f16vec4 color_s = f16vec4(texelFetch(colortex1, texel, 0));
	color_s.rgb = linear(color_s.rgb);

	immut f16vec2 light = f16vec2(
		bitfieldExtract(gbuffer_data, 0, 13),
		bitfieldExtract(gbuffer_data, 13, 13)
	) / float16_t(8191.0);

	#ifdef LIGHT_LEVELS
		f16vec3 block_light = f16vec3(visualize_ll(light.x));
	#else
		f16vec3 block_light = light.x*light.x * f16vec3(1.2, 1.2, 1.0);
	#endif

	#if INDEXED_BLOCK_LIGHT && !defined LIGHT_LEVELS
		if (lit) {
			immut f16vec4 octa_normal = f16vec4(texelFetch(colortex2, texel, 0));
			immut float16_t roughness = float16_t(1.0) - sqrt(abs(color_s.a));

			immut f16vec3 n_pe = f16vec3(normalize(pe));
			immut vec3 offset = vec3(index_offset) - pe;

			immut f16vec3 w_face_normal = normalize(octa_decode(octa_normal.zw));
			immut f16vec3 w_tex_normal = normalize(octa_decode(octa_normal.xy));

			f16vec3 diffuse = f16vec3(0.0);
			f16vec3 specular = f16vec3(0.0);

			immut uint16_t index_len = uint16_t(sh_index_len);
			for (uint16_t i = uint16_t(0u); i < index_len; ++i) {
				immut uint light_data = sh_index_data[i];

				immut vec3 w_rel_light = vec3(
					bitfieldExtract(light_data, 0, 9),
					bitfieldExtract(light_data, 9, 9),
					bitfieldExtract(light_data, 18, 9)
				) + offset;

				immut f16vec3 f16_w_rel_light = f16vec3(w_rel_light);
				immut float16_t intensity = float16_t(bitfieldExtract(light_data.x, 27, 4));
				immut float16_t mhtn_dist = mhtn_length(f16_w_rel_light);

				if (mhtn_dist <= intensity) { // use this in culling too
					immut uint16_t light_color = sh_index_color[i];
					immut f16vec3 n_w_rel_light = f16vec3(normalize(w_rel_light));
					immut float16_t tex_lambertian = dot(w_tex_normal, n_w_rel_light);

					float16_t dist_light = length(f16_w_rel_light);
					if (bitfieldExtract(light_data, 31, 1) == 0u) dist_light *= dist_light;
					// use linear falloff instead of inverse square law when the "wide" flag is set

					immut float16_t brightness = min(intensity - mhtn_dist, float16_t(1.0)) * intensity / dist_light;
					immut uint light_color_32 = uint(light_color);
					immut f16vec3 illum = min(brightness, float16_t(48.0)) * f16vec3(
						bitfieldExtract(light_color_32, 6, 5),
						bitfieldExtract(light_color_32, 0, 6),
						bitfieldExtract(light_color_32, 11, 5)
					);

					if (tex_lambertian > float16_t(0.0) && dot(w_face_normal, n_w_rel_light) > float16_t(0.0)) {
						immut float16_t specular_light = brdf(w_tex_normal, n_pe, n_w_rel_light, roughness) * float16_t(0.03);
						specular = fma(specular_light.xxx, illum, specular);

						diffuse = fma((tex_lambertian * (float16_t(1.0) - specular_light) / PI_16).xxx, illum, diffuse);
					}

					diffuse = fma(f16vec3(IND_ILLUM), illum, diffuse); // very fake GI
					/*
						float lighting = IND_ILLUM;

						if (lit) {
							bool visible = true;

							immut vec3 v_pos = (gbufferModelView * vec4(n_pos * -1.5 + pos + pe, 1.0)).xyz;

							for (uint i = 1u; i < 32u && visible; ++i) {
								vec4 clip_sample = gbufferProjection * vec4(view * v_pos / mix(v_pos, view, float(i) / 32.0), 1.0);
								immut vec3 screen_sample = (clip_sample.xyz / clip_sample.w) * 0.5 + 0.5;

								if (screen_sample.z > textureLod(depthtex0, screen_sample.xy, 0).r - 0.001) visible = false;
							}

							lighting += float(visible);
						}
					*/
				}
			}

			const f16vec3 packing_scale = f16vec3(1.0 / (15.0 * vec3(31.0, 63.0, 31.0))); // Undo the multiplication from packing light color and brightness
			diffuse *= packing_scale;
			specular *= packing_scale;

			immut f16vec3 new_light = float16_t(INDEXED_BLOCK_LIGHT * 2) * light.x * f16vec3(diffuse + specular / max(color_s.rgb, float16_t(1.0e-4)));

			block_light = mix(new_light, block_light, smoothstep(float16_t(INDEX_DIST - 15), float16_t(INDEX_DIST), chebyshev_dist));
			// block_light++; // DEBUG `lit`
		}

		// Debug culling & LDS overflow
		// block_light.gb += f16vec2(sh_index_len < index.len, sh_index_len == 0);
		// block_light.rgb += distance(max(float16_t(sh_bb_view_min), float16_t(0.0)), max(float16_t(sh_bb_view_max), float16_t(0.0))) * float16_t(0.01);
		// if (sh_index_len > local_index_size) block_light *= 10;
	#endif

	#ifdef LIGHT_LEVELS
		const float16_t sky_light = float16_t(0.0);
	#else
		#ifdef NETHER
			const f16vec3 sky_light = f16vec3(0.3, 0.15, 0.2);
		#elif defined END
			const f16vec3 sky_light = f16vec3(0.15, 0.075, 0.2);
		#else
			immut float16_t sky_light = float16_t(1.0) - sqrt(float16_t(1.0) - light.y * clamp(float16_t(day) * float16_t(10.0), float16_t(0.25), float16_t(1.0)));
		#endif
	#endif

	immut float16_t emission = float16_t(bitfieldExtract(gbuffer_data, 26, 4));
	immut float16_t emi = fma(emission, float16_t(0.2), float16_t(luminance(color_s.rgb)));
	immut f16vec3 final_light = sky_light * float16_t(0.4) + emi*emi*emi*emi * float16_t(0.005) + block_light;

	imageStore(deferredLight0, texel, vec4(final_light, 0.0));
}