From c329e01566ee377887f7601d99302b456d5c23c7 Mon Sep 17 00:00:00 2001 From: Michael Eickenberg <michael.eickenberg@gmail.com> Date: Thu, 12 Oct 2023 12:20:12 -0400 Subject: [PATCH] TST add mode ordering to type2 forward test 2d --- tests/test_2d/test_forward_2d.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/test_2d/test_forward_2d.py b/tests/test_2d/test_forward_2d.py index 3a94f65..ba84f45 100644 --- a/tests/test_2d/test_forward_2d.py +++ b/tests/test_2d/test_forward_2d.py @@ -114,7 +114,8 @@ def test_2d_t2_forward_CPU(N: int) -> None: @pytest.mark.parametrize("N", Ns) -def test_t2_forward_CPU(N: int) -> None: +@pytest.mark.parametrize("fftshift", [False, True]) +def test_t2_forward_CPU(N: int, fftshift: bool) -> None: """ Tests against implementations of the FFT by setting up a uniform grid over which to call FINUFFT through the API. @@ -131,9 +132,14 @@ def test_t2_forward_CPU(N: int) -> None: finufft_out = pytorch_finufft.functional.finufft_type2.apply( points, targets, + None, + {'modeord': int(not fftshift)}, ) - against_torch = torch.fft.fft2(targets) + if fftshift: + against_torch = torch.fft.fft2(torch.fft.ifftshift(targets)) + else: + against_torch = torch.fft.fft2(targets) abs_errors = torch.abs(finufft_out - against_torch.ravel()) l_inf_error = abs_errors.max()