diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_scatter.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_scatter.py new file mode 100644 index 00000000000..767931795e5 --- /dev/null +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_scatter.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: © 2023-24 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import random +from functools import partial +import ttnn + + +from tests.tt_eager.python_api_testing.sweep_tests import ( + comparison_funcs, + generation_funcs, +) +from tests.tt_eager.python_api_testing.sweep_tests.run_pytorch_ci_tests import ( + run_single_pytorch_test, +) +from models.utility_functions import is_grayskull + +mem_configs = [ + ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM), + ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1), +] + + +@pytest.mark.parametrize( + "input_shapes", + [ + [[1, 1, 32, 32], [1, 1, 32, 32]], + [[1, 1, 320, 384], [1, 1, 320, 384]], + [[1, 3, 32, 32], [1, 3, 32, 32]], + [[1, 1, 32, 32], [1, 1, 64, 64]], + [[1, 1, 320, 320], [1, 1, 320, 384]], + ], +) +@pytest.mark.parametrize( + "dst_mem_config", + mem_configs, +) +class TestScatter: + def test_run_scatter( + self, + input_shapes, + dst_mem_config, + device, + ): + datagen_func = [ + generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-1e6, high=1e6), torch.bfloat16) + ] * 2 + test_args = generation_funcs.gen_default_dtype_layout_device(input_shapes)[0] + test_args.update({"output_mem_config": dst_mem_config}) + comparison_func = comparison_funcs.comp_pcc + + run_single_pytorch_test( + "eltwise-scatter", + input_shapes, + datagen_func, + comparison_func, + device, + test_args, + ) diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py index 7362670f625..f69d7fc6e5a 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py @@ -129,7 +129,7 @@ def hypot(x, y, *args, **kwargs): def scatter(x, y, *args, **kwargs): - y[:, :, : x.shape[-2], : x.shape[-1]] = x + y[0:, 0:, : x.shape[-2], : x.shape[-1]] = x return y diff --git a/ttnn/ttnn/operations/binary.py b/ttnn/ttnn/operations/binary.py index ce8b02488f0..d5a06251b01 100644 --- a/ttnn/ttnn/operations/binary.py +++ b/ttnn/ttnn/operations/binary.py @@ -357,7 +357,7 @@ def torch_squared_difference(x, y, *args, **kwargs): def _golden_function_scatter(input_tensor_a, input_tensor_b, *args, **kwargs): - input_tensor_b[:, :, : input_tensor_a.shape[-2], : input_tensor_a.shape[-1]] = input_tensor_a + input_tensor_b[0, 0, : input_tensor_a.shape[-2], : input_tensor_a.shape[-1]] = input_tensor_a return input_tensor_b