From 87159918cfa7459bbebeb114dd4f1bb9c9310d3e Mon Sep 17 00:00:00 2001 From: Vickie Ye Date: Wed, 24 Jan 2024 18:56:42 -0800 Subject: [PATCH] handling case in rasterization if there are no intersections in view (#109) --- gsplat/project_gaussians.py | 2 + gsplat/rasterize.py | 111 ++++++++++++++++++++++-------------- 2 files changed, 70 insertions(+), 43 deletions(-) diff --git a/gsplat/project_gaussians.py b/gsplat/project_gaussians.py index f98a229d5..4fd310fa6 100644 --- a/gsplat/project_gaussians.py +++ b/gsplat/project_gaussians.py @@ -96,6 +96,8 @@ def forward( clip_thresh: float = 0.01, ): num_points = means3d.shape[-2] + if num_points < 1 or means3d.shape[-1] != 3: + raise ValueError(f"Invalid shape for means3d: {means3d.shape}") ( cov3d, diff --git a/gsplat/rasterize.py b/gsplat/rasterize.py index c946b61ad..aa2103390 100644 --- a/gsplat/rasterize.py +++ b/gsplat/rasterize.py @@ -112,35 +112,52 @@ def forward( num_intersects, cum_tiles_hit = compute_cumulative_intersects(num_tiles_hit) - ( - isect_ids_unsorted, - gaussian_ids_unsorted, - isect_ids_sorted, - gaussian_ids_sorted, - tile_bins, - ) = bin_and_sort_gaussians( - num_points, num_intersects, xys, depths, radii, cum_tiles_hit, tile_bounds - ) - - if colors.shape[-1] == 3: - rasterize_fn = _C.rasterize_forward + if num_intersects < 1: + out_img = ( + torch.ones(img_height, img_width, colors.shape[-1], device=xys.device) + * background + ) + gaussian_ids_sorted = torch.zeros(0, 1, device=xys.device) + tile_bins = torch.zeros(0, 2, device=xys.device) + final_Ts = torch.zeros(img_height, img_width, device=xys.device) + final_idx = torch.zeros(img_height, img_width, device=xys.device) else: - rasterize_fn = _C.nd_rasterize_forward - out_img, final_Ts, final_idx = rasterize_fn( - tile_bounds, - block, - img_size, - gaussian_ids_sorted, - tile_bins, - xys, - conics, - colors, - opacity, - background, - ) + ( + isect_ids_unsorted, + gaussian_ids_unsorted, + isect_ids_sorted, + gaussian_ids_sorted, + tile_bins, + ) = bin_and_sort_gaussians( + num_points, + num_intersects, + xys, + depths, + radii, + cum_tiles_hit, + tile_bounds, + ) + if colors.shape[-1] == 3: + rasterize_fn = _C.rasterize_forward + else: + rasterize_fn = _C.nd_rasterize_forward + + out_img, final_Ts, final_idx = rasterize_fn( + tile_bounds, + block, + img_size, + gaussian_ids_sorted, + tile_bins, + xys, + conics, + colors, + opacity, + background, + ) ctx.img_width = img_width ctx.img_height = img_height + ctx.num_intersects = num_intersects ctx.save_for_backward( gaussian_ids_sorted, tile_bins, @@ -163,6 +180,7 @@ def forward( def backward(ctx, v_out_img, v_out_alpha=None): img_height = ctx.img_height img_width = ctx.img_width + num_intersects = ctx.num_intersects if v_out_alpha is None: v_out_alpha = torch.zeros_like(v_out_img[..., 0]) @@ -179,25 +197,32 @@ def backward(ctx, v_out_img, v_out_alpha=None): final_idx, ) = ctx.saved_tensors - if colors.shape[-1] == 3: - rasterize_fn = _C.rasterize_backward + if num_intersects < 1: + v_xy = torch.zeros_like(xys) + v_conic = torch.zeros_like(conics) + v_colors = torch.zeros_like(colors) + v_opacity = torch.zeros_like(opacity) + else: - rasterize_fn = _C.nd_rasterize_backward - v_xy, v_conic, v_colors, v_opacity = rasterize_fn( - img_height, - img_width, - gaussian_ids_sorted, - tile_bins, - xys, - conics, - colors, - opacity, - background, - final_Ts, - final_idx, - v_out_img, - v_out_alpha, - ) + if colors.shape[-1] == 3: + rasterize_fn = _C.rasterize_backward + else: + rasterize_fn = _C.nd_rasterize_backward + v_xy, v_conic, v_colors, v_opacity = rasterize_fn( + img_height, + img_width, + gaussian_ids_sorted, + tile_bins, + xys, + conics, + colors, + opacity, + background, + final_Ts, + final_idx, + v_out_img, + v_out_alpha, + ) return ( v_xy, # xys