Skip to content

Commit

Permalink
Fix _torch_impl.py
Browse files Browse the repository at this point in the history
  • Loading branch information
inuex35 committed Oct 12, 2024
1 parent c85f423 commit 1474c16
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions gsplat/cuda/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,24 +248,26 @@ def _spherical_proj(

tx, ty, tz = torch.unbind(means, dim=-1) # [C, N]
tr = torch.sqrt(tx**2 + ty**2 + tz**2)
xz_norm = torch.sqrt(tx * tx + tz * tz + 1e-8)
denom_xz = tx * tx + tz * tz + 1e-8
denom_r2 = tr * tr + 1e-8

longitude = torch.atan2(tx, tz)
latitude = torch.atan2(ty, torch.sqrt(tx**2 + tz**2))
latitude = torch.atan2(ty, xz_norm)

normalized_latitude = latitude / (torch.pi / 2.0)
normalized_longitude = longitude / torch.pi

means2d = torch.stack([(normalized_longitude + 1) * width / 2, (normalized_latitude + 1) * height / 2], dim=-1)

O = torch.zeros((C, N), device=means.device, dtype=means.dtype)
J = torch.stack(
[
tz / (tx**2 + tz**2),
-(tx * ty) / (tr**2 * torch.sqrt(tx**2 + tz**2)),
width / (2 * torch.pi) * (tz / denom_xz),
height / torch.pi * -(tx * ty) / (denom_r2 * xz_norm),
O,
torch.sqrt(tx**2 + tz**2) / (tr**2),
-tx / (tx**2 + tz**2),
-(tz * ty) / (tr**2 * torch.sqrt(tx**2 + tz**2)),
height / torch.pi * xz_norm / denom_r2,
width / (2 * torch.pi) * -tx / denom_xz,
height / torch.pi * -(tz * ty) / (denom_r2 * xz_norm)
],
dim=-1,
).reshape(C, N, 2, 3)
Expand Down

0 comments on commit 1474c16

Please sign in to comment.