Skip to content

Commit

Permalink
Push the comments, might be helpful for understanding the program log…
Browse files Browse the repository at this point in the history
…ic (#111)

Push the comments, might be helpful for understanding the program logic.
The comments comes from the offline discussion. Wasn't very detailed,
but still good to have.
  • Loading branch information
yanzhoupan authored Aug 7, 2023
1 parent f3ec21d commit f7631e3
Showing 1 changed file with 54 additions and 34 deletions.
88 changes: 54 additions & 34 deletions taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def filter_point_in_camera(
pointcloud: ti.types.ndarray(ti.f32, ndim=2), # (N, 3)
point_invalid_mask: ti.types.ndarray(ti.i8, ndim=1), # (N)
camera_intrinsics: ti.types.ndarray(ti.f32, ndim=2), # (3, 3)
point_object_id: ti.types.ndarray(ti.i32, ndim=1), # (N)
point_object_id: ti.types.ndarray(ti.i32, ndim=1), # (N), every element is in [0, K-1] corresponding to the camera id
q_camera_pointcloud: ti.types.ndarray(ti.f32, ndim=2), # (K, 4)
t_camera_pointcloud: ti.types.ndarray(ti.f32, ndim=2), # (K, 3)
point_in_camera_mask: ti.types.ndarray(ti.i8, ndim=1), # (N)
point_in_camera_mask: ti.types.ndarray(ti.i8, ndim=1), # (N), output
near_plane: ti.f32,
far_plane: ti.f32,
camera_width: ti.i32,
Expand Down Expand Up @@ -99,7 +99,7 @@ def get_bounding_box_by_point_and_radii(

@ti.kernel
def generate_num_overlap_tiles(
num_overlap_tiles: ti.types.ndarray(ti.i32, ndim=1), # (M)
num_overlap_tiles: ti.types.ndarray(ti.i32, ndim=1), # (M), output,a number for each point, count how many tiles the point can be projected in.
point_uv: ti.types.ndarray(ti.f32, ndim=2), # (M, 2)
point_radii: ti.types.ndarray(ti.f32, ndim=1), # (M)
camera_width: ti.i32, # required to be multiple of 16
Expand Down Expand Up @@ -168,9 +168,9 @@ def generate_point_sort_key_by_num_overlap_tiles(
def find_tile_start_and_end(
point_in_camera_sort_key: ti.types.ndarray(ti.i64, ndim=1), # (M)
# (tiles_per_row * tiles_per_col), for output
tile_points_start: ti.types.ndarray(ti.i32, ndim=1),
tile_points_start: ti.types.ndarray(ti.i32, ndim=1), # output
# (tiles_per_row * tiles_per_col), for output
tile_points_end: ti.types.ndarray(ti.i32, ndim=1),
tile_points_end: ti.types.ndarray(ti.i32, ndim=1), # output
):
for idx in range(point_in_camera_sort_key.shape[0] - 1):
sort_key = point_in_camera_sort_key[idx]
Expand Down Expand Up @@ -228,20 +228,20 @@ def load_point_cloud_row_into_gaussian_point_3d(
return gaussian_point_3d

@ti.kernel
def generate_point_attributes_in_camera_plane(
def generate_point_attributes_in_camera_plane( # from 3d gaussian to 2d features, including color, alpha, 2d gaussion covariance
pointcloud: ti.types.ndarray(ti.f32, ndim=2), # (N, 3)
pointcloud_features: ti.types.ndarray(ti.f32, ndim=2), # (N, M)
pointcloud_features: ti.types.ndarray(ti.f32, ndim=2), # (N, 56) 56 features (cov_rotation, xxx, r, g, b)
camera_intrinsics: ti.types.ndarray(ti.f32, ndim=2), # (3, 3)
point_object_id: ti.types.ndarray(ti.i32, ndim=1), # (N)
point_object_id: ti.types.ndarray(ti.i32, ndim=1), # (N) [0, K-1] camera_id
q_camera_pointcloud: ti.types.ndarray(ti.f32, ndim=2), # (K, 4)
t_camera_pointcloud: ti.types.ndarray(ti.f32, ndim=2), # (K, 3)
point_id_list: ti.types.ndarray(ti.i32, ndim=1), # (M)
point_uv: ti.types.ndarray(ti.f32, ndim=2), # (M, 2)
point_in_camera: ti.types.ndarray(ti.f32, ndim=2), # (M, 3)
point_uv_conic: ti.types.ndarray(ti.f32, ndim=2), # (M, 3)
point_alpha_after_activation: ti.types.ndarray(ti.f32, ndim=1), # (M)
point_color: ti.types.ndarray(ti.f32, ndim=2), # (M, 3)
point_radii: ti.types.ndarray(ti.f32, ndim=1), # (M)
point_id_list: ti.types.ndarray(ti.i32, ndim=1), # (M) the point id in the view frustum.
point_uv: ti.types.ndarray(ti.f32, ndim=2), # (M, 2) # output
point_in_camera: ti.types.ndarray(ti.f32, ndim=2), # (M, 3) # output
point_uv_conic: ti.types.ndarray(ti.f32, ndim=2), # (M, 3)# output
point_alpha_after_activation: ti.types.ndarray(ti.f32, ndim=1), # (M), output,alpha after sigmoid
point_color: ti.types.ndarray(ti.f32, ndim=2), # (M, 3)# output
point_radii: ti.types.ndarray(ti.f32, ndim=1), # (M)# output, estimated eigenvalues, basically the size of gaussian
):
for idx in range(point_id_list.shape[0]):
camera_intrinsics_mat = ti.Matrix(
Expand Down Expand Up @@ -315,33 +315,34 @@ def gaussian_point_rasterisation(
point_uv_conic: ti.types.ndarray(ti.f32, ndim=2), # (M, 3)
point_alpha_after_activation: ti.types.ndarray(ti.f32, ndim=1), # (M)
point_color: ti.types.ndarray(ti.f32, ndim=2), # (M, 3)
rasterized_image: ti.types.ndarray(ti.f32, ndim=3), # (H, W, 3)
rasterized_depth: ti.types.ndarray(ti.f32, ndim=2), # (H, W)
pixel_accumulated_alpha: ti.types.ndarray(ti.f32, ndim=2), # (H, W)
rasterized_image: ti.types.ndarray(ti.f32, ndim=3), # (H, W, 3) # output
rasterized_depth: ti.types.ndarray(ti.f32, ndim=2), # (H, W) # output, Note: think about handling the occlusion
pixel_accumulated_alpha: ti.types.ndarray(ti.f32, ndim=2), # (H, W) # output
# (H, W)
pixel_offset_of_last_effective_point: ti.types.ndarray(ti.i32, ndim=2),
pixel_valid_point_count: ti.types.ndarray(ti.i32, ndim=2),
rgb_only: ti.template(),
pixel_offset_of_last_effective_point: ti.types.ndarray(ti.i32, ndim=2), # output
pixel_valid_point_count: ti.types.ndarray(ti.i32, ndim=2), # output
rgb_only: ti.template(), # input
):
ti.loop_config(block_dim=256)
for pixel_offset in ti.ndrange(camera_height * camera_width):
tile_id = pixel_offset // 256
thread_id = pixel_offset % 256
tile_u = ti.cast(tile_id % (camera_width // 16), ti.i32)
for pixel_offset in ti.ndrange(camera_height * camera_width): # 1920*1080
# initialize
tile_id = pixel_offset // 256 # put each 16x16 tile in the same CUDA thread group (block)
thread_id = pixel_offset % 256 # can wait for other threads in the same group, also have a shared memory.
tile_u = ti.cast(tile_id % (camera_width // 16), ti.i32) # tile position
tile_v = ti.cast(tile_id // (camera_width // 16), ti.i32)
pixel_offset_in_tile = pixel_offset - tile_id * 256
pixel_offset_in_tile = pixel_offset - tile_id * 256 # pixel position in tile (The relative position of the pixel in the tile)
pixel_u = tile_u * 16 + pixel_offset_in_tile % 16
pixel_v = tile_v * 16 + pixel_offset_in_tile // 16
start_offset = tile_points_start[tile_id]
end_offset = tile_points_end[tile_id]
T_i = 1.0
T_i = 1.0 # The initial value of accumulated alpha (initial value of accumulated multiplication)
accumulated_color = ti.math.vec3([0., 0., 0.])
accumulated_depth = 0.
depth_normalization_factor = 0.
offset_of_last_effective_point = start_offset

valid_point_count: ti.i32 = 0

# open the shared memory
tile_point_uv = ti.simt.block.SharedArray((2, 256), dtype=ti.f32)
tile_point_uv_conic = ti.simt.block.SharedArray(
(3, 256), dtype=ti.f32)
Expand Down Expand Up @@ -418,11 +419,12 @@ def gaussian_point_rasterisation(
next_T_i = T_i * (1 - alpha)
if next_T_i < 0.0001:
pixel_saturated = True
continue
continue # somehow faster than directly breaking
offset_of_last_effective_point = idx_point_offset_with_sort_key + 1
accumulated_color += color * alpha * T_i

if not rgb_only:
# Weighted depth for all valid points.
depth = tile_point_depth[point_group_offset]
accumulated_depth += depth * alpha * T_i
depth_normalization_factor += alpha * T_i
Expand Down Expand Up @@ -788,6 +790,7 @@ def forward(ctx,
pointcloud.shape[0], dtype=torch.int32, device=pointcloud.device)
q_camera_pointcloud, t_camera_pointcloud = inverse_se3_qt_torch(
q=q_pointcloud_camera, t=t_pointcloud_camera)
# Step 1: filter points
filter_point_in_camera(
pointcloud=pointcloud,
point_invalid_mask=point_invalid_mask,
Expand All @@ -802,13 +805,17 @@ def forward(ctx,
camera_width=camera_info.camera_width,
)
point_in_camera_mask = point_in_camera_mask.bool()

# Get id based on the camera_mask
point_id_in_camera_list = point_id[point_in_camera_mask].contiguous(
)
del point_id
del point_in_camera_mask

# Number of points in camera
num_points_in_camera = point_id_in_camera_list.shape[0]

# Allocate memory
point_uv = torch.empty(
size=(num_points_in_camera, 2), dtype=torch.float32, device=pointcloud.device)
point_alpha_after_activation = torch.empty(
Expand All @@ -822,7 +829,7 @@ def forward(ctx,
point_radii = torch.empty(
size=(num_points_in_camera,), dtype=torch.float32, device=pointcloud.device)


# Step 2: get 2d features
generate_point_attributes_in_camera_plane(
pointcloud=pointcloud,
pointcloud_features=pointcloud_features,
Expand All @@ -839,6 +846,7 @@ def forward(ctx,
point_radii=point_radii,
)

# Step 3: get how many tiles overlapped, in order to allocate memory
num_overlap_tiles = torch.empty_like(point_id_in_camera_list)
generate_num_overlap_tiles(
num_overlap_tiles=num_overlap_tiles,
Expand All @@ -847,28 +855,36 @@ def forward(ctx,
camera_width=camera_info.camera_width,
camera_height=camera_info.camera_height,
)
# Calculate pre-sum of number_overlap_tiles
accumulated_num_overlap_tiles = torch.cumsum(
num_overlap_tiles, dim=0)
if len(accumulated_num_overlap_tiles) > 0:
total_num_overlap_tiles = accumulated_num_overlap_tiles[-1]
else:
total_num_overlap_tiles = 0
# The space of each point.
accumulated_num_overlap_tiles = torch.cat(
(torch.zeros(size=(1,), dtype=torch.int32, device=pointcloud.device),
accumulated_num_overlap_tiles[:-1]))

# del num_overlap_tiles

# 64-bits key
point_in_camera_sort_key = torch.empty(
size=(total_num_overlap_tiles,), dtype=torch.int64, device=pointcloud.device)
# Corresponding to the original position, the record is the point offset in the frustum (engineering optimization)
point_offset_with_sort_key = torch.empty(
size=(total_num_overlap_tiles,), dtype=torch.int32, device=pointcloud.device)

# Step 4: calclualte key
if point_in_camera_sort_key.shape[0] > 0:
generate_point_sort_key_by_num_overlap_tiles(
point_uv=point_uv,
point_in_camera=point_in_camera,
point_radii=point_radii,
accumulated_num_overlap_tiles=accumulated_num_overlap_tiles,
point_offset_with_sort_key=point_offset_with_sort_key,
point_in_camera_sort_key=point_in_camera_sort_key,
accumulated_num_overlap_tiles=accumulated_num_overlap_tiles, # input
point_offset_with_sort_key=point_offset_with_sort_key, # output
point_in_camera_sort_key=point_in_camera_sort_key, # output
camera_width=camera_info.camera_width,
camera_height=camera_info.camera_height,
depth_to_sort_key_scale=self.config.depth_to_sort_key_scale,
Expand All @@ -878,20 +894,22 @@ def forward(ctx,
point_offset_with_sort_key = point_offset_with_sort_key[permutation].contiguous(
) # now the point_offset_with_sort_key is sorted by the sort_key
del permutation

tiles_per_row = camera_info.camera_width // 16
tiles_per_col = camera_info.camera_height // 16
tile_points_start = torch.zeros(size=(
tiles_per_row * tiles_per_col,), dtype=torch.int32, device=pointcloud.device)
tile_points_end = torch.zeros(size=(
tiles_per_row * tiles_per_col,), dtype=torch.int32, device=pointcloud.device)

# Find tile's start and end.
if point_in_camera_sort_key.shape[0] > 0:
find_tile_start_and_end(
point_in_camera_sort_key=point_in_camera_sort_key,
tile_points_start=tile_points_start,
tile_points_end=tile_points_end,
)

# Allocate space for the image.
rasterized_image = torch.empty(
camera_info.camera_height, camera_info.camera_width, 3, dtype=torch.float32, device=pointcloud.device)
rasterized_depth = torch.empty(
Expand All @@ -903,6 +921,8 @@ def forward(ctx,
pixel_valid_point_count = torch.empty(
camera_info.camera_height, camera_info.camera_width, dtype=torch.int32, device=pointcloud.device)
# print(f"num_points: {pointcloud.shape[0]}, num_points_in_camera: {num_points_in_camera}, num_points_rendered: {point_in_camera_sort_key.shape[0]}")

# Step 5: render
if point_in_camera_sort_key.shape[0] > 0:
gaussian_point_rasterisation(
camera_height=camera_info.camera_height,
Expand Down

0 comments on commit f7631e3

Please sign in to comment.