Skip to content

Commit

Permalink
Merge pull request #89 from flatironinstitute/ref/functional-wrappers
Browse files Browse the repository at this point in the history
Add function wrappers around .apply
  • Loading branch information
WardBrian authored Oct 17, 2023
2 parents 11cc785 + 8819feb commit 020842c
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 55 deletions.
83 changes: 80 additions & 3 deletions pytorch_finufft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ def coordinate_ramps(shape, device):
return coord_ramps


class finufft_type1(torch.autograd.Function):
class FinufftType1(torch.autograd.Function):
@staticmethod
def forward( # type: ignore[override]
ctx: Any,
points: torch.Tensor,
values: torch.Tensor,
output_shape: Union[int, Tuple[int, int], Tuple[int, int, int]],
output_shape: Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]],
finufftkwargs: Optional[Dict[str, Union[int, float]]] = None,
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -181,7 +181,7 @@ def backward( # type: ignore[override]
)


class finufft_type2(torch.autograd.Function):
class FinufftType2(torch.autograd.Function):
"""
FINUFFT 2D problem type 2
"""
Expand Down Expand Up @@ -324,3 +324,80 @@ def backward( # type: ignore[override]
None,
None,
)


def finufft_type1(
points: torch.Tensor,
values: torch.Tensor,
output_shape: Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]],
**finufftkwargs: Union[int, float],
) -> torch.Tensor:
"""
Evaluates the Type 1 (nonuniform-to-uniform) NUFFT on the inputs.
This is a wrapper around :func:`finufft.nufft1d1`, :func:`finufft.nufft2d1`, and
:func:`finufft.nufft3d1` on CPU, and :func:`cufinufft.nufft1d1`,
:func:`cufinufft.nufft2d1`, and :func:`cufinufft.nufft3d1` on GPU.
Parameters
----------
points : torch.Tensor
DxN tensor of locations of the non-uniform points.
Points must lie in the range ``[-3pi, 3pi]``.
values : torch.Tensor
Length N complex-valued tensor of values at the non-uniform points
output_shape : int | tuple(int, ...)
Requested output shape of Fourier modes. Must be a tuple of length D or
an integer (1D only).
**finufftkwargs : int | float
Additional keyword arguments are forwarded to the underlying
FINUFFT functions. A few notable options are
- ``eps``: precision requested (default: ``1e-6``)
- ``modeord``: 0 for FINUFFT default, 1 for Pytorch default (default: ``1``)
- ``isign``: Sign of the exponent in the Fourier transform (default: ``-1``)
Returns
-------
torch.Tensor
Tensor with shape ``output_shape`` containing the Fourier
transform of the values.
"""
res: torch.Tensor = FinufftType1.apply(points, values, output_shape, finufftkwargs)
return res


