From 750d0f2a44750123c87a0707aa4ba081dd224000 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Choi=20HyungSuk=28=EC=B5=9C=ED=98=95=EC=84=9D=29?= Date: Thu, 7 Nov 2024 10:13:44 +0900 Subject: [PATCH] #14707: refactoring return optional tensors (#14708) * #14707: refactoring return optional tensors --- .../operations/test_moreh_group_norm.py | 8 +- ttnn/cpp/ttnn/decorators.hpp | 130 +++++++++--------- .../example_multiple_return.cpp | 14 +- .../example_multiple_return.hpp | 8 +- .../moreh/moreh_adam/moreh_adam.cpp | 26 +--- .../moreh/moreh_adam/moreh_adam.hpp | 5 +- .../moreh/moreh_adamw/moreh_adamw.cpp | 18 +-- .../moreh/moreh_adamw/moreh_adamw.hpp | 5 +- .../moreh_dot_backward/moreh_dot_backward.cpp | 18 +-- .../moreh_dot_backward/moreh_dot_backward.hpp | 5 +- .../moreh_group_norm/moreh_group_norm.cpp | 16 +-- .../moreh_group_norm/moreh_group_norm.hpp | 5 +- .../moreh_group_norm_backward.cpp | 25 ++-- .../moreh_group_norm_backward.hpp | 5 +- .../moreh_layer_norm/moreh_layer_norm.cpp | 16 +-- .../moreh_layer_norm/moreh_layer_norm.hpp | 5 +- .../moreh_layer_norm_backward.cpp | 16 +-- .../moreh_layer_norm_backward.hpp | 5 +- .../moreh_linear_backward.cpp | 18 +-- .../moreh_linear_backward.hpp | 5 +- .../moreh_matmul_backward.cpp | 18 +-- .../moreh_matmul_backward.hpp | 5 +- .../operations/moreh/moreh_sgd/moreh_sgd.cpp | 17 +-- .../operations/moreh/moreh_sgd/moreh_sgd.hpp | 5 +- ttnn/cpp/ttnn/run_operation_inl.hpp | 6 +- 25 files changed, 148 insertions(+), 256 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_moreh_group_norm.py b/tests/ttnn/unit_tests/operations/test_moreh_group_norm.py index b990cae1aa1..ee35d8244e7 100644 --- a/tests/ttnn/unit_tests/operations/test_moreh_group_norm.py +++ b/tests/ttnn/unit_tests/operations/test_moreh_group_norm.py @@ -202,8 +202,6 @@ def make_input_tensors(input_shape, affine, do_backward=False): def run_test_moreh_group_norm(N, C_num_groups, HW, eps, affine, compute_mean_rstd, device): - torch.manual_seed(2024) - H, W = HW C, num_groups = C_num_groups input_shape = (N, C, H, W) @@ -284,6 +282,8 @@ def run_test_moreh_group_norm(N, C_num_groups, HW, eps, affine, compute_mean_rst ], ) def test_moreh_group_norm(N, C_num_groups, HW, eps, affine, compute_mean_rstd, device): + torch.manual_seed(2024) + run_test_moreh_group_norm(N, C_num_groups, HW, eps, affine, compute_mean_rstd, device) @@ -343,7 +343,8 @@ def run_test_moreh_group_norm_backward( if not affine and (gamma_requires_grad or beta_requires_grad): pytest.skip("gamma_requires_grad and beta_requires_grad are only valid when affine is True.") - torch.manual_seed(2024) + if not (input_requires_grad or gamma_requires_grad or beta_requires_grad): + pytest.skip("at least one requires_grad should be True.") C, num_groups = C_num_groups input_shape = (N, C, H, W) @@ -466,6 +467,7 @@ def run_test_moreh_group_norm_backward( def test_moreh_group_norm_backward( N, C_num_groups, HW, eps, affine, input_requires_grad, gamma_requires_grad, beta_requires_grad, device ): + torch.manual_seed(2024) run_test_moreh_group_norm_backward( N, C_num_groups, HW, eps, affine, input_requires_grad, gamma_requires_grad, beta_requires_grad, device ) diff --git a/ttnn/cpp/ttnn/decorators.hpp b/ttnn/cpp/ttnn/decorators.hpp index faa0bcb1881..66f344fa365 100644 --- a/ttnn/cpp/ttnn/decorators.hpp +++ b/ttnn/cpp/ttnn/decorators.hpp @@ -6,14 +6,14 @@ #include +#include "tt_metal/graph/graph_tracking.hpp" #include "tt_metal/third_party/tracy/public/tracy/Tracy.hpp" +#include "ttnn/common/constants.hpp" #include "ttnn/core.hpp" +#include "ttnn/device_operation.hpp" #include "ttnn/operation.hpp" #include "ttnn/run_operation.hpp" #include "ttnn/tensor/tensor.hpp" -#include "ttnn/common/constants.hpp" -#include "ttnn/device_operation.hpp" -#include "tt_metal/graph/graph_tracking.hpp" namespace ttnn { namespace decorators { @@ -51,8 +51,9 @@ auto extract_args_to_vector(args_t&&... args) { return result; } -template -inline Tensors create_async_output_tensors(const Tensors& inputs, const OptionalConstTensors& optional_inputs) { +template +inline auto create_async_output_tensors( + const Tensors& inputs, const OptionalConstTensors& optional_inputs, args_t&&... args) { bool enable_autoformat_device = false; constexpr bool custom_create_async_outputs = @@ -60,8 +61,20 @@ inline Tensors create_async_output_tensors(const Tensors& inputs, const Optional if constexpr (custom_create_async_outputs) { return operation_t::create_async_output_tensors(inputs, optional_inputs); + } else if constexpr (std::is_same_v, OptionalTensors>) { + constexpr bool custom_create_async_optional_outputs = requires(const operation_t& t) { + t.create_async_optional_output_tensors(std::forward(args)...); + }; + static_assert( + custom_create_async_optional_outputs, + "If the operation returns a vector of optional Tensors, it must " + "implement create_async_optional_output_tensors."); + + return operation_t::create_async_optional_output_tensors(std::forward(args)...); } else if constexpr (std::is_same_v, Tensor>) { - return {Tensor(operation::get_workers_for_op_output(inputs, optional_inputs, enable_autoformat_device))}; + return std::vector{ + Tensor(operation::get_workers_for_op_output(inputs, optional_inputs, enable_autoformat_device))}; + } else if constexpr (detail::is_homogenous_tuple()) { Tensors output_tensors; output_tensors.reserve(std::tuple_size_v); @@ -108,21 +121,13 @@ auto map_launch_op_args_to_execute_on_worker_thread_args( } template -auto map_execute_on_worker_thread_return_to_launch_op_return(const T&& value) -> Tensors { +auto map_execute_on_worker_thread_return_to_launch_op_return(const T&& value) { if constexpr (std::is_same_v, Tensors>) { return value; } else if constexpr (std::is_same_v, Tensor>) { - return {value}; + return std::vector{value}; } else if constexpr (std::is_same_v, OptionalTensors>) { - Tensors output_tensors; - auto size = value.size(); - output_tensors.reserve(size); - - auto dummy_tensor = Tensor(); - for (auto& val : value) { - output_tensors.push_back(val.value_or(dummy_tensor)); - } - return output_tensors; + return value; } else if constexpr (is_homogenous_tuple()) { Tensors output_tensors; output_tensors.reserve(std::tuple_size_v); @@ -135,8 +140,8 @@ auto map_execute_on_worker_thread_return_to_launch_op_return(const T&& value) -> } else { static_assert( tt::stl::concepts::always_false_v, - "Operation must return either a single Tensor or a vector of Tensors or implement " - "map_execute_on_worker_thread_return_to_launch_op_return."); + "Operation must return either a single Tensor or a vector of Tensors or a vector of optional Tensors " + "implement map_execute_on_worker_thread_return_to_launch_op_return."); } } @@ -194,7 +199,7 @@ template concept PrimitiveOperationConcept = device_operation::DeviceOperationConcept; // Composite operation allows any code to be executed -template +template concept CompositeOperationConcept = !PrimitiveOperationConcept; template @@ -213,10 +218,11 @@ struct registered_operation_t { } template - requires PrimitiveOperationConcept + requires PrimitiveOperationConcept auto invoke(uint8_t queue_id, args_t&&... args) const { - static_assert(requires { operation_t::invoke(std::forward(args)...); }, - "Primitive Operation must implement operator() method to be invoked."); + static_assert( + requires { operation_t::invoke(std::forward(args)...); }, + "Primitive Operation must implement operator() method to be invoked."); ZoneScopedN("Run primitive ttnn operation"); ZoneName(static_cast(cpp_fully_qualified_name.data.data()), cpp_fully_qualified_name.size()); auto [operation_attributes, tensors_args] = operation_t::invoke(std::forward(args)...); @@ -224,12 +230,11 @@ struct registered_operation_t { } template - requires(PrimitiveOperationConcept) + requires(PrimitiveOperationConcept) auto invoke(args_t&&... args) const { return invoke(DefaultQueueId, std::forward(args)...); } - template requires(not auto_launch_op) auto invoke_composite(args_t&&... args) const { @@ -247,16 +252,14 @@ struct registered_operation_t { // #8479: Fix and re-enable logging in cpp operation decorator // detail::log("Arguments: ", std::forward(args)...); - using execute_on_worker_thread_return_t = - decltype(operation_t::invoke(std::forward(args)...)); + using execute_on_worker_thread_return_t = decltype(operation_t::invoke(std::forward(args)...)); const Tensors input_tensors = detail::extract_args_to_vector(std::forward(args)...); const OptionalConstTensors optional_input_tensors = detail::extract_args_to_vector>(std::forward(args)...); - auto output_tensors = - detail::create_async_output_tensors( - input_tensors, optional_input_tensors); + auto output_tensors = detail::create_async_output_tensors( + input_tensors, optional_input_tensors, std::forward(args)...); const OptionalTensors optional_output_tensors = detail::extract_args_to_vector>(std::forward(args)...); @@ -266,11 +269,11 @@ struct registered_operation_t { [args...]( const Tensors& input_tensors, const OptionalConstTensors& optional_input_tensors, - const OptionalTensors& optional_output_tensors) mutable -> Tensors { + const OptionalTensors& optional_output_tensors) mutable { auto execute_on_worker_thread_args = detail::map_launch_op_args_to_execute_on_worker_thread_args( input_tensors, optional_input_tensors, optional_output_tensors, std::forward(args)...); return std::apply( - [](auto&&... args) -> Tensors { + [](auto&&... args) { return detail::map_execute_on_worker_thread_return_to_launch_op_return( operation_t::invoke(std::forward(args)...)); }, @@ -282,24 +285,12 @@ struct registered_operation_t { optional_output_tensors, enable_autoformat); - if constexpr (std::is_same_v, Tensor>) { return output_tensors.at(0); } else if constexpr (std::is_same_v) { return output_tensors; } else if constexpr (std::is_same_v) { - // convert tensor to optional tensor - auto size = output_tensors.size(); - std::vector> ret(size); - - auto return_flags = operation_t::create_async_return_flag(std::forward(args)...); - - for (uint32_t i = 0 ; i < size; i++) { - if (return_flags.at(i)) { - ret[i] = output_tensors.at(i); - } - } - return ret; + return output_tensors; } else if constexpr (detail::is_homogenous_tuple()) { return detail::make_tuple_from_vector(output_tensors); } else { @@ -312,7 +303,7 @@ struct registered_operation_t { } template - requires(CompositeOperationConcept) + requires(CompositeOperationConcept) auto invoke(args_t&&... args) const { return invoke_composite(std::forward(args)...); } @@ -336,24 +327,20 @@ struct registered_operation_t { } }; -template -struct operation_name_key_t{ +template +struct operation_name_key_t { friend consteval auto get(operation_name_key_t); }; -template -struct operation_key_t{ +template +struct operation_key_t { friend consteval auto get(operation_key_t); }; -template +template struct set_operation_t : std::true_type { - friend consteval auto get(operation_key_t) { - return operation; - } - friend consteval auto get(operation_name_key_t) { - return operation; - } + friend consteval auto get(operation_key_t) { return operation; } + friend consteval auto get(operation_name_key_t) { return operation; } }; constexpr reflect::fixed_string prim_namespace = "ttnn::prim"; @@ -361,18 +348,24 @@ constexpr reflect::fixed_string prim_namespace = "ttnn::prim"; template consteval void assert_operation_in_correct_namespace() { if constexpr (PrimitiveOperationConcept) { - if constexpr(cpp_fully_qualified_name.size() > prim_namespace.size()) { - constexpr auto namespace_substring = tt::stl::reflection::fixed_string_substring<0, prim_namespace.size()>(cpp_fully_qualified_name); - static_assert(tt::stl::reflection::fixed_string_equals(namespace_substring, prim_namespace), "Primitive operations must be in the `ttnn::prim` namespace."); + if constexpr (cpp_fully_qualified_name.size() > prim_namespace.size()) { + constexpr auto namespace_substring = + tt::stl::reflection::fixed_string_substring<0, prim_namespace.size()>(cpp_fully_qualified_name); + static_assert( + tt::stl::reflection::fixed_string_equals(namespace_substring, prim_namespace), + "Primitive operations must be in the `ttnn::prim` namespace."); } else { - #ifndef DISABLE_NAMESPACE_STATIC_ASSERT +#ifndef DISABLE_NAMESPACE_STATIC_ASSERT static_assert(false, "Primitive operations must be in the `ttnn::prim` namespace."); - #endif +#endif } } else { if constexpr (cpp_fully_qualified_name.size() > prim_namespace.size()) { - constexpr auto namespace_substring = tt::stl::reflection::fixed_string_substring<0, prim_namespace.size()>(cpp_fully_qualified_name); - static_assert(not tt::stl::reflection::fixed_string_equals(namespace_substring, prim_namespace), "Composite operations must not be in the `ttnn::prim` namespace."); + constexpr auto namespace_substring = + tt::stl::reflection::fixed_string_substring<0, prim_namespace.size()>(cpp_fully_qualified_name); + static_assert( + not tt::stl::reflection::fixed_string_equals(namespace_substring, prim_namespace), + "Composite operations must not be in the `ttnn::prim` namespace."); } } } @@ -381,13 +374,16 @@ template (); constexpr auto operation = registered_operation_t{}; - static_assert(not requires(operation_name_key_t key) { get(key); }, "Operation with this `cpp_fully_qualified_name` was already registered. Please use a different name."); - static_assert(not requires(operation_key_t key) { get(key); }, "Operation with this `operation_t` was already registered. Please use a different type."); + static_assert( + not requires(operation_name_key_t key) { get(key); }, + "Operation with this `cpp_fully_qualified_name` was already registered. Please use a different name."); + static_assert( + not requires(operation_key_t key) { get(key); }, + "Operation with this `operation_t` was already registered. Please use a different type."); static_assert(set_operation_t::value); return operation; } - template constexpr auto register_operation() { return register_operation_impl(); diff --git a/ttnn/cpp/ttnn/operations/examples/example_multiple_return/example_multiple_return.cpp b/ttnn/cpp/ttnn/operations/examples/example_multiple_return/example_multiple_return.cpp index 4bf3ae0b34d..12b28e54c70 100644 --- a/ttnn/cpp/ttnn/operations/examples/example_multiple_return/example_multiple_return.cpp +++ b/ttnn/cpp/ttnn/operations/examples/example_multiple_return/example_multiple_return.cpp @@ -11,17 +11,11 @@ std::vector> CompositeExampleMutipleReturnOperation::invok return prim::example_multiple_return(input_tensor, return_output1, return_output2); } -std::vector CompositeExampleMutipleReturnOperation::create_async_output_tensors( - const std::vector& input_tensors, const std::vector>& optional_inputs) { - const auto& input_tensor = input_tensors.at(0); +OptionalTensors CompositeExampleMutipleReturnOperation::create_async_optional_output_tensors( + const Tensor& input_tensor, bool return_output1, bool return_output2) { return { - Tensor(operation::get_workers_for_op_output({input_tensor})), - Tensor(operation::get_workers_for_op_output({input_tensor}))}; -} - -std::vector CompositeExampleMutipleReturnOperation::create_async_return_flag(const Tensor& input_tensor, bool return_output1, bool return_output2) { - - return {return_output1, return_output2}; + return_output1 ? std::optional(operation::get_workers_for_op_output({input_tensor})) : std::nullopt, + return_output2 ? std::optional(operation::get_workers_for_op_output({input_tensor})) : std::nullopt}; } } // namespace ttnn::operations::examples diff --git a/ttnn/cpp/ttnn/operations/examples/example_multiple_return/example_multiple_return.hpp b/ttnn/cpp/ttnn/operations/examples/example_multiple_return/example_multiple_return.hpp index 8f9857eee94..c43e03ac671 100644 --- a/ttnn/cpp/ttnn/operations/examples/example_multiple_return/example_multiple_return.hpp +++ b/ttnn/cpp/ttnn/operations/examples/example_multiple_return/example_multiple_return.hpp @@ -16,13 +16,9 @@ struct CompositeExampleMutipleReturnOperation { // is registered static std::vector> invoke(const Tensor& input_tensor, bool return_output1, bool return_output2); - static std::vector create_async_output_tensors( - const std::vector& input_tensors, const std::vector>& optional_inputs); + static OptionalTensors create_async_optional_output_tensors( + const Tensor& input_tensor, bool return_output1, bool return_output2); - // The parameters of this function must be identical to those of invoke. - static std::vector create_async_return_flag( - const Tensor& input_tensor, bool return_output1, bool return_output2 - ); }; } // namespace ttnn::operations::examples diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam.cpp index f6779d926c0..a22063a68a7 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam.cpp @@ -48,24 +48,7 @@ std::vector> MorehAdam::invoke( compute_kernel_config); } -std::vector MorehAdam::create_async_output_tensors( - const std::vector& input_tensors, const std::vector>& optional_inputs) { - const auto& param_in = input_tensors.at(0); - const auto& grad = input_tensors.at(1); - const auto& exp_avg_in = input_tensors.at(2); - const auto& exp_avg_sq_in = input_tensors.at(3); - - const auto& max_exp_avg_sq_in = optional_inputs.at(0); - - return { - Tensor(operation::get_workers_for_op_output({param_in, grad, exp_avg_in, exp_avg_sq_in}, {max_exp_avg_sq_in})), - Tensor(operation::get_workers_for_op_output({param_in, grad, exp_avg_in, exp_avg_sq_in}, {max_exp_avg_sq_in})), - Tensor(operation::get_workers_for_op_output({param_in, grad, exp_avg_in, exp_avg_sq_in}, {max_exp_avg_sq_in})), - Tensor(operation::get_workers_for_op_output({param_in, grad, exp_avg_in, exp_avg_sq_in}, {max_exp_avg_sq_in})), - }; -} - -std::vector MorehAdam::create_async_return_flag( +OptionalTensors MorehAdam::create_async_optional_output_tensors( const Tensor& param_in, const Tensor& grad, const Tensor& exp_avg_in, @@ -85,6 +68,11 @@ std::vector MorehAdam::create_async_return_flag( const std::optional& memory_config, const std::optional& compute_kernel_config) { // First three are always true, last one depends on amsgrad - return {true, true, true, amsgrad.value_or(false)}; + return { + std::optional(operation::get_workers_for_op_output({param_in, grad, exp_avg_in, exp_avg_sq_in}, {max_exp_avg_sq_in})), + std::optional(operation::get_workers_for_op_output({param_in, grad, exp_avg_in, exp_avg_sq_in}, {max_exp_avg_sq_in})), + std::optional(operation::get_workers_for_op_output({param_in, grad, exp_avg_in, exp_avg_sq_in}, {max_exp_avg_sq_in})), + amsgrad.value_or(false) ? std::optional(operation::get_workers_for_op_output({param_in, grad, exp_avg_in, exp_avg_sq_in}, {max_exp_avg_sq_in})) : std::nullopt + }; } } // namespace ttnn::operations::moreh::moreh_adam diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam.hpp index ed36c62ba7b..443d449c503 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam.hpp @@ -29,10 +29,7 @@ struct MorehAdam { const std::optional& memory_config, const std::optional& compute_kernel_config); - static std::vector create_async_output_tensors( - const std::vector& input_tensors, const std::vector>& optional_inputs); - - static std::vector create_async_return_flag( + static OptionalTensors create_async_optional_output_tensors( const Tensor& param_in, const Tensor& grad, const Tensor& exp_avg_in, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_adamw/moreh_adamw.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_adamw/moreh_adamw.cpp index 28c2899cef9..958d30574f6 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_adamw/moreh_adamw.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_adamw/moreh_adamw.cpp @@ -50,17 +50,7 @@ std::vector> MorehAdamw::invoke( compute_kernel_config); } -std::vector MorehAdamw::create_async_output_tensors( - const std::vector& input_tensors, const std::vector>& optional_inputs) { - const auto& input_tensor = input_tensors.at(0); - return { - Tensor(operation::get_workers_for_op_output({input_tensor})), - Tensor(operation::get_workers_for_op_output({input_tensor})), - Tensor(operation::get_workers_for_op_output({input_tensor})), - Tensor(operation::get_workers_for_op_output({input_tensor}))}; -} - -std::vector MorehAdamw::create_async_return_flag( +OptionalTensors MorehAdamw::create_async_optional_output_tensors( const Tensor& param_in, const Tensor& grad, const Tensor& exp_avg_in, @@ -81,6 +71,10 @@ std::vector MorehAdamw::create_async_return_flag( const std::optional& max_exp_avg_sq_out, const std::optional& memory_config, const std::optional compute_kernel_config) { - return std::vector{true, true, true, amsgrad.value_or(false)}; + return { + std::optional(operation::get_workers_for_op_output({param_in, grad, exp_avg_in, exp_avg_sq_in}, {max_exp_avg_sq_in})), + std::optional(operation::get_workers_for_op_output({param_in, grad, exp_avg_in, exp_avg_sq_in}, {max_exp_avg_sq_in})), + std::optional(operation::get_workers_for_op_output({param_in, grad, exp_avg_in, exp_avg_sq_in}, {max_exp_avg_sq_in})), + amsgrad.value_or(false) ? std::optional(operation::get_workers_for_op_output({param_in, grad, exp_avg_in, exp_avg_sq_in}, {max_exp_avg_sq_in})) : std::nullopt,}; } } // namespace ttnn::operations::moreh::moreh_adamw diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_adamw/moreh_adamw.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_adamw/moreh_adamw.hpp index 11d8e294ff4..c2ebc82458c 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_adamw/moreh_adamw.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_adamw/moreh_adamw.hpp @@ -33,10 +33,7 @@ struct MorehAdamw { const std::optional& memory_config, const std::optional compute_kernel_config); - static std::vector create_async_output_tensors( - const std::vector& input_tensors, const std::vector>& optional_inputs); - - static std::vector create_async_return_flag( + static OptionalTensors create_async_optional_output_tensors( const Tensor& param_in, const Tensor& grad, const Tensor& exp_avg_in, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_dot_backward/moreh_dot_backward.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_dot_backward/moreh_dot_backward.cpp index efc3a542a15..9f30eccc970 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_dot_backward/moreh_dot_backward.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_dot_backward/moreh_dot_backward.cpp @@ -17,26 +17,16 @@ std::vector> MorehDotBackward::invoke( return ttnn::prim::moreh_dot_backward(output_grad, input, other, input_grad, other_grad, memory_config); } -std::vector MorehDotBackward::create_async_output_tensors( - const std::vector &input_tensors, const std::vector> &optional_inputs) { - auto output_grad = input_tensors.at(0); - auto input = input_tensors.at(1); - auto other = input_tensors.at(2); - - return { - Tensor(operation::get_workers_for_op_output({output_grad, input, other})), - Tensor(operation::get_workers_for_op_output({output_grad, input, other})), - }; -} - -std::vector MorehDotBackward::create_async_return_flag( +OptionalTensors MorehDotBackward::create_async_optional_output_tensors( const Tensor &output_grad, const Tensor &input, const Tensor &other, const std::optional &input_grad, const std::optional &other_grad, const std::optional &memory_config) { - return {input_grad.has_value(), other_grad.has_value()}; + return { + input_grad.has_value() ? std::optional(operation::get_workers_for_op_output({output_grad, input, other})) : std::nullopt, + other_grad.has_value() ? std::optional(operation::get_workers_for_op_output({output_grad, input, other})) : std::nullopt}; } } // namespace ttnn::operations::moreh::moreh_dot_backward diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_dot_backward/moreh_dot_backward.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_dot_backward/moreh_dot_backward.hpp index e984c6d70ef..e66ccd59bb7 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_dot_backward/moreh_dot_backward.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_dot_backward/moreh_dot_backward.hpp @@ -14,10 +14,7 @@ struct MorehDotBackward { const std::optional &other_grad, const std::optional &memory_config); - static std::vector create_async_output_tensors( - const std::vector &input_tensors, const std::vector> &optional_inputs); - - static std::vector create_async_return_flag( + static OptionalTensors create_async_optional_output_tensors( const Tensor &output_grad, const Tensor &input, const Tensor &other, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/moreh_group_norm.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/moreh_group_norm.cpp index c000b7c6c87..2efc14400c8 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/moreh_group_norm.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/moreh_group_norm.cpp @@ -36,15 +36,8 @@ std::vector> MorehGroupNorm::invoke( rstd_memory_config, compute_kernel_config); } -std::vector MorehGroupNorm::create_async_output_tensors( - const std::vector& input_tensors, const std::vector>& optional_inputs) { - return { - Tensor(operation::get_workers_for_op_output(input_tensors, optional_inputs)), - Tensor(operation::get_workers_for_op_output(input_tensors, optional_inputs)), - Tensor(operation::get_workers_for_op_output(input_tensors, optional_inputs)), - }; -} -std::vector MorehGroupNorm::create_async_return_flag( + +OptionalTensors MorehGroupNorm::create_async_optional_output_tensors( const Tensor& input, const uint32_t num_groups, const float eps, @@ -58,6 +51,9 @@ std::vector MorehGroupNorm::create_async_return_flag( const std::optional& mean_memory_config, const std::optional& rstd_memory_config, const std::optional& compute_kernel_config) { - return are_required_outputs; + return { + are_required_outputs.at(0) ? std::optional(operation::get_workers_for_op_output({input}, {gamma, beta})) : std::nullopt, + are_required_outputs.at(1) ? std::optional(operation::get_workers_for_op_output({input}, {gamma, beta})) : std::nullopt, + are_required_outputs.at(2) ? std::optional(operation::get_workers_for_op_output({input}, {gamma, beta})) : std::nullopt}; } } // namespace ttnn::operations::moreh::moreh_group_norm diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/moreh_group_norm.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/moreh_group_norm.hpp index 6fd10773a2c..b557c2ce82f 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/moreh_group_norm.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/moreh_group_norm.hpp @@ -22,9 +22,8 @@ struct MorehGroupNorm { const std::optional& mean_memory_config, const std::optional& rstd_memory_config, const std::optional& compute_kernel_config); - static std::vector create_async_output_tensors( - const std::vector& input_tensors, const std::vector>& optional_inputs); - static std::vector create_async_return_flag( + + static OptionalTensors create_async_optional_output_tensors( const Tensor& input, const uint32_t num_groups, const float eps, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm_backward/moreh_group_norm_backward.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm_backward/moreh_group_norm_backward.cpp index e4e4a190f61..e0f05e63a7b 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm_backward/moreh_group_norm_backward.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm_backward/moreh_group_norm_backward.cpp @@ -24,6 +24,7 @@ std::vector> MorehGroupNormBackward::invoke( const std::optional& beta_grad_memory_config, const std::optional& compute_kernel_config) { std::vector> outputs; + if (are_required_outputs[0]) { outputs.push_back(ttnn::prim::moreh_group_norm_backward_input_grad( output_grad, @@ -61,16 +62,8 @@ std::vector> MorehGroupNormBackward::invoke( } return std::move(outputs); } -std::vector MorehGroupNormBackward::create_async_output_tensors( - const std::vector& input_tensors, const std::vector>& optional_inputs) { - return { - Tensor(operation::get_workers_for_op_output(input_tensors, optional_inputs)), - Tensor(operation::get_workers_for_op_output(input_tensors, optional_inputs)), - Tensor(operation::get_workers_for_op_output(input_tensors, optional_inputs)), - }; -} -std::vector MorehGroupNormBackward::create_async_return_flag( +OptionalTensors MorehGroupNormBackward::create_async_optional_output_tensors( const Tensor& output_grad, const Tensor& input, const Tensor& mean, @@ -85,6 +78,18 @@ std::vector MorehGroupNormBackward::create_async_return_flag( const std::optional& gamma_grad_memory_config, const std::optional& beta_grad_memory_config, const std::optional& compute_kernel_config) { - return are_required_outputs; + + TT_FATAL(are_required_outputs[0] or are_required_outputs[1] or are_required_outputs[2], "backward is called, but all gradients are not required"); + + return { + are_required_outputs.at(0) + ? std::optional(operation::get_workers_for_op_output({output_grad, input, mean, rstd}, {gamma})) + : std::nullopt, + are_required_outputs.at(1) + ? std::optional(operation::get_workers_for_op_output({output_grad, input, mean, rstd}, {gamma})) + : std::nullopt, + are_required_outputs.at(2) + ? std::optional(operation::get_workers_for_op_output({output_grad, input, mean, rstd}, {gamma})) + : std::nullopt}; } } // namespace ttnn::operations::moreh::moreh_group_norm_backward diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm_backward/moreh_group_norm_backward.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm_backward/moreh_group_norm_backward.hpp index 99f64c5f628..d9d474fbeb6 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm_backward/moreh_group_norm_backward.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm_backward/moreh_group_norm_backward.hpp @@ -23,9 +23,8 @@ struct MorehGroupNormBackward { const std::optional& gamma_grad_memory_config, const std::optional& beta_grad_memory_config, const std::optional& compute_kernel_config); - static std::vector create_async_output_tensors( - const std::vector& input_tensors, const std::vector>& optional_inputs); - static std::vector create_async_return_flag( + + static OptionalTensors create_async_optional_output_tensors( const Tensor& output_grad, const Tensor& input, const Tensor& mean, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/moreh_layer_norm.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/moreh_layer_norm.cpp index d8268c7dbbb..d27c789a098 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/moreh_layer_norm.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/moreh_layer_norm.cpp @@ -22,16 +22,7 @@ std::vector> MorehLayerNorm::invoke( input, normalized_dims, eps, gamma, beta, output, mean, rstd, memory_config, compute_kernel_config); } -std::vector MorehLayerNorm::create_async_output_tensors( - const std::vector& input_tensors, const std::vector>& optional_inputs) { - const auto& input = input_tensors.at(0); - return { - Tensor(operation::get_workers_for_op_output({input})), - Tensor(operation::get_workers_for_op_output({input})), - Tensor(operation::get_workers_for_op_output({input}))}; -} - -std::vector MorehLayerNorm::create_async_return_flag( +OptionalTensors MorehLayerNorm::create_async_optional_output_tensors( const Tensor& input, const uint32_t normalized_dims, const float eps, @@ -45,6 +36,9 @@ std::vector MorehLayerNorm::create_async_return_flag( const auto return_mean = mean.has_value(); const auto return_rstd = rstd.has_value(); - return {true, return_mean, return_rstd}; + return { + std::optional(operation::get_workers_for_op_output({input}, {gamma, beta})), + return_mean ? std::optional(operation::get_workers_for_op_output({input}, {gamma, beta})) : std::nullopt, + return_rstd ? std::optional(operation::get_workers_for_op_output({input}, {gamma, beta})) : std::nullopt}; } } // namespace ttnn::operations::moreh::moreh_layer_norm diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/moreh_layer_norm.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/moreh_layer_norm.hpp index 0707982c08e..b3e3fcd890a 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/moreh_layer_norm.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/moreh_layer_norm.hpp @@ -21,11 +21,8 @@ struct MorehLayerNorm { const std::optional& memory_config, const std::optional& compute_kernel_config); - static std::vector create_async_output_tensors( - const std::vector& input_tensors, const std::vector>& optional_inputs); - // The parameters of this function must be identical to those of invoke. - static std::vector create_async_return_flag( + static OptionalTensors create_async_optional_output_tensors( const Tensor& input, const uint32_t normalized_dims, const float eps, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/moreh_layer_norm_backward.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/moreh_layer_norm_backward.cpp index 4226e79967b..148be4c24e9 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/moreh_layer_norm_backward.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/moreh_layer_norm_backward.cpp @@ -70,16 +70,7 @@ std::vector> MorehLayerNormBackward::invoke( return outputs; } -std::vector MorehLayerNormBackward::create_async_output_tensors( - const std::vector& input_tensors, const std::vector>& optional_inputs) { - const auto& output_grad = input_tensors.at(0); - return { - Tensor(operation::get_workers_for_op_output({output_grad})), - Tensor(operation::get_workers_for_op_output({output_grad})), - Tensor(operation::get_workers_for_op_output({output_grad}))}; -} - -std::vector MorehLayerNormBackward::create_async_return_flag( +OptionalTensors MorehLayerNormBackward::create_async_optional_output_tensors( const Tensor& output_grad, const Tensor& input, const Tensor& mean, @@ -94,6 +85,9 @@ std::vector MorehLayerNormBackward::create_async_return_flag( const auto return_input_grad = input_grad.has_value(); const auto return_gamma_grad = gamma_grad.has_value(); const auto return_beta_grad = beta_grad.has_value(); - return {return_input_grad, return_gamma_grad, return_beta_grad}; + return { + return_input_grad ? std::optional(operation::get_workers_for_op_output({output_grad, input, mean, rstd}, {gamma})) : std::nullopt, + return_gamma_grad ? std::optional(operation::get_workers_for_op_output({output_grad, input, mean, rstd}, {gamma})) : std::nullopt, + return_beta_grad ? std::optional(operation::get_workers_for_op_output({output_grad, input, mean, rstd}, {gamma})) : std::nullopt}; } } // namespace ttnn::operations::moreh::moreh_layer_norm_backward diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/moreh_layer_norm_backward.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/moreh_layer_norm_backward.hpp index 95d3fbfc278..5e89ad87bc2 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/moreh_layer_norm_backward.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/moreh_layer_norm_backward.hpp @@ -22,11 +22,8 @@ struct MorehLayerNormBackward { const std::optional& memory_config, const std::optional& compute_kernel_config); - static std::vector create_async_output_tensors( - const std::vector& input_tensors, const std::vector>& optional_inputs); - // The parameters of this function must be identical to those of invoke. - static std::vector create_async_return_flag( + static OptionalTensors create_async_optional_output_tensors( const Tensor& output_grad, const Tensor& input, const Tensor& mean, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/moreh_linear_backward.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/moreh_linear_backward.cpp index 3e859d075fd..bc72f47a2ac 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/moreh_linear_backward.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/moreh_linear_backward.cpp @@ -105,17 +105,6 @@ bool is_same_batch_dim(const Tensor& tensor_a, const Tensor& tensor_b) { return true; } -std::vector MorehLinearBackward::create_async_output_tensors( - const std::vector& input_tensors, const std::vector>& optional_inputs) { - const auto& output_grad = input_tensors.at(0); - const auto& input = input_tensors.at(1); - const auto& weight = input_tensors.at(2); - return { - Tensor(operation::get_workers_for_op_output({output_grad, input, weight})), - Tensor(operation::get_workers_for_op_output({output_grad, input, weight})), - Tensor(operation::get_workers_for_op_output({output_grad, input, weight}))}; -} - std::vector> MorehLinearBackward::invoke( const Tensor& output_grad, const Tensor& input, @@ -197,7 +186,7 @@ std::vector> MorehLinearBackward::invoke( return result; } -std::vector MorehLinearBackward::create_async_return_flag( +OptionalTensors MorehLinearBackward::create_async_optional_output_tensors( const Tensor& output_grad, const Tensor& input, const Tensor& weight, @@ -210,7 +199,10 @@ std::vector MorehLinearBackward::create_async_return_flag( const std::optional& weight_grad_memory_config, const std::optional& bias_grad_memory_config, const DeviceComputeKernelConfig compute_kernel_config) { - return are_required_outputs; + return { + are_required_outputs.at(0) ? std::optional(operation::get_workers_for_op_output({output_grad, input, weight})) : std::nullopt, + are_required_outputs.at(1) ? std::optional(operation::get_workers_for_op_output({output_grad, input, weight})) : std::nullopt, + are_required_outputs.at(2) ? std::optional(operation::get_workers_for_op_output({output_grad, input, weight})) : std::nullopt}; } } // namespace ttnn::operations::moreh::moreh_linear_backward diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/moreh_linear_backward.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/moreh_linear_backward.hpp index 75177991a64..43415c00326 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/moreh_linear_backward.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/moreh_linear_backward.hpp @@ -12,9 +12,6 @@ namespace ttnn::operations::moreh::moreh_linear_backward { struct MorehLinearBackward { static std::tuple get_required_outputs(const std::vector& are_required_outputs); - static std::vector create_async_output_tensors( - const std::vector& input_tensors, const std::vector>& optional_inputs); - static std::vector> invoke( const Tensor& output_grad, const Tensor& input, @@ -29,7 +26,7 @@ struct MorehLinearBackward { const std::optional& bias_grad_memory_config, const DeviceComputeKernelConfig compute_kernel_config); - static std::vector create_async_return_flag( + static OptionalTensors create_async_optional_output_tensors( const Tensor& output_grad, const Tensor& input, const Tensor& weight, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.cpp index 2b66b4459ec..0494ef1b96b 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.cpp @@ -81,19 +81,7 @@ std::vector> MorehMatmulBackward::invoke( return outputs; } -std::vector MorehMatmulBackward::create_async_output_tensors( - const std::vector& input_tensors, const std::vector>& optional_inputs) { - const auto& output_grad = input_tensors.at(0); - const auto& input = input_tensors.at(1); - const auto& other = input_tensors.at(2); - - return { - Tensor(operation::get_workers_for_op_output({output_grad, input, other})), - Tensor(operation::get_workers_for_op_output({output_grad, input, other})), - }; -} - -std::vector MorehMatmulBackward::create_async_return_flag( +OptionalTensors MorehMatmulBackward::create_async_optional_output_tensors( const Tensor& output_grad, const Tensor& input, const Tensor& other, @@ -102,7 +90,9 @@ std::vector MorehMatmulBackward::create_async_return_flag( const std::optional& other_grad, const std::optional& memory_config, const std::optional compute_kernel_config) { - return {are_required_outputs.at(0), are_required_outputs.at(1)}; + return { + are_required_outputs.at(0) ? std::optional(operation::get_workers_for_op_output({output_grad, input, other})) : std::nullopt, + are_required_outputs.at(1) ? std::optional(operation::get_workers_for_op_output({output_grad, input, other})) : std::nullopt}; } } // namespace ttnn::operations::moreh::moreh_matmul_backward diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.hpp index 05f6967f99e..1819649240f 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.hpp @@ -17,10 +17,7 @@ struct MorehMatmulBackward { const std::optional& memory_config, const std::optional compute_kernel_config); - static std::vector create_async_output_tensors( - const std::vector& input_tensors, const std::vector>& optional_inputs); - - static std::vector create_async_return_flag( + static OptionalTensors create_async_optional_output_tensors( const Tensor& output_grad, const Tensor& input, const Tensor& other, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sgd/moreh_sgd.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sgd/moreh_sgd.cpp index 773b54a5790..ce0997d1e5a 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sgd/moreh_sgd.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sgd/moreh_sgd.cpp @@ -39,16 +39,7 @@ std::vector> MorehSgd::invoke( compute_kernel_config); } -std::vector MorehSgd::create_async_output_tensors( - const std::vector& input_tensors, const std::vector>& optional_inputs) { - const auto& param_in = input_tensors.at(0); - const auto& grad = input_tensors.at(1); - return { - Tensor(operation::get_workers_for_op_output({param_in, grad})), - Tensor(operation::get_workers_for_op_output({param_in, grad}))}; -} - -std::vector MorehSgd::create_async_return_flag( +OptionalTensors MorehSgd::create_async_optional_output_tensors( const Tensor& param_in, const Tensor& grad, const std::optional& momentum_buffer_in, @@ -63,9 +54,9 @@ std::vector MorehSgd::create_async_return_flag( const std::optional& param_out_memory_config, const std::optional& momentum_buffer_out_memory_config, const std::optional& compute_kernel_config) { - if (momentum != 0.0f) - return {true, true}; - return {true, false}; + return { + std::optional(operation::get_workers_for_op_output({param_in, grad}, {momentum_buffer_in})), + (momentum != 0.0f) ? std::optional(operation::get_workers_for_op_output({param_in, grad}, {momentum_buffer_in})) : std::nullopt}; } } // namespace ttnn::operations::moreh::moreh_sgd diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sgd/moreh_sgd.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sgd/moreh_sgd.hpp index 63d95829680..f076070a578 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sgd/moreh_sgd.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sgd/moreh_sgd.hpp @@ -25,10 +25,7 @@ struct MorehSgd { const std::optional& momentum_buffer_out_memory_config, const std::optional& compute_kernel_config); - static std::vector create_async_output_tensors( - const std::vector& input_tensors, const std::vector>& optional_inputs); - - static std::vector create_async_return_flag( + static OptionalTensors create_async_optional_output_tensors( const Tensor& param_in, const Tensor& grad, const std::optional& momentum_buffer_in, diff --git a/ttnn/cpp/ttnn/run_operation_inl.hpp b/ttnn/cpp/ttnn/run_operation_inl.hpp index db57a4e7d2f..687a821f115 100644 --- a/ttnn/cpp/ttnn/run_operation_inl.hpp +++ b/ttnn/cpp/ttnn/run_operation_inl.hpp @@ -211,6 +211,7 @@ void launch_op( for (int i = 0; i < local_tensors.size(); i++) { auto output_tensor = get_tensor(outputs[i]); auto local_tensor = get_tensor(local_tensors[i]); + // not sure if it the case but in my opinion it should not happen // both output and local tensor should be presented or absent TT_ASSERT((output_tensor != nullptr && local_tensor != nullptr) || (local_tensor == nullptr && output_tensor == nullptr)); @@ -218,11 +219,6 @@ void launch_op( continue; } - // The return type is vector>, and this refers to the case where the i-th value is nullopt. - if (output_tensor->tensor_attributes.use_count() != 0 && local_tensor->tensor_attributes.use_count() == 0) { - continue; - } - if (std::holds_alternative(local_tensor->tensor_attributes->storage)) { TT_ASSERT( output_tensor->tensor_attributes->dynamic_storage,