From 5f8617e64d8c0167ca7e55d1838cac625f20e135 Mon Sep 17 00:00:00 2001 From: wanmeihuali Date: Wed, 12 Jul 2023 22:31:36 -0700 Subject: [PATCH 1/3] support bg color --- .../GaussianPointCloudRasterisation.py | 28 +++++++++++++++++-- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py b/taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py index 9acfb052..f41861b3 100644 --- a/taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py +++ b/taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py @@ -455,6 +455,7 @@ def gaussian_point_rasterisation( # (H, W) pixel_offset_of_last_effective_point: ti.types.ndarray(ti.i32, ndim=2), pixel_valid_point_count: ti.types.ndarray(ti.i32, ndim=2), + background_color: ti.types.ndarray(ti.f32, ndim=1), # (3) ): ti.loop_config(block_dim=256) for pixel_offset in ti.ndrange(camera_height * camera_width): @@ -569,6 +570,10 @@ def gaussian_point_rasterisation( ti.simt.block.sync() if tile_saturated_pixel_count[0] == 256: break + + background_color_vector = ti.math.vec3([background_color[0], background_color[1], background_color[2]]) + background_color_vector = ti.math.clamp(background_color_vector, 0, 1) + accumulated_color += background_color_vector * T_i # end of point group id loop @@ -619,6 +624,7 @@ def gaussian_point_rasterisation_backward( point_uv_covariance: ti.types.ndarray(ti.f32, ndim=3), # (M, 2, 2) point_alpha_after_activation: ti.types.ndarray(ti.f32, ndim=1), # (M) point_color: ti.types.ndarray(ti.f32, ndim=2), # (M, 3) + background_color: ti.types.ndarray(ti.f32, ndim=1), # (3) ): camera_intrinsics_mat = ti.Matrix( [[camera_intrinsics[row, col] for col in ti.static(range(3))] for row in ti.static(range(3))]) @@ -664,6 +670,9 @@ def gaussian_point_rasterisation_backward( last_effective_point = pixel_offset_of_last_effective_point[pixel_v, pixel_u] accumulated_alpha: ti.f32 = pixel_accumulated_alpha[pixel_v, pixel_u] T_i = 1.0 - accumulated_alpha # T_i = \prod_{j=1}^{i-1} (1 - a_j) + T_final = T_i + background_color_vector = ti.math.vec3([ + background_color[0], background_color[1], background_color[2]]) # \frac{dC}{da_i} = c_i T(i) - \frac{1}{1 - a_i} \sum_{j=i+1}^{n} c_j a_j T(j) # let w_i = \sum_{j=i+1}^{n} c_j a_j T(j) # we have w_n = 0, w_{i-1} = w_i + c_i a_i T(i) @@ -771,6 +780,8 @@ def gaussian_point_rasterisation_backward( # \frac{dC}{da_i} = c_i T(i) - \frac{1}{1 - a_i} w_i alpha_grad_from_rgb = (color * T_i - w_i / (1. - alpha)) \ * pixel_rgb_grad + alpha_grad_from_rgb -= pixel_rgb_grad * background_color_vector * \ + T_final / (1 - alpha) # w_{i-1} = w_i + c_i a_i T(i) w_i += color * alpha * T_i alpha_grad: ti.f32 = alpha_grad_from_rgb[0] + \ @@ -940,6 +951,7 @@ class GaussianPointCloudRasterisationInput: camera_info: CameraInfo q_pointcloud_camera: torch.Tensor # Kx4, x to the right, y down, z forward, K is the number of objects t_pointcloud_camera: torch.Tensor # Kx3, x to the right, y down, z forward, K is the number of objects + background_color: Optional[torch.Tensor] = None # 3 color_max_sh_band: int = 2 @dataclass @@ -976,6 +988,7 @@ def forward(ctx, t_pointcloud_camera, camera_info, color_max_sh_band, + background_color, ): point_in_camera_mask = torch.zeros( size=(pointcloud.shape[0],), dtype=torch.int8, device=pointcloud.device) @@ -1120,7 +1133,8 @@ def forward(ctx, rasterized_depth=rasterized_depth, pixel_accumulated_alpha=pixel_accumulated_alpha, pixel_offset_of_last_effective_point=pixel_offset_of_last_effective_point, - pixel_valid_point_count=pixel_valid_point_count) + pixel_valid_point_count=pixel_valid_point_count, + background_color=background_color) ctx.save_for_backward( pointcloud, pointcloud_features, @@ -1142,6 +1156,7 @@ def forward(ctx, point_uv_covariance, point_alpha_after_activation, point_color, + background_color ) ctx.camera_info = camera_info ctx.color_max_sh_band = color_max_sh_band @@ -1170,7 +1185,8 @@ def backward(ctx, grad_rasterized_image, grad_rasterized_depth, grad_pixel_valid point_in_camera, \ point_uv_covariance, \ point_alpha_after_activation, \ - point_color = ctx.saved_tensors + point_color, \ + background_color = ctx.saved_tensors camera_info = ctx.camera_info color_max_sh_band = ctx.color_max_sh_band grad_rasterized_image = grad_rasterized_image.contiguous() @@ -1223,6 +1239,7 @@ def backward(ctx, grad_rasterized_image, grad_rasterized_depth, grad_pixel_valid point_uv_covariance=point_uv_covariance, point_alpha_after_activation=point_alpha_after_activation, point_color=point_color, + background_color=background_color, ) del tile_points_start, tile_points_end, pixel_accumulated_alpha, pixel_offset_of_last_effective_point grad_pointcloud_features = self._clear_grad_by_color_max_sh_band( @@ -1273,7 +1290,7 @@ def backward(ctx, grad_rasterized_image, grad_rasterized_depth, grad_pixel_valid None, \ grad_q_pointcloud_camera, \ grad_t_pointcloud_camera, \ - None, None + None, None, None self._module_function = _module_function @@ -1302,6 +1319,10 @@ def forward(self, input_data: GaussianPointCloudRasterisationInput): q_pointcloud_camera = input_data.q_pointcloud_camera t_pointcloud_camera = input_data.t_pointcloud_camera color_max_sh_band = input_data.color_max_sh_band + background_color = input_data.background_color + if background_color is None: + background_color = torch.ones((3, ), dtype=torch.float32, + device=pointcloud.device) camera_info = input_data.camera_info assert camera_info.camera_width % 16 == 0 assert camera_info.camera_height % 16 == 0 @@ -1314,4 +1335,5 @@ def forward(self, input_data: GaussianPointCloudRasterisationInput): t_pointcloud_camera, camera_info, color_max_sh_band, + background_color, ) From ed08d689dfbb562cd15a3b7611d4e26eb9ce9b05 Mon Sep 17 00:00:00 2001 From: wanmeihuali Date: Wed, 25 Oct 2023 11:38:32 -0700 Subject: [PATCH 2/3] fix typo when merging main --- taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py b/taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py index c2382102..3fbb3c33 100644 --- a/taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py +++ b/taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py @@ -275,7 +275,7 @@ def generate_point_attributes_in_camera_plane( # from 3d gaussian to 2d feature [t_camera_pointcloud[point_object_id[point_id], idx] for idx in ti.static(range(3))]) T_camera_pointcloud_mat = transform_matrix_from_quaternion_and_translation( q=point_q_camera_pointcloud, - t=point_t_camera_pointcloud>>>>>>> main, + t=point_t_camera_pointcloud, ) T_pointcloud_camera = taichi_inverse_SE3(T_camera_pointcloud_mat) ray_origin = ti.math.vec3( From 723ba480da6bd6887b3faa80918ccc61910e03ec Mon Sep 17 00:00:00 2001 From: wanmeihuali Date: Wed, 25 Oct 2023 11:39:56 -0700 Subject: [PATCH 3/3] use black as default backgroundd --- taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py b/taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py index 3fbb3c33..3e911a4e 100644 --- a/taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py +++ b/taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py @@ -1208,7 +1208,7 @@ def forward(self, input_data: GaussianPointCloudRasterisationInput): color_max_sh_band = input_data.color_max_sh_band background_color = input_data.background_color if background_color is None: - background_color = torch.ones((3, ), dtype=torch.float32, + background_color = torch.zeros((3, ), dtype=torch.float32, device=pointcloud.device) camera_info = input_data.camera_info assert camera_info.camera_width % TILE_WIDTH == 0