Skip to content

Commit

Permalink
Clean up checks
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed Oct 13, 2023
1 parent fcc2072 commit ca8f7ee
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions pytorch_finufft/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def check_dtypes(data: torch.Tensor, points: torch.Tensor, name: str) -> None:

def check_sizes_t1(values: torch.Tensor, points: torch.Tensor) -> None:
"""
Checks that values and points of the same length.
Checks that values and points are of the same length.
This is used in type1.
"""
if len(values.shape) != 1:
Expand All @@ -56,7 +56,7 @@ def check_sizes_t1(values: torch.Tensor, points: torch.Tensor) -> None:
raise ValueError("The same number of points and values must be supplied")
elif len(points.shape) == 2:
if points.shape[0] not in {1, 2, 3}:
raise ValueError(f"Points can be at most 3d, got {points.shape} instead")
raise ValueError(f"Points can be at most 3d, got {points.shape[0]} instead")
if len(values) != points.shape[1]:
raise ValueError("The same number of points and values must be supplied")
else:
Expand Down Expand Up @@ -96,7 +96,7 @@ def check_sizes_t2(targets: torch.Tensor, points: torch.Tensor) -> None:
raise ValueError("The points tensor must be 1d or 2d")

if points_dim not in {1, 2, 3}:
raise ValueError(f"Points can be at most 3d, got {points.shape} instead")
raise ValueError(f"Points can be at most 3d, got {points_dim} instead")

if targets_dim != points_dim:
raise ValueError(
Expand Down

0 comments on commit ca8f7ee

Please sign in to comment.