Skip to content

Commit

Permalink
#4003: updated concat operation to raise an exception if the dimensio…
Browse files Browse the repository at this point in the history
…n is out of range
  • Loading branch information
arakhmati committed Jan 23, 2024
1 parent 283158f commit 9202d17
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
3 changes: 1 addition & 2 deletions tests/ttnn/sweep_tests/sweeps/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,9 @@ def is_expected_to_fail(
return True, "You must have at least two tensors to concat!"

if dimension_to_concatenate_on >= rank_of_tensors:
dimension_range = f"[{-rank_of_tensors}, {rank_of_tensors - 1}]"
return (
True,
f"TTNN: Dimension out of range (expected to be in range of {dimension_range}, but got {dimension_to_concatenate_on})",
f"ttnn: Dimension out of range: dim {dimension_to_concatenate_on} cannot be used for tensors of rank {rank_of_tensors}",
)

return False, None
Expand Down
18 changes: 10 additions & 8 deletions ttnn/data_movement.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,18 @@ def concat(tensors: Union[ttnn.Tensor, List[ttnn.Tensor]], dim: int = 0) -> ttnn
if len(tensors) < 2:
raise RuntimeError("You must have at least two tensors to concat!")

rank = len(tensors[0].shape)
original_dim = dim
if dim < 0:
dim = rank + dim
if dim < 0 or dim >= rank:
raise RuntimeError(
f"ttnn: Dimension out of range: dim {original_dim} cannot be used for tensors of rank {rank}"
)

for input_tensor in tensors:
if not ttnn.has_storage_type_of(input_tensor, ttl.tensor.StorageType.DEVICE):
raise RuntimeError("All tensors must be on device!")
raise RuntimeError("ttnn: All tensors must be on device!")

dtype = tensors[0].dtype
device = tensors[0].device
Expand All @@ -179,13 +188,6 @@ def concat(tensors: Union[ttnn.Tensor, List[ttnn.Tensor]], dim: int = 0) -> ttnn
"All dimensions must be the same size except for the dimension along which the contenation is taking place."
)

rank_of_tensors = len(tensors[0].shape)
if dim >= rank_of_tensors:
dimension_range = f"[{-rank_of_tensors}, {rank_of_tensors - 1}]"
raise RuntimeError(
f"TTNN: Dimension out of range (expected to be in range of {dimension_range}, but got {dim})"
)

output_tensor = _torch_concat(tensors, dim=dim)

return ttnn.from_torch(output_tensor, dtype=dtype, device=device, layout=layout)
Expand Down

0 comments on commit 9202d17

Please sign in to comment.