Skip to content

Commit

Permalink
Merge pull request #82 from flatironinstitute/refactor/t2-checks
Browse files Browse the repository at this point in the history
Refactor Type 2 checks
  • Loading branch information
WardBrian authored Oct 13, 2023
2 parents a9caabe + ca8f7ee commit 0c12b40
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 78 deletions.
83 changes: 23 additions & 60 deletions pytorch_finufft/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,31 +21,31 @@ def check_devices(*tensors: torch.Tensor) -> None:
)


def check_dtypes(values: torch.Tensor, points: torch.Tensor) -> None:
def check_dtypes(data: torch.Tensor, points: torch.Tensor, name: str) -> None:
"""
Checks that values is complex-valued
Checks that data is complex-valued
and that points is real-valued of the same precision
"""
complex_dtype = values.dtype
complex_dtype = data.dtype
if complex_dtype is torch.complex128:
real_dtype = torch.float64
elif complex_dtype is torch.complex64:
real_dtype = torch.float32
else:
raise TypeError(
"Values must have a dtype of torch.complex64 or torch.complex128"
f"{name} must have a dtype of torch.complex64 or torch.complex128"
)

if points.dtype is not real_dtype:
raise TypeError(
f"Points must have a dtype of {real_dtype} as values has a dtype of "
f"{complex_dtype}"
f"Points must have a dtype of {real_dtype} as {name.lower()} has a "
f"dtype of {complex_dtype}"
)


def check_sizes(values: torch.Tensor, points: torch.Tensor) -> None:
def check_sizes_t1(values: torch.Tensor, points: torch.Tensor) -> None:
"""
Checks that values and points are 1d and of the same length.
Checks that values and points are of the same length.
This is used in type1.
"""
if len(values.shape) != 1:
Expand All @@ -56,7 +56,7 @@ def check_sizes(values: torch.Tensor, points: torch.Tensor) -> None:
raise ValueError("The same number of points and values must be supplied")
elif len(points.shape) == 2:
if points.shape[0] not in {1, 2, 3}:
raise ValueError(f"Points can be at most 3d, got {points.shape} instead")
raise ValueError(f"Points can be at most 3d, got {points.shape[0]} instead")
if len(values) != points.shape[1]:
raise ValueError("The same number of points and values must be supplied")
else:
Expand All @@ -82,60 +82,23 @@ def check_output_shape(ndim: int, output_shape: Union[int, Tuple[int, ...]]) ->
raise ValueError("Got output_shape that was not positive integer")


### TODO delete the following post-consolidation

_COORD_CHAR_TABLE = "xyz"


def _type2_checks(points_tuple: torch.Tensor, targets: torch.Tensor) -> None:
def check_sizes_t2(targets: torch.Tensor, points: torch.Tensor) -> None:
"""
Performs all type, precision, size, device, ... checks for the
type 2 FINUFFT
Parameters
----------
points_tuple : Tuple[torch.Tensor, ...]
A tuple of all points tensors. Eg, (points, ), or (points_x, points_y)
targets : torch.Tensor
The targets tensor from the forward call to FINUFFT
Raises
------
TypeError
In the case that targets is not complex-valued
ValueError
In the case that targets is not of the correct shape
TypeError
In the case that any of the points tensors are not of the correct
type or the correct precision
ValueError
In the case that the i'th dimension of targets is not of the same
length as the i'th points tensor
Checks that targets and points are of the same dimension.
This is used in type2.
"""

if not torch.is_complex(targets):
raise TypeError("Got values that is not complex-valued")

complex_dtype = targets.dtype
real_dtype = torch.float32 if complex_dtype is torch.complex64 else torch.float64

dimension = len(points_tuple)
targets_dim = len(targets.shape)
if len(points.shape) == 1:
points_dim = 1
elif len(points.shape) == 2:
points_dim = points.shape[0]
else:
raise ValueError("The points tensor must be 1d or 2d")

if dimension != targets_dim:
if points_dim not in {1, 2, 3}:
raise ValueError(f"Points can be at most 3d, got {points_dim} instead")

if targets_dim != points_dim:
raise ValueError(
f"For type 2 {dimension}d FINUFFT, targets must be a {dimension}d " "tensor"
f"For type 2 {points_dim}d FINUFFT, targets must be a {points_dim}d tensor"
)

coord_char = ""

# Check dtypes (complex vs. real) on the inputs
for i in range(dimension):
coord_char = "" if dimension == 1 else ("_" + _COORD_CHAR_TABLE[i])

