diff --git a/taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py b/taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py index f9a8f891..3e911a4e 100644 --- a/taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py +++ b/taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py @@ -339,6 +339,7 @@ def gaussian_point_rasterisation( # output pixel_offset_of_last_effective_point: ti.types.ndarray(ti.i32, ndim=2), pixel_valid_point_count: ti.types.ndarray(ti.i32, ndim=2), # output + background_color: ti.types.ndarray(ti.f32, ndim=1), # (3) rgb_only: ti.template(), # input ): ti.loop_config(block_dim=(TILE_WIDTH * TILE_HEIGHT)) @@ -469,6 +470,10 @@ def gaussian_point_rasterisation( valid_point_count += 1 T_i = next_T_i # end of point group loop + + background_color_vector = ti.math.vec3([background_color[0], background_color[1], background_color[2]]) + background_color_vector = ti.math.clamp(background_color_vector, 0, 1) + accumulated_color += background_color_vector * T_i # end of point group id loop @@ -517,7 +522,7 @@ def gaussian_point_rasterisation_backward( point_uv_conic_and_rescale: ti.types.ndarray(ti.f32, ndim=2), # (M, 3) point_alpha_after_activation: ti.types.ndarray(ti.f32, ndim=1), # (M) point_color: ti.types.ndarray(ti.f32, ndim=2), # (M, 3) - + background_color: ti.types.ndarray(ti.f32, ndim=1), # (3) need_extra_info: ti.template(), magnitude_grad_viewspace: ti.types.ndarray(ti.f32, ndim=1), # (N) # (H, W, 2) @@ -558,6 +563,9 @@ def gaussian_point_rasterisation_backward( last_effective_point = pixel_offset_of_last_effective_point[pixel_v, pixel_u] accumulated_alpha: ti.f32 = pixel_accumulated_alpha[pixel_v, pixel_u] T_i = 1.0 - accumulated_alpha # T_i = \prod_{j=1}^{i-1} (1 - a_j) + T_final = T_i + background_color_vector = ti.math.vec3([ + background_color[0], background_color[1], background_color[2]]) # \frac{dC}{da_i} = c_i T(i) - \frac{1}{1 - a_i} \sum_{j=i+1}^{n} c_j a_j T(j) # let w_i = \sum_{j=i+1}^{n} c_j a_j T(j) # we have w_n = 0, w_{i-1} = w_i + c_i a_i T(i) @@ -652,6 +660,8 @@ def gaussian_point_rasterisation_backward( # \frac{dC}{da_i} = c_i T(i) - \frac{1}{1 - a_i} w_i alpha_grad_from_rgb = (color * T_i - w_i / (1. - alpha)) \ * pixel_rgb_grad + alpha_grad_from_rgb -= pixel_rgb_grad * background_color_vector * \ + T_final / (1 - alpha) # w_{i-1} = w_i + c_i a_i T(i) w_i += color * alpha * T_i alpha_grad: ti.f32 = alpha_grad_from_rgb.sum() @@ -800,7 +810,9 @@ class GaussianPointCloudRasterisationInput: # Kx4, x to the right, y down, z forward, K is the number of objects q_pointcloud_camera: torch.Tensor # Kx3, x to the right, y down, z forward, K is the number of objects + t_pointcloud_camera: torch.Tensor + background_color: Optional[torch.Tensor] = None # 3 color_max_sh_band: int = 2 @dataclass @@ -837,6 +849,7 @@ def forward(ctx, t_pointcloud_camera, camera_info, color_max_sh_band, + background_color, ): point_in_camera_mask = torch.zeros( size=(pointcloud.shape[0],), dtype=torch.int8, device=pointcloud.device) @@ -994,7 +1007,8 @@ def forward(ctx, rasterized_depth=rasterized_depth, pixel_accumulated_alpha=pixel_accumulated_alpha, pixel_offset_of_last_effective_point=pixel_offset_of_last_effective_point, - pixel_valid_point_count=pixel_valid_point_count) + pixel_valid_point_count=pixel_valid_point_count, + background_color=background_color) ctx.save_for_backward( pointcloud, pointcloud_features, @@ -1016,6 +1030,7 @@ def forward(ctx, point_uv_conic_and_rescale, point_alpha_after_activation, point_color, + background_color ) ctx.camera_info = camera_info ctx.color_max_sh_band = color_max_sh_band @@ -1044,7 +1059,8 @@ def backward(ctx, grad_rasterized_image, grad_rasterized_depth, grad_pixel_valid point_in_camera, \ point_uv_conic, \ point_alpha_after_activation, \ - point_color = ctx.saved_tensors + point_color, \ + background_color = ctx.saved_tensors camera_info = ctx.camera_info color_max_sh_band = ctx.color_max_sh_band grad_rasterized_image = grad_rasterized_image.contiguous() @@ -1093,6 +1109,7 @@ def backward(ctx, grad_rasterized_image, grad_rasterized_depth, grad_pixel_valid point_uv_conic_and_rescale=point_uv_conic.contiguous(), point_alpha_after_activation=point_alpha_after_activation.contiguous(), point_color=point_color.contiguous(), + background_color=background_color.contiguous(), need_extra_info=True, magnitude_grad_viewspace=magnitude_grad_viewspace.contiguous(), magnitude_grad_viewspace_on_image=magnitude_grad_viewspace_on_image.contiguous(), @@ -1160,7 +1177,7 @@ def backward(ctx, grad_rasterized_image, grad_rasterized_depth, grad_pixel_valid None, \ grad_q_pointcloud_camera, \ grad_t_pointcloud_camera, \ - None, None + None, None, None self._module_function = _module_function @@ -1189,6 +1206,10 @@ def forward(self, input_data: GaussianPointCloudRasterisationInput): q_pointcloud_camera = input_data.q_pointcloud_camera t_pointcloud_camera = input_data.t_pointcloud_camera color_max_sh_band = input_data.color_max_sh_band + background_color = input_data.background_color + if background_color is None: + background_color = torch.zeros((3, ), dtype=torch.float32, + device=pointcloud.device) camera_info = input_data.camera_info assert camera_info.camera_width % TILE_WIDTH == 0 assert camera_info.camera_height % TILE_HEIGHT == 0 @@ -1201,4 +1222,5 @@ def forward(self, input_data: GaussianPointCloudRasterisationInput): t_pointcloud_camera, camera_info, color_max_sh_band, + background_color, )