Skip to content

Commit

Permalink
Add bilateral pass
Browse files Browse the repository at this point in the history
  • Loading branch information
WANG-Ruipeng committed Dec 6, 2024
1 parent 4f7acc5 commit a622dd6
Show file tree
Hide file tree
Showing 7 changed files with 393 additions and 30 deletions.
236 changes: 236 additions & 0 deletions shaders/bilateral_cleanup_pass.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@

#version 460
#extension GL_ARB_separate_shader_objects : enable
#extension GL_EXT_nonuniform_qualifier : enable
#extension GL_GOOGLE_include_directive : enable
#extension GL_EXT_scalar_block_layout : enable
#extension GL_EXT_ray_tracing : enable
#extension GL_EXT_ray_query : enable
#extension GL_ARB_shader_clock : enable // Using clockARB
#extension GL_EXT_shader_image_load_formatted : enable // The folowing extension allow to pass images as function parameters

#extension GL_NV_shader_sm_builtins : require // Debug - gl_WarpIDNV, gl_SMIDNV
#extension GL_ARB_gpu_shader_int64 : enable // Debug - heatmap value
#extension GL_EXT_shader_realtime_clock : enable // Debug - heatmap timing

#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require
#extension GL_EXT_buffer_reference2 : require
#extension GL_EXT_debug_printf : enable
#extension GL_KHR_vulkan_glsl : enable

#include "host_device.h"

layout(push_constant) uniform _RtxState
{
RtxState rtxState;
};

#include "globals.glsl"

layout(set = 0, binding = 0) uniform image2D reflectionColor;
layout(set = 0, binding = 1) uniform image2D reflectionDirection;
layout(set = 0, binding = 2) uniform image2D reflectionPointBrdf;
layout(set = 0, binding = 3) uniform image2D filteredReflectionColor;
layout(set = 5, binding = 4) uniform image2D bilateralCleanupColor;

const ivec2 samplePattern[4][16] = ivec2[4][16](
ivec2[16](
ivec2(-4, 0),
ivec2(-4, 1),
ivec2(-3, -2),
ivec2(-3, 3),
ivec2(-2, 2),
ivec2(-1, -4),
ivec2(-1, -2),
ivec2(-1, 4),
ivec2(0, -1),
ivec2(0, 1),
ivec2(1, -2),
ivec2(1, 2),
ivec2(2, -2),
ivec2(3, -1),
ivec2(3, 1),
ivec2(3, 3)),
ivec2[16](
ivec2(-3, -3),
ivec2(-3, 0),
ivec2(-3, 2),
ivec2(-2, -1),
ivec2(-2, 0),
ivec2(-1, -3),
ivec2(-1, 2),
ivec2(0, 0),
ivec2(0, 2),
ivec2(1, -1),
ivec2(2, -3),
ivec2(2, 2),
ivec2(2, 3),
ivec2(3, -3),
ivec2(3, 0),
ivec2(4, 0)),
ivec2[16](
ivec2(-4, -1),
ivec2(-4, 2),
ivec2(-3, 1),
ivec2(-2, -3),
ivec2(-2, -2),
ivec2(-1, -1),
ivec2(-1, 1),
ivec2(-1, 3),
ivec2(1, -4),
ivec2(1, -3),
ivec2(1, 0),
ivec2(1, 3),
ivec2(1, 4),
ivec2(2, 0),
ivec2(2, 1),
ivec2(3, -2)),
ivec2[16](
ivec2(-4, -2),
ivec2(-3, -1),
ivec2(-2, 1),
ivec2(-2, 3),
ivec2(-1, 0),
ivec2(0, -4),
ivec2(0, -3),
ivec2(0, -2),
ivec2(0, 3),
ivec2(0, 4),
ivec2(1, 1),
ivec2(2, -1),
ivec2(2, 4),
ivec2(3, 2),
ivec2(4, -1),
ivec2(4, 1))
);

