Skip to content

Commit

Permalink
refactor and move SH test
Browse files Browse the repository at this point in the history
  • Loading branch information
maturk committed Oct 3, 2023
1 parent 986c66c commit e663632
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 61 deletions.
61 changes: 0 additions & 61 deletions examples/test_sh.py

This file was deleted.

50 changes: 50 additions & 0 deletions tests/test_sh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest
import torch


device = torch.device("cuda:0")


@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device")
def test_sh():
from diff_rast import _torch_impl
from diff_rast import sh

num_points = 1
degree = 4
gt_colors = torch.ones(num_points, 3, device=device) * 0.5
viewdirs = torch.randn(num_points, 3, device=device)
viewdirs /= torch.linalg.norm(viewdirs, dim=-1, keepdim=True)
sh_coeffs = torch.rand(
num_points, sh.num_sh_bases(degree), 3, device=device, requires_grad=True
)
optim = torch.optim.Adam([sh_coeffs], lr=1e-2)

num_iters = 1000
for _ in range(num_iters):
optim.zero_grad()

# compute PyTorch's color and grad
check_colors = _torch_impl.compute_sh_color(viewdirs, sh_coeffs)
check_loss = torch.square(check_colors - gt_colors).mean()
check_loss.backward()
check_grad = sh_coeffs.grad.detach()

optim.zero_grad()

# compute our colors and grads
colors = sh.SphericalHarmonics.apply(degree, viewdirs, sh_coeffs)
loss = torch.square(colors - gt_colors).mean()
loss.backward()
grad = sh_coeffs.grad.detach()
optim.step()

torch.testing.assert_close(check_grad, grad)
torch.testing.assert_close(check_colors, colors)

# check final optimized color
torch.testing.assert_close(check_colors, gt_colors)


if __name__ == "__main__":
test_sh()

0 comments on commit e663632

Please sign in to comment.