/**
 * A shader to blur the Single Channel Texture.
 *
 * It takes as input a monochannel floating point texture, and it blurs its data
 * to an output texture of the same format. It uses a Bilateral Blur filter to 
 * avoid blurring across depth discontinuities. This helps preserving edges between
 * objects in the foreground and objects in the background.
 */
precision highp float;
out float blurred;                     //The output texture has only one R channel.

uniform sampler2D uSingleChannelTex;     // Single channel texture
uniform sampler2D uDepthTex; 	         // depth texture
uniform float uRadius;                   // hemispheric kernel radius inside of which to evaluate for occlusion.
uniform mat4 uProjectionMatrixInverse;   // inverse projection matrix to unproject depth.

uniform bool uLargeKernel;                // Kernel to use for blurring. If large uses 9x9 else 3x3

// 8 neighbour offsets around the pixel
const int OFFSET_COUNT = 8;
const float EPSILON = 0.0001;

const vec2[] OFFSETS = vec2[]
(
    vec2(1, 0),
    vec2(1, -1),
    vec2(0, -1),
    vec2(-1, -1),
    vec2(-1, 0),
    vec2(-1, 1),
    vec2(0, 1),
    vec2(1, 1)
);

// For the corresponding OFFSETS
const float[] GAUSSIAN_COEFFICIENTS = float[]
(
    2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0
);

in vec2 vUv;

float getCameraDepth(const in vec2 uv) {
	float d =  2.0 * texture(uDepthTex, uv).r - 1.0;
	float x = 2.0 * (uv.x) - 1.0;
	float y = 2.0 * (uv.y) - 1.0;
	vec4 v_pos = uProjectionMatrixInverse * vec4(x, y, d, 1.0f);
	v_pos /= v_pos.w;
	return v_pos.z;
}

  const float EDGE_SHARPNESS = 1.0;
  const int SCALE = 2;
  
// Blurring 9x9
float largeBlur() {
    vec2 screenSize = vec2(textureSize(uSingleChannelTex, 0));

    float sum = texture(uSingleChannelTex, vUv).x;
    float originDepth = texture(uDepthTex, vUv).x;
    float totalWeight = 1.0;
    sum *= totalWeight;

    for (int x = -4; x <= 4; x++) {
      for (int y = -4; y <= 4; y++) {
        if (x != 0 || y != 0) {
          vec2 samplePosition = vUv +
            vec2(float(x * SCALE), float(y * SCALE)) * vec2(1.0/screenSize.x, 1.0/screenSize.y);
          float val = texture(uSingleChannelTex, samplePosition).x;
          float sampleDepth = texture(uDepthTex, samplePosition).x;
          int kx = 4 - abs(x);
          int ky = 4 - abs(y);
          float weight = 0.3 + (float(kx * ky) / (25.0 * 25.0));
          weight *= max(0.0, 1.0 - (EDGE_SHARPNESS * 2000.0) * abs(sampleDepth - originDepth));

          sum += val * weight;
          totalWeight += weight;
        }
      }
    }

    return sum / (totalWeight + EPSILON);
}

// Blurring 3x3
float smallBlur()
{
    vec2 texelSize =  1.0 / vec2(textureSize(uSingleChannelTex, 0));
    float centerVal = texture(uSingleChannelTex, vUv).r;
    float centerDepth = getCameraDepth(vUv);
    float result = centerVal * 4.0;
    float weightSum = 4.0;

    for(int i = 0; i < OFFSET_COUNT; ++i) {
        vec2 offset = OFFSETS[i] * texelSize;
        vec2 currUv = offset + vUv;
        float w = GAUSSIAN_COEFFICIENTS[i];
        float currDepth = getCameraDepth(currUv);
        // For each of the neighboring samples, we measure how close is it in depth to the center
        // fragment, relative to the a predefined distance. We weight the neighbor contribution to the blur with 
        // the closeness. Depth discontinuities will prevent neighbors to affect the final value 
        // of the current fragment.
        float closeness = clamp(1.0 - abs(centerDepth - currDepth) / uRadius, 0.0, 1.0);
        w *= closeness;
        result += texture(uSingleChannelTex, currUv).r * w;
        weightSum += w;
    }
    return result / weightSum;
}
void main()
{
    if(uLargeKernel)    
        blurred = largeBlur();
    else
        blurred = smallBlur();
}
