Skip to content

Commit

Permalink
#4514: fixed the bug in ttnn.reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed Jan 9, 2024
1 parent d5d6a1f commit ce6739f
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions ttnn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,18 +1110,15 @@ def ttnn_reshape(tensor, shape):

ttnn_reshape = ttl.tensor.decorate_external_operation(ttnn_reshape, function_name="ttnn.reshape")

if input_tensor.is_contiguous():
if has_storage_type_of(input_tensor, ttl.tensor.StorageType.DEVICE):
# Page size depends on the width, so only modify the shape if the width is the same
if input_tensor.shape[-1] == shape[-1]:
return ttnn_reshape(input_tensor, shape)
else:
# Page size depends on the width, so only modify the shape if the width is the same
if input_tensor.shape[-1] == shape[-1]:
if input_tensor.is_contiguous():
return ttnn_reshape(input_tensor, shape)

if input_tensor.layout == TILE_LAYOUT:
*_, new_height, new_width = tuple(shape.padded())
if new_height % TILE_SIZE == 0 and new_width % TILE_SIZE == 0:
return ttnn_reshape(input_tensor, shape)
if input_tensor.layout == TILE_LAYOUT:
*_, new_height, new_width = tuple(shape.padded())
if new_height % TILE_SIZE == 0 and new_width % TILE_SIZE == 0:
return ttnn_reshape(input_tensor, shape)

if (
has_storage_type_of(input_tensor, ttl.tensor.StorageType.DEVICE)
Expand Down

0 comments on commit ce6739f

Please sign in to comment.