From 290a8000fe386a626ff2af011604468277af57a5 Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Thu, 5 Oct 2023 15:10:11 -0400 Subject: [PATCH] TST tests passing for 1D. CAVEAT: we may need to check cases for 1-d arrays and add a channel axis somewhere --- tests/test_1d/test_forward_1d.py | 35 ++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/test_1d/test_forward_1d.py b/tests/test_1d/test_forward_1d.py index 9050c52..5379927 100644 --- a/tests/test_1d/test_forward_1d.py +++ b/tests/test_1d/test_forward_1d.py @@ -106,6 +106,41 @@ def test_1d_t2_forward_CPU(targets: torch.Tensor): ) +@pytest.mark.parametrize("N", Ns) +def test_t1_forward_CPU(N: int) -> None: + """ + Tests against implementations of the FFT by setting up a uniform grid + over which to call FINUFFT through the API. + """ + g = np.mgrid[:N] * 2 * np.pi / N + g.shape = 1, -1 + points = torch.from_numpy(g.reshape(1, -1)) + + values = torch.randn(*points[0].shape, dtype=torch.complex128) + + print("N is " + str(N)) + print("shape of points is " + str(points.shape)) + print("shape of values is " + str(values.shape)) + + finufft_out = pytorch_finufft.functional.finufft_type1.apply( + points, + values, + (N,), + ) + + against_torch = torch.fft.fft(values.reshape(g[0].shape)) + + abs_errors = torch.abs(finufft_out - against_torch) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 4.5e-5 * N + assert l_2_error < 1e-5 * N ** 2 + assert l_1_error < 1e-5 * N ** 3 + + + # @pytest.mark.parametrize("values", cases) # def test_1d_t3_forward_CPU(values: torch.Tensor) -> None: # """