const vec2 poissonDisk[8] = vec2[](
vec2(-0.5, -0.5),
vec2( 0.5, -0.5),
vec2(-0.5, 0.5),
vec2( 0.5, 0.5),
vec2(-0.25, -0.75),
vec2( 0.25, -0.75),
vec2(-0.75, 0.25),
vec2( 0.75, 0.25)
);

layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in;

void main()
{
ivec2 imageRes = rtxState.size;
ivec2 imageCoords = ivec2(gl_GlobalInvocationID.xy);

if (imageCoords.x >= imageRes.x || imageCoords.y >= imageRes.y)
return;

// STAGE 1; Spatial reconstruction filtering
ivec2 halfImageCoords = imageCoords / 2;

vec3 result = vec3(0.0);
float weightSum = 0.0;

// Base sampling radius
float radius = min(imageRes.x, imageRes.y) * 0.003;

// Get center pixel values as reference
vec4 centerColor = imageLoad(reflectionColor, halfImageCoords);
if (centerColor.a == 0.0)
{
imageStore(filteredReflectionColor, imageCoords, vec4(0.0));
return;
}

vec4 centerDirection = imageLoad(reflectionDirection, halfImageCoords);
vec4 centerBrdf = imageLoad(reflectionPointBrdf, halfImageCoords);

// Add center sample contribution
float centerPdfInv = max(centerDirection.a, 0.001);
float centerWeight = max(centerColor.a * centerPdfInv, 0.0);
vec3 centerContrib = centerColor.rgb * centerWeight;
result += centerContrib;
weightSum += centerWeight;

uint quadID = (imageCoords.y & 0x1) << 1 + (imageCoords.x & 0x1);
float variance = 0.0;

for (int i = 0; i < 16; ++i)
{
ivec2 offset = samplePattern[quadID][i];
ivec2 neighborCoords = halfImageCoords + offset;

vec4 neighborColor = imageLoad(reflectionColor, neighborCoords);
if (neighborColor.a == 0.0)
continue;

float brdfWeight = neighborColor.a;
float pdfInv = max(imageLoad(reflectionDirection, neighborCoords).a, 1e-5);

// Combine BRDF weight with distance weight
float weight = max(brdfWeight * pdfInv, 0.0);
vec3 contrib = neighborColor.rgb * weight;
vec3 diff = contrib - centerContrib;
variance += dot(diff, diff);

result += contrib;
weightSum += weight;
}

// Normalize result
if (weightSum > 1e-5) {
result /= weightSum;
} else {
result = centerColor.rgb;
}

// temporal accumulation
//float ratio = 1.0 / min(64.0, (1.0 + float(rtxState.frame)));
//vec3 colPrev = imageLoad(filteredReflectionColor, imageCoords).rgb;
//result = mix(colPrev, result, ratio);

imageStore(filteredReflectionColor, imageCoords, vec4(result, sqrt(variance)));
//barriering is fine for now, if we need to optimize, maybe spliting different stage into seperate shader will be good
memoryBarrier();
barrier();

// STAGE 2: bilateral filter
if(sqrt(variance) > 2.0 && !(length (centerColor.rgb ) < 0.001) ){
result = vec3(0.0);
weightSum = 0.0;
// Define sigma values
float sigmaS = clamp (sqrt(variance) , 2.0, 6.0);
float sigmaR = 0.5;
// Define a kernel radius for the spatial domain
int kernelRadius = int(0.25 * sigmaS);
for (int y = -kernelRadius; y <= kernelRadius; ++y) {
for (int x = -kernelRadius; x <= kernelRadius; ++x) {
ivec2 neighborCoords = imageCoords + ivec2(x, y);
// Boundary check
if (neighborCoords.x < 0 || neighborCoords.y < 0 ||
neighborCoords.x >= imageRes.x || neighborCoords.y >= imageRes.y)
continue;
vec4 neighborColor = imageLoad(filteredReflectionColor, neighborCoords);
if(length (neighborColor.rgb) < 0.001) continue;
// Spatial weight
float dist2 = float(x * x + y * y);
float spatialWeight = exp(-dist2 / (2.0 * sigmaS * sigmaS));
// Range weight
float colorDiff = length(neighborColor.rgb - centerColor.rgb);
float rangeWeight = exp(-colorDiff * colorDiff / (2.0 * sigmaR * sigmaR));
// Combined weight
float weight = spatialWeight * rangeWeight;
result += neighborColor.rgb * weight;
weightSum += weight;
}
}
// Normalize result
if (weightSum > 1e-5) {
result /= weightSum;
} else {
result = centerColor.rgb;
}
}

imageStore(filteredReflectionColor, imageCoords, vec4(result, 1.0));
}
9 changes: 5 additions & 4 deletions shaders/lightPass.frag
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,11 @@ layout(set = 4, binding = 2) uniform sampler2D gbufferDepth;

