Skip to content

Commit

Permalink
#4003: updated ttnn.Tensor to be the same as ttl.tensor.Tensor in C++
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed Feb 5, 2024
1 parent 2eaf9ec commit 0a4e5a1
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 51 deletions.
26 changes: 9 additions & 17 deletions tt_eager/tensor/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,15 @@ class Tensor {
std::cref(this->storage_), std::cref(this->shape_), std::cref(this->dtype_), std::cref(this->layout_));
}

std::vector<uint32_t> host_page_ordering();
std::vector<uint32_t> host_page_ordering();

const ttnn::Shape ttnn_shape() const { return ttnn::Shape(this->shape_); }

private:
Storage storage_;
Shape shape_;
DataType dtype_;
Layout layout_;
Storage storage_;
Shape shape_;
DataType dtype_;
Layout layout_;
};


Expand All @@ -141,18 +144,7 @@ void memcpy(Tensor &dst, const Tensor &src);
} // namespace tt

namespace ttnn {
namespace types {

struct Tensor {
const tt::tt_metal::Tensor value;
const ttnn::Shape shape;

explicit Tensor(tt::tt_metal::Tensor &&tensor) : value{tensor}, shape{ttnn::Shape(tensor.shape())} {}
explicit Tensor(const tt::tt_metal::Tensor &tensor) : value{tensor}, shape{ttnn::Shape(tensor.shape())} {}
};

} // namespace types

using types::Tensor;
using Tensor = tt::tt_metal::Tensor;

} // namespace ttnn
6 changes: 5 additions & 1 deletion tt_eager/tensor/tensor_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ namespace tt_metal {

template<typename T>
static std::size_t compute_volume(const T& shape) {
return std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies<uint32_t>());
auto volume = 1;
for (auto index = 0; index < shape.rank(); index++) {
volume *= shape[index];
}
return volume;
}

template<typename T>
Expand Down
44 changes: 21 additions & 23 deletions tt_eager/tt_lib/csrc/ttnn/operations/binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ static const auto L1_MEMORY_CONFIG = tt::tt_metal::MemoryConfig{
.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED, .buffer_type = tt::tt_metal::BufferType::L1};

ttnn::Tensor reshape(const ttnn::Tensor& tensor, const ttnn::Shape& shape) {
return ttnn::Tensor(tensor.value.reshape(shape.value()));
return ttnn::Tensor(tensor.reshape(shape.value()));
}

