Skip to content

Commit

Permalink
Optimize ReSTIR pipeline performance
Browse files Browse the repository at this point in the history
  • Loading branch information
LeonKang130 committed Jun 1, 2024
1 parent cffce10 commit 02f7a8a
Showing 1 changed file with 37 additions and 117 deletions.
154 changes: 37 additions & 117 deletions src/integrators/restir_di.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ class ReSTIRDirectLightingInstance final : public ProgressiveIntegrator::Instanc
Buffer<float4x4> _prev_frame_view_matrix;
void _temporal_pass(const Camera::Instance *camera, Expr<uint> frame_index,
Expr<uint2> pixel_id, Expr<float> time,
Expr<uint> num_initial_sample, Expr<bool> enable_visibility_reuse, Expr<bool> enable_temporal_reuse,
Expr<bool> enable_decorrelation) const noexcept {
Expr<uint> num_initial_sample, Expr<bool> enable_visibility_reuse, Expr<bool> enable_temporal_reuse
) const noexcept {
auto resolution = camera->film()->node()->resolution();
sampler()->start(pixel_id, frame_index);
auto u_filter = sampler()->generate_pixel_2d();
Expand Down Expand Up @@ -277,34 +277,6 @@ class ReSTIRDirectLightingInstance final : public ProgressiveIntegrator::Instanc
};
};
};
// perturb the light samples to reduce correlation, use Metropolis to determine whether to accept the perturbation
$if(enable_decorrelation & reservoir.weight.total_weight > 0.f) {
auto MARKOV_CHAIN_LENGTH = 4u;
auto sample_box_muller = [](Expr<float2> u) noexcept {
auto r = sqrt(clamp(-2.f * log(u.x), 0.f, 1.f));
auto theta = 2.f * pi * u.y;
return make_float2(r * cos(theta), r * sin(theta));
};
$for(markov_iter, MARKOV_CHAIN_LENGTH) {
auto candidate = reservoir;
auto perturbation = 0.025f * sample_box_muller(sampler()->generate_2d());
candidate.sample.u_light_surface = reservoir.sample.u_light_surface + perturbation;
$if(any(candidate.sample.u_light_surface < 0.f) | any(candidate.sample.u_light_surface > 1.f)) { $continue; };
auto sample_target_pdf = def(0.f);
$if(enable_visibility_reuse) {
auto [L, _] = _evaluate_with_occlusion(candidate.sample, *it, wo, swl, time);
sample_target_pdf = pipeline().spectrum()->cie_y(swl, L);
} $else {
auto [L, _] = _evaluate_without_occlusion(candidate.sample, *it, wo, swl, time);
sample_target_pdf = pipeline().spectrum()->cie_y(swl, L);
};
auto accepting_prob = min(1.f, sample_target_pdf / candidate.weight.target_pdf);
candidate.weight.target_pdf = sample_target_pdf;
$if(sampler()->generate_1d() < accepting_prob) {
reservoir = candidate;
};
};
};
$break;
};
$if(dsl::isnan(reservoir.weight.total_weight) | dsl::isnan(reservoir.weight.target_pdf)) {
Expand Down Expand Up @@ -334,8 +306,6 @@ class ReSTIRDirectLightingInstance final : public ProgressiveIntegrator::Instanc
ArrayFloat<3u> valid_neighbor_m_array;
auto num_valid_neighbor = def(0u);
auto z = reservoir.weight.m;
auto depth_projector = inverse(camera->camera_to_world())[2];
auto current_pixel_depth = dot(depth_projector.xyz(), it->p()) + depth_projector.w;
$for(_, num_neighbor_sample) {
auto u_radius = sampler()->generate_1d(), u_theta = sampler()->generate_1d();
auto radius = neighbor_radius * sqrt(u_radius);
Expand All @@ -347,9 +317,7 @@ class ReSTIRDirectLightingInstance final : public ProgressiveIntegrator::Instanc
auto neighbor_hit = _visibility_buffer->hit(neighbor_id);
auto neighbor_it = pipeline().geometry()->interaction(neighbor_ray, neighbor_hit);
$if(neighbor_it->valid() & neighbor_it->shape().has_surface()) {
auto neighbor_pixel_depth = dot(depth_projector.xyz(), neighbor_it->p()) + depth_projector.w;
$if(abs(neighbor_pixel_depth - current_pixel_depth) < 0.1f * abs(current_pixel_depth) &
dot(it->ng(), neighbor_it->ng()) > 0.91f) {
$if(dot(it->ng(), neighbor_it->ng()) > 0.9f) {
auto neighbor_reservoir = _spatial_reservoir_buffer->read(neighbor_id);
$if(dsl::isnan(neighbor_reservoir.weight.total_weight) | dsl::isnan(neighbor_reservoir.weight.target_pdf) | neighbor_reservoir.weight.m == 0.f) { $continue; };
auto neighbor_target_pdf = def(0.f);
Expand Down Expand Up @@ -397,7 +365,8 @@ class ReSTIRDirectLightingInstance final : public ProgressiveIntegrator::Instanc
_temporal_reservoir_buffer->write(reservoir, pixel_id);
}
[[nodiscard]] Float3 Li(const Camera::Instance *camera, Expr<uint> frame_index, Expr<uint2> pixel_id, Expr<float> time,
Expr<bool> enable_spatial_reuse, Expr<bool> unbiased, Expr<bool> enable_visibility_reuse) const noexcept {
Expr<bool> enable_spatial_reuse, Expr<bool> unbiased, Expr<bool> enable_visibility_reuse,
Expr<bool> enable_decorrelation) const noexcept {
auto resolution = camera->film()->node()->resolution();
sampler()->start(pixel_id, frame_index << 1u | 1u);
auto spectrum = pipeline().spectrum();
Expand Down Expand Up @@ -428,83 +397,34 @@ class ReSTIRDirectLightingInstance final : public ProgressiveIntegrator::Instanc
// compute direct lighting
$if(!it->shape().has_surface()) { $break; };
$if(enable_spatial_reuse) {
$outline {
reservoir = _temporal_reservoir_buffer->read(pixel_id);
auto num_neighbor_sample = ite(unbiased, 3u, 5u);
auto constexpr neighbor_radius = 30.f;
ArrayVar<Ray, 3u> valid_neighbor_ray_array;
ArrayVar<Hit, 3u> valid_neighbor_hit_array;
ArrayFloat<3u> valid_neighbor_m_array;
auto num_valid_neighbor = def(0u);
auto z = reservoir.weight.m;
auto camera_to_world = camera->camera_to_world();
auto world_to_camera = inverse(camera_to_world);
auto current_pixel_depth = (world_to_camera * make_float4(it->p(), 1.f)).z;
$for(_, num_neighbor_sample) {
auto u_radius = sampler()->generate_1d(), u_theta = sampler()->generate_1d();
auto radius = neighbor_radius * sqrt(u_radius);
auto theta = 2.f * pi * u_theta;
auto offset = make_float2(radius * cos(theta), radius * sin(theta));
auto neighbor_id = make_uint2(clamp(make_float2(pixel_id) + offset, make_float2(0.f), make_float2(resolution) - 1.f));
$if(all(neighbor_id == pixel_id)) { $continue; };
auto neighbor_ray = _visibility_buffer->ray(neighbor_id);
auto neighbor_hit = _visibility_buffer->hit(neighbor_id);
auto neighbor_it = pipeline().geometry()->interaction(neighbor_ray, neighbor_hit);
$if(neighbor_it->valid() & neighbor_it->shape().has_surface()) {
auto neighbor_pixel_depth = (world_to_camera * make_float4(neighbor_it->p(), 1.f)).z;
$if(abs(neighbor_pixel_depth - current_pixel_depth) < 0.05f * abs(current_pixel_depth) &
dot(it->ng(), neighbor_it->ng()) > 0.91f) {
auto neighbor_reservoir = Reservoir::zero();
neighbor_reservoir = _temporal_reservoir_buffer->read(neighbor_id);
$if(dsl::isnan(neighbor_reservoir.weight.total_weight) | dsl::isnan(neighbor_reservoir.weight.target_pdf) | neighbor_reservoir.weight.m == 0.f) { $continue; };
auto neighbor_target_pdf = def(0.f);
$if(enable_visibility_reuse & unbiased) {
auto [L, _] = _evaluate_with_occlusion(neighbor_reservoir.sample, *it, wo, swl, time);
neighbor_target_pdf = pipeline().spectrum()->cie_y(swl, L);
} $else {
auto [L, _] = _evaluate_without_occlusion(neighbor_reservoir.sample, *it, wo, swl, time);
neighbor_target_pdf = pipeline().spectrum()->cie_y(swl, L);
};
neighbor_reservoir.weight.total_weight *= ite(neighbor_reservoir.weight.target_pdf == 0.f, 0.f, neighbor_target_pdf / neighbor_reservoir.weight.target_pdf);
neighbor_reservoir.weight.target_pdf = neighbor_target_pdf;
reservoir.update(neighbor_reservoir, sampler()->generate_1d());
$if(unbiased) {
valid_neighbor_ray_array[num_valid_neighbor] = neighbor_ray;
valid_neighbor_hit_array[num_valid_neighbor] = neighbor_hit;
valid_neighbor_m_array[num_valid_neighbor] = neighbor_reservoir.weight.m;
num_valid_neighbor += 1u;
};
};
};
};
$if(unbiased) {
$for(neighbor_index, num_valid_neighbor) {
auto neighbor_ray = valid_neighbor_ray_array[neighbor_index];
auto neighbor_hit = valid_neighbor_hit_array[neighbor_index];
auto neighbor_it = pipeline().geometry()->interaction(neighbor_ray, neighbor_hit);
auto out_of_domain = def(true);
$if(enable_visibility_reuse) {
auto [L, _] = _evaluate_with_occlusion(reservoir.sample, *neighbor_it, -neighbor_ray->direction(), swl, time);
out_of_domain = L.is_zero();
} $else {
auto [L, _] = _evaluate_without_occlusion(reservoir.sample, *neighbor_it, -neighbor_ray->direction(), swl, time);
out_of_domain = L.is_zero();
};
$if(!out_of_domain) {
z += valid_neighbor_m_array[neighbor_index];
};
};
reservoir.weight.total_weight *= reservoir.weight.m / z;
reservoir.weight.m = z;
};
};
reservoir = _temporal_reservoir_buffer->read(pixel_id);
}
$else {
reservoir = _spatial_reservoir_buffer->read(pixel_id);
};
// perturb the light samples to reduce correlation, use Metropolis to determine whether to accept the perturbation
$if(enable_decorrelation & reservoir.weight.total_weight > 0.f) {
auto MARKOV_CHAIN_LENGTH = 2u;
auto sample_box_muller = [](Expr<float2> u) noexcept {
auto r = sqrt(clamp(-2.f * log(u.x), 0.f, 1.f));
auto theta = 2.f * pi * u.y;
return make_float2(r * cos(theta), r * sin(theta));
};
$for(markov_iter, MARKOV_CHAIN_LENGTH) {
auto candidate = reservoir;
auto perturbation = 0.025f * sample_box_muller(sampler()->generate_2d());
candidate.sample.u_light_surface = reservoir.sample.u_light_surface + perturbation;
$if(any(candidate.sample.u_light_surface < 0.f) | any(candidate.sample.u_light_surface > 1.f)) { $continue; };
auto sample_target_pdf = def(0.f);
auto [L, _] = _evaluate_with_occlusion(candidate.sample, *it, wo, swl, time);
sample_target_pdf = pipeline().spectrum()->cie_y(swl, L);
auto accepting_prob = min(1.f, sample_target_pdf / candidate.weight.target_pdf);
candidate.weight.target_pdf = sample_target_pdf;
$if(sampler()->generate_1d() < accepting_prob) { reservoir = candidate; };
};
};
auto [L, _] = _evaluate_with_occlusion(reservoir.sample, *it, wo, swl, time);
auto contribution_weight = reservoir.contribution_weight();
Li += weight * contribution_weight * L;
Li += weight * reservoir.contribution_weight() * L;
$break;
};
_spatial_reservoir_buffer->write(reservoir, pixel_id);
Expand Down Expand Up @@ -546,22 +466,23 @@ class ReSTIRDirectLightingInstance final : public ProgressiveIntegrator::Instanc
}
using namespace luisa::compute;
Kernel2D temporal_pass_kernel = [&](UInt frame_index, Float time, UInt num_initial_sample, Bool enable_visibility_reuse,
Bool enable_temporal_reuse, Bool enable_decorrelation) noexcept {
Bool enable_temporal_reuse) noexcept {
set_block_size(16u, 16u, 1u);
auto pixel_id = dispatch_id().xy();
_temporal_pass(camera, frame_index, pixel_id, time, num_initial_sample, enable_visibility_reuse,
enable_temporal_reuse, enable_decorrelation);
enable_temporal_reuse);
};
Kernel2D spatial_pass_kernel = [&](UInt frame_index, Float time, Bool unbiased, Bool enable_visibility_reuse) noexcept {
set_block_size(16u, 16u, 1u);
auto pixel_id = dispatch_id().xy();
_spatial_pass(camera, frame_index, pixel_id, time, unbiased, enable_visibility_reuse);
};
Kernel2D render_kernel = [&](UInt frame_index, Float time, Float shutter_weight,
Bool enable_spatial_reuse, Bool unbiased, Bool enable_visibility_reuse) noexcept {
Bool enable_spatial_reuse, Bool unbiased, Bool enable_visibility_reuse,
Bool enable_decorrelation) noexcept {
set_block_size(16u, 16u, 1u);
auto pixel_id = dispatch_id().xy();
auto L = Li(camera, frame_index, pixel_id, time, enable_spatial_reuse, unbiased, enable_visibility_reuse);
auto L = Li(camera, frame_index, pixel_id, time, enable_spatial_reuse, unbiased, enable_visibility_reuse, enable_decorrelation);
camera->film()->accumulate(pixel_id, shutter_weight * L);
$if(all(pixel_id == 0u)) {
auto view_matrix = inverse(camera->camera_to_world());
Expand Down Expand Up @@ -595,20 +516,19 @@ class ReSTIRDirectLightingInstance final : public ProgressiveIntegrator::Instanc
for (auto s : shutter_samples) {
pipeline().update(command_buffer, s.point.time);
for (auto i = 0u; i < s.spp; i++) {
// camera->film()->clear(command_buffer);
auto constexpr num_spatial_reuse_pass = 2u;
camera->film()->clear(command_buffer);
command_buffer << temporal_pass(sample_id, s.point.time,
node<ReSTIRDirectLighting>()->num_initial_sample(),
node<ReSTIRDirectLighting>()->enable_visibility_reuse(),
node<ReSTIRDirectLighting>()->enable_temporal_reuse(),
node<ReSTIRDirectLighting>()->enable_decorrelation())
node<ReSTIRDirectLighting>()->enable_temporal_reuse())
.dispatch(resolution);
if (node<ReSTIRDirectLighting>()->enable_spatial_reuse()) {
command_buffer << spatial_pass(sample_id, s.point.time, node<ReSTIRDirectLighting>()->unbiased_spatial_reuse(),
node<ReSTIRDirectLighting>()->enable_visibility_reuse()).dispatch(resolution);
}
command_buffer << render(sample_id++, s.point.time, s.point.weight, node<ReSTIRDirectLighting>()->enable_spatial_reuse(),
node<ReSTIRDirectLighting>()->unbiased_spatial_reuse(), node<ReSTIRDirectLighting>()->enable_visibility_reuse())
node<ReSTIRDirectLighting>()->unbiased_spatial_reuse(), node<ReSTIRDirectLighting>()->enable_visibility_reuse(),
node<ReSTIRDirectLighting>()->enable_decorrelation())
.dispatch(resolution);
_temporal_reservoir_buffer->copy_from(command_buffer, *_spatial_reservoir_buffer);
if (auto &&p = pipeline().printer(); !p.empty()) {
Expand Down

0 comments on commit 02f7a8a

Please sign in to comment.