Skip to content

Commit

Permalink
#8364: Disable fallback for ttnn.repeat_interleave
Browse files Browse the repository at this point in the history
  • Loading branch information
ayerofieiev-tt committed May 22, 2024
1 parent 9c0ecf3 commit 6a1209e
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def repeat_kv(key, values, repeats, device):
values = ttnn.to_layout(
ttnn.to_device(ttnn.from_torch(values, dtype=ttnn.bfloat16), device), layout=ttnn.TILE_LAYOUT
)
keys = ttnn.repeat_interleave(keys, repeats, dim)
values = ttnn.repeat_interleave(values, repeats, dim)
keys = ttnn.get_fallback_function(ttnn.repeat_interleave)(keys, repeats, dim)
values = ttnn.get_fallback_function(ttnn.repeat_interleave)(values, repeats, dim)
return keys, values


Expand Down
2 changes: 2 additions & 0 deletions tests/ttnn/unit_tests/operations/test_repeat_interleave.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tests.ttnn.utils_for_testing import assert_with_pcc


@pytest.mark.skip(reason="ttnn.repeat_interleave only supports repeat over dim 0 or 1")
def test_repeat_interleave(device):
torch_input_tensor = torch.tensor([[1, 2], [3, 4]])
torch_result = torch.repeat_interleave(torch_input_tensor, 2, dim=0)
Expand All @@ -23,6 +24,7 @@ def test_repeat_interleave(device):
assert_with_pcc(torch_result, output, 0.9999)


@pytest.mark.skip(reason="ttnn.repeat_interleave only supports repeat over dim 0 or 1")
def test_repeat_interleave_with_repeat_tensor(device):
torch_input_tensor = torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16)
torch_repeats = torch.tensor([1, 2])
Expand Down
10 changes: 6 additions & 4 deletions ttnn/ttnn/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,11 +757,13 @@ def operation_decorator(function: callable):
global OPERATION_TO_FALLBACK_FUNCTION

def fallback_function(*function_args, **function_kwargs):
updated_function_args, updated_function_kwargs = preprocess_golden_function_inputs(
function_args, function_kwargs
)
preprocess_inputs = preprocess_golden_function_inputs or default_preprocess_golden_function_inputs
postprocess_outputs = postprocess_golden_function_outputs or default_postprocess_golden_function_outputs

updated_function_args, updated_function_kwargs = preprocess_inputs(function_args, function_kwargs)
output = golden_function(*updated_function_args, **updated_function_kwargs)
output = postprocess_golden_function_outputs(output, function_args, function_kwargs)
output = postprocess_outputs(output, function_args, function_kwargs)

return output

if ttnn.CONFIG.enable_fast_runtime_mode:
Expand Down
5 changes: 4 additions & 1 deletion ttnn/ttnn/operations/data_movement.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,14 @@ def _repeat_interleave_validate_input_tensors(operation_name, input_tensor, *arg
)


# This operation does not support the following cases:
# - Shape([2[32], 2[32]]) -> repeats = 2, dim = 0
# - Shape([2[32], 2[32]]) -> repeats = Tensor[1,2], dim = 1
@ttnn.register_operation(
name="ttnn.repeat_interleave",
validate_input_tensors=_repeat_interleave_validate_input_tensors,
golden_function=_golden_function,
allow_to_fallback_to_golden_function_on_failure=True,
allow_to_fallback_to_golden_function_on_failure=False,
)
def repeat_interleave(input_tensor: ttnn.Tensor, repeats: Union[ttnn.Tensor, int], dim: int = 0) -> ttnn.Tensor:
r"""
Expand Down

0 comments on commit 6a1209e

Please sign in to comment.