diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index 96551b5..7b40e71 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -4,6 +4,7 @@ from typing import Any, Dict, Optional, Tuple, Union +import numpy as np import finufft import torch @@ -1592,3 +1593,161 @@ def backward( None, None, ) + + + + + +############################################################################### +# Consolidated forward function for all 1D, 2D, and 3D problems for nufft type 1 +############################################################################### + +def get_nufft_func(dim, nufft_type): + return getattr(finufft, f"nufft{dim}d{nufft_type}") + + +class finufft_type1(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + points: torch.Tensor, + values: torch.Tensor, + output_shape: Union[int, tuple[int, int], tuple[int, int, int]], + out: Optional[torch.Tensor]=None, + fftshift: bool=False, + finufftkwargs: dict[str, Union[int, float]]=None): + """ + Evaluates the Type 1 NUFFT on the inputs. + + """ + + if out is not None: + print("In-place results are not yet implemented") + # All this requires is a check on the out array to make sure it is the + # correct shape. + + err._type1_checks(points, values, output_shape) # revisit these error checks to take into account the shape of points instead of passing them separately + # ^ make sure these checks check for consistency between output shape and len(points) + + if finufftkwargs is None: + finufftkwargs = dict() + finufftkwargs = {k: v for k, v in finufftkwargs.items()} + _mode_ordering = finufftkwargs.pop("modeord", 1) + _i_sign = finufftkwargs.pop("isign", -1) + + if fftshift: + # TODO -- this check should be done elsewhere? or error msg changed + # to note instead that there is a conflict in fftshift + if _mode_ordering != 1: + raise ValueError( + "Double specification of ordering; only one of fftshift and modeord should be provided" + ) + _mode_ordering = 0 + + ctx.save_for_backward(points, values) + + ctx.isign = _i_sign + ctx.mode_ordering = _mode_ordering + ctx.finufftkwargs = finufftkwargs + + # this below should be a pre-check + ndim = points.shape[0] + assert len(output_shape) == ndim + + nufft_func = get_nufft_func(ndim, 1) + finufft_out = torch.from_numpy( + nufft_func( + *points.data.numpy(), + values.data.numpy(), + output_shape, + modeord=_mode_ordering, + isign=_i_sign, + **finufftkwargs, + ) + ) + + return finufft_out + + @staticmethod + def backward( + ctx: Any, grad_output: torch.Tensor + ) -> tuple[Union[torch.Tensor, None], ...]: + """ + Implements derivatives wrt. each argument in the forward method. + + Parameters + ---------- + ctx : Any + Pytorch context object. + grad_output : torch.Tensor + Backpass gradient output + + Returns + ------- + tuple[Union[torch.Tensor, None], ...] + A tuple of derivatives wrt. each argument in the forward method + """ + _i_sign = -1 * ctx.isign + _mode_ordering = ctx.mode_ordering + finufftkwargs = ctx.finufftkwargs + + points, values = ctx.saved_tensors + + start_points = -(np.array(grad_output.shape) // 2) + end_points = start_points + grad_output.shape + slices = tuple(slice(start, end) for start, end in zip(start_points, end_points)) + + # CPU idiosyncracy that needs to be done differently + coord_ramps = torch.from_numpy(np.mgrid[slices]) + + grads_points = None + grad_values = None + + ndim = points.shape[0] + + nufft_func = get_nufft_func(ndim, 2) + + if ctx.needs_input_grad[0]: + # wrt points + + if _mode_ordering != 0: + coord_ramps = torch.fft.ifftshift(coord_ramps, dim=tuple(range(1, ndim+1))) + + ramped_grad_output = coord_ramps * grad_output[np.newaxis] * 1j * _i_sign + + grads_points = [] + for ramp in ramped_grad_output: # we can batch this into finufft + backprop_ramp = torch.from_numpy( + nufft_func( + *points.numpy(), + ramp.data.numpy(), + isign=_i_sign, + modeord=_mode_ordering, + **finufftkwargs, + )) + grad_points = (backprop_ramp.conj() * values).real + grads_points.append(grad_points) + + grads_points = torch.stack(grads_points) + + if ctx.needs_input_grad[1]: + np_grad_output = grad_output.data.numpy() + + grad_values = torch.from_numpy( + nufft_func( + *points.numpy(), + np_grad_output, + isign=_i_sign, + modeord=_mode_ordering, + **finufftkwargs, + ) + ) + + return ( + grads_points, + grad_values, + None, + None, + None, + None, + ) diff --git a/tests/test_1d/test_forward_1d.py b/tests/test_1d/test_forward_1d.py index 1bbd8d0..5379927 100644 --- a/tests/test_1d/test_forward_1d.py +++ b/tests/test_1d/test_forward_1d.py @@ -67,6 +67,16 @@ def test_1d_t1_forward_CPU(values: torch.Tensor) -> None: ) == pytest.approx(0, abs=1e-06) + abs_errors = torch.abs(finufft1D1_out - against_torch) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 3.5e-3 * N ** .6 + assert l_2_error < 7.5e-4 * N ** 1.1 + assert l_1_error < 5e-4 * N ** 1.6 + + @pytest.mark.parametrize("targets", cases) def test_1d_t2_forward_CPU(targets: torch.Tensor): """ @@ -96,6 +106,41 @@ def test_1d_t2_forward_CPU(targets: torch.Tensor): ) +@pytest.mark.parametrize("N", Ns) +def test_t1_forward_CPU(N: int) -> None: + """ + Tests against implementations of the FFT by setting up a uniform grid + over which to call FINUFFT through the API. + """ + g = np.mgrid[:N] * 2 * np.pi / N + g.shape = 1, -1 + points = torch.from_numpy(g.reshape(1, -1)) + + values = torch.randn(*points[0].shape, dtype=torch.complex128) + + print("N is " + str(N)) + print("shape of points is " + str(points.shape)) + print("shape of values is " + str(values.shape)) + + finufft_out = pytorch_finufft.functional.finufft_type1.apply( + points, + values, + (N,), + ) + + against_torch = torch.fft.fft(values.reshape(g[0].shape)) + + abs_errors = torch.abs(finufft_out - against_torch) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 4.5e-5 * N + assert l_2_error < 1e-5 * N ** 2 + assert l_1_error < 1e-5 * N ** 3 + + + # @pytest.mark.parametrize("values", cases) # def test_1d_t3_forward_CPU(values: torch.Tensor) -> None: # """ diff --git a/tests/test_2d/test_backward_2d.py b/tests/test_2d/test_backward_2d.py index ce7122a..6a9b707 100644 --- a/tests/test_2d/test_backward_2d.py +++ b/tests/test_2d/test_backward_2d.py @@ -5,8 +5,11 @@ import pytorch_finufft +from functools import partial + torch.set_default_tensor_type(torch.DoubleTensor) torch.set_default_dtype(torch.float64) +torch.manual_seed(0) ###################################################################### # APPLY WRAPPERS @@ -97,6 +100,50 @@ def test_t1_backward_CPU_values( assert gradcheck(apply_finufft2d1(modifier, fftshift, isign), inputs) +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t1_consolidated_backward_CPU_values(N: int, modifier: int, fftshift: bool, isign: int) -> None: + + points = torch.rand((2, N), dtype=torch.float64) * 2 * np.pi + values = torch.randn(N, dtype=torch.complex128) + + points.requires_grad = False + values.requires_grad = True + + inputs = (points, values) + + def func(points, values): + return pytorch_finufft.functional.finufft_type1.apply( + points, values, (N,N + modifier), None, fftshift, dict(isign=isign) + ) + + assert gradcheck(func, inputs) + + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t1_consolidated_backward_CPU_points(N: int, modifier: int, fftshift: bool, isign: int) -> None: + + points = torch.rand((2, N), dtype=torch.float64) * 2 * np.pi + values = torch.randn(N, dtype=torch.complex128) + + points.requires_grad = True + values.requires_grad = False + + inputs = (points, values) + + def func(points, values): + return pytorch_finufft.functional.finufft_type1.apply( + points, values, (N,N + modifier), None, fftshift, dict(isign=isign) + ) + + assert gradcheck(func, inputs, atol=1e-5 * N) + + @pytest.mark.parametrize("N", Ns) @pytest.mark.parametrize("modifier", length_modifiers) @pytest.mark.parametrize("fftshift", [True, False]) diff --git a/tests/test_2d/test_forward_2d.py b/tests/test_2d/test_forward_2d.py index 3a61938..1dda568 100644 --- a/tests/test_2d/test_forward_2d.py +++ b/tests/test_2d/test_forward_2d.py @@ -1,6 +1,7 @@ import numpy as np import pytest import torch +torch.manual_seed(0) import pytorch_finufft @@ -45,28 +46,14 @@ def test_2d_t1_forward_CPU(N: int) -> None: against_torch = torch.fft.fft2(values.reshape(g[0].shape)) - assert abs((finufft_out - against_torch).sum()) / (N**3) == pytest.approx( - 0, abs=1e-6 - ) + abs_errors = torch.abs(finufft_out - against_torch) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) - values = torch.randn(*x.shape, dtype=torch.complex64) - - finufft_out = pytorch_finufft.functional.finufft2D1.apply( - torch.from_numpy(x).to(torch.float32), - torch.from_numpy(y).to(torch.float32), - values, - N, - ) - - against_torch = torch.fft.fft2(values.reshape(g[0].shape)) - - # NOTE -- the below tolerance is set to 1e-5 instead of -6 due - # to the occasional failing case that seems to be caused by - # the randomness of the test cases in addition to the expected - # accruation of numerical inaccuracies - assert abs((finufft_out - against_torch).sum()) / (N**3) == pytest.approx( - 0, abs=1e-5 - ) + assert l_inf_error < 5e-5 * N + assert l_2_error < 1e-5 * N ** 2 + assert l_1_error < 1e-5 * N ** 3 @pytest.mark.parametrize("N", Ns) @@ -109,9 +96,14 @@ def test_2d_t2_forward_CPU(N: int) -> None: against_torch = torch.fft.ifft2(values) - assert abs((finufft_out - against_torch).sum()) / (N**3) == pytest.approx( - 0, abs=1e-6 - ) + abs_errors = torch.abs(finufft_out - against_torch) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 1e-5 * N + assert l_2_error < 1e-5 * N ** 2 + assert l_1_error < 1e-5 * N ** 3 # @pytest.mark.parametrize("N", Ns) @@ -128,3 +120,37 @@ def test_2d_t2_forward_CPU(N: int) -> None: # assert abs((f - comparison).sum()) / (N**3) == pytest.approx(0, abs=1e-6) # pass + + +@pytest.mark.parametrize("N", Ns) +def test_t1_forward_CPU(N: int) -> None: + """ + Tests against implementations of the FFT by setting up a uniform grid + over which to call FINUFFT through the API. + """ + g = np.mgrid[:N, :N] * 2 * np.pi / N + points = torch.from_numpy(g.reshape(2, -1)) + + values = torch.randn(*points[0].shape, dtype=torch.complex128) + + print("N is " + str(N)) + print("shape of points is " + str(points.shape)) + print("shape of values is " + str(values.shape)) + + finufft_out = pytorch_finufft.functional.finufft_type1.apply( + points, + values, + (N, N), + ) + + against_torch = torch.fft.fft2(values.reshape(g[0].shape)) + + abs_errors = torch.abs(finufft_out - against_torch) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 4.5e-5 * N + assert l_2_error < 1e-5 * N ** 2 + assert l_1_error < 1e-5 * N ** 3 + diff --git a/tests/test_3d/test_forward_3d.py b/tests/test_3d/test_forward_3d.py index e470abc..45484aa 100644 --- a/tests/test_3d/test_forward_3d.py +++ b/tests/test_3d/test_forward_3d.py @@ -1,6 +1,7 @@ import numpy as np import pytest import torch +torch.manual_seed(0) import pytorch_finufft @@ -13,8 +14,6 @@ 25, 26, 37, - 100, - 101, ] @@ -41,9 +40,15 @@ def test_3d_t1_forward_CPU(N: int) -> None: against_torch = torch.fft.fftn(values.reshape(g[0].shape)) - assert abs((finufft_out - against_torch).sum()) / (N**4) == pytest.approx( - 0, abs=1e-6 - ) + abs_errors = torch.abs(finufft_out - against_torch) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 2e-5 * N ** 1.5 + assert l_2_error < 1e-5 * N ** 3 + assert l_1_error < 1e-5 * N ** 4.5 + @pytest.mark.parametrize("N", Ns) @@ -69,6 +74,44 @@ def test_3d_t2_forward_CPU(N: int) -> None: against_torch = torch.fft.ifftn(values) - assert (abs((finufft_out - against_torch).sum())) / (N**4) == pytest.approx( - 0, abs=1e-6 - ) + abs_errors = torch.abs(finufft_out - against_torch) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 1e-5 * N ** 1.5 + assert l_2_error < 1e-5 * N ** 3 + assert l_1_error < 1e-5 * N ** 4.5 + + +@pytest.mark.parametrize("N", Ns) +def test_t1_forward_CPU(N: int) -> None: + """ + Tests against implementations of the FFT by setting up a uniform grid + over which to call FINUFFT through the API. + """ + g = np.mgrid[:N, :N, :N] * 2 * np.pi / N + points = torch.from_numpy(g.reshape(3, -1)) + + values = torch.randn(*points[0].shape, dtype=torch.complex128) + + print("N is " + str(N)) + print("shape of points is " + str(points.shape)) + print("shape of values is " + str(values.shape)) + + finufft_out = pytorch_finufft.functional.finufft_type1.apply( + points, + values, + (N, N, N), + ) + + against_torch = torch.fft.fftn(values.reshape(g[0].shape)) + + abs_errors = torch.abs(finufft_out - against_torch) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 1.5e-5 * N ** 1.5 + assert l_2_error < 1e-5 * N ** 3 + assert l_1_error < 1e-5 * N ** 4.5 \ No newline at end of file