diff --git a/ttnn/core.py b/ttnn/core.py index 1dc6e305866..0ff07baef61 100644 --- a/ttnn/core.py +++ b/ttnn/core.py @@ -1110,18 +1110,15 @@ def ttnn_reshape(tensor, shape): ttnn_reshape = ttl.tensor.decorate_external_operation(ttnn_reshape, function_name="ttnn.reshape") - if input_tensor.is_contiguous(): - if has_storage_type_of(input_tensor, ttl.tensor.StorageType.DEVICE): - # Page size depends on the width, so only modify the shape if the width is the same - if input_tensor.shape[-1] == shape[-1]: - return ttnn_reshape(input_tensor, shape) - else: + # Page size depends on the width, so only modify the shape if the width is the same + if input_tensor.shape[-1] == shape[-1]: + if input_tensor.is_contiguous(): return ttnn_reshape(input_tensor, shape) - if input_tensor.layout == TILE_LAYOUT: - *_, new_height, new_width = tuple(shape.padded()) - if new_height % TILE_SIZE == 0 and new_width % TILE_SIZE == 0: - return ttnn_reshape(input_tensor, shape) + if input_tensor.layout == TILE_LAYOUT: + *_, new_height, new_width = tuple(shape.padded()) + if new_height % TILE_SIZE == 0 and new_width % TILE_SIZE == 0: + return ttnn_reshape(input_tensor, shape) if ( has_storage_type_of(input_tensor, ttl.tensor.StorageType.DEVICE)