#include "/lib/pi.glsl"

float16_t brdf(vec3 normal, vec3 dir, vec3 light_dir, float16_t roughness) {
	const float f0 = 0.04;
	
	immut vec3 h = normalize(light_dir - dir);

	immut float n_d = abs(dot(normal, dir)) + 1.0e-5;
	immut float n_l = max(dot(normal, light_dir), 0.0);
	immut float n_h = max(dot(normal, h), 0.0);
	immut float l_h = max(dot(light_dir, h), 0.0);

	immut float16_t a_2 = roughness*roughness*roughness*roughness;
	immut float denom = n_h * n_h * (a_2 - float16_t(1.0)) + 1.0;

	immut float a_2_32 = float(a_2);
	immut float d = a_2_32 / (PI * denom * denom);

	immut float f = fma(pow(1.0 - l_h, 5.0), 1.0 - f0, f0);

	immut float16_t a_2_inv = float16_t(1.0) - a_2;
	immut float ggx_l = n_d * sqrt(a_2_inv * n_l * n_l + a_2_32);
	immut float ggx_d = n_l * sqrt(a_2_inv * n_d * n_d + a_2_32);

	return float16_t(0.5) * float16_t(f * d / (ggx_d + ggx_l));
}