Skip to content

Commit

Permalink
Merge pull request #79 from flatironinstitute/error-on-missing-dep
Browse files Browse the repository at this point in the history
Error on missing dependencies
  • Loading branch information
WardBrian authored Oct 12, 2023
2 parents 973ad34 + bf66ad6 commit 2b94275
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
4 changes: 4 additions & 0 deletions pytorch_finufft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,8 +696,12 @@ def get_nufft_func(
dim: int, nufft_type: int, device_type: str
) -> Callable[..., torch.Tensor]:
if device_type == "cuda":
if not CUFINUFFT_AVAIL:
raise RuntimeError("CUDA device requested but cufinufft failed to import")
return getattr(cufinufft, f"nufft{dim}d{nufft_type}") # type: ignore

if not FINUFFT_AVAIL:
raise RuntimeError("CPU device requested but finufft failed to import")
# CPU needs extra work to go to/from torch and numpy
finufft_func = getattr(finufft, f"nufft{dim}d{nufft_type}")

Expand Down
19 changes: 19 additions & 0 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,22 @@ def test_t1_negative_output_dims() -> None:
ValueError, match="Got output_shape that was not positive integer"
):
pytorch_finufft.functional.finufft_type1.apply(points, values, (10, -2))


# dependencies
def test_finufft_not_installed():
if not pytorch_finufft.functional.CUFINUFFT_AVAIL:
if not torch.cuda.is_available():
pytest.skip("CUDA unavailable")
points = torch.rand(10, dtype=torch.float64).to("cuda")
values = torch.randn(10, dtype=torch.complex128).to("cuda")

with pytest.raises(RuntimeError, match="cufinufft failed to import"):
pytorch_finufft.functional.finufft_type1.apply(points, values, 10)

elif not pytorch_finufft.functional.FINUFFT_AVAIL:
points = torch.rand(10, dtype=torch.float64).to("cpu")
values = torch.randn(10, dtype=torch.complex128).to("cpu")

with pytest.raises(RuntimeError, match="finufft failed to import"):
pytorch_finufft.functional.finufft_type1.apply(points, values, 10)

0 comments on commit 2b94275

Please sign in to comment.