Skip to content

Commit

Permalink
debug forward sppm finished
Browse files Browse the repository at this point in the history
  • Loading branch information
jkxing committed Apr 23, 2024
1 parent 6ee34b4 commit d33f999
Showing 1 changed file with 156 additions and 73 deletions.
229 changes: 156 additions & 73 deletions src/integrators/megappm_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,22 @@ class MegakernelPhotonMappingDiffInstance : public DifferentiableIntegrator::Ins
_tau->write(pixel_id * dimension + i, (old_tau + phi) * value);
}
}


auto cur_w(Expr<uint> pixel_id) const noexcept {
if (!_shared_radius) {
//auto resolution = _film->node()->resolution();
return _cur_w->read(pixel_id);
} else {
return _cur_w->read(0u);
}
}
//weight=(weight+clamp(cur_w))*value, see pixel_info_update for useage
void update_weight(Expr<uint> pixel_id, Expr<float> value) noexcept {
//auto resolution = _film->node()->resolution();
//auto offset = pixel_id.y * resolution.x + pixel_id.x;
auto old_weight = _weight->read(pixel_id);
auto cur_w = cur_w(pixel_id);
_weight->write(pixel_id , (old_weight + cur_w) * value);
auto cur_w_ = cur_w(pixel_id);
_weight->write(pixel_id, (old_weight + cur_w_) * value);
}

auto n_photon(Expr<uint> pixel_id) const noexcept {
Expand All @@ -347,14 +355,9 @@ class MegakernelPhotonMappingDiffInstance : public DifferentiableIntegrator::Ins
return _cur_n->read(0u);
}
}

