From ca8f7eee54c2cd4e90e167ef8ff2e3f85cde6129 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 13 Oct 2023 15:22:02 -0400 Subject: [PATCH] Clean up checks --- pytorch_finufft/checks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_finufft/checks.py b/pytorch_finufft/checks.py index ce90500..2560ff8 100644 --- a/pytorch_finufft/checks.py +++ b/pytorch_finufft/checks.py @@ -45,7 +45,7 @@ def check_dtypes(data: torch.Tensor, points: torch.Tensor, name: str) -> None: def check_sizes_t1(values: torch.Tensor, points: torch.Tensor) -> None: """ - Checks that values and points of the same length. + Checks that values and points are of the same length. This is used in type1. """ if len(values.shape) != 1: @@ -56,7 +56,7 @@ def check_sizes_t1(values: torch.Tensor, points: torch.Tensor) -> None: raise ValueError("The same number of points and values must be supplied") elif len(points.shape) == 2: if points.shape[0] not in {1, 2, 3}: - raise ValueError(f"Points can be at most 3d, got {points.shape} instead") + raise ValueError(f"Points can be at most 3d, got {points.shape[0]} instead") if len(values) != points.shape[1]: raise ValueError("The same number of points and values must be supplied") else: @@ -96,7 +96,7 @@ def check_sizes_t2(targets: torch.Tensor, points: torch.Tensor) -> None: raise ValueError("The points tensor must be 1d or 2d") if points_dim not in {1, 2, 3}: - raise ValueError(f"Points can be at most 3d, got {points.shape} instead") + raise ValueError(f"Points can be at most 3d, got {points_dim} instead") if targets_dim != points_dim: raise ValueError(