layout(set = 5, binding = eSampler) uniform sampler2D indirectLightMap;

layout(set = 6, binding = 4) uniform sampler2D reflectionColor;
layout(set = 6, binding = 5) uniform sampler2D reflectionDirection;
layout(set = 6, binding = 6) uniform sampler2D reflectionPointBrdf;
layout(set = 6, binding = 7) uniform sampler2D filteredReflectionColor;
layout(set = 6, binding = 5) uniform sampler2D reflectionColor;
layout(set = 6, binding = 6) uniform sampler2D reflectionDirection;
layout(set = 6, binding = 7) uniform sampler2D reflectionPointBrdf;
layout(set = 6, binding = 8) uniform sampler2D filteredReflectionColor;
layout(set = 6, binding = 9) uniform sampler2D bilateralCleanupColor;

vec3 hsv2rgb(vec3 c) {
vec4 K = vec4(1.0, 2.0 / 3.0, 1.0 / 3.0, 3.0);
Expand Down
1 change: 1 addition & 0 deletions shaders/reflection_generation.comp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ layout(set = 5, binding = 0) uniform image2D reflectionColor;
layout(set = 5, binding = 1) uniform image2D reflectionDirection;
layout(set = 5, binding = 2) uniform image2D reflectionPointBrdf;
layout(set = 5, binding = 3) uniform image2D filteredReflectionColor;
layout(set = 5, binding = 4) uniform image2D bilateralCleanupColor;

float brdfWeight(vec3 V, vec3 N, vec3 L, float roughness)
{
Expand Down
20 changes: 1 addition & 19 deletions shaders/temporal_spatial_pass.comp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ layout(set = 0, binding = 0) uniform image2D reflectionColor;
layout(set = 0, binding = 1) uniform image2D reflectionDirection;
layout(set = 0, binding = 2) uniform image2D reflectionPointBrdf;
layout(set = 0, binding = 3) uniform image2D filteredReflectionColor;
layout(set = 5, binding = 4) uniform image2D bilateralCleanupColor;

const ivec2 samplePattern[4][16] = ivec2[4][16](
ivec2[16](
Expand Down Expand Up @@ -175,25 +176,6 @@ void main()
weightSum += weight;
}

// Sample with distance-based weights
// for (int i = 0; i < 8; i++) {
// ivec2 offset = ivec2(poissonDisk[i] * radius);
// ivec2 neighborCoords = halfImageCoords + offset;
//
// vec4 neighborColor = imageLoad(reflectionColor, neighborCoords);
// if (neighborColor.a == 0.0)
// continue;
//
// float brdfWeight = neighborColor.a;
// float pdfInv = max(imageLoad(reflectionDirection, neighborCoords).a, 1e-5);
//
// // Combine BRDF weight with distance weight
// float weight = max(brdfWeight * pdfInv, 0.0);
//
// result += neighborColor.rgb * weight;
// weightSum += weight;
// }

// Normalize result
if (weightSum > 1e-5) {
result /= weightSum;
Expand Down
77 changes: 77 additions & 0 deletions src/bilateral_cleanup_pass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#include "bilateral_cleanup_pass.h"

#include "nvvk/images_vk.hpp"
#include "nvh/alignment.hpp"
#include "nvh/fileoperations.hpp"
#include "nvvk/shaders_vk.hpp"
#include "scene.hpp"
#include "tools.hpp"
#include "nvvk/commands_vk.hpp"
#include "shaders/host_device.h"

#include "autogen/temporal_spatial_pass.comp.h"

void BilateralCleanupPass::setup(const VkDevice& device, const VkPhysicalDevice& physicalDevice, uint32_t familyIndex, nvvk::ResourceAllocator* allocator)
{
m_device = device;
m_pAlloc = allocator;
m_queueIndex = familyIndex;
m_debug.setup(device);
}

void BilateralCleanupPass::destroy()
{
vkDestroyPipeline(m_device, m_pipeline, nullptr);
vkDestroyPipelineLayout(m_device, m_pipelineLayout, nullptr);

m_pipelineLayout = VK_NULL_HANDLE;
m_pipeline = VK_NULL_HANDLE;
}

void BilateralCleanupPass::run(const VkCommandBuffer& cmdBuf, const VkExtent2D& size, nvvk::ProfilerVK& profiler, const std::vector<VkDescriptorSet>& descSets)
{
LABEL_SCOPE_VK(cmdBuf);
const int GROUP_SIZE = 8;
// Preparing for the compute shader
vkCmdBindPipeline(cmdBuf, VK_PIPELINE_BIND_POINT_COMPUTE, m_pipeline);
vkCmdBindDescriptorSets(cmdBuf, VK_PIPELINE_BIND_POINT_COMPUTE, m_pipelineLayout, 0,
static_cast<uint32_t>(descSets.size()), descSets.data(), 0, nullptr);

// Sending the push constant information
vkCmdPushConstants(cmdBuf, m_pipelineLayout, VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(RtxState), &m_state);

// Dispatching the shader
vkCmdDispatch(cmdBuf, (size.width + (GROUP_SIZE - 1)) / GROUP_SIZE, (size.height + (GROUP_SIZE - 1)) / GROUP_SIZE, 1);
}

void BilateralCleanupPass::create(const VkExtent2D& fullSize, const std::vector<VkDescriptorSetLayout>& extraDescSetsLayout, Scene* _scene)
{
VkExtent2D size = { fullSize.width, fullSize.height};

std::vector<VkPushConstantRange> push_constants;
push_constants.push_back({ VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(RtxState) });

VkPipelineLayoutCreateInfo layout_info{ VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO };
layout_info.pushConstantRangeCount = static_cast<uint32_t>(push_constants.size());
layout_info.pPushConstantRanges = push_constants.data();
layout_info.setLayoutCount = static_cast<uint32_t>(extraDescSetsLayout.size());
layout_info.pSetLayouts = extraDescSetsLayout.data();
vkCreatePipelineLayout(m_device, &layout_info, nullptr, &m_pipelineLayout);

VkComputePipelineCreateInfo computePipelineCreateInfo{ VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO };
computePipelineCreateInfo.layout = m_pipelineLayout;
computePipelineCreateInfo.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
computePipelineCreateInfo.stage.module = nvvk::createShaderModule(m_device, temporal_spatial_pass_comp, sizeof(temporal_spatial_pass_comp));
computePipelineCreateInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
computePipelineCreateInfo.stage.pName = "main";

vkCreateComputePipelines(m_device, {}, 1, &computePipelineCreateInfo, nullptr, &m_pipeline);

m_debug.setObjectName(m_pipeline, "Temporal Spatial Denoise Compute Pass");
vkDestroyShaderModule(m_device, computePipelineCreateInfo.stage.module, nullptr);
}

const std::string BilateralCleanupPass::name()
{
return "Reflection Compute Pass";
}
Loading

0 comments on commit a622dd6

Please sign in to comment.