auto cur_w(Expr<uint> pixel_id) const noexcept {
if (!_shared_radius) {
//auto resolution = _film->node()->resolution();
return _cur_w->read(pixel_id);
} else {
return _cur_w->read(0u);
}

auto weight(Expr<uint> pixel_id) const noexcept {
return _weight->read(pixel_id);
}

auto phi(Expr<uint> pixel_id) const noexcept {
Expand Down Expand Up @@ -584,72 +587,153 @@ class MegakernelPhotonMappingDiffInstance : public DifferentiableIntegrator::Ins
auto photon_per_iter = node<MegakernelPhotonMappingDiff>()->photon_per_iter();
auto pixel_count = resolution.x * resolution.y;
auto spectrum = camera->pipeline().spectrum();



uint add_x = (photon_per_iter + resolution.y - 1) / resolution.y;
sampler()->reset(command_buffer, make_uint2(resolution.x + add_x, resolution.y), pixel_count + add_x * resolution.y, spp);

command_buffer << pipeline().printer().reset();
command_buffer << compute::synchronize();

LUISA_INFO(
"Rendering to '{}' of resolution {}x{} at {}spp.",
image_file.string(),
resolution.x, resolution.y, spp);

using namespace luisa::compute;
auto &&device = camera->pipeline().device();
auto radius = node<MegakernelPhotonMappingDiff>()->initial_radius();
if (radius < 0) {
auto _grid_size = spectrum->pipeline().geometry()->world_max() - spectrum->pipeline().geometry()->world_min();
radius = min(min(_grid_size.x / -radius, _grid_size.y / -radius), _grid_size.z / -radius);
}
auto clamp = camera->film()->node()->clamp() * photon_per_iter * pi * radius * radius;

auto viewpoints_per_iter = resolution.x * resolution.y;

logger = make_unique<PhotonMappingLogger>(pixel_count, node<MegakernelPhotonMappingDiff>()->max_depth(), spectrum);
indirect = make_unique<PixelIndirect>(viewpoints_per_iter, spectrum, camera->film(), clamp, node<MegakernelPhotonMappingDiff>()->shared_radius());
viewpoints = make_unique<ViewPointMap>(viewpoints_per_iter, spectrum);
//return;
//pathlogger = make_unique<PathLogger>(node<MegakernelPhotonMappingDiff>()->max_depth(), node<MegakernelPhotonMappingDiff>()->photon_per_iter(), spectrum);
//initialize PixelIndirect
Kernel1D indirect_initialize_kernel = [&]() noexcept {

auto index = dispatch_x();
auto radius = node<MegakernelPhotonMappingDiff>()->initial_radius();
if (radius < 0)
viewpoints->write_grid_len(viewpoints->split(-radius));
else
viewpoints->write_grid_len(node<MegakernelPhotonMappingDiff>()->initial_radius());
//camera->pipeline().printer().info("grid:{}", viewpoints->grid_len());
indirect->write_radius(index, viewpoints->grid_len());
//camera->pipeline().printer().info("rad:{}", indirect->radius(index));

indirect->write_cur_n(index, 0u);
indirect->write_cur_w(index, 0.f);
indirect->write_n_photon(index, 0u);
indirect->reset_phi(index);
indirect->reset_tau(index);
};

Kernel1D viewpoint_reset_kernel = [&]() noexcept {
auto index = static_cast<UInt>(dispatch_x());
viewpoints->reset(index);
};

Kernel2D view_path_gradient_compute_kernel = [&](BufferFloat grad_in) noexcept {
//EPSM compute color and position gradients w.r.t to view path.
Kernel2D viewpath_construct_kernel = [&](UInt frame_index, Float time, Float shutter_weight) noexcept {
//Construct view path
auto pixel_id = dispatch_id().xy();
return;
auto L = emit_viewpoint_bp(camera, frame_index, pixel_id, time, shutter_weight);
camera->film()->accumulate(pixel_id, L, 0.5f);
};

Kernel2D emit_photons_bp_kernel = [&](UInt frame_index, Float time, BufferFloat &grad_in) noexcept {
auto fake_pixel_id = dispatch_id().xy();
auto sampler_id = UInt2(fake_pixel_id.x + resolution.x, fake_pixel_id.y);
auto photon_id = fake_pixel_id.x * resolution.y + fake_pixel_id.y;
$if(photon_id < photon_per_iter) {
photon_tracing_bp(camera, frame_index, sampler_id, time, photon_id, grad_in);
Kernel1D build_grid_kernel = [&]() noexcept {
auto index = static_cast<UInt>(dispatch_x());
auto radius = node<MegakernelPhotonMappingDiff>()->initial_radius();
$if(viewpoints->nxt(index) == 0u) {
viewpoints->link(index);
};
};


//auto view_path_gradient_compute = pipeline().device().compile(view_path_gradient_compute_kernel);
auto emit_photons_bp = pipeline().device().compile(emit_photons_bp_kernel);

// auto shutter_samples = camera->node()->shutter_samples();
// command_buffer << synchronize();
// LUISA_INFO("Backward Rendering started.");
// Clock clock;
// ProgressBar progress;
// progress.update(0.);
// auto dispatch_count = 0u;
// auto sample_id = 0u;
// bool initial_flag = false;
// uint runtime_spp = 0u;

// command_buffer << indirect_initialize().dispatch(viewpoints_per_iter) << synchronize();
// pipeline().update(command_buffer, 0);
Kernel2D emit_photons_kernel = [&](UInt frame_index, Float time) noexcept {
auto pixel_id = dispatch_id().xy();
auto sampler_id = UInt2(pixel_id.x + resolution.x, pixel_id.y);
$if(pixel_id.x * resolution.y + pixel_id.y < photon_per_iter) {
photon_tracing_bp(camera, frame_index, sampler_id, time, pixel_id.x * resolution.y + pixel_id.y);
};
};

Kernel1D indirect_update_kernel = [&]() noexcept {
set_block_size(16u, 16u, 1u);
auto pixel_id = dispatch_x();
indirect->pixel_info_update(pixel_id);
};

Kernel1D shared_update_kernel = [&]() noexcept {
indirect->shared_update();
viewpoints->write_grid_len(indirect->radius(0u));
};

//accumulate the stored indirect light into final image
Kernel2D indirect_draw_kernel = [&](UInt tot_photon, UInt spp) noexcept {
set_block_size(16u, 16u, 1u);
auto pixel_id = dispatch_id().xy();
auto pixel_id_1d = pixel_id.x * resolution.y + pixel_id.y;
auto L = get_indirect(camera->pipeline().spectrum(), pixel_id_1d, tot_photon);
camera->film()->accumulate(pixel_id, L, 0.5f * spp);
};
Clock clock_compile;

auto indirect_initialize = pipeline().device().compile(indirect_initialize_kernel);
auto viewpoint_reset = pipeline().device().compile(viewpoint_reset_kernel);
auto viewpath_construct = pipeline().device().compile(viewpath_construct_kernel);
auto build_grid = pipeline().device().compile(build_grid_kernel);
auto emit_photon = pipeline().device().compile(emit_photons_kernel);

auto indirect_draw = pipeline().device().compile(indirect_draw_kernel);
auto indirect_update = pipeline().device().compile(indirect_update_kernel);
auto shared_update = pipeline().device().compile(shared_update_kernel);

auto integrator_shader_compilation_time = clock_compile.toc();
LUISA_INFO("Integrator shader compile in {} ms.", integrator_shader_compilation_time);
auto shutter_samples = camera->node()->shutter_samples();
command_buffer << synchronize();

LUISA_INFO("Rendering started.");
Clock clock;
ProgressBar progress;
progress.update(0.);
auto dispatch_count = 0u;
auto sample_id = 0u;
bool initial_flag = false;
uint runtime_spp = 0u;

command_buffer << emit_photons_bp(0, 0, grad_in).dispatch(make_uint2(add_x, resolution.y));

// for (auto s : shutter_samples) {
// runtime_spp+=spp;
// for (auto i = 0u; i < s.spp; i++) {
// //command_buffer << viewpoint_reset().dispatch(viewpoints->size());
// //command_buffer << viewpath_construct(sample_id++, s.point.time, s.point.weight).dispatch(resolution);
// //command_buffer << build_grid().dispatch(viewpoints->size());
// command_buffer << indirect_update().dispatch(viewpoints_per_iter);
// if (node<MegakernelPhotonMappingDiff>()->shared_radius()) {
// command_buffer << shared_update().dispatch(1u);
// }
// }
// }

// command_buffer << indirect_draw(node<MegakernelPhotonMappingDiff>()->photon_per_iter(), runtime_spp).dispatch(resolution);
// LUISA_INFO("Finishi indirect_draw");

command_buffer << indirect_initialize().dispatch(viewpoints_per_iter) << synchronize();
pipeline().update(command_buffer, 0);
for (auto s : shutter_samples) {
runtime_spp+=spp;
for (auto i = 0u; i < s.spp; i++) {
command_buffer << viewpoint_reset().dispatch(viewpoints->size());
command_buffer << viewpath_construct(sample_id++, s.point.time, s.point.weight).dispatch(resolution);
command_buffer << build_grid().dispatch(viewpoints->size());
command_buffer << emit_photon(sample_id++, s.point.time).dispatch(make_uint2(add_x, resolution.y));
command_buffer << indirect_update().dispatch(viewpoints_per_iter);
if (node<MegakernelPhotonMappingDiff>()->shared_radius()) {
command_buffer << shared_update().dispatch(1u);
}
}
}
command_buffer << indirect_draw(node<MegakernelPhotonMappingDiff>()->photon_per_iter(), runtime_spp).dispatch(resolution);
LUISA_INFO("Finishi indirect_draw");
command_buffer << synchronize();
command_buffer << pipeline().printer().retrieve();
progress.done();
auto render_time = clock.toc();
LUISA_INFO("Rendering finished in {} ms.", render_time);



LUISA_INFO("Backward Rendering finished in {} ms.", render_time);
}

void _render_one_camera(CommandBuffer &command_buffer, Camera::Instance *camera) noexcept override {
if (!pipeline().has_lighting()) [[unlikely]] {
LUISA_WARNING_WITH_LOCATION(
Expand All @@ -662,8 +746,6 @@ class MegakernelPhotonMappingDiffInstance : public DifferentiableIntegrator::Ins
auto photon_per_iter = node<MegakernelPhotonMappingDiff>()->photon_per_iter();
auto pixel_count = resolution.x * resolution.y;
auto spectrum = camera->pipeline().spectrum();



uint add_x = (photon_per_iter + resolution.y - 1) / resolution.y;
sampler()->reset(command_buffer, make_uint2(resolution.x + add_x, resolution.y), pixel_count + add_x * resolution.y, spp);
Expand Down Expand Up @@ -705,7 +787,7 @@ class MegakernelPhotonMappingDiffInstance : public DifferentiableIntegrator::Ins
//camera->pipeline().printer().info("rad:{}", indirect->radius(index));

indirect->write_cur_n(index, 0u);
indirect->write_cur_w(index, 0u);
indirect->write_cur_w(index, 0.f);
indirect->write_n_photon(index, 0u);
indirect->reset_phi(index);
indirect->reset_tau(index);
Expand All @@ -719,7 +801,7 @@ class MegakernelPhotonMappingDiffInstance : public DifferentiableIntegrator::Ins
Kernel2D viewpath_construct_kernel = [&](UInt frame_index, Float time, Float shutter_weight) noexcept {
//Construct view path
auto pixel_id = dispatch_id().xy();
auto L = emit_viewpoint_bp(camera, frame_index, pixel_id, time, shutter_weight);
auto L = emit_viewpoint(camera, frame_index, pixel_id, time, shutter_weight);
camera->film()->accumulate(pixel_id, L, 0.5f);
};

Expand Down Expand Up @@ -817,7 +899,7 @@ class MegakernelPhotonMappingDiffInstance : public DifferentiableIntegrator::Ins
[[nodiscard]] Float3 get_indirect(const Spectrum::Instance *spectrum, Expr<uint> pixel_id, Expr<uint> tot_photon) noexcept {
auto r = indirect->radius(pixel_id);
auto tau = indirect->tau(pixel_id);
Float3 L = tau / (_weight->read(pixel_id) * pi * r * r);
Float3 L = tau / (tot_photon * pi * r * r);
return L;
}

Expand Down Expand Up @@ -1021,7 +1103,7 @@ class MegakernelPhotonMappingDiffInstance : public DifferentiableIntegrator::Ins
auto wi_local = it->shading().world_to_local(wi);
Float3 Phi;
auto rel_dis = dis / indirect->radius(pixel_id);
auto weight = 1 - 6 * pow(rel_dis, 5.) + 15 * pow(rel_dis, 4.) - 10 * pow(rel_dis, 3.);
auto weight = 3.0f*(1.0f-rel_dis);//- 6 * pow(rel_dis, 5.) + 15 * pow(rel_dis, 4.) - 10 * pow(rel_dis, 3.);
if (!spectrum->node()->is_fixed()) {
auto viewpoint_swl = viewpoints->swl(pixel_id);
Phi = spectrum->wavelength_mul(swl, beta * (eval_viewpoint / abs_cos_theta(wi_local)), viewpoint_swl, viewpoint_beta);
Expand Down Expand Up @@ -1285,7 +1367,6 @@ class MegakernelPhotonMappingDiffInstance : public DifferentiableIntegrator::Ins
auto rr_depth = node<MegakernelPhotonMappingDiff>()->rr_depth();
$if(depth + 1u >= rr_depth) { u_rr = sampler()->generate_1d(); };
$if(depth > 0) {// add diffuse constraint?

auto grid = viewpoints->point_to_grid(it->p());
Float3 grad_beta = make_float3(0.f);
Float2 grad_bary = make_float2(0.f);
Expand Down Expand Up @@ -1328,13 +1409,14 @@ class MegakernelPhotonMappingDiffInstance : public DifferentiableIntegrator::Ins
requires_grad(bary, beta_diff);
Float3 photon_pos = point_0 * bary[0] + point_1 * bary[1] + point_2 * (1 - bary[0] - bary[1]);
auto rel_dis_diff = distance(position, photon_pos) / rad;
auto weight = 1- 6*pow(rel_dis_diff, 5.) + 15*pow(rel_dis_diff, 4.) - 10*pow(rel_dis_diff, 3.);
auto weight = (1-rel_dis_diff)*3;// 1- 6*pow(rel_dis_diff, 5.) + 15*pow(rel_dis_diff, 4.) - 10*pow(rel_dis_diff, 3.);
auto wi_local = it->shading().world_to_local(wi);
auto Phi = spectrum->srgb(swl, viewpoint_beta * eval_viewpoint / abs_cos_theta(wi_local));
auto Phi_beta = Phi * beta_diff * weight;
grad_pixel_0 = grad_in->read(pixel_id*_grad_dimension+0);
grad_pixel_1 = grad_in->read(pixel_id*_grad_dimension+1);
grad_pixel_2 = grad_in->read(pixel_id*_grad_dimension+2);
auto _grad_dimension = 5u;
auto grad_pixel_0 = grad_in->read(pixel_id*_grad_dimension+0);
auto grad_pixel_1 = grad_in->read(pixel_id * _grad_dimension + 1);
auto grad_pixel_2 = grad_in->read(pixel_id * _grad_dimension + 2);
auto dldPhi = (Phi_beta[0]*grad_pixel_0 + Phi_beta[1]*grad_pixel_1 + Phi_beta[2]*grad_pixel_2) / indirect->cur_w(pixel_id);
backward(dldPhi);
grad_bary += grad(bary).xy();
Expand Down Expand Up @@ -1391,9 +1473,11 @@ class MegakernelPhotonMappingDiffInstance : public DifferentiableIntegrator::Ins
beta *= ite(q < rr_threshold, 1.0f / q, 1.f);
};
};

$if(tot_neighbors>0)
//compute gradient w.r.t to photon position and power
EPSM_photon(path_size, points, normals, inst_ids, triangle_ids, bary_coords, etas, light_sample, grad_betas, grad_barys, mat_bary, mat_param);
{
EPSM_photon(path_size, points, normals, inst_ids, triangle_ids, bary_coords, etas, light_sample, grad_betas, grad_barys, mat_bary, mat_param);
};
}

void EPSM_photon(UInt path_size, ArrayFloat3<4> &points, ArrayFloat3<4> &normals, ArrayUInt<4> &inst_ids, ArrayUInt<4> &triangle_ids, ArrayFloat3<4> &bary_coords,
Expand Down Expand Up @@ -1514,7 +1598,6 @@ class MegakernelPhotonMappingDiffInstance : public DifferentiableIntegrator::Ins
auto bary_cur = bary_coords[0];

$for(id, path_size-1){

auto normal_cur_0 = v0->normal();
auto normal_cur_1 = v1->normal();
auto normal_cur_2 = v2->normal();
Expand Down

0 comments on commit d33f999

Please sign in to comment.