if points_tuple[i].dtype is not real_dtype:
raise TypeError(
f"Got points{coord_char} that is not {real_dtype}-valued; "
f"points{coord_char} must also be the same precision as "
"targets."
)
24 changes: 12 additions & 12 deletions pytorch_finufft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def forward( # type: ignore[override]
raise NotImplementedError("In-place results are not yet implemented")

checks.check_devices(values, points)
checks.check_dtypes(values, points)
checks.check_sizes(values, points)
checks.check_dtypes(values, points, "Values")
checks.check_sizes_t1(values, points)
points = torch.atleast_2d(points)
ndim = points.shape[0]
checks.check_output_shape(ndim, output_shape)
Expand Down Expand Up @@ -232,14 +232,14 @@ def forward( # type: ignore[override]
"""

if out is not None:
print("In-place results are not yet implemented")
raise NotImplementedError("In-place results are not yet implemented")

# TODO -- extend checks to 2d
checks._type2_checks(points, targets)
checks.check_devices(targets, points)
checks.check_dtypes(targets, points, "Targets")
checks.check_sizes_t2(targets, points)

if finufftkwargs is None:
finufftkwargs = dict()

finufftkwargs = {k: v for k, v in finufftkwargs.items()}
_mode_ordering = finufftkwargs.pop(
"modeord", 1
Expand All @@ -248,17 +248,17 @@ def forward( # type: ignore[override]
"isign", -1
) # isign=-1 is finufft default for type 2

ndim = points.shape[0]
if _mode_ordering == 1:
points = torch.atleast_2d(points)
if _mode_ordering:
targets = torch.fft.fftshift(targets)

ctx.save_for_backward(points, targets)

ctx.isign = _i_sign
ctx.mode_ordering = _mode_ordering
ctx.finufftkwargs = finufftkwargs

ctx.save_for_backward(points, targets)

nufft_func = get_nufft_func(ndim, 2, points.device.type)
nufft_func = get_nufft_func(points.shape[0], 2, points.device.type)

finufft_out = nufft_func(
*points,
Expand Down Expand Up @@ -325,7 +325,7 @@ def backward( # type: ignore[override]
**finufftkwargs,
)

if _mode_ordering == 1:
if _mode_ordering:
grad_targets = torch.fft.ifftshift(grad_targets)

return (
Expand Down
4 changes: 2 additions & 2 deletions tests/test_1d/test_backward_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def func(points, targets):
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [False, True])
@pytest.mark.parametrize("isign", [-1, 1])
def test_t2_backward_CPU_values(
def test_t2_backward_CPU_targets(
N: int, modifier: int, fftshift: bool, isign: int
) -> None:
check_t2_backward(N, modifier, fftshift, isign, "cpu", False)
Expand All @@ -145,7 +145,7 @@ def test_t2_backward_CPU_points(
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [False, True])
@pytest.mark.parametrize("isign", [-1, 1])
def test_t2_backward_cuda_values(
def test_t2_backward_cuda_targets(
N: int, modifier: int, fftshift: bool, isign: int
) -> None:
check_t2_backward(N, modifier, fftshift, isign, "cuda", False)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_2d/test_backward_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def func(points, targets):
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [False, True])
@pytest.mark.parametrize("isign", [-1, 1])
def test_t2_backward_CPU_values(
def test_t2_backward_CPU_targets(
N: int, modifier: int, fftshift: bool, isign: int
) -> None:
check_t2_backward(N, modifier, fftshift, isign, "cpu", False)
Expand All @@ -156,7 +156,7 @@ def test_t2_backward_CPU_points(
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [False, True])
@pytest.mark.parametrize("isign", [-1, 1])
def test_t2_backward_cuda_values(
def test_t2_backward_cuda_targets(
N: int, modifier: int, fftshift: bool, isign: int
) -> None:
check_t2_backward(N, modifier, fftshift, isign, "cuda", False)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_3d/test_backward_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def func(points, targets):
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [False, True])
@pytest.mark.parametrize("isign", [-1, 1])
def test_t2_backward_CPU_values(
def test_t2_backward_CPU_targets(
N: int, modifier: int, fftshift: bool, isign: int
) -> None:
check_t2_backward(N, modifier, fftshift, isign, "cpu", False)
Expand All @@ -153,7 +153,7 @@ def test_t2_backward_CPU_points(
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [False, True])
@pytest.mark.parametrize("isign", [-1, 1])
def test_t2_backward_cuda_values(
def test_t2_backward_cuda_targets(
N: int, modifier: int, fftshift: bool, isign: int
) -> None:
check_t2_backward(N, modifier, fftshift, isign, "cuda", False)
Expand Down
112 changes: 112 additions & 0 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings

import numpy as np
import pytest
import torch

Expand All @@ -26,6 +27,25 @@ def test_t1_mismatch_cuda_index() -> None:
pytorch_finufft.functional.finufft_type1.apply(points, values, (10, 10))


def test_t2_mismatch_device_cuda_cpu() -> None:
g = np.mgrid[:10, :10] * 2 * np.pi / 10
points = torch.from_numpy(g.reshape(2, -1))
targets = torch.randn(*g[0].shape, dtype=torch.complex128).to("cuda:0")

with pytest.raises(ValueError, match="Some tensors are not on the same device"):
pytorch_finufft.functional.finufft_type2.apply(points, targets)


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="require multiple GPUs")
def test_t2_mismatch_cuda_index() -> None:
g = np.mgrid[:10, :10] * 2 * np.pi / 10
points = torch.from_numpy(g.reshape(2, -1)).to("cuda:0")
targets = torch.randn(*g[0].shape, dtype=torch.complex128).to("cuda:1")

with pytest.raises(ValueError, match="Some tensors are not on the same device"):
pytorch_finufft.functional.finufft_type2.apply(points, targets)


# dtypes


Expand Down Expand Up @@ -88,6 +108,69 @@ def test_t1_mismatch_precision() -> None:
pytorch_finufft.functional.finufft_type1.apply(points, values, (10, 10))


def test_t2_non_complex_targets() -> None:
g = np.mgrid[:10, :10] * 2 * np.pi / 10
points = torch.from_numpy(g.reshape(2, -1))
targets = torch.randn(*g[0].shape, dtype=torch.float64)

with pytest.raises(
TypeError,
match="Targets must have a dtype of torch.complex64 or torch.complex128",
):
pytorch_finufft.functional.finufft_type2.apply(points, targets)


def test_t2_half_complex_targets() -> None:
g = np.mgrid[:10, :10] * 2 * np.pi / 10
points = torch.from_numpy(g.reshape(2, -1))

with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
targets = torch.randn(*g[0].shape, dtype=torch.complex32)

with pytest.raises(
TypeError,
match="Targets must have a dtype of torch.complex64 or torch.complex128",
):
pytorch_finufft.functional.finufft_type2.apply(points, targets)


def test_t2_non_real_points() -> None:
g = np.mgrid[:10, :10] * 2 * np.pi / 10
points = torch.from_numpy(g.reshape(2, -1)).to(torch.complex128)
targets = torch.randn(*g[0].shape, dtype=torch.complex128)

with pytest.raises(
TypeError,
match="Points must have a dtype of torch.float64 as targets has "
"a dtype of torch.complex128",
):
pytorch_finufft.functional.finufft_type2.apply(points, targets)


def test_t2_mismatch_precision() -> None:
g = np.mgrid[:10, :10] * 2 * np.pi / 10
points = torch.from_numpy(g.reshape(2, -1)).to(torch.float32)
targets = torch.randn(*g[0].shape, dtype=torch.complex128)

with pytest.raises(
TypeError,
match="Points must have a dtype of torch.float64 as targets has "
"a dtype of torch.complex128",
):
pytorch_finufft.functional.finufft_type2.apply(points, targets)

points = points.to(torch.float64)
targets = targets.to(torch.complex64)

with pytest.raises(
TypeError,
match="Points must have a dtype of torch.float32 as targets has "
"a dtype of torch.complex64",
):
pytorch_finufft.functional.finufft_type2.apply(points, targets)


# sizes


Expand Down Expand Up @@ -160,6 +243,35 @@ def test_t1_negative_output_dims() -> None:
pytorch_finufft.functional.finufft_type1.apply(points, values, (10, -2))


def test_t2_points_4d() -> None:
g = np.mgrid[:10, :10, :10, :10] * 2 * np.pi / 10
points = torch.from_numpy(g.reshape(4, -1)).to(torch.float64)
targets = torch.randn(*g[0].shape, dtype=torch.complex128)

with pytest.raises(ValueError, match="Points can be at most 3d, got"):
pytorch_finufft.functional.finufft_type2.apply(points, targets)


def test_t2_too_many_points_dims() -> None:
g = np.mgrid[:10, :10] * 2 * np.pi / 10
points = torch.from_numpy(g.reshape(1, 2, -1)).to(torch.float64)
targets = torch.randn(*g[0].shape, dtype=torch.complex128)

with pytest.raises(ValueError, match="The points tensor must be 1d or 2d"):
pytorch_finufft.functional.finufft_type2.apply(points, targets)


def test_t2_mismatch_dims() -> None:
g = np.mgrid[:10, :10, :10] * 2 * np.pi / 10
points = torch.from_numpy(g.reshape(3, -1)).to(torch.float64)
targets = torch.randn(*g[0].shape[:-1], dtype=torch.complex128)

with pytest.raises(
ValueError, match="For type 2 3d FINUFFT, targets must be a 3d tensor"
):
pytorch_finufft.functional.finufft_type2.apply(points, targets)


# dependencies
def test_finufft_not_installed():
if not pytorch_finufft.functional.CUFINUFFT_AVAIL:
Expand Down

0 comments on commit 0c12b40

Please sign in to comment.