From 6a1209e64cb88f869df86d2eaec142de2e12929a Mon Sep 17 00:00:00 2001 From: Artem Yerofieiev Date: Wed, 22 May 2024 00:16:29 +0000 Subject: [PATCH] #8364: Disable fallback for ttnn.repeat_interleave --- .../functional_mistral/tt/ttnn_functional_attention.py | 4 ++-- .../unit_tests/operations/test_repeat_interleave.py | 2 ++ ttnn/ttnn/decorators.py | 10 ++++++---- ttnn/ttnn/operations/data_movement.py | 5 ++++- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/models/experimental/functional_mistral/tt/ttnn_functional_attention.py b/models/experimental/functional_mistral/tt/ttnn_functional_attention.py index e2e786bbda3..63293cb5913 100644 --- a/models/experimental/functional_mistral/tt/ttnn_functional_attention.py +++ b/models/experimental/functional_mistral/tt/ttnn_functional_attention.py @@ -64,8 +64,8 @@ def repeat_kv(key, values, repeats, device): values = ttnn.to_layout( ttnn.to_device(ttnn.from_torch(values, dtype=ttnn.bfloat16), device), layout=ttnn.TILE_LAYOUT ) - keys = ttnn.repeat_interleave(keys, repeats, dim) - values = ttnn.repeat_interleave(values, repeats, dim) + keys = ttnn.get_fallback_function(ttnn.repeat_interleave)(keys, repeats, dim) + values = ttnn.get_fallback_function(ttnn.repeat_interleave)(values, repeats, dim) return keys, values diff --git a/tests/ttnn/unit_tests/operations/test_repeat_interleave.py b/tests/ttnn/unit_tests/operations/test_repeat_interleave.py index 5350c5f3b57..4a550a1b7c6 100644 --- a/tests/ttnn/unit_tests/operations/test_repeat_interleave.py +++ b/tests/ttnn/unit_tests/operations/test_repeat_interleave.py @@ -11,6 +11,7 @@ from tests.ttnn.utils_for_testing import assert_with_pcc +@pytest.mark.skip(reason="ttnn.repeat_interleave only supports repeat over dim 0 or 1") def test_repeat_interleave(device): torch_input_tensor = torch.tensor([[1, 2], [3, 4]]) torch_result = torch.repeat_interleave(torch_input_tensor, 2, dim=0) @@ -23,6 +24,7 @@ def test_repeat_interleave(device): assert_with_pcc(torch_result, output, 0.9999) +@pytest.mark.skip(reason="ttnn.repeat_interleave only supports repeat over dim 0 or 1") def test_repeat_interleave_with_repeat_tensor(device): torch_input_tensor = torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16) torch_repeats = torch.tensor([1, 2]) diff --git a/ttnn/ttnn/decorators.py b/ttnn/ttnn/decorators.py index dc1d2e8a9b5..f95b0bd56a9 100644 --- a/ttnn/ttnn/decorators.py +++ b/ttnn/ttnn/decorators.py @@ -757,11 +757,13 @@ def operation_decorator(function: callable): global OPERATION_TO_FALLBACK_FUNCTION def fallback_function(*function_args, **function_kwargs): - updated_function_args, updated_function_kwargs = preprocess_golden_function_inputs( - function_args, function_kwargs - ) + preprocess_inputs = preprocess_golden_function_inputs or default_preprocess_golden_function_inputs + postprocess_outputs = postprocess_golden_function_outputs or default_postprocess_golden_function_outputs + + updated_function_args, updated_function_kwargs = preprocess_inputs(function_args, function_kwargs) output = golden_function(*updated_function_args, **updated_function_kwargs) - output = postprocess_golden_function_outputs(output, function_args, function_kwargs) + output = postprocess_outputs(output, function_args, function_kwargs) + return output if ttnn.CONFIG.enable_fast_runtime_mode: diff --git a/ttnn/ttnn/operations/data_movement.py b/ttnn/ttnn/operations/data_movement.py index b1002a5ec3f..5f0c04bd1be 100644 --- a/ttnn/ttnn/operations/data_movement.py +++ b/ttnn/ttnn/operations/data_movement.py @@ -299,11 +299,14 @@ def _repeat_interleave_validate_input_tensors(operation_name, input_tensor, *arg ) +# This operation does not support the following cases: +# - Shape([2[32], 2[32]]) -> repeats = 2, dim = 0 +# - Shape([2[32], 2[32]]) -> repeats = Tensor[1,2], dim = 1 @ttnn.register_operation( name="ttnn.repeat_interleave", validate_input_tensors=_repeat_interleave_validate_input_tensors, golden_function=_golden_function, - allow_to_fallback_to_golden_function_on_failure=True, + allow_to_fallback_to_golden_function_on_failure=False, ) def repeat_interleave(input_tensor: ttnn.Tensor, repeats: Union[ttnn.Tensor, int], dim: int = 0) -> ttnn.Tensor: r"""