From 5ae390355965ff7af616c013154a8eaf1334aa63 Mon Sep 17 00:00:00 2001 From: Wang Ruipeng <98011585+WANG-Ruipeng@users.noreply.github.com> Date: Wed, 4 Dec 2024 18:58:56 -0500 Subject: [PATCH] Use poisson sampling for spatial pass --- shaders/temporal_spatial_pass.comp | 80 +++++++++++++++++++++--------- 1 file changed, 57 insertions(+), 23 deletions(-) diff --git a/shaders/temporal_spatial_pass.comp b/shaders/temporal_spatial_pass.comp index f692a7e..9af14b4 100644 --- a/shaders/temporal_spatial_pass.comp +++ b/shaders/temporal_spatial_pass.comp @@ -31,44 +31,78 @@ layout(set = 0, binding = 1) uniform image2D reflectionDirection; layout(set = 0, binding = 2) uniform image2D reflectionPointBrdf; layout(set = 0, binding = 3) uniform image2D filteredReflectionColor; +const vec2 poissonDisk[8] = vec2[]( + vec2(-0.5, -0.5), + vec2( 0.5, -0.5), + vec2(-0.5, 0.5), + vec2( 0.5, 0.5), + vec2(-0.25, -0.75), + vec2( 0.25, -0.75), + vec2(-0.75, 0.25), + vec2( 0.75, 0.25) +); - -//-------------------------------------------------------------------------------------------------- -//-------------------------------------------------------------------------------------------------- layout(local_size_x = 16, local_size_y = 16, local_size_z = 1) in; void main() { - ivec2 imageRes = imageSize(reflectionColor); + ivec2 imageRes = imageSize(reflectionColor); ivec2 imageCoords = ivec2(gl_GlobalInvocationID.xy); - + if (imageCoords.x >= imageRes.x || imageCoords.y >= imageRes.y) return; vec3 result = vec3(0.0); float weightSum = 0.0; + + // Base sampling radius + float radius = min(imageRes.x, imageRes.y) * 0.01; + + // Get center pixel values as reference + vec4 centerColor = imageLoad(reflectionColor, imageCoords); + vec4 centerDirection = imageLoad(reflectionDirection, imageCoords); + vec4 centerBrdf = imageLoad(reflectionPointBrdf, imageCoords); + + // Sample with distance-based weights + for (int i = 0; i < 8; i++) { + ivec2 offset = ivec2(poissonDisk[i] * radius); + ivec2 neighborCoords = imageCoords + offset; + + // Boundary check + if (neighborCoords.x < 0 || neighborCoords.y < 0 || + neighborCoords.x >= imageRes.x || neighborCoords.y >= imageRes.y) + continue; + + // Calculate distance-based weight + float dist = length(vec2(offset)); + float distanceWeight = exp(-dist * dist / (2.0 * radius * radius)); + + vec4 neighborColor = imageLoad(reflectionColor, neighborCoords); + vec4 neighborDirection = imageLoad(reflectionDirection, neighborCoords); + vec4 neighborBrdf = imageLoad(reflectionPointBrdf, neighborCoords); + + vec3 brdf = neighborBrdf.rgb; + float pdf = max(neighborDirection.a, 0.001); + + // Combine BRDF weight with distance weight + float weight = max(dot(brdf, vec3(1.0)) / pdf, 0.001) * distanceWeight; + + result += neighborColor.rgb * weight; + weightSum += weight; + } - for (int dy = -1; dy <= 1; dy++) { - for (int dx = -1; dx <= 1; dx++) { - ivec2 neighborCoords = imageCoords + ivec2(dx, dy); - if (neighborCoords.x < 0 || neighborCoords.y < 0 || neighborCoords.x >= imageRes.x || neighborCoords.y >= imageRes.y) - continue; - - vec4 neighborColor = imageLoad(reflectionColor, neighborCoords); - vec4 neighborDirection = imageLoad(reflectionDirection, neighborCoords); - vec4 neighborBrdf = imageLoad(reflectionPointBrdf, neighborCoords); - - vec3 brdf = neighborBrdf.rgb; - float pdf = 1.0 / neighborDirection.a; - - float weight = max(dot(brdf, vec3(1.0)) / pdf, 0.001); + // Add center sample contribution + float centerPdf = max(centerDirection.a, 0.001); + float centerWeight = max(dot(centerBrdf.rgb, vec3(1.0)) / centerPdf, 0.001) * 1.0; // Center has full weight + result += centerColor.rgb * centerWeight; + weightSum += centerWeight; - result += neighborColor.rgb * weight; - weightSum += weight; - } - } + // Normalize result if (weightSum > 0.0) { result /= weightSum; + } else { + result = centerColor.rgb; } + imageStore(filteredReflectionColor, imageCoords, vec4(result, 1.0)); } \ No newline at end of file