diff --git a/python/finufft/test/test_finufft_plan.py b/python/finufft/test/test_finufft_plan.py index 5df87c38e..4d8a57ee2 100644 --- a/python/finufft/test/test_finufft_plan.py +++ b/python/finufft/test/test_finufft_plan.py @@ -12,16 +12,18 @@ N_PTS = [10, 11] DTYPES = [np.complex64, np.complex128] OUTPUT_ARGS = [False, True] +MODEORDS = [0, 1] @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("shape", SHAPES) @pytest.mark.parametrize("n_pts", N_PTS) @pytest.mark.parametrize("output_arg", OUTPUT_ARGS) -def test_finufft1_plan(dtype, shape, n_pts, output_arg): +@pytest.mark.parametrize("modeord", MODEORDS) +def test_finufft1_plan(dtype, shape, n_pts, output_arg, modeord): pts, coefs = utils.type1_problem(dtype, shape, n_pts) - plan = Plan(1, shape, dtype=dtype) + plan = Plan(1, shape, dtype=dtype, modeord=modeord) plan.setpts(*pts) @@ -31,6 +33,9 @@ def test_finufft1_plan(dtype, shape, n_pts, output_arg): sig = np.empty(shape, dtype=dtype) plan.execute(coefs, out=sig) + if modeord == 1: + sig = np.fft.fftshift(sig) + utils.verify_type1(pts, coefs, shape, sig, 1e-6) @@ -38,18 +43,24 @@ def test_finufft1_plan(dtype, shape, n_pts, output_arg): @pytest.mark.parametrize("shape", SHAPES) @pytest.mark.parametrize("n_pts", N_PTS) @pytest.mark.parametrize("output_arg", OUTPUT_ARGS) -def test_finufft2_plan(dtype, shape, n_pts, output_arg): +@pytest.mark.parametrize("modeord", MODEORDS) +def test_finufft2_plan(dtype, shape, n_pts, output_arg, modeord): pts, sig = utils.type2_problem(dtype, shape, n_pts) - plan = Plan(2, shape, dtype=dtype) + plan = Plan(2, shape, dtype=dtype, modeord=modeord) plan.setpts(*pts) + if modeord == 1: + _sig = np.fft.ifftshift(sig) + else: + _sig = sig + if not output_arg: - coefs = plan.execute(sig) + coefs = plan.execute(_sig) else: coefs = np.empty(n_pts, dtype=dtype) - plan.execute(sig, out=coefs) + plan.execute(_sig, out=coefs) utils.verify_type2(pts, sig, coefs, 1e-6) @@ -76,24 +87,6 @@ def test_finufft3_plan(dtype, dim, n_source_pts, n_target_pts, output_arg): utils.verify_type3(source_pts, source_coefs, target_pts, target_coefs, 1e-6) -def test_finufft_plan_modeord(): - dtype = "complex64" - shape = (8, 8) - n_pts = 17 - - plan = Plan(1, shape, dtype=dtype, modeord=1) - - pts, coefs = utils.type1_problem(dtype, shape, n_pts) - - plan.setpts(*pts) - - sig = plan.execute(coefs) - - sig = np.fft.fftshift(sig) - - utils.verify_type1(pts, coefs, shape, sig, 1e-6) - - def test_finufft_plan_errors(): with pytest.raises(RuntimeError, match="must be single or double"): Plan(1, (8, 8), dtype="uint32")