diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py index 245d3768e..14e3c7c6c 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algorithmic_efficiency/data_utils.py @@ -71,7 +71,7 @@ def _prepare(x): def pad(tensor: np.ndarray, pad_size: int, padding_value: int = 0) -> np.ndarray: - if len(tensor) > 1: + if tensor.ndim > 1: pad_size = (pad_size, *tensor.shape[1:]) padding = np.full(pad_size, padding_value, dtype=tensor.dtype) padded_tensor = np.concatenate((tensor, padding), axis=0)