diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5d41053..3ef15c8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,7 +17,7 @@ repos: # Using this mirror lets us use mypyc-compiled black, which is about 2x faster - repo: https://github.com/psf/black-pre-commit-mirror - rev: 23.9.1 + rev: 24.1.1 hooks: - id: black diff --git a/examples/convolution_2d.py b/examples/convolution_2d.py index d746c5b..1d34437 100644 --- a/examples/convolution_2d.py +++ b/examples/convolution_2d.py @@ -3,7 +3,6 @@ ================= """ - ####################################################################################### # Import packages # --------------- diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index e4c1a6c..cdd71bf 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -2,6 +2,7 @@ Implementations of the corresponding Autograd functions """ +import functools import warnings from typing import Any, Callable, Dict, Optional, Tuple, Union @@ -36,12 +37,16 @@ def get_nufft_func( - dim: int, nufft_type: int, device_type: str + dim: int, nufft_type: int, device: torch.device ) -> Callable[..., torch.Tensor]: - if device_type == "cuda": + if device.type == "cuda": if not CUFINUFFT_AVAIL: raise RuntimeError("CUDA device requested but cufinufft failed to import") - return getattr(cufinufft, f"nufft{dim}d{nufft_type}") # type: ignore + # note: in the future, cufinufft may figure out gpu_device_id on its own + # see: https://github.com/flatironinstitute/finufft/issues/420 + return functools.partial( + getattr(cufinufft, f"nufft{dim}d{nufft_type}"), gpu_device_id=device.index + ) if not FINUFFT_AVAIL: raise RuntimeError("CPU device requested but finufft failed to import") @@ -137,7 +142,7 @@ def forward( # type: ignore[override] # pop because cufinufft doesn't support modeord modeord = finufftkwargs.pop("modeord", FinufftType1.MODEORD_DEFAULT) - nufft_func = get_nufft_func(ndim, 1, points.device.type) + nufft_func = get_nufft_func(ndim, 1, points.device) batch_dims = values.shape[:-1] finufft_out = nufft_func( @@ -217,7 +222,7 @@ def backward( # type: ignore[override] grads_points = None grad_values = None - nufft_func = get_nufft_func(ndim, 2, device.type) + nufft_func = get_nufft_func(ndim, 2, device) if any(ctx.needs_input_grad): if _mode_ordering: @@ -317,7 +322,7 @@ def forward( # type: ignore[override] if modeord: targets = batch_fftshift(targets, ndim) - nufft_func = get_nufft_func(ndim, 2, points.device.type) + nufft_func = get_nufft_func(ndim, 2, points.device) batch_dims = targets.shape[:-ndim] shape = targets.shape[-ndim:] finufft_out = nufft_func( @@ -378,7 +383,13 @@ def vmap( # type: ignore[override] @staticmethod def backward( # type: ignore[override] ctx: Any, grad_output: torch.Tensor - ) -> Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None], None, None, None,]: + ) -> Tuple[ + Union[torch.Tensor, None], + Union[torch.Tensor, None], + None, + None, + None, + ]: _i_sign = ctx.isign _mode_ordering = ctx.mode_ordering finufftkwargs = ctx.finufftkwargs @@ -404,7 +415,7 @@ def backward( # type: ignore[override] if ctx.needs_input_grad[0]: # wrt. points - nufft_func = get_nufft_func(ndim, 2, points.device.type) + nufft_func = get_nufft_func(ndim, 2, points.device) coord_ramps = coordinate_ramps(shape, device) @@ -422,7 +433,7 @@ def backward( # type: ignore[override] if ctx.needs_input_grad[1]: # wrt. targets - nufft_func = get_nufft_func(ndim, 1, points.device.type) + nufft_func = get_nufft_func(ndim, 1, points.device) grad_targets = nufft_func( *points, diff --git a/tests/test_t1_forward.py b/tests/test_t1_forward.py index c81e264..9dbac83 100644 --- a/tests/test_t1_forward.py +++ b/tests/test_t1_forward.py @@ -69,3 +69,9 @@ def test_t1_forward_CPU(N, dim) -> None: @pytest.mark.parametrize("N, dim", Ns_and_dims) def test_t1_forward_cuda(N, dim) -> None: check_t1_forward(N, dim, "cuda") + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="require multiple GPUs") +def test_t1_forward_cuda_device_1() -> None: + # added after https://github.com/flatironinstitute/pytorch-finufft/issues/103 + check_t1_forward(3, 1, "cuda:1")