Skip to content

Commit

Permalink
Use poisson sampling for spatial pass
Browse files Browse the repository at this point in the history
  • Loading branch information
WANG-Ruipeng committed Dec 4, 2024
1 parent 8cdda63 commit 5ae3903
Showing 1 changed file with 57 additions and 23 deletions.
80 changes: 57 additions & 23 deletions shaders/temporal_spatial_pass.comp
Original file line number Diff line number Diff line change
Expand Up @@ -31,44 +31,78 @@ layout(set = 0, binding = 1) uniform image2D reflectionDirection;
layout(set = 0, binding = 2) uniform image2D reflectionPointBrdf;
layout(set = 0, binding = 3) uniform image2D filteredReflectionColor;

const vec2 poissonDisk[8] = vec2[](
vec2(-0.5, -0.5),
vec2( 0.5, -0.5),
vec2(-0.5, 0.5),
vec2( 0.5, 0.5),
vec2(-0.25, -0.75),
vec2( 0.25, -0.75),
vec2(-0.75, 0.25),
vec2( 0.75, 0.25)
);


//--------------------------------------------------------------------------------------------------
//--------------------------------------------------------------------------------------------------
layout(local_size_x = 16, local_size_y = 16, local_size_z = 1) in;

void main()
{
ivec2 imageRes = imageSize(reflectionColor);
ivec2 imageRes = imageSize(reflectionColor);
ivec2 imageCoords = ivec2(gl_GlobalInvocationID.xy);

if (imageCoords.x >= imageRes.x || imageCoords.y >= imageRes.y)
return;

vec3 result = vec3(0.0);
float weightSum = 0.0;

// Base sampling radius
float radius = min(imageRes.x, imageRes.y) * 0.01;

// Get center pixel values as reference
vec4 centerColor = imageLoad(reflectionColor, imageCoords);
vec4 centerDirection = imageLoad(reflectionDirection, imageCoords);
vec4 centerBrdf = imageLoad(reflectionPointBrdf, imageCoords);

// Sample with distance-based weights
for (int i = 0; i < 8; i++) {
ivec2 offset = ivec2(poissonDisk[i] * radius);
ivec2 neighborCoords = imageCoords + offset;

// Boundary check
if (neighborCoords.x < 0 || neighborCoords.y < 0 ||
neighborCoords.x >= imageRes.x || neighborCoords.y >= imageRes.y)
continue;

// Calculate distance-based weight
float dist = length(vec2(offset));
float distanceWeight = exp(-dist * dist / (2.0 * radius * radius));

vec4 neighborColor = imageLoad(reflectionColor, neighborCoords);
vec4 neighborDirection = imageLoad(reflectionDirection, neighborCoords);
vec4 neighborBrdf = imageLoad(reflectionPointBrdf, neighborCoords);

vec3 brdf = neighborBrdf.rgb;
float pdf = max(neighborDirection.a, 0.001);

// Combine BRDF weight with distance weight
float weight = max(dot(brdf, vec3(1.0)) / pdf, 0.001) * distanceWeight;

result += neighborColor.rgb * weight;
weightSum += weight;
}

for (int dy = -1; dy <= 1; dy++) {
for (int dx = -1; dx <= 1; dx++) {
ivec2 neighborCoords = imageCoords + ivec2(dx, dy);
if (neighborCoords.x < 0 || neighborCoords.y < 0 || neighborCoords.x >= imageRes.x || neighborCoords.y >= imageRes.y)
continue;

vec4 neighborColor = imageLoad(reflectionColor, neighborCoords);
vec4 neighborDirection = imageLoad(reflectionDirection, neighborCoords);
vec4 neighborBrdf = imageLoad(reflectionPointBrdf, neighborCoords);

vec3 brdf = neighborBrdf.rgb;
float pdf = 1.0 / neighborDirection.a;

float weight = max(dot(brdf, vec3(1.0)) / pdf, 0.001);
// Add center sample contribution
float centerPdf = max(centerDirection.a, 0.001);
float centerWeight = max(dot(centerBrdf.rgb, vec3(1.0)) / centerPdf, 0.001) * 1.0; // Center has full weight
result += centerColor.rgb * centerWeight;
weightSum += centerWeight;

result += neighborColor.rgb * weight;
weightSum += weight;
}
}
// Normalize result
if (weightSum > 0.0) {
result /= weightSum;
} else {
result = centerColor.rgb;
}

imageStore(filteredReflectionColor, imageCoords, vec4(result, 1.0));
}

0 comments on commit 5ae3903

Please sign in to comment.