From 2eaf9ec94e4dade1690be78c74f7f928fe46a8eb Mon Sep 17 00:00:00 2001 From: Akhmed Rakhmati Date: Wed, 31 Jan 2024 20:33:56 +0000 Subject: [PATCH] #4003: moved ttnn.add to C++ --- tests/ttnn/unit_tests/operations/test_add.py | 63 ++--- tt_eager/tensor/tensor.hpp | 31 ++- tt_eager/tensor/types.hpp | 229 ++++++++++++++++-- tt_eager/tt_lib/csrc/ttnn/module.hpp | 10 +- .../tt_lib/csrc/ttnn/operations/binary.hpp | 121 +++++++++ .../tt_lib/csrc/ttnn/operations/module.hpp | 24 ++ tt_eager/tt_lib/csrc/ttnn/tensor/module.hpp | 190 --------------- tt_eager/tt_lib/csrc/ttnn/types.hpp | 89 +++++++ ttnn/ttnn/__init__.py | 2 +- ttnn/ttnn/core.py | 5 +- ttnn/ttnn/decorators.py | 3 +- ttnn/ttnn/operations/binary.py | 105 +------- ttnn/ttnn/operations/core.py | 10 +- ttnn/ttnn/operations/data_movement.py | 4 +- ttnn/ttnn/types.py | 24 +- ttnn/ttnn/validation.py | 2 +- 16 files changed, 527 insertions(+), 385 deletions(-) create mode 100644 tt_eager/tt_lib/csrc/ttnn/operations/binary.hpp create mode 100644 tt_eager/tt_lib/csrc/ttnn/operations/module.hpp delete mode 100644 tt_eager/tt_lib/csrc/ttnn/tensor/module.hpp create mode 100644 tt_eager/tt_lib/csrc/ttnn/types.hpp diff --git a/tests/ttnn/unit_tests/operations/test_add.py b/tests/ttnn/unit_tests/operations/test_add.py index 75c5b5b4228..6ee43a38147 100644 --- a/tests/ttnn/unit_tests/operations/test_add.py +++ b/tests/ttnn/unit_tests/operations/test_add.py @@ -30,21 +30,6 @@ def test_add_1D_tensor_and_scalar(device, scalar, size): assert output_tensor.shape == (size,) -@pytest.mark.parametrize("alpha", [0.42]) -@pytest.mark.parametrize("scalar_input_tensor_b", [0.5]) -@pytest.mark.parametrize("h", [1]) -@pytest.mark.parametrize("w", [4]) -def test_add_scalar_and_alpha(device, alpha, scalar_input_tensor_b, h, w): - torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16) - torch_output_tensor = torch.add(torch_input_tensor, scalar_input_tensor_b, alpha=alpha) - - input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) - output_tensor = ttnn.add(input_tensor, scalar_input_tensor_b, alpha=alpha) - output_tensor = ttnn.to_torch(output_tensor) - - assert_with_pcc(torch_output_tensor, output_tensor, 0.99999) - - @pytest.mark.parametrize("h", [32]) @pytest.mark.parametrize("w", [64]) def test_add_2D_tensors(device, h, w): @@ -106,44 +91,44 @@ def test_add_4D_tensors(device, h, w): @pytest.mark.parametrize("h", [32]) @pytest.mark.parametrize("w", [64]) def test_add_with_broadcast(device, h, w): - torch_a = torch.rand((2, 16, 1, w), dtype=torch.bfloat16) - torch_b = torch.rand((2, 16, h, w), dtype=torch.bfloat16) - torch_output = torch.add(torch_a, torch_b) + torch_input_tensor_a = torch.rand((2, 16, 1, w), dtype=torch.bfloat16) + torch_input_tensor_b = torch.rand((2, 16, h, w), dtype=torch.bfloat16) + torch_output_tensor = torch.add(torch_input_tensor_a, torch_input_tensor_b) - a = ttnn.from_torch(torch_a, layout=ttnn.TILE_LAYOUT, device=device) - b = ttnn.from_torch(torch_b, layout=ttnn.TILE_LAYOUT, device=device) - tt_output = ttnn.add(a, b) - tt_output = ttnn.to_torch(tt_output) + input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device) + input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device) + output_tensor = ttnn.add(input_tensor_a, input_tensor_b) + output_tensor = ttnn.to_torch(output_tensor) - assert_with_pcc(torch_output, tt_output, 0.9999) + assert_with_pcc(torch_output_tensor, output_tensor, 0.9999) @pytest.mark.parametrize("h", [500]) @pytest.mark.parametrize("w", [512]) def test_expand_and_broadcast(device, h, w): - torch_a = torch.rand((1, h, w), dtype=torch.bfloat16) - torch_b = torch.rand((h, w), dtype=torch.bfloat16) - torch_output = torch.add(torch_a, torch_b) + torch_input_tensor_a = torch.rand((1, h, w), dtype=torch.bfloat16) + torch_input_tensor_b = torch.rand((h, w), dtype=torch.bfloat16) + torch_output_tensor = torch.add(torch_input_tensor_a, torch_input_tensor_b) - a = ttnn.from_torch(torch_a, layout=ttnn.TILE_LAYOUT, device=device) - b = ttnn.from_torch(torch_b, layout=ttnn.TILE_LAYOUT, device=device) - tt_output = ttnn.add(a, b) - tt_output = ttnn.to_torch(tt_output) + input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device) + input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device) + output_tensor = ttnn.add(input_tensor_a, input_tensor_b) + output_tensor = ttnn.to_torch(output_tensor) - assert_with_pcc(torch_output, tt_output, 0.9999) + assert_with_pcc(torch_output_tensor, output_tensor, 0.9999) @pytest.mark.skip(reason="4005: Unable to broadcast on batch or seq dimension") @pytest.mark.parametrize("h", [32]) @pytest.mark.parametrize("w", [64]) def test_add_with_broadcast_on_batch(device, h, w): - torch_a = torch.rand((1, 16, 1, w), dtype=torch.bfloat16) - torch_b = torch.rand((2, 16, h, w), dtype=torch.bfloat16) - torch_output = torch.add(torch_a, torch_b) + torch_input_tensor_a = torch.rand((1, 16, 1, w), dtype=torch.bfloat16) + torch_input_tensor_b = torch.rand((2, 16, h, w), dtype=torch.bfloat16) + torch_output_tensor = torch.add(torch_input_tensor_a, torch_input_tensor_b) - a = ttnn.from_torch(torch_a, layout=ttnn.TILE_LAYOUT, device=device) - b = ttnn.from_torch(torch_b, layout=ttnn.TILE_LAYOUT, device=device) - tt_output = ttnn.add(a, b) - tt_output = ttnn.to_torch(tt_output) + input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device) + input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device) + output_tensor = ttnn.add(input_tensor_a, input_tensor_b) + output_tensor = ttnn.to_torch(output_tensor) - assert_with_pcc(torch_output, tt_output, 0.9999) + assert_with_pcc(torch_output_tensor, output_tensor, 0.9999) diff --git a/tt_eager/tensor/tensor.hpp b/tt_eager/tensor/tensor.hpp index 2bfd6282997..7d2c26d5b05 100644 --- a/tt_eager/tensor/tensor.hpp +++ b/tt_eager/tensor/tensor.hpp @@ -4,19 +4,19 @@ #pragma once -#include #include #include #include +#include +#include -#include "tensor/types.hpp" -#include "tt_metal/impl/device/device.hpp" -#include "tt_metal/impl/buffers/buffer.hpp" -#include "common/test_tiles.hpp" -#include "common/tt_backend_api_types.hpp" #include "common/bfloat16.hpp" #include "common/bfloat8.hpp" - +#include "common/test_tiles.hpp" +#include "common/tt_backend_api_types.hpp" +#include "tensor/types.hpp" +#include "tt_metal/impl/buffers/buffer.hpp" +#include "tt_metal/impl/device/device.hpp" #include "tt_metal/tt_stl/reflection.hpp" namespace tt { @@ -139,3 +139,20 @@ void memcpy(Tensor &dst, const Tensor &src); } // namespace tt_metal } // namespace tt + +namespace ttnn { +namespace types { + +struct Tensor { + const tt::tt_metal::Tensor value; + const ttnn::Shape shape; + + explicit Tensor(tt::tt_metal::Tensor &&tensor) : value{tensor}, shape{ttnn::Shape(tensor.shape())} {} + explicit Tensor(const tt::tt_metal::Tensor &tensor) : value{tensor}, shape{ttnn::Shape(tensor.shape())} {} +}; + +} // namespace types + +using types::Tensor; + +} // namespace ttnn diff --git a/tt_eager/tensor/types.hpp b/tt_eager/tensor/types.hpp index 64f57ba6b78..1ad43825b05 100644 --- a/tt_eager/tensor/types.hpp +++ b/tt_eager/tensor/types.hpp @@ -118,26 +118,21 @@ class Shape { explicit Shape(const Shape&, const Padding&); template - explicit Shape( - const std::array& shape, - const std::optional>& padded_shape = std::nullopt) : - rank_(Rank), dimensions_{}, padding_{Rank} { - if (padded_shape.has_value()) { - TT_ASSERT(shape.size() == padded_shape.value().size()); - for (auto index = 0; index < Rank; index++) { - auto padded_dimension = padded_shape.value()[index]; - this->dimensions_[index] = padded_dimension; - this->padding_[index] = {.front = 0, .back = padded_dimension - shape[index]}; - } - } else { - for (auto index = 0; index < Rank; index++) { - this->dimensions_[index] = shape[index]; - } + Shape(const std::array &shape) : rank_(Rank), dimensions_{}, padding_{Rank} { + for (auto index = 0; index < Rank; index++) { + this->dimensions_[index] = shape[index]; } } - // Add an implicit constructor from 4D array due to legacy code - Shape(const std::array& shape) : Shape(shape, std::optional>{std::nullopt}) {} + template + explicit Shape(const std::array &shape, const std::array &padded_shape) : + rank_(Rank), dimensions_{}, padding_{Rank} { + for (auto index = 0; index < Rank; index++) { + auto padded_dimension = padded_shape[index]; + this->dimensions_[index] = padded_dimension; + this->padding_[index] = {.front = 0, .back = padded_dimension - shape[index]}; + } + } std::size_t rank() const; @@ -282,3 +277,203 @@ bool operator!=(const ShardSpec& spec_a, const ShardSpec& spec_b); } // namespace tt_metal } // namespace tt + +namespace ttnn { +namespace types { + +namespace detail { +template +static tt::tt_metal::Shape compute_ttl_shape( + const std::array &shape, const std::array, Rank> &padding) { + auto ttl_shape = std::array{}; + for (auto index = 0; index < Rank; index++) { + ttl_shape[index] = shape[index] + padding[index][0] + padding[index][1]; + } + return tt::tt_metal::Shape{ + tt::tt_metal::Shape{ttl_shape}, tt::tt_metal::Padding{padding, tt::tt_metal::Padding::PadValue::Any}}; +} + +} // namespace detail + +template +struct RankedShape { + const std::size_t rank; + const tt::tt_metal::Shape value; + + explicit RankedShape(tt::tt_metal::Shape &&shape) : rank{Rank}, value(shape) {} + explicit RankedShape(const tt::tt_metal::Shape &shape) : rank{Rank}, value(shape) {} + + explicit RankedShape(const std::array &shape) : rank{Rank}, value{shape} {} + + explicit RankedShape(const std::array &shape, const std::array &padded_shape) : + rank{Rank}, value{shape, padded_shape} {} + + explicit RankedShape( + const std::array &shape, const std::array, Rank> &padding) : + rank{Rank}, value{detail::compute_ttl_shape(shape, padding)} {} + + RankedShape padded() const { + return RankedShape{tt::tt_metal::Shape{this->value, tt::tt_metal::Padding{this->value.rank()}}}; + } + + RankedShape operator+(const std::array, Rank> &padding) const { + auto shape = this->value; + const auto ¤t_padding = this->value.padding(); + auto accumulated_padding = padding; + for (auto index = 0; index < Rank; index++) { + shape[index] += padding[index][0] + padding[index][1]; + accumulated_padding[index][0] += current_padding[index].front; + accumulated_padding[index][1] += current_padding[index].back; + } + return RankedShape{tt::tt_metal::Shape{ + shape, tt::tt_metal::Padding{accumulated_padding, tt::tt_metal::Padding::PadValue::Any}}}; + } + + template + RankedShape operator+(const std::array, OtherRank> &padding) const { + TT_THROW("Invalid padding"); + } + + bool operator==(const RankedShape &other) const { return this->value == other.value; } + + template + bool operator==(const RankedShape &other) const { + return false; + } + + const auto &operator[](std::int64_t index) const { return this->value.without_padding()[index]; } +}; + +template +static std::ostream &operator<<(std::ostream &os, const RankedShape &self) { + os << "ttnn.Shape(["; + const auto shape = self.value.without_padding(); + const auto &padding = self.value.padding(); + const auto &padded_shape = self.value; + for (auto i = 0; i < Rank; ++i) { + if (i > 0) { + os << ", "; + } + if (padding[i].front > 0) { + os << padding[i].front << " + "; + } + os << shape[i]; + if (padding[i].back > 0) { + os << " + " << padding[i].back; + } + } + os << "])"; + return os; +} + +struct Shape { + using RankedShapeVariant = std::variant< + const RankedShape<1>, + const RankedShape<2>, + const RankedShape<3>, + const RankedShape<4>, + const RankedShape<5>, + const RankedShape<6>, + const RankedShape<7>, + const RankedShape<8>>; + + const RankedShapeVariant ranked_shape; + + private: + RankedShapeVariant ttl_shape_to_ttnn_shape(const tt::tt_metal::Shape &shape) { + switch (shape.rank()) { + case 1: return RankedShape<1>{shape}; + case 2: return RankedShape<2>{shape}; + case 3: return RankedShape<3>{shape}; + case 4: return RankedShape<4>{shape}; + case 5: return RankedShape<5>{shape}; + case 6: return RankedShape<6>{shape}; + case 7: return RankedShape<7>{shape}; + case 8: return RankedShape<8>{shape}; + }; + TT_THROW("Unsupported rank"); + } + + public: + explicit Shape(const tt::tt_metal::Shape &shape) : ranked_shape{ttl_shape_to_ttnn_shape(shape)} {} + + template + explicit Shape(const RankedShape &shape) : ranked_shape{shape} {} + + template + explicit Shape(const std::array &shape) : ranked_shape{RankedShape{shape}} {} + + template + explicit Shape(const std::array &shape, const std::array &padded_shape) : + ranked_shape{RankedShape{shape, padded_shape}} {} + + template + explicit Shape(const std::array &shape, const std::array, Rank> &padding) : + ranked_shape{RankedShape{shape, padding}} {} + + const auto rank() const { + return std::visit( + [](const RankedShape &shape) -> const auto { return Rank; }, this->ranked_shape); + } + + Shape padded() const { + return std::visit([](const auto &shape) -> Shape { return Shape(shape.padded()); }, this->ranked_shape); + } + + template + Shape operator+(const std::array, Rank> &padding) const { + return std::visit( + [&padding](const auto &shape) -> Shape { return Shape(shape + padding); }, this->ranked_shape); + } + + bool operator==(const Shape &other) const { + return std::visit( + [](const auto &shape, const auto &other) -> bool { return shape == other; }, + this->ranked_shape, + other.ranked_shape); + } + + const auto &operator[](std::int64_t index) const { + return std::visit([index](const auto &shape) -> decltype(auto) { return shape[index]; }, this->ranked_shape); + } + + const auto &value() const { + return std::visit([](const auto &shape) -> const auto & { return shape.value; }, this->ranked_shape); + } + + template + const Shape to_rank() const { + return std::visit( + [](const RankedShape &shape) { + if constexpr (Rank == NewRank) { + return Shape(shape); + } else { + auto num_missing_dims = NewRank - Rank; + + std::array new_shape{}; + std::array new_padded_shape{}; + + new_shape.fill(1); + new_padded_shape.fill(1); + + for (auto index = 0; index < Rank; index++) { + new_shape[index + num_missing_dims] = shape[index]; + new_padded_shape[index + num_missing_dims] = shape.padded()[index]; + } + return Shape(RankedShape(new_shape, new_padded_shape)); + } + }, + this->ranked_shape); + } +}; + +static std::ostream &operator<<(std::ostream &os, const Shape &self) { + std::visit([&os](const auto &shape) { os << shape; }, self.ranked_shape); + return os; +} + +} // namespace types + +using types::Shape; + +} // namespace ttnn diff --git a/tt_eager/tt_lib/csrc/ttnn/module.hpp b/tt_eager/tt_lib/csrc/ttnn/module.hpp index 5255e2ef974..101c1bc0b2f 100644 --- a/tt_eager/tt_lib/csrc/ttnn/module.hpp +++ b/tt_eager/tt_lib/csrc/ttnn/module.hpp @@ -6,15 +6,19 @@ #include -#include "tensor/module.hpp" +#include "operations/module.hpp" +#include "types.hpp" namespace py = pybind11; namespace ttnn { void py_module(py::module& m_ttnn) { - auto m_tensor = m_ttnn.def_submodule("tensor", "Tensor"); - tensor::py_module(m_tensor); + auto m_types = m_ttnn.def_submodule("types", "ttnn Types"); + types::py_module(m_types); + + auto m_operations = m_ttnn.def_submodule("operations", "ttnn Operations"); + operations::py_module(m_operations); } } // namespace ttnn diff --git a/tt_eager/tt_lib/csrc/ttnn/operations/binary.hpp b/tt_eager/tt_lib/csrc/ttnn/operations/binary.hpp new file mode 100644 index 00000000000..c85b182744e --- /dev/null +++ b/tt_eager/tt_lib/csrc/ttnn/operations/binary.hpp @@ -0,0 +1,121 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp" + +namespace py = pybind11; + +namespace ttnn { + +static const auto DRAM_MEMORY_CONFIG = tt::tt_metal::MemoryConfig{ + .memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED, .buffer_type = tt::tt_metal::BufferType::DRAM}; +static const auto L1_MEMORY_CONFIG = tt::tt_metal::MemoryConfig{ + .memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED, .buffer_type = tt::tt_metal::BufferType::L1}; + +ttnn::Tensor reshape(const ttnn::Tensor& tensor, const ttnn::Shape& shape) { + return ttnn::Tensor(tensor.value.reshape(shape.value())); +} + +ttnn::Tensor unsqueeze_to_4D(const ttnn::Tensor& tensor) { + const auto& tensor_shape = tensor.shape; + const auto rank = tensor_shape.rank(); + if (rank == 4) { + return tensor; + } + if (rank > 4) { + TT_THROW("Tensor rank is greater than 4"); + } + + const auto tensor_shape_4D = tensor_shape.to_rank<4>(); + return ttnn::reshape(tensor, tensor_shape_4D); +} + +namespace operations { +namespace binary { + +void py_module(py::module& m_binary) { + m_binary.def( + "add", + [](const ttnn::Tensor& input_tensor_a, + const ttnn::Tensor& input_tensor_b, + const tt::tt_metal::MemoryConfig& memory_config) { + const auto& original_shape = input_tensor_a.shape; + const auto& input_shape_b = input_tensor_b.shape; + + std::size_t height_b{}; + std::size_t width_b{}; + if (input_shape_b.rank() == 1) { + height_b = 1; + width_b = input_shape_b[-1]; + } else { + height_b = input_shape_b[-2]; + width_b = input_shape_b[-1]; + } + + auto input_tensor_a_4D = ttnn::unsqueeze_to_4D(input_tensor_a); + auto input_tensor_b_4D = ttnn::unsqueeze_to_4D(input_tensor_b); + + const auto& ttl_input_tensor_a = input_tensor_a_4D.value; + const auto& ttl_input_tensor_b = input_tensor_b_4D.value; + + if (height_b == 1 or width_b == 1) { + tt::tt_metal::BcastOpDim bcast_op_dim; + if (height_b == 1 and width_b == 1) { + bcast_op_dim = tt::tt_metal::BcastOpDim::HW; + } else if (height_b == 1) { + bcast_op_dim = tt::tt_metal::BcastOpDim::H; + } else if (width_b == 1) { + bcast_op_dim = tt::tt_metal::BcastOpDim::W; + } else { + TT_THROW("Invalid broadcasting dimensions"); + } + auto ttl_output = tt::tt_metal::bcast( + ttl_input_tensor_a, + ttl_input_tensor_b, + tt::tt_metal::BcastOpMath::ADD, + bcast_op_dim, + memory_config); + auto output = ttnn::Tensor(ttl_output); + return ttnn::reshape(output, original_shape); + } else { + auto ttl_output = + tt::tt_metal::add(ttl_input_tensor_a, ttl_input_tensor_b, std::nullopt, memory_config); + auto output = ttnn::Tensor(ttl_output); + return ttnn::reshape(output, original_shape); + } + }, + py::arg("input_tensor_a"), + py::arg("input_tensor_b"), + py::kw_only(), + py::arg("memory_config") = DRAM_MEMORY_CONFIG + + ); + + m_binary.def( + "add", + [](const ttnn::Tensor& input_tensor_a, + const float input_tensor_b, + const tt::tt_metal::MemoryConfig& memory_config) { + const auto& original_shape = input_tensor_a.shape; + + auto input_tensor_a_4D = ttnn::unsqueeze_to_4D(input_tensor_a); + const auto& ttl_input_tensor_a = input_tensor_a_4D.value; + + auto ttl_output = tt::tt_metal::add_unary(ttl_input_tensor_a, input_tensor_b, memory_config); + auto output = ttnn::Tensor(ttl_output); + return ttnn::reshape(output, original_shape); + }, + py::arg("input_tensor_a"), + py::arg("input_tensor_b"), + py::kw_only(), + py::arg("memory_config") = DRAM_MEMORY_CONFIG); +} + +} // namespace binary +} // namespace operations +} // namespace ttnn diff --git a/tt_eager/tt_lib/csrc/ttnn/operations/module.hpp b/tt_eager/tt_lib/csrc/ttnn/operations/module.hpp new file mode 100644 index 00000000000..b37200eaba9 --- /dev/null +++ b/tt_eager/tt_lib/csrc/ttnn/operations/module.hpp @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "binary.hpp" + +namespace py = pybind11; + +namespace ttnn { + +namespace operations { + +void py_module(py::module& m_operations) { + auto m_binary = m_operations.def_submodule("binary", "binary operations"); + binary::py_module(m_binary); +} + +} // namespace operations + +} // namespace ttnn diff --git a/tt_eager/tt_lib/csrc/ttnn/tensor/module.hpp b/tt_eager/tt_lib/csrc/ttnn/tensor/module.hpp deleted file mode 100644 index a0980bdf066..00000000000 --- a/tt_eager/tt_lib/csrc/ttnn/tensor/module.hpp +++ /dev/null @@ -1,190 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include - -#include "tensor/tensor.hpp" - -namespace py = pybind11; - -namespace ttnn { -namespace tensor { - -namespace detail { -template -tt::tt_metal::Shape compute_ttl_shape( - const std::array& shape, const std::array, Rank>& padding) { - auto ttl_shape = std::array{}; - for (auto index = 0; index < Rank; index++) { - ttl_shape[index] = shape[index] + padding[index][0] + padding[index][1]; - } - return tt::tt_metal::Shape{ - tt::tt_metal::Shape{ttl_shape}, tt::tt_metal::Padding{padding, tt::tt_metal::Padding::PadValue::Any}}; -} -} // namespace detail - -struct Shape { - const std::size_t rank; - const tt::tt_metal::Shape value; - - explicit Shape(tt::tt_metal::Shape&& shape) : rank{shape.rank()}, value(shape) {} - explicit Shape(const tt::tt_metal::Shape& shape) : rank{shape.rank()}, value(shape) {} - - template - explicit Shape( - const std::array& shape, const std::optional>& padded_shape) : - rank{Rank}, value{shape, padded_shape} {} - - template - explicit Shape(const std::array& shape, const std::array, Rank>& padding) : - rank{Rank}, value{detail::compute_ttl_shape(shape, padding)} {} - - Shape padded() const { return Shape{tt::tt_metal::Shape{this->value, Padding{this->value.rank()}}}; } - - template - Shape operator+(const std::array, Rank>& padding) const { - auto shape = this->value; - const auto& current_padding = this->value.padding(); - auto accumulated_padding = padding; - for (auto index = 0; index < Rank; index++) { - shape[index] += padding[index][0] + padding[index][1]; - accumulated_padding[index][0] += current_padding[index].front; - accumulated_padding[index][1] += current_padding[index].back; - } - return Shape{tt::tt_metal::Shape{ - shape, tt::tt_metal::Padding{accumulated_padding, tt::tt_metal::Padding::PadValue::Any}}}; - } - - bool operator==(const Shape& other) const { return this->value == other.value; } -}; - -std::ostream& operator<<(std::ostream& os, const Shape& self) { - os << "ttnn.Shape(["; - const auto shape = self.value.without_padding(); - const auto& padding = self.value.padding(); - const auto& padded_shape = self.value; - for (auto i = 0; i < shape.rank(); ++i) { - if (i > 0) { - os << ", "; - } - if (padding[i].front > 0) { - os << padding[i].front << " + "; - } - os << shape[i]; - if (padding[i].back > 0) { - os << " + " << padding[i].back; - } - } - os << "])"; - return os; -} - -struct Tensor { - const tt::tt_metal::Tensor value; - - explicit Tensor(tt::tt_metal::Tensor&& tensor) : value{tensor} {} - explicit Tensor(const tt::tt_metal::Tensor& tensor) : value{tensor} {} -}; - -void py_module(py::module& m_tensor) { - py::class_(m_tensor, "Shape") - .def(py::init()) - .def( - py::init&, const std::array, 1>&>(), - py::arg("shape"), - py::arg("padding")) - .def( - py::init&, const std::array, 2>&>(), - py::arg("shape"), - py::arg("padding")) - .def( - py::init&, const std::array, 3>&>(), - py::arg("shape"), - py::arg("padding")) - .def( - py::init&, const std::array, 4>&>(), - py::arg("shape"), - py::arg("padding")) - .def( - py::init&, const std::optional>&>(), - py::arg("shape"), - py::arg("padded_shape") = std::nullopt) - .def( - py::init&, const std::optional>&>(), - py::arg("shape"), - py::arg("padded_shape") = std::nullopt) - .def( - py::init&, const std::optional>&>(), - py::arg("shape"), - py::arg("padded_shape") = std::nullopt) - .def( - py::init&, const std::optional>&>(), - py::arg("shape"), - py::arg("padded_shape") = std::nullopt) - .def( - "__add__", - [](const Shape& self, const std::array, 1>& padding) { return self + padding; }) - .def( - "__add__", - [](const Shape& self, const std::array, 2>& padding) { return self + padding; }) - .def( - "__add__", - [](const Shape& self, const std::array, 3>& padding) { return self + padding; }) - .def( - "__add__", - [](const Shape& self, const std::array, 4>& padding) { return self + padding; }) - .def_property_readonly("value", [](const Shape& self) { return self.value; }) - .def("__len__", [](const Shape& self) { return self.value.rank(); }) - .def( - "__getitem__", - [](const Shape& self, std::int64_t index) { - auto shape = self.value.without_padding(); - return shape[index]; - }) - .def("__iter__", [](const Shape& self) { return py::iter(py::cast(self.value.without_padding())); }) - .def("__eq__", [](const Shape& self, const Shape& other) { return self == other; }) - .def( - "__eq__", - [](const Shape& self, const std::array& other) { - return Shape{self.value.without_padding()} == Shape{tt::tt_metal::Shape{other}}; - }) - .def( - "__eq__", - [](const Shape& self, const std::array& other) { - return Shape{self.value.without_padding()} == Shape{tt::tt_metal::Shape{other}}; - }) - .def( - "__eq__", - [](const Shape& self, const std::array& other) { - return Shape{self.value.without_padding()} == Shape{tt::tt_metal::Shape{other}}; - }) - .def( - "__eq__", - [](const Shape& self, const std::array& other) { - return Shape{self.value.without_padding()} == Shape{tt::tt_metal::Shape{other}}; - }) - .def("__eq__", [](const Shape& self, const py::none) { return false; }) - .def( - "__repr__", - [](const Shape& self) { - std::stringstream ss; - ss << self; - return ss.str(); - }) - .def_property_readonly("rank", [](const Shape& self) { return self.rank; }) - .def("padded", [](const Shape& self) { return self.padded(); }); - - py::class_(m_tensor, "Tensor") - .def(py::init()) - .def_property_readonly("value", [](const Tensor& self) -> auto& { return self.value; }) - .def("__repr__", [](const Tensor& self) { return self.value.write_to_string(Layout::ROW_MAJOR, true); }) - .def_property_readonly("shape", [](const Tensor& self) { return py::cast(Shape{self.value.shape()}); }) - .def_property_readonly("dtype", [](const Tensor& self) { return self.value.dtype(); }) - .def_property_readonly("layout", [](const Tensor& self) { return self.value.layout(); }); -} - -} // namespace tensor -} // namespace ttnn diff --git a/tt_eager/tt_lib/csrc/ttnn/types.hpp b/tt_eager/tt_lib/csrc/ttnn/types.hpp new file mode 100644 index 00000000000..417a5d12b37 --- /dev/null +++ b/tt_eager/tt_lib/csrc/ttnn/types.hpp @@ -0,0 +1,89 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "tensor/tensor.hpp" + +namespace py = pybind11; + +namespace ttnn { +namespace types { + +void py_module(py::module& m_types) { + auto PyShape = py::class_(m_types, "Shape"); + PyShape.def(py::init()); + + [&PyShape](std::index_sequence) { + ( + [&PyShape]() { + if constexpr (Ns > 0) { + PyShape.def(py::init&>(), py::arg("shape")); + + PyShape.def( + py::init&, const std::array, Ns>&>(), + py::arg("shape"), + py::arg("padding")); + + PyShape.def( + py::init&, std::array&>(), + py::arg("shape"), + py::arg("padded_shape")); + + PyShape.def( + "__add__", [](const Shape& self, const std::array, Ns>& padding) { + return self + padding; + }); + + PyShape.def("__eq__", [](const Shape& self, const std::array& other) { + return Shape{self.value().without_padding()} == Shape{tt::tt_metal::Shape{other}}; + }); + } + }(), + ...); + }(std::make_index_sequence<8>()); + + PyShape.def_property_readonly("value", [](const Shape& self) { return self.value(); }); + PyShape.def("__len__", [](const Shape& self) { return self.rank(); }); + PyShape.def("__getitem__", [](const Shape& self, std::int64_t index) { return self[index]; }); + PyShape.def("__iter__", [](const Shape& self) { return py::iter(py::cast(self.value().without_padding())); }); + PyShape.def("__eq__", [](const Shape& self, const Shape& other) { return self == other; }); + PyShape.def("__eq__", [](const Shape& self, const py::none) { return false; }); + PyShape.def("__repr__", [](const Shape& self) { + std::stringstream ss; + ss << self; + return ss.str(); + }); + PyShape.def_property_readonly("rank", [](const Shape& self) -> std::size_t { return self.rank(); }); + PyShape.def("padded", [](const Shape& self) { return self.padded(); }); + + py::class_(m_types, "Tensor") + .def(py::init()) + .def_property_readonly("value", [](const Tensor& self) -> auto& { return self.value; }) + .def("__repr__", [](const Tensor& self) { return self.value.write_to_string(Layout::ROW_MAJOR, true); }) + .def_property_readonly("shape", [](const Tensor& self) { return py::cast(Shape{self.value.shape()}); }) + .def_property_readonly("dtype", [](const Tensor& self) { return self.value.dtype(); }) + .def_property_readonly("layout", [](const Tensor& self) { return self.value.layout(); }) + .def_property_readonly( + "device", + [](const Tensor& self) -> Device* { + if (self.value.storage_type() == tt::tt_metal::StorageType::DEVICE) { + return self.value.device(); + } else { + throw std::runtime_error("Tensor is not on device!"); + } + }) + .def("is_contiguous", [](const Tensor& self) -> bool { + if (self.value.layout() == tt::tt_metal::Layout::ROW_MAJOR) { + return self.value.shape() == self.value.shape().without_padding(); + } else { + return false; + } + }); +} + +} // namespace types +} // namespace ttnn diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 3207cfeb6f1..df5ca653428 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -29,10 +29,10 @@ DEVICE_STORAGE_TYPE, Shape, Tensor, - has_storage_type_of, ) from ttnn.core import ( + has_storage_type_of, has_padding, is_sharded, get_memory_config, diff --git a/ttnn/ttnn/core.py b/ttnn/ttnn/core.py index 394d5942501..ce55e766078 100644 --- a/ttnn/ttnn/core.py +++ b/ttnn/ttnn/core.py @@ -7,7 +7,6 @@ import tt_lib as ttl from ttnn.types import ( - has_storage_type_of, DEVICE_STORAGE_TYPE, MemoryConfig, ShardStrategy, @@ -18,6 +17,10 @@ ) +def has_storage_type_of(tensor: "ttnn.Tensor", storage_type) -> bool: + return tensor.value.storage_type() == storage_type + + def is_sharded(tensor) -> bool: return tensor.value.is_sharded() diff --git a/ttnn/ttnn/decorators.py b/ttnn/ttnn/decorators.py index d7d79996cae..ad45b78d382 100644 --- a/ttnn/ttnn/decorators.py +++ b/ttnn/ttnn/decorators.py @@ -184,7 +184,8 @@ def call_wrapper(*function_args, **function_kwargs): decorated_function = validate_decorator(decorated_function) if ENABLE_DEBUG_DECORATOR: decorated_function = debug_decorator(decorated_function) - return decorated_function(*function_args, **function_kwargs) + output = decorated_function(*function_args, **function_kwargs) + return output return call_wrapper diff --git a/ttnn/ttnn/operations/binary.py b/ttnn/ttnn/operations/binary.py index 9df7c474951..abca0b25f1d 100644 --- a/ttnn/ttnn/operations/binary.py +++ b/ttnn/ttnn/operations/binary.py @@ -140,16 +140,15 @@ def add( input_tensor_a: ttnn.Tensor, input_tensor_b: Union[ttnn.Tensor, int, float], *, - alpha: Union[int, float] = 1, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG, ) -> ttnn.Tensor: r""" - add(input_tensor_a: ttnn.Tensor, input_tensor_b: Union[ttnn.Tensor, int, float], *, alpha: Union[int, float]=1, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG) -> ttnn.Tensor: + add(input_tensor_a: ttnn.Tensor, input_tensor_b: Union[ttnn.Tensor, int, float], *, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG) -> ttnn.Tensor: - Adds :attr:`input_tensor_b`, scaled by :attr:`alpha`, to :attr:`input_tensor_a` and returns the tensor with the same layout as :attr:`input_tensor_a` + Adds :attr:`input_tensor_a` to :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a` .. math:: - \mathrm{{input\_tensor\_a}}_i + \mathrm{{alpha}} \times \mathrm{{input\_tensor\_b}}_i + \mathrm{{input\_tensor\_a}}_i + \mathrm{{input\_tensor\_b}}_i Supports broadcasting. @@ -158,110 +157,22 @@ def add( * :attr:`input_tensor_b` (ttnn.Tensor or Number): the tensor or number to add to :attr:`input_tensor_a`. Keyword args: - :attr:`alpha` (Number): the multiplier for :attr:`input_tensor_b`. + :attr:`memory_config` (ttnn.MemoryConfig): memory config for the output tensor Example:: >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device) - >>> output = ttnn.add(tensor1, tensor2, alpha=2) + >>> output = ttnn.add(tensor1, tensor2) >>> print(output) - ttnn.Tensor([ 1, 4], dtype=bfloat16 ) + ttnn.Tensor([ 1, 3], dtype=bfloat16 ) """ - - if not isinstance(input_tensor_a, ttnn.Tensor): - raise TypeError("Expected first argument to be a ttnn.Tensor") - # Swap tensors if input_tensor_a needs to be broadcasted to input_tensor_b - if ( - isinstance(input_tensor_a, ttnn.Tensor) - and isinstance(input_tensor_b, ttnn.Tensor) - and math.prod(input_tensor_a.shape) < math.prod(input_tensor_b.shape) - ): + if isinstance(input_tensor_b, ttnn.Tensor) and math.prod(input_tensor_a.shape) < math.prod(input_tensor_b.shape): input_tensor_a, input_tensor_b = input_tensor_b, input_tensor_a - original_shape = input_tensor_a.shape - input_tensor_a = ttnn.unsqueeze_to_4D(input_tensor_a) - ttl_input_tensor_a = input_tensor_a.value - - if not ttnn.has_storage_type_of(input_tensor_a, ttl.tensor.StorageType.DEVICE): - raise RuntimeError("input_tensor_a must be on device!") - - if _is_scalar(input_tensor_b): - output_tensor = ttnn.Tensor( - ttl.tensor.add_unary( - ttl_input_tensor_a, - input_tensor_b * alpha, - output_mem_config=memory_config, - ) - ) - return ttnn.reshape(output_tensor, original_shape) - elif isinstance(input_tensor_b, ttnn.Tensor): - input_shape_b = input_tensor_b.shape - - if len(input_shape_b) == 1: - height_b = 1 - (width_b,) = input_shape_b - else: - *_, height_b, width_b = input_shape_b - - input_tensor_b = ttnn.unsqueeze_to_4D(input_tensor_b) - ttl_input_tensor_b = input_tensor_b.value - if ttl_input_tensor_b.storage_type() != ttl.tensor.StorageType.DEVICE: - raise RuntimeError("input_tensor_a must be on device!") - else: - raise TypeError("Expected second argument to be a ttnn.Tensor or a scalar") - - ttl_input_tensor_b = input_tensor_b.value - if alpha != 1: - ttl_input_tensor_b = ttl.tensor.mul_unary( - ttl_input_tensor_b, - alpha, - output_mem_config=memory_config, - ) - - if height_b == 1 and width_b == 1: - output_tensor = ttnn.Tensor( - ttl.tensor.bcast( - ttl_input_tensor_a, - ttl_input_tensor_b, - ttl.tensor.BcastOpMath.ADD, - ttl.tensor.BcastOpDim.HW, - output_mem_config=memory_config, - ) - ) - elif height_b == 1: - output_tensor = ttnn.Tensor( - ttl.tensor.bcast( - ttl_input_tensor_a, - ttl_input_tensor_b, - ttl.tensor.BcastOpMath.ADD, - ttl.tensor.BcastOpDim.H, - output_mem_config=memory_config, - ) - ) - elif width_b == 1: - output_tensor = ttnn.Tensor( - ttl.tensor.bcast( - ttl_input_tensor_a, - ttl_input_tensor_b, - ttl.tensor.BcastOpMath.ADD, - ttl.tensor.BcastOpDim.W, - output_mem_config=memory_config, - ) - ) - else: - output_tensor = ttnn.Tensor( - ttl.tensor.add( - ttl_input_tensor_a, - ttl_input_tensor_b, - output_mem_config=memory_config, - ) - ) - - output_tensor = ttnn.reshape(output_tensor, original_shape) - return output_tensor + return ttl.ttnn.operations.binary.add(input_tensor_a, input_tensor_b, memory_config=memory_config) def _sub_validate_input_tensors(operation_name, input_tensor_a, input_tensor_b, *args, **kwargs): diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index 07fc780b7b5..5b0c07f0126 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -190,13 +190,15 @@ def unsqueeze_to_4D(tensor): raise RuntimeError("Tensor cannot have more than 4 dimensions!") num_missing_dims = 4 - len(tensor.shape) shape = tuple(tensor.shape) - full_shape = tuple(tensor.shape.padded()) + padded_shape = tuple(tensor.shape.padded()) shape = (1,) * num_missing_dims + shape - full_shape = (1,) * num_missing_dims + full_shape - return ttnn.reshape(tensor, shape=ttnn.Shape(shape, full_shape)) + padded_shape = (1,) * num_missing_dims + padded_shape + return ttnn.reshape(tensor, shape=ttnn.Shape(shape, padded_shape)) -def squeeze(tensor): +def squeeze(tensor, dim): + if dim != 0: + raise RuntimeError("Only dim=0 is supported for squeeze operation!") if len(tensor.shape) == 1: return tensor if len(tensor.shape) > 4: diff --git a/ttnn/ttnn/operations/data_movement.py b/ttnn/ttnn/operations/data_movement.py index 5b9343967b2..2af847c981a 100644 --- a/ttnn/ttnn/operations/data_movement.py +++ b/ttnn/ttnn/operations/data_movement.py @@ -144,7 +144,7 @@ def permute(input_tensor: ttnn.Tensor, order: Tuple[int, ...]) -> ttnn.Tensor: rank_should_be_updated = len(output_tensor.shape) > rank while rank_should_be_updated: prior_rank = len(output_tensor.shape) - output_tensor = ttnn.squeeze(output_tensor) + output_tensor = ttnn.squeeze(output_tensor, dim=0) rank_should_be_updated = prior_rank != len(output_tensor.shape) and len(output_tensor.shape) > rank if on_device and not ttnn.has_storage_type_of(output_tensor, ttnn.DEVICE_STORAGE_TYPE): @@ -260,7 +260,7 @@ def convert_to_ttl_tensor(tensor): rank_should_be_updated = len(output_tensor.shape) > rank while rank_should_be_updated: prior_rank = len(output_tensor.shape) - output_tensor = ttnn.squeeze(output_tensor) + output_tensor = ttnn.squeeze(output_tensor, dim=0) rank_should_be_updated = prior_rank != len(output_tensor.shape) and len(output_tensor.shape) > rank return output_tensor else: diff --git a/ttnn/ttnn/types.py b/ttnn/ttnn/types.py index 55b8f64a1c7..298991176fc 100644 --- a/ttnn/ttnn/types.py +++ b/ttnn/ttnn/types.py @@ -34,30 +34,10 @@ TILE_SIZE = 32 -Shape = ttl.ttnn.tensor.Shape +Shape = ttl.ttnn.types.Shape -class Cpu: - ... - - -def has_storage_type_of(tensor: "ttnn.Tensor", storage_type) -> bool: - return tensor.value.storage_type() == storage_type - - -class Tensor(ttl.ttnn.tensor.Tensor): - @property - def device(self: "Tensor") -> DataType: - if has_storage_type_of(self, DEVICE_STORAGE_TYPE): - return self.value.device() - else: - return Cpu() - - def is_contiguous(self: "Shape") -> bool: - if self.layout == ROW_MAJOR_LAYOUT: - return self.value.shape() == self.value.shape_without_padding() - else: - return False +Tensor = ttl.ttnn.types.Tensor class ShardStrategy(Enum): diff --git a/ttnn/ttnn/validation.py b/ttnn/ttnn/validation.py index 1c1c62dd164..928af55693b 100644 --- a/ttnn/ttnn/validation.py +++ b/ttnn/ttnn/validation.py @@ -9,8 +9,8 @@ Layout, DEVICE_STORAGE_TYPE, Tensor, - has_storage_type_of, ) +from ttnn.core import has_storage_type_of def validate_input_tensor(