def finufft_type2(
points: torch.Tensor,
targets: torch.Tensor,
**finufftkwargs: Union[int, float],
) -> torch.Tensor:
"""
Evaluates the Type 2 (uniform-to-nonuniform) NUFFT on the inputs.
This is a wrapper around :func:`finufft.nufft1d2`, :func:`finufft.nufft2d2`, and
:func:`finufft.nufft3d2` on CPU, and :func:`cufinufft.nufft1d2`,
:func:`cufinufft.nufft2d2`, and :func:`cufinufft.nufft3d2` on GPU.
Parameters
----------
points : torch.Tensor
DxN tensor of locations of the non-uniform points.
Points must lie in the range ``[-3pi, 3pi]``.
targets : torch.Tensor
D-dimensional complex-valued tensor of Fourier modes to evaluate at the points
**finufftkwargs : int | float
Additional keyword arguments are forwarded to the underlying
FINUFFT functions. A few notable options are
- ``eps``: precision requested (default: ``1e-6``)
- ``modeord``: 0 for FINUFFT default, 1 for Pytorch default (default: ``1``)
- ``isign``: Sign of the exponent in the Fourier transform (default: ``-1``)
Returns
-------
torch.Tensor
A DxN tensor of values at the non-uniform points.
"""
res: torch.Tensor = FinufftType2.apply(points, targets, finufftkwargs)
return res
10 changes: 6 additions & 4 deletions tests/test_1d/test_backward_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ def check_t1_backward(
inputs = (points, values)

def func(points, values):
return pytorch_finufft.functional.finufft_type1.apply(
return pytorch_finufft.functional.finufft_type1(
points,
values,
(N + modifier,),
dict(modeord=int(not fftshift), isign=isign),
modeord=int(not fftshift),
isign=isign,
)

assert gradcheck(func, inputs, eps=1e-8, atol=1e-5 * N)
Expand Down Expand Up @@ -110,10 +111,11 @@ def check_t2_backward(
inputs = (points, targets)

def func(points, targets):
return pytorch_finufft.functional.finufft_type2.apply(
return pytorch_finufft.functional.finufft_type2(
points,
targets,
dict(modeord=int(not fftshift), isign=isign),
modeord=int(not fftshift),
isign=isign,
)

assert gradcheck(func, inputs, eps=1e-8, atol=1.5e-3 * N)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_1d/test_forward_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def check_t1_forward(N: int, device: str) -> None:
print("shape of points is " + str(points.shape))
print("shape of values is " + str(values.shape))

finufft_out = pytorch_finufft.functional.finufft_type1.apply(
finufft_out = pytorch_finufft.functional.finufft_type1(
points,
values,
(N,),
Expand Down Expand Up @@ -89,7 +89,7 @@ def test_t2_forward_CPU(N: int) -> None:
print("shape of points is " + str(points.shape))
print("shape of targets is " + str(targets.shape))

finufft_out = pytorch_finufft.functional.finufft_type2.apply(
finufft_out = pytorch_finufft.functional.finufft_type2(
points,
targets,
)
Expand Down
12 changes: 5 additions & 7 deletions tests/test_2d/test_backward_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,8 @@ def check_t1_backward(
inputs = (points, values)

def func(points, values):
return pytorch_finufft.functional.finufft_type1.apply(
points,
values,
(N, N + modifier),
dict(modeord=int(not fftshift), isign=isign),
return pytorch_finufft.functional.finufft_type1(
points, values, (N, N + modifier), modeord=int(not fftshift), isign=isign
)

assert gradcheck(func, inputs, atol=1e-5 * N)
Expand Down Expand Up @@ -121,10 +118,11 @@ def check_t2_backward(
inputs = (points, targets)

def func(points, targets):
return pytorch_finufft.functional.finufft_type2.apply(
return pytorch_finufft.functional.finufft_type2(
points,
targets,
dict(modeord=int(not fftshift), isign=isign),
modeord=int(not fftshift),
isign=isign,
)

assert gradcheck(func, inputs, atol=1e-5 * N)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_2d/test_forward_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def check_t1_forward(N: int, device: str) -> None:
print("shape of points is " + str(points.shape))
print("shape of values is " + str(values.shape))

finufft_out = pytorch_finufft.functional.finufft_type1.apply(
finufft_out = pytorch_finufft.functional.finufft_type1(
points,
values,
(N, N),
Expand Down Expand Up @@ -79,10 +79,10 @@ def test_t2_forward_CPU(N: int, fftshift: bool) -> None:
print("shape of points is " + str(points.shape))
print("shape of targets is " + str(targets.shape))

finufft_out = pytorch_finufft.functional.finufft_type2.apply(
finufft_out = pytorch_finufft.functional.finufft_type2(
points,
targets,
{"modeord": int(not fftshift)},
modeord=int(not fftshift),
)

if fftshift:
Expand Down
11 changes: 5 additions & 6 deletions tests/test_3d/test_backward_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@ def check_t1_backward(
inputs = (points, values)

def func(points, values):
return pytorch_finufft.functional.finufft_type1.apply(
return pytorch_finufft.functional.finufft_type1(
points,
values,
(N, N + modifier, N + 2 * modifier),
dict(modeord=int(not fftshift), isign=isign),
modeord=int(not fftshift),
isign=isign,
)

assert gradcheck(func, inputs, eps=1e-8, atol=1e-5 * N)
Expand Down Expand Up @@ -118,10 +119,8 @@ def check_t2_backward(
inputs = (points, targets)

def func(points, targets):
return pytorch_finufft.functional.finufft_type2.apply(
points,
targets,
dict(modeord=int(not fftshift), isign=isign),
return pytorch_finufft.functional.finufft_type2(
points, targets, modeord=int(not fftshift), isign=isign
)

assert gradcheck(func, inputs, atol=1e-5 * N)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_3d/test_forward_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def check_t1_forward(N: int, device: str) -> None:
print("shape of points is " + str(points.shape))
print("shape of values is " + str(values.shape))

finufft_out = pytorch_finufft.functional.finufft_type1.apply(
finufft_out = pytorch_finufft.functional.finufft_type1(
points,
values,
(N, N, N),
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_t2_forward_CPU(N: int) -> None:
print("shape of points is " + str(points.shape))
print("shape of targets is " + str(targets.shape))

finufft_out = pytorch_finufft.functional.finufft_type2.apply(
finufft_out = pytorch_finufft.functional.finufft_type2(
points,
targets,
)
Expand Down
Loading

0 comments on commit 020842c

Please sign in to comment.