Skip to content

Commit

Permalink
#4003: moved ttnn.add to C++
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed Feb 5, 2024
1 parent cc5403e commit 2eaf9ec
Show file tree
Hide file tree
Showing 16 changed files with 527 additions and 385 deletions.
63 changes: 24 additions & 39 deletions tests/ttnn/unit_tests/operations/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,6 @@ def test_add_1D_tensor_and_scalar(device, scalar, size):
assert output_tensor.shape == (size,)


@pytest.mark.parametrize("alpha", [0.42])
@pytest.mark.parametrize("scalar_input_tensor_b", [0.5])
@pytest.mark.parametrize("h", [1])
@pytest.mark.parametrize("w", [4])
def test_add_scalar_and_alpha(device, alpha, scalar_input_tensor_b, h, w):
torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.add(torch_input_tensor, scalar_input_tensor_b, alpha=alpha)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn.add(input_tensor, scalar_input_tensor_b, alpha=alpha)
output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output_tensor, output_tensor, 0.99999)


@pytest.mark.parametrize("h", [32])
@pytest.mark.parametrize("w", [64])
def test_add_2D_tensors(device, h, w):
Expand Down Expand Up @@ -106,44 +91,44 @@ def test_add_4D_tensors(device, h, w):
@pytest.mark.parametrize("h", [32])
@pytest.mark.parametrize("w", [64])
def test_add_with_broadcast(device, h, w):
torch_a = torch.rand((2, 16, 1, w), dtype=torch.bfloat16)
torch_b = torch.rand((2, 16, h, w), dtype=torch.bfloat16)
torch_output = torch.add(torch_a, torch_b)
torch_input_tensor_a = torch.rand((2, 16, 1, w), dtype=torch.bfloat16)
torch_input_tensor_b = torch.rand((2, 16, h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.add(torch_input_tensor_a, torch_input_tensor_b)

a = ttnn.from_torch(torch_a, layout=ttnn.TILE_LAYOUT, device=device)
b = ttnn.from_torch(torch_b, layout=ttnn.TILE_LAYOUT, device=device)
tt_output = ttnn.add(a, b)
tt_output = ttnn.to_torch(tt_output)
input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn.add(input_tensor_a, input_tensor_b)
output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output, tt_output, 0.9999)
assert_with_pcc(torch_output_tensor, output_tensor, 0.9999)


@pytest.mark.parametrize("h", [500])
@pytest.mark.parametrize("w", [512])
def test_expand_and_broadcast(device, h, w):
torch_a = torch.rand((1, h, w), dtype=torch.bfloat16)
torch_b = torch.rand((h, w), dtype=torch.bfloat16)
torch_output = torch.add(torch_a, torch_b)
torch_input_tensor_a = torch.rand((1, h, w), dtype=torch.bfloat16)
torch_input_tensor_b = torch.rand((h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.add(torch_input_tensor_a, torch_input_tensor_b)

a = ttnn.from_torch(torch_a, layout=ttnn.TILE_LAYOUT, device=device)
b = ttnn.from_torch(torch_b, layout=ttnn.TILE_LAYOUT, device=device)
tt_output = ttnn.add(a, b)
tt_output = ttnn.to_torch(tt_output)
input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn.add(input_tensor_a, input_tensor_b)
output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output, tt_output, 0.9999)
assert_with_pcc(torch_output_tensor, output_tensor, 0.9999)


@pytest.mark.skip(reason="4005: Unable to broadcast on batch or seq dimension")
@pytest.mark.parametrize("h", [32])
@pytest.mark.parametrize("w", [64])
def test_add_with_broadcast_on_batch(device, h, w):
torch_a = torch.rand((1, 16, 1, w), dtype=torch.bfloat16)
torch_b = torch.rand((2, 16, h, w), dtype=torch.bfloat16)
torch_output = torch.add(torch_a, torch_b)
torch_input_tensor_a = torch.rand((1, 16, 1, w), dtype=torch.bfloat16)
torch_input_tensor_b = torch.rand((2, 16, h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.add(torch_input_tensor_a, torch_input_tensor_b)

a = ttnn.from_torch(torch_a, layout=ttnn.TILE_LAYOUT, device=device)
b = ttnn.from_torch(torch_b, layout=ttnn.TILE_LAYOUT, device=device)
tt_output = ttnn.add(a, b)
tt_output = ttnn.to_torch(tt_output)
input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn.add(input_tensor_a, input_tensor_b)
output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output, tt_output, 0.9999)
assert_with_pcc(torch_output_tensor, output_tensor, 0.9999)
31 changes: 24 additions & 7 deletions tt_eager/tensor/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@

#pragma once

#include <vector>
#include <array>
#include <random>
#include <tuple>
#include <variant>
#include <vector>

#include "tensor/types.hpp"
#include "tt_metal/impl/device/device.hpp"
#include "tt_metal/impl/buffers/buffer.hpp"
#include "common/test_tiles.hpp"
#include "common/tt_backend_api_types.hpp"
#include "common/bfloat16.hpp"
#include "common/bfloat8.hpp"

#include "common/test_tiles.hpp"
#include "common/tt_backend_api_types.hpp"
#include "tensor/types.hpp"
#include "tt_metal/impl/buffers/buffer.hpp"
#include "tt_metal/impl/device/device.hpp"
#include "tt_metal/tt_stl/reflection.hpp"

namespace tt {
Expand Down Expand Up @@ -139,3 +139,20 @@ void memcpy(Tensor &dst, const Tensor &src);
} // namespace tt_metal

} // namespace tt

namespace ttnn {
namespace types {

struct Tensor {
const tt::tt_metal::Tensor value;
const ttnn::Shape shape;

explicit Tensor(tt::tt_metal::Tensor &&tensor) : value{tensor}, shape{ttnn::Shape(tensor.shape())} {}
explicit Tensor(const tt::tt_metal::Tensor &tensor) : value{tensor}, shape{ttnn::Shape(tensor.shape())} {}
};

} // namespace types

using types::Tensor;

} // namespace ttnn
229 changes: 212 additions & 17 deletions tt_eager/tensor/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,26 +118,21 @@ class Shape {
explicit Shape(const Shape&, const Padding&);

template <std::size_t Rank>
explicit Shape(
const std::array<uint32_t, Rank>& shape,
const std::optional<std::array<uint32_t, Rank>>& padded_shape = std::nullopt) :
rank_(Rank), dimensions_{}, padding_{Rank} {
if (padded_shape.has_value()) {
TT_ASSERT(shape.size() == padded_shape.value().size());
for (auto index = 0; index < Rank; index++) {
auto padded_dimension = padded_shape.value()[index];
this->dimensions_[index] = padded_dimension;
this->padding_[index] = {.front = 0, .back = padded_dimension - shape[index]};
}
} else {
for (auto index = 0; index < Rank; index++) {
this->dimensions_[index] = shape[index];
}
Shape(const std::array<uint32_t, Rank> &shape) : rank_(Rank), dimensions_{}, padding_{Rank} {
for (auto index = 0; index < Rank; index++) {
this->dimensions_[index] = shape[index];
}
}

// Add an implicit constructor from 4D array due to legacy code
Shape(const std::array<uint32_t, 4>& shape) : Shape(shape, std::optional<std::array<uint32_t, 4>>{std::nullopt}) {}
template <std::size_t Rank>
explicit Shape(const std::array<uint32_t, Rank> &shape, const std::array<uint32_t, Rank> &padded_shape) :
rank_(Rank), dimensions_{}, padding_{Rank} {
for (auto index = 0; index < Rank; index++) {
auto padded_dimension = padded_shape[index];
this->dimensions_[index] = padded_dimension;
this->padding_[index] = {.front = 0, .back = padded_dimension - shape[index]};
}
}

std::size_t rank() const;

Expand Down Expand Up @@ -282,3 +277,203 @@ bool operator!=(const ShardSpec& spec_a, const ShardSpec& spec_b);
} // namespace tt_metal

} // namespace tt

namespace ttnn {
namespace types {

namespace detail {
template <std::size_t Rank>
static tt::tt_metal::Shape compute_ttl_shape(
const std::array<uint32_t, Rank> &shape, const std::array<std::array<uint32_t, 2>, Rank> &padding) {
auto ttl_shape = std::array<uint32_t, Rank>{};
for (auto index = 0; index < Rank; index++) {
ttl_shape[index] = shape[index] + padding[index][0] + padding[index][1];
}
return tt::tt_metal::Shape{
tt::tt_metal::Shape{ttl_shape}, tt::tt_metal::Padding{padding, tt::tt_metal::Padding::PadValue::Any}};
}

} // namespace detail

template <std::size_t Rank>
struct RankedShape {
const std::size_t rank;
const tt::tt_metal::Shape value;

explicit RankedShape(tt::tt_metal::Shape &&shape) : rank{Rank}, value(shape) {}
explicit RankedShape(const tt::tt_metal::Shape &shape) : rank{Rank}, value(shape) {}

explicit RankedShape(const std::array<uint32_t, Rank> &shape) : rank{Rank}, value{shape} {}

explicit RankedShape(const std::array<uint32_t, Rank> &shape, const std::array<uint32_t, Rank> &padded_shape) :
rank{Rank}, value{shape, padded_shape} {}

explicit RankedShape(
const std::array<uint32_t, Rank> &shape, const std::array<std::array<uint32_t, 2>, Rank> &padding) :
rank{Rank}, value{detail::compute_ttl_shape(shape, padding)} {}

RankedShape<Rank> padded() const {
return RankedShape{tt::tt_metal::Shape{this->value, tt::tt_metal::Padding{this->value.rank()}}};
}

RankedShape<Rank> operator+(const std::array<std::array<uint32_t, 2>, Rank> &padding) const {
auto shape = this->value;
const auto &current_padding = this->value.padding();
auto accumulated_padding = padding;
for (auto index = 0; index < Rank; index++) {
shape[index] += padding[index][0] + padding[index][1];
accumulated_padding[index][0] += current_padding[index].front;
accumulated_padding[index][1] += current_padding[index].back;
}
return RankedShape<Rank>{tt::tt_metal::Shape{
shape, tt::tt_metal::Padding{accumulated_padding, tt::tt_metal::Padding::PadValue::Any}}};
}

template <std::size_t OtherRank>
RankedShape<Rank> operator+(const std::array<std::array<uint32_t, 2>, OtherRank> &padding) const {
TT_THROW("Invalid padding");
}

bool operator==(const RankedShape<Rank> &other) const { return this->value == other.value; }

template <std::size_t OtherRank>
bool operator==(const RankedShape<OtherRank> &other) const {
return false;
}

const auto &operator[](std::int64_t index) const { return this->value.without_padding()[index]; }
};

template <std::size_t Rank>
static std::ostream &operator<<(std::ostream &os, const RankedShape<Rank> &self) {
os << "ttnn.Shape([";
const auto shape = self.value.without_padding();
const auto &padding = self.value.padding();
const auto &padded_shape = self.value;
for (auto i = 0; i < Rank; ++i) {
if (i > 0) {
os << ", ";
}
if (padding[i].front > 0) {
os << padding[i].front << " + ";
}
os << shape[i];
if (padding[i].back > 0) {
os << " + " << padding[i].back;
}
}
os << "])";
return os;
}

struct Shape {
using RankedShapeVariant = std::variant<
const RankedShape<1>,
const RankedShape<2>,
const RankedShape<3>,
const RankedShape<4>,
const RankedShape<5>,
const RankedShape<6>,
const RankedShape<7>,
const RankedShape<8>>;

const RankedShapeVariant ranked_shape;

private:
RankedShapeVariant ttl_shape_to_ttnn_shape(const tt::tt_metal::Shape &shape) {
switch (shape.rank()) {
case 1: return RankedShape<1>{shape};
case 2: return RankedShape<2>{shape};
case 3: return RankedShape<3>{shape};
case 4: return RankedShape<4>{shape};
case 5: return RankedShape<5>{shape};
case 6: return RankedShape<6>{shape};
case 7: return RankedShape<7>{shape};
case 8: return RankedShape<8>{shape};
};
TT_THROW("Unsupported rank");
}

public:
explicit Shape(const tt::tt_metal::Shape &shape) : ranked_shape{ttl_shape_to_ttnn_shape(shape)} {}

template <std::size_t Rank>
explicit Shape(const RankedShape<Rank> &shape) : ranked_shape{shape} {}

template <std::size_t Rank>
explicit Shape(const std::array<uint32_t, Rank> &shape) : ranked_shape{RankedShape<Rank>{shape}} {}

template <std::size_t Rank>
explicit Shape(const std::array<uint32_t, Rank> &shape, const std::array<uint32_t, Rank> &padded_shape) :
ranked_shape{RankedShape<Rank>{shape, padded_shape}} {}

template <std::size_t Rank>
explicit Shape(const std::array<uint32_t, Rank> &shape, const std::array<std::array<uint32_t, 2>, Rank> &padding) :
ranked_shape{RankedShape<Rank>{shape, padding}} {}

const auto rank() const {
return std::visit(
[]<std::size_t Rank>(const RankedShape<Rank> &shape) -> const auto { return Rank; }, this->ranked_shape);
}

Shape padded() const {
return std::visit([](const auto &shape) -> Shape { return Shape(shape.padded()); }, this->ranked_shape);
}

template <std::size_t Rank>
Shape operator+(const std::array<std::array<uint32_t, 2>, Rank> &padding) const {
return std::visit(
[&padding](const auto &shape) -> Shape { return Shape(shape + padding); }, this->ranked_shape);
}

bool operator==(const Shape &other) const {
return std::visit(
[](const auto &shape, const auto &other) -> bool { return shape == other; },
this->ranked_shape,
other.ranked_shape);
}

const auto &operator[](std::int64_t index) const {
return std::visit([index](const auto &shape) -> decltype(auto) { return shape[index]; }, this->ranked_shape);
}

const auto &value() const {
return std::visit([](const auto &shape) -> const auto & { return shape.value; }, this->ranked_shape);
}

template <std::size_t NewRank>
const Shape to_rank() const {
return std::visit(
[]<std::size_t Rank>(const RankedShape<Rank> &shape) {
if constexpr (Rank == NewRank) {
return Shape(shape);
} else {
auto num_missing_dims = NewRank - Rank;

std::array<uint32_t, NewRank> new_shape{};
std::array<uint32_t, NewRank> new_padded_shape{};

new_shape.fill(1);
new_padded_shape.fill(1);

for (auto index = 0; index < Rank; index++) {
new_shape[index + num_missing_dims] = shape[index];
new_padded_shape[index + num_missing_dims] = shape.padded()[index];
}
return Shape(RankedShape<NewRank>(new_shape, new_padded_shape));
}
},
this->ranked_shape);
}
};

static std::ostream &operator<<(std::ostream &os, const Shape &self) {
std::visit([&os](const auto &shape) { os << shape; }, self.ranked_shape);
return os;
}

} // namespace types

using types::Shape;

} // namespace ttnn
Loading

0 comments on commit 2eaf9ec

Please sign in to comment.