Skip to content

Commit

Permalink
add bilateral pass
Browse files Browse the repository at this point in the history
  • Loading branch information
WANG-Ruipeng committed Dec 6, 2024
1 parent a622dd6 commit aa9783e
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 79 deletions.
78 changes: 6 additions & 72 deletions shaders/bilateral_cleanup_pass.comp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ layout(set = 0, binding = 0) uniform image2D reflectionColor;
layout(set = 0, binding = 1) uniform image2D reflectionDirection;
layout(set = 0, binding = 2) uniform image2D reflectionPointBrdf;
layout(set = 0, binding = 3) uniform image2D filteredReflectionColor;
layout(set = 5, binding = 4) uniform image2D bilateralCleanupColor;
layout(set = 0, binding = 4) uniform image2D bilateralCleanupColor;

const ivec2 samplePattern[4][16] = ivec2[4][16](
ivec2[16](
Expand Down Expand Up @@ -125,79 +125,13 @@ void main()
if (imageCoords.x >= imageRes.x || imageCoords.y >= imageRes.y)
return;

// STAGE 1; Spatial reconstruction filtering
ivec2 halfImageCoords = imageCoords / 2;

vec3 result = vec3(0.0);
vec4 centerColor = imageLoad(filteredReflectionColor, imageCoords);
float variance = centerColor.w;
vec3 result = vec3(0.8,0.0,0.0);
float weightSum = 0.0;

// Base sampling radius
float radius = min(imageRes.x, imageRes.y) * 0.003;

// Get center pixel values as reference
vec4 centerColor = imageLoad(reflectionColor, halfImageCoords);
if (centerColor.a == 0.0)
{
imageStore(filteredReflectionColor, imageCoords, vec4(0.0));
return;
}

vec4 centerDirection = imageLoad(reflectionDirection, halfImageCoords);
vec4 centerBrdf = imageLoad(reflectionPointBrdf, halfImageCoords);

// Add center sample contribution
float centerPdfInv = max(centerDirection.a, 0.001);
float centerWeight = max(centerColor.a * centerPdfInv, 0.0);
vec3 centerContrib = centerColor.rgb * centerWeight;
result += centerContrib;
weightSum += centerWeight;

uint quadID = (imageCoords.y & 0x1) << 1 + (imageCoords.x & 0x1);
float variance = 0.0;

for (int i = 0; i < 16; ++i)
{
ivec2 offset = samplePattern[quadID][i];
ivec2 neighborCoords = halfImageCoords + offset;

vec4 neighborColor = imageLoad(reflectionColor, neighborCoords);
if (neighborColor.a == 0.0)
continue;

float brdfWeight = neighborColor.a;
float pdfInv = max(imageLoad(reflectionDirection, neighborCoords).a, 1e-5);

// Combine BRDF weight with distance weight
float weight = max(brdfWeight * pdfInv, 0.0);
vec3 contrib = neighborColor.rgb * weight;
vec3 diff = contrib - centerContrib;
variance += dot(diff, diff);

result += contrib;
weightSum += weight;
}

// Normalize result
if (weightSum > 1e-5) {
result /= weightSum;
} else {
result = centerColor.rgb;
}

// temporal accumulation
//float ratio = 1.0 / min(64.0, (1.0 + float(rtxState.frame)));
//vec3 colPrev = imageLoad(filteredReflectionColor, imageCoords).rgb;
//result = mix(colPrev, result, ratio);

imageStore(filteredReflectionColor, imageCoords, vec4(result, sqrt(variance)));
//barriering is fine for now, if we need to optimize, maybe spliting different stage into seperate shader will be good
memoryBarrier();
barrier();

// STAGE 2: bilateral filter
if(sqrt(variance) > 2.0 && !(length (centerColor.rgb ) < 0.001) ){
result = vec3(0.0);
weightSum = 0.0;

// Define sigma values
float sigmaS = clamp (sqrt(variance) , 2.0, 6.0);
float sigmaR = 0.5;
Expand Down Expand Up @@ -232,5 +166,5 @@ void main()
}
}

imageStore(filteredReflectionColor, imageCoords, vec4(result, 1.0));
imageStore(bilateralCleanupColor, imageCoords, vec4(result, 1.0));
}
4 changes: 2 additions & 2 deletions shaders/lightPass.frag
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ void main()
else if (rtxState.debugging_mode == esEmissive)
fragColor.xyz = state.mat.emission;
else if (rtxState.debugging_mode == esReflectionBrdf){
uv = ( gl_FragCoord.xy) / vec2(textureSize(filteredReflectionColor,0));
fragColor.xyz = texture(filteredReflectionColor, uv).rgb;
uv = ( gl_FragCoord.xy) / vec2(textureSize(bilateralCleanupColor,0));
fragColor.xyz = texture(bilateralCleanupColor, uv).rgb;
fragColor.a = 1.0;
}
else
Expand Down
8 changes: 4 additions & 4 deletions src/bilateral_cleanup_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "nvvk/commands_vk.hpp"
#include "shaders/host_device.h"

