#version 420 core

#include <constants.glsl>







#include <globals.glsl>
#include <cloud.glsl>
#include <lod.glsl>

layout(location = 0) in vec3 f_pos;
layout(location = 1) in vec3 f_norm;
layout(location = 2) in float pull_down;
layout(location = 0) out vec4 tgt_color;

#include <sky.glsl>

void main() {
    float my_alt = /*f_pos.z;*/alt_at_real(f_pos.xy);
    vec3 my_pos = vec3(f_pos.xy, my_alt);
    vec3 my_norm = lod_norm(f_pos.xy/*, f_square*/);

    float which_norm = dot(my_norm, normalize(cam_pos.xyz - my_pos));
    vec3 f_norm = mix(faceforward(f_norm, cam_pos.xyz - f_pos, -f_norm), my_norm, which_norm);
    vec3 f_pos = mix(f_pos, my_pos, which_norm);
    vec3 delta_sides = mix(fract(f_pos) - 1.0, fract(f_pos), lessThan(sides, vec3(0.0)));
    // Three faces: xy, xz, and yz.
    // TODO: Handle zero slopes (for xz and yz).
    vec2 corner_xy = min(abs(f_norm.xy / f_norm.z * delta_sides.z), 1.0);
    vec2 corner_yz = min(abs(f_norm.yz / f_norm.x * delta_sides.x), 1.0);
    vec2 corner_xz = min(abs(f_norm.xz / f_norm.y * delta_sides.y), 1.0);
    // Now we just compute an (upper bounded) distance to the corner in each direction.
    // vec3 corner_distance = min(abs(corner_delta), 1.0);
    // Now, if both sides hit something, lerp to 0.25.  If one side hits something, lerp to 0.75.  And if no sides hit something,
    // lerp to 1.0 (TODO: incorporate the corner properly).
    // Bilinear interpolation on each plane:
    float ao_xy = dot(vec2(corner_xy.x, 1.0 - corner_xy.x), mat2(vec2(corner_xy.x < 1.00 ? corner_xy.y < 1.00 ? 0.25 : 0.5 : corner_xy.y < 1.00 ? 0.5 : 0.75, corner_xy.x < 1.00 ? 0.75 : 1.00), vec2(corner_xy.y < 1.00 ? 0.75 : 1.0, 1.0)) * vec2(corner_xy.y, 1.0 - corner_xy.y));
    float ao_yz = dot(vec2(corner_yz.x, 1.0 - corner_yz.x), mat2(vec2(corner_yz.x < 1.00 ? corner_yz.y < 1.00 ? 0.25 : 0.5 : corner_yz.y < 1.00 ? 0.5 : 0.75, corner_yz.x < 1.00 ? 0.75 : 1.00), vec2(corner_yz.y < 1.00 ? 0.75 : 1.0, 1.0)) * vec2(corner_yz.y, 1.0 - corner_yz.y));
    float ao_xz = dot(vec2(corner_xz.x, 1.0 - corner_xz.x), mat2(vec2(corner_xz.x < 1.00 ? corner_xz.y < 1.00 ? 0.25 : 0.5 : corner_xz.y < 1.00 ? 0.5 : 0.75, corner_xz.x < 1.00 ? 0.75 : 1.00), vec2(corner_xz.y < 1.00 ? 0.75 : 1.0, 1.0)) * vec2(corner_xz.y, 1.0 - corner_xz.y));
    // Now, multiply each component by the face "share" which is just the absolute value of its normal for that plane...
    vec3 f_ao_vec = mix(/*abs(voxel_norm)*/vec3(1.0, 1.0, 1.0), /*abs(voxel_norm) * */vec3(ao_yz, ao_xz, ao_xy), /*abs(voxel_norm)*/vec3(length(f_norm.yz), length(f_norm.xz), length(f_norm.xy))/*vec3(1.0)*//*sign(max(view_dir * sides, 0.0))*/);
    float f_orig_len = length(cam_pos.xyz - f_pos);
    vec3 f_orig_view_dir = normalize(cam_pos.xyz - f_pos);
    vec3 voxel_norm;

    const float VOXELIZE_DIST = 2000;
    float voxelize_factor = clamp(1.0 - (distance(focus_pos.xy, f_pos.xy) - view_distance.x) / VOXELIZE_DIST, 0, 1);
    vec3 cam_dir = normalize(cam_pos.xyz - f_pos.xyz);
    vec3 side_norm = normalize(vec3(my_norm.xy, 0));
    vec3 top_norm = vec3(0, 0, 1);
    float side_factor = 1.0 - my_norm.z;
    // min(dot(vec3(0, -sign(cam_dir.y), 0), -cam_dir), dot(vec3(-sign(cam_dir.x), 0, 0), -cam_dir))
    if (max(abs(my_norm.x), abs(my_norm.y)) < 0.01 || fract(my_alt) * clamp(dot(normalize(vec3(cam_dir.xy, 0)), side_norm), 0, 1) < cam_dir.z / my_norm.z) {
        f_ao *= mix(1.0, clamp(fract(my_alt) / length(my_norm.xy) + clamp(dot(side_norm, -cam_dir), 0, 1), 0, 1), voxelize_factor);
        voxel_norm = top_norm;
    } else {
        f_ao *= mix(1.0, clamp(pow(fract(my_alt), 0.5), 0, 1), voxelize_factor);

        if (fract(f_pos.x) * abs(my_norm.y / cam_dir.x) < fract(f_pos.y) * abs(my_norm.x / cam_dir.y)) {
            voxel_norm = vec3(sign(cam_dir.x), 0, 0);
        } else {
            voxel_norm = vec3(0, sign(cam_dir.y), 0);

    float shadow_alt = /*f_pos.z;*/alt_at(f_pos.xy);//max(alt_at(f_pos.xy), f_pos.z);
    // float shadow_alt = f_pos.z;
    float shadow_alt = f_pos.z;

    vec4 f_shadow = textureBicubic(t_horizon, s_horizon, pos_to_tex(f_pos.xy));
    float sun_shade_frac = horizon_at2(f_shadow, shadow_alt, f_pos, sun_dir);
    // float sun_shade_frac = 1.0;
    float sun_shade_frac = 1.0;//horizon_at2(f_shadow, shadow_alt, f_pos, sun_dir);
    float moon_shade_frac = 1.0;//horizon_at2(f_shadow, shadow_alt, f_pos, moon_dir);

    // Magic stop-gap code without any physical justification.
    vec3 lerpy_norm;
    if (my_norm.z/*f_norm.z*/ > 0.99999) {
        lerpy_norm = vec3(0, 0, 1);
    } else {
        vec3 side_norm = normalize(vec3(my_norm.xy, 0));
        // lerpy_norm = f_norm;
        float mix_factor = clamp(abs(dot(f_orig_view_dir, side_norm)), 0, 1);
        lerpy_norm = mix(
            mix(my_norm, side_norm, clamp(dot(side_norm, my_norm) + 0.5, 0, 1)),
    const float DIST = 0.07;
    voxel_norm = normalize(mix(voxel_norm, lerpy_norm, clamp(my_norm.z * my_norm.z - (1.0 - DIST), 0, 1) / DIST));

    f_pos.xyz += abs(voxel_norm) * delta_sides;
    voxel_norm = mix(my_norm, voxel_norm == vec3(0.0) ? f_norm : voxel_norm, voxelize_factor);

    vec3 hash_pos = f_pos + focus_off.xyz;
    const float A = 0.055;
    const float W_INV = 1 / (1 + A);
    const float W_2 = W_INV * W_INV;//pow(W_INV, 2.4);
    const float NOISE_FACTOR = 0.02;//pow(0.02, 1.2);
    float noise = hash(vec4(floor(hash_pos * 3.0 - voxel_norm * 0.5), 0));//0.005/* - 0.01*/;
    vec3 noise_delta = (sqrt(f_col_raw) * W_INV + noise * NOISE_FACTOR);
    // noise_delta = noise_delta * noise_delta * W_2 - f_col;
    // lum = W ⋅ col
    // lum + noise = W ⋅ (col + delta)
    // W ⋅ col + noise = W ⋅ col + W ⋅ delta
    // noise = W ⋅ delta
    // delta = noise / W
    // vec3 col = (f_col + noise_delta);
    // vec3 col = noise_delta * noise_delta * W_2;

    vec3 f_col = noise_delta * noise_delta * W_2;
    // f_col = /*srgb_to_linear*/(f_col + hash(vec4(floor(hash_pos * 3.0 - voxel_norm * 0.5), 0)) * 0.01/* - 0.01*/); // Small-scale noise

    // f_ao = 1.0;
    // f_ao = dot(f_ao_vec, sqrt(1.0 - delta_sides * delta_sides));

    f_ao *= dot(f_ao_vec, abs(voxel_norm));
    // f_ao = sqrt(dot(f_ao_vec * abs(voxel_norm), sqrt(1.0 - delta_sides * delta_sides)) / 3.0);

    vec3 cam_to_frag = normalize(f_pos - cam_pos.xyz);
    vec3 view_dir = -cam_to_frag;
    // vec3 view_dir = normalize(f_pos - cam_pos.xyz);

    // DirectionalLight sun_info = get_sun_info(sun_dir, sun_shade_frac, light_pos);
    DirectionalLight sun_info = get_sun_info(sun_dir, sun_shade_frac, /*sun_pos*/f_pos);
    DirectionalLight moon_info = get_moon_info(moon_dir, moon_shade_frac/*, light_pos*/);

    float alpha = 1.0;//0.1;//0.2;///1.0;//sqrt(2.0);
    const float n2 = 1.5;
    const float R_s2s0 = pow((1.0 - n2) / (1.0 + n2), 2);
    const float R_s1s0 = pow((1.3325 - n2) / (1.3325 + n2), 2);
    const float R_s2s1 = pow((1.0 - 1.3325) / (1.0 + 1.3325), 2);
    const float R_s1s2 = pow((1.3325 - 1.0) / (1.3325 + 1.0), 2);
    float cam_alt = alt_at(cam_pos.xy);
    float fluid_alt = medium.x == 1u ? max(cam_alt + 1, floor(shadow_alt)) : view_distance.w;
    float R_s = (f_pos.z < my_alt) ? mix(R_s2s1 * R_s1s0, R_s1s0, medium.x) : mix(R_s2s0, R_s1s2 * R_s2s0, medium.x);

    vec3 emitted_light, reflected_light;

    vec3 mu = medium.x == 1u/* && f_pos.z <= fluid_alt*/ ? MU_WATER : vec3(0.0);
    // NOTE: Default intersection point is camera position, meaning if we fail to intersect we assume the whole camera is in water.
    vec3 cam_attenuation = compute_attenuation_point(cam_pos.xyz, view_dir, mu, fluid_alt, /*cam_pos.z <= fluid_alt ? cam_pos.xyz : f_pos*/f_pos);
    // Use f_norm here for better shadows.
    // vec3 light_frac = light_reflection_factor(f_norm/*l_norm*/, view_dir, vec3(0, 0, -1.0), vec3(1.0), vec3(/*1.0*/R_s), alpha);

    float ao = f_ao;// /*pow(f_ao, 0.5)*/f_ao * 0.9 + 0.1;
    emitted_light *= ao;
    reflected_light *= ao;

    vec3 surf_color;
        if (length(f_col_raw - vec3(0.02, 0.06, 0.22)) < 0.025 && dot(vec3(0, 0, 1), f_norm) > 0.9) {
            vec3 water_color = (1.0 - MU_WATER) * MU_SCATTER;

            vec3 reflect_ray = cam_to_frag * vec3(1, 1, -1);

            float passthrough = dot(faceforward(f_norm, f_norm, cam_to_frag), -cam_to_frag);

            vec3 reflect_color = get_sky_color(reflect_ray, time_of_day.x, f_pos, vec3(-100000), 0.125, true);
            reflect_color = get_cloud_color(reflect_color, reflect_ray, cam_pos.xyz, time_of_day.x, 100000.0, 0.1);

            const float REFLECTANCE = 0.5;
            surf_color = illuminate(max_light, view_dir, f_col * emitted_light, reflect_color * REFLECTANCE + water_color * reflected_light);

            const vec3 underwater_col = vec3(0.0);
            float min_refl = min(emitted_light.r, min(emitted_light.g, emitted_light.b));
            surf_color = mix(underwater_col, surf_color, (1.0 - passthrough) * 1.0 / (1.0 + min_refl));
        } else {
            surf_color = illuminate(max_light, view_dir, f_col * emitted_light, f_col * reflected_light);
        surf_color = illuminate(max_light, view_dir, f_col * emitted_light, f_col * reflected_light);

    tgt_color = vec4(surf_color, 1.0);