From f47f0e10098b4ce53928e017e4f4c6492433b2e8 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 13 Oct 2023 12:00:37 -0400 Subject: [PATCH] Refactor checks, add more tests --- pytorch_finufft/checks.py | 81 +++++++----------------- pytorch_finufft/functional.py | 24 ++++---- tests/test_errors.py | 112 ++++++++++++++++++++++++++++++++++ 3 files changed, 146 insertions(+), 71 deletions(-) diff --git a/pytorch_finufft/checks.py b/pytorch_finufft/checks.py index 95431a4..ce90500 100644 --- a/pytorch_finufft/checks.py +++ b/pytorch_finufft/checks.py @@ -21,31 +21,31 @@ def check_devices(*tensors: torch.Tensor) -> None: ) -def check_dtypes(values: torch.Tensor, points: torch.Tensor) -> None: +def check_dtypes(data: torch.Tensor, points: torch.Tensor, name: str) -> None: """ - Checks that values is complex-valued + Checks that data is complex-valued and that points is real-valued of the same precision """ - complex_dtype = values.dtype + complex_dtype = data.dtype if complex_dtype is torch.complex128: real_dtype = torch.float64 elif complex_dtype is torch.complex64: real_dtype = torch.float32 else: raise TypeError( - "Values must have a dtype of torch.complex64 or torch.complex128" + f"{name} must have a dtype of torch.complex64 or torch.complex128" ) if points.dtype is not real_dtype: raise TypeError( - f"Points must have a dtype of {real_dtype} as values has a dtype of " - f"{complex_dtype}" + f"Points must have a dtype of {real_dtype} as {name.lower()} has a " + f"dtype of {complex_dtype}" ) -def check_sizes(values: torch.Tensor, points: torch.Tensor) -> None: +def check_sizes_t1(values: torch.Tensor, points: torch.Tensor) -> None: """ - Checks that values and points are 1d and of the same length. + Checks that values and points of the same length. This is used in type1. """ if len(values.shape) != 1: @@ -82,60 +82,23 @@ def check_output_shape(ndim: int, output_shape: Union[int, Tuple[int, ...]]) -> raise ValueError("Got output_shape that was not positive integer") -### TODO delete the following post-consolidation - -_COORD_CHAR_TABLE = "xyz" - - -def _type2_checks(points_tuple: torch.Tensor, targets: torch.Tensor) -> None: +def check_sizes_t2(targets: torch.Tensor, points: torch.Tensor) -> None: """ - Performs all type, precision, size, device, ... checks for the - type 2 FINUFFT - - Parameters - ---------- - points_tuple : Tuple[torch.Tensor, ...] - A tuple of all points tensors. Eg, (points, ), or (points_x, points_y) - targets : torch.Tensor - The targets tensor from the forward call to FINUFFT - - Raises - ------ - TypeError - In the case that targets is not complex-valued - ValueError - In the case that targets is not of the correct shape - TypeError - In the case that any of the points tensors are not of the correct - type or the correct precision - ValueError - In the case that the i'th dimension of targets is not of the same - length as the i'th points tensor + Checks that targets and points are of the same dimension. + This is used in type2. """ - - if not torch.is_complex(targets): - raise TypeError("Got values that is not complex-valued") - - complex_dtype = targets.dtype - real_dtype = torch.float32 if complex_dtype is torch.complex64 else torch.float64 - - dimension = len(points_tuple) targets_dim = len(targets.shape) + if len(points.shape) == 1: + points_dim = 1 + elif len(points.shape) == 2: + points_dim = points.shape[0] + else: + raise ValueError("The points tensor must be 1d or 2d") - if dimension != targets_dim: + if points_dim not in {1, 2, 3}: + raise ValueError(f"Points can be at most 3d, got {points.shape} instead") + + if targets_dim != points_dim: raise ValueError( - f"For type 2 {dimension}d FINUFFT, targets must be a {dimension}d " "tensor" + f"For type 2 {points_dim}d FINUFFT, targets must be a {points_dim}d tensor" ) - - coord_char = "" - - # Check dtypes (complex vs. real) on the inputs - for i in range(dimension): - coord_char = "" if dimension == 1 else ("_" + _COORD_CHAR_TABLE[i]) - - if points_tuple[i].dtype is not real_dtype: - raise TypeError( - f"Got points{coord_char} that is not {real_dtype}-valued; " - f"points{coord_char} must also be the same precision as " - "targets." - ) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index 0109b66..f14287c 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -91,8 +91,8 @@ def forward( # type: ignore[override] raise NotImplementedError("In-place results are not yet implemented") checks.check_devices(values, points) - checks.check_dtypes(values, points) - checks.check_sizes(values, points) + checks.check_dtypes(values, points, "Values") + checks.check_sizes_t1(values, points) points = torch.atleast_2d(points) ndim = points.shape[0] checks.check_output_shape(ndim, output_shape) @@ -232,14 +232,14 @@ def forward( # type: ignore[override] """ if out is not None: - print("In-place results are not yet implemented") + raise NotImplementedError("In-place results are not yet implemented") - # TODO -- extend checks to 2d - checks._type2_checks(points, targets) + checks.check_devices(targets, points) + checks.check_dtypes(targets, points, "Targets") + checks.check_sizes_t2(targets, points) if finufftkwargs is None: finufftkwargs = dict() - finufftkwargs = {k: v for k, v in finufftkwargs.items()} _mode_ordering = finufftkwargs.pop( "modeord", 1 @@ -248,17 +248,17 @@ def forward( # type: ignore[override] "isign", -1 ) # isign=-1 is finufft default for type 2 - ndim = points.shape[0] - if _mode_ordering == 1: + points = torch.atleast_2d(points) + if _mode_ordering: targets = torch.fft.fftshift(targets) + ctx.save_for_backward(points, targets) + ctx.isign = _i_sign ctx.mode_ordering = _mode_ordering ctx.finufftkwargs = finufftkwargs - ctx.save_for_backward(points, targets) - - nufft_func = get_nufft_func(ndim, 2, points.device.type) + nufft_func = get_nufft_func(points.shape[0], 2, points.device.type) finufft_out = nufft_func( *points, @@ -325,7 +325,7 @@ def backward( # type: ignore[override] **finufftkwargs, ) - if _mode_ordering == 1: + if _mode_ordering: grad_targets = torch.fft.ifftshift(grad_targets) return ( diff --git a/tests/test_errors.py b/tests/test_errors.py index 9cc34ff..0386ab5 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1,5 +1,6 @@ import warnings +import numpy as np import pytest import torch @@ -26,6 +27,25 @@ def test_t1_mismatch_cuda_index() -> None: pytorch_finufft.functional.finufft_type1.apply(points, values, (10, 10)) +def test_t2_mismatch_device_cuda_cpu() -> None: + g = np.mgrid[:10, :10] * 2 * np.pi / 10 + points = torch.from_numpy(g.reshape(2, -1)) + targets = torch.randn(*g[0].shape, dtype=torch.complex128).to("cuda:0") + + with pytest.raises(ValueError, match="Some tensors are not on the same device"): + pytorch_finufft.functional.finufft_type1.apply(points, targets) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="require multiple GPUs") +def test_t2_mismatch_cuda_index() -> None: + g = np.mgrid[:10, :10] * 2 * np.pi / 10 + points = torch.from_numpy(g.reshape(2, -1)).to("cuda:0") + targets = torch.randn(*g[0].shape, dtype=torch.complex128).to("cuda:1") + + with pytest.raises(ValueError, match="Some tensors are not on the same device"): + pytorch_finufft.functional.finufft_type2.apply(points, targets) + + # dtypes @@ -88,6 +108,69 @@ def test_t1_mismatch_precision() -> None: pytorch_finufft.functional.finufft_type1.apply(points, values, (10, 10)) +def test_t2_non_complex_targets() -> None: + g = np.mgrid[:10, :10] * 2 * np.pi / 10 + points = torch.from_numpy(g.reshape(2, -1)) + targets = torch.randn(*g[0].shape, dtype=torch.float64) + + with pytest.raises( + TypeError, + match="Targets must have a dtype of torch.complex64 or torch.complex128", + ): + pytorch_finufft.functional.finufft_type2.apply(points, targets) + + +def test_t2_half_complex_targets() -> None: + g = np.mgrid[:10, :10] * 2 * np.pi / 10 + points = torch.from_numpy(g.reshape(2, -1)) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + targets = torch.randn(*g[0].shape, dtype=torch.complex32) + + with pytest.raises( + TypeError, + match="Targets must have a dtype of torch.complex64 or torch.complex128", + ): + pytorch_finufft.functional.finufft_type2.apply(points, targets) + + +def test_t2_non_real_points() -> None: + g = np.mgrid[:10, :10] * 2 * np.pi / 10 + points = torch.from_numpy(g.reshape(2, -1)).to(torch.complex128) + targets = torch.randn(*g[0].shape, dtype=torch.complex128) + + with pytest.raises( + TypeError, + match="Points must have a dtype of torch.float64 as targets has " + "a dtype of torch.complex128", + ): + pytorch_finufft.functional.finufft_type2.apply(points, targets) + + +def test_t2_mismatch_precision() -> None: + g = np.mgrid[:10, :10] * 2 * np.pi / 10 + points = torch.from_numpy(g.reshape(2, -1)).to(torch.float32) + targets = torch.randn(*g[0].shape, dtype=torch.complex128) + + with pytest.raises( + TypeError, + match="Points must have a dtype of torch.float64 as targets has " + "a dtype of torch.complex128", + ): + pytorch_finufft.functional.finufft_type2.apply(points, targets) + + points = points.to(torch.float64) + targets = targets.to(torch.complex64) + + with pytest.raises( + TypeError, + match="Points must have a dtype of torch.float32 as targets has " + "a dtype of torch.complex64", + ): + pytorch_finufft.functional.finufft_type2.apply(points, targets) + + # sizes @@ -160,6 +243,35 @@ def test_t1_negative_output_dims() -> None: pytorch_finufft.functional.finufft_type1.apply(points, values, (10, -2)) +def test_t2_points_4d() -> None: + g = np.mgrid[:10, :10, :10, :10] * 2 * np.pi / 10 + points = torch.from_numpy(g.reshape(4, -1)).to(torch.float64) + targets = torch.randn(*g[0].shape, dtype=torch.complex128) + + with pytest.raises(ValueError, match="Points can be at most 3d, got"): + pytorch_finufft.functional.finufft_type2.apply(points, targets) + + +def test_t2_too_many_points_dims() -> None: + g = np.mgrid[:10, :10] * 2 * np.pi / 10 + points = torch.from_numpy(g.reshape(1, 2, -1)).to(torch.float64) + targets = torch.randn(*g[0].shape, dtype=torch.complex128) + + with pytest.raises(ValueError, match="The points tensor must be 1d or 2d"): + pytorch_finufft.functional.finufft_type2.apply(points, targets) + + +def test_t2_mismatch_dims() -> None: + g = np.mgrid[:10, :10, :10] * 2 * np.pi / 10 + points = torch.from_numpy(g.reshape(3, -1)).to(torch.float64) + targets = torch.randn(*g[0].shape[:-1], dtype=torch.complex128) + + with pytest.raises( + ValueError, match="For type 2 3d FINUFFT, targets must be a 3d tensor" + ): + pytorch_finufft.functional.finufft_type2.apply(points, targets) + + # dependencies def test_finufft_not_installed(): if not pytorch_finufft.functional.CUFINUFFT_AVAIL: