diff --git a/shaders/bilateral_cleanup_pass.comp b/shaders/bilateral_cleanup_pass.comp index 4dae106..fc56c57 100644 --- a/shaders/bilateral_cleanup_pass.comp +++ b/shaders/bilateral_cleanup_pass.comp @@ -31,7 +31,7 @@ layout(set = 0, binding = 0) uniform image2D reflectionColor; layout(set = 0, binding = 1) uniform image2D reflectionDirection; layout(set = 0, binding = 2) uniform image2D reflectionPointBrdf; layout(set = 0, binding = 3) uniform image2D filteredReflectionColor; -layout(set = 5, binding = 4) uniform image2D bilateralCleanupColor; +layout(set = 0, binding = 4) uniform image2D bilateralCleanupColor; const ivec2 samplePattern[4][16] = ivec2[4][16]( ivec2[16]( @@ -125,79 +125,13 @@ void main() if (imageCoords.x >= imageRes.x || imageCoords.y >= imageRes.y) return; - // STAGE 1; Spatial reconstruction filtering - ivec2 halfImageCoords = imageCoords / 2; - - vec3 result = vec3(0.0); + vec4 centerColor = imageLoad(filteredReflectionColor, imageCoords); + float variance = centerColor.w; + vec3 result = vec3(0.8,0.0,0.0); float weightSum = 0.0; - - // Base sampling radius - float radius = min(imageRes.x, imageRes.y) * 0.003; - - // Get center pixel values as reference - vec4 centerColor = imageLoad(reflectionColor, halfImageCoords); - if (centerColor.a == 0.0) - { - imageStore(filteredReflectionColor, imageCoords, vec4(0.0)); - return; - } - - vec4 centerDirection = imageLoad(reflectionDirection, halfImageCoords); - vec4 centerBrdf = imageLoad(reflectionPointBrdf, halfImageCoords); - - // Add center sample contribution - float centerPdfInv = max(centerDirection.a, 0.001); - float centerWeight = max(centerColor.a * centerPdfInv, 0.0); - vec3 centerContrib = centerColor.rgb * centerWeight; - result += centerContrib; - weightSum += centerWeight; - uint quadID = (imageCoords.y & 0x1) << 1 + (imageCoords.x & 0x1); - float variance = 0.0; - - for (int i = 0; i < 16; ++i) - { - ivec2 offset = samplePattern[quadID][i]; - ivec2 neighborCoords = halfImageCoords + offset; - - vec4 neighborColor = imageLoad(reflectionColor, neighborCoords); - if (neighborColor.a == 0.0) - continue; - - float brdfWeight = neighborColor.a; - float pdfInv = max(imageLoad(reflectionDirection, neighborCoords).a, 1e-5); - - // Combine BRDF weight with distance weight - float weight = max(brdfWeight * pdfInv, 0.0); - vec3 contrib = neighborColor.rgb * weight; - vec3 diff = contrib - centerContrib; - variance += dot(diff, diff); - - result += contrib; - weightSum += weight; - } - - // Normalize result - if (weightSum > 1e-5) { - result /= weightSum; - } else { - result = centerColor.rgb; - } - - // temporal accumulation - //float ratio = 1.0 / min(64.0, (1.0 + float(rtxState.frame))); - //vec3 colPrev = imageLoad(filteredReflectionColor, imageCoords).rgb; - //result = mix(colPrev, result, ratio); - - imageStore(filteredReflectionColor, imageCoords, vec4(result, sqrt(variance))); - //barriering is fine for now, if we need to optimize, maybe spliting different stage into seperate shader will be good - memoryBarrier(); - barrier(); - - // STAGE 2: bilateral filter if(sqrt(variance) > 2.0 && !(length (centerColor.rgb ) < 0.001) ){ - result = vec3(0.0); - weightSum = 0.0; + // Define sigma values float sigmaS = clamp (sqrt(variance) , 2.0, 6.0); float sigmaR = 0.5; @@ -232,5 +166,5 @@ void main() } } - imageStore(filteredReflectionColor, imageCoords, vec4(result, 1.0)); + imageStore(bilateralCleanupColor, imageCoords, vec4(result, 1.0)); } \ No newline at end of file diff --git a/shaders/lightPass.frag b/shaders/lightPass.frag index 68e77a4..63d53c8 100644 --- a/shaders/lightPass.frag +++ b/shaders/lightPass.frag @@ -188,8 +188,8 @@ void main() else if (rtxState.debugging_mode == esEmissive) fragColor.xyz = state.mat.emission; else if (rtxState.debugging_mode == esReflectionBrdf){ - uv = ( gl_FragCoord.xy) / vec2(textureSize(filteredReflectionColor,0)); - fragColor.xyz = texture(filteredReflectionColor, uv).rgb; + uv = ( gl_FragCoord.xy) / vec2(textureSize(bilateralCleanupColor,0)); + fragColor.xyz = texture(bilateralCleanupColor, uv).rgb; fragColor.a = 1.0; } else diff --git a/src/bilateral_cleanup_pass.cpp b/src/bilateral_cleanup_pass.cpp index 59312ae..4360668 100644 --- a/src/bilateral_cleanup_pass.cpp +++ b/src/bilateral_cleanup_pass.cpp @@ -9,7 +9,7 @@ #include "nvvk/commands_vk.hpp" #include "shaders/host_device.h" -#include "autogen/temporal_spatial_pass.comp.h" +#include "autogen/bilateral_cleanup_pass.comp.h" void BilateralCleanupPass::setup(const VkDevice& device, const VkPhysicalDevice& physicalDevice, uint32_t familyIndex, nvvk::ResourceAllocator* allocator) { @@ -61,17 +61,17 @@ void BilateralCleanupPass::create(const VkExtent2D& fullSize, const std::vector< VkComputePipelineCreateInfo computePipelineCreateInfo{ VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO }; computePipelineCreateInfo.layout = m_pipelineLayout; computePipelineCreateInfo.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; - computePipelineCreateInfo.stage.module = nvvk::createShaderModule(m_device, temporal_spatial_pass_comp, sizeof(temporal_spatial_pass_comp)); + computePipelineCreateInfo.stage.module = nvvk::createShaderModule(m_device, bilateral_cleanup_pass_comp, sizeof(bilateral_cleanup_pass_comp)); computePipelineCreateInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT; computePipelineCreateInfo.stage.pName = "main"; vkCreateComputePipelines(m_device, {}, 1, &computePipelineCreateInfo, nullptr, &m_pipeline); - m_debug.setObjectName(m_pipeline, "Temporal Spatial Denoise Compute Pass"); + m_debug.setObjectName(m_pipeline, "Bilateral Clean Up Pass"); vkDestroyShaderModule(m_device, computePipelineCreateInfo.stage.module, nullptr); } const std::string BilateralCleanupPass::name() { - return "Reflection Compute Pass"; + return "Bilateral Clean Up Pass"; } \ No newline at end of file diff --git a/src/sample_example.cpp b/src/sample_example.cpp index ab482b1..f9d4dcf 100644 --- a/src/sample_example.cpp +++ b/src/sample_example.cpp @@ -84,6 +84,7 @@ void SampleExample::setup(const VkInstance& instance, m_surfelIntegratePass.setup(m_device, physicalDevice, queues[eGCT0].familyIndex, &m_alloc); m_reflectionComputePass.setup(m_device, physicalDevice, queues[eGCT0].familyIndex, &m_alloc); m_temporalSpatialPass.setup(m_device, physicalDevice, queues[eGCT0].familyIndex, &m_alloc); + m_bilateralCleanupPass.setup(m_device, physicalDevice, queues[eGCT0].familyIndex, &m_alloc); m_lightPass.setup(m_device, physicalDevice, queues[eGCT0].familyIndex, &m_alloc); // Create and setup all renderers @@ -368,6 +369,9 @@ void SampleExample::createReflectionPass() m_temporalSpatialPass.create(m_size, { m_reflectionComputePass.getSamplerDescSetLayout() }, &m_scene); + + m_bilateralCleanupPass.create(m_size, { + m_reflectionComputePass.getSamplerDescSetLayout() }, &m_scene); } //-------------------------------------------------------------------------------------------------- @@ -838,7 +842,7 @@ void SampleExample::computeReflection(const VkCommandBuffer& cmdBuf, nvvk::Profi m_reflectionComputePass.setPushContants(m_rtxState); m_temporalSpatialPass.setPushContants(m_rtxState); - + m_bilateralCleanupPass.setPushContants(m_rtxState); m_reflectionComputePass.run(cmdBuf, render_size, profiler, { m_accelStruct.getDescSet(), @@ -873,4 +877,8 @@ void SampleExample::computeReflection(const VkCommandBuffer& cmdBuf, nvvk::Profi } vkCmdPipelineBarrier(cmdBuf, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_FRAGMENT_SHADER_BIT, 0, 0, nullptr, 0, nullptr, static_cast(barriers.size()), barriers.data()); + + m_bilateralCleanupPass.run(cmdBuf, render_size, profiler, { + m_reflectionComputePass.getSamplerDescSet() + }); } diff --git a/src/sample_example.hpp b/src/sample_example.hpp index 8e9eb39..f9271b7 100644 --- a/src/sample_example.hpp +++ b/src/sample_example.hpp @@ -79,6 +79,7 @@ typedef nvvk::ResourceAllocatorDedicated Allocator; #include "lightPass.h" #include "reflection_compute.h" #include "temporal_spatial_pass.h" +#include "bilateral_cleanup_pass.h" class SampleGUI; @@ -159,6 +160,7 @@ class SampleExample : public nvvkhl::AppBaseVk // reflection compute passes ReflectionComputePass m_reflectionComputePass; TemporalSpatialPass m_temporalSpatialPass; + BilateralCleanupPass m_bilateralCleanupPass; // light pass LightPass m_lightPass;