Skip to content

Commit

Permalink
#5389: Move ttnn.repeat_interleave to c++ (#8961)
Browse files Browse the repository at this point in the history
  • Loading branch information
ayerofieiev-tt authored Jun 6, 2024
1 parent 4276e5c commit 3082585
Show file tree
Hide file tree
Showing 14 changed files with 222 additions and 101 deletions.
1 change: 1 addition & 0 deletions tests/ttnn/unit_tests/gtests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

set(TTNN_UNIT_TESTS_SRC
${CMAKE_CURRENT_SOURCE_DIR}/test_add.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_repeat_interleave.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_async_runtime.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_multiprod_queue.cpp
)
Expand Down
105 changes: 105 additions & 0 deletions tests/ttnn/unit_tests/gtests/test_repeat_interleave.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "gtest/gtest.h"

#include "tt_metal/common/bfloat16.hpp"
#include "ttnn/device.hpp"
#include "ttnn/operations/core.hpp"
#include "ttnn/async_runtime.hpp"
#include "ttnn/operations/data_movement.hpp"
#include "tt_numpy/functions.hpp"
#include "tt_metal/common/logger.hpp"

#include "ttnn_test_fixtures.hpp"

#include <memory>

namespace ttnn {
namespace operations {
namespace data_movement {
namespace test {

void run_repeat_interleave_test(tt::tt_metal::Device* device, const uint32_t repeats, const uint32_t dim) {
MemoryConfig mem_cfg;
mem_cfg.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED;
mem_cfg.buffer_type = BufferType::DRAM;

const uint32_t io_cq = 0;
const uint32_t input_buf_size_datums = 32 * 32;
const uint32_t output_buf_size_datums = input_buf_size_datums * repeats;
const uint32_t datum_size_bytes = 2;
ttnn::Shape input_shape = ttnn::Shape(tt::tt_metal::Shape({1, 1, 32, 32}));
auto host_data = std::shared_ptr<uint16_t[]>(new uint16_t[input_buf_size_datums]);
auto readback_data = std::shared_ptr<uint16_t[]>(new uint16_t[output_buf_size_datums]);

for (uint16_t i = 0; i < 32; i++) {
for (uint16_t j = 0; j < 32; j++) {
host_data[i * 32 + j] = i;
}
}

auto input_buffer = ttnn::allocate_buffer_on_device(input_buf_size_datums * datum_size_bytes, device, input_shape, DataType::UINT16, Layout::TILE, mem_cfg);
auto input_storage = tt::tt_metal::DeviceStorage{input_buffer};
Tensor input_tensor = Tensor(input_storage, input_shape, DataType::UINT16, Layout::TILE);
ttnn::write_buffer(io_cq, input_tensor, {host_data});

ttnn::Tensor output_tensor = ttnn::repeat_interleave(input_tensor, repeats, dim);

ttnn::read_buffer(io_cq, output_tensor, {readback_data});

tt::log_debug("input_data: \n {}", input_tensor.write_to_string());
tt::log_debug("readback_data: \n {}", output_tensor.write_to_string());

for (int i = 0; i < input_buf_size_datums; i++) {
auto input_value = host_data[i];
for(int r = 0; r < repeats; r++) {
auto value = readback_data[i + r * input_buf_size_datums];
ASSERT_EQ(input_value, value);
}
}

input_tensor.deallocate();
output_tensor.deallocate();
}

struct RepeatInterleaveParams {
int repeats = 0;
int dim = 0;
};

class RepeatInterleaveTest : public ttnn::TTNNFixtureWithDevice, public ::testing::WithParamInterface<RepeatInterleaveParams> {};

TEST_P(RepeatInterleaveTest, RunsCorrectly) {
RepeatInterleaveParams params = GetParam();
run_repeat_interleave_test(device_, params.repeats, params.dim);
}

INSTANTIATE_TEST_SUITE_P(
RepeatInterleaveWithDim0,
RepeatInterleaveTest,
::testing::Values(
RepeatInterleaveParams{1, 0},
RepeatInterleaveParams{2, 0},
RepeatInterleaveParams{3, 0}
)
);

// tests/ttnn/unit_tests/operations/test_repeat_interleave.py proves that it should work over dim 1 too
// likely need to fix the comparison in the test
INSTANTIATE_TEST_SUITE_P(
DISABLED_RepeatInterleaveWithDim1,
RepeatInterleaveTest,
::testing::Values(
RepeatInterleaveParams{1, 1},
RepeatInterleaveParams{2, 1},
RepeatInterleaveParams{3, 1}
)
);


} // namespace test
} // namespace binary
} // namespace operations
} // namespace ttnn
19 changes: 19 additions & 0 deletions tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

#include "gtest/gtest.h"

#include "ttnn/device.hpp"
#include "tests/tt_metal/test_utils/env_vars.hpp"

