From ec1fb47408f9d53e01f39072a0ce08705a2b71f5 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 19 Oct 2023 09:32:03 -0400 Subject: [PATCH] Add ifft helpers, basic tests --- docs/api.rst | 11 +++++ pytorch_finufft/functional.py | 31 +++++++++++++ tests/test_inverses.py | 86 +++++++++++++++++++++++++++++++++++ 3 files changed, 128 insertions(+) create mode 100644 tests/test_inverses.py diff --git a/docs/api.rst b/docs/api.rst index c1d1601..32484c2 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -6,3 +6,14 @@ API Reference .. autofunction:: pytorch_finufft.functional.finufft_type1 .. autofunction:: pytorch_finufft.functional.finufft_type2 + +Inverse Transform helper functions +---------------------------------- + +Both of these functions are provided merely as helpers, they call +the above functions just with different default arguments and scaling +to provide the equivalent of an ifft function. + +.. autofunction:: pytorch_finufft.functional.finuifft_type1 + +.. autofunction:: pytorch_finufft.functional.finuifft_type2 diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index ca42977..7758bc1 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -401,3 +401,34 @@ def finufft_type2( """ res: torch.Tensor = FinufftType2.apply(points, targets, finufftkwargs) return res + + +def finuifft_type1( + points: torch.Tensor, + values: torch.Tensor, + output_shape: Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]], + **finufftkwargs: Union[int, float], +) -> torch.Tensor: + """ + Equivalent to :func:`~pytorch_finufft.functional.finufft_type1`, + but passing ``isign=1`` and dividing by the size of the output. + """ + finufftkwargs["isign"] = 1 + res: torch.Tensor = finufft_type1(points, values, output_shape, **finufftkwargs) + res = res / res.numel() + return res + + +def finuifft_type2( + points: torch.Tensor, + targets: torch.Tensor, + **finufftkwargs: Union[int, float], +) -> torch.Tensor: + """ + Equivalent to :func:`~pytorch_finufft.functional.finufft_type2`, + but passing ``isign=1`` and dividing by the size of the output. + """ + finufftkwargs["isign"] = 1 + res: torch.Tensor = finufft_type2(points, targets, **finufftkwargs) + res = res / res.numel() + return res diff --git a/tests/test_inverses.py b/tests/test_inverses.py new file mode 100644 index 0000000..3e49417 --- /dev/null +++ b/tests/test_inverses.py @@ -0,0 +1,86 @@ +import numpy as np +import pytest +import torch + +import pytorch_finufft + + +def check_t2_ifft_undoes_t1(N: int, dim: int, device: str) -> None: + """ + Tests that nuifft_type2 undoes nufft_type1 + """ + slices = tuple(slice(None, N) for _ in range(dim)) + g = np.mgrid[slices] * 2 * np.pi / N + points = torch.from_numpy(g.reshape(dim, -1)).to(device) + + values = torch.randn(*points[0].shape, dtype=torch.complex128).to(device) + + print("N is " + str(N)) + print("dim is " + str(dim)) + print("shape of points is " + str(points.shape)) + print("shape of values is " + str(values.shape)) + + finufft_out = pytorch_finufft.functional.finufft_type1( + points, + values, + tuple(N for _ in range(dim)), + ) + + back = pytorch_finufft.functional.finuifft_type2( + points, + finufft_out, + ) + + np.testing.assert_allclose(values.cpu().numpy(), back.cpu().numpy(), atol=1e-4) + + +Ns = [ + 5, + 10, + 15, + 100, + 101, +] + +dims = [1, 2, 3] + + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("dim", dims) +def test_t2_ifft_undoes_t1_forward_CPU(N, dim): + check_t2_ifft_undoes_t1(N, dim, "cpu") + + +def check_t1_ifft_undoes_t2(N: int, dim: int, device: str) -> None: + """ + Tests that nuifft_type1 undoes nufft_type2 + """ + slices = tuple(slice(None, N) for _ in range(dim)) + g = np.mgrid[slices] * 2 * np.pi / N + points = torch.from_numpy(g.reshape(g.shape[0], -1)).to(device) + + targets = torch.randn(*g[0].shape, dtype=torch.complex128).to(device) + + print("N is " + str(N)) + print("dim is " + str(dim)) + print("shape of points is " + str(points.shape)) + print("shape of targets is " + str(targets.shape)) + + finufft_out = pytorch_finufft.functional.finufft_type2( + points, + targets, + ) + + back = pytorch_finufft.functional.finuifft_type1( + points, + finufft_out, + tuple(N for _ in range(dim)), + ) + + np.testing.assert_allclose(targets.cpu().numpy(), back.cpu().numpy(), atol=1e-4) + + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("dim", dims) +def test_t1_ifft_undoes_t2_forward_CPU(N, dim): + check_t1_ifft_undoes_t2(N, dim, "cpu")