From 83ab0ba0161ea78e24146a034d967abc37ccaebe Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Thu, 5 Oct 2023 00:11:04 -0400 Subject: [PATCH 01/10] FIX tighten the forward tests according to observed error wrt array size --- tests/test_1d/test_forward_1d.py | 10 ++++++++ tests/test_2d/test_forward_2d.py | 39 ++++++++++++-------------------- tests/test_3d/test_forward_3d.py | 23 ++++++++++++++----- 3 files changed, 42 insertions(+), 30 deletions(-) diff --git a/tests/test_1d/test_forward_1d.py b/tests/test_1d/test_forward_1d.py index 1bbd8d0..9050c52 100644 --- a/tests/test_1d/test_forward_1d.py +++ b/tests/test_1d/test_forward_1d.py @@ -67,6 +67,16 @@ def test_1d_t1_forward_CPU(values: torch.Tensor) -> None: ) == pytest.approx(0, abs=1e-06) + abs_errors = torch.abs(finufft1D1_out - against_torch) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 3.5e-3 * N ** .6 + assert l_2_error < 7.5e-4 * N ** 1.1 + assert l_1_error < 5e-4 * N ** 1.6 + + @pytest.mark.parametrize("targets", cases) def test_1d_t2_forward_CPU(targets: torch.Tensor): """ diff --git a/tests/test_2d/test_forward_2d.py b/tests/test_2d/test_forward_2d.py index 3a61938..5cd74dd 100644 --- a/tests/test_2d/test_forward_2d.py +++ b/tests/test_2d/test_forward_2d.py @@ -45,28 +45,14 @@ def test_2d_t1_forward_CPU(N: int) -> None: against_torch = torch.fft.fft2(values.reshape(g[0].shape)) - assert abs((finufft_out - against_torch).sum()) / (N**3) == pytest.approx( - 0, abs=1e-6 - ) + abs_errors = torch.abs(finufft_out - against_torch) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) - values = torch.randn(*x.shape, dtype=torch.complex64) - - finufft_out = pytorch_finufft.functional.finufft2D1.apply( - torch.from_numpy(x).to(torch.float32), - torch.from_numpy(y).to(torch.float32), - values, - N, - ) - - against_torch = torch.fft.fft2(values.reshape(g[0].shape)) - - # NOTE -- the below tolerance is set to 1e-5 instead of -6 due - # to the occasional failing case that seems to be caused by - # the randomness of the test cases in addition to the expected - # accruation of numerical inaccuracies - assert abs((finufft_out - against_torch).sum()) / (N**3) == pytest.approx( - 0, abs=1e-5 - ) + assert l_inf_error < 5e-5 * N + assert l_2_error < 1e-5 * N ** 2 + assert l_1_error < 1e-5 * N ** 3 @pytest.mark.parametrize("N", Ns) @@ -109,9 +95,14 @@ def test_2d_t2_forward_CPU(N: int) -> None: against_torch = torch.fft.ifft2(values) - assert abs((finufft_out - against_torch).sum()) / (N**3) == pytest.approx( - 0, abs=1e-6 - ) + abs_errors = torch.abs(finufft_out - against_torch) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 1e-5 * N + assert l_2_error < 1e-5 * N ** 2 + assert l_1_error < 1e-5 * N ** 3 # @pytest.mark.parametrize("N", Ns) diff --git a/tests/test_3d/test_forward_3d.py b/tests/test_3d/test_forward_3d.py index e470abc..ae8f5f6 100644 --- a/tests/test_3d/test_forward_3d.py +++ b/tests/test_3d/test_forward_3d.py @@ -41,9 +41,15 @@ def test_3d_t1_forward_CPU(N: int) -> None: against_torch = torch.fft.fftn(values.reshape(g[0].shape)) - assert abs((finufft_out - against_torch).sum()) / (N**4) == pytest.approx( - 0, abs=1e-6 - ) + abs_errors = torch.abs(finufft_out - against_torch) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 2e-5 * N ** 1.5 + assert l_2_error < 1e-5 * N ** 3 + assert l_1_error < 1e-5 * N ** 4.5 + @pytest.mark.parametrize("N", Ns) @@ -69,6 +75,11 @@ def test_3d_t2_forward_CPU(N: int) -> None: against_torch = torch.fft.ifftn(values) - assert (abs((finufft_out - against_torch).sum())) / (N**4) == pytest.approx( - 0, abs=1e-6 - ) + abs_errors = torch.abs(finufft_out - against_torch) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 1e-5 * N ** 1.5 + assert l_2_error < 1e-5 * N ** 3 + assert l_1_error < 1e-5 * N ** 4.5 From 937d7ef5882c4ea9c44991ca156d81d2ab83ac94 Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Wed, 4 Oct 2023 12:36:16 -0400 Subject: [PATCH 02/10] WIP outline of consolidation idea for type 1 --- pytorch_finufft/functional.py | 148 ++++++++++++++++++++++++++++++++++ 1 file changed, 148 insertions(+) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index 96551b5..a63289e 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -1592,3 +1592,151 @@ def backward( None, None, ) + + + + + +############################################################################### +# Consolidated forward function for all 1D, 2D, and 3D problems for nufft type 1 +############################################################################### + +# This function takes a ctx object and is supposed to later replace all type 1 +# functions above for all dimensionalities. + +def finufft_type1( + ctx: Any, + points: torch.Tensor, + values: torch.Tensor, + output_shape: Union[int, tuple[int, int], tuple[int, int, int]], + out: Optional[torch.Tensor]=None, + fftshift: bool=False, + finufftkwargs: dict[str, Union[int, float]]=None): + """ + Evaluates the Type 1 NUFFT on the inputs. + + """ + + if out is not None: + print("In-place results are not yet implemented") + # All this requires is a check on the out array to make sure it is the + # correct shape. + + err._type1_checks(*points.T, values, output_shape) # revisit these error checks to take into account the shape of points instead of passing them separately + + + if finufftkwargs is None: + finufftkwargs = dict() + finufftkwargs = {k: v for k, v in finufftkwargs.items()} + _mode_ordering = finufftkwargs.pop("modeord", 1) + _i_sign = finufftkwargs.pop("isign", -1) + + if fftshift: + # TODO -- this check should be done elsewhere? or error msg changed + # to note instead that there is a conflict in fftshift + if _mode_ordering != 1: + raise ValueError( + "Double specification of ordering; only one of fftshift and modeord should be provided" + ) + _mode_ordering = 0 + + ctx.save_for_backward(*points.T, values) + + ctx.isign = _i_sign + ctx.mode_ordering = _mode_ordering + ctx.finufftkwargs = finufftkwargs + + finufft_out = torch.from_numpy( + finufft.nufft3d1( + *points.data.T.numpy(), + values.data.numpy(), + output_shape, + modeord=_mode_ordering, + isign=_i_sign, + **finufftkwargs, + ) + ) + + return finufft_out + + + + +def backward_type1( + ctx: Any, grad_output: torch.Tensor +) -> tuple[Union[torch.Tensor, None], ...]: + """ + Implements derivatives wrt. each argument in the forward method. + + Parameters + ---------- + ctx : Any + Pytorch context object. + grad_output : torch.Tensor + Backpass gradient output + + Returns + ------- + tuple[Union[torch.Tensor, None], ...] + A tuple of derivatives wrt. each argument in the forward method + """ + _i_sign = -1 * ctx.isign + _mode_ordering = ctx.mode_ordering + finufftkwargs = ctx.finufftkwargs + + + points, values = ctx.saved_tensors + + start_points = -np.array(grad_output.shape) // 2 + end_points = start_points + grad_output.shape + slices = tuple(slice(start, end) for start, end in zip(start_points, end_points)) + + coord_ramps = torch.mgrid[slices] + + grads_points = None + grad_values = None + + if ctx.needs_input_grad[0]: + # wrt points + + if _mode_ordering != 0: + coord_ramps = torch.fft.ifftshift(coord_ramps) + + ramped_grad_output = coord_ramps * grad_output[np.newaxis] * 1j * _i_sign + + grads_points = [] + for ramp in ramped_grad_output: # we can batch this into finufft + backprop_ramp = torch.from_numpy( + finufft.nufft3d2( + *points.T.numpy(), + ramp.data.numpy(), + isign=_i_sign, + modeord=_mode_ordering, + **finufftkwargs, + )) + grad_points = (backprop_ramp.conj() * values).real + grads_points.append(grad_points) + + grads_points = torch.stack(grads_points) + + if ctx.needs_input_grad[1]: + np_grad_output = grad_output.data.numpy() + + grad_values = torch.from_numpy( + finufft.nufft3d2( + *points.T.numpy() + np_grad_output, + isign=_i_sign, + modeord=_mode_ordering, + **finufftkwargs, + ) + ) + + return ( + grads_points, + grad_values, + None, + None, + None, + None, + ) From 48410fff8f6851685384effe18c2a56ae60de796 Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Wed, 4 Oct 2023 12:41:41 -0400 Subject: [PATCH 03/10] WIP bring forward and backward into one class --- pytorch_finufft/functional.py | 217 +++++++++++++++++----------------- 1 file changed, 107 insertions(+), 110 deletions(-) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index a63289e..0546f62 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -1601,142 +1601,139 @@ def backward( # Consolidated forward function for all 1D, 2D, and 3D problems for nufft type 1 ############################################################################### -# This function takes a ctx object and is supposed to later replace all type 1 -# functions above for all dimensionalities. - -def finufft_type1( - ctx: Any, - points: torch.Tensor, - values: torch.Tensor, - output_shape: Union[int, tuple[int, int], tuple[int, int, int]], - out: Optional[torch.Tensor]=None, - fftshift: bool=False, - finufftkwargs: dict[str, Union[int, float]]=None): - """ - Evaluates the Type 1 NUFFT on the inputs. +class finufft_type1(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + points: torch.Tensor, + values: torch.Tensor, + output_shape: Union[int, tuple[int, int], tuple[int, int, int]], + out: Optional[torch.Tensor]=None, + fftshift: bool=False, + finufftkwargs: dict[str, Union[int, float]]=None): + """ + Evaluates the Type 1 NUFFT on the inputs. - """ + """ - if out is not None: - print("In-place results are not yet implemented") - # All this requires is a check on the out array to make sure it is the - # correct shape. + if out is not None: + print("In-place results are not yet implemented") + # All this requires is a check on the out array to make sure it is the + # correct shape. - err._type1_checks(*points.T, values, output_shape) # revisit these error checks to take into account the shape of points instead of passing them separately + err._type1_checks(*points.T, values, output_shape) # revisit these error checks to take into account the shape of points instead of passing them separately - if finufftkwargs is None: - finufftkwargs = dict() - finufftkwargs = {k: v for k, v in finufftkwargs.items()} - _mode_ordering = finufftkwargs.pop("modeord", 1) - _i_sign = finufftkwargs.pop("isign", -1) + if finufftkwargs is None: + finufftkwargs = dict() + finufftkwargs = {k: v for k, v in finufftkwargs.items()} + _mode_ordering = finufftkwargs.pop("modeord", 1) + _i_sign = finufftkwargs.pop("isign", -1) - if fftshift: - # TODO -- this check should be done elsewhere? or error msg changed - # to note instead that there is a conflict in fftshift - if _mode_ordering != 1: - raise ValueError( - "Double specification of ordering; only one of fftshift and modeord should be provided" - ) - _mode_ordering = 0 + if fftshift: + # TODO -- this check should be done elsewhere? or error msg changed + # to note instead that there is a conflict in fftshift + if _mode_ordering != 1: + raise ValueError( + "Double specification of ordering; only one of fftshift and modeord should be provided" + ) + _mode_ordering = 0 - ctx.save_for_backward(*points.T, values) + ctx.save_for_backward(*points.T, values) - ctx.isign = _i_sign - ctx.mode_ordering = _mode_ordering - ctx.finufftkwargs = finufftkwargs + ctx.isign = _i_sign + ctx.mode_ordering = _mode_ordering + ctx.finufftkwargs = finufftkwargs - finufft_out = torch.from_numpy( - finufft.nufft3d1( - *points.data.T.numpy(), - values.data.numpy(), - output_shape, - modeord=_mode_ordering, - isign=_i_sign, - **finufftkwargs, + finufft_out = torch.from_numpy( + finufft.nufft3d1( + *points.data.T.numpy(), + values.data.numpy(), + output_shape, + modeord=_mode_ordering, + isign=_i_sign, + **finufftkwargs, + ) ) - ) - return finufft_out + return finufft_out + @staticmethod + def backward( + ctx: Any, grad_output: torch.Tensor + ) -> tuple[Union[torch.Tensor, None], ...]: + """ + Implements derivatives wrt. each argument in the forward method. + Parameters + ---------- + ctx : Any + Pytorch context object. + grad_output : torch.Tensor + Backpass gradient output + Returns + ------- + tuple[Union[torch.Tensor, None], ...] + A tuple of derivatives wrt. each argument in the forward method + """ + _i_sign = -1 * ctx.isign + _mode_ordering = ctx.mode_ordering + finufftkwargs = ctx.finufftkwargs -def backward_type1( - ctx: Any, grad_output: torch.Tensor -) -> tuple[Union[torch.Tensor, None], ...]: - """ - Implements derivatives wrt. each argument in the forward method. - - Parameters - ---------- - ctx : Any - Pytorch context object. - grad_output : torch.Tensor - Backpass gradient output - - Returns - ------- - tuple[Union[torch.Tensor, None], ...] - A tuple of derivatives wrt. each argument in the forward method - """ - _i_sign = -1 * ctx.isign - _mode_ordering = ctx.mode_ordering - finufftkwargs = ctx.finufftkwargs + points, values = ctx.saved_tensors - points, values = ctx.saved_tensors + start_points = -np.array(grad_output.shape) // 2 + end_points = start_points + grad_output.shape + slices = tuple(slice(start, end) for start, end in zip(start_points, end_points)) - start_points = -np.array(grad_output.shape) // 2 - end_points = start_points + grad_output.shape - slices = tuple(slice(start, end) for start, end in zip(start_points, end_points)) + coord_ramps = torch.mgrid[slices] - coord_ramps = torch.mgrid[slices] + grads_points = None + grad_values = None - grads_points = None - grad_values = None + if ctx.needs_input_grad[0]: + # wrt points - if ctx.needs_input_grad[0]: - # wrt points + if _mode_ordering != 0: + coord_ramps = torch.fft.ifftshift(coord_ramps) + + ramped_grad_output = coord_ramps * grad_output[np.newaxis] * 1j * _i_sign + + grads_points = [] + for ramp in ramped_grad_output: # we can batch this into finufft + backprop_ramp = torch.from_numpy( + finufft.nufft3d2( + *points.T.numpy(), + ramp.data.numpy(), + isign=_i_sign, + modeord=_mode_ordering, + **finufftkwargs, + )) + grad_points = (backprop_ramp.conj() * values).real + grads_points.append(grad_points) + + grads_points = torch.stack(grads_points) - if _mode_ordering != 0: - coord_ramps = torch.fft.ifftshift(coord_ramps) - - ramped_grad_output = coord_ramps * grad_output[np.newaxis] * 1j * _i_sign + if ctx.needs_input_grad[1]: + np_grad_output = grad_output.data.numpy() - grads_points = [] - for ramp in ramped_grad_output: # we can batch this into finufft - backprop_ramp = torch.from_numpy( + grad_values = torch.from_numpy( finufft.nufft3d2( - *points.T.numpy(), - ramp.data.numpy(), + *points.T.numpy() + np_grad_output, isign=_i_sign, modeord=_mode_ordering, **finufftkwargs, - )) - grad_points = (backprop_ramp.conj() * values).real - grads_points.append(grad_points) - - grads_points = torch.stack(grads_points) - - if ctx.needs_input_grad[1]: - np_grad_output = grad_output.data.numpy() - - grad_values = torch.from_numpy( - finufft.nufft3d2( - *points.T.numpy() - np_grad_output, - isign=_i_sign, - modeord=_mode_ordering, - **finufftkwargs, + ) ) - ) - return ( - grads_points, - grad_values, - None, - None, - None, - None, - ) + return ( + grads_points, + grad_values, + None, + None, + None, + None, + ) From 745c343d72f1ccc5b2b3907746a6ca13e91a6640 Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Wed, 4 Oct 2023 15:25:37 -0400 Subject: [PATCH 04/10] WIP adding tests but found bug in other tests, going to fix those first rebased on top of those changes --- pytorch_finufft/functional.py | 21 +++++++++++++++------ tests/test_2d/test_forward_2d.py | 28 ++++++++++++++++++++++++++++ tests/test_3d/test_forward_3d.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 6 deletions(-) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index 0546f62..ab88103 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -1601,6 +1601,10 @@ def backward( # Consolidated forward function for all 1D, 2D, and 3D problems for nufft type 1 ############################################################################### +def get_nufft_func(dim, nufft_type): + return getattr(finufft, f"nufft{dim}d{nufft_type}") + + class finufft_type1(torch.autograd.Function): @staticmethod def forward( @@ -1621,8 +1625,8 @@ def forward( # All this requires is a check on the out array to make sure it is the # correct shape. - err._type1_checks(*points.T, values, output_shape) # revisit these error checks to take into account the shape of points instead of passing them separately - + err._type1_checks(points, values, output_shape) # revisit these error checks to take into account the shape of points instead of passing them separately + # ^ make sure these checks check for consistency between output shape and len(points) if finufftkwargs is None: finufftkwargs = dict() @@ -1645,9 +1649,14 @@ def forward( ctx.mode_ordering = _mode_ordering ctx.finufftkwargs = finufftkwargs + # this below should be a pre-check + ndim = points.shape[0] + assert len(output_shape) == ndim + + nufft_func = get_nufft_func(ndim, 1) finufft_out = torch.from_numpy( - finufft.nufft3d1( - *points.data.T.numpy(), + nufft_func( + *points.data.numpy(), values.data.numpy(), output_shape, modeord=_mode_ordering, @@ -1705,7 +1714,7 @@ def backward( for ramp in ramped_grad_output: # we can batch this into finufft backprop_ramp = torch.from_numpy( finufft.nufft3d2( - *points.T.numpy(), + *points.numpy(), ramp.data.numpy(), isign=_i_sign, modeord=_mode_ordering, @@ -1721,7 +1730,7 @@ def backward( grad_values = torch.from_numpy( finufft.nufft3d2( - *points.T.numpy() + *points.numpy(), np_grad_output, isign=_i_sign, modeord=_mode_ordering, diff --git a/tests/test_2d/test_forward_2d.py b/tests/test_2d/test_forward_2d.py index 5cd74dd..343e279 100644 --- a/tests/test_2d/test_forward_2d.py +++ b/tests/test_2d/test_forward_2d.py @@ -119,3 +119,31 @@ def test_2d_t2_forward_CPU(N: int) -> None: # assert abs((f - comparison).sum()) / (N**3) == pytest.approx(0, abs=1e-6) # pass + + +@pytest.mark.parametrize("N", Ns) +def test_t1_forward_CPU(N: int) -> None: + """ + Tests against implementations of the FFT by setting up a uniform grid + over which to call FINUFFT through the API. + """ + g = np.mgrid[:N, :N] * 2 * np.pi / N + points = torch.from_numpy(g.reshape(2, -1)) + + values = torch.randn(*points[0].shape, dtype=torch.complex128) + + print("N is " + str(N)) + print("shape of points is " + str(points.shape)) + print("shape of values is " + str(values.shape)) + + finufft_out = pytorch_finufft.functional.finufft_type1.apply( + points, + values, + (N, N), + ) + + against_torch = torch.fft.fft2(values.reshape(g[0].shape)) + + assert (finufft_out - against_torch) == pytest.approx( + 0, abs=1e-6 + ) diff --git a/tests/test_3d/test_forward_3d.py b/tests/test_3d/test_forward_3d.py index ae8f5f6..0115f9c 100644 --- a/tests/test_3d/test_forward_3d.py +++ b/tests/test_3d/test_forward_3d.py @@ -83,3 +83,31 @@ def test_3d_t2_forward_CPU(N: int) -> None: assert l_inf_error < 1e-5 * N ** 1.5 assert l_2_error < 1e-5 * N ** 3 assert l_1_error < 1e-5 * N ** 4.5 + + +@pytest.mark.parametrize("N", Ns) +def test_t1_forward_CPU(N: int) -> None: + """ + Tests against implementations of the FFT by setting up a uniform grid + over which to call FINUFFT through the API. + """ + g = np.mgrid[:N, :N, :N] * 2 * np.pi / N + points = torch.from_numpy(g.reshape(3, -1)) + + values = torch.randn(*points[0].shape, dtype=torch.complex128) + + print("N is " + str(N)) + print("shape of points is " + str(points.shape)) + print("shape of values is " + str(values.shape)) + + finufft_out = pytorch_finufft.functional.finufft_type1.apply( + points, + values, + (N, N, N), + ) + + against_torch = torch.fft.fft2(values.reshape(g[0].shape)) + + assert abs((finufft_out - against_torch).sum()) / N**3 == pytest.approx( + 0, abs=1e-6 + ) From edacea0df59fce7908b31958a46cb25361b789bf Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Thu, 5 Oct 2023 14:59:39 -0400 Subject: [PATCH 05/10] TST test for forward type 1 passing in 2D, also fixed random seed --- tests/test_2d/test_forward_2d.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/test_2d/test_forward_2d.py b/tests/test_2d/test_forward_2d.py index 343e279..1dda568 100644 --- a/tests/test_2d/test_forward_2d.py +++ b/tests/test_2d/test_forward_2d.py @@ -1,6 +1,7 @@ import numpy as np import pytest import torch +torch.manual_seed(0) import pytorch_finufft @@ -144,6 +145,12 @@ def test_t1_forward_CPU(N: int) -> None: against_torch = torch.fft.fft2(values.reshape(g[0].shape)) - assert (finufft_out - against_torch) == pytest.approx( - 0, abs=1e-6 - ) + abs_errors = torch.abs(finufft_out - against_torch) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 4.5e-5 * N + assert l_2_error < 1e-5 * N ** 2 + assert l_1_error < 1e-5 * N ** 3 + From 2b3ac0fe82d536028291e37008055713d6885f78 Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Thu, 5 Oct 2023 15:04:57 -0400 Subject: [PATCH 06/10] TST tests passing for forward type 1 3d consolidated --- tests/test_3d/test_forward_3d.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/test_3d/test_forward_3d.py b/tests/test_3d/test_forward_3d.py index 0115f9c..45484aa 100644 --- a/tests/test_3d/test_forward_3d.py +++ b/tests/test_3d/test_forward_3d.py @@ -1,6 +1,7 @@ import numpy as np import pytest import torch +torch.manual_seed(0) import pytorch_finufft @@ -13,8 +14,6 @@ 25, 26, 37, - 100, - 101, ] @@ -106,8 +105,13 @@ def test_t1_forward_CPU(N: int) -> None: (N, N, N), ) - against_torch = torch.fft.fft2(values.reshape(g[0].shape)) + against_torch = torch.fft.fftn(values.reshape(g[0].shape)) - assert abs((finufft_out - against_torch).sum()) / N**3 == pytest.approx( - 0, abs=1e-6 - ) + abs_errors = torch.abs(finufft_out - against_torch) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 1.5e-5 * N ** 1.5 + assert l_2_error < 1e-5 * N ** 3 + assert l_1_error < 1e-5 * N ** 4.5 \ No newline at end of file From 290a8000fe386a626ff2af011604468277af57a5 Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Thu, 5 Oct 2023 15:10:11 -0400 Subject: [PATCH 07/10] TST tests passing for 1D. CAVEAT: we may need to check cases for 1-d arrays and add a channel axis somewhere --- tests/test_1d/test_forward_1d.py | 35 ++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/test_1d/test_forward_1d.py b/tests/test_1d/test_forward_1d.py index 9050c52..5379927 100644 --- a/tests/test_1d/test_forward_1d.py +++ b/tests/test_1d/test_forward_1d.py @@ -106,6 +106,41 @@ def test_1d_t2_forward_CPU(targets: torch.Tensor): ) +@pytest.mark.parametrize("N", Ns) +def test_t1_forward_CPU(N: int) -> None: + """ + Tests against implementations of the FFT by setting up a uniform grid + over which to call FINUFFT through the API. + """ + g = np.mgrid[:N] * 2 * np.pi / N + g.shape = 1, -1 + points = torch.from_numpy(g.reshape(1, -1)) + + values = torch.randn(*points[0].shape, dtype=torch.complex128) + + print("N is " + str(N)) + print("shape of points is " + str(points.shape)) + print("shape of values is " + str(values.shape)) + + finufft_out = pytorch_finufft.functional.finufft_type1.apply( + points, + values, + (N,), + ) + + against_torch = torch.fft.fft(values.reshape(g[0].shape)) + + abs_errors = torch.abs(finufft_out - against_torch) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 4.5e-5 * N + assert l_2_error < 1e-5 * N ** 2 + assert l_1_error < 1e-5 * N ** 3 + + + # @pytest.mark.parametrize("values", cases) # def test_1d_t3_forward_CPU(values: torch.Tensor) -> None: # """ From b1b8bf9f7abd89eda69939d082ac9d1feb26a763 Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Thu, 5 Oct 2023 15:34:34 -0400 Subject: [PATCH 08/10] ENH/FIX made modifications to make 2D type 1 backward tests pass --- pytorch_finufft/functional.py | 14 ++++++++++---- tests/test_2d/test_backward_2d.py | 25 +++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index ab88103..837d281 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -4,6 +4,7 @@ from typing import Any, Dict, Optional, Tuple, Union +import numpy as np import finufft import torch @@ -1643,7 +1644,7 @@ def forward( ) _mode_ordering = 0 - ctx.save_for_backward(*points.T, values) + ctx.save_for_backward(points, values) ctx.isign = _i_sign ctx.mode_ordering = _mode_ordering @@ -1697,11 +1698,16 @@ def backward( end_points = start_points + grad_output.shape slices = tuple(slice(start, end) for start, end in zip(start_points, end_points)) - coord_ramps = torch.mgrid[slices] + # CPU idiosyncracy that needs to be done differently + coord_ramps = torch.from_numpy(np.mgrid[slices]) grads_points = None grad_values = None + ndim = points.shape[0] + + nufft_func = get_nufft_func(ndim, 2) + if ctx.needs_input_grad[0]: # wrt points @@ -1713,7 +1719,7 @@ def backward( grads_points = [] for ramp in ramped_grad_output: # we can batch this into finufft backprop_ramp = torch.from_numpy( - finufft.nufft3d2( + nufft_func( *points.numpy(), ramp.data.numpy(), isign=_i_sign, @@ -1729,7 +1735,7 @@ def backward( np_grad_output = grad_output.data.numpy() grad_values = torch.from_numpy( - finufft.nufft3d2( + nufft_func( *points.numpy(), np_grad_output, isign=_i_sign, diff --git a/tests/test_2d/test_backward_2d.py b/tests/test_2d/test_backward_2d.py index ce7122a..3561eae 100644 --- a/tests/test_2d/test_backward_2d.py +++ b/tests/test_2d/test_backward_2d.py @@ -5,8 +5,11 @@ import pytorch_finufft +from functools import partial + torch.set_default_tensor_type(torch.DoubleTensor) torch.set_default_dtype(torch.float64) +torch.manual_seed(0) ###################################################################### # APPLY WRAPPERS @@ -97,6 +100,28 @@ def test_t1_backward_CPU_values( assert gradcheck(apply_finufft2d1(modifier, fftshift, isign), inputs) +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t1_consolidated_backward_CPU_values(N: int, modifier: int, fftshift: bool, isign: int) -> None: + + points = torch.rand((2, N), dtype=torch.float64) * 2 * np.pi + values = torch.randn(N, dtype=torch.complex128) + + points.requires_grad = False + values.requires_grad = True + + inputs = (points, values) + + def func(points, values): + return pytorch_finufft.functional.finufft_type1.apply( + points, values, (N,N + modifier), None, fftshift, dict(isign=isign) + ) + + assert gradcheck(func, inputs) + + @pytest.mark.parametrize("N", Ns) @pytest.mark.parametrize("modifier", length_modifiers) @pytest.mark.parametrize("fftshift", [True, False]) From e1ce8f97d7215fa7da1d5edd733e40e2e257c26b Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Thu, 5 Oct 2023 16:35:59 -0400 Subject: [PATCH 09/10] BUG backward isn't working. Commit before adding debug code --- pytorch_finufft/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index 837d281..da5fd39 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -1694,7 +1694,7 @@ def backward( points, values = ctx.saved_tensors - start_points = -np.array(grad_output.shape) // 2 + start_points = -(np.array(grad_output.shape) // 2) end_points = start_points + grad_output.shape slices = tuple(slice(start, end) for start, end in zip(start_points, end_points)) From f63c4d95d6cb574b8eda5c2e4205a2b5a624626c Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Thu, 5 Oct 2023 17:18:27 -0400 Subject: [PATCH 10/10] FIX consolidated backward working for 2D now --- pytorch_finufft/functional.py | 3 +-- tests/test_2d/test_backward_2d.py | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index da5fd39..7b40e71 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -1691,7 +1691,6 @@ def backward( _mode_ordering = ctx.mode_ordering finufftkwargs = ctx.finufftkwargs - points, values = ctx.saved_tensors start_points = -(np.array(grad_output.shape) // 2) @@ -1712,7 +1711,7 @@ def backward( # wrt points if _mode_ordering != 0: - coord_ramps = torch.fft.ifftshift(coord_ramps) + coord_ramps = torch.fft.ifftshift(coord_ramps, dim=tuple(range(1, ndim+1))) ramped_grad_output = coord_ramps * grad_output[np.newaxis] * 1j * _i_sign diff --git a/tests/test_2d/test_backward_2d.py b/tests/test_2d/test_backward_2d.py index 3561eae..6a9b707 100644 --- a/tests/test_2d/test_backward_2d.py +++ b/tests/test_2d/test_backward_2d.py @@ -122,6 +122,28 @@ def func(points, values): assert gradcheck(func, inputs) +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t1_consolidated_backward_CPU_points(N: int, modifier: int, fftshift: bool, isign: int) -> None: + + points = torch.rand((2, N), dtype=torch.float64) * 2 * np.pi + values = torch.randn(N, dtype=torch.complex128) + + points.requires_grad = True + values.requires_grad = False + + inputs = (points, values) + + def func(points, values): + return pytorch_finufft.functional.finufft_type1.apply( + points, values, (N,N + modifier), None, fftshift, dict(isign=isign) + ) + + assert gradcheck(func, inputs, atol=1e-5 * N) + + @pytest.mark.parametrize("N", Ns) @pytest.mark.parametrize("modifier", length_modifiers) @pytest.mark.parametrize("fftshift", [True, False])