#include "autogen/temporal_spatial_pass.comp.h"
#include "autogen/bilateral_cleanup_pass.comp.h"

void BilateralCleanupPass::setup(const VkDevice& device, const VkPhysicalDevice& physicalDevice, uint32_t familyIndex, nvvk::ResourceAllocator* allocator)
{
Expand Down Expand Up @@ -61,17 +61,17 @@ void BilateralCleanupPass::create(const VkExtent2D& fullSize, const std::vector<
VkComputePipelineCreateInfo computePipelineCreateInfo{ VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO };
computePipelineCreateInfo.layout = m_pipelineLayout;
computePipelineCreateInfo.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
computePipelineCreateInfo.stage.module = nvvk::createShaderModule(m_device, temporal_spatial_pass_comp, sizeof(temporal_spatial_pass_comp));
computePipelineCreateInfo.stage.module = nvvk::createShaderModule(m_device, bilateral_cleanup_pass_comp, sizeof(bilateral_cleanup_pass_comp));
computePipelineCreateInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
computePipelineCreateInfo.stage.pName = "main";

vkCreateComputePipelines(m_device, {}, 1, &computePipelineCreateInfo, nullptr, &m_pipeline);

m_debug.setObjectName(m_pipeline, "Temporal Spatial Denoise Compute Pass");
m_debug.setObjectName(m_pipeline, "Bilateral Clean Up Pass");
vkDestroyShaderModule(m_device, computePipelineCreateInfo.stage.module, nullptr);
}

const std::string BilateralCleanupPass::name()
{
return "Reflection Compute Pass";
return "Bilateral Clean Up Pass";
}
10 changes: 9 additions & 1 deletion src/sample_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ void SampleExample::setup(const VkInstance& instance,
m_surfelIntegratePass.setup(m_device, physicalDevice, queues[eGCT0].familyIndex, &m_alloc);
m_reflectionComputePass.setup(m_device, physicalDevice, queues[eGCT0].familyIndex, &m_alloc);
m_temporalSpatialPass.setup(m_device, physicalDevice, queues[eGCT0].familyIndex, &m_alloc);
m_bilateralCleanupPass.setup(m_device, physicalDevice, queues[eGCT0].familyIndex, &m_alloc);
m_lightPass.setup(m_device, physicalDevice, queues[eGCT0].familyIndex, &m_alloc);

// Create and setup all renderers
Expand Down Expand Up @@ -368,6 +369,9 @@ void SampleExample::createReflectionPass()

m_temporalSpatialPass.create(m_size, {
m_reflectionComputePass.getSamplerDescSetLayout() }, &m_scene);

m_bilateralCleanupPass.create(m_size, {
m_reflectionComputePass.getSamplerDescSetLayout() }, &m_scene);
}

//--------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -838,7 +842,7 @@ void SampleExample::computeReflection(const VkCommandBuffer& cmdBuf, nvvk::Profi

m_reflectionComputePass.setPushContants(m_rtxState);
m_temporalSpatialPass.setPushContants(m_rtxState);

m_bilateralCleanupPass.setPushContants(m_rtxState);

m_reflectionComputePass.run(cmdBuf, render_size, profiler, {
m_accelStruct.getDescSet(),
Expand Down Expand Up @@ -873,4 +877,8 @@ void SampleExample::computeReflection(const VkCommandBuffer& cmdBuf, nvvk::Profi
}

vkCmdPipelineBarrier(cmdBuf, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_FRAGMENT_SHADER_BIT, 0, 0, nullptr, 0, nullptr, static_cast<uint32_t>(barriers.size()), barriers.data());

m_bilateralCleanupPass.run(cmdBuf, render_size, profiler, {
m_reflectionComputePass.getSamplerDescSet()
});
}
2 changes: 2 additions & 0 deletions src/sample_example.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ typedef nvvk::ResourceAllocatorDedicated Allocator;
#include "lightPass.h"
#include "reflection_compute.h"
#include "temporal_spatial_pass.h"
#include "bilateral_cleanup_pass.h"

class SampleGUI;

Expand Down Expand Up @@ -159,6 +160,7 @@ class SampleExample : public nvvkhl::AppBaseVk
// reflection compute passes
ReflectionComputePass m_reflectionComputePass;
TemporalSpatialPass m_temporalSpatialPass;
BilateralCleanupPass m_bilateralCleanupPass;

// light pass
LightPass m_lightPass;
Expand Down

0 comments on commit aa9783e

Please sign in to comment.