ttnn::Tensor unsqueeze_to_4D(const ttnn::Tensor& tensor) {
const auto& tensor_shape = tensor.shape;
const auto tensor_shape = tensor.ttnn_shape();
const auto rank = tensor_shape.rank();
if (rank == 4) {
return tensor;
Expand All @@ -41,11 +41,21 @@ namespace binary {
void py_module(py::module& m_binary) {
m_binary.def(
"add",
[](const ttnn::Tensor& input_tensor_a,
const ttnn::Tensor& input_tensor_b,
[](const ttnn::Tensor& input_tensor_a_arg,
const ttnn::Tensor& input_tensor_b_arg,
const tt::tt_metal::MemoryConfig& memory_config) {
const auto& original_shape = input_tensor_a.shape;
const auto& input_shape_b = input_tensor_b.shape;
auto&& [input_tensor_a, input_tensor_b] = [](const auto& input_tensor_a_arg,
const auto& input_tensor_b_arg) {
// Swap tensors if input_tensor_a needs to be broadcasted to input_tensor_b
if (tt::tt_metal::compute_volume(input_tensor_a_arg.ttnn_shape()) <
tt::tt_metal::compute_volume(input_tensor_b_arg.ttnn_shape())) {
return std::make_tuple(input_tensor_b_arg, input_tensor_a_arg);
}
return std::make_tuple(input_tensor_a_arg, input_tensor_b_arg);
}(input_tensor_a_arg, input_tensor_b_arg);

const auto original_shape = input_tensor_a.ttnn_shape();
const auto input_shape_b = input_tensor_b.ttnn_shape();

std::size_t height_b{};
std::size_t width_b{};
Expand All @@ -60,9 +70,6 @@ void py_module(py::module& m_binary) {
auto input_tensor_a_4D = ttnn::unsqueeze_to_4D(input_tensor_a);
auto input_tensor_b_4D = ttnn::unsqueeze_to_4D(input_tensor_b);

const auto& ttl_input_tensor_a = input_tensor_a_4D.value;
const auto& ttl_input_tensor_b = input_tensor_b_4D.value;

if (height_b == 1 or width_b == 1) {
tt::tt_metal::BcastOpDim bcast_op_dim;
if (height_b == 1 and width_b == 1) {
Expand All @@ -74,18 +81,11 @@ void py_module(py::module& m_binary) {
} else {
TT_THROW("Invalid broadcasting dimensions");
}
auto ttl_output = tt::tt_metal::bcast(
ttl_input_tensor_a,
ttl_input_tensor_b,
tt::tt_metal::BcastOpMath::ADD,
bcast_op_dim,
memory_config);
auto output = ttnn::Tensor(ttl_output);
auto output = tt::tt_metal::bcast(
input_tensor_a_4D, input_tensor_b_4D, tt::tt_metal::BcastOpMath::ADD, bcast_op_dim, memory_config);
return ttnn::reshape(output, original_shape);
} else {
auto ttl_output =
tt::tt_metal::add(ttl_input_tensor_a, ttl_input_tensor_b, std::nullopt, memory_config);
auto output = ttnn::Tensor(ttl_output);
auto output = tt::tt_metal::add(input_tensor_a_4D, input_tensor_b_4D, std::nullopt, memory_config);
return ttnn::reshape(output, original_shape);
}
},
Expand All @@ -101,13 +101,11 @@ void py_module(py::module& m_binary) {
[](const ttnn::Tensor& input_tensor_a,
const float input_tensor_b,
const tt::tt_metal::MemoryConfig& memory_config) {
const auto& original_shape = input_tensor_a.shape;
const auto original_shape = input_tensor_a.ttnn_shape();

auto input_tensor_a_4D = ttnn::unsqueeze_to_4D(input_tensor_a);
const auto& ttl_input_tensor_a = input_tensor_a_4D.value;

auto ttl_output = tt::tt_metal::add_unary(ttl_input_tensor_a, input_tensor_b, memory_config);
auto output = ttnn::Tensor(ttl_output);
auto output = tt::tt_metal::add_unary(input_tensor_a_4D, input_tensor_b, memory_config);
return ttnn::reshape(output, original_shape);
},
py::arg("input_tensor_a"),
Expand Down
4 changes: 4 additions & 0 deletions tt_eager/tt_lib/csrc/ttnn/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ void py_module(py::module& m_types) {
PyShape.def_property_readonly("rank", [](const Shape& self) -> std::size_t { return self.rank(); });
PyShape.def("padded", [](const Shape& self) { return self.padded(); });

struct Tensor {
tt::tt_metal::Tensor value;
};

py::class_<Tensor>(m_types, "Tensor")
.def(py::init<tt::tt_metal::Tensor>())
.def_property_readonly("value", [](const Tensor& self) -> auto& { return self.value; })
Expand Down
9 changes: 4 additions & 5 deletions ttnn/ttnn/operations/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,10 @@ def add(
ttnn.Tensor([ 1, 3], dtype=bfloat16 )
"""
# Swap tensors if input_tensor_a needs to be broadcasted to input_tensor_b
if isinstance(input_tensor_b, ttnn.Tensor) and math.prod(input_tensor_a.shape) < math.prod(input_tensor_b.shape):
input_tensor_a, input_tensor_b = input_tensor_b, input_tensor_a

return ttl.ttnn.operations.binary.add(input_tensor_a, input_tensor_b, memory_config=memory_config)
input_tensor_a = input_tensor_a.value
input_tensor_b = input_tensor_b.value if isinstance(input_tensor_b, ttnn.Tensor) else input_tensor_b
output = ttl.ttnn.operations.binary.add(input_tensor_a, input_tensor_b, memory_config=memory_config)
return ttnn.Tensor(output)


def _sub_validate_input_tensors(operation_name, input_tensor_a, input_tensor_b, *args, **kwargs):
Expand Down
8 changes: 3 additions & 5 deletions ttnn/ttnn/operations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def reshape(input_tensor: ttnn.Tensor, shape: Union[ttnn.Shape, Tuple[int, ...]]
if input_tensor.shape == shape and list(input_tensor.shape) == list(shape):
return input_tensor

def ttnn_reshape(tensor, shape):
def ttnn_reshape(tensor: ttnn.Tensor, shape: ttnn.Shape) -> ttnn.Tensor:
ttl_input_tensor = tensor.value
return ttnn.Tensor(ttl_input_tensor.reshape(shape.value))

Expand Down Expand Up @@ -309,10 +309,8 @@ def impl(ttl_tensor):
output = output.squeeze()
return output

ttl_tensor = tensor.value
tensor = ttnn.Tensor(ttl_tensor.reshape(tensor.shape.padded().value))

return ttl.tensor.decorate_external_operation(impl, function_name="ttnn.to_torch")(ttl_tensor)
tensor = ttnn.reshape(tensor, tensor.shape.padded())
return ttl.tensor.decorate_external_operation(impl, function_name="ttnn.to_torch")(tensor.value)


def _to_device_validate_input_tensors(operation_name, tensor, *args, **kwargs):
Expand Down

0 comments on commit 0a4e5a1

Please sign in to comment.