From 3ea17c10eb57c257f97b2cc43bb3583c95edb2ec Mon Sep 17 00:00:00 2001 From: Artem Yerofieiev Date: Wed, 29 May 2024 23:21:16 +0000 Subject: [PATCH] #5389: Move ttnn.repeat_interleave to c++ --- tests/ttnn/unit_tests/gtests/CMakeLists.txt | 1 + .../gtests/test_repeat_interleave.cpp | 105 ++++++++++++++++++ .../unit_tests/gtests/ttnn_test_fixtures.hpp | 19 ++++ .../operations/test_repeat_interleave.py | 17 +-- ttnn/cpp/pybind11/operations/binary.hpp | 2 +- ttnn/cpp/pybind11/operations/ccl.hpp | 2 +- ttnn/cpp/pybind11/operations/core.hpp | 7 +- .../cpp/pybind11/operations/data_movement.hpp | 43 ++++++- ttnn/cpp/pybind11/operations/embedding.hpp | 2 +- .../cpp/pybind11/operations/normalization.hpp | 2 +- ttnn/cpp/pybind11/operations/pool.hpp | 2 +- ttnn/cpp/pybind11/operations/unary.hpp | 8 +- ttnn/cpp/ttnn/operations/data_movement.hpp | 33 ++++++ ttnn/ttnn/operations/data_movement.py | 80 +------------ 14 files changed, 222 insertions(+), 101 deletions(-) create mode 100644 tests/ttnn/unit_tests/gtests/test_repeat_interleave.cpp diff --git a/tests/ttnn/unit_tests/gtests/CMakeLists.txt b/tests/ttnn/unit_tests/gtests/CMakeLists.txt index 6ad8e40a486..a132f44d868 100644 --- a/tests/ttnn/unit_tests/gtests/CMakeLists.txt +++ b/tests/ttnn/unit_tests/gtests/CMakeLists.txt @@ -1,6 +1,7 @@ set(TTNN_UNIT_TESTS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/test_add.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test_repeat_interleave.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_async_runtime.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_multiprod_queue.cpp ) diff --git a/tests/ttnn/unit_tests/gtests/test_repeat_interleave.cpp b/tests/ttnn/unit_tests/gtests/test_repeat_interleave.cpp new file mode 100644 index 00000000000..e899a97e6c8 --- /dev/null +++ b/tests/ttnn/unit_tests/gtests/test_repeat_interleave.cpp @@ -0,0 +1,105 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "gtest/gtest.h" + +#include "tt_metal/common/bfloat16.hpp" +#include "ttnn/device.hpp" +#include "ttnn/operations/core.hpp" +#include "ttnn/async_runtime.hpp" +#include "ttnn/operations/data_movement.hpp" +#include "tt_numpy/functions.hpp" +#include "tt_metal/common/logger.hpp" + +#include "ttnn_test_fixtures.hpp" + +#include + +namespace ttnn { +namespace operations { +namespace data_movement { +namespace test { + +void run_repeat_interleave_test(tt::tt_metal::Device* device, const uint32_t repeats, const uint32_t dim) { + MemoryConfig mem_cfg; + mem_cfg.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED; + mem_cfg.buffer_type = BufferType::DRAM; + + const uint32_t io_cq = 0; + const uint32_t input_buf_size_datums = 32 * 32; + const uint32_t output_buf_size_datums = input_buf_size_datums * repeats; + const uint32_t datum_size_bytes = 2; + ttnn::Shape input_shape = ttnn::Shape(tt::tt_metal::Shape({1, 1, 32, 32})); + auto host_data = std::shared_ptr(new uint16_t[input_buf_size_datums]); + auto readback_data = std::shared_ptr(new uint16_t[output_buf_size_datums]); + + for (uint16_t i = 0; i < 32; i++) { + for (uint16_t j = 0; j < 32; j++) { + host_data[i * 32 + j] = i; + } + } + + auto input_buffer = ttnn::allocate_buffer_on_device(input_buf_size_datums * datum_size_bytes, device, input_shape, DataType::UINT16, Layout::TILE, mem_cfg); + auto input_storage = tt::tt_metal::DeviceStorage{input_buffer}; + Tensor input_tensor = Tensor(input_storage, input_shape, DataType::UINT16, Layout::TILE); + ttnn::write_buffer(io_cq, input_tensor, {host_data}); + + ttnn::Tensor output_tensor = ttnn::repeat_interleave(input_tensor, repeats, dim); + + ttnn::read_buffer(io_cq, output_tensor, {readback_data}); + + tt::log_debug("input_data: \n {}", input_tensor.write_to_string()); + tt::log_debug("readback_data: \n {}", output_tensor.write_to_string()); + + for (int i = 0; i < input_buf_size_datums; i++) { + auto input_value = host_data[i]; + for(int r = 0; r < repeats; r++) { + auto value = readback_data[i + r * input_buf_size_datums]; + ASSERT_EQ(input_value, value); + } + } + + input_tensor.deallocate(); + output_tensor.deallocate(); +} + +struct RepeatInterleaveParams { + int repeats = 0; + int dim = 0; +}; + +class RepeatInterleaveTest : public ttnn::TTNNFixtureWithDevice, public ::testing::WithParamInterface {}; + +TEST_P(RepeatInterleaveTest, RunsCorrectly) { + RepeatInterleaveParams params = GetParam(); + run_repeat_interleave_test(device_, params.repeats, params.dim); +} + +INSTANTIATE_TEST_SUITE_P( + RepeatInterleaveWithDim0, + RepeatInterleaveTest, + ::testing::Values( + RepeatInterleaveParams{1, 0}, + RepeatInterleaveParams{2, 0}, + RepeatInterleaveParams{3, 0} + ) +); + +// tests/ttnn/unit_tests/operations/test_repeat_interleave.py proves that it should work over dim 1 too +// likely need to fix the comparison in the test +INSTANTIATE_TEST_SUITE_P( + DISABLED_RepeatInterleaveWithDim1, + RepeatInterleaveTest, + ::testing::Values( + RepeatInterleaveParams{1, 1}, + RepeatInterleaveParams{2, 1}, + RepeatInterleaveParams{3, 1} + ) +); + + +} // namespace test +} // namespace binary +} // namespace operations +} // namespace ttnn diff --git a/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp b/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp index e7ae534392a..86da367c3d3 100644 --- a/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp +++ b/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp @@ -10,6 +10,9 @@ #include "gtest/gtest.h" +#include "ttnn/device.hpp" +#include "tests/tt_metal/test_utils/env_vars.hpp" + namespace ttnn { class TTNNFixture : public ::testing::Test { @@ -26,4 +29,20 @@ class TTNNFixture : public ::testing::Test { void TearDown() override { tt::Cluster::instance().set_internal_routing_info_for_ethernet_cores(false); } }; + +class TTNNFixtureWithDevice : public TTNNFixture { + protected: + tt::tt_metal::Device* device_ = nullptr; + + void SetUp() override { + TTNNFixture::SetUp(); + device_ = tt::tt_metal::CreateDevice(0); + } + + void TearDown() override { + TTNNFixture::TearDown(); + tt::tt_metal::CloseDevice(device_); + } +}; + } // namespace ttnn diff --git a/tests/ttnn/unit_tests/operations/test_repeat_interleave.py b/tests/ttnn/unit_tests/operations/test_repeat_interleave.py index 4a550a1b7c6..aefd70b99c2 100644 --- a/tests/ttnn/unit_tests/operations/test_repeat_interleave.py +++ b/tests/ttnn/unit_tests/operations/test_repeat_interleave.py @@ -7,26 +7,29 @@ import torch import ttnn +from loguru import logger from tests.ttnn.utils_for_testing import assert_with_pcc -@pytest.mark.skip(reason="ttnn.repeat_interleave only supports repeat over dim 0 or 1") -def test_repeat_interleave(device): - torch_input_tensor = torch.tensor([[1, 2], [3, 4]]) - torch_result = torch.repeat_interleave(torch_input_tensor, 2, dim=0) +@pytest.mark.parametrize("repeats", [1, 2, 3]) +@pytest.mark.parametrize("dim", [0, 1, 2, 3]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_repeat_interleave(device, repeats, dim, dtype): + torch_input_tensor = torch.rand(1, 1, 32, 32, dtype=dtype) + torch_result = torch.repeat_interleave(torch_input_tensor, repeats, dim=dim) input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) - output = ttnn.repeat_interleave(input_tensor, 2, dim=0) + output = ttnn.repeat_interleave(input_tensor, repeats, dim=dim) output = ttnn.to_torch(output) assert_with_pcc(torch_result, output, 0.9999) -@pytest.mark.skip(reason="ttnn.repeat_interleave only supports repeat over dim 0 or 1") +@pytest.mark.skip(reason="ttnn.repeat_interleave only supports `repeats` as int") def test_repeat_interleave_with_repeat_tensor(device): - torch_input_tensor = torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16) + torch_input_tensor = torch.rand(1, 2, 32, 32, dtype=torch.bfloat16) torch_repeats = torch.tensor([1, 2]) torch_result = torch.repeat_interleave(torch_input_tensor, torch_repeats, dim=1) input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) diff --git a/ttnn/cpp/pybind11/operations/binary.hpp b/ttnn/cpp/pybind11/operations/binary.hpp index 4c9f2104b58..a7771348bcf 100644 --- a/ttnn/cpp/pybind11/operations/binary.hpp +++ b/ttnn/cpp/pybind11/operations/binary.hpp @@ -37,7 +37,7 @@ void bind_binary_operation(py::module& module, const binary_operation_t& operati * :attr:`dtype` (ttnn.DataType): data type for the output tensor * :attr:`activations` (List[str]): list of activation functions to apply to the output tensor - Example:: + Example: >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device) diff --git a/ttnn/cpp/pybind11/operations/ccl.hpp b/ttnn/cpp/pybind11/operations/ccl.hpp index 03d049b8ed9..fa68102a988 100644 --- a/ttnn/cpp/pybind11/operations/ccl.hpp +++ b/ttnn/cpp/pybind11/operations/ccl.hpp @@ -58,7 +58,7 @@ void py_module(py::module& module) { * :attr:`num_links` (int): Number of links to use for the all-gather operation. * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. - Example:: + Example: >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) >>> output = ttnn.all_gather(tensor, dim=0) diff --git a/ttnn/cpp/pybind11/operations/core.hpp b/ttnn/cpp/pybind11/operations/core.hpp index 04ec5536378..32d5295d82a 100644 --- a/ttnn/cpp/pybind11/operations/core.hpp +++ b/ttnn/cpp/pybind11/operations/core.hpp @@ -113,7 +113,7 @@ void py_module(py::module& module) { * :attr:`memory_config`: the desired MemoryConfig * :attr:`dtype`: the optional `ttnn` data type. - Example:: + >>> device_id = 0 >>> device = ttnn.open_device(device_id=device_id) >>> tensor = ttnn.to_device(ttnn.from_torch(torch.randn((10, 64, 32), dtype=torch.bfloat16)), device) @@ -133,7 +133,7 @@ void py_module(py::module& module) { * :attr:`tensor`: the ttnn.Tensor * :attr:`dtype`: `ttnn` data type. - Example:: + Example: >>> tensor = ttnn.from_torch(torch.randn((10, 64, 32), dtype=torch.bfloat16)) >>> tensor = ttnn.to_dtype(tensor, dtype=ttnn.uint16) )doc", @@ -252,7 +252,8 @@ void py_module(py::module& module) { * :attr:`dtype`: the optional output data type. * :attr:`memory_config`: the optional output memory configuration. * :attr:`device`: Device/DeviceMesh whose worker thread on host should be used for the layout conversion - Example:: + + Example: >>> device_id = 0 >>> device = ttnn.open_device(device_id=device_id) >>> tensor = ttnn.to_device(ttnn.from_torch(torch.randn((10, 64, 32), dtype=torch.bfloat16)), device) diff --git a/ttnn/cpp/pybind11/operations/data_movement.hpp b/ttnn/cpp/pybind11/operations/data_movement.hpp index 0b8484a895f..f73aca67acb 100644 --- a/ttnn/cpp/pybind11/operations/data_movement.hpp +++ b/ttnn/cpp/pybind11/operations/data_movement.hpp @@ -27,7 +27,7 @@ Permutes :attr:`input_tensor` using :attr:`order`. * :attr:`input_tensor`: the input tensor * :attr:`order`: the desired ordering of dimensions. -Example:: +Example: >>> tensor = ttnn.to_device(ttnn.from_torch(torch.zeros((1, 1, 64, 32), dtype=torch.bfloat16)), device) >>> output = ttnn.permute(tensor, (0, 1, 3, 2)) @@ -53,7 +53,7 @@ Concats :attr:`tensors` in the given :attr:`dim`. Keyword Args: * :attr:`memory_config`: the memory configuration to use for the operation -Example:: +Example: >>> tensor = ttnn.concat(ttnn.from_torch(torch.zeros((1, 1, 64, 32), ttnn.from_torch(torch.zeros((1, 1, 64, 32), dim=3)), device) @@ -79,7 +79,9 @@ The algorithms available for upsampling are 'nearest' for now. * :attr:`scale_factor`: multiplier for spatial size. Has to match input size if it is a tuple. )doc", ttnn::pybind_arguments_t{ - py::arg("input_tensor"), py::arg("scale_factor"), py::arg("memory_config") = std::nullopt}); + py::arg("input_tensor"), + py::arg("scale_factor"), + py::arg("memory_config") = std::nullopt}); ttnn::bind_registered_operation( module, @@ -96,7 +98,7 @@ Returns a new tensor filled with repetition of input :attr:`input_tensor` accord Keyword Args: * :attr:`memory_config`: the memory configuration to use for the operation -Example:: +Example: >>> tensor = ttnn.repeat(ttnn.from_torch(torch.tensor([[1, 2], [3, 4]]), 2,)), device) >>> print(tensor) @@ -107,6 +109,39 @@ Example:: )doc", ttnn::pybind_arguments_t{ py::arg("input_tensor"), py::arg("shape"), py::kw_only(), py::arg("memory_config") = std::nullopt}); + + ttnn::bind_registered_operation( + module, + ttnn::repeat_interleave, + R"doc( +repeat_interleave(input_tensor: ttnn.Tensor, repeats : int, dim: int = 0) -> ttnn.Tensor + +Repeats elements of a :attr:`tensor` in the given :attr:`dim`. + +Args: + * :attr:`input_tensor`: the input_tensor to apply the repeate interleave operation. + * :attr:`repeats`: The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis. + * :attr:`dim`: the dimension to expand with the repetitions. + +Example: + +torch_input_tensor = + torch_result = torch.repeat_interleave(torch_input_tensor, repeats, dim=dim) + + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) + + output = ttnn.repeat_interleave(input_tensor, repeats, dim=dim) + >>> a = ttnn.from_torch(torch.rand(1, 1, 32, 32, dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device) + >>> b = ttnn.repeat_interleave(a, 2, dim=0) + >>> print(a.shape, b.shape) + ttnn.Shape([1, 1, 32, 32]) ttnn.Shape([2, 1, 32, 32]) + )doc", + ttnn::pybind_arguments_t{ + py::arg("input_tensor"), + py::arg("repeats"), + py::arg("dim"), + py::kw_only(), + py::arg("memory_config") = std::nullopt}); } } // namespace data_movement diff --git a/ttnn/cpp/pybind11/operations/embedding.hpp b/ttnn/cpp/pybind11/operations/embedding.hpp index 388bea6bd1e..5261092004d 100644 --- a/ttnn/cpp/pybind11/operations/embedding.hpp +++ b/ttnn/cpp/pybind11/operations/embedding.hpp @@ -33,7 +33,7 @@ void py_module(py::module& module) { * :attr:`layout`: the layout of the input and output tensors. Default is ttnn.ROW_MAJOR_LAYOUT. * :attr:`memory_config`: the memory configuration of the output tensor. Default is input tensor memory config. - Example:: + Example: >>> device_id = 0 >>> device = ttnn.open_device(device_id=device_id) >>> input_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]]), dtype=ttnn.uint32), device) diff --git a/ttnn/cpp/pybind11/operations/normalization.hpp b/ttnn/cpp/pybind11/operations/normalization.hpp index 8d72f20a62c..73713517fe5 100644 --- a/ttnn/cpp/pybind11/operations/normalization.hpp +++ b/ttnn/cpp/pybind11/operations/normalization.hpp @@ -34,7 +34,7 @@ void py_module(py::module& module) { Keyword Args: * :attr:`memory_config`: the memory configuration for the output tensor. If not provided, the memory configuration of the input tensor is used. - Example:: + Example: >>> tensor = ttnn.to_device(ttnn.from_torch(torch.zeros((1, 1, 64, 32), dtype=torch.bfloat16)), device) >>> output = ttnn.softmax(tensor, -1) diff --git a/ttnn/cpp/pybind11/operations/pool.hpp b/ttnn/cpp/pybind11/operations/pool.hpp index 38ee2fcdab8..c775273cf14 100644 --- a/ttnn/cpp/pybind11/operations/pool.hpp +++ b/ttnn/cpp/pybind11/operations/pool.hpp @@ -38,7 +38,7 @@ void bind_global_avg_pool2d(py::module& module) { Returns: ttnn.Tensor: The tensor with the averaged values. The output tensor shape is (batch_size, channels, 1, 1). - Example:: + Example: >>> tensor = ttnn.from_torch(torch.randn((10, 3, 32, 32), dtype=ttnn.bfloat16), device=device) >>> output = {1}(tensor) diff --git a/ttnn/cpp/pybind11/operations/unary.hpp b/ttnn/cpp/pybind11/operations/unary.hpp index bb169e3ea6f..888a5402687 100644 --- a/ttnn/cpp/pybind11/operations/unary.hpp +++ b/ttnn/cpp/pybind11/operations/unary.hpp @@ -35,7 +35,7 @@ void bind_unary_operation(py::module& module, const unary_operation_t& operation Keyword Args: * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. - Example:: + Example: >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) >>> output = {1}(tensor) @@ -67,7 +67,7 @@ void bind_unary_operation_with_fast_and_approximate_mode(py::module& module, con * :attr:`fast_and_approximate_mode` (bool): "Use fast and approximate mode". * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. - Example:: + Example: >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) >>> output = {1}(tensor, fast_and_approximate_mode=true) @@ -107,7 +107,7 @@ void bind_unary_operation_with_float_parameter( * :attr:`{2}` (bool): {3}. * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. - Example:: + Example: >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) >>> output = {1}(tensor, {2}=true) @@ -145,7 +145,7 @@ void bind_softplus(py::module& module) { * :attr:`threshold` (float): Used to switch to a linear function for large values to improve numerical stability. This avoids issues with floating-point representation for very large values * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. - Example:: + Example: >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) >>> output = {1}(tensor, parameter=true) diff --git a/ttnn/cpp/ttnn/operations/data_movement.hpp b/ttnn/cpp/ttnn/operations/data_movement.hpp index 13ebf3eec10..bd9c02cc738 100644 --- a/ttnn/cpp/ttnn/operations/data_movement.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement.hpp @@ -7,6 +7,7 @@ #include "tt_eager/tt_dnn/op_library/concat/concat_op.hpp" #include "tt_eager/tt_dnn/op_library/permute/permute_op.hpp" #include "tt_eager/tt_dnn/op_library/repeat/repeat_op.hpp" +#include "tt_eager/tt_dnn/op_library/composite/composite_ops.hpp" #include "tt_eager/tt_dnn/op_library/upsample/upsample_op.hpp" #include "ttnn/cpp/ttnn/operations/core.hpp" @@ -283,8 +284,40 @@ struct Repeat { } }; +struct RepeatInterleave { + static inline const std::array input_tensor_schemas() { + return {ttnn::TensorSchema{ + 4, // min rank + 4, // max rank + {ttnn::bfloat16}, + {ttnn::TILE_LAYOUT}, + true, // can_be_on_device + true, // can_be_on_cpu + false, // can_be_scalar + false}}; // is_optional + } + + template + static auto input_tensors_to_validate(const ttnn::Tensor& input_tensor, Args&&... args) { + return std::make_tuple(input_tensor); + } + + // # This operation does not support the following cases: + // # - Shape([2[32], 2[32]]) -> repeats = 2, dim = 0 + // # - Shape([2[32], 2[32]]) -> repeats = Tensor[1,2], dim = 1 + static ttnn::Tensor execute_on_worker_thread(const ttnn::Tensor& input_tensor, + uint32_t repeats, + int32_t dim, + std::optional output_mem_config = std::nullopt) { + MemoryConfig mem_config = output_mem_config.value_or(input_tensor.memory_config()); + auto output_tensor = tt::tt_metal::repeat_interleave(input_tensor, repeats, dim, mem_config); + return output_tensor; + } +}; + } // namespace data_movement } // namespace operations constexpr auto upsample = ttnn::register_operation("ttnn::upsample"); constexpr auto repeat = ttnn::register_operation("ttnn::repeat"); +constexpr auto repeat_interleave = ttnn::register_operation("ttnn::repeat_interleave"); } // namespace ttnn diff --git a/ttnn/ttnn/operations/data_movement.py b/ttnn/ttnn/operations/data_movement.py index 2dd2bf9a5f8..bc47a7e7e9f 100644 --- a/ttnn/ttnn/operations/data_movement.py +++ b/ttnn/ttnn/operations/data_movement.py @@ -285,85 +285,9 @@ def _golden_function(tensor, repeats, dim=0, **_): return torch.repeat_interleave(tensor, repeats, dim=dim) -def _repeat_interleave_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): - ttnn.validate_input_tensor( - operation_name, - input_tensor, - ranks=(2, 3, 4), - dtypes=(ttnn.bfloat16, ttnn.bfloat8_b, ttnn.uint16, ttnn.int32, ttnn.uint32), - layouts=(ttnn.TILE_LAYOUT,), - can_be_on_device=True, - can_be_on_cpu=True, - ) - - -# This operation does not support the following cases: -# - Shape([2[32], 2[32]]) -> repeats = 2, dim = 0 -# - Shape([2[32], 2[32]]) -> repeats = Tensor[1,2], dim = 1 -@ttnn.register_operation( - name="ttnn.repeat_interleave", - validate_input_tensors=_repeat_interleave_validate_input_tensors, - golden_function=_golden_function, +repeat_interleave = ttnn.register_operation(golden_function=_golden_function)( + ttnn._ttnn.operations.data_movement.repeat_interleave ) -def repeat_interleave(input_tensor: ttnn.Tensor, repeats: Union[ttnn.Tensor, int], dim: int = 0) -> ttnn.Tensor: - r""" - repeat_interleave(input_tensor: ttnn.Tensor, repeats : Union[ttnn.Tensor,int], dim: int = 0) -> ttnn.Tensor - - Repeats elements of a :attr:`tensor` in the given :attr:`dim`. - - Args: - * :attr:`input_tensor`: the input_tensor to apply the repeate interleave operation. - * :attr:`repeats`: The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis. - * :attr:`dim`: the dimension to expand with the repetitions. - - Example:: - - >>> a = ttnn.from_torch(torch.tensor([[1, 2], [3, 4]]), device=device, layout=ttnn.TILE_LAYOUT) - >>> b = ttnn.repeat_interleave(a, 2, dim=0) - >>> print(a.shape, b.shape) - ttnn.Shape([2[32], 2[32]]) ttnn.Shape([4[32], 2[32]]) - - """ - - if not isinstance(repeats, int) and not isinstance(repeats, ttnn.Tensor): - raise RuntimeError("ttnn: Expected repeat to either be an int or a ttnn.Tensor") - - rank_of_tensor = len(input_tensor.shape) - if dim >= rank_of_tensor: - dimension_range = f"[{-rank_of_tensor}, {rank_of_tensor - 1}]" - raise RuntimeError( - f"ttnn: Dimension out of range (expected to be in range of {dimension_range}, but got {dim})" - ) - - def custom_numel(tensor): - total_elements = 1 - for dimension in tensor.shape: - total_elements *= dimension - return total_elements - - if isinstance(repeats, ttnn.Tensor): - if input_tensor.shape[dim] != custom_numel(repeats): - raise RuntimeError("ttnn: repeats must have the same size as input along dim") - elif len(repeats.shape) != 1: - raise RuntimeError("ttnn: repeats must be 0-dim or 1-dim tensor") - - dtype = input_tensor.dtype - rank = len(input_tensor.shape) - if dtype == ttnn.bfloat16 and rank == 4 and dim != 2 and dim != 3: - output_tensor = ttl.tensor.repeat_interleave(input_tensor, repeats, dim=dim) - *batch, _, _ = output_tensor.shape - *_, h, w = input_tensor.shape - *_, padded_h, padded_w = input_tensor.shape.with_tile_padding() - if dim == 2: - *_, h, _ = output_tensor.shape - *_, padded_h, _ = output_tensor.shape.with_tile_padding() - elif dim == 3: - *_, _, w = output_tensor.shape - *_, _, padded_w = output_tensor.shape.with_tile_padding() - output_tensor = ttnn.reshape(output_tensor, shape=ttnn.Shape(batch + [h, w], batch + [padded_h, padded_w])) - return output_tensor - else: - raise NotImplementedError def _golden_function(tensor, shape, **_):