Skip to content

Commit

Permalink
Merge pull request #104 from flatironinstitute/cufinufft-request-devi…
Browse files Browse the repository at this point in the history
…ce-id

Fix: Automatically set cufinufft's gpu_device_id parameter
  • Loading branch information
WardBrian authored Feb 14, 2024
2 parents 9780906 + 5be2121 commit 270369d
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ repos:

# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.9.1
rev: 24.1.1
hooks:
- id: black

Expand Down
1 change: 0 additions & 1 deletion examples/convolution_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
=================
"""


#######################################################################################
# Import packages
# ---------------
Expand Down
29 changes: 20 additions & 9 deletions pytorch_finufft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Implementations of the corresponding Autograd functions
"""

import functools
import warnings
from typing import Any, Callable, Dict, Optional, Tuple, Union

Expand Down Expand Up @@ -36,12 +37,16 @@


def get_nufft_func(
dim: int, nufft_type: int, device_type: str
dim: int, nufft_type: int, device: torch.device
) -> Callable[..., torch.Tensor]:
if device_type == "cuda":
if device.type == "cuda":
if not CUFINUFFT_AVAIL:
raise RuntimeError("CUDA device requested but cufinufft failed to import")
return getattr(cufinufft, f"nufft{dim}d{nufft_type}") # type: ignore
# note: in the future, cufinufft may figure out gpu_device_id on its own
# see: https://github.com/flatironinstitute/finufft/issues/420
return functools.partial(
getattr(cufinufft, f"nufft{dim}d{nufft_type}"), gpu_device_id=device.index
)

if not FINUFFT_AVAIL:
raise RuntimeError("CPU device requested but finufft failed to import")
Expand Down Expand Up @@ -137,7 +142,7 @@ def forward( # type: ignore[override]
# pop because cufinufft doesn't support modeord
modeord = finufftkwargs.pop("modeord", FinufftType1.MODEORD_DEFAULT)

nufft_func = get_nufft_func(ndim, 1, points.device.type)
nufft_func = get_nufft_func(ndim, 1, points.device)

batch_dims = values.shape[:-1]
finufft_out = nufft_func(
Expand Down Expand Up @@ -217,7 +222,7 @@ def backward( # type: ignore[override]
grads_points = None
grad_values = None

nufft_func = get_nufft_func(ndim, 2, device.type)
nufft_func = get_nufft_func(ndim, 2, device)

if any(ctx.needs_input_grad):
if _mode_ordering:
Expand Down Expand Up @@ -317,7 +322,7 @@ def forward( # type: ignore[override]
if modeord:
targets = batch_fftshift(targets, ndim)

nufft_func = get_nufft_func(ndim, 2, points.device.type)
nufft_func = get_nufft_func(ndim, 2, points.device)
batch_dims = targets.shape[:-ndim]
shape = targets.shape[-ndim:]
finufft_out = nufft_func(
Expand Down Expand Up @@ -378,7 +383,13 @@ def vmap( # type: ignore[override]
@staticmethod
def backward( # type: ignore[override]
ctx: Any, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None], None, None, None,]:
) -> Tuple[
Union[torch.Tensor, None],
Union[torch.Tensor, None],
None,
None,
None,
]:
_i_sign = ctx.isign
_mode_ordering = ctx.mode_ordering
finufftkwargs = ctx.finufftkwargs
Expand All @@ -404,7 +415,7 @@ def backward( # type: ignore[override]

if ctx.needs_input_grad[0]:
# wrt. points
nufft_func = get_nufft_func(ndim, 2, points.device.type)
nufft_func = get_nufft_func(ndim, 2, points.device)

coord_ramps = coordinate_ramps(shape, device)

Expand All @@ -422,7 +433,7 @@ def backward( # type: ignore[override]

if ctx.needs_input_grad[1]:
# wrt. targets
nufft_func = get_nufft_func(ndim, 1, points.device.type)
nufft_func = get_nufft_func(ndim, 1, points.device)

grad_targets = nufft_func(
*points,
Expand Down
6 changes: 6 additions & 0 deletions tests/test_t1_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,9 @@ def test_t1_forward_CPU(N, dim) -> None:
@pytest.mark.parametrize("N, dim", Ns_and_dims)
def test_t1_forward_cuda(N, dim) -> None:
check_t1_forward(N, dim, "cuda")


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="require multiple GPUs")
def test_t1_forward_cuda_device_1() -> None:
# added after https://github.com/flatironinstitute/pytorch-finufft/issues/103
check_t1_forward(3, 1, "cuda:1")

0 comments on commit 270369d

Please sign in to comment.