From c329e01566ee377887f7601d99302b456d5c23c7 Mon Sep 17 00:00:00 2001
From: Michael Eickenberg <michael.eickenberg@gmail.com>
Date: Thu, 12 Oct 2023 12:20:12 -0400
Subject: [PATCH] TST add mode ordering to type2 forward test 2d

---
 tests/test_2d/test_forward_2d.py | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/tests/test_2d/test_forward_2d.py b/tests/test_2d/test_forward_2d.py
index 3a94f65..ba84f45 100644
--- a/tests/test_2d/test_forward_2d.py
+++ b/tests/test_2d/test_forward_2d.py
@@ -114,7 +114,8 @@ def test_2d_t2_forward_CPU(N: int) -> None:
 
 
 @pytest.mark.parametrize("N", Ns)
-def test_t2_forward_CPU(N: int) -> None:
+@pytest.mark.parametrize("fftshift", [False, True])
+def test_t2_forward_CPU(N: int, fftshift: bool) -> None:
     """
     Tests against implementations of the FFT by setting up a uniform grid
     over which to call FINUFFT through the API.
@@ -131,9 +132,14 @@ def test_t2_forward_CPU(N: int) -> None:
     finufft_out = pytorch_finufft.functional.finufft_type2.apply(
         points,
         targets,
+        None,
+        {'modeord': int(not fftshift)},
     )
 
-    against_torch = torch.fft.fft2(targets)
+    if fftshift:
+        against_torch = torch.fft.fft2(torch.fft.ifftshift(targets))
+    else:
+        against_torch = torch.fft.fft2(targets)
 
     abs_errors = torch.abs(finufft_out - against_torch.ravel())
     l_inf_error = abs_errors.max()