Skip to content

Commit

Permalink
Fix pad_size in pad function
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Sep 14, 2023
1 parent 35c8736 commit ad64fd1
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion algorithmic_efficiency/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _prepare(x):
def pad(tensor: np.ndarray,
pad_size: int,
padding_value: int = 0) -> np.ndarray:
if len(tensor) > 1:
if tensor.ndim > 1:
pad_size = (pad_size, *tensor.shape[1:])
padding = np.full(pad_size, padding_value, dtype=tensor.dtype)
padded_tensor = np.concatenate((tensor, padding), axis=0)
Expand Down

0 comments on commit ad64fd1

Please sign in to comment.