namespace ttnn {

class TTNNFixture : public ::testing::Test {
Expand All @@ -26,4 +29,20 @@ class TTNNFixture : public ::testing::Test {

void TearDown() override { tt::Cluster::instance().set_internal_routing_info_for_ethernet_cores(false); }
};

class TTNNFixtureWithDevice : public TTNNFixture {
protected:
tt::tt_metal::Device* device_ = nullptr;

void SetUp() override {
TTNNFixture::SetUp();
device_ = tt::tt_metal::CreateDevice(0);
}

void TearDown() override {
TTNNFixture::TearDown();
tt::tt_metal::CloseDevice(device_);
}
};

} // namespace ttnn
17 changes: 10 additions & 7 deletions tests/ttnn/unit_tests/operations/test_repeat_interleave.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,29 @@
import torch

import ttnn
from loguru import logger

from tests.ttnn.utils_for_testing import assert_with_pcc


@pytest.mark.skip(reason="ttnn.repeat_interleave only supports repeat over dim 0 or 1")
def test_repeat_interleave(device):
torch_input_tensor = torch.tensor([[1, 2], [3, 4]])
torch_result = torch.repeat_interleave(torch_input_tensor, 2, dim=0)
@pytest.mark.parametrize("repeats", [1, 2, 3])
@pytest.mark.parametrize("dim", [0, 1, 2, 3])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_repeat_interleave(device, repeats, dim, dtype):
torch_input_tensor = torch.rand(1, 1, 32, 32, dtype=dtype)
torch_result = torch.repeat_interleave(torch_input_tensor, repeats, dim=dim)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)

output = ttnn.repeat_interleave(input_tensor, 2, dim=0)
output = ttnn.repeat_interleave(input_tensor, repeats, dim=dim)
output = ttnn.to_torch(output)

assert_with_pcc(torch_result, output, 0.9999)


@pytest.mark.skip(reason="ttnn.repeat_interleave only supports repeat over dim 0 or 1")
@pytest.mark.skip(reason="ttnn.repeat_interleave only supports `repeats` as int")
def test_repeat_interleave_with_repeat_tensor(device):
torch_input_tensor = torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16)
torch_input_tensor = torch.rand(1, 2, 32, 32, dtype=torch.bfloat16)
torch_repeats = torch.tensor([1, 2])
torch_result = torch.repeat_interleave(torch_input_tensor, torch_repeats, dim=1)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void bind_binary_operation(py::module& module, const binary_operation_t& operati
* :attr:`activations` (Optional[List[str]]): list of activation functions to apply to the output tensor
* :attr:`queue_id` (Optional[uint8]): command queue id
Example::
Example:
>>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
>>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device)
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/ccl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ void py_module(py::module& module) {
* :attr:`num_links` (int): Number of links to use for the all-gather operation.
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
Example::
Example:
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = ttnn.all_gather(tensor, dim=0)
Expand Down
7 changes: 4 additions & 3 deletions ttnn/cpp/pybind11/operations/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ void py_module(py::module& module) {
* :attr:`memory_config`: the desired MemoryConfig
* :attr:`dtype`: the optional `ttnn` data type.
Example::
>>> device_id = 0
>>> device = ttnn.open_device(device_id=device_id)
>>> tensor = ttnn.to_device(ttnn.from_torch(torch.randn((10, 64, 32), dtype=torch.bfloat16)), device)
Expand All @@ -133,7 +133,7 @@ void py_module(py::module& module) {
* :attr:`tensor`: the ttnn.Tensor
* :attr:`dtype`: `ttnn` data type.
Example::
Example:
>>> tensor = ttnn.from_torch(torch.randn((10, 64, 32), dtype=torch.bfloat16))
>>> tensor = ttnn.to_dtype(tensor, dtype=ttnn.uint16)
)doc",
Expand Down Expand Up @@ -252,7 +252,8 @@ void py_module(py::module& module) {
* :attr:`dtype`: the optional output data type.
* :attr:`memory_config`: the optional output memory configuration.
* :attr:`device`: Device/DeviceMesh whose worker thread on host should be used for the layout conversion
Example::
Example:
>>> device_id = 0
>>> device = ttnn.open_device(device_id=device_id)
>>> tensor = ttnn.to_device(ttnn.from_torch(torch.randn((10, 64, 32), dtype=torch.bfloat16)), device)
Expand Down
43 changes: 39 additions & 4 deletions ttnn/cpp/pybind11/operations/data_movement.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Permutes :attr:`input_tensor` using :attr:`order`.
* :attr:`input_tensor`: the input tensor
* :attr:`order`: the desired ordering of dimensions.
Example::
Example:
>>> tensor = ttnn.to_device(ttnn.from_torch(torch.zeros((1, 1, 64, 32), dtype=torch.bfloat16)), device)
>>> output = ttnn.permute(tensor, (0, 1, 3, 2))
Expand All @@ -53,7 +53,7 @@ Concats :attr:`tensors` in the given :attr:`dim`.
Keyword Args:
* :attr:`memory_config`: the memory configuration to use for the operation
Example::
Example:
>>> tensor = ttnn.concat(ttnn.from_torch(torch.zeros((1, 1, 64, 32), ttnn.from_torch(torch.zeros((1, 1, 64, 32), dim=3)), device)
Expand All @@ -79,7 +79,9 @@ The algorithms available for upsampling are 'nearest' for now.
* :attr:`scale_factor`: multiplier for spatial size. Has to match input size if it is a tuple.
)doc",
ttnn::pybind_arguments_t{
py::arg("input_tensor"), py::arg("scale_factor"), py::arg("memory_config") = std::nullopt});
py::arg("input_tensor"),
py::arg("scale_factor"),
py::arg("memory_config") = std::nullopt});

ttnn::bind_registered_operation(
module,
Expand All @@ -96,7 +98,7 @@ Returns a new tensor filled with repetition of input :attr:`input_tensor` accord
Keyword Args:
* :attr:`memory_config`: the memory configuration to use for the operation
Example::
Example:
>>> tensor = ttnn.repeat(ttnn.from_torch(torch.tensor([[1, 2], [3, 4]]), 2,)), device)
>>> print(tensor)
Expand All @@ -107,6 +109,39 @@ Example::
)doc",
ttnn::pybind_arguments_t{
py::arg("input_tensor"), py::arg("shape"), py::kw_only(), py::arg("memory_config") = std::nullopt});

