From 1d3204ae0c0121eaa8767c1b5288732806d76708 Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Tue, 10 Oct 2023 21:32:05 -0400 Subject: [PATCH 01/11] ENH type 2 finufft with test for 2D --- pytorch_finufft/functional.py | 192 +++++++++++++++++++++++++++++++ tests/test_2d/test_forward_2d.py | 34 ++++++ 2 files changed, 226 insertions(+) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index 95981bf..48e616b 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -836,3 +836,195 @@ def backward( # type: ignore[override] None, None, ) + + + +class finufft_type2(torch.autograd.Function): + """ + FINUFFT 2D problem type 2 + """ + + @staticmethod + def forward( + ctx: Any, + points: torch.Tensor, + targets: torch.Tensor, + out: Optional[torch.Tensor] = None, + fftshift: bool = False, + finufftkwargs: Dict[str, Union[int, float]] = {}, + ) -> torch.Tensor: + """ + Evaluates the Type 2 NUFFT on the inputs. + + NOTE: By default, the ordering is set to match that of Pytorch, + Numpy, and Scipy's FFT APIs. To match the mode ordering + native to FINUFFT, set fftshift=True. + + ``` + c[j] = SUM f[k1, k2] exp(+/-i (k1 x(j) + k2 y(j))) + k1, k2 + + for j = 0, ..., M-1, where the sum is over -N1/2 <= k1 <= (N1-1)/2, + -N2/2 <= k2 <= (N2-1)/2 + ``` + + Parameters + ---------- + ctx : Any + Pytorch context objecy + points_x : torch.Tensor + The non-uniform points x_j + points_y : torch.Tensor + The non-uniform points y_j + targets : torch.Tensor + The target Fourier mode coefficients f[k1, k2] + out : Optional[torch.Tensor], optional + Array to take the result in-place, by default None + fftshift : bool + If True centers the 0 mode in the resultant torch.Tensor, by default False + finufftkwargs : Dict[str, Union[int, float]] + Additional arguments will be passed into FINUFFT. See + https://finufft.readthedocs.io/en/latest/python.html. By default + an empty dictionary + + Returns + ------- + torch.Tensor + The resultant array c[j] + + Raises + ------ + ValueError + In the case of conflicting specification of the wave-mode ordering. + """ + + if out is not None: + print("In-place results are not yet implemented") + + # TODO -- extend checks to 2d + err._type2_checks(points, targets) + + finufftkwargs = {k: v for k, v in finufftkwargs.items()} + _mode_ordering = finufftkwargs.pop("modeord", 1) + _i_sign = finufftkwargs.pop("isign", -1) + + if fftshift: + if _mode_ordering != 1: # This seems like it is the wrong way round??????? + raise ValueError( + "Double specification of ordering; only one of fftshift and " + "modeord should be provided." + ) + _mode_ordering = 0 + + if _mode_ordering == 1: + targets = torch.fft.fftshift(targets, dim=(-2, -1)) + + + ctx.isign = _i_sign + ctx.mode_ordering = _mode_ordering + ctx.fftshift = fftshift + ctx.finufftkwargs = finufftkwargs + + ctx.save_for_backward(points, targets) + + finufft_out = finufft.nufft2d2( + *points.data.numpy(), + targets.data.numpy(), + #modeord=_mode_ordering, + isign=_i_sign, + **finufftkwargs, + ) + + return torch.from_numpy(finufft_out) + + @staticmethod + def backward( + ctx: Any, grad_output: torch.Tensor + ) -> Tuple[ + Union[torch.Tensor, None], + Union[torch.Tensor, None], + Union[torch.Tensor, None], + None, + None, + None, + ]: + """ + Implements derivatives wrt. each argument in the forward method. + + Parameters + ---------- + ctx : Any + Pytorch context object + grad_output : torch.Tensor + Backpass gradient output. + + Returns + ------- + Tuple[ Union[torch.Tensor, None], ...] + A tuple of derivatives wrt. each argument in the forward method + """ + _i_sign = ctx.isign + _mode_ordering = ctx.mode_ordering + finufftkwargs = ctx.finufftkwargs + + points, targets = ctx.saved_tensors + + start_points = -(np.array(targets.shape) // 2) + end_points = start_points + targets.shape + slices = tuple(slice(start, end) for start, end in zip(start_points, end_points)) + + # CPU idiosyncracy that needs to be done differently + k_ramps = torch.from_numpy(np.mgrid[slices], dtype=points.dtype) + + grad_points_x = grad_points_y = grad_targets = None + + if ctx.needs_input_grad[0]: + # wrt. points_x + if _mode_ordering != 0: + k_ramps = torch.fft.ifftshift(k_ramps, dim=tuple(range(1, len(k_ramps.shape)))) + + # TODO analytically work out if we can simplify this *1j, + # the below conj, and below *values + ramped_targets = k_ramps * targets[np.newaxis] * 1j * _i_sign + + np_points = points.data.numpy() + np_ramped_targets = ramped_targets.data.numpy() + + grad_points = torch.from_numpy( + finufft.nufft2d2( + *np_points, + np_ramped_targets, + isign=_i_sign, + modeord=_mode_ordering, + **finufftkwargs, + ) + ).conj().to(targets.dtype) + + grad_points = grad_points * grad_output + grad_points = grad_points.real + + if ctx.needs_input_grad[1]: + # wrt. targets + + np_points = points.data.numpy() + np_grad_output = grad_output.data.numpy() + + grad_targets = torch.from_numpy( + finufft.nufft2d1( + *np_points, + np_grad_output, + len(targets), + modeord=_mode_ordering, + isign=-_i_sign, + **finufftkwargs, + ) + ) + + return ( + grad_points, + grad_targets, + None, + None, + None, + ) + diff --git a/tests/test_2d/test_forward_2d.py b/tests/test_2d/test_forward_2d.py index ad6ac50..3a94f65 100644 --- a/tests/test_2d/test_forward_2d.py +++ b/tests/test_2d/test_forward_2d.py @@ -9,6 +9,7 @@ # Case generation Ns = [ + 3, 10, 15, 75, @@ -110,3 +111,36 @@ def test_2d_t2_forward_CPU(N: int) -> None: assert l_inf_error < 1e-5 * N assert l_2_error < 1e-5 * N**2 assert l_1_error < 1e-5 * N**3 + + +@pytest.mark.parametrize("N", Ns) +def test_t2_forward_CPU(N: int) -> None: + """ + Tests against implementations of the FFT by setting up a uniform grid + over which to call FINUFFT through the API. + """ + g = np.mgrid[:N, :N] * 2 * np.pi / N + points = torch.from_numpy(g.reshape(2, -1)) + + targets = torch.randn(*g[0].shape, dtype=torch.complex128) + + print("N is " + str(N)) + print("shape of points is " + str(points.shape)) + print("shape of targets is " + str(targets.shape)) + + finufft_out = pytorch_finufft.functional.finufft_type2.apply( + points, + targets, + ) + + against_torch = torch.fft.fft2(targets) + + abs_errors = torch.abs(finufft_out - against_torch.ravel()) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 4.5e-5 * N + assert l_2_error < 1e-5 * N ** 2 + assert l_1_error < 1e-5 * N ** 3 + From 3da1eefc99052d9cf8848a9083c6cce5387effc1 Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Wed, 11 Oct 2023 12:49:56 -0400 Subject: [PATCH 02/11] FIX make forward function work also for 3D and 1D, add corresponding tests --- pytorch_finufft/functional.py | 10 ++++++---- tests/test_1d/test_forward_1d.py | 33 +++++++++++++++++++++++++++++++ tests/test_3d/test_forward_3d.py | 34 ++++++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 4 deletions(-) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index 48e616b..383b280 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -916,8 +916,9 @@ def forward( ) _mode_ordering = 0 + ndim = points.shape[0] if _mode_ordering == 1: - targets = torch.fft.fftshift(targets, dim=(-2, -1)) + targets = torch.fft.fftshift(targets, dim=tuple(range(-ndim, 0))) ctx.isign = _i_sign @@ -927,15 +928,16 @@ def forward( ctx.save_for_backward(points, targets) - finufft_out = finufft.nufft2d2( + nufft_func = get_nufft_func(ndim, 2, points.device.type) + + finufft_out = nufft_func( *points.data.numpy(), targets.data.numpy(), - #modeord=_mode_ordering, isign=_i_sign, **finufftkwargs, ) - return torch.from_numpy(finufft_out) + return finufft_out @staticmethod def backward( diff --git a/tests/test_1d/test_forward_1d.py b/tests/test_1d/test_forward_1d.py index 5f09c24..8ca8a9d 100644 --- a/tests/test_1d/test_forward_1d.py +++ b/tests/test_1d/test_forward_1d.py @@ -101,3 +101,36 @@ def test_1d_t2_forward_CPU(targets: torch.Tensor): assert torch.norm(finufft_out - against_torch) / N**2 == pytest.approx( 0, abs=1e-05 ) + + +@pytest.mark.parametrize("N", Ns) +def test_t2_forward_CPU(N: int) -> None: + """ + Tests against implementations of the FFT by setting up a uniform grid + over which to call FINUFFT through the API. + """ + g = np.mgrid[:N,] * 2 * np.pi / N + points = torch.from_numpy(g.reshape(g.shape[0], -1)) + + targets = torch.randn(*g[0].shape, dtype=torch.complex128) + + print("N is " + str(N)) + print("shape of points is " + str(points.shape)) + print("shape of targets is " + str(targets.shape)) + + finufft_out = pytorch_finufft.functional.finufft_type2.apply( + points, + targets, + ) + + against_torch = torch.fft.fftn(targets) + + abs_errors = torch.abs(finufft_out - against_torch.ravel()) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 4.5e-5 * N ** 1.1 + assert l_2_error < 6e-5 * N ** 2.1 + assert l_1_error < 1.2e-4 * N ** 3.2 + diff --git a/tests/test_3d/test_forward_3d.py b/tests/test_3d/test_forward_3d.py index a5c945e..f4cbd98 100644 --- a/tests/test_3d/test_forward_3d.py +++ b/tests/test_3d/test_forward_3d.py @@ -9,6 +9,7 @@ # Case generation Ns = [ + 3, 5, 10, 15, @@ -92,3 +93,36 @@ def test_3d_t2_forward_CPU(N: int) -> None: assert l_inf_error < 1e-5 * N**1.5 assert l_2_error < 1e-5 * N**3 assert l_1_error < 1e-5 * N**4.5 + + +@pytest.mark.parametrize("N", Ns) +def test_t2_forward_CPU(N: int) -> None: + """ + Tests against implementations of the FFT by setting up a uniform grid + over which to call FINUFFT through the API. + """ + g = np.mgrid[:N, :N, :N] * 2 * np.pi / N + points = torch.from_numpy(g.reshape(g.shape[0], -1)) + + targets = torch.randn(*g[0].shape, dtype=torch.complex128) + + print("N is " + str(N)) + print("shape of points is " + str(points.shape)) + print("shape of targets is " + str(targets.shape)) + + finufft_out = pytorch_finufft.functional.finufft_type2.apply( + points, + targets, + ) + + against_torch = torch.fft.fftn(targets) + + abs_errors = torch.abs(finufft_out - against_torch.ravel()) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 4.5e-5 * N ** 1.1 + assert l_2_error < 6e-5 * N ** 2.1 + assert l_1_error < 1.2e-4 * N ** 3.2 + From 6186170a99b856f66d51f89aab2fc402b16d1455 Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Wed, 11 Oct 2023 14:18:14 -0400 Subject: [PATCH 03/11] FIX err -> checks --- pytorch_finufft/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index 383b280..bdba849 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -902,7 +902,7 @@ def forward( print("In-place results are not yet implemented") # TODO -- extend checks to 2d - err._type2_checks(points, targets) + checks._type2_checks(points, targets) finufftkwargs = {k: v for k, v in finufftkwargs.items()} _mode_ordering = finufftkwargs.pop("modeord", 1) From 1aed7e20079ee60dffdc7a77192d2856c26ad055 Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Wed, 11 Oct 2023 17:11:56 -0400 Subject: [PATCH 04/11] FIX managed to get values and points gradients in 2D to pass tests. Do not understand why the finufft output needs to be conjugated and why that can't be replaced with a different choice of isign --- pytorch_finufft/functional.py | 70 +++++++++++++++++-------------- tests/test_2d/test_backward_2d.py | 48 +++++++++++++++++++++ 2 files changed, 87 insertions(+), 31 deletions(-) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index bdba849..af13202 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -970,57 +970,65 @@ def backward( finufftkwargs = ctx.finufftkwargs points, targets = ctx.saved_tensors + device = points.device - start_points = -(np.array(targets.shape) // 2) - end_points = start_points + targets.shape - slices = tuple(slice(start, end) for start, end in zip(start_points, end_points)) + # start_points = -(np.array(targets.shape) // 2) + # end_points = start_points + targets.shape + # slices = tuple(slice(start, end) for start, end in zip(start_points, end_points)) - # CPU idiosyncracy that needs to be done differently - k_ramps = torch.from_numpy(np.mgrid[slices], dtype=points.dtype) + # # CPU idiosyncracy that needs to be done differently + # k_ramps = torch.from_numpy(np.mgrid[slices], dtype=points.dtype) - grad_points_x = grad_points_y = grad_targets = None + grad_points = grad_targets = None + + ## From type 1, commenting for now to understand whether needed + # if any(ctx.needs_input_grad) and _mode_ordering: + # grad_output = torch.fft.fftshift(grad_output) if ctx.needs_input_grad[0]: - # wrt. points_x - if _mode_ordering != 0: - k_ramps = torch.fft.ifftshift(k_ramps, dim=tuple(range(1, len(k_ramps.shape)))) + # wrt points + start_points = -(torch.tensor(targets.shape, device=device) // 2) + end_points = start_points + torch.tensor(targets.shape, device=device) + coord_ramps = torch.stack( + torch.meshgrid( + *( + torch.arange(start, end, device=device) + for start, end in zip(start_points, end_points) + ), + indexing="ij", + ) + ) - # TODO analytically work out if we can simplify this *1j, - # the below conj, and below *values - ramped_targets = k_ramps * targets[np.newaxis] * 1j * _i_sign + ndim = points.shape[0] - np_points = points.data.numpy() - np_ramped_targets = ramped_targets.data.numpy() + if ctx.needs_input_grad[0]: + ramped_targets = coord_ramps * targets[np.newaxis] * 1j * _i_sign + nufft_func = get_nufft_func(ndim, 2, points.device.type) - grad_points = torch.from_numpy( - finufft.nufft2d2( - *np_points, - np_ramped_targets, + grad_points = nufft_func( + *points, + ramped_targets, isign=_i_sign, - modeord=_mode_ordering, + #modeord=_mode_ordering, **finufftkwargs, - ) - ).conj().to(targets.dtype) + ).conj() # Currently don't really get why this is hard to replace with a flipped isign grad_points = grad_points * grad_output grad_points = grad_points.real if ctx.needs_input_grad[1]: # wrt. targets + nufft_func = get_nufft_func(ndim, 1, points.device.type) - np_points = points.data.numpy() - np_grad_output = grad_output.data.numpy() - - grad_targets = torch.from_numpy( - finufft.nufft2d1( - *np_points, - np_grad_output, - len(targets), - modeord=_mode_ordering, + grad_targets = nufft_func( + *points, + grad_output, + targets.shape, + #modeord=_mode_ordering, isign=-_i_sign, **finufftkwargs, ) - ) + return ( grad_points, diff --git a/tests/test_2d/test_backward_2d.py b/tests/test_2d/test_backward_2d.py index 903595c..7d67dce 100644 --- a/tests/test_2d/test_backward_2d.py +++ b/tests/test_2d/test_backward_2d.py @@ -202,3 +202,51 @@ def test_t2_backward_CPU_points_y( inputs = (points_x, points_y, targets) assert gradcheck(apply_finufft2d2(fftshift, isign), inputs) + + + +def check_t2_backward( + N: int, + modifier: int, + fftshift: bool, + isign: int, + device: str, + points_or_targets: bool, +) -> None: + points = torch.rand((2, N + modifier), dtype=torch.float64).to(device) * 2 * np.pi + targets = torch.randn(N, N, dtype=torch.complex128).to(device) + + points.requires_grad = points_or_targets + targets.requires_grad = not points_or_targets + + inputs = (points, targets) + + def func(points, targets): + return pytorch_finufft.functional.finufft_type2.apply( + points, + targets, + None, + dict(modeord=int(not fftshift), isign=isign), + ) + + assert gradcheck(func, inputs, atol=1e-5 * N) + + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t2_backward_CPU_values( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t2_backward(N, modifier, fftshift, isign, "cpu", False) + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t2_backward_CPU_points( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t2_backward(N, modifier, fftshift, isign, "cpu", True) + From c329e01566ee377887f7601d99302b456d5c23c7 Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Thu, 12 Oct 2023 12:20:12 -0400 Subject: [PATCH 05/11] TST add mode ordering to type2 forward test 2d --- tests/test_2d/test_forward_2d.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/test_2d/test_forward_2d.py b/tests/test_2d/test_forward_2d.py index 3a94f65..ba84f45 100644 --- a/tests/test_2d/test_forward_2d.py +++ b/tests/test_2d/test_forward_2d.py @@ -114,7 +114,8 @@ def test_2d_t2_forward_CPU(N: int) -> None: @pytest.mark.parametrize("N", Ns) -def test_t2_forward_CPU(N: int) -> None: +@pytest.mark.parametrize("fftshift", [False, True]) +def test_t2_forward_CPU(N: int, fftshift: bool) -> None: """ Tests against implementations of the FFT by setting up a uniform grid over which to call FINUFFT through the API. @@ -131,9 +132,14 @@ def test_t2_forward_CPU(N: int) -> None: finufft_out = pytorch_finufft.functional.finufft_type2.apply( points, targets, + None, + {'modeord': int(not fftshift)}, ) - against_torch = torch.fft.fft2(targets) + if fftshift: + against_torch = torch.fft.fft2(torch.fft.ifftshift(targets)) + else: + against_torch = torch.fft.fft2(targets) abs_errors = torch.abs(finufft_out - against_torch.ravel()) l_inf_error = abs_errors.max() From 729dac03adfdf1bcccde6d03871fed6610cc73c5 Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Thu, 12 Oct 2023 15:05:00 -0400 Subject: [PATCH 06/11] FIX 2D backward working --- pytorch_finufft/functional.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index af13202..a5ce56f 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -850,7 +850,6 @@ def forward( points: torch.Tensor, targets: torch.Tensor, out: Optional[torch.Tensor] = None, - fftshift: bool = False, finufftkwargs: Dict[str, Union[int, float]] = {}, ) -> torch.Tensor: """ @@ -908,13 +907,13 @@ def forward( _mode_ordering = finufftkwargs.pop("modeord", 1) _i_sign = finufftkwargs.pop("isign", -1) - if fftshift: - if _mode_ordering != 1: # This seems like it is the wrong way round??????? - raise ValueError( - "Double specification of ordering; only one of fftshift and " - "modeord should be provided." - ) - _mode_ordering = 0 + # if fftshift: + # if _mode_ordering != 1: # This seems like it is the wrong way round??????? + # raise ValueError( + # "Double specification of ordering; only one of fftshift and " + # "modeord should be provided." + # ) + # _mode_ordering = 0 ndim = points.shape[0] if _mode_ordering == 1: @@ -923,7 +922,6 @@ def forward( ctx.isign = _i_sign ctx.mode_ordering = _mode_ordering - ctx.fftshift = fftshift ctx.finufftkwargs = finufftkwargs ctx.save_for_backward(points, targets) @@ -980,10 +978,8 @@ def backward( # k_ramps = torch.from_numpy(np.mgrid[slices], dtype=points.dtype) grad_points = grad_targets = None + ndim = points.shape[0] - ## From type 1, commenting for now to understand whether needed - # if any(ctx.needs_input_grad) and _mode_ordering: - # grad_output = torch.fft.fftshift(grad_output) if ctx.needs_input_grad[0]: # wrt points @@ -999,8 +995,6 @@ def backward( ) ) - ndim = points.shape[0] - if ctx.needs_input_grad[0]: ramped_targets = coord_ramps * targets[np.newaxis] * 1j * _i_sign nufft_func = get_nufft_func(ndim, 2, points.device.type) @@ -1029,6 +1023,8 @@ def backward( **finufftkwargs, ) + if _mode_ordering == 1: + grad_targets = torch.fft.ifftshift(grad_targets, dim=tuple(range(-ndim, 0))) return ( grad_points, From 8fe4eab62e41c5bf0ffbd8b42b365fd4c1431084 Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Thu, 12 Oct 2023 15:08:52 -0400 Subject: [PATCH 07/11] TST backward type2 3d working --- tests/test_3d/test_backward_3d.py | 48 +++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/test_3d/test_backward_3d.py b/tests/test_3d/test_backward_3d.py index 8dd41ea..50fa9b6 100644 --- a/tests/test_3d/test_backward_3d.py +++ b/tests/test_3d/test_backward_3d.py @@ -240,3 +240,51 @@ def test_t2_backward_CPU_points_z( inputs = (points_x, points_y, points_z, targets) assert gradcheck(apply_finufft3d2(fftshift, isign), inputs) + + + +def check_t2_backward( + N: int, + modifier: int, + fftshift: bool, + isign: int, + device: str, + points_or_targets: bool, +) -> None: + points = torch.rand((3, N + modifier), dtype=torch.float64).to(device) * 2 * np.pi + targets = torch.randn(N, N, N, dtype=torch.complex128).to(device) + + points.requires_grad = points_or_targets + targets.requires_grad = not points_or_targets + + inputs = (points, targets) + + def func(points, targets): + return pytorch_finufft.functional.finufft_type2.apply( + points, + targets, + None, + dict(modeord=int(not fftshift), isign=isign), + ) + + assert gradcheck(func, inputs, atol=1e-5 * N) + + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t2_backward_CPU_values( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t2_backward(N, modifier, fftshift, isign, "cpu", False) + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t2_backward_CPU_points( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t2_backward(N, modifier, fftshift, isign, "cpu", True) + From 5db77b2304781d320fed2e2b4952b5f7047522fc Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Thu, 12 Oct 2023 15:19:36 -0400 Subject: [PATCH 08/11] FIX backward 1d type 2 working --- pytorch_finufft/functional.py | 4 +-- tests/test_1d/test_backward_1d.py | 47 +++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index a5ce56f..daf1ef0 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -1001,14 +1001,14 @@ def backward( grad_points = nufft_func( *points, - ramped_targets, + ramped_targets.squeeze(), isign=_i_sign, #modeord=_mode_ordering, **finufftkwargs, ).conj() # Currently don't really get why this is hard to replace with a flipped isign grad_points = grad_points * grad_output - grad_points = grad_points.real + grad_points = torch.atleast_2d(grad_points.real) if ctx.needs_input_grad[1]: # wrt. targets diff --git a/tests/test_1d/test_backward_1d.py b/tests/test_1d/test_backward_1d.py index 43793ed..271c353 100644 --- a/tests/test_1d/test_backward_1d.py +++ b/tests/test_1d/test_backward_1d.py @@ -161,3 +161,50 @@ def test_t2_backward_CPU_points( inputs = (points, targets) assert gradcheck(apply_finufft1d2(fftshift, isign), inputs, eps=1e-8, atol=1e-5 * N) + + +def check_t2_backward( + N: int, + modifier: int, + fftshift: bool, + isign: int, + device: str, + points_or_targets: bool, +) -> None: + points = torch.rand((1, N + modifier), dtype=torch.float64).to(device) * 2 * np.pi + targets = torch.randn(N, dtype=torch.complex128).to(device) + + points.requires_grad = points_or_targets + targets.requires_grad = not points_or_targets + + inputs = (points, targets) + + def func(points, targets): + return pytorch_finufft.functional.finufft_type2.apply( + points, + targets, + None, + dict(modeord=int(not fftshift), isign=isign), + ) + + assert gradcheck(func, inputs, atol=1.5e-4 * N) + + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t2_backward_CPU_values( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t2_backward(N, modifier, fftshift, isign, "cpu", False) + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t2_backward_CPU_points( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t2_backward(N, modifier, fftshift, isign, "cpu", True) + From f53348c9f21fd255d0495fbfee32f5b1a9308e86 Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Thu, 12 Oct 2023 16:28:46 -0400 Subject: [PATCH 09/11] FIX apply suggestions from Brian's PR comments including putting coordinate ramps in a helper --- pytorch_finufft/functional.py | 106 +++++++++++----------------------- 1 file changed, 35 insertions(+), 71 deletions(-) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index daf1ef0..3e423bb 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -712,6 +712,21 @@ def f(*args, **kwargs): return f +def coordinate_ramps(shape, device): + start_points = -(torch.tensor(shape, device=device) // 2) + end_points = start_points + torch.tensor(shape, device=device) + coord_ramps = torch.stack( + torch.meshgrid( + *( + torch.arange(start, end, device=device) + for start, end in zip(start_points, end_points) + ), + indexing="ij", + ) + ) + + return coord_ramps + class finufft_type1(torch.autograd.Function): @staticmethod def forward( # type: ignore[override] @@ -799,17 +814,7 @@ def backward( # type: ignore[override] if ctx.needs_input_grad[0]: # wrt points - start_points = -(torch.tensor(grad_output.shape, device=device) // 2) - end_points = start_points + torch.tensor(grad_output.shape, device=device) - coord_ramps = torch.stack( - torch.meshgrid( - *( - torch.arange(start, end, device=device) - for start, end in zip(start_points, end_points) - ), - indexing="ij", - ) - ) + coord_ramps = coordinate_ramps(grad_output.shape, device) # we can't batch in 1d case so we squeeze and fix up the ouput later ramped_grad_output = ( @@ -850,51 +855,37 @@ def forward( points: torch.Tensor, targets: torch.Tensor, out: Optional[torch.Tensor] = None, - finufftkwargs: Dict[str, Union[int, float]] = {}, + finufftkwargs: Dict[str, Union[int, float]] = None, ) -> torch.Tensor: """ Evaluates the Type 2 NUFFT on the inputs. NOTE: By default, the ordering is set to match that of Pytorch, Numpy, and Scipy's FFT APIs. To match the mode ordering - native to FINUFFT, set fftshift=True. - - ``` - c[j] = SUM f[k1, k2] exp(+/-i (k1 x(j) + k2 y(j))) - k1, k2 - - for j = 0, ..., M-1, where the sum is over -N1/2 <= k1 <= (N1-1)/2, - -N2/2 <= k2 <= (N2-1)/2 - ``` + native to FINUFFT, add {'modeord': 0} to finufftkwargs. Parameters ---------- ctx : Any Pytorch context objecy - points_x : torch.Tensor - The non-uniform points x_j - points_y : torch.Tensor - The non-uniform points y_j + points : torch.Tensor, shape=(ndim, num_points) + The non-uniform points x targets : torch.Tensor - The target Fourier mode coefficients f[k1, k2] + The values on the input grid out : Optional[torch.Tensor], optional Array to take the result in-place, by default None - fftshift : bool - If True centers the 0 mode in the resultant torch.Tensor, by default False finufftkwargs : Dict[str, Union[int, float]] Additional arguments will be passed into FINUFFT. See - https://finufft.readthedocs.io/en/latest/python.html. By default - an empty dictionary + https://finufft.readthedocs.io/en/latest/python.html. Returns ------- torch.Tensor - The resultant array c[j] + The Fourier transform of the targets grid evaluated at the points `points` Raises ------ - ValueError - In the case of conflicting specification of the wave-mode ordering. + """ if out is not None: @@ -903,21 +894,17 @@ def forward( # TODO -- extend checks to 2d checks._type2_checks(points, targets) - finufftkwargs = {k: v for k, v in finufftkwargs.items()} - _mode_ordering = finufftkwargs.pop("modeord", 1) - _i_sign = finufftkwargs.pop("isign", -1) - # if fftshift: - # if _mode_ordering != 1: # This seems like it is the wrong way round??????? - # raise ValueError( - # "Double specification of ordering; only one of fftshift and " - # "modeord should be provided." - # ) - # _mode_ordering = 0 + if finufftkwargs is None: + finufftkwargs = dict() + + finufftkwargs = {k: v for k, v in finufftkwargs.items()} + _mode_ordering = finufftkwargs.pop("modeord", 1) # not finufft default, but corresponds to pytorch default + _i_sign = finufftkwargs.pop("isign", -1) # isign=-1 is finufft default for type 2 ndim = points.shape[0] if _mode_ordering == 1: - targets = torch.fft.fftshift(targets, dim=tuple(range(-ndim, 0))) + targets = torch.fft.fftshift(targets) ctx.isign = _i_sign @@ -929,8 +916,8 @@ def forward( nufft_func = get_nufft_func(ndim, 2, points.device.type) finufft_out = nufft_func( - *points.data.numpy(), - targets.data.numpy(), + *points, + targets, isign=_i_sign, **finufftkwargs, ) @@ -970,32 +957,11 @@ def backward( points, targets = ctx.saved_tensors device = points.device - # start_points = -(np.array(targets.shape) // 2) - # end_points = start_points + targets.shape - # slices = tuple(slice(start, end) for start, end in zip(start_points, end_points)) - - # # CPU idiosyncracy that needs to be done differently - # k_ramps = torch.from_numpy(np.mgrid[slices], dtype=points.dtype) - grad_points = grad_targets = None ndim = points.shape[0] - - if ctx.needs_input_grad[0]: - # wrt points - start_points = -(torch.tensor(targets.shape, device=device) // 2) - end_points = start_points + torch.tensor(targets.shape, device=device) - coord_ramps = torch.stack( - torch.meshgrid( - *( - torch.arange(start, end, device=device) - for start, end in zip(start_points, end_points) - ), - indexing="ij", - ) - ) - if ctx.needs_input_grad[0]: + coord_ramps = coordinate_ramps(targets.shape, device=device) ramped_targets = coord_ramps * targets[np.newaxis] * 1j * _i_sign nufft_func = get_nufft_func(ndim, 2, points.device.type) @@ -1003,7 +969,6 @@ def backward( *points, ramped_targets.squeeze(), isign=_i_sign, - #modeord=_mode_ordering, **finufftkwargs, ).conj() # Currently don't really get why this is hard to replace with a flipped isign @@ -1018,13 +983,12 @@ def backward( *points, grad_output, targets.shape, - #modeord=_mode_ordering, isign=-_i_sign, **finufftkwargs, ) if _mode_ordering == 1: - grad_targets = torch.fft.ifftshift(grad_targets, dim=tuple(range(-ndim, 0))) + grad_targets = torch.fft.ifftshift(grad_targets) return ( grad_points, From b7233e3c8685a5edd56457e0445992a9472524c2 Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Thu, 12 Oct 2023 16:29:10 -0400 Subject: [PATCH 10/11] FIX address PR comments for tests --- tests/test_1d/test_backward_1d.py | 20 +++++++++++++++++++- tests/test_2d/test_backward_2d.py | 18 ++++++++++++++++++ tests/test_3d/test_backward_3d.py | 18 ++++++++++++++++++ 3 files changed, 55 insertions(+), 1 deletion(-) diff --git a/tests/test_1d/test_backward_1d.py b/tests/test_1d/test_backward_1d.py index 271c353..fefc70f 100644 --- a/tests/test_1d/test_backward_1d.py +++ b/tests/test_1d/test_backward_1d.py @@ -187,7 +187,7 @@ def func(points, targets): dict(modeord=int(not fftshift), isign=isign), ) - assert gradcheck(func, inputs, atol=1.5e-4 * N) + assert gradcheck(func, inputs, atol=5e-3 * N) @pytest.mark.parametrize("N", Ns) @@ -208,3 +208,21 @@ def test_t2_backward_CPU_points( ) -> None: check_t2_backward(N, modifier, fftshift, isign, "cpu", True) +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t2_backward_CPU_values( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t2_backward(N, modifier, fftshift, isign, "cuda", False) + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t2_backward_CPU_points( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t2_backward(N, modifier, fftshift, isign, "cuda", True) + diff --git a/tests/test_2d/test_backward_2d.py b/tests/test_2d/test_backward_2d.py index 7d67dce..7650769 100644 --- a/tests/test_2d/test_backward_2d.py +++ b/tests/test_2d/test_backward_2d.py @@ -250,3 +250,21 @@ def test_t2_backward_CPU_points( ) -> None: check_t2_backward(N, modifier, fftshift, isign, "cpu", True) +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t2_backward_CPU_values( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t2_backward(N, modifier, fftshift, isign, "cuda", False) + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t2_backward_CPU_points( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t2_backward(N, modifier, fftshift, isign, "cuda", True) + diff --git a/tests/test_3d/test_backward_3d.py b/tests/test_3d/test_backward_3d.py index 50fa9b6..5377c25 100644 --- a/tests/test_3d/test_backward_3d.py +++ b/tests/test_3d/test_backward_3d.py @@ -288,3 +288,21 @@ def test_t2_backward_CPU_points( ) -> None: check_t2_backward(N, modifier, fftshift, isign, "cpu", True) +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t2_backward_CPU_values( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t2_backward(N, modifier, fftshift, isign, "cuda", False) + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t2_backward_CPU_points( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t2_backward(N, modifier, fftshift, isign, "cuda", True) + From 237ed96a8d64b0a228d9fab693aadf727dd3d445 Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Thu, 12 Oct 2023 16:44:58 -0400 Subject: [PATCH 11/11] FIX gave the cuda tests the correct name (CPU->cuda) --- tests/test_1d/test_backward_1d.py | 4 ++-- tests/test_2d/test_backward_2d.py | 4 ++-- tests/test_3d/test_backward_3d.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_1d/test_backward_1d.py b/tests/test_1d/test_backward_1d.py index fefc70f..988c578 100644 --- a/tests/test_1d/test_backward_1d.py +++ b/tests/test_1d/test_backward_1d.py @@ -212,7 +212,7 @@ def test_t2_backward_CPU_points( @pytest.mark.parametrize("modifier", length_modifiers) @pytest.mark.parametrize("fftshift", [False, True]) @pytest.mark.parametrize("isign", [-1, 1]) -def test_t2_backward_CPU_values( +def test_t2_backward_cuda_values( N: int, modifier: int, fftshift: bool, isign: int ) -> None: check_t2_backward(N, modifier, fftshift, isign, "cuda", False) @@ -221,7 +221,7 @@ def test_t2_backward_CPU_values( @pytest.mark.parametrize("modifier", length_modifiers) @pytest.mark.parametrize("fftshift", [False, True]) @pytest.mark.parametrize("isign", [-1, 1]) -def test_t2_backward_CPU_points( +def test_t2_backward_cuda_points( N: int, modifier: int, fftshift: bool, isign: int ) -> None: check_t2_backward(N, modifier, fftshift, isign, "cuda", True) diff --git a/tests/test_2d/test_backward_2d.py b/tests/test_2d/test_backward_2d.py index 7650769..88a51bf 100644 --- a/tests/test_2d/test_backward_2d.py +++ b/tests/test_2d/test_backward_2d.py @@ -254,7 +254,7 @@ def test_t2_backward_CPU_points( @pytest.mark.parametrize("modifier", length_modifiers) @pytest.mark.parametrize("fftshift", [False, True]) @pytest.mark.parametrize("isign", [-1, 1]) -def test_t2_backward_CPU_values( +def test_t2_backward_cuda_values( N: int, modifier: int, fftshift: bool, isign: int ) -> None: check_t2_backward(N, modifier, fftshift, isign, "cuda", False) @@ -263,7 +263,7 @@ def test_t2_backward_CPU_values( @pytest.mark.parametrize("modifier", length_modifiers) @pytest.mark.parametrize("fftshift", [False, True]) @pytest.mark.parametrize("isign", [-1, 1]) -def test_t2_backward_CPU_points( +def test_t2_backward_cuda_points( N: int, modifier: int, fftshift: bool, isign: int ) -> None: check_t2_backward(N, modifier, fftshift, isign, "cuda", True) diff --git a/tests/test_3d/test_backward_3d.py b/tests/test_3d/test_backward_3d.py index 5377c25..3c63014 100644 --- a/tests/test_3d/test_backward_3d.py +++ b/tests/test_3d/test_backward_3d.py @@ -292,7 +292,7 @@ def test_t2_backward_CPU_points( @pytest.mark.parametrize("modifier", length_modifiers) @pytest.mark.parametrize("fftshift", [False, True]) @pytest.mark.parametrize("isign", [-1, 1]) -def test_t2_backward_CPU_values( +def test_t2_backward_cuda_values( N: int, modifier: int, fftshift: bool, isign: int ) -> None: check_t2_backward(N, modifier, fftshift, isign, "cuda", False) @@ -301,7 +301,7 @@ def test_t2_backward_CPU_values( @pytest.mark.parametrize("modifier", length_modifiers) @pytest.mark.parametrize("fftshift", [False, True]) @pytest.mark.parametrize("isign", [-1, 1]) -def test_t2_backward_CPU_points( +def test_t2_backward_cuda_points( N: int, modifier: int, fftshift: bool, isign: int ) -> None: check_t2_backward(N, modifier, fftshift, isign, "cuda", True)