Skip to content

Commit

Permalink
handling case in rasterization if there are no intersections in view (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
vye16 authored Jan 25, 2024
1 parent 45b253f commit 8715991
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 43 deletions.
2 changes: 2 additions & 0 deletions gsplat/project_gaussians.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def forward(
clip_thresh: float = 0.01,
):
num_points = means3d.shape[-2]
if num_points < 1 or means3d.shape[-1] != 3:
raise ValueError(f"Invalid shape for means3d: {means3d.shape}")

(
cov3d,
Expand Down
111 changes: 68 additions & 43 deletions gsplat/rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,35 +112,52 @@ def forward(

num_intersects, cum_tiles_hit = compute_cumulative_intersects(num_tiles_hit)

(
isect_ids_unsorted,
gaussian_ids_unsorted,
isect_ids_sorted,
gaussian_ids_sorted,
tile_bins,
) = bin_and_sort_gaussians(
num_points, num_intersects, xys, depths, radii, cum_tiles_hit, tile_bounds
)

if colors.shape[-1] == 3:
rasterize_fn = _C.rasterize_forward
if num_intersects < 1:
out_img = (
torch.ones(img_height, img_width, colors.shape[-1], device=xys.device)
* background
)
gaussian_ids_sorted = torch.zeros(0, 1, device=xys.device)
tile_bins = torch.zeros(0, 2, device=xys.device)
final_Ts = torch.zeros(img_height, img_width, device=xys.device)
final_idx = torch.zeros(img_height, img_width, device=xys.device)
else:
rasterize_fn = _C.nd_rasterize_forward
out_img, final_Ts, final_idx = rasterize_fn(
tile_bounds,
block,
img_size,
gaussian_ids_sorted,
tile_bins,
xys,
conics,
colors,
opacity,
background,
)
(
isect_ids_unsorted,
gaussian_ids_unsorted,
isect_ids_sorted,
gaussian_ids_sorted,
tile_bins,
) = bin_and_sort_gaussians(
num_points,
num_intersects,
xys,
depths,
radii,
cum_tiles_hit,
tile_bounds,
)
if colors.shape[-1] == 3:
rasterize_fn = _C.rasterize_forward
else:
rasterize_fn = _C.nd_rasterize_forward

out_img, final_Ts, final_idx = rasterize_fn(
tile_bounds,
block,
img_size,
gaussian_ids_sorted,
tile_bins,
xys,
conics,
colors,
opacity,
background,
)

ctx.img_width = img_width
ctx.img_height = img_height
ctx.num_intersects = num_intersects
ctx.save_for_backward(
gaussian_ids_sorted,
tile_bins,
Expand All @@ -163,6 +180,7 @@ def forward(
def backward(ctx, v_out_img, v_out_alpha=None):
img_height = ctx.img_height
img_width = ctx.img_width
num_intersects = ctx.num_intersects

if v_out_alpha is None:
v_out_alpha = torch.zeros_like(v_out_img[..., 0])
Expand All @@ -179,25 +197,32 @@ def backward(ctx, v_out_img, v_out_alpha=None):
final_idx,
) = ctx.saved_tensors

if colors.shape[-1] == 3:
rasterize_fn = _C.rasterize_backward
if num_intersects < 1:
v_xy = torch.zeros_like(xys)
v_conic = torch.zeros_like(conics)
v_colors = torch.zeros_like(colors)
v_opacity = torch.zeros_like(opacity)

else:
rasterize_fn = _C.nd_rasterize_backward
v_xy, v_conic, v_colors, v_opacity = rasterize_fn(
img_height,
img_width,
gaussian_ids_sorted,
tile_bins,
xys,
conics,
colors,
opacity,
background,
final_Ts,
final_idx,
v_out_img,
v_out_alpha,
)
if colors.shape[-1] == 3:
rasterize_fn = _C.rasterize_backward
else:
rasterize_fn = _C.nd_rasterize_backward
v_xy, v_conic, v_colors, v_opacity = rasterize_fn(
img_height,
img_width,
gaussian_ids_sorted,
tile_bins,
xys,
conics,
colors,
opacity,
background,
final_Ts,
final_idx,
v_out_img,
v_out_alpha,
)

return (
v_xy, # xys
Expand Down

0 comments on commit 8715991

Please sign in to comment.