Skip to content

Commit

Permalink
#14999: Update scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Nov 14, 2024
1 parent ce6ff4c commit 33a4e96
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# SPDX-FileCopyrightText: © 2023-24 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import random
from functools import partial
import ttnn


from tests.tt_eager.python_api_testing.sweep_tests import (
comparison_funcs,
generation_funcs,
)
from tests.tt_eager.python_api_testing.sweep_tests.run_pytorch_ci_tests import (
run_single_pytorch_test,
)
from models.utility_functions import is_grayskull

mem_configs = [
ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM),
ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1),
]


@pytest.mark.parametrize(
"input_shapes",
[
[[1, 1, 32, 32], [1, 1, 32, 32]],
[[1, 1, 320, 384], [1, 1, 320, 384]],
[[1, 3, 32, 32], [1, 3, 32, 32]],
[[1, 1, 32, 32], [1, 1, 64, 64]],
[[1, 1, 320, 320], [1, 1, 320, 384]],
],
)
@pytest.mark.parametrize(
"dst_mem_config",
mem_configs,
)
class TestScatter:
def test_run_scatter(
self,
input_shapes,
dst_mem_config,
device,
):
datagen_func = [
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-1e6, high=1e6), torch.bfloat16)
] * 2
test_args = generation_funcs.gen_default_dtype_layout_device(input_shapes)[0]
test_args.update({"output_mem_config": dst_mem_config})
comparison_func = comparison_funcs.comp_pcc

run_single_pytorch_test(
"eltwise-scatter",
input_shapes,
datagen_func,
comparison_func,
device,
test_args,
)
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def hypot(x, y, *args, **kwargs):


def scatter(x, y, *args, **kwargs):
y[:, :, : x.shape[-2], : x.shape[-1]] = x
y[0:, 0:, : x.shape[-2], : x.shape[-1]] = x
return y


Expand Down
2 changes: 1 addition & 1 deletion ttnn/ttnn/operations/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def torch_squared_difference(x, y, *args, **kwargs):


def _golden_function_scatter(input_tensor_a, input_tensor_b, *args, **kwargs):
input_tensor_b[:, :, : input_tensor_a.shape[-2], : input_tensor_a.shape[-1]] = input_tensor_a
input_tensor_b[0, 0, : input_tensor_a.shape[-2], : input_tensor_a.shape[-1]] = input_tensor_a
return input_tensor_b


Expand Down

0 comments on commit 33a4e96

Please sign in to comment.