From c4e1e9d068886022feac998256be6e891e17cb8c Mon Sep 17 00:00:00 2001 From: Akhmed Rakhmati Date: Fri, 7 Jun 2024 04:32:46 +0000 Subject: [PATCH] #8835: added TMP-based device operation infra --- .../unit_tests/operations/test_relational.py | 16 +- tt_eager/tensor/tensor.hpp | 12 +- tt_eager/tensor/types.hpp | 22 +- tt_eager/tt_dnn/op_library/run_operation.cpp | 4 +- tt_eager/tt_dnn/op_library/run_operation.hpp | 6 - .../csrc/tt_lib_bindings_tensor_pytensor.cpp | 2 +- tt_metal/impl/device/program_cache.hpp | 2 +- tt_metal/tools/profiler/op_profiler.hpp | 43 ++- ttnn/cpp/ttnn/decorators.hpp | 4 +- ttnn/cpp/ttnn/device_operation.hpp | 311 ++++++++++++++++ .../ttnn/operations/eltwise/binary/binary.hpp | 18 +- .../eltwise/binary/device/binary_op.cpp | 335 +++++++---------- .../eltwise/binary/device/binary_op.hpp | 165 +++++--- .../eltwise/binary/device/binary_op_type.hpp | 49 +++ ...t_and_width_multi_core_program_factory.hpp | 352 ++++++++++++++++++ ...cast_height_multi_core_program_factory.hpp | 314 ++++++++++++++++ ...dcast_width_multi_core_program_factory.hpp | 316 ++++++++++++++++ ...ement_wise_multi_core_program_factory.hpp} | 275 +++++++------- 18 files changed, 1813 insertions(+), 433 deletions(-) create mode 100644 ttnn/cpp/ttnn/device_operation.hpp create mode 100644 ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op_type.hpp create mode 100644 ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.hpp create mode 100644 ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.hpp create mode 100644 ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.hpp rename ttnn/cpp/ttnn/operations/eltwise/binary/device/{binary_program_factory.hpp => element_wise_multi_core_program_factory.hpp} (60%) diff --git a/tests/ttnn/unit_tests/operations/test_relational.py b/tests/ttnn/unit_tests/operations/test_relational.py index 6bbcd725483..36860e882a6 100644 --- a/tests/ttnn/unit_tests/operations/test_relational.py +++ b/tests/ttnn/unit_tests/operations/test_relational.py @@ -236,16 +236,16 @@ def test_expand_and_broadcast(device, h, w): @pytest.mark.parametrize("h", [500]) @pytest.mark.parametrize("w", [512]) def test_expand_and_broadcast_reversed(device, h, w): - torch_a = torch.rand((1, h, w), dtype=torch.bfloat16) - torch_b = torch.rand((h, w), dtype=torch.bfloat16) - torch_output = torch.lt(torch_b, torch_a) + torch_input_tensor_a = torch.rand((1, h, w), dtype=torch.bfloat16) + torch_input_tensor_b = torch.rand((h, w), dtype=torch.bfloat16) + torch_output = torch.lt(torch_input_tensor_b, torch_input_tensor_a) - a = ttnn.from_torch(torch_a, layout=ttnn.TILE_LAYOUT, device=device) - b = ttnn.from_torch(torch_b, layout=ttnn.TILE_LAYOUT, device=device) - tt_output = ttnn.lt(b, a) - tt_output = ttnn.to_torch(tt_output) + input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device) + input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device) + output = ttnn.lt(input_tensor_b, input_tensor_a) + output = ttnn.to_torch(output) - assert_with_pcc(torch_output, tt_output, 0.9999) + assert_with_pcc(torch_output, output, 0.9999) @pytest.mark.parametrize("atol", [1e-8, 1e-10]) diff --git a/tt_eager/tensor/tensor.hpp b/tt_eager/tensor/tensor.hpp index 226f7913e13..39da6adc550 100644 --- a/tt_eager/tensor/tensor.hpp +++ b/tt_eager/tensor/tensor.hpp @@ -43,7 +43,8 @@ struct Tensor { bool track_ref_count = false; TensorAttributes(const Storage storage, const ttnn::Shape shape, DataType dtype, Layout layout) : storage(storage), shape(shape), dtype(dtype), layout(layout) {} - TensorAttributes() : shape({0xff, 0xff, 0xff, 0xff}), dtype(DataType::INVALID), layout(Layout::INVALID) {} + TensorAttributes() : + shape(std::array{0xff, 0xff, 0xff, 0xff}), dtype(DataType::INVALID), layout(Layout::INVALID) {} ~TensorAttributes() = default; // Use these functions to manage the main_thread_ref_count for a tensor attr instance. @@ -392,6 +393,15 @@ Tensor create_device_tensor( Device *device, const MemoryConfig &memory_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}); +static Tensor create_device_tensor( + const ttnn::Shape &shape, + DataType dtype, + Layout layout, + Device *device, + const MemoryConfig &memory_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) { + return create_device_tensor(shape.value(), dtype, layout, device, memory_config); +} + // template // void *get_host_buffer(const Tensor &tensor); void *get_raw_host_data_ptr(const Tensor &tensor); diff --git a/tt_eager/tensor/types.hpp b/tt_eager/tensor/types.hpp index cb361a46c1f..952321a71b3 100644 --- a/tt_eager/tensor/types.hpp +++ b/tt_eager/tensor/types.hpp @@ -173,19 +173,21 @@ class Shape { } template - explicit Shape(const std::array &shape, const std::array &shape_tile_padding) : + explicit Shape(const std::array &shape, const std::array &shape_with_tile_padding) : rank_(Rank), dimensions_{}, padding_{Rank} { for (auto index = 0; index < Rank; index++) { - auto padded_dimension = shape_tile_padding[index]; + auto padded_dimension = shape_with_tile_padding[index]; this->dimensions_[index] = padded_dimension; this->padding_[index] = {.front = 0, .back = padded_dimension - shape[index]}; } } - explicit Shape(const std::vector &shape, const std::vector &shape_tile_padding) : + explicit Shape(const std::vector &shape, const std::vector &shape_with_tile_padding) : rank_(shape.size()), dimensions_{}, padding_{shape.size()} { - TT_ASSERT(shape.size() == shape_tile_padding.size(), "Shape and shape_tile_padding must have the same size"); + TT_ASSERT( + shape.size() == shape_with_tile_padding.size(), + "Shape and shape_with_tile_padding must have the same size"); for (auto index = 0; index < shape.size(); index++) { - auto padded_dimension = shape_tile_padding[index]; + auto padded_dimension = shape_with_tile_padding[index]; this->dimensions_[index] = padded_dimension; this->padding_[index] = {.front = 0, .back = padded_dimension - shape[index]}; } @@ -720,14 +722,20 @@ struct Shape { explicit Shape(const std::array &shape) : ranked_shape{RankedShape{shape}} {} template - explicit Shape(const std::array &shape, const std::array &shape_tile_padding) : - ranked_shape{RankedShape{shape, shape_tile_padding}} {} + explicit Shape(const std::array &shape, const std::array &shape_with_tile_padding) : + ranked_shape{RankedShape{shape, shape_with_tile_padding}} {} template explicit Shape( const std::array &shape, const std::array, Rank> &tile_padding) : ranked_shape{RankedShape{shape, tile_padding}} {} + static Shape from_vector(const std::vector &shape) { return Shape{tt::tt_metal::Shape{shape}}; } + + static Shape from_vector(const std::vector &shape, const std::vector &shape_with_tile_padding) { + return Shape{tt::tt_metal::Shape{shape, shape_with_tile_padding}}; + } + const auto rank() const { return std::visit( [](const RankedShape &shape) -> const auto { return Rank; }, this->ranked_shape); diff --git a/tt_eager/tt_dnn/op_library/run_operation.cpp b/tt_eager/tt_dnn/op_library/run_operation.cpp index 050913c3b8b..389527ed3b2 100644 --- a/tt_eager/tt_dnn/op_library/run_operation.cpp +++ b/tt_eager/tt_dnn/op_library/run_operation.cpp @@ -119,7 +119,7 @@ constexpr auto decorate_device_operation(const Function& function) { template OutputTensors run_host_operation(const HostOperation& operation, const Tensors& input_tensors) { ZoneScopedN("TT_DNN_HOST_OP"); - uint32_t op_id = assign_id(); + uint32_t op_id = assign_operation_id(); operation.validate(input_tensors); auto output_tensors = operation.compute_output_tensors(input_tensors); @@ -143,7 +143,7 @@ OutputTensors run_device_operation( const OptionalConstTensors& optional_input_tensors, const OptionalTensors& optional_output_tensors) { ZoneScopedN("TT_DNN_DEVICE_OP"); - uint32_t op_id = assign_id(); + uint32_t op_id = assign_operation_id(); std::function, std::reference_wrapper>( const DeviceOperation&, diff --git a/tt_eager/tt_dnn/op_library/run_operation.hpp b/tt_eager/tt_dnn/op_library/run_operation.hpp index 94692be8b53..ec9c971ca8c 100644 --- a/tt_eager/tt_dnn/op_library/run_operation.hpp +++ b/tt_eager/tt_dnn/op_library/run_operation.hpp @@ -275,12 +275,6 @@ inline void log_operation( const OptionalTensors& optional_output_tensors = {}) {} #endif -inline uint32_t assign_id() -{ - static std::atomic atomic_count{0}; - return atomic_count.fetch_add(1); -} - template OutputTensors run( const HostOperation& operation, diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp index af73c400051..bb26e0e062c 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp @@ -567,7 +567,7 @@ Tensor convert_python_tensors_to_tt_tensors(py::list tensor_shards, std::optiona ZoneScopedN("TT_DNN_FALLBACK_OP"); auto [op, input_tensors] = detail::parse_external_operation(function, args, kwargs, function_name); operation::log_operation(op, input_tensors); - uint32_t op_id = tt::tt_metal::operation::assign_id(); + uint32_t op_id = tt::tt_metal::assign_operation_id(); auto output_tensors = function(*args, **kwargs); diff --git a/tt_metal/impl/device/program_cache.hpp b/tt_metal/impl/device/program_cache.hpp index 5c34f97c33e..a35b1a9fd65 100644 --- a/tt_metal/impl/device/program_cache.hpp +++ b/tt_metal/impl/device/program_cache.hpp @@ -51,7 +51,7 @@ struct ProgramCache { private: inline static bool is_enabled_ = false; - std::unordered_map> cache_{}; + std::unordered_map> cache_{}; }; } diff --git a/tt_metal/tools/profiler/op_profiler.hpp b/tt_metal/tools/profiler/op_profiler.hpp index 778948e271e..5d41032f266 100644 --- a/tt_metal/tools/profiler/op_profiler.hpp +++ b/tt_metal/tools/profiler/op_profiler.hpp @@ -22,6 +22,11 @@ namespace tt { namespace tt_metal { +static uint32_t assign_operation_id() { + static std::atomic atomic_count{0}; + return atomic_count.fetch_add(1); +} + namespace op_profiler { enum class OpType { python_fallback, tt_dnn_cpu, tt_dnn_device, unknown }; @@ -251,6 +256,23 @@ inline json get_base_json( return j; } +inline json get_base_json(uint32_t opID, const auto& op) { + ZoneScoped; + json j; + j["global_call_count"] = opID; + + std::string opName = "device operation"; + + std::replace(opName.begin(), opName.end(), ',', ';'); + j["op_code"] = opName; + + json attributesObj; + j["attributes"] = attributesObj; + j["input_tensors"] = get_tensors_json(std::vector{}); + j["output_tensors"] = get_tensors_json(std::vector{}); + return j; +} + inline std::string op_meta_data_serialized_json( uint32_t opID, const tt::tt_metal::operation::ExternalOperation& op, const std::vector& input_tensors) { auto j = get_base_json(opID, op, input_tensors); @@ -321,6 +343,12 @@ inline std::string op_meta_data_serialized_json( return fmt::format("{}{}`", cached_ops.at(device_id).at(opHash), opID); } } +inline std::string op_meta_data_serialized_json(uint32_t opID, const auto& op) { + auto j = get_base_json(opID, op); + j["op_type"] = magic_enum::enum_name(OpType::tt_dnn_device); + std::string ser = j.dump(4); + return fmt::format("`Device Operation:{} ->\n{}`", j["op_code"], ser); +} #define TracyOpTTNNDevice( \ op_id, op_hash, is_cached, device_id, operation, program, input_tensors, optional_input_tensors, output_tensors) \ @@ -338,6 +366,12 @@ inline std::string op_meta_data_serialized_json( ZoneText(op_text.c_str(), op_text.size()); \ TracyMessage(op_message.c_str(), op_message.size()); +#define TracyOpTNNNDeviceV2(op_id, op) \ + std::string op_message = op_profiler::op_meta_data_serialized_json(op_id, op); \ + std::string op_text = fmt::format("id:{}", op_id); \ + ZoneText(op_text.c_str(), op_text.size()); \ + TracyMessage(op_message.c_str(), op_message.size()); + #define TracyOpTTNNHost(op_id, operation, input_tensors, output_tensors) \ std::string op_message = \ op_profiler::op_meta_data_serialized_json(op_id, operation, input_tensors, output_tensors); \ @@ -345,16 +379,17 @@ inline std::string op_meta_data_serialized_json( ZoneText(op_text.c_str(), op_text.size()); \ TracyMessage(op_message.c_str(), op_message.size()); -#define TracyOpTTNNExternal(op_id, op, input_tensors) \ - std::string op_message = op_profiler::op_meta_data_serialized_json(op_id, op, input_tensors); \ - std::string op_text = fmt::format("id:{}", op_id); \ - ZoneText(op_text.c_str(), op_text.size()); \ +#define TracyOpTTNNExternal(op_id, op, input_tensors) \ + std::string op_message = op_profiler::op_meta_data_serialized_json(op_id, op); \ + std::string op_text = fmt::format("id:{}", op_id); \ + ZoneText(op_text.c_str(), op_text.size()); \ TracyMessage(op_message.c_str(), op_message.size()); #else #define TracyOpTTNNDevice( \ op_id, op_hash, is_cached, device_id, operation, program, input_tensors, optional_input_tensors, output_tensors) +#define TracyOpTNNNDeviceV2(op_id, op) #define TracyOpTTNNHost(op_id, operation, input_tensors, output_tensors) #define TracyOpTTNNExternal(op_id, op, input_tensors) diff --git a/ttnn/cpp/ttnn/decorators.hpp b/ttnn/cpp/ttnn/decorators.hpp index 3ebccd3ad41..513ed01dabf 100644 --- a/ttnn/cpp/ttnn/decorators.hpp +++ b/ttnn/cpp/ttnn/decorators.hpp @@ -213,7 +213,7 @@ struct operation_t { template auto operator()(args_t&&... args) const { - ZoneScoped; + ZoneScopedN("Run ttnn operation (struct-based)"); ZoneName(this->cpp_fully_qualified_name, std::strlen(this->cpp_fully_qualified_name)); tt::log_debug(tt::LogOp, "Started C++ ttnn operation: {}", this->cpp_fully_qualified_name); @@ -324,7 +324,7 @@ struct lambda_operation_t { template auto operator()(args_t&&... args) const { - ZoneScoped; + ZoneScopedN("Run ttnn operation (lambda-based)"); ZoneName(this->cpp_fully_qualified_name, std::strlen(this->cpp_fully_qualified_name)); tt::log_debug(tt::LogOp, "Started C++ ttnn operation: {}", this->cpp_fully_qualified_name); auto output = this->lambda(std::forward(args)...); diff --git a/ttnn/cpp/ttnn/device_operation.hpp b/ttnn/cpp/ttnn/device_operation.hpp new file mode 100644 index 00000000000..1f55e3e8649 --- /dev/null +++ b/ttnn/cpp/ttnn/device_operation.hpp @@ -0,0 +1,311 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "third_party/magic_enum/magic_enum.hpp" +#include "tt_dnn/op_library/operation_history.hpp" +#include "tt_metal/third_party/tracy/public/tracy/Tracy.hpp" +#include "tt_metal/tools/profiler/op_profiler.hpp" +#include "tt_stl/concepts.hpp" +#include "tt_stl/reflection.hpp" +#include "tt_stl/unique_any.hpp" + +namespace ttnn { + +namespace device_operation { + +template +struct CachedProgram { + tt::tt_metal::Program program; + // Cached program needs to share program_attributes between create and override_runtime_arguments functions + std::tuple program_attributes; + + CachedProgram(tt::tt_metal::Program&& program, program_attributes_t... program_attributes) : + program{std::move(program)}, program_attributes{std::tuple{program_attributes...}} {} +}; + +struct CachedProgramFactory { + tt::stl::unique_any<1024, 32> cached_program; + std::size_t program_factory_index; + + template + CachedProgramFactory(CachedProgram&& cached_program, std::size_t program_factory_index) : + cached_program{std::move(cached_program)}, program_factory_index{program_factory_index} {} +}; + +template +concept ProgramFactoryConcept = requires { + [](const auto& operation_attributes, const auto& tensor_args, auto& tensor_return_value) { + auto cached_program = program_factory_t::create(operation_attributes, tensor_args, tensor_return_value); + program_factory_t::override_runtime_arguments( + cached_program, operation_attributes, tensor_args, tensor_return_value); + }; +}; + +template +concept DeviceOperationConcept = requires { + [](const typename operation_t::operation_attributes_t& operation_attributes, + const typename operation_t::tensor_args_t& tensor_args) { + operation_t::validate_on_program_cache_hit(operation_attributes, tensor_args); + operation_t::validate_on_program_cache_miss(operation_attributes, tensor_args); + + using shape_return_value_t = typename operation_t::shape_return_value_t; + static_assert(std::same_as< + decltype(operation_t::compute_output_shapes(operation_attributes, tensor_args)), + shape_return_value_t>); + + using tensor_return_value_t = typename operation_t::tensor_return_value_t; + static_assert(std::same_as< + decltype(operation_t::create_output_tensors(operation_attributes, tensor_args)), + tensor_return_value_t>); + + const auto program_factory = operation_t::select_program_factory(operation_attributes, tensor_args); + std::visit( + [](auto&& program_factory) { + using program_factory_t = std::decay_t; + static_assert(ProgramFactoryConcept); + }, + program_factory); + }; +}; + +template +concept DeviceOperationWithCustomProgramCacheConcept = DeviceOperationConcept and requires { + [](auto&& program_factory, + const typename operation_t::operation_attributes_t& operation_attributes, + const typename operation_t::tensor_args_t& tensor_args) { + operation_t::compute_program_hash(operation_attributes, tensor_args); + }; +}; + +template +[[nodiscard]] std::variant constexpr map_index_to_variant(std::size_t i, std::variant) { + assert(i < sizeof...(Ts)); + static constexpr std::variant table[] = { Ts{ }... }; + return table[i]; +} + +template + requires std::same_as, Tensor> +constexpr auto visit_tensor(auto callback, T&& value) { + callback(value); +} + +template +constexpr auto visit_tensor(auto callback, const std::optional& value) { + if (value.has_value()) { + const auto& tensor = value.value(); + visit_tensor(callback, tensor); + } +} + +template +constexpr auto visit_tensor(auto callback, const std::vector& value) { + for (auto& tensor : value) { + visit_tensor(callback, tensor); + } +} + +template +constexpr auto visit_tensor(auto callback, const std::array& value) { + for (auto& tensor : value) { + visit_tensor(callback, tensor); + } +} + +template +constexpr auto visit_tensor(auto callback, const std::tuple& value) { + constexpr auto num_attributes = sizeof...(Ts); + [&callback, &value](std::index_sequence) { + (visit_tensor(callback, std::get(value)), ...); + }(std::make_index_sequence{}); +} + +template + requires(not std::same_as, Tensor>) and requires { std::decay_t::attribute_names; } +constexpr auto visit_tensor(auto callback, T&& object) { + constexpr auto num_attributes = std::tuple_size_v::attribute_names)>; + visit_tensor(callback, object.attribute_values()); +} + +template + requires std::same_as, Tensor> +constexpr auto get_first_tensor(T&& value) { + return std::cref(value); +} + +template +constexpr auto get_first_tensor(const std::optional& value) { + if (value.has_value()) { + const auto& tensor = value.value(); + return get_first_tensor(tensor); + } +} + +template +constexpr auto get_first_tensor(const std::vector& value) { + for (auto& tensor : value) { + return get_first_tensor(tensor); + } +} + +template +constexpr auto get_first_tensor(const std::array& value) { + for (auto& tensor : value) { + return get_first_tensor(tensor); + } +} + +template +constexpr auto get_first_tensor(const std::tuple& value) { + constexpr auto num_attributes = sizeof...(Ts); + return get_first_tensor(std::get<0>(value)); +} + +template + requires requires { std::decay_t::attribute_names; } and (not std::same_as, Tensor>) +constexpr auto get_first_tensor(T&& object) { + constexpr auto num_attributes = std::tuple_size_v::attribute_names)>; + return get_first_tensor(object.attribute_values()); +} + +inline const auto USE_FAST_DISPATCH = std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr; + +template +inline auto compute_program_hash( + const typename operation_t::operation_attributes_t& operation_attributes, + const typename operation_t::tensor_args_t& tensor_args) { + if constexpr (DeviceOperationWithCustomProgramCacheConcept) { + ZoneScopedN("Compute custom program hash"); + return operation_t::compute_program_hash(operation_attributes, tensor_args); + } else { + ZoneScopedN("Compute default program hash"); + return tt::stl::hash::hash_objects_with_default_seed( + tt::stl::hash::type_hash, operation_attributes, tensor_args); + } +} + +template +inline auto& create_or_get_program_from_cache( + auto& program_cache, + auto cache_hit, + auto program_hash, + const typename operation_t::operation_attributes_t& operation_attributes, + const typename operation_t::tensor_args_t& tensor_args, + typename operation_t::tensor_return_value_t& tensor_return_value) { + if (not cache_hit) { + ZoneScopedN("Program Cache Miss"); + auto program_factory = operation_t::select_program_factory(operation_attributes, tensor_args); + + auto& program = std::visit( + [&program_cache, + &program_hash, + &operation_attributes, + &tensor_args, + &tensor_return_value, + program_factory_index = program_factory.index()](auto&& program_factory) -> auto& { + using program_factory_t = std::decay_t; + using cached_program_t = + decltype(program_factory_t::create(operation_attributes, tensor_args, tensor_return_value)); + program_cache.insert( + program_hash, + CachedProgramFactory{ + program_factory_t::create(operation_attributes, tensor_args, tensor_return_value), + program_factory_index}); + auto& cached_program_factory = program_cache.template get(program_hash); + auto& cached_program = cached_program_factory.cached_program.template get(); + return cached_program.program; + }, + program_factory); + return program; + } else { + ZoneScopedN("Program Cache Hit"); + auto& cached_program_factory = program_cache.template get(program_hash); + auto program_factory_index = cached_program_factory.program_factory_index; + + using program_factory_variant_t = + decltype(operation_t::select_program_factory(operation_attributes, tensor_args)); + auto program_factory = map_index_to_variant(program_factory_index, program_factory_variant_t{}); + + auto& program = std::visit( + [&cached_program_factory, &operation_attributes, &tensor_args, &tensor_return_value]( + auto&& program_factory) -> auto& { + using program_factory_t = std::decay_t; + + using cached_program_t = + decltype(program_factory_t::create(operation_attributes, tensor_args, tensor_return_value)); + auto& cached_program = cached_program_factory.cached_program.template get(); + + program_factory_t::override_runtime_arguments( + cached_program, operation_attributes, tensor_args, tensor_return_value); + + return cached_program.program; + }, + program_factory); + return program; + } +} + +template + requires DeviceOperationConcept +typename operation_t::tensor_return_value_t run( + uint8_t cq_id, + const typename operation_t::operation_attributes_t& operation_attributes, + const typename operation_t::tensor_args_t& tensor_args) { + ZoneScopedN("TT_DNN_DEVICE_OP"); + uint32_t operation_id = assign_operation_id(); + + using tensor_return_value_t = typename operation_t::tensor_return_value_t; + static_assert(not std::same_as, "Operation cannot return type cannot be void"); + + auto device = get_first_tensor(tensor_args).get().device(); + auto& program_cache = device->program_cache; + + auto program_hash = compute_program_hash(operation_attributes, tensor_args); + auto cache_hit = program_cache.contains(program_hash); + + if (cache_hit) { + operation_t::validate_on_program_cache_hit(operation_attributes, tensor_args); + } else { + operation_t::validate_on_program_cache_miss(operation_attributes, tensor_args); + } + auto tensor_return_value = operation_t::create_output_tensors(operation_attributes, tensor_args); + + auto& program = create_or_get_program_from_cache( + program_cache, cache_hit, program_hash, operation_attributes, tensor_args, tensor_return_value); + + if (USE_FAST_DISPATCH) { + ZoneScopedN("EnqueueProgram"); + auto& queue = device->command_queue(cq_id); + // Program will temporarily own the input buffers. This is required, since with Async command + // queues, the input tensor can preemptively be deallocted on device, unless program maintains + // explicit ownership. This invocation of the program wicll give up ownership once its enqueued. + auto assign_global_buffer_to_program = [&program](auto&& tensor) { + AssignGlobalBufferToProgram(tensor.device_buffer(), program); + }; + visit_tensor(assign_global_buffer_to_program, tensor_args); + tt::tt_metal::EnqueueProgram(queue, program, false); + } else { + ZoneScopedN("LaunchProgram"); + ::detail::LaunchProgram(device, program); + } + + // Visit output tensors with the sole purpose of checking the return type to make sure that it only has Tensors + // TODO: come up with a better way of checking the return type + visit_tensor([](auto&& tensor) {}, tensor_return_value); + + // TODO: update this to work properly take program cache info, as well as tensors + TracyOpTNNNDeviceV2(operation_id, operation_attributes); + + return tensor_return_value; +} + +} // namespace device_operation + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp index dd9c0078f99..a0cf1c9d938 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp @@ -6,7 +6,7 @@ #pragma once #include "device/binary_op.hpp" - +#include "ttnn/device_operation.hpp" #include "ttnn/operations/data_movement.hpp" namespace ttnn { @@ -108,17 +108,11 @@ struct ExecuteBinary { dtype = optional_output_tensor.value().get_dtype(); } - auto output_tensors = operation::run(Binary{BinaryProgramConfig{binary_op_type, - in_place, - activations, - output_memory_config, - dtype}}, - {input_tensor_a, input_tensor_b}, - {}, - {optional_output_tensor}, - queue_id); - - return output_tensors.at(0); + return ttnn::device_operation::run( + queue_id, + Binary::operation_attributes_t{ + binary_op_type, in_place, activations, output_memory_config, dtype, std::nullopt}, + Binary::tensor_args_t{input_tensor_a, input_tensor_b, optional_output_tensor}); } template diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp index cbec394fe99..38461c4d9e9 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp @@ -3,18 +3,13 @@ // SPDX-License-Identifier: Apache-2.0 #include "binary_op.hpp" -#include "binary_program_factory.hpp" #include "third_party/magic_enum/magic_enum.hpp" - #include "tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp" - #include "tt_metal/common/constants.hpp" #include "tt_metal/host_api.hpp" #include "tt_metal/tools/profiler/op_profiler.hpp" - - namespace ttnn::operations::binary { namespace utils { @@ -140,103 +135,38 @@ std::map get_defines( } // namespace utils -enum class BinaryProgramType { - ElementWiseMultiCore, - BroadcastWidthMultiCore, - BroadcastHeightMultiCore, - BroadcastHeightAndWidthMultiCore, -}; - -inline BinaryProgramType get_program_type(const Binary& operation, const std::vector& input_tensors) { - const auto& input_tensor_a = input_tensors.at(0); - const auto& input_tensor_b = input_tensors.at(1); +/* static */ Binary::program_factory_t Binary::select_program_factory( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + const auto& input_shape_a = tensor_args.input_tensor_a.shape(); + const auto& input_shape_b = tensor_args.input_tensor_b.shape(); - const auto& input_shape_a = input_tensor_a.get_shape(); - const auto& input_shape_b = input_tensor_b.get_shape(); - - auto batch_size_0_a = input_shape_a.rank() >= 4 ? input_shape_a[-4] : 1; - auto batch_size_1_a = input_shape_a.rank() >= 3 ? input_shape_a[-3] : 1; auto height_a = input_shape_a[-2]; auto width_a = input_shape_a[-1]; - auto batch_size_0_b = input_shape_b.rank() >= 4 ? input_shape_b[-4] : 1; - auto batch_size_1_b = input_shape_b.rank() >= 3 ? input_shape_b[-3] : 1; auto height_b = input_shape_b[-2]; auto width_b = input_shape_b[-1]; - /* - fmt::print("input_shape_a: {}, input_shape_b: {}\n", input_shape_a, input_shape_b); - fmt::print( - "batch_size_0_a: {}, batch_size_1_a: {}, height_a: {}, width_a: {}\n", - batch_size_0_a, - batch_size_1_a, - height_a, - width_a); - fmt::print( - "batch_size_0_b: {}, batch_size_1_b: {}, height_b: {}, width_b: {}\n", - batch_size_0_b, - batch_size_1_b, - height_b, - width_b); - */ - - if (batch_size_0_a == batch_size_0_b and batch_size_1_a == batch_size_1_b and height_a == height_b and - width_a == width_b) { - return BinaryProgramType::ElementWiseMultiCore; + if (height_a == height_b and width_a == width_b) { + return ElementWiseMultiCore{}; } else if (height_b == 1 or width_b == 1) { - if (operation.dtype != input_tensor_a.get_dtype()) { - TT_THROW("ttnn::operations::binary::Binary: cannot change dtype when broadcasting"); - } if (height_b == 1 and width_b == 1) { - // fmt::print("BinaryProgramType::BroadcastHeightAndWidthMultiCore\n"); - return BinaryProgramType::BroadcastHeightAndWidthMultiCore; + return BroadcastHeightAndWidthMultiCore{}; } else if (height_b == 1) { - // fmt::print("BinaryProgramType::BroadcastHeightMultiCore\n"); - return BinaryProgramType::BroadcastHeightMultiCore; + return BroadcastHeightMultiCore{}; } else if (width_b == 1) { - // fmt::print("BinaryProgramType::BroadcastWidthMultiCore\n"); - return BinaryProgramType::BroadcastWidthMultiCore; + return BroadcastWidthMultiCore{}; } } TT_THROW("ttnn::operations::binary::Binary: unsupported broadcast"); } -void Binary::validate_with_output_tensors(const std::vector &input_tensors, const std::vector>& output_tensors) const { - auto program_type = get_program_type(*this, input_tensors); - - const auto& input_tensor_a = input_tensors.at(0); - const auto& input_tensor_b = input_tensors.at(1); - - const auto& input_shape_a = input_tensor_a.get_shape(); - const auto& input_shape_b = input_tensor_b.get_shape(); - - auto batch_size_0_a = input_shape_a.rank() >= 4 ? input_shape_a[-4] : 1; - auto batch_size_1_a = input_shape_a.rank() >= 3 ? input_shape_a[-3] : 1; - auto height_a = input_shape_a[-2]; - auto width_a = input_shape_a[-1]; - - auto batch_size_0_b = input_shape_b.rank() >= 4 ? input_shape_b[-4] : 1; - auto batch_size_1_b = input_shape_b.rank() >= 3 ? input_shape_b[-3] : 1; - auto height_b = input_shape_b[-2]; - auto width_b = input_shape_b[-1]; +/* static */ void Binary::validate_on_program_cache_miss( + const operation_attributes_t& attributes, const tensor_args_t& tensor_args) { + const auto& input_tensor_a = tensor_args.input_tensor_a; + const auto& input_tensor_b = tensor_args.input_tensor_b; + const auto& output_tensor = tensor_args.output_tensor; - // Input shape b must be the same as or broadcastable to input shape a - if (batch_size_0_a != batch_size_0_b) { - TT_ASSERT( - batch_size_0_a > batch_size_0_b and batch_size_0_b == 1, - "ttnn::operations::binary::Binary: batch size mismatch"); - } - if (batch_size_1_a != batch_size_1_b) { - TT_ASSERT( - batch_size_1_a > batch_size_1_b and batch_size_1_b == 1, - "ttnn::operations::binary::Binary: batch size mismatch"); - } - if (height_a != height_b) { - TT_ASSERT(height_a > height_b and height_b == 1, "ttnn::operations::binary::Binary: height mismatch"); - } - if (width_a != width_b) { - TT_ASSERT(width_a > width_b and width_b == 1, "ttnn::operations::binary::Binary: width mismatch"); - } + Binary::validate_on_program_cache_hit(attributes, tensor_args); TT_FATAL( input_tensor_a.device() == input_tensor_b.device(), @@ -244,10 +174,10 @@ void Binary::validate_with_output_tensors(const std::vector &input_tenso TT_FATAL( (input_tensor_a.get_layout() == Layout::TILE && input_tensor_b.get_layout() == Layout::TILE), "Inputs to eltwise binary must be tilized"); - if (this->in_place) { - TT_FATAL(input_tensor_a.memory_config().memory_layout == this->memory_config.memory_layout); - TT_FATAL(input_tensor_a.memory_config().buffer_type == this->memory_config.buffer_type); - TT_FATAL(input_tensor_a.get_dtype() == this->dtype); + if (attributes.in_place) { + TT_FATAL(input_tensor_a.memory_config().memory_layout == attributes.memory_config.memory_layout); + TT_FATAL(input_tensor_a.memory_config().buffer_type == attributes.memory_config.buffer_type); + TT_FATAL(input_tensor_a.get_dtype() == attributes.dtype); } if (input_tensor_a.memory_config().is_sharded()) { if (input_tensor_a.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) { @@ -260,70 +190,127 @@ void Binary::validate_with_output_tensors(const std::vector &input_tenso TT_FATAL(input_tensor_a.memory_config().memory_layout == input_tensor_b.memory_config().memory_layout); TT_FATAL(input_tensor_a.shard_spec().value() == input_tensor_b.shard_spec().value()); } - if (this->memory_config.is_sharded()) { - TT_FATAL(input_tensor_a.memory_config().memory_layout == this->memory_config.memory_layout); + if (attributes.memory_config.is_sharded()) { + TT_FATAL(input_tensor_a.memory_config().memory_layout == attributes.memory_config.memory_layout); } else { - TT_FATAL(this->memory_config.memory_layout == TensorMemoryLayout::INTERLEAVED); + TT_FATAL(attributes.memory_config.memory_layout == TensorMemoryLayout::INTERLEAVED); } } else if (input_tensor_b.memory_config().is_sharded()) { TT_FATAL(input_tensor_b.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED); TT_FATAL(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED); - if (this->memory_config.is_sharded()) { - TT_FATAL(input_tensor_b.memory_config().memory_layout == this->memory_config.memory_layout); + if (attributes.memory_config.is_sharded()) { + TT_FATAL(input_tensor_b.memory_config().memory_layout == attributes.memory_config.memory_layout); } else { - TT_FATAL(this->memory_config.memory_layout == TensorMemoryLayout::INTERLEAVED); + TT_FATAL(attributes.memory_config.memory_layout == TensorMemoryLayout::INTERLEAVED); } } else { TT_FATAL(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED); TT_FATAL(input_tensor_b.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED); - if (this->memory_config.is_sharded()) { - TT_FATAL(this->memory_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED); + if (attributes.memory_config.is_sharded()) { + TT_FATAL(attributes.memory_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED); uint32_t num_blocks = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / TILE_HEIGHT; auto core_grid = input_tensor_a.device()->compute_with_storage_grid_size(); uint32_t num_cores = core_grid.x * core_grid.y; TT_FATAL(num_blocks < num_cores or num_blocks % num_cores == 0); } else { - TT_FATAL(this->memory_config.memory_layout == TensorMemoryLayout::INTERLEAVED); + TT_FATAL(attributes.memory_config.memory_layout == TensorMemoryLayout::INTERLEAVED); } } - if (program_type != BinaryProgramType::ElementWiseMultiCore) { - TT_FATAL(not this->activations.has_value()); + auto program_factory = select_program_factory(attributes, tensor_args); + std::visit( + [&attributes](auto&& program_factory) { + if constexpr (std::is_same_v) { + TT_FATAL(not attributes.activations.has_value()); + } + }, + program_factory); + + if (output_tensor.has_value()) { + TT_FATAL( + not attributes.in_place, + "Operation is configured as in_place. First input is used as output. Provided output tensor is " + "ignored"); } +} +/* static */ void Binary::validate_on_program_cache_hit( + const operation_attributes_t& attributes, const tensor_args_t& tensor_args) { + const auto& input_tensor_a = tensor_args.input_tensor_a; + const auto& input_tensor_b = tensor_args.input_tensor_b; + const auto& output_tensor = tensor_args.output_tensor; - if (!output_tensors.empty()) { - TT_FATAL(output_tensors.size() == 1, "Must have 1 output tensors"); + const auto& input_shape_a = input_tensor_a.get_shape(); + const auto& input_shape_b = input_tensor_b.get_shape(); - if(output_tensors.at(0).has_value()) { - TT_FATAL(!this->in_place, "Operation is configured as in_place. First input is used as output. Provided output tensor is ignored"); - } + auto batch_size_0_a = input_shape_a.rank() >= 4 ? input_shape_a[-4] : 1; + auto batch_size_1_a = input_shape_a.rank() >= 3 ? input_shape_a[-3] : 1; + auto height_a = input_shape_a[-2]; + auto width_a = input_shape_a[-1]; + + auto batch_size_0_b = input_shape_b.rank() >= 4 ? input_shape_b[-4] : 1; + auto batch_size_1_b = input_shape_b.rank() >= 3 ? input_shape_b[-3] : 1; + auto height_b = input_shape_b[-2]; + auto width_b = input_shape_b[-1]; + + // Input shape b must be the same as or broadcastable to input shape a + if (batch_size_0_a != batch_size_0_b) { + TT_ASSERT( + batch_size_0_a > batch_size_0_b and batch_size_0_b == 1, + "ttnn::operations::binary::Binary: batch size mismatch"); + } + if (batch_size_1_a != batch_size_1_b) { + TT_ASSERT( + batch_size_1_a > batch_size_1_b and batch_size_1_b == 1, + "ttnn::operations::binary::Binary: batch size mismatch"); + } + if (height_a != height_b) { + TT_ASSERT(height_a > height_b and height_b == 1, "ttnn::operations::binary::Binary: height mismatch"); + } + if (width_a != width_b) { + TT_ASSERT(width_a > width_b and width_b == 1, "ttnn::operations::binary::Binary: width mismatch"); } } -std::vector Binary::compute_output_shapes(const std::vector& input_tensors) const { - const auto& input_tensor_a = input_tensors.at(0); - const auto& input_tensor_b = input_tensors.at(1); - if (input_tensor_a.get_shape().rank() >= input_tensor_b.get_shape().rank()) { - return {input_tensor_a.get_legacy_shape()}; +/* static */ Binary::shape_return_value_t Binary::compute_output_shapes( + const operation_attributes_t&, const tensor_args_t& tensor_args) { + const auto input_shape_a = tensor_args.input_tensor_a.shape(); + const auto input_shape_b = tensor_args.input_tensor_b.shape(); + + auto rank = std::max(input_shape_a.rank(), input_shape_b.rank()); + std::vector output_shape(rank, 0); + std::vector output_shape_with_tile_padding(rank, 0); + + for (int i = -1; i >= -rank; --i) { + auto dim_a = i + input_shape_a.rank() < input_shape_a.rank() ? input_shape_a[i] : 1; + auto dim_b = i + input_shape_b.rank() < input_shape_b.rank() ? input_shape_b[i] : 1; + output_shape[i + rank] = std::max(dim_a, dim_b); + + auto dim_a_with_tile_padding = + i + input_shape_a.rank() < input_shape_a.rank() ? input_shape_a.with_tile_padding()[i] : 1; + auto dim_b_with_tile_padding = + i + input_shape_b.rank() < input_shape_b.rank() ? input_shape_b.with_tile_padding()[i] : 1; + output_shape_with_tile_padding[i + rank] = std::max(dim_a_with_tile_padding, dim_b_with_tile_padding); } - return {input_tensor_b.get_legacy_shape()}; + return ttnn::Shape::from_vector(output_shape, output_shape_with_tile_padding); } -std::vector Binary::create_output_tensors(const std::vector& input_tensors, const std::vector>& output_tensors) const { - const auto& input_tensor_a = input_tensors.at(0); - const auto& input_tensor_b = input_tensors.at(1); - if (this->in_place) { +/* static */ Binary::tensor_return_value_t Binary::create_output_tensors( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + auto output_shape = compute_output_shapes(operation_attributes, tensor_args); + const auto& input_tensor_a = tensor_args.input_tensor_a; + const auto& input_tensor_b = tensor_args.input_tensor_b; + const auto& output_tensor = tensor_args.output_tensor; + if (operation_attributes.in_place) { return {input_tensor_a}; } else { - if (!output_tensors.empty() && output_tensors.at(0).has_value()) { - return {output_tensors.at(0).value()}; + if (output_tensor.has_value()) { + return output_tensor.value(); } - auto program_type = get_program_type(*this, input_tensors); - - if (program_type == BinaryProgramType::ElementWiseMultiCore) { - if (this->memory_config.is_sharded()) { + auto program_factory = select_program_factory(operation_attributes, tensor_args); + if (std::holds_alternative(program_factory)) { + if (operation_attributes.memory_config.is_sharded()) { ShardSpec shard_spec{CoreRangeSet({}), {0, 0}}; if (input_tensor_a.memory_config().is_sharded()) { shard_spec = input_tensor_a.shard_spec().value(); @@ -339,90 +326,50 @@ std::vector Binary::create_output_tensors(const std::vector& inp num_blocks / target_num_cores * TILE_HEIGHT, input_tensor_a.get_legacy_shape()[-1]}; shard_spec.orientation = ShardOrientation::ROW_MAJOR; } - auto memory_config = this->memory_config; + auto memory_config = operation_attributes.memory_config; memory_config.shard_spec = shard_spec; - return {create_device_tensor( - this->compute_output_shapes(input_tensors).at(0), - this->dtype, + return create_device_tensor( + output_shape, + operation_attributes.dtype, Layout::TILE, input_tensor_a.device(), - memory_config)}; + operation_attributes.memory_config); } } else { - if (this->memory_config.is_sharded()) { + if (operation_attributes.memory_config.is_sharded()) { ShardSpec shard_spec{CoreRangeSet({}), {0, 0}}; if (input_tensor_a.memory_config().is_sharded()) { // Derive output shard_spec based on input shard_spec = input_tensor_a.shard_spec().value(); } - auto memory_config = this->memory_config; + auto memory_config = operation_attributes.memory_config; memory_config.shard_spec = shard_spec; - return {create_device_tensor( - this->compute_output_shapes(input_tensors).at(0), - this->dtype, + return create_device_tensor( + output_shape, + operation_attributes.dtype, Layout::TILE, input_tensor_a.device(), - memory_config)}; + operation_attributes.memory_config); } } - return operation::generic_create_output_tensors( - *this, input_tensors, this->dtype, Layout::TILE, this->memory_config); - } -} - -const std::optional binary_op_type_to_bcast_op_math(const BinaryOpType binary_op_type) { - switch (binary_op_type) { - case BinaryOpType::ADD: return tt::tt_metal::BcastOpMath::ADD; - case BinaryOpType::SUB: return tt::tt_metal::BcastOpMath::SUB; - case BinaryOpType::MUL: return tt::tt_metal::BcastOpMath::MUL; - default: return std::nullopt; + return create_device_tensor( + output_shape, + operation_attributes.dtype, + Layout::TILE, + input_tensor_a.device(), + operation_attributes.memory_config); } } -operation::ProgramWithCallbacks Binary::create_program( - const std::vector& input_tensors, std::vector& output_tensors) const { - const auto& input_tensor_a = input_tensors.at(0); - const auto& input_tensor_b = input_tensors.at(1); - const auto& output_tensor = output_tensors.at(0); - - std::vector activations; - if (this->program_config.activations.has_value()) { - activations = this->program_config.activations.value(); - } +/* static */ tt::stl::hash::hash_t Binary::compute_program_hash( + const operation_attributes_t& attributes, const tensor_args_t& tensor_args) { + const auto& input_tensor_a = tensor_args.input_tensor_a; + const auto& input_tensor_b = tensor_args.input_tensor_b; - auto program_type = get_program_type(*this, input_tensors); - auto bcast_op_math = binary_op_type_to_bcast_op_math(this->binary_op_type); - if (bcast_op_math.has_value()) { - switch (program_type) { - case BinaryProgramType::ElementWiseMultiCore: - return eltwise_binary_multi_core( - input_tensor_a, input_tensor_b, output_tensor, this->binary_op_type, activations); - case BinaryProgramType::BroadcastHeightAndWidthMultiCore: - return bcast_multi_core_hw( - input_tensor_a, input_tensor_b, output_tensor, bcast_op_math.value(), false /* in-place */); - case BinaryProgramType::BroadcastHeightMultiCore: - return bcast_multi_core_h(input_tensor_a, input_tensor_b, output_tensor, bcast_op_math.value()); - case BinaryProgramType::BroadcastWidthMultiCore: - return bcast_multi_core_w(input_tensor_a, input_tensor_b, output_tensor, bcast_op_math.value()); - default: TT_THROW("Invalid program type"); - } - } else { - switch (program_type) { - case BinaryProgramType::ElementWiseMultiCore: - return eltwise_binary_multi_core( - input_tensor_a, input_tensor_b, output_tensor, this->binary_op_type, activations); - default: TT_THROW("Invalid program type"); - } - } -} - -const operation::Hash Binary::compute_program_hash(const std::vector& input_tensors) const { - const auto& input_tensor_a = input_tensors.at(0); - const auto& input_tensor_b = input_tensors.at(1); - auto program_type = get_program_type(*this, input_tensors); + auto program_factory = select_program_factory(attributes, tensor_args); operation::Hash hash = operation::hash_operation( - this->program_config, - program_type, + attributes, + program_factory.index(), input_tensor_a.dtype(), std::get(input_tensor_a.storage()).memory_config(), input_tensor_b.dtype(), @@ -430,22 +377,24 @@ const operation::Hash Binary::compute_program_hash(const std::vector& in return hash; } -operation::OpPerformanceModel Binary::create_op_performance_model( - const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - std::vector& output_tensors) const { +/* static */ operation::OpPerformanceModel Binary::create_op_performance_model( + const operation_attributes_t& attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value) { + const auto& input_tensor_a = tensor_args.input_tensor_a; + const auto& input_tensor_b = tensor_args.input_tensor_b; // GS specific parameters // 80 B/cycle unpacker BW shared // 128 datums per cycle math, but unpacker cant keep up constexpr int num_cores = 9 * 12; int total_bytes = 0; - for (const auto& t : input_tensors) { - total_bytes += t.volume() * t.element_size(); - } + total_bytes += input_tensor_a.volume() * input_tensor_a.element_size(); + total_bytes += input_tensor_b.volume() * input_tensor_b.element_size(); int ideal_eltwise_cycles = total_bytes / 80 / num_cores; - operation::OpPerformanceModel result(input_tensors, output_tensors, ideal_eltwise_cycles); + // TODO: update OpPerformanceModel to work on variadic arguments + operation::OpPerformanceModel result({}, {}, ideal_eltwise_cycles); #if 0 tt::log_info(tt::LogOp, "Binary PerfModel:"); tt::log_info(tt::LogOp, "\t Data (Bytes): {}", total_bytes); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.hpp index 82c7186d711..164d5ed8312 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.hpp @@ -6,85 +6,128 @@ #include #include +#include +#include "binary_op_type.hpp" +#include "broadcast_height_and_width_multi_core_program_factory.hpp" +#include "broadcast_height_multi_core_program_factory.hpp" +#include "broadcast_width_multi_core_program_factory.hpp" +#include "element_wise_multi_core_program_factory.hpp" #include "tensor/tensor.hpp" #include "third_party/magic_enum/magic_enum.hpp" #include "tt_eager/tensor/host_buffer/functions.hpp" #include "tt_eager/tensor/tensor_utils.hpp" #include "tt_eager/tt_dnn/op_library/compute_kernel_config.hpp" - -#include "tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" #include "tt_eager/tt_dnn/op_library/run_operation.hpp" #include "tt_metal/host_api.hpp" #include "tt_metal/impl/dispatch/command_queue.hpp" #include "ttnn/core.hpp" #include "ttnn/decorators.hpp" +#include "ttnn/device_operation.hpp" #include "ttnn/types.hpp" namespace ttnn::operations::binary { -enum class BinaryOpType { - ADD, - SUB, - MUL, - GT, - LT, - LTE, - GTE, - EQ, - NE, - SQUARED_DIFFERENCE, - BIAS_GELU, - LOGADDEXP, - LOGICAL_AND, - LOGICAL_OR, - LDEXP, - LOGADDEXP2, - DIV_FAST -}; - -using FusedActivations = std::vector; -namespace utils { - -std::map get_defines(BinaryOpType op_type, const std::optional in_dtype = std::nullopt, const std::optional out_dtype = std::nullopt, - const std::optional fused_activations = std::nullopt); - -} // namespace utils - constexpr uint8_t DefaultQueueId = 0; struct Binary { - BinaryOpType binary_op_type; - bool in_place; - const std::optional> activations; - const MemoryConfig memory_config; - const DataType dtype; - std::optional compute_kernel_config; - - void validate_with_output_tensors(const std::vector &input_tensors, const std::vector>& output_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; - std::vector create_output_tensors(const std::vector &input_tensors, const std::vector>& output_tensors) const; - operation::ProgramWithCallbacks create_program( - const std::vector &input_tensors, std::vector &output_tensors) const; - - const operation::Hash compute_program_hash(const std::vector &input_tensors) const; - - operation::OpPerformanceModel create_op_performance_model( - const std::vector &input_tensors, - const std::vector> &optional_input_tensors, - std::vector &output_tensors) const; - - static constexpr auto attribute_names = std::forward_as_tuple( - "binary_op_type", "in_place", "activations", "memory_config", "dtype", "compute_kernel_config"); - const auto attribute_values() const { - return std::forward_as_tuple( - this->binary_op_type, - this->in_place, - this->activations, - this->memory_config, - this->dtype, - this->compute_kernel_config); - } + struct operation_attributes_t { + BinaryOpType binary_op_type; + bool in_place; + const std::optional activations; + const MemoryConfig memory_config; + const DataType dtype; + std::optional compute_kernel_config; + + static constexpr auto attribute_names = std::forward_as_tuple( + "binary_op_type", "in_place", "activations", "memory_config", "dtype", "compute_kernel_config"); + const auto attribute_values() const { + return std::forward_as_tuple( + this->binary_op_type, + this->in_place, + this->activations, + this->memory_config, + this->dtype, + this->compute_kernel_config); + } + }; + struct tensor_args_t { + const Tensor& input_tensor_a; + const Tensor& input_tensor_b; + std::optional output_tensor; + + static constexpr auto attribute_names = + std::forward_as_tuple("input_tensor_a", "input_tensor_b", "output_tensor"); + const auto attribute_values() const { + return std::forward_as_tuple(this->input_tensor_a, this->input_tensor_b, this->output_tensor); + } + }; + using shape_return_value_t = ttnn::Shape; + using tensor_return_value_t = Tensor; + + struct ElementWiseMultiCore { + static auto create(auto&&... args) { + return element_wise_multi_core_program_factory::create(std::forward(args)...); + } + static void override_runtime_arguments(auto&&... args) { + element_wise_multi_core_program_factory::override_runtime_arguments(std::forward(args)...); + } + }; + + struct BroadcastWidthMultiCore { + static auto create(auto&&... args) { + return broadcast_width_multi_core_program_factory::create(std::forward(args)...); + } + static void override_runtime_arguments(auto&&... args) { + broadcast_width_multi_core_program_factory::override_runtime_arguments( + std::forward(args)...); + } + }; + + struct BroadcastHeightMultiCore { + static auto create(auto&&... args) { + return broadcast_height_multi_core_program_factory::create(std::forward(args)...); + } + static void override_runtime_arguments(auto&&... args) { + broadcast_height_multi_core_program_factory::override_runtime_arguments( + std::forward(args)...); + } + }; + + struct BroadcastHeightAndWidthMultiCore { + static auto create(auto&&... args) { + return broadcast_height_and_width_multi_core_program_factory::create(std::forward(args)...); + } + static void override_runtime_arguments(auto&&... args) { + broadcast_height_and_width_multi_core_program_factory::override_runtime_arguments( + std::forward(args)...); + } + }; + + using program_factory_t = std::variant< + ElementWiseMultiCore, + BroadcastWidthMultiCore, + BroadcastHeightMultiCore, + BroadcastHeightAndWidthMultiCore>; + + static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&); + + static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); + static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); + + static shape_return_value_t compute_output_shapes( + const operation_attributes_t&, const tensor_args_t&); + + static tensor_return_value_t create_output_tensors( + const operation_attributes_t& operation_attributes, const tensor_args_t&); + + static tt::stl::hash::hash_t compute_program_hash( + const operation_attributes_t&, const tensor_args_t&); + + static operation::OpPerformanceModel create_op_performance_model( + const operation_attributes_t& attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value); }; } // namespace ttnn::operations::binary diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op_type.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op_type.hpp new file mode 100644 index 00000000000..232808250b8 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op_type.hpp @@ -0,0 +1,49 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include + +#include "tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" +#include "ttnn/types.hpp" + +namespace ttnn::operations::binary { + +enum class BinaryOpType { + ADD, + SUB, + MUL, + GT, + LT, + LTE, + GTE, + EQ, + NE, + SQUARED_DIFFERENCE, + BIAS_GELU, + LOGADDEXP, + LOGICAL_AND, + LOGICAL_OR, + LDEXP, + LOGADDEXP2, + DIV_FAST +}; + +using FusedActivations = std::vector; + +namespace utils { + +std::map get_defines( + BinaryOpType op_type, + const std::optional in_dtype = std::nullopt, + const std::optional out_dtype = std::nullopt, + const std::optional fused_activations = std::nullopt); + +} // namespace utils + +} // namespace ttnn::operations::binary diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.hpp new file mode 100644 index 00000000000..dbb5272b6a4 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.hpp @@ -0,0 +1,352 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "impl/buffers/buffer.hpp" +#include "tensor/tensor.hpp" +#include "tt_dnn/op_library/bcast/bcast_op.hpp" +#include "tt_dnn/op_library/work_split.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" +#include "ttnn/device_operation.hpp" +#include "binary_op_type.hpp" + +namespace ttnn::operations::binary::broadcast_height_and_width_multi_core_program_factory { + +static const tt::tt_metal::BcastOpMath binary_op_type_to_bcast_op_math(const BinaryOpType binary_op_type) { + switch (binary_op_type) { + case BinaryOpType::ADD: return tt::tt_metal::BcastOpMath::ADD; + case BinaryOpType::SUB: return tt::tt_metal::BcastOpMath::SUB; + case BinaryOpType::MUL: return tt::tt_metal::BcastOpMath::MUL; + default: TT_THROW("BinaryOpType cannot be mapped to BcastOpMath"); + } +} + +inline auto create(const auto& operation_attributes, const auto& tensor_args, auto& tensor_return) { + using namespace tt; + using namespace tt::tt_metal; + + const auto& a = tensor_args.input_tensor_a; + const auto& b = tensor_args.input_tensor_b; + auto& output = tensor_return; + auto bcast_math = binary_op_type_to_bcast_op_math(operation_attributes.binary_op_type); + const auto ashape = a.get_legacy_shape(); + const auto bshape = b.get_legacy_shape(); + uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1; + uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1; + uint32_t H = ashape[-2]; + uint32_t W = ashape[-1]; + uint32_t bN = bshape.rank() >= 4 ? bshape[-4] : 1; + uint32_t bC = bshape.rank() >= 3 ? bshape[-3] : 1; + uint32_t bH = bshape[-2]; + uint32_t bW = bshape[-1]; + uint32_t NC = N * C; + uint32_t HW = H * W; + + uint32_t Wt = W / TILE_WIDTH; + uint32_t Ht = H / TILE_HEIGHT; + uint32_t HtWt = Ht * Wt; + + uint32_t num_tensor_tiles = NC * Ht * Wt; + + uint32_t bnc1 = (bN * bC == 1); + + tt_metal::Program program = tt_metal::CreateProgram(); + + tt_metal::Device* device = a.device(); + + std::optional shard_spec = std::nullopt; + bool src0_sharded = a.memory_config().is_sharded(); + bool output_sharded = output.memory_config().is_sharded(); + if (src0_sharded) { + shard_spec = a.shard_spec().value(); + } else if (output_sharded) { + shard_spec = output.shard_spec().value(); + } + + tt::DataFormat src0_cb_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); + tt::DataFormat src1_cb_data_format = tt_metal::datatype_to_dataformat_converter(b.get_dtype()); + tt::DataFormat dst_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + + uint32_t src0_single_tile_size = tt_metal::detail::TileSize(src0_cb_data_format); + uint32_t src1_single_tile_size = tt_metal::detail::TileSize(src1_cb_data_format); + uint32_t dst_single_tile_size = tt_metal::detail::TileSize(dst_cb_data_format); + + auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + auto all_device_cores = CoreRange({0, 0}, {num_cores_x - 1, num_cores_y - 1}); + + auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = + split_work_to_cores(compute_with_storage_grid_size, num_tensor_tiles); + + auto src0_buffer = a.buffer(); + auto src1_buffer = b.buffer(); + auto dst_buffer = output.buffer(); + TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + + uint32_t src0_cb_index = 0; + uint32_t num_input_tiles = 2; + uint32_t num_tiles_per_shard = 0; + if (shard_spec.has_value()) { + num_tiles_per_shard = shard_spec.value().shape[0] * shard_spec.value().shape[1] / TILE_HW; + num_tiles_per_core_group_1 = num_tiles_per_shard; + num_tiles_per_core_group_2 = 0; + all_cores = shard_spec.value().grid; + core_group_1 = all_cores; + core_group_2 = CoreRangeSet({}); + } + + uint32_t num_input_tiles_cb0 = src0_sharded ? num_tiles_per_shard : num_input_tiles; + + tt_metal::CircularBufferConfig src0_cb_config = + tt_metal::CircularBufferConfig( + num_input_tiles_cb0 * src0_single_tile_size, {{src0_cb_index, src0_cb_data_format}}) + .set_page_size(src0_cb_index, src0_single_tile_size); + if (src0_sharded) { + src0_cb_config = src0_cb_config.set_globally_allocated_address(*a.buffer()); + } + auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_device_cores, src0_cb_config); + + uint32_t src1_cb_index = 1; + tt_metal::CircularBufferConfig src1_cb_config = + tt_metal::CircularBufferConfig(num_input_tiles * src1_single_tile_size, {{src1_cb_index, src1_cb_data_format}}) + .set_page_size(src1_cb_index, src1_single_tile_size); + auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_device_cores, src1_cb_config); + + uint32_t output_cb_index = 16; // output operands start at index 16 + uint32_t num_output_tiles = output_sharded ? num_tiles_per_shard : 2; + tt_metal::CircularBufferConfig output_cb_config = + tt_metal::CircularBufferConfig(num_output_tiles * dst_single_tile_size, {{output_cb_index, dst_cb_data_format}}) + .set_page_size(output_cb_index, dst_single_tile_size); + if (output_sharded) { + output_cb_config = output_cb_config.set_globally_allocated_address(*output.buffer()); + } + auto cb_output = tt_metal::CreateCircularBuffer(program, all_device_cores, output_cb_config); + + bool src0_is_dram = src0_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + bool src1_is_dram = src1_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + std::vector reader_compile_time_args = {(uint32_t)src0_is_dram, (uint32_t)src1_is_dram}; + + bool dst_is_dram = dst_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + std::vector writer_compile_time_args = {(std::uint32_t)output_cb_index, (std::uint32_t)dst_is_dram}; + + std::map reader_defines; + std::map bcast_compute_defines = bcast_op_utils::get_defines(BcastOpDim::HW, bcast_math); + if (bnc1) { + reader_defines["BCAST_SCALAR"] = "1"; + bcast_compute_defines["BCAST_SCALAR"] = "1"; + } + if (src0_sharded) { + reader_defines["IN0_SHARDED"] = "1"; + } + KernelHandle binary_reader_kernel_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/op_library/bcast/kernels/dataflow/reader_bcast_hw_interleaved_partitioned.cpp", + all_device_cores, + tt_metal::ReaderDataMovementConfig(reader_compile_time_args, reader_defines)); + + std::map writer_defines; + if (output_sharded) { + writer_defines["OUT_SHARDED"] = "1"; + } + KernelHandle unary_writer_kernel_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/kernels/dataflow/writer_unary_interleaved_start_id.cpp", + all_device_cores, + tt_metal::WriterDataMovementConfig(writer_compile_time_args, writer_defines)); + + auto bcast_kernel_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/op_library/bcast/kernels/compute/bcast_hw.cpp", + all_device_cores, + tt_metal::ComputeConfig{.compile_args = {}, .defines = bcast_compute_defines}); + + for (uint32_t i = 0, num_tiles_read = 0; i < num_cores_y * num_cores_x; i++) { + CoreCoord core = {i / num_cores_y, i % num_cores_y}; + uint32_t num_tensor_tiles_per_core; + if (core_group_1.core_coord_in_core_ranges(core)) { + num_tensor_tiles_per_core = num_tiles_per_core_group_1; + } else if (core_group_2.core_coord_in_core_ranges(core)) { + num_tensor_tiles_per_core = num_tiles_per_core_group_2; + } else { + tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, std::vector(7, 0)); + tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, std::vector(3, 0)); + tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, std::vector(3, 0)); + continue; + } + + tt_metal::SetRuntimeArgs( + program, + binary_reader_kernel_id, + core, + {a.buffer()->address(), // 0 + b.buffer()->address(), + num_tensor_tiles_per_core, + HtWt, + num_tiles_read / HtWt * HtWt, + num_tiles_read % HtWt, + bnc1 ? 0 : num_tiles_read / HtWt}); + + tt_metal::SetRuntimeArgs( + program, + bcast_kernel_id, + core, + { + 1, // B + 1, // Ht + num_tensor_tiles_per_core // Wt + }); + + tt_metal::SetRuntimeArgs( + program, + unary_writer_kernel_id, + core, + { + output.buffer()->address(), + num_tensor_tiles_per_core, + num_tiles_read, + }); + num_tiles_read += num_tensor_tiles_per_core; + } + + return device_operation::CachedProgram{ + std::move(program), + binary_reader_kernel_id, + unary_writer_kernel_id, + bcast_kernel_id, + compute_with_storage_grid_size, + cb_src0, + src0_single_tile_size, + src1_single_tile_size, + dst_single_tile_size, + cb_output}; +} + +inline void override_runtime_arguments( + auto& cached_program, const auto& operation_attributes, const auto& tensor_args, auto& tensor_return) { + using namespace tt; + using namespace tt::tt_metal; + + const auto& input_tensor_a = tensor_args.input_tensor_a; + const auto& input_tensor_b = tensor_args.input_tensor_b; + auto& output_tensor = tensor_return; + + auto&& [binary_reader_kernel_id, unary_writer_kernel_id, bcast_kernel_id, compute_with_storage_grid_size, cb_src0, src0_single_tile_size, src1_single_tile_size, dst_single_tile_size, cb_output] = + cached_program.program_attributes; + + auto& program = cached_program.program; + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + + auto src_buffer_a = input_tensor_a.buffer(); + auto src_dram_buffer_b = input_tensor_b.buffer(); + std::optional shard_spec = std::nullopt; + bool src0_sharded = input_tensor_a.memory_config().is_sharded(); + bool out_sharded = output_tensor.memory_config().is_sharded(); + + if (src0_sharded) { + shard_spec = input_tensor_a.shard_spec().value(); + } else if (out_sharded) { + shard_spec = output_tensor.shard_spec().value(); + } + + auto dst_buffer = output_tensor.buffer(); + + const auto ashape = input_tensor_a.get_legacy_shape(); + const auto bshape = input_tensor_b.get_legacy_shape(); + uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1; + uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1; + uint32_t H = ashape[-2]; + uint32_t W = ashape[-1]; + uint32_t bN = bshape.rank() >= 4 ? bshape[-4] : 1; + uint32_t bC = bshape.rank() >= 3 ? bshape[-3] : 1; + uint32_t bH = bshape[-2]; + uint32_t bW = bshape[-1]; + uint32_t NC = N * C; + uint32_t HW = H * W; + + uint32_t Wt = W / TILE_WIDTH; + uint32_t Ht = H / TILE_HEIGHT; + uint32_t HtWt = Ht * Wt; + + uint32_t num_tensor_tiles = NC * Ht * Wt; + + uint32_t bnc1 = (bN * bC == 1); + + auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = + split_work_to_cores(compute_with_storage_grid_size, num_tensor_tiles); + + if (shard_spec.has_value()) { + uint32_t num_tiles_per_shard = 0; + num_tiles_per_shard = shard_spec.value().shape[0] * shard_spec.value().shape[1] / TILE_HW; + num_tiles_per_core_group_1 = num_tiles_per_shard; + num_tiles_per_core_group_2 = 0; + all_cores = shard_spec.value().grid; + core_group_1 = all_cores; + core_group_2 = CoreRangeSet({}); + } + + for (uint32_t i = 0, num_tiles_read = 0; i < num_cores_y * num_cores_x; i++) { + CoreCoord core = {i / num_cores_y, i % num_cores_y}; + uint32_t num_tensor_tiles_per_core; + if (core_group_1.core_coord_in_core_ranges(core)) { + num_tensor_tiles_per_core = num_tiles_per_core_group_1; + } else if (core_group_2.core_coord_in_core_ranges(core)) { + num_tensor_tiles_per_core = num_tiles_per_core_group_2; + } else { + tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, std::vector(7, 0)); + tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, std::vector(3, 0)); + tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, std::vector(3, 0)); + continue; + } + + tt_metal::SetRuntimeArgs( + program, + binary_reader_kernel_id, + core, + {src_buffer_a->address(), // 0 + src_dram_buffer_b->address(), + num_tensor_tiles_per_core, + HtWt, + num_tiles_read / HtWt * HtWt, + num_tiles_read % HtWt, + bnc1 ? 0 : num_tiles_read / HtWt}); + + tt_metal::SetRuntimeArgs( + program, + bcast_kernel_id, + core, + { + 1, // B + 1, // Ht + num_tensor_tiles_per_core // Wt + }); + + tt_metal::SetRuntimeArgs( + program, + unary_writer_kernel_id, + core, + { + dst_buffer->address(), + num_tensor_tiles_per_core, + num_tiles_read, + }); + num_tiles_read += num_tensor_tiles_per_core; + } + + if (src0_sharded) { + UpdateDynamicCircularBufferAddress(program, cb_src0, *src_buffer_a); + UpdateCircularBufferTotalSize(program, cb_src0, num_tiles_per_core_group_1 * src0_single_tile_size); + } + + if (out_sharded) { + UpdateDynamicCircularBufferAddress(program, cb_output, *dst_buffer); + UpdateCircularBufferTotalSize(program, cb_output, num_tiles_per_core_group_1 * dst_single_tile_size); + } +} + +} // namespace ttnn::operations::binary::broadcast_height_and_width_multi_core_program_factory diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.hpp new file mode 100644 index 00000000000..71cb2cf07fa --- /dev/null +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.hpp @@ -0,0 +1,314 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tensor/tensor.hpp" +#include "tt_dnn/op_library/bcast/bcast_op.hpp" +#include "tt_dnn/op_library/work_split.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" +#include "ttnn/device_operation.hpp" +#include "binary_op_type.hpp" + +namespace ttnn::operations::binary::broadcast_height_multi_core_program_factory { + +static const tt::tt_metal::BcastOpMath binary_op_type_to_bcast_op_math(const BinaryOpType binary_op_type) { + switch (binary_op_type) { + case BinaryOpType::ADD: return tt::tt_metal::BcastOpMath::ADD; + case BinaryOpType::SUB: return tt::tt_metal::BcastOpMath::SUB; + case BinaryOpType::MUL: return tt::tt_metal::BcastOpMath::MUL; + default: TT_THROW("BinaryOpType cannot be mapped to BcastOpMath"); + } +} + +inline auto create(const auto& operation_attributes, const auto& tensor_args, auto& tensor_return) { + using namespace tt; + using namespace tt::tt_metal; + + const auto& a = tensor_args.input_tensor_a; + const auto& b = tensor_args.input_tensor_b; + auto& output = tensor_return; + auto bcast_math = binary_op_type_to_bcast_op_math(operation_attributes.binary_op_type); + + const auto ashape = a.get_legacy_shape(); + const auto bshape = b.get_legacy_shape(); + uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1; + uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1; + uint32_t H = ashape[-2]; + uint32_t W = ashape[-1]; + uint32_t bN = bshape.rank() >= 4 ? bshape[-4] : 1; + uint32_t bC = bshape.rank() >= 3 ? bshape[-3] : 1; + uint32_t bH = bshape[-2]; + uint32_t bW = bshape[-1]; + uint32_t NC = N * C; + uint32_t HW = H * W; + + uint32_t Wt = W / TILE_WIDTH; + uint32_t Ht = H / TILE_HEIGHT; + + uint32_t num_tensor_tiles = NC * Ht * Wt; + uint32_t num_btensor_tiles = NC * bH * bW / TILE_HW; + + uint32_t bnc1 = (bN * bC == 1) ? 1 : 0; + + tt_metal::Program program = tt_metal::CreateProgram(); + + tt_metal::Device* device = a.device(); + + tt::DataFormat src0_cb_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); + tt::DataFormat src1_cb_data_format = tt_metal::datatype_to_dataformat_converter(b.get_dtype()); + tt::DataFormat dst_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + + uint32_t src0_single_tile_size = tt_metal::detail::TileSize(src0_cb_data_format); + uint32_t src1_single_tile_size = tt_metal::detail::TileSize(src1_cb_data_format); + uint32_t dst_single_tile_size = tt_metal::detail::TileSize(dst_cb_data_format); + + auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + auto all_device_cores = CoreRange({0, 0}, {num_cores_x - 1, num_cores_y - 1}); + + auto [num_cores, all_cores, core_group_1, core_group_2, Ht_per_core_group_1, Ht_per_core_group_2] = + split_work_to_cores(compute_with_storage_grid_size, Ht); + + auto src0_buffer = a.buffer(); + auto src1_buffer = b.buffer(); + auto dst_buffer = output.buffer(); + TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + + uint32_t src0_cb_index = 0; + uint32_t num_input_tiles = 2; + + tt_metal::CircularBufferConfig src0_cb_config = + tt_metal::CircularBufferConfig(num_input_tiles * src0_single_tile_size, {{src0_cb_index, src0_cb_data_format}}) + .set_page_size(src0_cb_index, src0_single_tile_size); + auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_device_cores, src0_cb_config); + + uint32_t src1_cb_index = 1; + tt_metal::CircularBufferConfig src1_cb_config = + tt_metal::CircularBufferConfig(num_input_tiles * src1_single_tile_size, {{src1_cb_index, src1_cb_data_format}}) + .set_page_size(src1_cb_index, src1_single_tile_size); + auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_device_cores, src1_cb_config); + + uint32_t output_cb_index = 16; // output operands start at index 16 + uint32_t num_output_tiles = 2; + tt_metal::CircularBufferConfig output_cb_config = + tt_metal::CircularBufferConfig(num_output_tiles * dst_single_tile_size, {{output_cb_index, dst_cb_data_format}}) + .set_page_size(output_cb_index, dst_single_tile_size); + auto cb_output = tt_metal::CreateCircularBuffer(program, all_device_cores, output_cb_config); + + bool src0_is_dram = src0_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + bool src1_is_dram = src1_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + std::vector reader_compile_time_args = {(uint32_t)src0_is_dram, (uint32_t)src1_is_dram}; + + bool dst_is_dram = dst_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + std::vector writer_compile_time_args = {(uint32_t)dst_is_dram}; + + KernelHandle binary_reader_kernel_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/op_library/bcast/kernels/dataflow/reader_bcast_h_interleaved_input_rows_partitioned.cpp", + all_device_cores, + tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); + + KernelHandle unary_writer_kernel_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/op_library/bcast/kernels/dataflow/writer_unary_interleaved_input_cols_batched.cpp", + all_device_cores, + tt_metal::WriterDataMovementConfig(writer_compile_time_args)); + + std::map bcast_defines = bcast_op_utils::get_defines(BcastOpDim::H, bcast_math); + auto bcast_kernel_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/op_library/bcast/kernels/compute/bcast_h.cpp", + all_device_cores, + tt_metal::ComputeConfig{.compile_args = {}, .defines = bcast_defines}); + + for (uint32_t i = 0, num_Wtiles_read = 0; i < num_cores_y * num_cores_x; i++) { + CoreCoord core = {i / num_cores_y, i % num_cores_y}; + uint32_t Ht_per_core; + if (core_group_1.core_coord_in_core_ranges(core)) { + Ht_per_core = Ht_per_core_group_1; + } else if (core_group_2.core_coord_in_core_ranges(core)) { + Ht_per_core = Ht_per_core_group_2; + } else { + tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, std::vector(15, 0)); + tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, std::vector(3, 0)); + tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, std::vector(9, 0)); + continue; + } + uint32_t num_tensor_tiles_per_core = NC * Ht_per_core * Wt; + + tt_metal::SetRuntimeArgs( + program, + binary_reader_kernel_id, + core, + { + a.buffer()->address(), // 0 + 0, // 1 + 0, // 2 + num_tensor_tiles_per_core, // 3 + b.buffer()->address(), // 4 + 0, // 5 + 0, // 6 + num_btensor_tiles, // 7 + num_tensor_tiles_per_core, // 8 + NC, // 9 + Ht_per_core, // 10 + Wt, // 11 + bnc1, // 12 + num_Wtiles_read, // 13 + Ht * Wt, // 14 + }); + + tt_metal::SetRuntimeArgs( + program, + bcast_kernel_id, + core, + { + NC, // B + Ht_per_core, // Ht + Wt // Wt + }); + + tt_metal::SetRuntimeArgs( + program, + unary_writer_kernel_id, + core, + { + output.buffer()->address(), + 0, + 0, + Ht_per_core, + Wt, + num_Wtiles_read, + 0, + NC, + Ht * Wt, + }); + + num_Wtiles_read += Ht_per_core * Wt; + } + + return device_operation::CachedProgram{ + std::move(program), + binary_reader_kernel_id, + unary_writer_kernel_id, + bcast_kernel_id, + compute_with_storage_grid_size}; +} + +inline void override_runtime_arguments( + auto& cached_program, const auto& operation_attributes, const auto& tensor_args, auto& tensor_return) { + using namespace tt; + using namespace tt::tt_metal; + + const auto& input_tensor_a = tensor_args.input_tensor_a; + const auto& input_tensor_b = tensor_args.input_tensor_b; + auto& output_tensor = tensor_return; + + auto&& [binary_reader_kernel_id, unary_writer_kernel_id, bcast_kernel_id, compute_with_storage_grid_size] = + cached_program.program_attributes; + + auto& program = cached_program.program; + + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + + auto src_dram_buffer_a = input_tensor_a.buffer(); + auto src_dram_buffer_b = input_tensor_b.buffer(); + + auto dst_dram_buffer = output_tensor.buffer(); + + const auto ashape = input_tensor_a.get_legacy_shape(); + const auto bshape = input_tensor_b.get_legacy_shape(); + uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1; + uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1; + uint32_t H = ashape[-2]; + uint32_t W = ashape[-1]; + uint32_t bN = bshape.rank() >= 4 ? bshape[-4] : 1; + uint32_t bC = bshape.rank() >= 3 ? bshape[-3] : 1; + uint32_t bH = bshape[-2]; + uint32_t bW = bshape[-1]; + uint32_t NC = N * C; + uint32_t HW = H * W; + + uint32_t Wt = W / TILE_WIDTH; + uint32_t Ht = H / TILE_HEIGHT; + + uint32_t num_tensor_tiles = NC * Ht * Wt; + uint32_t num_btensor_tiles = NC * bH * bW / TILE_HW; + + uint32_t bnc1 = (bN * bC == 1) ? 1 : 0; + + auto [num_cores, all_cores, core_group_1, core_group_2, Ht_per_core_group_1, Ht_per_core_group_2] = + split_work_to_cores(compute_with_storage_grid_size, Ht); + + for (uint32_t i = 0, num_Wtiles_read = 0; i < num_cores_y * num_cores_x; i++) { + CoreCoord core = {i / num_cores_y, i % num_cores_y}; + uint32_t Ht_per_core; + if (core_group_1.core_coord_in_core_ranges(core)) { + Ht_per_core = Ht_per_core_group_1; + } else if (core_group_2.core_coord_in_core_ranges(core)) { + Ht_per_core = Ht_per_core_group_2; + } else { + tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, std::vector(15, 0)); + tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, std::vector(3, 0)); + tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, std::vector(9, 0)); + continue; + } + uint32_t num_tensor_tiles_per_core = NC * Ht_per_core * Wt; + + tt_metal::SetRuntimeArgs( + program, + binary_reader_kernel_id, + core, + { + src_dram_buffer_a->address(), // 0 + 0, // 1 + 0, // 2 + num_tensor_tiles_per_core, // 3 + src_dram_buffer_b->address(), // 4 + 0, // 5 + 0, // 6 + num_btensor_tiles, // 7 + num_tensor_tiles_per_core, // 8 + NC, // 9 + Ht_per_core, // 10 + Wt, // 11 + bnc1, // 12 + num_Wtiles_read, // 13 + Ht * Wt, // 14 + }); + + tt_metal::SetRuntimeArgs( + program, + bcast_kernel_id, + core, + { + NC, // B + Ht_per_core, // Ht + Wt // Wt + }); + + tt_metal::SetRuntimeArgs( + program, + unary_writer_kernel_id, + core, + { + dst_dram_buffer->address(), + 0, + 0, + Ht_per_core, + Wt, + num_Wtiles_read, + 0, + NC, + Ht * Wt, + }); + + num_Wtiles_read += Ht_per_core * Wt; + } +} + +} // namespace ttnn::operations::binary::broadcast_height_multi_core_program_factory diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.hpp new file mode 100644 index 00000000000..f565d5526c1 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.hpp @@ -0,0 +1,316 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tensor/tensor.hpp" +#include "tt_dnn/op_library/bcast/bcast_op.hpp" +#include "tt_dnn/op_library/work_split.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" +#include "ttnn/device_operation.hpp" +#include "binary_op_type.hpp" + +namespace ttnn::operations::binary::broadcast_width_multi_core_program_factory { + +static const tt::tt_metal::BcastOpMath binary_op_type_to_bcast_op_math(const BinaryOpType binary_op_type) { + switch (binary_op_type) { + case BinaryOpType::ADD: return tt::tt_metal::BcastOpMath::ADD; + case BinaryOpType::SUB: return tt::tt_metal::BcastOpMath::SUB; + case BinaryOpType::MUL: return tt::tt_metal::BcastOpMath::MUL; + default: TT_THROW("BinaryOpType cannot be mapped to BcastOpMath"); + } +} + +inline auto create(const auto& operation_attributes, const auto& tensor_args, auto& tensor_return) { + using namespace tt; + using namespace tt::tt_metal; + + const auto& a = tensor_args.input_tensor_a; + const auto& b = tensor_args.input_tensor_b; + auto& output = tensor_return; + auto bcast_math = binary_op_type_to_bcast_op_math(operation_attributes.binary_op_type); + + const auto ashape = a.get_legacy_shape(); + const auto bshape = b.get_legacy_shape(); + uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1; + uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1; + uint32_t H = ashape[-2]; + uint32_t W = ashape[-1]; + uint32_t bN = bshape.rank() >= 4 ? bshape[-4] : 1; + uint32_t bC = bshape.rank() >= 3 ? bshape[-3] : 1; + uint32_t bH = bshape[-2]; + uint32_t bW = bshape[-1]; + uint32_t NC = N * C; + uint32_t HW = H * W; + + uint32_t Wt = W / TILE_WIDTH; + uint32_t Ht = H / TILE_HEIGHT; + + uint32_t num_tensor_tiles = NC * Ht * Wt; + uint32_t num_btensor_tiles = NC * bH * bW / TILE_HW; + + uint32_t bnc1 = (bN * bC == 1) ? 1 : 0; + + tt_metal::Program program = tt_metal::CreateProgram(); + + tt_metal::Device* device = a.device(); + + tt::DataFormat src0_cb_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); + tt::DataFormat src1_cb_data_format = tt_metal::datatype_to_dataformat_converter(b.get_dtype()); + tt::DataFormat dst_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + + uint32_t src0_single_tile_size = tt_metal::detail::TileSize(src0_cb_data_format); + uint32_t src1_single_tile_size = tt_metal::detail::TileSize(src1_cb_data_format); + uint32_t dst_single_tile_size = tt_metal::detail::TileSize(dst_cb_data_format); + + auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + auto all_device_cores = CoreRange({0, 0}, {num_cores_x - 1, num_cores_y - 1}); + + auto [num_cores, all_cores, core_group_1, core_group_2, Wt_per_core_group_1, Wt_per_core_group_2] = + split_work_to_cores(compute_with_storage_grid_size, Wt); + + auto src0_buffer = a.buffer(); + auto src1_buffer = b.buffer(); + auto dst_buffer = output.buffer(); + TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + + uint32_t src0_cb_index = 0; + uint32_t num_input_tiles = 2; + + tt_metal::CircularBufferConfig src0_cb_config = + tt_metal::CircularBufferConfig(num_input_tiles * src0_single_tile_size, {{src0_cb_index, src0_cb_data_format}}) + .set_page_size(src0_cb_index, src0_single_tile_size); + auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_device_cores, src0_cb_config); + + uint32_t src1_cb_index = 1; + tt_metal::CircularBufferConfig src1_cb_config = + tt_metal::CircularBufferConfig(num_input_tiles * src1_single_tile_size, {{src1_cb_index, src1_cb_data_format}}) + .set_page_size(src1_cb_index, src1_single_tile_size); + auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_device_cores, src1_cb_config); + + uint32_t output_cb_index = 16; // output operands start at index 16 + uint32_t num_output_tiles = 2; + tt_metal::CircularBufferConfig output_cb_config = + tt_metal::CircularBufferConfig(num_output_tiles * dst_single_tile_size, {{output_cb_index, dst_cb_data_format}}) + .set_page_size(output_cb_index, dst_single_tile_size); + auto cb_output = tt_metal::CreateCircularBuffer(program, all_device_cores, output_cb_config); + + bool src0_is_dram = src0_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + bool src1_is_dram = src1_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + std::vector reader_compile_time_args = {(uint32_t)src0_is_dram, (uint32_t)src1_is_dram}; + + bool dst_is_dram = dst_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + std::vector writer_compile_time_args = {(uint32_t)dst_is_dram}; + + KernelHandle binary_reader_kernel_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/op_library/bcast/kernels/dataflow/reader_bcast_w_interleaved_input_cols_partitioned.cpp", + all_device_cores, + tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); + + KernelHandle unary_writer_kernel_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/op_library/bcast/kernels/dataflow/writer_unary_interleaved_input_cols_batched.cpp", + all_device_cores, + tt_metal::WriterDataMovementConfig(writer_compile_time_args)); + + std::map bcast_defines = bcast_op_utils::get_defines(BcastOpDim::W, bcast_math); + auto bcast_kernel_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/op_library/bcast/kernels/compute/bcast_w.cpp", + all_device_cores, + tt_metal::ComputeConfig{.compile_args = {}, .defines = bcast_defines}); + + for (uint32_t i = 0, num_Wtiles_read = 0; i < num_cores_y * num_cores_x; i++) { + CoreCoord core = {i / num_cores_y, i % num_cores_y}; + uint32_t Wt_per_core; + if (core_group_1.core_coord_in_core_ranges(core)) { + Wt_per_core = Wt_per_core_group_1; + } else if (core_group_2.core_coord_in_core_ranges(core)) { + Wt_per_core = Wt_per_core_group_2; + } else { + tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, std::vector(16, 0)); + tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, std::vector(3, 0)); + tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, std::vector(9, 0)); + continue; + } + uint32_t num_tensor_tiles_per_core = NC * Ht * Wt_per_core; + uint32_t Wt_skip = Wt - Wt_per_core; + + tt_metal::SetRuntimeArgs( + program, + binary_reader_kernel_id, + core, + { + a.buffer()->address(), // 0 + 0, // 1 + 0, // 2 + num_tensor_tiles_per_core, // 3 + b.buffer()->address(), // 4 + 0, // 5 + 0, // 6 + num_btensor_tiles, // 7 + num_tensor_tiles_per_core, // 8 + NC, // 9 + Ht, // 10 + Wt_per_core, // 11 + bnc1, // 12 + num_Wtiles_read, // 13 + Ht * Wt, // 14 + Wt_skip, // 15 + }); + + tt_metal::SetRuntimeArgs( + program, + bcast_kernel_id, + core, + { + NC, // B + Ht, // Ht + Wt_per_core // Wt + }); + + tt_metal::SetRuntimeArgs( + program, + unary_writer_kernel_id, + core, + { + output.buffer()->address(), + 0, + 0, + Ht, + Wt_per_core, + num_Wtiles_read, + Wt_skip, + NC, + Ht * Wt, + }); + num_Wtiles_read += Wt_per_core; + } + + return device_operation::CachedProgram{ + std::move(program), + binary_reader_kernel_id, + unary_writer_kernel_id, + bcast_kernel_id, + compute_with_storage_grid_size}; +} + +inline void override_runtime_arguments( + auto& cached_program, const auto& operation_attributes, const auto& tensor_args, auto& tensor_return) { + using namespace tt; + using namespace tt::tt_metal; + + const auto& input_tensor_a = tensor_args.input_tensor_a; + const auto& input_tensor_b = tensor_args.input_tensor_b; + auto& output_tensor = tensor_return; + + auto&& [binary_reader_kernel_id, unary_writer_kernel_id, bcast_kernel_id, compute_with_storage_grid_size] = + cached_program.program_attributes; + + auto& program = cached_program.program; + + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + + auto src_dram_buffer_a = input_tensor_a.buffer(); + auto src_dram_buffer_b = input_tensor_b.buffer(); + + auto dst_dram_buffer = output_tensor.buffer(); + + const auto ashape = input_tensor_a.get_legacy_shape(); + const auto bshape = input_tensor_b.get_legacy_shape(); + uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1; + uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1; + uint32_t H = ashape[-2]; + uint32_t W = ashape[-1]; + uint32_t bN = bshape.rank() >= 4 ? bshape[-4] : 1; + uint32_t bC = bshape.rank() >= 3 ? bshape[-3] : 1; + uint32_t bH = bshape[-2]; + uint32_t bW = bshape[-1]; + uint32_t NC = N * C; + uint32_t HW = H * W; + + uint32_t Wt = W / TILE_WIDTH; + uint32_t Ht = H / TILE_HEIGHT; + + uint32_t num_tensor_tiles = NC * Ht * Wt; + uint32_t num_btensor_tiles = NC * bH * bW / TILE_HW; + + uint32_t bnc1 = (bN * bC == 1) ? 1 : 0; + + auto [num_cores, all_cores, core_group_1, core_group_2, Wt_per_core_group_1, Wt_per_core_group_2] = + split_work_to_cores(compute_with_storage_grid_size, Wt); + + for (uint32_t i = 0, num_Wtiles_read = 0; i < num_cores_y * num_cores_x; i++) { + CoreCoord core = {i / num_cores_y, i % num_cores_y}; + uint32_t Wt_per_core; + if (core_group_1.core_coord_in_core_ranges(core)) { + Wt_per_core = Wt_per_core_group_1; + } else if (core_group_2.core_coord_in_core_ranges(core)) { + Wt_per_core = Wt_per_core_group_2; + } else { + tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, std::vector(16, 0)); + tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, std::vector(3, 0)); + tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, std::vector(9, 0)); + continue; + } + uint32_t num_tensor_tiles_per_core = NC * Ht * Wt_per_core; + uint32_t Wt_skip = Wt - Wt_per_core; + + tt_metal::SetRuntimeArgs( + program, + binary_reader_kernel_id, + core, + { + src_dram_buffer_a->address(), // 0 + 0, // 1 + 0, // 2 + num_tensor_tiles_per_core, // 3 + src_dram_buffer_b->address(), // 4 + 0, // 5 + 0, // 6 + num_btensor_tiles, // 7 + num_tensor_tiles_per_core, // 8 + NC, // 9 + Ht, // 10 + Wt_per_core, // 11 + bnc1, // 12 + num_Wtiles_read, // 13 + Ht * Wt, // 14 + Wt_skip, // 15 + }); + + tt_metal::SetRuntimeArgs( + program, + bcast_kernel_id, + core, + { + NC, // B + Ht, // Ht + Wt_per_core // Wt + }); + + tt_metal::SetRuntimeArgs( + program, + unary_writer_kernel_id, + core, + { + dst_dram_buffer->address(), + 0, + 0, + Ht, + Wt_per_core, + num_Wtiles_read, + Wt_skip, + NC, + Ht * Wt, + }); + num_Wtiles_read += Wt_per_core; + } +} + +} // namespace ttnn::operations::binary::broadcast_width_multi_core_program_factory diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_program_factory.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.hpp similarity index 60% rename from ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_program_factory.hpp rename to ttnn/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.hpp index 589634601f4..c37ffc80677 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_program_factory.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.hpp @@ -4,37 +4,33 @@ #include -#include "binary_op.hpp" - -#include "tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" -#include "tt_dnn/op_library/work_split.hpp" - -#include "tt_metal/host_api.hpp" +#include "binary_op_type.hpp" +#include "tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" +#include "tt_eager/tt_dnn/op_library/work_split.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" - -namespace ttnn::operations::binary { +namespace ttnn::operations::binary::element_wise_multi_core_program_factory { template inline __attribute__((always_inline)) void set_eltwise_binary_runtime_args( - tt::tt_metal::Program& program, - const tt::tt_metal::Tensor& a, - const tt::tt_metal::Tensor& b, - const tt::tt_metal::Tensor& output, - const tt::tt_metal::KernelHandle binary_reader_kernel_id, - const tt::tt_metal::KernelHandle unary_writer_kernel_id, - const tt::tt_metal::KernelHandle eltwise_binary_kernel_id, - const tt::tt_metal::CBHandle cb_src0, - const tt::tt_metal::CBHandle cb_src1, - const tt::tt_metal::CBHandle cb_output, + Program& program, + const Tensor& a, + const Tensor& b, + const Tensor& output, + const KernelHandle binary_reader_kernel_id, + const KernelHandle unary_writer_kernel_id, + const KernelHandle eltwise_binary_kernel_id, + const CBHandle cb_src0, + const CBHandle cb_src1, + const CBHandle cb_output, const CoreCoord compute_with_storage_grid_size, const uint32_t src0_single_tile_size, const uint32_t src1_single_tile_size, - const uint32_t dst_single_tile_size){ - + const uint32_t dst_single_tile_size) { + using namespace tt; using namespace tt::tt_metal; - using namespace tt::constants; auto src_buffer_a = a.buffer(); auto src_buffer_b = b.buffer(); @@ -71,7 +67,8 @@ inline __attribute__((always_inline)) void set_eltwise_binary_runtime_args( uint32_t block_cnt_per_core_group_1, block_cnt_per_core_group_2; bool row_major; - uint32_t block_height = 0, block_width = 0, block_size = 0, output_width = 0, last_unpadded_block_height = 0, last_unpadded_block_width = 0; + uint32_t block_height = 0, block_width = 0, block_size = 0, output_width = 0, last_unpadded_block_height = 0, + last_unpadded_block_width = 0; CoreCoord end_core; vector cores; @@ -95,14 +92,16 @@ inline __attribute__((always_inline)) void set_eltwise_binary_runtime_args( end_core = (*shard_spec.value().grid.ranges().begin()).end; output_width = output.get_legacy_shape()[-1] / TILE_WIDTH; uint32_t output_height = output.volume() / output.get_legacy_shape()[-1] / TILE_HEIGHT; - last_unpadded_block_height = block_height - (tt::round_up(output_height, block_height) - output_height); - last_unpadded_block_width = block_width - (tt::round_up(output_width, block_width) - output_width); + last_unpadded_block_height = block_height - (round_up(output_height, block_height) - output_height); + last_unpadded_block_width = block_width - (round_up(output_width, block_width) - output_width); } auto bbox = core_group_1.bounding_box(); cores = grid_to_cores_with_noop(bbox.end.x, bbox.end.y, num_cores_x, num_cores_y, row_major); } else { row_major = true; - std::tie(num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2) = split_work_to_cores(compute_with_storage_grid_size, num_tiles, row_major); + std::tie( + num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2) = + split_work_to_cores(compute_with_storage_grid_size, num_tiles, row_major); block_cnt_per_core_group_1 = num_tiles_per_core_group_1; block_cnt_per_core_group_2 = num_tiles_per_core_group_2; cores = grid_to_cores(num_cores_x * num_cores_y, num_cores_x, num_cores_y, row_major); @@ -111,25 +110,24 @@ inline __attribute__((always_inline)) void set_eltwise_binary_runtime_args( uint32_t g1_numcores = core_group_1.num_cores(); uint32_t g2_numcores = core_group_2.num_cores(); - - std::vector< std::vector > binary_reader_args; - std::vector< std::vector > eltwise_binary_args; - std::vector< std::vector > unary_writer_args; - if constexpr(initialize_args) { - binary_reader_args = { cores.size(), std::vector(4) }; - eltwise_binary_args = { cores.size(), std::vector(2) }; + std::vector> binary_reader_args; + std::vector> eltwise_binary_args; + std::vector> unary_writer_args; + if constexpr (initialize_args) { + binary_reader_args = {cores.size(), std::vector(4)}; + eltwise_binary_args = {cores.size(), std::vector(2)}; if (block_sharded and not out_sharded) - unary_writer_args = { cores.size(), std::vector(7) }; + unary_writer_args = {cores.size(), std::vector(7)}; else - unary_writer_args = { cores.size(), std::vector(3) }; + unary_writer_args = {cores.size(), std::vector(3)}; } auto& cached_reader_args = GetRuntimeArgs(program, binary_reader_kernel_id); auto& cached_eltwise_args = GetRuntimeArgs(program, eltwise_binary_kernel_id); auto& cached_writer_args = GetRuntimeArgs(program, unary_writer_kernel_id); - for (uint32_t i = 0, num_tiles_read = 0; i < num_cores_total; ++i){ - const CoreCoord &core = cores.at(i); + for (uint32_t i = 0, num_tiles_read = 0; i < num_cores_total; ++i) { + const CoreCoord& core = cores.at(i); uint32_t num_tiles_per_core = 0; uint32_t block_cnt_per_core = 0; uint32_t block_size_per_core = 0; @@ -154,8 +152,9 @@ inline __attribute__((always_inline)) void set_eltwise_binary_runtime_args( } continue; } - if constexpr(initialize_args) { - binary_reader_args[i] = {src_buffer_a->address(), src_buffer_b->address(), num_tiles_per_core, num_tiles_read}; + if constexpr (initialize_args) { + binary_reader_args[i] = { + src_buffer_a->address(), src_buffer_b->address(), num_tiles_per_core, num_tiles_read}; eltwise_binary_args[i] = {block_cnt_per_core, block_size_per_core}; } else { auto& reader_args = cached_reader_args.at(core.x).at(core.y); @@ -191,16 +190,17 @@ inline __attribute__((always_inline)) void set_eltwise_binary_runtime_args( unpadded_block_height = last_unpadded_block_height; } } - if constexpr(initialize_args) { - unary_writer_args[i] = { dst_buffer->address(), - block_height, - block_width, - unpadded_block_height, - unpadded_block_width, - output_width, - block_size, - block_start_height_offset * output_width + block_start_width_offset, - 0 }; + if constexpr (initialize_args) { + unary_writer_args[i] = { + dst_buffer->address(), + block_height, + block_width, + unpadded_block_height, + unpadded_block_width, + output_width, + block_size, + block_start_height_offset * output_width + block_start_width_offset, + 0}; } else { auto& writer_args = cached_writer_args.at(core.x).at(core.y); writer_args[0] = dst_buffer->address(); @@ -214,8 +214,8 @@ inline __attribute__((always_inline)) void set_eltwise_binary_runtime_args( writer_args[8] = 0; } } else { - if constexpr(initialize_args) { - unary_writer_args[i] = { dst_buffer->address(), num_tiles_per_core, num_tiles_read }; + if constexpr (initialize_args) { + unary_writer_args[i] = {dst_buffer->address(), num_tiles_per_core, num_tiles_read}; } else { auto& writer_args = cached_writer_args.at(core.x).at(core.y); writer_args[0] = dst_buffer->address(); @@ -226,7 +226,7 @@ inline __attribute__((always_inline)) void set_eltwise_binary_runtime_args( num_tiles_read += num_tiles_per_core; } - if constexpr(initialize_args) { + if constexpr (initialize_args) { SetRuntimeArgs(program, binary_reader_kernel_id, cores, binary_reader_args); SetRuntimeArgs(program, eltwise_binary_kernel_id, cores, eltwise_binary_args); SetRuntimeArgs(program, unary_writer_kernel_id, cores, unary_writer_args); @@ -244,27 +244,33 @@ inline __attribute__((always_inline)) void set_eltwise_binary_runtime_args( UpdateDynamicCircularBufferAddress(program, cb_output, *dst_buffer); UpdateCircularBufferTotalSize(program, cb_output, num_tiles_per_core_group_1 * dst_single_tile_size); } - } -operation::ProgramWithCallbacks eltwise_binary_multi_core(const ttnn::Tensor &a, const ttnn::Tensor &b, const ttnn::Tensor& output, - BinaryOpType op_type, const std::optional> fused_activations) { +inline auto create(const auto& operation_attributes, const auto& tensor_args, auto& tensor_return) { + using namespace tt; using namespace tt::tt_metal; - using namespace tt::constants; + + const auto& a = tensor_args.input_tensor_a; + const auto& b = tensor_args.input_tensor_b; + auto& output = tensor_return; + const auto& op_type = operation_attributes.binary_op_type; + + std::vector fused_activations = + operation_attributes.activations.value_or(std::vector{}); Program program{}; - tt::DataFormat src0_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); - uint32_t src0_single_tile_size = tt::tt_metal::detail::TileSize(src0_cb_data_format); - tt::DataFormat src1_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(b.get_dtype()); - uint32_t src1_single_tile_size = tt::tt_metal::detail::TileSize(src1_cb_data_format); - tt::DataFormat dst_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype()); - uint32_t dst_single_tile_size = tt::tt_metal::detail::TileSize(dst_cb_data_format); + tt::DataFormat src0_cb_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); + uint32_t src0_single_tile_size = tt_metal::detail::TileSize(src0_cb_data_format); + tt::DataFormat src1_cb_data_format = tt_metal::datatype_to_dataformat_converter(b.get_dtype()); + uint32_t src1_single_tile_size = tt_metal::detail::TileSize(src1_cb_data_format); + tt::DataFormat dst_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + uint32_t dst_single_tile_size = tt_metal::detail::TileSize(dst_cb_data_format); - tt::tt_metal::Buffer *src0_buffer = a.buffer(); - tt::tt_metal::Buffer *src1_buffer = b.buffer(); + tt_metal::Buffer* src0_buffer = a.buffer(); + tt_metal::Buffer* src1_buffer = b.buffer(); - tt::tt_metal::Device *device = a.device(); + tt_metal::Device* device = a.device(); std::optional shard_spec = std::nullopt; bool src0_sharded = a.memory_config().is_sharded(); @@ -294,50 +300,56 @@ operation::ProgramWithCallbacks eltwise_binary_multi_core(const ttnn::Tensor &a, max_block_size = find_max_block_size(num_tiles_per_shard); } - tt::tt_metal::Buffer *dst_buffer = output.buffer(); + tt_metal::Buffer* dst_buffer = output.buffer(); TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); auto all_device_cores = CoreRange({0, 0}, {num_cores_x - 1, num_cores_y - 1}); uint32_t src0_cb_index = 0; uint32_t num_input_tiles = src0_sharded ? num_tiles_per_shard : 2 * max_block_size; - tt::tt_metal::CircularBufferConfig cb_src0_config = tt::tt_metal::CircularBufferConfig(num_input_tiles * src0_single_tile_size, {{src0_cb_index, src0_cb_data_format}}) - .set_page_size(src0_cb_index, src0_single_tile_size); + tt_metal::CircularBufferConfig cb_src0_config = + tt_metal::CircularBufferConfig(num_input_tiles * src0_single_tile_size, {{src0_cb_index, src0_cb_data_format}}) + .set_page_size(src0_cb_index, src0_single_tile_size); if (src0_sharded) { cb_src0_config = cb_src0_config.set_globally_allocated_address(*a.buffer()); } - auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, all_device_cores, cb_src0_config); + auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_src0_config); uint32_t src1_cb_index = 1; num_input_tiles = src1_sharded ? num_tiles_per_shard : 2 * max_block_size; - tt::tt_metal::CircularBufferConfig cb_src1_config = tt::tt_metal::CircularBufferConfig(num_input_tiles * src1_single_tile_size, {{src1_cb_index, src1_cb_data_format}}) - .set_page_size(src1_cb_index, src1_single_tile_size); + tt_metal::CircularBufferConfig cb_src1_config = + tt_metal::CircularBufferConfig(num_input_tiles * src1_single_tile_size, {{src1_cb_index, src1_cb_data_format}}) + .set_page_size(src1_cb_index, src1_single_tile_size); if (src1_sharded) { cb_src1_config = cb_src1_config.set_globally_allocated_address(*b.buffer()); } - auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, all_device_cores, cb_src1_config); + auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_src1_config); - std::map eltwise_defines = utils::get_defines(op_type, a.get_dtype(), output.get_dtype(), fused_activations); + std::map eltwise_defines = + utils::get_defines(op_type, a.get_dtype(), output.get_dtype(), fused_activations); if (eltwise_defines.find("SFPU_OP_INIT_PRE_IN0_0") != eltwise_defines.end()) { - tt::tt_metal::CircularBufferConfig cb_interm_config = tt::tt_metal::CircularBufferConfig(1 * src0_single_tile_size, {{tt::CB::c_intermed0, src0_cb_data_format}}) - .set_page_size(tt::CB::c_intermed0, src0_single_tile_size); - auto cb_interm = tt::tt_metal::CreateCircularBuffer(program, all_device_cores, cb_interm_config); + tt_metal::CircularBufferConfig cb_interm_config = + tt_metal::CircularBufferConfig(1 * src0_single_tile_size, {{CB::c_intermed0, src0_cb_data_format}}) + .set_page_size(CB::c_intermed0, src0_single_tile_size); + auto cb_interm = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_interm_config); } if (eltwise_defines.find("SFPU_OP_INIT_PRE_IN1_0") != eltwise_defines.end()) { - tt::tt_metal::CircularBufferConfig cb_interm2_config = tt::tt_metal::CircularBufferConfig(1 * src1_single_tile_size, {{tt::CB::c_intermed1, src1_cb_data_format}}) - .set_page_size(tt::CB::c_intermed1, src1_single_tile_size); - auto cb_interm2 = tt::tt_metal::CreateCircularBuffer(program, all_device_cores, cb_interm2_config); + tt_metal::CircularBufferConfig cb_interm2_config = + tt_metal::CircularBufferConfig(1 * src1_single_tile_size, {{CB::c_intermed1, src1_cb_data_format}}) + .set_page_size(CB::c_intermed1, src1_single_tile_size); + auto cb_interm2 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_interm2_config); } - uint32_t output_cb_index = 16; // output operands start at index 16 + uint32_t output_cb_index = 16; // output operands start at index 16 uint32_t num_output_tiles = (out_sharded || block_sharded) ? num_tiles_per_shard : 2 * max_block_size; - tt::tt_metal::CircularBufferConfig cb_output_config = tt::tt_metal::CircularBufferConfig(num_output_tiles * dst_single_tile_size, {{output_cb_index, dst_cb_data_format}}) - .set_page_size(output_cb_index, dst_single_tile_size); + tt_metal::CircularBufferConfig cb_output_config = + tt_metal::CircularBufferConfig(num_output_tiles * dst_single_tile_size, {{output_cb_index, dst_cb_data_format}}) + .set_page_size(output_cb_index, dst_single_tile_size); if (out_sharded) { cb_output_config = cb_output_config.set_globally_allocated_address(*output.buffer()); } - auto cb_output = tt::tt_metal::CreateCircularBuffer(program, all_device_cores, cb_output_config); + auto cb_output = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_output_config); std::map reader_defines; if (src0_sharded) { @@ -351,38 +363,35 @@ operation::ProgramWithCallbacks eltwise_binary_multi_core(const ttnn::Tensor &a, writer_defines["OUT_SHARDED"] = "1"; } - bool src0_is_dram = src0_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - bool src1_is_dram = src1_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - std::vector reader_compile_time_args = { - (std::uint32_t) src0_is_dram, - (std::uint32_t) src1_is_dram - }; + bool src0_is_dram = src0_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + bool src1_is_dram = src1_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + std::vector reader_compile_time_args = {(std::uint32_t)src0_is_dram, (std::uint32_t)src1_is_dram}; - bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - std::vector writer_compile_time_args = { - (std::uint32_t) output_cb_index, - (std::uint32_t) dst_is_dram - }; + bool dst_is_dram = dst_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + std::vector writer_compile_time_args = {(std::uint32_t)output_cb_index, (std::uint32_t)dst_is_dram}; - KernelHandle binary_reader_kernel_id = tt::tt_metal::CreateKernel( + KernelHandle binary_reader_kernel_id = tt_metal::CreateKernel( program, "ttnn/cpp/ttnn/operations/eltwise/binary/device/kernels/dataflow/reader_binary_interleaved_start_id.cpp", all_device_cores, - tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args, reader_defines)); + tt_metal::ReaderDataMovementConfig(reader_compile_time_args, reader_defines)); - KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel( + KernelHandle unary_writer_kernel_id = tt_metal::CreateKernel( program, - (block_sharded and not out_sharded) ? "tt_eager/tt_dnn/op_library/sharded/kernels/dataflow/writer_unary_sharded_blocks_interleaved_start_id.cpp" : "tt_eager/tt_dnn/kernels/dataflow/writer_unary_interleaved_start_id.cpp", + (block_sharded and not out_sharded) + ? "tt_eager/tt_dnn/op_library/sharded/kernels/dataflow/writer_unary_sharded_blocks_interleaved_start_id.cpp" + : "tt_eager/tt_dnn/kernels/dataflow/writer_unary_interleaved_start_id.cpp", all_device_cores, - tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args, writer_defines)); + tt_metal::WriterDataMovementConfig(writer_compile_time_args, writer_defines)); - bool fp32_dest_acc_en = dst_cb_data_format == tt::DataFormat::UInt32 || dst_cb_data_format == tt::DataFormat::Int32 || dst_cb_data_format == tt::DataFormat::Float32; - auto eltwise_binary_kernel_id = tt::tt_metal::CreateKernel( + bool fp32_dest_acc_en = dst_cb_data_format == tt::DataFormat::UInt32 || + dst_cb_data_format == tt::DataFormat::Int32 || + dst_cb_data_format == tt::DataFormat::Float32; + auto eltwise_binary_kernel_id = tt_metal::CreateKernel( program, "ttnn/cpp/ttnn/operations/eltwise/binary/device/kernels/compute/eltwise_binary_kernel.cpp", all_device_cores, - tt::tt_metal::ComputeConfig{.fp32_dest_acc_en=fp32_dest_acc_en, .defines = eltwise_defines}); - + tt_metal::ComputeConfig{.fp32_dest_acc_en = fp32_dest_acc_en, .defines = eltwise_defines}); set_eltwise_binary_runtime_args( program, @@ -400,8 +409,8 @@ operation::ProgramWithCallbacks eltwise_binary_multi_core(const ttnn::Tensor &a, src1_single_tile_size, dst_single_tile_size); - - auto override_runtime_arguments_callback = [ + return device_operation::CachedProgram{ + std::move(program), binary_reader_kernel_id, unary_writer_kernel_id, eltwise_binary_kernel_id, @@ -411,37 +420,33 @@ operation::ProgramWithCallbacks eltwise_binary_multi_core(const ttnn::Tensor &a, compute_with_storage_grid_size, src0_single_tile_size, src1_single_tile_size, - dst_single_tile_size - ] - ( - const void* operation, - tt::tt_metal::Program& program, - const std::vector& input_tensors, - const std::vector>&, - const std::vector& output_tensors - ) { - - auto src_buffer_a = input_tensors.at(0).buffer(); - auto src_buffer_b = input_tensors.at(1).buffer(); - const auto& output_tensor = output_tensors.size() == 1 ? output_tensors.at(0) : input_tensors.at(0); - - set_eltwise_binary_runtime_args( - program, - input_tensors.at(0), - input_tensors.at(1), - output_tensor, - binary_reader_kernel_id, - unary_writer_kernel_id, - eltwise_binary_kernel_id, - cb_src0, - cb_src1, - cb_output, - compute_with_storage_grid_size, - src0_single_tile_size, - src1_single_tile_size, - dst_single_tile_size); - }; - return {.program=std::move(program), .override_runtime_arguments_callback=override_runtime_arguments_callback}; + dst_single_tile_size}; } +inline void override_runtime_arguments( + auto& cached_program, auto& operation_attributes, auto& tensor_args, auto& tensor_return) { + const auto& input_tensor_a = tensor_args.input_tensor_a; + const auto& input_tensor_b = tensor_args.input_tensor_b; + auto& output_tensor = tensor_return; + + auto&& [binary_reader_kernel_id, unary_writer_kernel_id, eltwise_binary_kernel_id, cb_src0, cb_src1, cb_output, compute_with_storage_grid_size, src0_single_tile_size, src1_single_tile_size, dst_single_tile_size] = + cached_program.program_attributes; + + set_eltwise_binary_runtime_args( + cached_program.program, + input_tensor_a, + input_tensor_b, + output_tensor, + binary_reader_kernel_id, + unary_writer_kernel_id, + eltwise_binary_kernel_id, + cb_src0, + cb_src1, + cb_output, + compute_with_storage_grid_size, + src0_single_tile_size, + src1_single_tile_size, + dst_single_tile_size); } + +} // namespace ttnn::operations::binary::element_wise_multi_core_program_factory