Skip to content

Commit

Permalink
Merge pull request #75 from flatironinstitute/mike-type2
Browse files Browse the repository at this point in the history
[WIP] type 2 finufft consolidated
  • Loading branch information
WardBrian authored Oct 13, 2023
2 parents 2b94275 + 237ed96 commit 92c5eaf
Show file tree
Hide file tree
Showing 7 changed files with 477 additions and 11 deletions.
184 changes: 173 additions & 11 deletions pytorch_finufft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,21 @@ def f(*args, **kwargs):
return f


def coordinate_ramps(shape, device):
start_points = -(torch.tensor(shape, device=device) // 2)
end_points = start_points + torch.tensor(shape, device=device)
coord_ramps = torch.stack(
torch.meshgrid(
*(
torch.arange(start, end, device=device)
for start, end in zip(start_points, end_points)
),
indexing="ij",
)
)

return coord_ramps

class finufft_type1(torch.autograd.Function):
@staticmethod
def forward( # type: ignore[override]
Expand Down Expand Up @@ -803,17 +818,7 @@ def backward( # type: ignore[override]

if ctx.needs_input_grad[0]:
# wrt points
start_points = -(torch.tensor(grad_output.shape, device=device) // 2)
end_points = start_points + torch.tensor(grad_output.shape, device=device)
coord_ramps = torch.stack(
torch.meshgrid(
*(
torch.arange(start, end, device=device)
for start, end in zip(start_points, end_points)
),
indexing="ij",
)
)
coord_ramps = coordinate_ramps(grad_output.shape, device)

# we can't batch in 1d case so we squeeze and fix up the ouput later
ramped_grad_output = (
Expand All @@ -840,3 +845,160 @@ def backward( # type: ignore[override]
None,
None,
)



class finufft_type2(torch.autograd.Function):
"""
FINUFFT 2D problem type 2
"""

@staticmethod
def forward(
ctx: Any,
points: torch.Tensor,
targets: torch.Tensor,
out: Optional[torch.Tensor] = None,
finufftkwargs: Dict[str, Union[int, float]] = None,
) -> torch.Tensor:
"""
Evaluates the Type 2 NUFFT on the inputs.
NOTE: By default, the ordering is set to match that of Pytorch,
Numpy, and Scipy's FFT APIs. To match the mode ordering
native to FINUFFT, add {'modeord': 0} to finufftkwargs.
Parameters
----------
ctx : Any
Pytorch context objecy
points : torch.Tensor, shape=(ndim, num_points)
The non-uniform points x
targets : torch.Tensor
The values on the input grid
out : Optional[torch.Tensor], optional
Array to take the result in-place, by default None
finufftkwargs : Dict[str, Union[int, float]]
Additional arguments will be passed into FINUFFT. See
https://finufft.readthedocs.io/en/latest/python.html.
Returns
-------
torch.Tensor
The Fourier transform of the targets grid evaluated at the points `points`
Raises
------
"""

if out is not None:
print("In-place results are not yet implemented")

# TODO -- extend checks to 2d
checks._type2_checks(points, targets)


if finufftkwargs is None:
finufftkwargs = dict()

finufftkwargs = {k: v for k, v in finufftkwargs.items()}
_mode_ordering = finufftkwargs.pop("modeord", 1) # not finufft default, but corresponds to pytorch default
_i_sign = finufftkwargs.pop("isign", -1) # isign=-1 is finufft default for type 2

ndim = points.shape[0]
if _mode_ordering == 1:
targets = torch.fft.fftshift(targets)


ctx.isign = _i_sign
ctx.mode_ordering = _mode_ordering
ctx.finufftkwargs = finufftkwargs

ctx.save_for_backward(points, targets)

nufft_func = get_nufft_func(ndim, 2, points.device.type)

finufft_out = nufft_func(
*points,
targets,
isign=_i_sign,
**finufftkwargs,
)

return finufft_out

@staticmethod
def backward(
ctx: Any, grad_output: torch.Tensor
) -> Tuple[
Union[torch.Tensor, None],
Union[torch.Tensor, None],
Union[torch.Tensor, None],
None,
None,
None,
]:
"""
Implements derivatives wrt. each argument in the forward method.
Parameters
----------
ctx : Any
Pytorch context object
grad_output : torch.Tensor
Backpass gradient output.
Returns
-------
Tuple[ Union[torch.Tensor, None], ...]
A tuple of derivatives wrt. each argument in the forward method
"""
_i_sign = ctx.isign
_mode_ordering = ctx.mode_ordering
finufftkwargs = ctx.finufftkwargs

points, targets = ctx.saved_tensors
device = points.device

grad_points = grad_targets = None
ndim = points.shape[0]

if ctx.needs_input_grad[0]:
coord_ramps = coordinate_ramps(targets.shape, device=device)
ramped_targets = coord_ramps * targets[np.newaxis] * 1j * _i_sign
nufft_func = get_nufft_func(ndim, 2, points.device.type)

grad_points = nufft_func(
*points,
ramped_targets.squeeze(),
isign=_i_sign,
**finufftkwargs,
).conj() # Currently don't really get why this is hard to replace with a flipped isign

grad_points = grad_points * grad_output
grad_points = torch.atleast_2d(grad_points.real)

if ctx.needs_input_grad[1]:
# wrt. targets
nufft_func = get_nufft_func(ndim, 1, points.device.type)

grad_targets = nufft_func(
*points,
grad_output,
targets.shape,
isign=-_i_sign,
**finufftkwargs,
)

if _mode_ordering == 1:
grad_targets = torch.fft.ifftshift(grad_targets)

return (
grad_points,
grad_targets,
None,
None,
None,
)

65 changes: 65 additions & 0 deletions tests/test_1d/test_backward_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,68 @@ def test_t2_backward_CPU_points(
inputs = (points, targets)

assert gradcheck(apply_finufft1d2(fftshift, isign), inputs, eps=1e-8, atol=1e-5 * N)


def check_t2_backward(
N: int,
modifier: int,
fftshift: bool,
isign: int,
device: str,
points_or_targets: bool,
) -> None:
points = torch.rand((1, N + modifier), dtype=torch.float64).to(device) * 2 * np.pi
targets = torch.randn(N, dtype=torch.complex128).to(device)

points.requires_grad = points_or_targets
targets.requires_grad = not points_or_targets

inputs = (points, targets)

def func(points, targets):
return pytorch_finufft.functional.finufft_type2.apply(
points,
targets,
None,
dict(modeord=int(not fftshift), isign=isign),
)

assert gradcheck(func, inputs, atol=5e-3 * N)


@pytest.mark.parametrize("N", Ns)
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [False, True])
@pytest.mark.parametrize("isign", [-1, 1])
def test_t2_backward_CPU_values(
N: int, modifier: int, fftshift: bool, isign: int
) -> None:
check_t2_backward(N, modifier, fftshift, isign, "cpu", False)

@pytest.mark.parametrize("N", Ns)
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [False, True])
@pytest.mark.parametrize("isign", [-1, 1])
def test_t2_backward_CPU_points(
N: int, modifier: int, fftshift: bool, isign: int
) -> None:
check_t2_backward(N, modifier, fftshift, isign, "cpu", True)

@pytest.mark.parametrize("N", Ns)
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [False, True])
@pytest.mark.parametrize("isign", [-1, 1])
def test_t2_backward_cuda_values(
N: int, modifier: int, fftshift: bool, isign: int
) -> None:
check_t2_backward(N, modifier, fftshift, isign, "cuda", False)

@pytest.mark.parametrize("N", Ns)
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [False, True])
@pytest.mark.parametrize("isign", [-1, 1])
def test_t2_backward_cuda_points(
N: int, modifier: int, fftshift: bool, isign: int
) -> None:
check_t2_backward(N, modifier, fftshift, isign, "cuda", True)

33 changes: 33 additions & 0 deletions tests/test_1d/test_forward_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,36 @@ def test_1d_t2_forward_CPU(targets: torch.Tensor):
assert torch.norm(finufft_out - against_torch) / N**2 == pytest.approx(
0, abs=1e-05
)


@pytest.mark.parametrize("N", Ns)
def test_t2_forward_CPU(N: int) -> None:
"""
Tests against implementations of the FFT by setting up a uniform grid
over which to call FINUFFT through the API.
"""
g = np.mgrid[:N,] * 2 * np.pi / N
points = torch.from_numpy(g.reshape(g.shape[0], -1))

targets = torch.randn(*g[0].shape, dtype=torch.complex128)

print("N is " + str(N))
print("shape of points is " + str(points.shape))
print("shape of targets is " + str(targets.shape))

finufft_out = pytorch_finufft.functional.finufft_type2.apply(
points,
targets,
)

against_torch = torch.fft.fftn(targets)

abs_errors = torch.abs(finufft_out - against_torch.ravel())
l_inf_error = abs_errors.max()
l_2_error = torch.sqrt(torch.sum(abs_errors**2))
l_1_error = torch.sum(abs_errors)

assert l_inf_error < 4.5e-5 * N ** 1.1
assert l_2_error < 6e-5 * N ** 2.1
assert l_1_error < 1.2e-4 * N ** 3.2

66 changes: 66 additions & 0 deletions tests/test_2d/test_backward_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,69 @@ def test_t2_backward_CPU_points_y(
inputs = (points_x, points_y, targets)

assert gradcheck(apply_finufft2d2(fftshift, isign), inputs)



def check_t2_backward(
N: int,
modifier: int,
fftshift: bool,
isign: int,
device: str,
points_or_targets: bool,
) -> None:
points = torch.rand((2, N + modifier), dtype=torch.float64).to(device) * 2 * np.pi
targets = torch.randn(N, N, dtype=torch.complex128).to(device)

points.requires_grad = points_or_targets
targets.requires_grad = not points_or_targets

inputs = (points, targets)

def func(points, targets):
return pytorch_finufft.functional.finufft_type2.apply(
points,
targets,
None,
dict(modeord=int(not fftshift), isign=isign),
)

assert gradcheck(func, inputs, atol=1e-5 * N)


@pytest.mark.parametrize("N", Ns)
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [False, True])
@pytest.mark.parametrize("isign", [-1, 1])
def test_t2_backward_CPU_values(
N: int, modifier: int, fftshift: bool, isign: int
) -> None:
check_t2_backward(N, modifier, fftshift, isign, "cpu", False)

@pytest.mark.parametrize("N", Ns)
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [False, True])
@pytest.mark.parametrize("isign", [-1, 1])
def test_t2_backward_CPU_points(
N: int, modifier: int, fftshift: bool, isign: int
) -> None:
check_t2_backward(N, modifier, fftshift, isign, "cpu", True)

@pytest.mark.parametrize("N", Ns)
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [False, True])
@pytest.mark.parametrize("isign", [-1, 1])
def test_t2_backward_cuda_values(
N: int, modifier: int, fftshift: bool, isign: int
) -> None:
check_t2_backward(N, modifier, fftshift, isign, "cuda", False)

@pytest.mark.parametrize("N", Ns)
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [False, True])
@pytest.mark.parametrize("isign", [-1, 1])
def test_t2_backward_cuda_points(
N: int, modifier: int, fftshift: bool, isign: int
) -> None:
check_t2_backward(N, modifier, fftshift, isign, "cuda", True)

Loading

0 comments on commit 92c5eaf

Please sign in to comment.