From 8caffaeecaa7f42a86fabf4d9cf75c43fbbffb7d Mon Sep 17 00:00:00 2001 From: Eyon Date: Mon, 10 Jun 2024 19:24:02 +0000 Subject: [PATCH] #9527: Remove usage of bcast for cases where multiply could be used --- tt_eager/tensor/host_buffer/functions.hpp | 4 + tt_eager/tensor/tensor.cpp | 4 + tt_eager/tensor/tensor.hpp | 1 + tt_eager/tensor/tensor_utils.cpp | 7 + tt_eager/tt_dnn/op_library/auto_format.cpp | 18 + tt_eager/tt_dnn/op_library/auto_format.hpp | 14 +- .../op_library/backward/backward_ops.cpp | 10 + .../bcast/multi_core_h/bcast_op_sharded_h.cpp | 4 +- .../op_library/composite/composite_ops.cpp | 484 +++++++++++------- .../op_library/composite/composite_ops.hpp | 39 +- .../eltwise_unary/eltwise_unary_op.cpp | 10 +- .../moreh_clip_grad_norm_op.cpp | 5 +- .../transformer_tms/transformer_tms.cpp | 5 + .../op_library/transpose/transpose_op.cpp | 1 + tt_eager/tt_dnn/op_library/unpad/unpad_op.cpp | 1 + .../csrc/tt_lib_bindings_tensor_pytensor.cpp | 5 + .../detail/reports/compilation_reporter.cpp | 2 +- tt_metal/tt_stl/reflection.hpp | 10 + .../ttnn/op_library/to_dtype/to_dtype_op.hpp | 2 + ttnn/cpp/ttnn/operations/creation.hpp | 36 ++ .../eltwise/binary/device/binary_op.cpp | 2 + 21 files changed, 444 insertions(+), 220 deletions(-) diff --git a/tt_eager/tensor/host_buffer/functions.hpp b/tt_eager/tensor/host_buffer/functions.hpp index 4d9a4beee69..541253e5d48 100644 --- a/tt_eager/tensor/host_buffer/functions.hpp +++ b/tt_eager/tensor/host_buffer/functions.hpp @@ -8,6 +8,7 @@ #include "tensor/host_buffer/types.hpp" #include "tensor/tensor.hpp" +#include "tt_metal/tt_stl/reflection.hpp" namespace tt { @@ -41,6 +42,7 @@ Buffer get_as(BorrowedBuffer& buffer) { template Buffer get_as(const BorrowedBuffer& buffer) { + TT_ASSERT(std::holds_alternative>(buffer), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(buffer),__FILE__, __LINE__)); return std::get>(buffer); } @@ -155,12 +157,14 @@ namespace host_buffer { template borrowed_buffer::Buffer get_as(OwnedBuffer& buffer) { + TT_ASSERT(std::holds_alternative>(buffer), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(buffer),__FILE__, __LINE__)); auto& owned_buffer = std::get>(buffer); return borrowed_buffer::Buffer(owned_buffer.begin(), owned_buffer.size()); } template borrowed_buffer::Buffer get_as(const OwnedBuffer& buffer) { + TT_ASSERT(std::holds_alternative>(buffer), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(buffer),__FILE__, __LINE__)); auto owned_buffer = std::get>(buffer); return borrowed_buffer::Buffer(owned_buffer.begin(), owned_buffer.size()); } diff --git a/tt_eager/tensor/tensor.cpp b/tt_eager/tensor/tensor.cpp index 1409c091a83..341e3a62422 100644 --- a/tt_eager/tensor/tensor.cpp +++ b/tt_eager/tensor/tensor.cpp @@ -188,6 +188,7 @@ void Tensor::deallocate(bool force) { auto dealloc_lambda = std::make_shared>( [force, attr = this->tensor_attributes](Device* worker) mutable { ZoneScopedN("ShardDeallocate"); + TT_ASSERT(std::holds_alternative(attr->storage), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(attr->storage),__FILE__, __LINE__)); auto& s = std::get(attr->storage); if (s.has_buffer_for_device(worker)) { auto& device_buffer = s.get_buffer_for_device(worker); @@ -809,6 +810,8 @@ const Shape Tensor::strides() const { return detail::compute_strides(this->get_l uint32_t Tensor::volume() const { return tt::tt_metal::compute_volume(this->get_legacy_shape()); } +uint32_t Tensor::intended_volume() const { return tt::tt_metal::compute_volume(this->get_shape()); } + Tensor create_device_tensor( const Shape& shape, DataType data_type, Layout layout, Device* device, const MemoryConfig& memory_config) { ZoneScoped; @@ -1037,6 +1040,7 @@ void write_tensor(Tensor host_tensor, Tensor device_tensor, uint8_t cq_id) { auto host_storage = std::get(async_safe_tensor.get_storage()); std::visit([&host_data](auto&& b) { host_data = b.data(); }, host_storage.buffer); } else { + TT_ASSERT(std::holds_alternative(async_safe_tensor.get_storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(async_safe_tensor.get_storage()),__FILE__, __LINE__)); auto host_storage = std::get(async_safe_tensor.get_storage()); std::visit([&host_data](auto&& b) { host_data = b.begin(); }, host_storage.get_buffer()); } diff --git a/tt_eager/tensor/tensor.hpp b/tt_eager/tensor/tensor.hpp index a8989dce303..7ba6c809e3e 100644 --- a/tt_eager/tensor/tensor.hpp +++ b/tt_eager/tensor/tensor.hpp @@ -311,6 +311,7 @@ struct Tensor { StorageType storage_type() const; const Shape strides() const; uint32_t volume() const; + uint32_t intended_volume() const; bool is_allocated() const; diff --git a/tt_eager/tensor/tensor_utils.cpp b/tt_eager/tensor/tensor_utils.cpp index 82d1cc9dd90..3ca2f4e10e3 100644 --- a/tt_eager/tensor/tensor_utils.cpp +++ b/tt_eager/tensor/tensor_utils.cpp @@ -420,6 +420,7 @@ bool is_multi_device_tensor(const Tensor& tensor) { std::vector get_tensors_from_multi_device_storage(const Tensor& multi_device_tensor) { std::vector tensors; if (multi_device_tensor.storage_type() == StorageType::MULTI_DEVICE) { + TT_ASSERT(std::holds_alternative(multi_device_tensor.get_storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(multi_device_tensor.get_storage()),__FILE__, __LINE__)); const auto& tensor_storage = std::get(multi_device_tensor.get_storage()); tensors = std::vector(tensor_storage.num_buffers(), Tensor()); for (int i = 0; i < tensor_storage.ordered_device_ids.size(); ++i) { @@ -432,6 +433,7 @@ std::vector get_tensors_from_multi_device_storage(const Tensor& multi_de } return tensors; } else if (multi_device_tensor.storage_type() == StorageType::MULTI_DEVICE_HOST) { + TT_ASSERT(std::holds_alternative(multi_device_tensor.get_storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(multi_device_tensor.get_storage()),__FILE__, __LINE__)); const auto& tensor_storage = std::get(multi_device_tensor.get_storage()); for (int i = 0; i < tensor_storage.num_buffers(); ++i) { tensors.push_back(Tensor{ @@ -448,9 +450,11 @@ std::vector get_tensors_from_multi_device_storage(const Tensor& multi_de DistributedTensorConfig get_distributed_tensor_config_from_tensor(const Tensor& tensor) { if (tensor.storage_type() == StorageType::MULTI_DEVICE) { + TT_ASSERT(std::holds_alternative(tensor.get_storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(tensor.get_storage()),__FILE__, __LINE__)); const auto& tensor_storage = std::get(tensor.get_storage()); return tensor_storage.strategy; } else if (tensor.storage_type() == StorageType::MULTI_DEVICE_HOST) { + TT_ASSERT(std::holds_alternative(tensor.get_storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(tensor.get_storage()),__FILE__, __LINE__)); const auto& tensor_storage = std::get(tensor.get_storage()); return tensor_storage.strategy; } @@ -468,6 +472,7 @@ Tensor create_multi_device_tensor( std::unordered_map shapes; std::unordered_map device_buffers; for (const auto& tensor : tensors) { + TT_ASSERT(std::holds_alternative(tensor.get_storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(tensor.get_storage()),__FILE__, __LINE__)); Device* device = std::get(tensor.get_storage()).buffer->device(); auto device_id = device->id(); ordered_device_ids.push_back(device_id); @@ -483,6 +488,7 @@ Tensor create_multi_device_tensor( std::vector owned_buffers; std::vector shapes; for (const auto& tensor : tensors) { + TT_ASSERT(std::holds_alternative(tensor.get_storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(tensor.get_storage()),__FILE__, __LINE__)); owned_buffers.push_back(std::get(tensor.get_storage()).buffer); shapes.push_back(tensor.get_legacy_shape()); } @@ -516,6 +522,7 @@ void apply(const Tensor& tensor, std::function callable) { std::vector get_devices(const Tensor& tensor) { std::vector devices; if (tensor.storage_type() == tt::tt_metal::StorageType::MULTI_DEVICE) { + TT_ASSERT(std::holds_alternative(tensor.get_storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(tensor.get_storage()),__FILE__, __LINE__)); const auto& tensor_storage = std::get(tensor.get_storage()); for (int i = 0; i < tensor_storage.ordered_device_ids.size(); ++i) { auto device_id = tensor_storage.ordered_device_ids[i]; diff --git a/tt_eager/tt_dnn/op_library/auto_format.cpp b/tt_eager/tt_dnn/op_library/auto_format.cpp index 56ad92cc639..744124b9763 100644 --- a/tt_eager/tt_dnn/op_library/auto_format.cpp +++ b/tt_eager/tt_dnn/op_library/auto_format.cpp @@ -40,6 +40,24 @@ Tensor AutoFormat::move_tensor_to_mem_config(const Tensor& input, const MemoryCo } } +// This code is a workaround for cases where we need to remove autoformat but other dependent ops +// are not quite ready. So here we basically just put the tensor back on device. +// Used in backward_ops.cpp +// See: Remove auto format within permute_op.cpp #9404 +Tensor AutoFormat::move_tensor_to_device_and_pad(const Tensor& input, Device *device, Layout target_layout, std::optional target_mem_config){ + const auto intended_shape = input.get_shape(); + const auto device_shape = input.get_legacy_shape(); + const auto new_intended_shape = std::array{intended_shape[0], intended_shape[1], intended_shape[-2], intended_shape[-1]}; + const auto new_device_shape = std::array{ + device_shape[0], + device_shape[1], + (device_shape[-2] % TILE_HEIGHT != 0 ? (device_shape[-2] / TILE_HEIGHT + 1) * TILE_HEIGHT : device_shape[-2]), + (device_shape[-1] % TILE_WIDTH != 0 ? (device_shape[-1] / TILE_WIDTH + 1) * TILE_WIDTH : device_shape[-1]) + }; + const auto new_shape = tt_metal::Shape(new_intended_shape, new_device_shape); + return AutoFormat::format_input_tensor(input, device, new_shape, 0.0, target_layout, target_mem_config); +} + Tensor AutoFormat::format_input_tensor( const Tensor& input, Device* device, diff --git a/tt_eager/tt_dnn/op_library/auto_format.hpp b/tt_eager/tt_dnn/op_library/auto_format.hpp index c2de0e0542f..0e6f9056ae3 100644 --- a/tt_eager/tt_dnn/op_library/auto_format.hpp +++ b/tt_eager/tt_dnn/op_library/auto_format.hpp @@ -34,10 +34,10 @@ class AutoFormat { static Shape pad_to_tile_shape(const Shape& unpadded_shape, bool pad_c=false, bool pad_n=false, bool pad_h=true, bool pad_w=true) { - auto n = pad_n ? round_up(unpadded_shape[0], TILE_HEIGHT) : unpadded_shape[0]; - auto c = pad_c ? round_up(unpadded_shape[1], TILE_WIDTH) : unpadded_shape[1]; - auto h = pad_h ? round_up(unpadded_shape[2], TILE_HEIGHT) : unpadded_shape[2]; - auto w = pad_w ? round_up(unpadded_shape[3], TILE_WIDTH) : unpadded_shape[3]; + auto n = pad_n ? round_up(unpadded_shape.rank() >= 4 ? unpadded_shape[-4] : 1, TILE_HEIGHT) : unpadded_shape.rank() >= 4 ? unpadded_shape[-4] : 1; + auto c = pad_c ? round_up(unpadded_shape.rank() >= 3 ? unpadded_shape[-3] : 1, TILE_WIDTH) : unpadded_shape.rank() >= 3 ? unpadded_shape[-3] : 1; + auto h = pad_h ? round_up(unpadded_shape[-2], TILE_HEIGHT) : unpadded_shape[-2]; + auto w = pad_w ? round_up(unpadded_shape[-1], TILE_WIDTH) : unpadded_shape[-1]; Shape padded_shape = {n, c, h, w}; return padded_shape; } @@ -83,6 +83,12 @@ class AutoFormat { return false; } + // This code is a workaround for cases where we need to remove autoformat but other dependent ops + // are not quite ready. So here we basically just put the tensor back on device. + // Used in backward_ops.cpp + // See: Remove auto format within permute_op.cpp #9404 + static Tensor move_tensor_to_device_and_pad(const Tensor& input, Device *device, Layout target_layout, std::optional target_mem_config); + static Tensor move_tensor_to_device(const Tensor &input, Device * device, const MemoryConfig& mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); static Tensor move_tensor_to_mem_config(const Tensor &input, const MemoryConfig& mem_config); diff --git a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp index 863bc5cda55..4109b3919b1 100644 --- a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp +++ b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp @@ -2218,6 +2218,11 @@ std::vector _prod_bw( std::vector after_permute_dims = {0, 2, 3, 1}; Tensor tensor_1 = permute(tensor_1_temp, after_permute_dims, output_mem_config); Tensor tensor_2 = permute(temp, after_permute_dims, output_mem_config); + + // put the tensor back on device because permute throws it off device + // See: Remove auto format within permute_op.cpp #9404 + tensor_2 = AutoFormat::move_tensor_to_device_and_pad(tensor_2, tensor_1.device(),tensor_1.get_layout(), tensor_1.memory_config()); + after_permute_dims = {0, 3, 1, 2}; Tensor result = permute( bcast(tensor_1, tensor_2, BcastOpMath::MUL, BcastOpDim::W, output_mem_config), @@ -2250,6 +2255,11 @@ std::vector _prod_bw( std::vector after_permute_dims = {3, 1, 2, 0}; Tensor tensor_1 = permute(tensor_1_temp, after_permute_dims, output_mem_config); Tensor tensor_2 = permute(temp, after_permute_dims, output_mem_config); + + // put the tensor back on device because permute throws it off device + // See: Remove auto format within permute_op.cpp #9404 + tensor_2 = AutoFormat::move_tensor_to_device_and_pad(tensor_2, tensor_1.device(),tensor_1.get_layout(), tensor_1.memory_config()); + Tensor result = permute( bcast(tensor_1, tensor_2, BcastOpMath::MUL, BcastOpDim::W, output_mem_config), after_permute_dims, diff --git a/tt_eager/tt_dnn/op_library/bcast/multi_core_h/bcast_op_sharded_h.cpp b/tt_eager/tt_dnn/op_library/bcast/multi_core_h/bcast_op_sharded_h.cpp index 06885ce922b..fd9fe860a62 100644 --- a/tt_eager/tt_dnn/op_library/bcast/multi_core_h/bcast_op_sharded_h.cpp +++ b/tt_eager/tt_dnn/op_library/bcast/multi_core_h/bcast_op_sharded_h.cpp @@ -22,8 +22,8 @@ namespace tt_metal { operation::ProgramWithCallbacks bcast_sharded_h(const Tensor &a, const Tensor &b, const Tensor& output, BcastOpMath bcast_math/*, BcastOpDim bcast_dim*/){ const auto ashape = a.get_legacy_shape(); const auto bshape = b.get_legacy_shape(); - uint32_t N = ashape[0], C = ashape[1], H = ashape[2], W = ashape[3]; - uint32_t bN = bshape[0], bC = bshape[1], bH = bshape[2], bW = bshape[3]; + uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1, C = ashape.rank() >= 3 ? ashape[-3] : 1, H = ashape[-2], W = ashape[-1]; + uint32_t bN = bshape.rank() >= 4 ? bshape[-4] : 1, bC = bshape.rank() >= 3 ? bshape[-3] : 1, bH = bshape[-2], bW = bshape[-1]; uint32_t NC = N*C; diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp index 96b7dbe3952..8eb2f920c6d 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp @@ -28,34 +28,42 @@ namespace tt { namespace tt_metal { -Tensor mk_zero_tensor_like(uint8_t queue_id, const Tensor& reference_tensor, const MemoryConfig& output_mem_config, std::optional output_tensor = std::nullopt) { - // Tensor zero_like = bcast(reference_tensor, , BcastOpMath::MUL, BcastOpDim::HW); - Tensor zero = mk_tiled_scalar(0.0f, reference_tensor.get_dtype()); - if(output_tensor.has_value()){ - bcast(queue_id, reference_tensor, zero, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config, output_tensor); - } - else{ - output_tensor = bcast(queue_id, reference_tensor, zero, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); - } - return output_tensor.value(); +Tensor mk_zero_tensor_like( + uint8_t queue_id, + const Tensor& reference_tensor, + const MemoryConfig& output_mem_config, + std::optional output_tensor = std::nullopt) { + const DataType& dtype = + output_tensor.has_value() ? output_tensor.value().get_dtype() : reference_tensor.get_dtype(); + Tensor zero = ttnn::operations::creation::create_scalar(0.0f, dtype, Layout::TILE, reference_tensor.device()); + return ttnn::multiply(queue_id, reference_tensor, zero, std::nullopt, output_mem_config, output_tensor); } -Tensor mk_zero_tensor_like(const Tensor& reference_tensor, const MemoryConfig& output_mem_config, std::optional output_tensor = std::nullopt) { + +Tensor mk_zero_tensor_like( + const Tensor& reference_tensor, + const MemoryConfig& output_mem_config, + std::optional output_tensor = std::nullopt) { uint8_t default_queue_id = 0; return mk_zero_tensor_like(default_queue_id, reference_tensor, output_mem_config, output_tensor); } // TODO: enable zeroes(), ones() and eye() type functions on-device using this type of logic template -Tensor mk_filled_tensor_like(const Tensor& reference_tensor, T val, const MemoryConfig& output_mem_config, std::optional output_tensor = std::nullopt, uint8_t queue_id = 0) { - Tensor k = mk_tiled_scalar(val, reference_tensor.get_dtype()); +Tensor mk_filled_tensor_like( + const Tensor& reference_tensor, + T val, + const MemoryConfig& output_mem_config, + std::optional output_tensor = std::nullopt, + uint8_t queue_id = 0) { + const DataType& dtype = + output_tensor.has_value() ? output_tensor.value().get_dtype() : reference_tensor.get_dtype(); + Tensor k = ttnn::operations::creation::create_scalar(val, dtype, Layout::TILE, reference_tensor.device()); Tensor zero_like = mk_zero_tensor_like(reference_tensor, output_mem_config); - if(output_tensor.has_value()){ - bcast(queue_id, zero_like, k, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config, output_tensor); - } - else{ - output_tensor = bcast(queue_id, zero_like, k, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config); + if (output_tensor.has_value()) { + return bcast(queue_id, zero_like, k, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config, output_tensor); + } else { + return bcast(queue_id, zero_like, k, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config); } - return output_tensor.value(); } // Function: softshrink @@ -66,7 +74,8 @@ Tensor _softshrink(const Tensor& a, float param, const MemoryConfig& output_mem_ Tensor t1 = ttnn::multiply(ltz(t_a_plus_param, output_mem_config), t_a_plus_param, std::nullopt, output_mem_config); t_a_plus_param.deallocate(); Tensor t_a_minus_param = sub_unary(a, param, output_mem_config); - Tensor t2 = ttnn::multiply(gtz(t_a_minus_param, output_mem_config), t_a_minus_param, std::nullopt, output_mem_config); + Tensor t2 = + ttnn::multiply(gtz(t_a_minus_param, output_mem_config), t_a_minus_param, std::nullopt, output_mem_config); t_a_minus_param.deallocate(); return ttnn::add(t1, t2, std::nullopt, output_mem_config); } @@ -191,7 +200,8 @@ Tensor _lgamma(const Tensor& x, const MemoryConfig& output_mem_config) { } temp_log = log(temp, output_mem_config); result = add_unary( - ttnn::multiply(add_unary(input, 0.5f, output_mem_config), t_log, std::nullopt, output_mem_config), + ttnn::multiply( + add_unary(input, 0.5f, output_mem_config), t_log, std::nullopt, output_mem_config), 0.918938531357171f, output_mem_config); } @@ -220,12 +230,12 @@ Tensor lgamma(const Tensor& a, const MemoryConfig& output_mem_config) { // Ref : https://pytorch.org/docs/stable/special.html#torch.special.multigammaln Tensor _multigammaln(const Tensor& x, const MemoryConfig& output_mem_config) { Tensor result = lgamma(x, output_mem_config); - result = - ttnn::add(result, lgamma(sub_unary(x, 0.5f, output_mem_config), output_mem_config), std::nullopt, output_mem_config); - result = - ttnn::add(result, lgamma(sub_unary(x, 1.0f, output_mem_config), output_mem_config), std::nullopt, output_mem_config); - result = - ttnn::add(result, lgamma(sub_unary(x, 1.5f, output_mem_config), output_mem_config), std::nullopt, output_mem_config); + result = ttnn::add( + result, lgamma(sub_unary(x, 0.5f, output_mem_config), output_mem_config), std::nullopt, output_mem_config); + result = ttnn::add( + result, lgamma(sub_unary(x, 1.0f, output_mem_config), output_mem_config), std::nullopt, output_mem_config); + result = ttnn::add( + result, lgamma(sub_unary(x, 1.5f, output_mem_config), output_mem_config), std::nullopt, output_mem_config); result = add_unary(result, 3.434189657547f, output_mem_config); return result; } @@ -273,21 +283,22 @@ Tensor mish(const Tensor& a, const MemoryConfig& output_mem_config) { Tensor _selu(const Tensor& x, const float scale, const float alpha, const MemoryConfig& output_mem_config) { // term 2 Tensor x_Exp = exp(x, output_mem_config); - Tensor minus_one = mk_tiled_scalar(-1.0f); + Tensor minus_one = ttnn::operations::creation::create_scalar(-1.0f, x.get_dtype(), Layout::TILE, x.device()); Tensor x_Exp_minus_1 = bcast(x_Exp, minus_one, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config); x_Exp.deallocate(); minus_one.deallocate(); - Tensor t_alpha = mk_tiled_scalar(alpha); - Tensor result_t2_ = bcast(x_Exp_minus_1, t_alpha, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); + Tensor t_alpha = ttnn::operations::creation::create_scalar(alpha, x.get_dtype(), Layout::TILE, x.device()); + Tensor result_t2_ = ttnn::multiply(x_Exp_minus_1, t_alpha, std::nullopt, output_mem_config); x_Exp_minus_1.deallocate(); t_alpha.deallocate(); - Tensor result_term2 = ttnn::multiply(gtz(result_t2_, output_mem_config), result_t2_, std::nullopt, output_mem_config); + Tensor result_term2 = + ttnn::multiply(gtz(result_t2_, output_mem_config), result_t2_, std::nullopt, output_mem_config); result_t2_.deallocate(); // term 1 - Tensor t_scale = mk_tiled_scalar(scale); + Tensor t_scale = ttnn::operations::creation::create_scalar(scale, x.get_dtype(), Layout::TILE, x.device()); Tensor x_relu = relu(x, output_mem_config); - Tensor result_term1 = bcast(x_relu, t_scale, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); + Tensor result_term1 = ttnn::multiply(x_relu, t_scale, std::nullopt, output_mem_config); t_scale.deallocate(); x_relu.deallocate(); Tensor result_selu = ttnn::add(result_term1, result_term2, std::nullopt, output_mem_config); @@ -306,7 +317,10 @@ Tensor selu(const Tensor& x, const float scale, const float alpha, const MemoryC Tensor rpow(const Tensor& a, float k, const MemoryConfig& output_mem_config) { TT_ASSERT(k > 0.0, "rpow cannot be calcualted for non-positive numbers"); float log_k = logf(k); - Tensor result = bcast(a, mk_tiled_scalar(log_k), BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); + + Tensor scalar = ttnn::operations::creation::create_scalar(log_k, a.get_dtype(), Layout::TILE, a.device()); + Tensor result = ttnn::multiply(a, scalar, std::nullopt, output_mem_config); + scalar.deallocate(); return exp(result, output_mem_config); } @@ -367,13 +381,22 @@ Tensor _polyval(const Tensor& input_tensor, std::vector coeffs, const Mem return mk_filled_tensor_like(input_tensor, coeffs[0], output_mem_config); } - Tensor result = - bcast(input_tensor, mk_tiled_scalar(coeffs[0]), BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); + Tensor scalar = ttnn::operations::creation::create_scalar( + coeffs[0], input_tensor.get_dtype(), Layout::TILE, input_tensor.device()); + Tensor result = ttnn::multiply(input_tensor, scalar, std::nullopt, output_mem_config); + scalar.deallocate(); for (int idx = 1; idx < coeffs.size() - 1; idx++) { - result = bcast(result, mk_tiled_scalar(coeffs[idx]), BcastOpMath::ADD, BcastOpDim::HW, output_mem_config); + Tensor scalar = ttnn::operations::creation::create_scalar( + coeffs[idx], input_tensor.get_dtype(), Layout::TILE, input_tensor.device()); + result = bcast(result, scalar, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config); + scalar.deallocate(); result = ttnn::multiply(input_tensor, result, std::nullopt, output_mem_config); } - return bcast(result, mk_tiled_scalar(coeffs.back()), BcastOpMath::ADD, BcastOpDim::HW, output_mem_config); + Tensor last_coeffs = ttnn::operations::creation::create_scalar( + coeffs.back(), input_tensor.get_dtype(), Layout::TILE, input_tensor.device()); + Tensor final_tensor = bcast(result, last_coeffs, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config); + last_coeffs.deallocate(); + return final_tensor; } Tensor polyval(const Tensor& input_tensor, std::vector coeffs, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _polyval)(input_tensor, coeffs, output_mem_config); @@ -383,9 +406,9 @@ Tensor polyval(const Tensor& input_tensor, std::vector coeffs, const Memo // compute multiply-accumulate: y = a * b + c, over various 8 combinations of a, b, c // being a scalar or tensor Tensor _mac(const Tensor& a, const Tensor& b, const Tensor& c, const MemoryConfig& output_mem_config) { - bool a_is_scalar = a.volume() == 1; - bool b_is_scalar = b.volume() == 1; - bool c_is_scalar = c.volume() == 1; + bool a_is_scalar = a.intended_volume() == 1; + bool b_is_scalar = b.intended_volume() == 1; + bool c_is_scalar = c.intended_volume() == 1; const auto dim = BcastOpDim::HW; if (!a_is_scalar && !b_is_scalar && !c_is_scalar) { @@ -393,24 +416,26 @@ Tensor _mac(const Tensor& a, const Tensor& b, const Tensor& c, const MemoryConfi return ttnn::add(ttnn::multiply(a, b, std::nullopt, output_mem_config), c, std::nullopt, output_mem_config); } else if (!a_is_scalar && !b_is_scalar && c_is_scalar) { // a - tensor, b - tensor, c - is scalar - return bcast(ttnn::multiply(a, b, std::nullopt, output_mem_config), c, BcastOpMath::ADD, dim, output_mem_config); + return bcast( + ttnn::multiply(a, b, std::nullopt, output_mem_config), c, BcastOpMath::ADD, dim, output_mem_config); } else if (!a_is_scalar && b_is_scalar && !c_is_scalar) { // a - tensor, b - scalar, c - is tensor - return ttnn::add(bcast(a, b, BcastOpMath::MUL, dim, output_mem_config), c, std::nullopt, output_mem_config); + return ttnn::add(ttnn::multiply(a, b, std::nullopt, output_mem_config), c, std::nullopt, output_mem_config); } else if (!a_is_scalar && b_is_scalar && c_is_scalar) { // a - tensor, b - scalar, c - is scalar return bcast( - bcast(a, b, BcastOpMath::MUL, dim, output_mem_config), c, BcastOpMath::ADD, dim, output_mem_config); + ttnn::multiply(a, b, std::nullopt, output_mem_config), c, BcastOpMath::ADD, dim, output_mem_config); } else if (a_is_scalar && !b_is_scalar && !c_is_scalar) { // a - scalar, b - tensor, c - tensor - return ttnn::add(bcast(b, a, BcastOpMath::MUL, dim, output_mem_config), c, std::nullopt, output_mem_config); + return ttnn::add(ttnn::multiply(b, a, std::nullopt, output_mem_config), c, std::nullopt, output_mem_config); } else if (a_is_scalar && !b_is_scalar && c_is_scalar) { // a - scalar, b - tensor, c - is scalar return bcast( - bcast(b, a, BcastOpMath::MUL, dim, output_mem_config), c, BcastOpMath::ADD, dim, output_mem_config); + ttnn::multiply(b, a, std::nullopt, output_mem_config), c, BcastOpMath::ADD, dim, output_mem_config); } else if (a_is_scalar && b_is_scalar && !c_is_scalar) { // a - scalar, b - scalar, c - is tensor - return bcast(c, ttnn::multiply(a, b, std::nullopt, output_mem_config), BcastOpMath::ADD, dim, output_mem_config); + return bcast( + c, ttnn::multiply(a, b, std::nullopt, output_mem_config), BcastOpMath::ADD, dim, output_mem_config); } // all scalars @@ -423,9 +448,12 @@ Tensor mac(const Tensor& a, const Tensor& b, const Tensor& c, const MemoryConfig } Tensor _mac_overload(const Tensor& a, float b, float c, const MemoryConfig& output_mem_config) { - Tensor t_b = mk_scalar(b); - Tensor t_c = mk_scalar(c); - return mac(a, t_b, t_c, output_mem_config); + Tensor t_b = ttnn::operations::creation::create_scalar(b, a.get_dtype(), Layout::TILE, a.device()); + Tensor t_c = ttnn::operations::creation::create_scalar(c, a.get_dtype(), Layout::TILE, a.device()); + Tensor return_tensor = mac(a, t_b, t_c, output_mem_config); + t_b.deallocate(); + t_c.deallocate(); + return return_tensor; } Tensor mac(const Tensor& input_a, float b, float c, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _mac_overload)(input_a, b, c, output_mem_config); @@ -469,7 +497,10 @@ Tensor _sinh(const Tensor& input_a, const MemoryConfig& output_mem_config) { Tensor nr_term = ttnn::subtract(e_pos_x, e_neg_x, std::nullopt, output_mem_config); e_pos_x.deallocate(); e_neg_x.deallocate(); - return bcast(nr_term, mk_tiled_scalar(0.5f), BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); + Tensor scalar = + ttnn::operations::creation::create_scalar(0.5f, input_a.get_dtype(), Layout::TILE, input_a.device()); + return ttnn::multiply(nr_term, scalar, std::nullopt, output_mem_config); + scalar.deallocate(); } Tensor sinh(const Tensor& input_a, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _sinh)(input_a, output_mem_config); @@ -482,7 +513,10 @@ Tensor _cosh(const Tensor& input_a, const MemoryConfig& output_mem_config) { Tensor nr_term = ttnn::add(e_pos_x, e_neg_x, std::nullopt, output_mem_config); e_pos_x.deallocate(); e_neg_x.deallocate(); - return bcast(nr_term, mk_tiled_scalar(0.5f), BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); + Tensor scalar = + ttnn::operations::creation::create_scalar(0.5f, input_a.get_dtype(), Layout::TILE, input_a.device()); + return ttnn::multiply(nr_term, scalar, std::nullopt, output_mem_config); + scalar.deallocate(); } Tensor cosh(const Tensor& input_a, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _cosh)(input_a, output_mem_config); @@ -498,7 +532,8 @@ Tensor _asinh(const Tensor& input_a, const MemoryConfig& output_mem_config) { Tensor x_sq = square(input_a, output_mem_config); x_sq_p1 = add_unary(x_sq, 1.0f, output_mem_config); } - ln_res = log(ttnn::add(x_abs, sqrt(x_sq_p1, output_mem_config), std::nullopt, output_mem_config), output_mem_config); + ln_res = + log(ttnn::add(x_abs, sqrt(x_sq_p1, output_mem_config), std::nullopt, output_mem_config), output_mem_config); } // input is negative, output is -asinh(input) Tensor result = where(input_a, ln_res, neg(ln_res, output_mem_config), output_mem_config); @@ -521,19 +556,19 @@ Tensor _acosh(const Tensor& input_a, const MemoryConfig& output_mem_config) { Tensor x_sq = square(x_abs, output_mem_config); x_sq_m1 = sub_unary(x_sq, 1.0f, output_mem_config); } - ln_res = - log(ttnn::add(x_abs, sqrt(x_sq_m1, output_mem_config), std::nullopt, output_mem_config), output_mem_config); + ln_res = log( + ttnn::add(x_abs, sqrt(x_sq_m1, output_mem_config), std::nullopt, output_mem_config), output_mem_config); } // To handle inputs <= 1 // input < 1, output is nan // input > 1, output is acosh(input) - Tensor nan_res = bcast( - ttnn::le(input_a, t_one, std::nullopt, output_mem_config), - mk_tiled_scalar(std::nanf("")), - BcastOpMath::MUL, - BcastOpDim::HW, - output_mem_config); - t_result = ttnn::multiply(ttnn::gt(input_a, t_one, std::nullopt, output_mem_config), ln_res, std::nullopt, output_mem_config); + Tensor scalar = ttnn::operations::creation::create_scalar( + std::nanf(""), input_a.get_dtype(), Layout::TILE, input_a.device()); + Tensor nan_res = ttnn::multiply( + ttnn::le(input_a, t_one, std::nullopt, output_mem_config), scalar, std::nullopt, output_mem_config); + scalar.deallocate(); + t_result = ttnn::multiply( + ttnn::gt(input_a, t_one, std::nullopt, output_mem_config), ln_res, std::nullopt, output_mem_config); t_result = ttnn::add(nan_res, t_result, std::nullopt, output_mem_config); } // input == 1, output is 0 @@ -553,7 +588,8 @@ Tensor _atanh(const Tensor& input_a, const MemoryConfig& output_mem_config) { Tensor pos_x = add_unary(input_a, 1.0f, output_mem_config); Tensor neg_x = sub_unary(input_a, 1.0f, output_mem_config); nr_term = log( - ttnn::multiply(pos_x, recip(neg(neg_x, output_mem_config), output_mem_config), std::nullopt, output_mem_config), + ttnn::multiply( + pos_x, recip(neg(neg_x, output_mem_config), output_mem_config), std::nullopt, output_mem_config), output_mem_config); } comp_result = mul_unary(nr_term, 0.5f, output_mem_config); @@ -571,9 +607,10 @@ Tensor atanh(const Tensor& input_a, const MemoryConfig& output_mem_config) { // lerp(input, end, weight) = start + weight * (end - start) Tensor _lerp(const Tensor& input_a, const Tensor& input_b, float value, const MemoryConfig& output_mem_config) { - Tensor t_value = mk_tiled_scalar(value); + Tensor t_value = + ttnn::operations::creation::create_scalar(value, input_a.get_dtype(), Layout::TILE, input_a.device()); Tensor t_diff = ttnn::subtract(input_b, input_a, std::nullopt, output_mem_config); - Tensor t_mul = bcast(t_diff, t_value, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); + Tensor t_mul = ttnn::multiply(t_diff, t_value, std::nullopt, output_mem_config); Tensor result = ttnn::add(input_a, t_mul, std::nullopt, output_mem_config); return result; } @@ -584,11 +621,11 @@ Tensor lerp(const Tensor& input_a, const Tensor& input_b, float value, const Mem Tensor _atan2(const Tensor& input_a, const Tensor& input_b, const MemoryConfig& output_mem_config) { Tensor result(input_a); { - Tensor atan_input = - ttnn::multiply(abs(input_b, output_mem_config), - recip(abs(input_a, output_mem_config), output_mem_config), - std::nullopt, - output_mem_config); + Tensor atan_input = ttnn::multiply( + abs(input_b, output_mem_config), + recip(abs(input_a, output_mem_config), output_mem_config), + std::nullopt, + output_mem_config); result = atan(atan_input, output_mem_config); } Tensor res(result); @@ -621,8 +658,8 @@ Tensor atan2(const Tensor& input_a, const Tensor& input_b, const MemoryConfig& o // lerp(input, end, weight) = start + weight * (end - start) Tensor _lerp_overload( const Tensor& input_a, const Tensor& input_b, const Tensor& input_c, const MemoryConfig& output_mem_config) { - Tensor t_diff = - ttnn::multiply(ttnn::subtract(input_b, input_a, std::nullopt, output_mem_config), input_c, std::nullopt, output_mem_config); + Tensor t_diff = ttnn::multiply( + ttnn::subtract(input_b, input_a, std::nullopt, output_mem_config), input_c, std::nullopt, output_mem_config); Tensor result = ttnn::add(input_a, t_diff, std::nullopt, output_mem_config); return result; } @@ -727,11 +764,18 @@ Tensor _addalpha( const MemoryConfig& output_mem_config, std::optional output_tensor) { if (output_tensor.has_value()) { - ttnn::add(cq_id, mul_unary(cq_id, input_b, alpha, output_mem_config), input_a, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, output_tensor); + ttnn::add( + cq_id, + mul_unary(cq_id, input_b, alpha, output_mem_config), + input_a, + std::nullopt, + operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + output_tensor); return output_tensor.value(); } - return ttnn::add(cq_id, mul_unary(cq_id, input_b, alpha, output_mem_config), input_a, std::nullopt, output_mem_config); + return ttnn::add( + cq_id, mul_unary(cq_id, input_b, alpha, output_mem_config), input_a, std::nullopt, output_mem_config); } Tensor addalpha( @@ -740,8 +784,8 @@ Tensor addalpha( float alpha, const MemoryConfig& output_mem_config, std::optional output_tensor) { - uint8_t default_queue_id = 0; - return operation::decorate_as_composite(__func__, _addalpha)( + uint8_t default_queue_id = 0; + return operation::decorate_as_composite(__func__, _addalpha)( default_queue_id, input_a, input_b, alpha, output_mem_config, output_tensor); } @@ -833,9 +877,10 @@ Tensor _addcmul( const Tensor& input_c, float value, const MemoryConfig& output_mem_config) { - Tensor t_value = mk_tiled_scalar(value); + Tensor t_value = + ttnn::operations::creation::create_scalar(value, input_a.get_dtype(), Layout::TILE, input_a.device()); Tensor t_mul = ttnn::multiply(input_b, input_c, std::nullopt, output_mem_config); - Tensor t_factor = bcast(t_mul, t_value, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); + Tensor t_factor = ttnn::multiply(t_mul, t_value, std::nullopt, output_mem_config); t_mul.deallocate(); t_value.deallocate(); Tensor result = ttnn::add(input_a, t_factor, std::nullopt, output_mem_config); @@ -857,9 +902,10 @@ Tensor _addcdiv( const Tensor& input_c, float value, const MemoryConfig& output_mem_config) { - Tensor t_value = mk_tiled_scalar(value); + Tensor t_value = + ttnn::operations::creation::create_scalar(value, input_a.get_dtype(), Layout::TILE, input_a.device()); Tensor t_div = ttnn::multiply(input_b, recip(input_c, output_mem_config), std::nullopt, output_mem_config); - Tensor t_factor = bcast(t_div, t_value, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); + Tensor t_factor = ttnn::multiply(t_div, t_value, std::nullopt, output_mem_config); t_div.deallocate(); t_value.deallocate(); Tensor result = ttnn::add(input_a, t_factor, std::nullopt, output_mem_config); @@ -955,7 +1001,7 @@ Tensor _round(const Tensor& input, int64_t decimals, const MemoryConfig& output_ auto arch = input.device()->arch(); TT_FATAL(arch == tt::ARCH::WORMHOLE_B0, "Op is only supported on Wormhole"); Tensor floor_res = tt::tt_metal::floor(input, output_mem_config); - if (decimals != 0) { //TODO: For decimal value!=0 + if (decimals != 0) { // TODO: For decimal value!=0 Tensor power_10 = pow(full_like(input, 10.0f, output_mem_config), static_cast(decimals), output_mem_config); Tensor rounded_non_half = tt::tt_metal::floor( @@ -965,7 +1011,8 @@ Tensor _round(const Tensor& input, int64_t decimals, const MemoryConfig& output_ return rounded_non_half; } else { // Bankers' Rounding Tensor rounded_non_half = tt::tt_metal::floor( - ttnn::add(input, + ttnn::add( + input, where(ttnn::logical_and(gte_unary(input, 0.4), lte_unary(input, 0.5)), 0.4f, 0.5f, output_mem_config), std::nullopt, output_mem_config), @@ -1009,6 +1056,7 @@ Tensor _floor_div_overload(const Tensor& input, float value, const MemoryConfig& t_nan, ttnn::multiply(t_inf, sign(input, output_mem_config), std::nullopt, output_mem_config), output_mem_config); + } Tensor temp = div_unary(input, value); return floor(temp); @@ -1039,10 +1087,10 @@ Tensor _remainder(const Tensor& input_a, const Tensor& input_b, const MemoryConf DataType input_dtype = input_a.get_dtype(); Tensor a = typecast(input_a, DataType::FLOAT32); Tensor b = typecast(input_b, DataType::FLOAT32); - Tensor result = ttnn::subtract(a, ttnn::multiply(b, floor_div(a, b), std::nullopt, output_mem_config)); + Tensor result = ttnn::subtract(a, ttnn::multiply(b, floor_div(input_a, input_b, output_mem_config), std::nullopt, output_mem_config)); result = where(ttnn::ge(result, b), ttnn::subtract(result, b), result); result = where(ltz(b), ttnn::add(result, b), result); - result = where(ttnn::eq(a, b), 0, result); + result = where(ttnn::eq(a, b, std::nullopt, output_mem_config), full_like(input_a, 0.0f, output_mem_config), result, output_mem_config); return typecast(result, input_dtype); } Tensor remainder(const Tensor& input_a, const Tensor& input_b, const MemoryConfig& output_mem_config) { @@ -1053,8 +1101,8 @@ Tensor _fmod(const Tensor& input_a, const Tensor& input_b, const MemoryConfig& o DataType input_dtype = input_a.get_dtype(); Tensor a = typecast(input_a, DataType::FLOAT32); Tensor b = typecast(input_b, DataType::FLOAT32); - Tensor result = ttnn::subtract(a, ttnn::multiply(b, div(input_a, input_b, true, "trunc"), std::nullopt, output_mem_config)); - result = where(ttnn::eq(a, b), 0, result); + Tensor result = ttnn::subtract(a, ttnn::multiply(b, div(input_a, input_b, true, "trunc", output_mem_config), std::nullopt, output_mem_config)); + result = where(ttnn::eq(a, b, std::nullopt, output_mem_config), full_like(input_a, 0.0f, output_mem_config), result, output_mem_config); return typecast(result, input_dtype); } Tensor fmod(const Tensor& input_a, const Tensor& input_b, const MemoryConfig& output_mem_config) { @@ -1077,7 +1125,8 @@ Tensor _logit(const Tensor& input_a, float eps, const MemoryConfig& output_mem_c t_eps.deallocate(); t1m_eps.deallocate(); Tensor linput_m1 = rsub(logit_input, 1.0, output_mem_config); - Tensor log_input = ttnn::multiply(logit_input, recip(linput_m1, output_mem_config), std::nullopt, output_mem_config); + Tensor log_input = + ttnn::multiply(logit_input, recip(linput_m1, output_mem_config), std::nullopt, output_mem_config); linput_m1.deallocate(); Tensor t_inf = mul_unary(sign(input_a, output_mem_config), std::numeric_limits::infinity(), output_mem_config); @@ -1299,7 +1348,7 @@ Tensor _normalize(const Tensor& y, const MemoryConfig& output_mem_config) { Tensor y_minus_mean_y = bcast(y, mean_y, BcastOpMath::SUB, BcastOpDim::HW); Tensor std_y = tt::tt_metal::_std(y, mean_y, y_minus_mean_y, output_mem_config); Tensor recip_std_y = recip(std_y, output_mem_config); - Tensor z = bcast(y_minus_mean_y, recip_std_y, BcastOpMath::MUL, BcastOpDim::HW); + Tensor z = ttnn::multiply(y_minus_mean_y, recip_std_y); return z; } Tensor normalize_hw(const Tensor& y, const MemoryConfig& output_mem_config) { @@ -1360,18 +1409,20 @@ Tensor scatter(const Tensor& input_a, const Tensor& input_b, const MemoryConfig& } // threshold(a,t,v) = (a <= t)*v + (a > t)*a -Tensor _threshold(const Tensor& input_a, float threshold, float value, const MemoryConfig& output_mem_config) { - Tensor t_threshold = mk_tiled_scalar(threshold, input_a.get_dtype()); - Tensor t0 = bcast(input_a, t_threshold, BcastOpMath::SUB, BcastOpDim::HW, output_mem_config); +Tensor _threshold(const Tensor& input_tensor, float threshold, float value, const MemoryConfig& output_mem_config) { + Tensor t_threshold = ttnn::operations::creation::create_scalar( + threshold, input_tensor.get_dtype(), Layout::TILE, input_tensor.device()); + Tensor t0 = bcast(input_tensor, t_threshold, BcastOpMath::SUB, BcastOpDim::HW, output_mem_config); t_threshold.deallocate(); - Tensor t_value = mk_tiled_scalar(value, input_a.get_dtype()); - Tensor t1 = bcast(lez(t0), t_value, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); + Tensor t_value = + ttnn::operations::creation::create_scalar(value, input_tensor.get_dtype(), Layout::TILE, input_tensor.device()); + Tensor t1 = ttnn::multiply(lez(t0), t_value, std::nullopt, output_mem_config); t_value.deallocate(); - Tensor t2 = ttnn::multiply(gtz(t0, output_mem_config), input_a, std::nullopt, output_mem_config); + Tensor t2 = ttnn::multiply(gtz(t0, output_mem_config), input_tensor, std::nullopt, output_mem_config); return ttnn::add(t1, t2, std::nullopt, output_mem_config); } -Tensor threshold(const Tensor& input_a, float threshold, float value, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _threshold)(input_a, threshold, value, output_mem_config); +Tensor threshold(const Tensor& input_tensor, float threshold, float value, const MemoryConfig& output_mem_config) { + return operation::decorate_as_composite(__func__, _threshold)(input_tensor, threshold, value, output_mem_config); } // TODO: In future will uplift the op once the floor and tan has supported. @@ -1388,27 +1439,33 @@ Tensor _digamma(const Tensor& input_a, const MemoryConfig& output_mem_config) { // (1/120) * x^4 tmp = ttnn::multiply(tmp, val_square, std::nullopt, output_mem_config); - output = ttnn::add(output, mul_unary(tmp, 0.008333333333333333f, output_mem_config), std::nullopt, output_mem_config); + output = + ttnn::add(output, mul_unary(tmp, 0.008333333333333333f, output_mem_config), std::nullopt, output_mem_config); //(1/252) * x^6 tmp = ttnn::multiply(tmp, val_square, std::nullopt, output_mem_config); - output = ttnn::subtract(output, mul_unary(tmp, 0.003968253968253968f, output_mem_config), std::nullopt, output_mem_config); + output = ttnn::subtract( + output, mul_unary(tmp, 0.003968253968253968f, output_mem_config), std::nullopt, output_mem_config); // (1/240) *x^8 tmp = ttnn::multiply(tmp, val_square, std::nullopt, output_mem_config); - output = ttnn::add(output, mul_unary(tmp, 0.004166666666666667f, output_mem_config), std::nullopt, output_mem_config); + output = + ttnn::add(output, mul_unary(tmp, 0.004166666666666667f, output_mem_config), std::nullopt, output_mem_config); //(1/132) * x^10 tmp = ttnn::multiply(tmp, val_square, std::nullopt, output_mem_config); - output = ttnn::subtract(output, mul_unary(tmp, 0.007575757575757576, output_mem_config), std::nullopt, output_mem_config); + output = ttnn::subtract( + output, mul_unary(tmp, 0.007575757575757576, output_mem_config), std::nullopt, output_mem_config); //(691/32760) * x^12 tmp = ttnn::multiply(tmp, val_square, std::nullopt, output_mem_config); - output = ttnn::add(output, mul_unary(tmp, 0.021092796092796094, output_mem_config), std::nullopt, output_mem_config); + output = + ttnn::add(output, mul_unary(tmp, 0.021092796092796094, output_mem_config), std::nullopt, output_mem_config); //(1/12) * x^14 tmp = ttnn::multiply(tmp, val_square, std::nullopt, output_mem_config); - output = ttnn::subtract(output, mul_unary(tmp, 0.08333333333333333, output_mem_config), std::nullopt, output_mem_config); + output = + ttnn::subtract(output, mul_unary(tmp, 0.08333333333333333, output_mem_config), std::nullopt, output_mem_config); return ttnn::subtract(t_log_out, output, std::nullopt, output_mem_config); } @@ -1418,16 +1475,18 @@ Tensor digamma(const Tensor& input_a, const MemoryConfig& output_mem_config) { // cbrt(a) = pow(a,1/3) or (cbrt(a))**3 = a. // = exp[ (1/3)*log[a] ] -Tensor _cbrt(const Tensor& input_a, const MemoryConfig& output_mem_config) { +Tensor _cbrt(const Tensor& input_tensor, const MemoryConfig& output_mem_config) { constexpr float scale = (float)(1.0 / 3.0); - Tensor t_scale = mk_tiled_scalar(scale); - Tensor t_ln_input = log(abs(input_a, output_mem_config), output_mem_config); // negative log is not useful here - Tensor t1 = bcast(t_ln_input, t_scale, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); + Tensor t_scale = + ttnn::operations::creation::create_scalar(scale, input_tensor.get_dtype(), Layout::TILE, input_tensor.device()); + Tensor t_ln_input = + log(abs(input_tensor, output_mem_config), output_mem_config); // negative log is not useful here + Tensor t1 = ttnn::multiply(t_ln_input, t_scale, std::nullopt, output_mem_config); t_scale.deallocate(); t_ln_input.deallocate(); Tensor t2 = exp(t1, output_mem_config); t1.deallocate(); - Tensor t3 = ttnn::multiply(t2, sign(input_a, output_mem_config), std::nullopt, output_mem_config); + Tensor t3 = ttnn::multiply(t2, sign(input_tensor, output_mem_config), std::nullopt, output_mem_config); return t3; } Tensor cbrt(const Tensor& input_a, const MemoryConfig& output_mem_config) { @@ -1443,57 +1502,87 @@ Tensor _where( const Tensor& value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { - - Tensor t2 = ttnn::multiply(queue_id, gtz(predicate, output_mem_config), value_true, std::nullopt, output_mem_config); - if(output_tensor.has_value()) - { - ttnn::multiply(queue_id, lez(predicate, output_mem_config), value_false, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, output_tensor); - ttnn::add(queue_id, t2, output_tensor.value(), std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, output_tensor); - } - else - { - Tensor t1 = ttnn::multiply(queue_id, lez(predicate, output_mem_config), value_false, std::nullopt, output_mem_config); + Tensor t2 = + ttnn::multiply(queue_id, gtz(predicate, output_mem_config), value_true, std::nullopt, output_mem_config); + if (output_tensor.has_value()) { + ttnn::multiply( + queue_id, + lez(predicate, output_mem_config), + value_false, + std::nullopt, + operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + output_tensor); + ttnn::add( + queue_id, t2, output_tensor.value(), std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, output_tensor); + } else { + Tensor t1 = + ttnn::multiply(queue_id, lez(predicate, output_mem_config), value_false, std::nullopt, output_mem_config); output_tensor = ttnn::add(queue_id, t2, t1, std::nullopt, output_mem_config); } return output_tensor.value(); } Tensor _where_v1( - uint8_t queue_id, const Tensor& predicate, const float value_true, const Tensor& value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { - + uint8_t queue_id, + const Tensor& predicate, + const float value_true, + const Tensor& value_false, + const MemoryConfig& output_mem_config, + std::optional output_tensor) { Tensor t2 = mul_unary(queue_id, gtz(predicate, output_mem_config), value_true, output_mem_config); - if(output_tensor.has_value()){ - ttnn::multiply(queue_id, lez(predicate, output_mem_config), value_false, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, output_tensor); - ttnn::add(queue_id, t2, output_tensor.value(), std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, output_tensor); - } - else - { - Tensor t1 = ttnn::multiply(queue_id, lez(predicate, output_mem_config), value_false, std::nullopt, output_mem_config); + if (output_tensor.has_value()) { + ttnn::multiply( + queue_id, + lez(predicate, output_mem_config), + value_false, + std::nullopt, + operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + output_tensor); + ttnn::add( + queue_id, t2, output_tensor.value(), std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, output_tensor); + } else { + Tensor t1 = + ttnn::multiply(queue_id, lez(predicate, output_mem_config), value_false, std::nullopt, output_mem_config); output_tensor = ttnn::add(queue_id, t2, t1, std::nullopt, output_mem_config); } return output_tensor.value(); } Tensor _where_v2( - uint8_t queue_id, const Tensor& predicate, const Tensor& value_true, float value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { - + uint8_t queue_id, + const Tensor& predicate, + const Tensor& value_true, + float value_false, + const MemoryConfig& output_mem_config, + std::optional output_tensor) { Tensor t1 = mul_unary(queue_id, lez(predicate, output_mem_config), value_false, output_mem_config); - if(output_tensor.has_value()){ - ttnn::multiply(queue_id, gtz(predicate, output_mem_config), value_true, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, output_tensor); - ttnn::add(queue_id, output_tensor.value(), t1, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, output_tensor); - } - else - { - Tensor t2 = ttnn::multiply(queue_id, gtz(predicate, output_mem_config), value_true, std::nullopt, output_mem_config); + if (output_tensor.has_value()) { + ttnn::multiply( + queue_id, + gtz(predicate, output_mem_config), + value_true, + std::nullopt, + operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + output_tensor); + ttnn::add( + queue_id, output_tensor.value(), t1, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, output_tensor); + } else { + Tensor t2 = + ttnn::multiply(queue_id, gtz(predicate, output_mem_config), value_true, std::nullopt, output_mem_config); output_tensor = ttnn::add(queue_id, t2, t1, std::nullopt, output_mem_config); } return output_tensor.value(); } Tensor _where_v3( - uint8_t queue_id, const Tensor& predicate, const float value_true, const float value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { + uint8_t queue_id, + const Tensor& predicate, + const float value_true, + const float value_false, + const MemoryConfig& output_mem_config, + std::optional output_tensor) { Tensor t2 = mul_unary(queue_id, gtz(predicate, output_mem_config), value_true, output_mem_config); Tensor t1 = mul_unary(queue_id, lez(predicate, output_mem_config), value_false, output_mem_config); - if(output_tensor.has_value()){ + if (output_tensor.has_value()) { ttnn::add(queue_id, t2, t1, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, output_tensor); } else { output_tensor = ttnn::add(queue_id, t2, t1, std::nullopt, output_mem_config); @@ -1508,22 +1597,38 @@ Tensor where( const MemoryConfig& output_mem_config, std::optional output_tensor) { uint8_t default_queue_id = 0; - return operation::decorate_as_composite(__func__, _where)(default_queue_id, predicate, value_true, value_false, output_mem_config, output_tensor); + return operation::decorate_as_composite(__func__, _where)( + default_queue_id, predicate, value_true, value_false, output_mem_config, output_tensor); } Tensor where( - const Tensor& predicate, const float value_true, const Tensor& value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { + const Tensor& predicate, + const float value_true, + const Tensor& value_false, + const MemoryConfig& output_mem_config, + std::optional output_tensor) { uint8_t default_queue_id = 0; - return operation::decorate_as_composite(__func__, _where_v1)(default_queue_id, predicate, value_true, value_false, output_mem_config, output_tensor); + return operation::decorate_as_composite(__func__, _where_v1)( + default_queue_id, predicate, value_true, value_false, output_mem_config, output_tensor); } Tensor where( - const Tensor& predicate, const Tensor& value_true, const float value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { + const Tensor& predicate, + const Tensor& value_true, + const float value_false, + const MemoryConfig& output_mem_config, + std::optional output_tensor) { uint8_t default_queue_id = 0; - return operation::decorate_as_composite(__func__, _where_v2)(default_queue_id, predicate, value_true, value_false, output_mem_config, output_tensor); + return operation::decorate_as_composite(__func__, _where_v2)( + default_queue_id, predicate, value_true, value_false, output_mem_config, output_tensor); } Tensor where( - const Tensor& predicate, const float value_true, const float value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { + const Tensor& predicate, + const float value_true, + const float value_false, + const MemoryConfig& output_mem_config, + std::optional output_tensor) { uint8_t default_queue_id = 0; - return operation::decorate_as_composite(__func__, _where_v3)(default_queue_id, predicate, value_true, value_false, output_mem_config, output_tensor); + return operation::decorate_as_composite(__func__, _where_v3)( + default_queue_id, predicate, value_true, value_false, output_mem_config, output_tensor); } Tensor where( @@ -1533,29 +1638,50 @@ Tensor where( const Tensor& value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { - return operation::decorate_as_composite(__func__, _where)(queue_id, predicate, value_true, value_false, output_mem_config, output_tensor); + return operation::decorate_as_composite(__func__, _where)( + queue_id, predicate, value_true, value_false, output_mem_config, output_tensor); } Tensor where( uint8_t queue_id, - const Tensor& predicate, const float value_true, const Tensor& value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { - return operation::decorate_as_composite(__func__, _where_v1)(queue_id, predicate, value_true, value_false, output_mem_config, output_tensor); + const Tensor& predicate, + const float value_true, + const Tensor& value_false, + const MemoryConfig& output_mem_config, + std::optional output_tensor) { + return operation::decorate_as_composite(__func__, _where_v1)( + queue_id, predicate, value_true, value_false, output_mem_config, output_tensor); } Tensor where( uint8_t queue_id, - const Tensor& predicate, const Tensor& value_true, const float value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { - return operation::decorate_as_composite(__func__, _where_v2)(queue_id, predicate, value_true, value_false, output_mem_config, output_tensor); + const Tensor& predicate, + const Tensor& value_true, + const float value_false, + const MemoryConfig& output_mem_config, + std::optional output_tensor) { + return operation::decorate_as_composite(__func__, _where_v2)( + queue_id, predicate, value_true, value_false, output_mem_config, output_tensor); } Tensor where( uint8_t queue_id, - const Tensor& predicate, const float value_true, const float value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { - return operation::decorate_as_composite(__func__, _where_v3)(queue_id, predicate, value_true, value_false, output_mem_config, output_tensor); + const Tensor& predicate, + const float value_true, + const float value_false, + const MemoryConfig& output_mem_config, + std::optional output_tensor) { + return operation::decorate_as_composite(__func__, _where_v3)( + queue_id, predicate, value_true, value_false, output_mem_config, output_tensor); } // on-device tensor creation 0s like @reference_tensor -Tensor zeros_like(uint8_t queue_id, const Tensor& reference_tensor, const MemoryConfig& output_mem_config, std::optional output_tensor) { +Tensor zeros_like( + uint8_t queue_id, + const Tensor& reference_tensor, + const MemoryConfig& output_mem_config, + std::optional output_tensor) { return mk_zero_tensor_like(reference_tensor, output_mem_config, output_tensor); } -Tensor zeros_like(const Tensor& reference_tensor, const MemoryConfig& output_mem_config, std::optional output_tensor) { +Tensor zeros_like( + const Tensor& reference_tensor, const MemoryConfig& output_mem_config, std::optional output_tensor) { uint8_t default_queue_id = 0; return mk_zero_tensor_like(default_queue_id, reference_tensor, output_mem_config, output_tensor); } @@ -1566,11 +1692,20 @@ Tensor ones_like(const Tensor& reference_tensor, const MemoryConfig& output_mem_ } // on-device tensor creation with value like @reference_tensor -Tensor full_like(const Tensor& reference_tensor, float value, const MemoryConfig& output_mem_config, std::optional output_tensor) { +Tensor full_like( + const Tensor& reference_tensor, + float value, + const MemoryConfig& output_mem_config, + std::optional output_tensor) { uint8_t default_queue_id = 0; return mk_filled_tensor_like(reference_tensor, value, output_mem_config, output_tensor, default_queue_id); } -Tensor full_like(uint8_t queue_id, const Tensor& reference_tensor, float value, const MemoryConfig& output_mem_config, std::optional output_tensor) { +Tensor full_like( + uint8_t queue_id, + const Tensor& reference_tensor, + float value, + const MemoryConfig& output_mem_config, + std::optional output_tensor) { return mk_filled_tensor_like(reference_tensor, value, output_mem_config, output_tensor, queue_id); } @@ -1791,7 +1926,8 @@ Tensor sfpu_eps(const Shape shape, Layout layout, Device* device, const MemoryCo // tril : select lower triangular region of input matrix Tensor _tril(const Tensor& input_a, int32_t diag, const MemoryConfig& output_mem_config) { - Tensor index_l = tt::numpy::index_tril(input_a.get_legacy_shape(), diag, DataType::BFLOAT16, Layout::TILE, input_a.device(), output_mem_config); + Tensor index_l = tt::numpy::index_tril( + input_a.get_legacy_shape(), diag, DataType::BFLOAT16, Layout::TILE, input_a.device(), output_mem_config); return ttnn::multiply(input_a, index_l, std::nullopt, output_mem_config); } Tensor tril( @@ -1803,7 +1939,8 @@ Tensor tril( // triu : select upper triangular region of input matrix Tensor _triu(const Tensor& input_a, int32_t diag, const MemoryConfig& output_mem_config) { - Tensor index_u = tt::numpy::index_triu(input_a.get_legacy_shape(), diag, DataType::BFLOAT16, Layout::TILE, input_a.device(), output_mem_config); + Tensor index_u = tt::numpy::index_triu( + input_a.get_legacy_shape(), diag, DataType::BFLOAT16, Layout::TILE, input_a.device(), output_mem_config); return ttnn::multiply(input_a, index_u, std::nullopt, output_mem_config); } Tensor triu( @@ -1824,7 +1961,8 @@ Tensor _power_fp(const Tensor& input_a, float exponent, const MemoryConfig& outp Tensor pow_frac = exp(pow_trunc_log, output_mem_config); pow_trunc_log.deallocate(); float t_nan = std::nanf(""); - Tensor result = ttnn::multiply(power(input_a, exponent_floor, output_mem_config), pow_frac, std::nullopt, output_mem_config); + Tensor result = + ttnn::multiply(power(input_a, exponent_floor, output_mem_config), pow_frac, std::nullopt, output_mem_config); // To handle negative inputs: // in torch For -ve inputs with float exponent power returns nan result = where(ltz(input_a, output_mem_config), t_nan, result); @@ -1877,14 +2015,13 @@ Tensor _argmax(const Tensor& input_t, int64_t _dim, bool all, const MemoryConfig bool is_width = (dim == (input_shape.rank() - 1)); Tensor max_val = max(input_a, dim, output_mem_config); Tensor max_tensor = zeros_like(input_a, output_mem_config); - Tensor tindex = tt::numpy::index_width(input_shape, DataType::BFLOAT16, Layout::TILE, input_a.device(), output_mem_config); - if (is_width) - { + Tensor tindex = tt::numpy::index_width( + input_shape, DataType::BFLOAT16, Layout::TILE, input_a.device(), output_mem_config); + if (is_width) { max_tensor = bcast(max_tensor, max_val, BcastOpMath::ADD, BcastOpDim::W, output_mem_config); - } - else - { - tindex = tt::numpy::index_height(input_shape, DataType::BFLOAT16, Layout::TILE, input_a.device(), output_mem_config); + } else { + tindex = tt::numpy::index_height( + input_shape, DataType::BFLOAT16, Layout::TILE, input_a.device(), output_mem_config); max_tensor = bcast(max_tensor, max_val, BcastOpMath::ADD, BcastOpDim::H, output_mem_config); } tindex = tindex.to(input_a.device()); @@ -1921,13 +2058,14 @@ Tensor _argmax(const Tensor& input_t, int64_t _dim, bool all, const MemoryConfig concat_out = ttnn::reshape(concat_out, input_a.get_shape()); Tensor cmp_results = ttnn::eq(input_a, concat_out, std::nullopt, output_mem_config); concat_out.deallocate(); - Tensor tindex = tt::numpy::index_channel(input_shape, DataType::BFLOAT16, Layout::TILE, input_a.device(), output_mem_config); - if (!is_channel) - { - tindex = tt::numpy::index_batch(input_shape, DataType::BFLOAT16, Layout::TILE, input_a.device(), output_mem_config); + Tensor tindex = tt::numpy::index_channel( + input_shape, DataType::BFLOAT16, Layout::TILE, input_a.device(), output_mem_config); + if (!is_channel) { + tindex = tt::numpy::index_batch( + input_shape, DataType::BFLOAT16, Layout::TILE, input_a.device(), output_mem_config); } tindex = tindex.to(input_a.device()); - Tensor max_indices = ttnn::multiply(cmp_results, tindex, std::nullopt, output_mem_config); + Tensor max_indices = ttnn::multiply(cmp_results, tindex, std::nullopt, output_mem_config); cmp_results.deallocate(); Tensor midx = full_like(max_indices, size); Tensor result = where(eqz(max_indices), midx, max_indices, output_mem_config); @@ -1945,8 +2083,10 @@ Tensor _argmax(const Tensor& input_t, int64_t _dim, bool all, const MemoryConfig } } } - //TODO: Fix the index generation code. With the fix the code will work for argmax that return entire maximum value index - Tensor tindex = tt::numpy::index_all(input_shape, DataType::BFLOAT16, Layout::TILE, input_a.device(), output_mem_config); + // TODO: Fix the index generation code. With the fix the code will work for argmax that return entire + // maximum value index + Tensor tindex = tt::numpy::index_all( + input_shape, DataType::BFLOAT16, Layout::TILE, input_a.device(), output_mem_config); Tensor max_val = global_max(input_a, output_mem_config); Tensor max_tensor = zeros_like(input_a, output_mem_config); max_tensor = bcast(max_tensor, max_val, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config); diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp index 40555a7dc75..31b399fc981 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp @@ -11,6 +11,7 @@ #include "tt_dnn/op_library/bcast/bcast_op.hpp" #include "tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" #include "tt_metal/common/constants.hpp" +#include "ttnn/cpp/ttnn/operations/creation.hpp" #include "ttnn/operations/eltwise/binary/device/binary_op.hpp" @@ -24,44 +25,6 @@ using binary_tensor_op_t = Tensor(const Tensor& a, const Tensor& b); // Note: inline doesn't allow pybind to work well so we keep few function not inlined. -template -Tensor mk_scalar(T value) { - assert(std::is_scalar::value && "T should be scalar"); - std::array shape = {1, 1, 1, 1}; - auto buffer = owned_buffer::create(std::vector{bfloat16(value)}); - Tensor scalar = Tensor(OwnedStorage{buffer}, shape, DataType::BFLOAT16, Layout::ROW_MAJOR); - return scalar; -} - -template -Tensor mk_tiled_scalar(T value) { - assert(std::is_scalar::value && "T should be scalar"); - std::array shape = {1, 1, TILE_HEIGHT, TILE_WIDTH}; - std::vector buffer_vec(TILE_HW, bfloat16(0)); - buffer_vec[0] = bfloat16(value); - auto buffer = owned_buffer::create(std::move(buffer_vec)); - Tensor scalar = Tensor(OwnedStorage{buffer}, shape, DataType::BFLOAT16, Layout::TILE); - return scalar; -} - -template -Tensor mk_tiled_scalar(T value, DataType dtype) { - assert(std::is_scalar::value && "T should be scalar"); - std::array shape = {1, 1, TILE_HEIGHT, TILE_WIDTH}; - if(dtype == DataType::BFLOAT8_B) - { - std::vector buffer_vec(TILE_HW, float(0)); - buffer_vec[0] = float(value); - auto output_packed_data = pack_fp32_vec_as_bfp8_tiles(buffer_vec, /*row_major_input=*/false, /*is_exp_a=*/false); - auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - return Tensor(std::move(OwnedStorage{std::move(output_uint32_buffer)}), shape, DataType::BFLOAT8_B, Layout::TILE); - } - std::vector buffer_vec(TILE_HW, bfloat16(0)); - buffer_vec[0] = bfloat16(value); - auto buffer = owned_buffer::create(std::move(buffer_vec)); - Tensor scalar = Tensor(OwnedStorage{buffer}, shape, DataType::BFLOAT16, Layout::TILE); - return scalar; -} // Function: softshrink // Ref: https://pytorch.org/docs/stable/generated/torch.nn.Softshrink.html Tensor softshrink( diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp index b876ff5f465..bb8008a5b3a 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp @@ -482,8 +482,14 @@ const operation::Hash EltwiseUnary::compute_program_hash(const std::vector -Tensor tie_binop_to_unary(uint8_t queue_id, const Tensor& input_tensor, float value, const MemoryConfig& output_mem_config, std::optional output_tensor = std::nullopt) { - Tensor t_value = mk_tiled_scalar(value, input_tensor.get_dtype()); +Tensor tie_binop_to_unary( + uint8_t queue_id, + const Tensor& input_tensor, + float value, + const MemoryConfig& output_mem_config, + std::optional output_tensor = std::nullopt) { + const DataType& dtype = output_tensor.has_value() ? output_tensor.value().get_dtype() : input_tensor.get_dtype(); + Tensor t_value = ttnn::operations::creation::create_scalar(value, dtype, Layout::TILE, input_tensor.device()); return bcast(queue_id, input_tensor, t_value, OP, BcastOpDim::HW, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, output_tensor); } diff --git a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.cpp b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.cpp index 33197420599..f87918b615e 100644 --- a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.cpp @@ -8,6 +8,7 @@ #include #include "tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.hpp" +#include "ttnn/cpp/ttnn/operations/creation.hpp" namespace tt { @@ -225,7 +226,9 @@ Tensor moreh_clip_grad_norm_impl( // max_norm / (total_norm + 1e-6) const auto &clip_coef = div_unary(max_norm, add_unary(total_norm, 1e-6f)); // min(clip_coef, 1.0f) - const auto &clip_coef_clamped = min(clip_coef, mk_tiled_scalar(1.0f)); + Tensor scalar = ttnn::operations::creation::create_scalar(1.0f,inputs.at(0).get_dtype(),Layout::TILE, inputs.at(0).device()); + const auto &clip_coef_clamped = min(clip_coef, scalar); + scalar.deallocate(); // Inplace update inputs(inputs *= clip_coef_clamped) moreh_clip_grad_norm_step3(inputs, clip_coef_clamped); diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp b/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp index 901e1807864..f30ca5f9df1 100644 --- a/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp +++ b/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp @@ -286,6 +286,8 @@ tt::stl::reflection::Attributes AttnMatmul::attributes() const { } const operation::Hash AttnMatmul::compute_program_hash(const std::vector& input_tensors) const { + TT_ASSERT(std::holds_alternative(input_tensors.at(0).storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(input_tensors.at(0).get_storage()),__FILE__, __LINE__)); + TT_ASSERT(std::holds_alternative(input_tensors.at(1).storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(input_tensors.at(1).get_storage()),__FILE__, __LINE__)); return operation::hash_operation( this->transpose_hw, this->output_mem_config, @@ -492,6 +494,9 @@ const operation::Hash GroupAttnMatmul::compute_program_hash(const std::vector(input_tensor_a.storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(input_tensor_a.storage()),__FILE__, __LINE__)); + TT_ASSERT(std::holds_alternative(input_tensor_b.storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(input_tensor_b.storage()),__FILE__, __LINE__)); + return operation::hash_operation( this->transpose_hw, this->out_subblock_w, diff --git a/tt_eager/tt_dnn/op_library/transpose/transpose_op.cpp b/tt_eager/tt_dnn/op_library/transpose/transpose_op.cpp index b730ce879f3..588289546a1 100644 --- a/tt_eager/tt_dnn/op_library/transpose/transpose_op.cpp +++ b/tt_eager/tt_dnn/op_library/transpose/transpose_op.cpp @@ -150,6 +150,7 @@ TransposeOpParallelizationStrategy Transpose::get_parallelization_strategy(const const operation::Hash Transpose::compute_program_hash( const std::vector &input_tensors) const { auto input_tensor = input_tensors.at(0); + TT_ASSERT(std::holds_alternative(input_tensor.storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(input_tensor.get_storage()),__FILE__, __LINE__)); auto input_mem_config = std::get(input_tensor.storage()).memory_config(); auto output_mem_config = this->output_mem_config; auto dtype = input_tensor.dtype(); diff --git a/tt_eager/tt_dnn/op_library/unpad/unpad_op.cpp b/tt_eager/tt_dnn/op_library/unpad/unpad_op.cpp index b8f437d2138..3ea74358be9 100644 --- a/tt_eager/tt_dnn/op_library/unpad/unpad_op.cpp +++ b/tt_eager/tt_dnn/op_library/unpad/unpad_op.cpp @@ -147,6 +147,7 @@ tt::stl::reflection::Attributes Unpad::attributes() const { const operation::Hash Unpad::compute_program_hash(const std::vector &input_tensors) const { auto input_tensor = input_tensors.at(0); + TT_ASSERT(std::holds_alternative(input_tensor.storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(input_tensor.get_storage()),__FILE__, __LINE__)); auto input_mem_config = std::get(input_tensor.storage()).memory_config(); auto output_mem_config = this->output_mem_config; auto dtype = input_tensor.dtype(); diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp index d27cc6592e3..045b50e1c24 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp @@ -363,6 +363,7 @@ Tensor convert_python_tensors_to_tt_tensors(py::list tensor_shards, std::optiona std::vector host_owned_buffers; std::vector host_owned_shapes; for (const auto &shard : tt_shards) { + TT_ASSERT(std::holds_alternative(shard.get_storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(shard.get_storage()),__FILE__, __LINE__)); host_owned_buffers.push_back(std::get(shard.get_storage()).buffer); host_owned_shapes.push_back(shard.get_legacy_shape()); } @@ -428,12 +429,14 @@ Tensor convert_python_tensors_to_tt_tensors(py::list tensor_shards, std::optiona auto tt_dtype = tt_tensor.get_dtype(); if (tt_dtype == DataType::BFLOAT8_B) { + TT_ASSERT(std::holds_alternative(buffer), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(buffer),__FILE__, __LINE__)); auto uint32_data = std::get>(std::get(buffer)).get(); auto float_unpacked_data = unpack_bfp8_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false); buffer = owned_buffer::create(std::move(float_unpacked_data)); tt_dtype = DataType::FLOAT32; } if (tt_dtype == DataType::BFLOAT4_B) { + TT_ASSERT(std::holds_alternative(buffer), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(buffer),__FILE__, __LINE__)); auto uint32_data = std::get>(std::get(buffer)).get(); auto float_unpacked_data = unpack_bfp4_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false); buffer = owned_buffer::create(std::move(float_unpacked_data)); @@ -489,6 +492,7 @@ Tensor convert_python_tensors_to_tt_tensors(py::list tensor_shards, std::optiona auto tt_dtype = tt_tensor.get_dtype(); if (tt_dtype == DataType::BFLOAT8_B) { + TT_ASSERT(std::holds_alternative(buffer), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(buffer),__FILE__, __LINE__)); auto uint32_data = std::get>(std::get(buffer)).get(); auto float_unpacked_data = unpack_bfp8_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false); @@ -496,6 +500,7 @@ Tensor convert_python_tensors_to_tt_tensors(py::list tensor_shards, std::optiona tt_dtype = DataType::FLOAT32; } if (tt_dtype == DataType::BFLOAT4_B) { + TT_ASSERT(std::holds_alternative(buffer), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(buffer),__FILE__, __LINE__)); auto uint32_data = std::get>(std::get(buffer)).get(); auto float_unpacked_data = unpack_bfp4_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false); diff --git a/tt_metal/detail/reports/compilation_reporter.cpp b/tt_metal/detail/reports/compilation_reporter.cpp index 5efb85ed0a5..fabee0a191b 100644 --- a/tt_metal/detail/reports/compilation_reporter.cpp +++ b/tt_metal/detail/reports/compilation_reporter.cpp @@ -66,7 +66,7 @@ std::string kernel_attributes_str(std::shared_ptr kernel) { if (std::holds_alternative(config)) { attr_str += "NOC: " + std::to_string(std::get(config).noc) + " "; } else { - TT_ASSERT(std::holds_alternative(config)); + TT_ASSERT(std::holds_alternative(config), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(config),__FILE__, __LINE__)); auto compute_config = std::get(config); std::stringstream math_fidel_str; math_fidel_str << compute_config.math_fidelity; diff --git a/tt_metal/tt_stl/reflection.hpp b/tt_metal/tt_stl/reflection.hpp index c684ca25098..3b225fe47d9 100644 --- a/tt_metal/tt_stl/reflection.hpp +++ b/tt_metal/tt_stl/reflection.hpp @@ -32,6 +32,16 @@ constexpr std::string_view get_type_name(const T& object) { return get_type_name(); } +template +concept IsVariant = requires { typename std::variant_size::type; }; + +template +constexpr auto get_active_type_name_in_variant(const Variant& v) { + return std::visit([](auto&& arg) -> std::string_view { + return short_type_name>; + }, v); +} + // Forward Declare hash_object namespace hash { diff --git a/ttnn/cpp/ttnn/op_library/to_dtype/to_dtype_op.hpp b/ttnn/cpp/ttnn/op_library/to_dtype/to_dtype_op.hpp index 803529b0127..82268ce6ff2 100644 --- a/ttnn/cpp/ttnn/op_library/to_dtype/to_dtype_op.hpp +++ b/ttnn/cpp/ttnn/op_library/to_dtype/to_dtype_op.hpp @@ -44,12 +44,14 @@ inline Tensor convert_to_cpp_supported_dtype(const Tensor& input_tensor) { input_tensor.get_storage()); if (input_dtype == DataType::BFLOAT8_B) { + TT_ASSERT(std::holds_alternative(buffer), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(buffer),__FILE__, __LINE__)); auto uint32_data = std::get>(std::get(buffer)).get(); auto float_unpacked_data = unpack_bfp8_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false); buffer = owned_buffer::create(std::move(float_unpacked_data)); input_dtype = DataType::FLOAT32; } else if (input_dtype == DataType::BFLOAT4_B) { + TT_ASSERT(std::holds_alternative(buffer), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(buffer),__FILE__, __LINE__)); auto uint32_data = std::get>(std::get(buffer)).get(); auto float_unpacked_data = unpack_bfp4_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false); diff --git a/ttnn/cpp/ttnn/operations/creation.hpp b/ttnn/cpp/ttnn/operations/creation.hpp index 41bb024b840..f082bb2ae99 100644 --- a/ttnn/cpp/ttnn/operations/creation.hpp +++ b/ttnn/cpp/ttnn/operations/creation.hpp @@ -17,6 +17,42 @@ namespace ttnn { namespace operations { namespace creation { +template +Tensor create_scalar(T scalar, DataType data_type, Layout layout, Device* device){ + static_assert(rank >=2, "Rank must be at least 2 when creating a tensor with TILE_LAYOUT"); + std::array intended_shape = {}; + intended_shape.fill(1); + std::array device_shape = {}; + device_shape.fill(1); + + if(layout == Layout::ROW_MAJOR){ + device_shape[device_shape.size() - 2] = 2; + auto host_buffer = owned_buffer::create<::bfloat16>(static_cast(2)); + host_buffer[0] = scalar; + Tensor scalar_tensor_host = Tensor( + OwnedStorage{host_buffer}, + ttnn::Shape(intended_shape, device_shape), + data_type, + Layout::ROW_MAJOR); + return scalar_tensor_host.to(device); + } + else if(layout == Layout::TILE){ + device_shape[device_shape.size() - 2] = TILE_HEIGHT; + device_shape[device_shape.size() - 1] = TILE_WIDTH; + auto host_buffer = owned_buffer::create<::bfloat16>(static_cast(TILE_HEIGHT * TILE_WIDTH)); + host_buffer[0] = scalar; + Tensor scalar_tensor_host = Tensor( + OwnedStorage{host_buffer}, + ttnn::Shape(intended_shape, device_shape), + data_type, + Layout::TILE); + return scalar_tensor_host.to(device); + } + else{ + throw std::runtime_error("Unsupported layout"); + } +} + template inline ttnn::Tensor full( const ttnn::Shape& shape, diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp index 876c29d2243..db148946524 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp @@ -361,6 +361,8 @@ tt::stl::hash::hash_t Binary::compute_program_hash( const auto& input_tensor_b = tensor_args.input_tensor_b; auto program_factory = select_program_factory(attributes, tensor_args); + TT_ASSERT(std::holds_alternative(input_tensor_a.get_storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(input_tensor_a.get_storage()),__FILE__, __LINE__)); + TT_ASSERT(std::holds_alternative(input_tensor_b.get_storage()), fmt::format("Unexpected type {} in {}:{} ",tt::stl::get_active_type_name_in_variant(input_tensor_b.get_storage()),__FILE__, __LINE__)); operation::Hash hash = operation::hash_operation( attributes, program_factory.index(),