ttnn::bind_registered_operation(
module,
ttnn::repeat_interleave,
R"doc(
repeat_interleave(input_tensor: ttnn.Tensor, repeats : int, dim: int = 0) -> ttnn.Tensor
Repeats elements of a :attr:`tensor` in the given :attr:`dim`.
Args:
* :attr:`input_tensor`: the input_tensor to apply the repeate interleave operation.
* :attr:`repeats`: The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis.
* :attr:`dim`: the dimension to expand with the repetitions.
Example:
torch_input_tensor =
torch_result = torch.repeat_interleave(torch_input_tensor, repeats, dim=dim)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
output = ttnn.repeat_interleave(input_tensor, repeats, dim=dim)
>>> a = ttnn.from_torch(torch.rand(1, 1, 32, 32, dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device)
>>> b = ttnn.repeat_interleave(a, 2, dim=0)
>>> print(a.shape, b.shape)
ttnn.Shape([1, 1, 32, 32]) ttnn.Shape([2, 1, 32, 32])
)doc",
ttnn::pybind_arguments_t{
py::arg("input_tensor"),
py::arg("repeats"),
py::arg("dim"),
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

} // namespace data_movement
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/embedding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void py_module(py::module& module) {
* :attr:`layout`: the layout of the input and output tensors. Default is ttnn.ROW_MAJOR_LAYOUT.
* :attr:`memory_config`: the memory configuration of the output tensor. Default is input tensor memory config.
Example::
Example:
>>> device_id = 0
>>> device = ttnn.open_device(device_id=device_id)
>>> input_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]]), dtype=ttnn.uint32), device)
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/normalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void py_module(py::module& module) {
Keyword Args:
* :attr:`memory_config`: the memory configuration for the output tensor. If not provided, the memory configuration of the input tensor is used.
Example::
Example:
>>> tensor = ttnn.to_device(ttnn.from_torch(torch.zeros((1, 1, 64, 32), dtype=torch.bfloat16)), device)
>>> output = ttnn.softmax(tensor, -1)
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ void bind_global_avg_pool2d(py::module& module) {
Returns:
ttnn.Tensor: The tensor with the averaged values. The output tensor shape is (batch_size, channels, 1, 1).
Example::
Example:
>>> tensor = ttnn.from_torch(torch.randn((10, 3, 32, 32), dtype=ttnn.bfloat16), device=device)
>>> output = {1}(tensor)
Expand Down
8 changes: 4 additions & 4 deletions ttnn/cpp/pybind11/operations/unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void bind_unary_operation(py::module& module, const unary_operation_t& operation
Keyword Args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
Example::
Example:
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor)
Expand Down Expand Up @@ -67,7 +67,7 @@ void bind_unary_operation_with_fast_and_approximate_mode(py::module& module, con
* :attr:`fast_and_approximate_mode` (bool): "Use fast and approximate mode".
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
Example::
Example:
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor, fast_and_approximate_mode=true)
Expand Down Expand Up @@ -108,7 +108,7 @@ void bind_unary_operation_with_float_parameter(
* :attr:`{2}` (bool): {3}.
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
Example::
Example:
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor, {2}=true)
Expand Down Expand Up @@ -147,7 +147,7 @@ void bind_softplus(py::module& module) {
* :attr:`threshold` (float): Used to switch to a linear function for large values to improve numerical stability. This avoids issues with floating-point representation for very large values
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
Example::
Example:
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor, parameter=true)
Expand Down
Loading

0 comments on commit 3082585

Please sign in to comment.