From 8a98874e7e09e78b3720e291545b2e9823b8151f Mon Sep 17 00:00:00 2001 From: pengwa Date: Tue, 21 May 2024 13:38:19 +0800 Subject: [PATCH] Flash attention recompute (#20603) ### Flash attn recompute 1. Allow PythonOp(FlashAttn) can be recomputed correctly. https://github.com/microsoft/onnxruntime/pull/20603/commits/45879ff5c20bf4cc11840b38b1808572126c5368 2. Use JSON to pass the selected-to-recompute subgraphs. https://github.com/microsoft/onnxruntime/pull/20603/commits/3c374da6788474cd09ba931eb0b00a45fa3f43e0 #### Better Memory Efficiency Customer model can run both PyTorch SPDA and Flash Attn, this PR make it possible to let the Flash Attn path work with ORTModule layerwise recompute. The peak drop from 45.xGB to 32.xGB if we only compare the layers (not including other pieces, BTW there are few more optimization targeting other pieces as well later). #### Better Perf Using Flash ATTN bring additionally 16% end to end time reduction, with highly aligned loss curve. ![image](https://github.com/microsoft/onnxruntime/assets/10530022/bb63894a-f281-49bc-a8e6-ff818439be38) #### Use JSON File to pass Recompute Plans To overcome the limitation of max length of the strings defined in session options. ### Motivation and Context --- cmake/onnxruntime_optimizer.cmake | 1 + docs/Memory_Optimizer.md | 36 +- .../onnxruntime_session_options_config_keys.h | 23 +- onnxruntime/core/session/inference_session.cc | 8 +- .../orttraining/core/agent/training_agent.cc | 2 + .../orttraining/core/agent/training_agent.h | 1 + .../core/optimizer/memory_optimizer/common.cc | 76 +-- .../core/optimizer/memory_optimizer/common.h | 2 +- .../memory_optimizer/memory_insight.cc | 65 ++- .../memory_optimizer/memory_insight.h | 5 +- .../memory_optimizer/memory_optimizer.cc | 149 ++++-- .../memory_optimizer/memory_optimizer.h | 16 +- .../memory_optimizer/recompute_analysis.cc | 56 ++- .../memory_optimizer/recompute_analysis.h | 5 +- .../memory_optimizer/transformer_specific.cc | 24 +- .../python/orttraining_pybind_state.cc | 12 +- .../_custom_autograd_function_exporter.py | 471 ++++++++++++------ .../training/ortmodule/_execution_agent.py | 4 +- .../ortmodule/_graph_execution_manager.py | 16 +- .../training/ortmodule/_runtime_inspector.py | 89 ++-- .../training/ortmodule/_training_manager.py | 3 +- .../python/training/ortmodule/options.py | 13 +- .../test/optimizer/memory_optimizer_test.cc | 65 ++- .../python/orttraining_test_ortmodule_api.py | 182 +++++++ .../orttraining_test_ortmodule_autograd.py | 59 +++ 25 files changed, 1002 insertions(+), 381 deletions(-) diff --git a/cmake/onnxruntime_optimizer.cmake b/cmake/onnxruntime_optimizer.cmake index 9ebd52f986207..3bae1b8a48e0f 100644 --- a/cmake/onnxruntime_optimizer.cmake +++ b/cmake/onnxruntime_optimizer.cmake @@ -113,6 +113,7 @@ onnxruntime_add_include_to_target(onnxruntime_optimizer onnxruntime_common onnxr target_include_directories(onnxruntime_optimizer PRIVATE ${ONNXRUNTIME_ROOT}) if (onnxruntime_ENABLE_TRAINING) target_include_directories(onnxruntime_optimizer PRIVATE ${ORTTRAINING_ROOT}) + onnxruntime_add_include_to_target(onnxruntime_optimizer nlohmann_json::nlohmann_json) if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) onnxruntime_add_include_to_target(onnxruntime_optimizer Python::Module) endif() diff --git a/docs/Memory_Optimizer.md b/docs/Memory_Optimizer.md index d08ba7b8f83c2..f8e015c3db9e4 100644 --- a/docs/Memory_Optimizer.md +++ b/docs/Memory_Optimizer.md @@ -31,7 +31,14 @@ Integrate models using `ORTModule`. There are two modes to enable the memory optimizations: - Transformer layerwise recompute, e.g. aggressively recompute all supported nodes within each transformer layer (usually including attention and mlp sublayers), enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=1`. In this mode, `ORTMODULE_MEMORY_OPT_CONFIG` env values passed by users are not respected. -- Manual selected subgraph recompute, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=0` and `export ORTMODULE_MEMORY_OPT_CONFIG=,,...`. This is an advanced usage, that allows users to find the most suitable graphs to recompute, at the cost of overhead to look for the best plans. +- Manual selected subgraph recompute, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=0` and `export ORTMODULE_MEMORY_OPT_CONFIG=`. This is an advanced usage, that allows users to find the most suitable graphs to recompute, at the cost of overhead to look for the best plans. The format for its content is: + ``` + [ + "", + "", + ... + ] + ``` ### Mode 1 - Simple Usage (Transformer Layerwise Recompute) @@ -39,7 +46,7 @@ There are two modes to enable the memory optimizations: 1. Set memory optimization level to be TRANSFORMER_LAYERWISE_RECOMPUTE, by `export ORTMODULE_MEMORY_OPT_LEVEL=1` 2. Run the training as usual; check the logs, you could find something like this if the current log level <= LogLevel.INFO: ``` - Memory Optimizer : ON : Memory Optimization Level: [TRANSFORMER_LAYERWISE_RECOMPUTE], Optimization Config: [Reshape+Where+:1:-1,BiasSoftmax+:1:-1,Cast+:1:-1,BiasGelu+:1:-1,FusedMatMul+:1:-1,Add+:1:-1,Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1] + Memory Optimizer : ON : Memory Optimization Level: [TRANSFORMER_LAYERWISE_RECOMPUTE], Optimization Config: mem_opt.json Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes) - Plan 1 : ON : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 - Plan 2 : ON : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) @@ -59,7 +66,7 @@ There are two modes to enable the memory optimizations: 1. Be noted `ORTMODULE_MEMORY_OPT_LEVEL` is by default be 0. Run the training as usual; then stop it after training a few steps. 2. Check the logs, you could find something like this if the current log level <= LogLevel.INFO:: ``` - Memory Optimizer : OFF : Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1 or ORTMODULE_MEMORY_OPT_CONFIG=,,... + Memory Optimizer : OFF : Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1 or ORTMODULE_MEMORY_OPT_CONFIG= Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes) - Plan 1 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 - Plan 2 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) @@ -73,8 +80,15 @@ There are two modes to enable the memory optimizations: 3. As shown above, `Config` is a string representative for a re-computable subgraph. All are disabled for recompute in this case. 4. Set environment variable `ORTMODULE_MEMORY_OPT_CONFIG` to enable some of the subgraphs to do recompute. ```bash - # Use comma as a separator for enabling more than one subgraphs. - export ORTMODULE_MEMORY_OPT_CONFIG="BiasGelu+:1:1" + export ORTMODULE_MEMORY_OPT_CONFIG="mem_opt.json" + + # Content of mem_opt.json: + [ + "BiasGelu+:1:1", + "Dropout+:1:-1" + ] + # Use comma as a separator for enabling more than one subgraphs in the json file. + # Explanation: # > BiasGelu+ is the subgraph string representative; # > 1 in the middle indicates 'Recompute' is enabled (0, on the contrary indicates it's disabled) @@ -83,7 +97,7 @@ There are two modes to enable the memory optimizations: ``` 5. Then run the training again, and you will see logs like this: ``` - Memory Optimizer : ON : Memory Optimization Level: [USER_SPECIFIED], Optimization Config: [BiasGelu+:1:-1] + Memory Optimizer : ON : Memory Optimization Level: [USER_SPECIFIED], Optimization Config: mem_opt.json Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes) - Plan 1 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 - Plan 2 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) @@ -127,7 +141,7 @@ MemoryInsight Summary - User config: not provided |6 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph FusedMatMul+Add+Reshape+ | -| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+Add+Reshape+:1:-1 | +| | Status : Disabled. | | | Stashed Activations: | | | - ReuseFreq : Output 0(6), | | | - Output 0 : [((inputs_input_ids_dim0)*(inputs_input_ids_dim1)*(32)*(240))], byte/elem: 2, 100% saved | @@ -135,26 +149,26 @@ MemoryInsight Summary - User config: not provided |5 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph FusedMatMul+ | -| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+:1:-1 | +| | Status : Disabled. | | | Stashed Activations: | | | - Output 0 : [((inputs_input_ids_dim0)*(inputs_input_ids_dim1)*(10240))], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _| |5 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph Cast+ | -| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:1:-1 | +| | Status : Disabled. | | | Stashed Activations: | | | - Output 0 : [((inputs_input_ids_dim0)*(32)*(inputs_input_ids_dim1)*(inputs_input_ids_dim1))], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _| |1 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+ | -| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 | +| | Status : Disabled. | | | Stashed Activations: | | | - Output 0 : [((inputs_input_ids_dim0)*(1)*(1)*(inputs_input_ids_dim1))], byte/elem: 4, 100% saved | | | | | |>>Option 2 : RecomputeWithCompromise subgraph Cast+ | -| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:2:-1 | +| | Status : Disabled. | | | Stashed Activations: | | | - Output 0 : [((inputs_input_ids_dim0)*(1)*(1)*(inputs_input_ids_dim1))], byte/elem: 4, 50% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _| diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index bb5e0344895e0..c32e2a77e8453 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -78,15 +78,20 @@ static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimizati static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining = "session.disable_aot_function_inlining"; #ifdef ENABLE_TRAINING -// Specifies a list of op types for memory footprint reduction. -// The value should be a ","-delimited list of pair of -// . -// For example, "Gelu+Cast+:1:0,Dropout+:1:1". -// A valid "subgraph string" should be one subgraph representation output by ORT graph transformations. -// "optimization strategy" currently has valid values: 0 - disabled, 1 - recompute. -// "number of subgraph to apply" is used to control how many subgraphs to apply optimization, to avoid "oversaving" -// the memory. -static const char* const kOrtSessionOptionsMemoryOptimizerEnabler = "optimization.memory_optimizer_config"; +// Specifies a path of the file containing a list of memory optimization configurations. +// The value should be a string indicating the file path of the config file. +// The content of the config file is a JSON struct like this: +// [ +// "Gelu+Cast+:1:0", +// "Dropout+:1:1" +// ] +// Taking the example of "Gelu+Cast+:1:0", +// > "Gelu+Cast+" is the subgraph string, a valid "subgraph string" should be one subgraph representation +// output by ORT graph transformations. +// > "1" is "optimization strategy", valid values: 0 - disabled, 1 - recompute. +// > "0" is "number of subgraph to apply" which is used to control how many subgraphs to apply optimization, +// to avoid "oversaving" the memory. +static const char* const kOrtSessionOptionsMemoryOptimizerApplyConfig = "optimization.memory_optimizer_config"; // Specifies the config for detecting subgraphs for memory footprint reduction. // The value should be a string contains int separated using commas. The default value is "0:0". diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 16f0752e3f603..d1add79f0cb00 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1266,15 +1266,15 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool } #ifdef ENABLE_TRAINING - // Enable memory optimizations (mainly insert recomputation nodes with priority). + // Enable memory optimizations. // Only applicable for training scenarios. { - const std::string memory_optimizer_config = - session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsMemoryOptimizerEnabler, ""); + const std::string memory_optimizer_config_file = + session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsMemoryOptimizerApplyConfig, ""); const std::string probe_config = session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsMemoryOptimizerProbeConfig, "0:0"); - MemoryOptimizer mem_transformer{memory_optimizer_config, probe_config}; + MemoryOptimizer mem_transformer{memory_optimizer_config_file, probe_config}; ORT_RETURN_IF_ERROR_SESSIONID_(apply_transformer_once(mem_transformer, *session_logger_, graph)); } #endif diff --git a/orttraining/orttraining/core/agent/training_agent.cc b/orttraining/orttraining/core/agent/training_agent.cc index 0b38a79cc21c9..cc8d341dbc084 100644 --- a/orttraining/orttraining/core/agent/training_agent.cc +++ b/orttraining/orttraining/core/agent/training_agent.cc @@ -110,6 +110,7 @@ void TrainingAgent::CreateAndInitializeFeedsFetchesManager(const SessionState& s std::string TrainingAgent::GetSerializedORTModuleMemoryStat(std::string_view memory_optimization_config, std::string_view recompute_probe_level, + const bool return_opportunity_table, std::map>& cluster_id_combinations_to_saved_symbolic_byte_map) const { @@ -120,6 +121,7 @@ std::string TrainingAgent::GetSerializedORTModuleMemoryStat(std::string_view mem session_state.GetGraphViewer(), memory_optimization_config, recompute_probe_level, + return_opportunity_table, *inference_session_.GetLogger(), cluster_id_combinations_to_saved_symbolic_byte_map, &ortvalue_name_to_idx_map, diff --git a/orttraining/orttraining/core/agent/training_agent.h b/orttraining/orttraining/core/agent/training_agent.h index 37e5272f66e32..8d88a6df39352 100644 --- a/orttraining/orttraining/core/agent/training_agent.h +++ b/orttraining/orttraining/core/agent/training_agent.h @@ -51,6 +51,7 @@ class TrainingAgent { std::string GetSerializedORTModuleMemoryStat(std::string_view memory_optimization_config, std::string_view recompute_probe_level, + const bool return_opportunity_table, std::map>& cluster_id_combinations_to_saved_symbolic_byte_map) const; diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/common.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/common.cc index d522e60125c36..2a4aab7ab9b4d 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/common.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/common.cc @@ -2,19 +2,23 @@ // Licensed under the MIT License. #include +#include #include #include #include "orttraining/core/optimizer/memory_optimizer/common.h" +#include "core/common/string_utils.h" +#include "core/framework/tensorprotoutils.h" #include "core/graph/graph_utils.h" -#include "core/optimizer/utils.h" #include "core/graph/graph_viewer.h" -#include "core/framework/tensorprotoutils.h" +#include "core/optimizer/utils.h" -#include "core/common/string_utils.h" +#include "nlohmann/json.hpp" namespace onnxruntime::optimizer::memory_optimizer { +using json = nlohmann::json; + namespace { constexpr const char empty_dim_param_placeholder[] = "empty_dim_param"; @@ -114,32 +118,48 @@ int ParseIntValueFromString(std::string_view str) { return int_value; } -Status ParseOptimizationConfigFromString(std::string_view memory_optimization_config, +void from_json(const json& j, UserConfig& mo) { + j.at("type").get_to(mo.type); + j.at("requested_count").get_to(mo.requested_count); +} + +Status ParseOptimizationConfigFromString(std::string_view memory_optimization_config_file_path, InlinedHashMap& cluster_id_to_config_map) { - if (!memory_optimization_config.empty()) { - const auto user_config_strs = utils::SplitString(memory_optimization_config, ","); - for (const auto& user_config_str : user_config_strs) { - const auto user_config = utils::SplitString(user_config_str, ":"); - ORT_RETURN_IF_NOT(user_config.size() == 3, - "User config should be in the format of SubgraphStr:OptimizationType:RequestApplyCount."); - - const std::string subgraph_string_representation(user_config[0]); - int optimization_type_int = ParseIntValueFromString(user_config[1]); - int requested_apply_count = ParseIntValueFromString(user_config[2]); - ORT_RETURN_IF_NOT(optimization_type_int < - static_cast(OptimizationType::TypeMax) && - optimization_type_int >= 0, - "Invalid optimization type specified for subgraph: ", - subgraph_string_representation); - - ORT_RETURN_IF_NOT(requested_apply_count == -1 || requested_apply_count >= 0, - "Invalid requested_apply_count specified for subgraph: ", requested_apply_count); - - // At this point, subgraph_string_representation is a pattern graph string representation. - // If a duplicated subgraph_string_representation is found in user config, the last one will be used. - cluster_id_to_config_map[subgraph_string_representation] = UserConfig{ - static_cast(optimization_type_int), - requested_apply_count}; + if (!memory_optimization_config_file_path.empty()) { + InlinedVector configs_by_cluster_id; // Each cluster_id might contains multiple plans. + try { + std::ifstream in{std::string(memory_optimization_config_file_path).c_str()}; + const json j = json::parse(in); + j.get_to>(configs_by_cluster_id); + } catch (const std::exception& ex) { + ORT_THROW("Fail to parse from json file: ", ex.what()); + } + + for (const auto& config_for_cur_cluster : configs_by_cluster_id) { + const auto configs_by_plan_id = utils::SplitString(config_for_cur_cluster, ","); + for (const auto& config_for_cur_plan : configs_by_plan_id) { + const auto user_config = utils::SplitString(config_for_cur_plan, ":"); + ORT_RETURN_IF_NOT(user_config.size() == 3, + "User config should be in the format of SubgraphStr:OptimizationType:RequestApplyCount."); + + const std::string subgraph_string_representation(user_config[0]); + int optimization_type_int = ParseIntValueFromString(user_config[1]); + int requested_apply_count = ParseIntValueFromString(user_config[2]); + ORT_RETURN_IF_NOT(optimization_type_int < + static_cast(OptimizationType::TypeMax) && + optimization_type_int >= 0, + "Invalid optimization type specified for subgraph: ", + subgraph_string_representation); + + ORT_RETURN_IF_NOT(requested_apply_count == -1 || requested_apply_count >= 0, + "Invalid requested_apply_count specified for subgraph: ", requested_apply_count); + + // At this point, subgraph_string_representation is a pattern graph string representation. + // If a duplicated subgraph_string_representation is found in user config, the last one will be used. + cluster_id_to_config_map[subgraph_string_representation] = UserConfig{ + static_cast(optimization_type_int), + requested_apply_count}; + } } } diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/common.h b/orttraining/orttraining/core/optimizer/memory_optimizer/common.h index 952fe49ffa657..651efff785c1e 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/common.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/common.h @@ -70,7 +70,7 @@ std::string GetTensorElemCountInSymbolicString(const Node* node, size_t output_i int ParseIntValueFromString(std::string_view str); -Status ParseOptimizationConfigFromString(std::string_view memory_optimization_config, +Status ParseOptimizationConfigFromString(std::string_view memory_optimization_config_file_path, InlinedHashMap& cluster_id_to_config_map); constexpr const ExecutionOrder TOPOLOGICAL_SORT_ALGORITHM = ExecutionOrder::MEMORY_EFFICIENT; diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc index 038ff0049b32a..cd99c82d0e2f8 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc @@ -40,18 +40,14 @@ constexpr const int kTitleWidthInSecondColumn = 15; * @param fw_op_output_arg_used_map Collected activation usage mapping. * - key: node arg name * - value: a pair of bool, representing whether the activation is used by forward nodes or by backward nodes. - * @param is_forward_nodes Collected node is forward pass op mapping. */ void GetForwardOutputUsageMap(const GraphViewer& graph_viewer, const ptrdiff_t boundary_op_order_in_topological_sort, const InlinedHashMap& node_index_to_its_order_in_topological_sort_map, - ActivationUsedMap& fw_op_output_arg_used_map, - InlinedHashMap& is_forward_nodes) { + ActivationUsedMap& fw_op_output_arg_used_map) { ORT_ENFORCE(boundary_op_order_in_topological_sort >= 0); const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(TOPOLOGICAL_SORT_ALGORITHM); - is_forward_nodes.clear(); - is_forward_nodes.reserve(node_ids.size()); auto is_forward_pass_operator = [](ptrdiff_t op_order_in_topological_sort, ptrdiff_t boundary_op_order_in_topological_sort) -> bool { @@ -69,12 +65,9 @@ void GetForwardOutputUsageMap(const GraphViewer& graph_viewer, const Node& node = *p_node; bool is_forward_op = is_forward_pass_operator(static_cast(i), boundary_op_order_in_topological_sort); if (!is_forward_op) { - is_forward_nodes[p_node] = false; continue; } - is_forward_nodes[p_node] = true; - for (auto& output_arg : node.OutputDefs()) { if (!output_arg->Exists() || output_arg->Name().empty()) { continue; @@ -109,19 +102,15 @@ void GetForwardOutputUsageMap(const GraphViewer& graph_viewer, * * @param graph_viewer Graph to iterate. * @param boundary_op_order_in_topological_sort The order of the boundary op in the topological sort. - * @param fw_op_output_arg_used_map Activation usage mapping. - * @param candidate_output_args_map Candidate activations, which are consumed by both fw and bw ops. - * @param is_forward_nodes Whether a node is a forward node. + * @param candidate_output_args_map Candidate activations generated in forward, and are consumed by backward ops. * @param logger Logger. * @return Status */ Status GetStashedActivationCandidates(const GraphViewer& graph_viewer, const ptrdiff_t boundary_op_order_in_topological_sort, - ActivationUsedMap& fw_op_output_arg_used_map, InlinedHashMap>& candidate_output_args_map, - InlinedHashMap& is_forward_nodes, const logging::Logger& logger) { if (boundary_op_order_in_topological_sort < 0) { MO_LOG_DEBUG_INFO(logger, "No boundary op found. Skip memory optimization."); @@ -140,19 +129,20 @@ Status GetStashedActivationCandidates(const GraphViewer& graph_viewer, node_index_to_its_order_in_topological_sort_map[p_node->Index()] = i; } + ActivationUsedMap fw_op_output_arg_used_map; GetForwardOutputUsageMap(graph_viewer, boundary_op_order_in_topological_sort, node_index_to_its_order_in_topological_sort_map, - fw_op_output_arg_used_map, - is_forward_nodes); + fw_op_output_arg_used_map); for (auto& kv : fw_op_output_arg_used_map) { - // used by fw and bw, then it is a candidate. - if (kv.second.first && kv.second.second) { - const Node* n = graph_viewer.GetProducerNode(kv.first); + const auto& fw_out_arg = kv.first; + const Node* n = graph_viewer.GetProducerNode(fw_out_arg); + // Node run in forward pass, and the result is used by bw, then it is a candidate. + if (kv.second.second) { ORT_ENFORCE(n, "Activation should have a producer node"); size_t k = 0; for (k = 0; k < n->OutputDefs().size(); ++k) { - if (n->OutputDefs()[k]->Name().compare(kv.first) == 0) { + if (n->OutputDefs()[k]->Name().compare(fw_out_arg) == 0) { break; } } @@ -201,14 +191,9 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, node_index_to_its_order_in_topological_sort_map[p_node->Index()] = static_cast(i); } - ActivationUsedMap fw_op_output_arg_used_map; - - InlinedHashMap is_forward_nodes; ORT_RETURN_IF_ERROR(GetStashedActivationCandidates(graph_viewer, yield_op_order_in_topological_sort, - fw_op_output_arg_used_map, candidate_output_args_map, - is_forward_nodes, logger)); InlinedVector layer_boundary_ln_nodes; @@ -236,7 +221,6 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, CheckNodeForRecompute(graph_viewer, *p_node, probe_config, - fw_op_output_arg_used_map, node_index_to_its_order_in_topological_sort_map, candidate_output_args_map, layer_boundary_ln_nodes, @@ -254,7 +238,7 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, // If the subgraph recompute can save memory by comprising the assumption - recompute graphs' input must exist // during backward pass, then we can consider to recompute them. std::unique_ptr recompute_with_compromise_plan = - CheckNodeForRecompute(graph_viewer, *p_node, probe_config, fw_op_output_arg_used_map, + CheckNodeForRecompute(graph_viewer, *p_node, probe_config, node_index_to_its_order_in_topological_sort_map, candidate_output_args_map, layer_boundary_ln_nodes, @@ -618,8 +602,7 @@ void FormatRecomputeMemoryRecords(int option_index, ", actual applied count=" + std::to_string(actual_count)); } else { rows.push_back(empty_first_col + ToFixedLengthString(" Status", kTitleWidthInSecondColumn) + - ": Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=" + - subgraph_str + ":" + std::to_string(static_cast(opt_type)) + ":-1"); + ": Disabled."); } std::string activation_str = empty_first_col + " Stashed Activations: "; @@ -658,7 +641,7 @@ void FormatRecomputeMemoryRecords(int option_index, std::string SerializeMemoryRecords( const std::vector>& records_grouped_by_node_cluster_id, - std::string_view user_config) { + std::string_view memory_optimization_config_file_path) { InlinedVector rows; rows.push_back(kTableBorder); rows.push_back("|" + ToFixedLengthString("Freq", kFirstColumnWidth) + @@ -714,7 +697,10 @@ std::string SerializeMemoryRecords( std::string table_border_full(max_length, '='); std::ostringstream summary; summary << std::endl; - summary << MakeString("MemoryInsight Summary - User config: ", (user_config.empty() ? "not provided" : user_config)) + summary << MakeString("MemoryInsight Summary - User config file path: ", + (memory_optimization_config_file_path.empty() + ? "not provided" + : memory_optimization_config_file_path)) << std::endl; for (auto& row : rows) { if (row == kTableRowSeparator) { @@ -732,8 +718,9 @@ std::string SerializeMemoryRecords( } std::string GetSerializedORTModuleMemoryStat(const GraphViewer& graph_viewer, - std::string_view memory_optimization_config, + std::string_view memory_optimization_config_file_path, std::string_view recompute_probe_config, + const bool return_opportunity_table, const logging::Logger& logger, std::map>& cluster_id_combinations_to_saved_symbolic_byte_map, @@ -764,8 +751,8 @@ std::string GetSerializedORTModuleMemoryStat(const GraphViewer& graph_viewer, NodeToClusterApplyContextMap node_to_apply_context_map; - if (!memory_optimization_config.empty()) { - ORT_ENFORCE(ParseOptimizationConfigFromString(memory_optimization_config, cluster_id_to_config_map) + if (!memory_optimization_config_file_path.empty()) { + ORT_ENFORCE(ParseOptimizationConfigFromString(memory_optimization_config_file_path, cluster_id_to_config_map) .IsOK()); InlinedHashMap> node_to_opt_plan_map; ORT_ENFORCE(memory_opt_planner.FinalizeNodePlansFromUserConfig(cluster_id_to_config_map, @@ -781,12 +768,16 @@ std::string GetSerializedORTModuleMemoryStat(const GraphViewer& graph_viewer, .IsOK()); } - std::vector> records; - GetMemoryRecordsGroupedByNodeClusterId(memory_opt_planner, node_to_apply_context_map, records); - GetMemorySavingSymbolicString(memory_opt_planner, logger, cluster_id_combinations_to_saved_symbolic_byte_map); - return SerializeMemoryRecords(records, memory_optimization_config); + if (return_opportunity_table) { + std::vector> records; + GetMemoryRecordsGroupedByNodeClusterId(memory_opt_planner, node_to_apply_context_map, records); + return SerializeMemoryRecords(records, memory_optimization_config_file_path); + } + + // Otherwise, return empty. + return ""; } } // namespace onnxruntime::optimizer::memory_optimizer diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.h b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.h index ca1df0633eb8f..bf31513019387 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.h @@ -65,7 +65,8 @@ class MemoryRecord { * @param logger Logger. * @param node_index_to_its_order_in_topological_sort_map The mapping of node index to its order in topological sort. * @param yield_op_order_in_topological_sort The order of the boundary op in the topological sort. - * @param candidate_output_args_map A map from node to its candidate activations, which are consumed by both fw and + * @param candidate_output_args_map A map from node to its candidate activations. The candidate activations are + * generated by forward node, and consumed by backward nodes. * @param mem_opt_stats A store to maintain all found optimization plans for related nodes. * @return Status */ @@ -111,6 +112,7 @@ std::string SerializeMemoryRecords(const std::vector>& diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc index e0c255e37daf3..91649c61c46d8 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include "core/framework/random_seed.h" #include "core/framework/tensorprotoutils.h" @@ -50,14 +51,29 @@ bool SetSeedForDropoutNode(Node& node) { return false; } +bool SetTrainingModeForForwardPythonOpNode(Node& node) { + if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "PythonOp", {1}, kMSDomain)) { + auto training_mode_attr = graph_utils::GetNodeAttribute(node, "training_mode"); + if (training_mode_attr != nullptr) { + node.ClearAttribute("training_mode"); + } + + // Let forward node does not maintain information (ctx) for backward. + node.AddAttribute("training_mode", static_cast(0)); + return true; + } + + return false; +} + } // namespace -Status MemoryOptimizer::ParseOptimizationConfigFromString(const std::string& memory_optimizer_config, +Status MemoryOptimizer::ParseOptimizationConfigFromString(const std::string& memory_optimization_config_file_path, const std::string& recompute_probe_config) { - optimizer_config_ = memory_optimizer_config; + optimizer_config_file_path_ = memory_optimization_config_file_path; ORT_RETURN_IF_ERROR(optimizer::memory_optimizer::ParseOptimizationConfigFromString( - memory_optimizer_config, + memory_optimization_config_file_path, pattern_subgraph_to_user_optimizer_config_map_)); ORT_RETURN_IF_ERROR(optimizer::memory_optimizer::ParseProbeConfigFromString( @@ -70,8 +86,6 @@ Status MemoryOptimizer::ParseOptimizationConfigFromString(const std::string& mem bool MemoryOptimizer::ModifyGraph(Graph& graph, const InlinedHashMap& node_index_to_its_order_in_topological_sort_map, - const InlinedHashMap>& - candidate_output_args_map, const logging::Logger& logger, ptrdiff_t boundary_op_order_in_topological_sort, Node* node, @@ -100,43 +114,44 @@ bool MemoryOptimizer::ModifyGraph(Graph& graph, } else { ORT_THROW("unsupported optimization type found."); } + ORT_ENFORCE(replacement_node_ptr); graph_is_modified = true; - for (size_t output_index : candidate_output_args_map.at(node)) { - // Collect output edges (connecting to backward ops), to remove. - std::vector output_edges; - for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) { - size_t src_output_idx = static_cast(it->GetSrcArgIndex()); - if (src_output_idx != output_index) { + std::vector output_edges; + for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) { + auto tid = node_index_to_its_order_in_topological_sort_map.find(it->GetNode().Index()); + // It is possible the consumer node is newly added as the recompute node, so we need a check here. + // For those kinds of ops, we can treat them as backward ops. + if (tid == node_index_to_its_order_in_topological_sort_map.end() || + !IsForwardPassOperator(node_index_to_its_order_in_topological_sort_map.at(tid->first), + boundary_op_order_in_topological_sort)) { + // Ignore the rng state consumer update for the determinstic PythonOp. + if ((graph_utils::IsSupportedOptypeVersionAndDomain(*node, "PythonOp", {1}, kMSDomain) && + (it->GetSrcArgIndex() == 1 || it->GetSrcArgIndex() == 2))) { continue; } - - auto tid = node_index_to_its_order_in_topological_sort_map.find(it->GetNode().Index()); - // It is possible the consumer node is newly added as the recompute node, so we need a check here. - // For those kinds of ops, we can treat them as backward ops. - if (tid == node_index_to_its_order_in_topological_sort_map.end() || - !IsForwardPassOperator(node_index_to_its_order_in_topological_sort_map.at(tid->first), - boundary_op_order_in_topological_sort)) { - // Remove the edge only connecting to backward op. - output_edges.push_back(graph_utils::GraphEdge::CreateGraphEdge(*node, *it, false)); - } + // Remove the edge only connecting to backward op. + output_edges.push_back(graph_utils::GraphEdge::CreateGraphEdge(*node, *it, false)); } + } - if (!output_edges.empty()) { + if (!output_edges.empty()) { + // Create connections between the replacement node and the outgoing nodes. + for (const auto& output_edge : output_edges) { // Remove the output edges of the node first - graph_utils::GraphEdge::RemoveGraphEdges(graph, output_edges); + graph.RemoveEdge(output_edge.src_node, + output_edge.dst_node, + output_edge.src_arg_index, + output_edge.dst_arg_index); - // Create connections between the replacement node and the outgoing nodes. - for (const auto& output_edge : output_edges) { - graph.RemoveConsumerNode(node->MutableOutputDefs()[output_index]->Name(), node); + graph.RemoveConsumerNode(node->MutableOutputDefs()[output_edge.src_arg_index]->Name(), node); - // Add new edge connecting the input with the output nodes directly. - // This also updates the destination node's input node args - graph.AddEdge(replacement_node_ptr->Index(), output_edge.dst_node, static_cast(output_index), - output_edge.dst_arg_index); - } + // Add new edge connecting the input with the output nodes directly. + // This also updates the destination node's input node args + graph.AddEdge(replacement_node_ptr->Index(), output_edge.dst_node, static_cast(output_edge.src_arg_index), + output_edge.dst_arg_index); } } } @@ -146,7 +161,7 @@ bool MemoryOptimizer::ModifyGraph(Graph& graph, Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& logger) const { - LOGS(logger, VERBOSE) << "Memory optimization config: " << optimizer_config_ << ", probe level: " + LOGS(logger, VERBOSE) << "Memory optimization config: " << optimizer_config_file_path_ << ", probe level: " << static_cast(recompute_probe_config_.probe_level) << ", enable_transformer_layer_as_boundary:" << recompute_probe_config_.enable_transformer_layer_as_boundary; @@ -197,7 +212,7 @@ Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve bool has_been_modified = false; if (node_to_opt_plan_map.find(p_node) != node_to_opt_plan_map.end()) { has_been_modified = ModifyGraph(graph, node_index_to_its_order_in_topological_sort_map, - candidate_output_args_map, logger, + logger, yield_op_order_in_topological_sort, p_node, node_to_opt_plan_map[p_node], @@ -230,7 +245,7 @@ void MemoryOptimizer::PrintSummary(const optimizer::memory_optimizer::MemoryOpti optimizer::memory_optimizer::GetMemoryRecordsGroupedByNodeClusterId(memory_opt_planner, node_to_apply_contexts_map, records_grouped_by_node_cluster_id); - LOGS(logger, INFO) << SerializeMemoryRecords(records_grouped_by_node_cluster_id, optimizer_config_) << "\n"; + LOGS(logger, INFO) << SerializeMemoryRecords(records_grouped_by_node_cluster_id, optimizer_config_file_path_) << "\n"; } /****************************************************** @@ -259,8 +274,59 @@ Status MemoryOptimizer::CreateRecomputeGraph(Graph& graph, << ")."; } + const bool is_python_op = graph_utils::IsSupportedOptypeVersionAndDomain(*node_to_duplicate, "PythonOp", {1}, kMSDomain); + InlinedVector new_input_args; - new_input_args.reserve(node_to_duplicate->MutableInputDefs().size()); + NodeAttributes update_attrs = node_to_duplicate->GetAttributes(); + + if (is_python_op) { + new_input_args.reserve(node_to_duplicate->MutableInputDefs().size() + 2); + // Ignore the ctx input, and connect the rng output to the recompute node. + new_input_args.push_back(node_to_duplicate->MutableOutputDefs()[1]); + new_input_args.push_back(node_to_duplicate->MutableOutputDefs()[2]); + std::string input_convention = update_attrs.at("input_convention").s(); + input_convention[1] = 'd'; // Update the rng state input to be a tensor. + input_convention[2] = 'd'; + update_attrs["input_convention"] = ONNX_NAMESPACE::MakeAttribute("input_convention", input_convention); + + const auto& input_pointer_scalars_ints = update_attrs.at("input_pointer_scalars").ints(); + std::vector input_pointer_scalars(input_pointer_scalars_ints.begin(), + input_pointer_scalars_ints.end()); + // Remove the rng state input. + input_pointer_scalars.erase(input_pointer_scalars.begin() + 1, input_pointer_scalars.begin() + 3); + update_attrs["input_pointer_scalars"] = ONNX_NAMESPACE::MakeAttribute("input_pointer_scalars", + input_pointer_scalars); + + const auto& input_pointer_scalars_positions_ints = update_attrs.at("input_pointer_scalar_positions").ints(); + std::vector input_pointer_scalar_positions(input_pointer_scalars_positions_ints.begin(), + input_pointer_scalars_positions_ints.end()); + // Remove the rng state input. + input_pointer_scalar_positions.erase(input_pointer_scalar_positions.begin() + 1, + input_pointer_scalar_positions.begin() + 3); + update_attrs["input_pointer_scalar_positions"] = ONNX_NAMESPACE::MakeAttribute("input_pointer_scalar_positions", + input_pointer_scalar_positions); + + const auto& input_tensor_ranks_ints = update_attrs.at("input_tensor_ranks").ints(); + std::vector input_tensor_ranks(input_tensor_ranks_ints.begin(), + input_tensor_ranks_ints.end()); + // Insert the rng state input and cuda rng state input at the beginning. + input_tensor_ranks.insert(input_tensor_ranks.begin(), {1, 1}); + + update_attrs["input_tensor_ranks"] = ONNX_NAMESPACE::MakeAttribute("input_tensor_ranks", + input_tensor_ranks); + + const auto& input_tensor_types_ints = update_attrs.at("input_tensor_types").ints(); + std::vector input_tensor_types(input_tensor_types_ints.begin(), + input_tensor_types_ints.end()); + // Insert the uint8 type of rng state and cuda rng state at the beginning. + input_tensor_types.insert(input_tensor_types.begin(), {ONNX_NAMESPACE::TensorProto_DataType_UINT8, + ONNX_NAMESPACE::TensorProto_DataType_UINT8}); + update_attrs["input_tensor_types"] = ONNX_NAMESPACE::MakeAttribute("input_tensor_types", + input_tensor_types); + } else { + new_input_args.reserve(node_to_duplicate->MutableInputDefs().size()); + } + for (NodeArg* input_arg : node_to_duplicate->MutableInputDefs()) { if (self_contained_outputs_map.find(input_arg) == self_contained_outputs_map.end()) { NodeArg* recompute_input_arg = graph.GetNodeArg(graph_utils::RecomputeName(input_arg->Name())); @@ -285,7 +351,7 @@ Status MemoryOptimizer::CreateRecomputeGraph(Graph& graph, "Recompute of " + node_to_duplicate->Name(), new_input_args, new_output_args, - &node_to_duplicate->GetAttributes(), + &update_attrs, node_to_duplicate->Domain()); recompute_node.SetExecutionProviderType(node_to_duplicate->GetExecutionProviderType()); @@ -312,6 +378,17 @@ Status MemoryOptimizer::CreateRecomputeGraph(Graph& graph, graph.AddConsumerNode(input_arg->Name(), &recompute_node); } + + if (is_python_op) { + graph.AddConsumerNode(node_to_duplicate->MutableOutputDefs()[1]->Name(), &recompute_node); + graph.AddConsumerNode(node_to_duplicate->MutableOutputDefs()[2]->Name(), &recompute_node); + } + + bool training_mode_reset = SetTrainingModeForForwardPythonOpNode(*node_to_duplicate); + if (training_mode_reset) { + LOGS(logger, VERBOSE) << "Set training mode for Node " << node_to_duplicate->Name() + << "(" << node_to_duplicate->OpType() << ")."; + } } return Status::OK(); diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.h b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.h index 45d7c10cea41f..c50d5b5642624 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.h @@ -29,10 +29,13 @@ Find recompute subgraphs and enable them according to user configs. The way we c class MemoryOptimizer : public GraphTransformer { private: public: - MemoryOptimizer(const std::string& memory_optimizer_config, const std::string& recompute_probe_config) + MemoryOptimizer(const std::string& memory_optimization_config_file_path, + const std::string& recompute_probe_config) : GraphTransformer("MemoryOptimizer") { // Parse user-defined configs. - ORT_ENFORCE(ParseOptimizationConfigFromString(memory_optimizer_config, recompute_probe_config).IsOK()); + ORT_ENFORCE(ParseOptimizationConfigFromString( + memory_optimization_config_file_path, recompute_probe_config) + .IsOK()); } Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; @@ -40,7 +43,8 @@ class MemoryOptimizer : public GraphTransformer { bool ShouldOnlyApplyOnce() const override { return true; } private: - Status ParseOptimizationConfigFromString(const std::string& memory_optimizer_config, const std::string& recompute_probe_config); + Status ParseOptimizationConfigFromString(const std::string& memory_optimizer_config_file_path, + const std::string& recompute_probe_config); /** * @brief Apply graph modifications based on user configs. @@ -48,8 +52,6 @@ class MemoryOptimizer : public GraphTransformer { * @param graph Graph to iterate and modify. * @param node_index_to_its_order_in_topological_sort_map The mapping of node index to its order in topological sort. * Used to re-order the collected subgraph nodes. - * @param candidate_output_args_map A map from node to its candidate activations, which are consumed by both fw and - * bw ops. * @param logger Logger. * @param boundary_op_order_in_topological_sort index of the boundary op between fw and bw. * @param subgraph_stores A store to maintain all found subgraphs. @@ -60,8 +62,6 @@ class MemoryOptimizer : public GraphTransformer { bool ModifyGraph(Graph& graph, const InlinedHashMap& node_index_to_its_order_in_topological_sort_map, - const InlinedHashMap>& - candidate_output_args_map, const logging::Logger& logger, ptrdiff_t boundary_op_order_in_topological_sort, Node* node, @@ -103,7 +103,7 @@ class MemoryOptimizer : public GraphTransformer { // User-enabled map of the subgraph string representation to the alleviation type. InlinedHashMap pattern_subgraph_to_user_optimizer_config_map_; - std::string optimizer_config_; + std::string optimizer_config_file_path_; optimizer::memory_optimizer::ProbeConfig recompute_probe_config_; }; diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index be2f1387fb66e..088fd345135db 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -13,13 +13,15 @@ #include "orttraining/core/optimizer/memory_optimizer/recompute_analysis.h" #include "core/common/string_utils.h" #include "core/framework/data_types.h" +#include "core/graph/graph_utils.h" #include "core/optimizer/utils.h" namespace onnxruntime::optimizer::memory_optimizer { namespace { -constexpr int32_t MAXIMUM_RECOMPUTE_NODE_COUNT = 50; +// We don't usually need maximum recompute node count from our existing use cases, so loose the constraints here. +constexpr size_t MAXIMUM_RECOMPUTE_NODE_COUNT = std::numeric_limits::max(); static size_t GetElementSize(const ONNX_NAMESPACE::DataType& tensor_type) { const ONNX_NAMESPACE::TypeProto& type_proto = ONNX_NAMESPACE::Utils::DataTypeUtils::ToTypeProto(tensor_type); @@ -251,6 +253,12 @@ const InlinedHashMap& GetAllowedRecompu {13, {}}, }, }, + { + utils::GetFullQualifiedOpName("MemcpyFromHost", kOnnxDomain), + { + {1, {0}}, // Ignore CPU input. + }, + }, { utils::GetFullQualifiedOpName("Mul", kOnnxDomain), { @@ -282,6 +290,13 @@ const InlinedHashMap& GetAllowedRecompu {1, {1, 2}}, // ignore the indices and unflatten_dims }, }, + { + // Be noted, NOT all PythonOp will be allowed to recompute, there will be further check. + utils::GetFullQualifiedOpName("PythonOp", kMSDomain), + { + {1, {}}, + }, + }, { utils::GetFullQualifiedOpName("Range", kOnnxDomain), { @@ -467,6 +482,24 @@ bool IsRecomputable(const Node& node, ProbeLevel probe_level) { if (it == op_table.end()) { return false; } + + // If node is PythonOp, we check whether it is a per-defined deterministic op. + if (node.OpType() == "PythonOp") { + const auto* func_name_attr = graph_utils::GetNodeAttribute(node, "func_name"); + if (func_name_attr != nullptr) { + static const std::set deterministic_python_ops = { + "flash_attn.bert_padding.IndexFirstAxis", + "flash_attn.bert_padding.IndexPutFirstAxis", + "flash_attn.flash_attn_interface.FlashAttnFunc", + "flash_attn.flash_attn_interface.FlashAttnVarlenFunc", + "orttraining_test_ortmodule_api.test_layerwise_recompute_determinstic..DropoutFunction", + }; + return deterministic_python_ops.find(func_name_attr->s()) != deterministic_python_ops.end(); + } + + return false; + } + return it->second.count(node.SinceVersion()); } @@ -483,8 +516,7 @@ const InlinedVector& GetIgnorableInputIndices(const Node& node, ProbeLevel * * @param entry_node The entry node to start the subgraph matching (bottom-up), usually the last node of found subgraphs. * @param probe_config The probe config to control recomputable subgraph detecting. - * @param node_output_index_candidates Candidate output indices of "node", which are consumed by both fw and bw ops. - * @param fw_op_output_arg_used_map The activation usage (in fw and bw) mapping. + * @param candidate_output_args_map Candidate node to output map. * @param node_index_to_its_order_in_topological_sort_map The mapping of node index to its order in topological sort. * Used to re-order the collected subgraph nodes. * @param nodes_in_topological_order Collected vector of nodes of found subgraph, in the order of the topological @@ -500,8 +532,8 @@ const InlinedVector& GetIgnorableInputIndices(const Node& node, ProbeLevel */ Status SelectRecomputeSubgraph(const Node& entry_node, const ProbeConfig& probe_config, - const InlinedVector& node_output_index_candidates, - const ActivationUsedMap& fw_op_output_arg_used_map, + const InlinedHashMap>& + candidate_output_args_map, const InlinedHashMap& node_index_to_its_order_in_topological_sort_map, const logging::Logger& logger, @@ -510,6 +542,7 @@ Status SelectRecomputeSubgraph(const Node& entry_node, bool& can_compromise_stashed_activation, float& save_ratio) { const ProbeLevel probe_level = probe_config.probe_level; + const InlinedVector& node_output_index_candidates = candidate_output_args_map.at(&entry_node); can_compromise_stashed_activation = false; @@ -575,7 +608,7 @@ Status SelectRecomputeSubgraph(const Node& entry_node, } } else { if (!is_recomputable) { - if (fw_op_output_arg_used_map.at(cur_output_arg_name).second) { + if (candidate_output_args_map.count(curr_node) > 0) { MO_LOG_DEBUG_INFO(logger, "Node " + curr_node->Name() + "(" + curr_node->OpType() + ") is **NOT** in recompute op list, but its output [" + cur_output_arg_name + @@ -592,7 +625,7 @@ Status SelectRecomputeSubgraph(const Node& entry_node, } } - if (fw_op_output_arg_used_map.at(cur_output_arg_name).second) { + if (candidate_output_args_map.count(curr_node) > 0) { MO_LOG_DEBUG_INFO(logger, "Node " + curr_node->Name() + "(" + curr_node->OpType() + ") " + "is in recompute op list, while its output [" + cur_output_arg_name + "] is used in backward, we don't need trace bottom-up further. Entry node: " + @@ -643,9 +676,10 @@ Status SelectRecomputeSubgraph(const Node& entry_node, bool all_constant_dim = true; int64_t num_elem = 1; for (int k = 0, dim_size = output_shape->dim_size(); k < dim_size; ++k) { - if (!output_shape->dim(k).has_dim_value()) { - all_constant_dim = false; + if (output_shape->dim(k).has_dim_value()) { num_elem *= output_shape->dim(k).dim_value(); + } else { + all_constant_dim = false; } } if (all_constant_dim && num_elem < 1 * 1024 * 1024) { @@ -744,7 +778,6 @@ Status ParseProbeConfigFromString(std::string_view recompute_probe_config, Probe std::unique_ptr CheckNodeForRecompute(const GraphViewer& graph_viewer, const Node& node, const ProbeConfig& probe_config, - const ActivationUsedMap& fw_op_output_arg_used_map, const InlinedHashMap& node_index_to_its_order_in_topological_sort_map, const InlinedHashMap>& @@ -786,8 +819,7 @@ std::unique_ptr CheckNodeForRecompute(const GraphViewer& grap float save_ratio = 1.f; ORT_ENFORCE(SelectRecomputeSubgraph(node, probe_config, - candidate_output_args_map.at(&node), - fw_op_output_arg_used_map, + candidate_output_args_map, node_index_to_its_order_in_topological_sort_map, logger, nodes_in_topological_order, diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h index ac1021f5eb83b..5aa05b0f02e0f 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h @@ -142,11 +142,9 @@ class NodeRecomputePlan : public NodeOptimizationPlanBase { * @param graph_viewer The graph viewer to get node information. * @param node The entry node to start the subgraph matching (bottom-up), usually the last node of found subgraphs. * @param probe_config The config for subgraph detecting. - * @param fw_op_output_arg_used_map The activation usage (in fw and bw) mapping. * @param node_index_to_its_order_in_topological_sort_map The mapping of node index to its order in topological sort. * Used to re-order the collected subgraph nodes. - * @param candidate_output_args_map A map from node to its candidate activations, which are consumed by both fw and - * bw ops. + * @param candidate_output_args_map A map from node to its candidate activations. * @param layer_boundary_ln_nodes A set of LayerNormalization nodes, which are used as the boundary for subgraph. * @param subgraph_stores A store to maintain all found subgraphs. * @param logger Logger. @@ -159,7 +157,6 @@ class NodeRecomputePlan : public NodeOptimizationPlanBase { std::unique_ptr CheckNodeForRecompute(const GraphViewer& graph_viewer, const Node& node, const ProbeConfig& probe_config, - const ActivationUsedMap& fw_op_output_arg_used_map, const InlinedHashMap& node_index_to_its_order_in_topological_sort_map, const InlinedHashMap>& diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc index 35ecf1159d321..a4fbacc8a1f4c 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc @@ -37,6 +37,24 @@ bool IsSoftmaxNode(const Node& node) { return softmax_ops.find(node.OpType()) != softmax_ops.end(); } +bool IsAttentionOp(const Node& node) { + if (node.OpType() != "PythonOp") { + return false; + } + + // Check the func_name attribute of the PythonOp node. + const auto* func_name_attr = graph_utils::GetNodeAttribute(node, "func_name"); + if (func_name_attr == nullptr) { + return false; + } + + static const std::set attn_op_names = { + "flash_attn.flash_attn_interface.FlashAttnVarlenFunc", + "flash_attn.flash_attn_interface.FlashAttnFunc", + }; + return attn_op_names.find(func_name_attr->s()) != attn_op_names.end(); +} + std::tuple IsResidualNodeArg(const GraphViewer& graph_viewer, const NodeArg* node_arg) { auto consumers = graph_viewer.GetConsumerNodes(node_arg->Name()); if (2 > consumers.size()) { @@ -104,8 +122,8 @@ void FindLayerBoundaryLayerNormNodes( InlinedVector& layer_boundary_ln_nodes) { // Loop all nodes to find LayerNormalization nodes. // For each LayerNormalization node, keep checking its output nodes, - // until find a node that is Softmax or BiasSoftmax or another LayerNormalization. - // If the found node is Softmax or BiasSoftmax, the LayerNormalization node as ATTENTION. + // until find a node that is Softmax or Attention or another LayerNormalization. + // If the found node is Softmax or Attention, the LayerNormalization node as ATTENTION. // If the found node is another LayerNormalization, the LayerNormalization node as MLP. layer_boundary_ln_nodes.clear(); @@ -159,7 +177,7 @@ void FindLayerBoundaryLayerNormNodes( } visited_nodes.insert(next_node); - if (IsSoftmaxNode(*next_node)) { + if (IsSoftmaxNode(*next_node) || IsAttentionOp(*next_node)) { MO_LOG_DEBUG_INFO(logger, "Found layer boundary node " + node.Name() + " with its input arg: " + input_arg->Name()); layer_boundary_ln_nodes.push_back(&node); diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 3f91dc1065d32..5ea60102f3ef8 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -435,15 +435,17 @@ void addObjectMethodsForTraining(py::module& m) { throw std::runtime_error("Error in backward pass execution: " + status.ErrorMessage()); } }) - .def("get_serialized_ortmodule_memory_stat", // for memory optimization - [](TrainingAgent* agent, // agent - const std::string& memory_optimization_config, // user config string - const std::string& recompute_probe_level // user config string for probe level + .def("get_serialized_ortmodule_memory_stat", // for memory optimization + [](TrainingAgent* agent, // agent + const std::string& memory_optimization_config_file_path, // user config file path + const std::string& recompute_probe_level, // user config string for probe level + const bool return_opportunity_table // return detailed opportunity_table or not. ) -> std::tuple>> { std::map> cluster_id_combinations_to_saved_symbolic_byte_map; std::string opportunity_table = - agent->GetSerializedORTModuleMemoryStat(memory_optimization_config, + agent->GetSerializedORTModuleMemoryStat(memory_optimization_config_file_path, recompute_probe_level, + return_opportunity_table, cluster_id_combinations_to_saved_symbolic_byte_map); return std::tuple>>( opportunity_table, cluster_id_combinations_to_saved_symbolic_byte_map); diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index 8a890625003a8..1efc3a23eef34 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -162,119 +162,6 @@ def _export_pt_1_10(g, n, *args, **kwargs): if try_export is not None: return try_export - # Fall back to common exporter if not handled by high priority exporter. - - cconv = n.cconv() - - input_tensor_types = [] - input_tensor_ranks = [] - - input_bool_scalars = [] - input_bool_scalar_positions = [] - - input_int_scalars = [] - input_int_scalar_positions = [] - - input_float_scalars = [] - input_float_scalar_positions = [] - - input_bool_tuples = [] - input_bool_tuple_positions = [] - input_bool_tuple_begins = [] - - input_int_tuples = [] - input_int_tuple_positions = [] - input_int_tuple_begins = [] - - input_float_tuples = [] - input_float_tuple_positions = [] - input_float_tuple_begins = [] - - input_pointer_scalars = [] - input_pointer_scalar_positions = [] - - tensor_args = [] - debug_comment = "" - # Encode inputs to torch.autograd.Function. - for i, arg, call_type in zip(range(len(args)), args, cconv): - if call_type == "d": - # Got a tensor variable. - tensor_args.append(arg) - scalar_type = pytorch_type_to_onnx_dtype(arg.type().scalarType()) - input_tensor_types.append(scalar_type) - input_tensor_ranks.append(arg.type().dim()) - continue - - if call_type != "c": - raise wrap_exception( - ORTModuleONNXModelException, - Exception(f"Unknown calling convention found: {i}. Only 'd' and 'c' are supported"), - ) - - # Got a non-tensor variable. - # Non-tensor can't have gradient. - if isinstance(arg, float): - # A float. - input_float_scalar_positions.append(i) - input_float_scalars.append(arg) - continue - # bool check MUST be before int check since bool is a subclass of int - elif isinstance(arg, bool): - # A bool. - input_bool_scalar_positions.append(i) - input_bool_scalars.append(int(arg)) - continue - elif isinstance(arg, int): - # A int. - input_int_scalar_positions.append(i) - input_int_scalars.append(arg) - continue - - is_bool_tuple = False - is_int_tuple = False - is_float_tuple = False - if isinstance(arg, tuple) and len(arg) > 0: - # bool check MUST be before int check since bool is a subclass of int. - is_bool_tuple = all(isinstance(ele, bool) for ele in arg) - is_int_tuple = not is_bool_tuple and all(isinstance(ele, int) for ele in arg) - is_float_tuple = not is_bool_tuple and not is_int_tuple and all(isinstance(ele, float) for ele in arg) - - # Only support tuple of bool, int or float, for other types, handle it as a pointer. - if is_bool_tuple: - # A tuple of bool. - input_bool_tuple_positions.append(i) - input_bool_tuple_begins.append(len(input_bool_tuples)) - input_bool_tuples.extend([int(ele) for ele in arg]) - continue - elif is_int_tuple: - # A tuple of ints. - input_int_tuple_positions.append(i) - input_int_tuple_begins.append(len(input_int_tuples)) - input_int_tuples.extend(list(arg)) - continue - elif is_float_tuple: - # A tuple of floats. - input_float_tuple_positions.append(i) - input_float_tuple_begins.append(len(input_float_tuples)) - input_float_tuples.extend(list(arg)) - continue - - from onnxruntime.training.utils.hooks._statistics_subscriber import _InspectActivation - - is_inspect_activation = func_full_qual_name == get_fully_qualified_class_name(_InspectActivation) - if is_inspect_activation and isinstance(arg, str): - # _InspectActivation is a special case where the first argument is a string - # that is used to determine the activation name to be inspected. - debug_comment += arg - - # All other inputs are accessed via "pointers". - input_pointer_scalar_positions.append(i) - input_pointer_scalars.append(id(arg)) - - # For pointer (for example, ProcessGroup passed to PythonOp) needed for PythonOp execution, - # we append it into a global store to hold a reference (in case it is released after module exported). - register_miscellaneous_const_input(arg) - output_tensor_types = [] output_tensor_ranks = [] for arg in n.outputs(): @@ -282,58 +169,187 @@ def _export_pt_1_10(g, n, *args, **kwargs): scalar_type = pytorch_type_to_onnx_dtype(arg.type().scalarType()) output_tensor_types.append(scalar_type) output_tensor_ranks.append(arg.type().dim()) + # Fall back to common exporter if not handled by high priority exporter. + return _default_export( + g, + func_full_qual_name, + func_class, + n.cconv(), + n.outputsSize(), + output_tensor_types, + output_tensor_ranks, + *args, + **kwargs, + ) - attrs = { - "func_name_s": func_full_qual_name, - "input_convention_s": cconv, - "outputs": n.outputsSize(), - "input_tensor_types_i": input_tensor_types, - "input_tensor_ranks_i": input_tensor_ranks, - "output_tensor_types_i": output_tensor_types, - "output_tensor_ranks_i": output_tensor_ranks, - "training_mode_i": 1 if _get_training_mode() else 0, - "comment_s": debug_comment, - } - - if len(input_bool_scalars) > 0: - attrs["input_bool_scalars_i"] = input_bool_scalars - attrs["input_bool_scalar_positions_i"] = input_bool_scalar_positions - if len(input_int_scalars) > 0: - attrs["input_int_scalars_i"] = input_int_scalars - attrs["input_int_scalar_positions_i"] = input_int_scalar_positions - if len(input_float_scalars) > 0: - attrs["input_float_scalars_f"] = input_float_scalars - attrs["input_float_scalar_positions_i"] = input_float_scalar_positions - if len(input_bool_tuples) > 0: - attrs["input_bool_tuples_i"] = input_bool_tuples - attrs["input_bool_tuple_positions_i"] = input_bool_tuple_positions - attrs["input_bool_tuple_begins_i"] = input_bool_tuple_begins - if len(input_int_tuples) > 0: - attrs["input_int_tuples_i"] = input_int_tuples - attrs["input_int_tuple_positions_i"] = input_int_tuple_positions - attrs["input_int_tuple_begins_i"] = input_int_tuple_begins - if len(input_float_tuples) > 0: - attrs["input_float_tuples_f"] = input_float_tuples - attrs["input_float_tuple_positions_i"] = input_float_tuple_positions - attrs["input_float_tuple_begins_i"] = input_float_tuple_begins - if len(input_pointer_scalars) > 0: - attrs["input_pointer_scalars_i"] = input_pointer_scalars - attrs["input_pointer_scalar_positions_i"] = input_pointer_scalar_positions - - returned_args = g.op("com.microsoft::PythonOp", *tensor_args, **attrs) - - # Register function with class names. - register_torch_autograd_function(func_full_qual_name, func_class) - - register_custom_function_schema_supplementary(func_class) - - return returned_args except Exception as e: sys.stdout.flush() sys.stderr.flush() raise wrap_exception(ORTModuleONNXModelException, e) # noqa: B904 +def _default_export( + g, func_full_qual_name, func_class, cconv, output_size, output_tensor_types, output_tensor_ranks, *args, **kwargs +): + + input_tensor_types = [] + input_tensor_ranks = [] + + input_bool_scalars = [] + input_bool_scalar_positions = [] + + input_int_scalars = [] + input_int_scalar_positions = [] + + input_float_scalars = [] + input_float_scalar_positions = [] + + input_bool_tuples = [] + input_bool_tuple_positions = [] + input_bool_tuple_begins = [] + + input_int_tuples = [] + input_int_tuple_positions = [] + input_int_tuple_begins = [] + + input_float_tuples = [] + input_float_tuple_positions = [] + input_float_tuple_begins = [] + + input_pointer_scalars = [] + input_pointer_scalar_positions = [] + + tensor_args = [] + debug_comment = "" + assert len(args) == len(cconv), "Number of arguments does not match calling convention" + + # Encode inputs to torch.autograd.Function. + for i, arg, call_type in zip(range(len(args)), args, cconv): + if call_type == "d": + # Got a tensor variable. + tensor_args.append(arg) + scalar_type = pytorch_type_to_onnx_dtype(arg.type().scalarType()) + input_tensor_types.append(scalar_type) + input_tensor_ranks.append(arg.type().dim()) + continue + + if call_type != "c": + raise wrap_exception( + ORTModuleONNXModelException, + Exception(f"Unknown calling convention found: {i}. Only 'd' and 'c' are supported"), + ) + + # Got a non-tensor variable. + # Non-tensor can't have gradient. + if isinstance(arg, float): + # A float. + input_float_scalar_positions.append(i) + input_float_scalars.append(arg) + continue + # bool check MUST be before int check since bool is a subclass of int + elif isinstance(arg, bool): + # A bool. + input_bool_scalar_positions.append(i) + input_bool_scalars.append(int(arg)) + continue + elif isinstance(arg, int): + # A int. + input_int_scalar_positions.append(i) + input_int_scalars.append(arg) + continue + + is_bool_tuple = False + is_int_tuple = False + is_float_tuple = False + if isinstance(arg, tuple) and len(arg) > 0: + # bool check MUST be before int check since bool is a subclass of int. + is_bool_tuple = all(isinstance(ele, bool) for ele in arg) + is_int_tuple = not is_bool_tuple and all(isinstance(ele, int) for ele in arg) + is_float_tuple = not is_bool_tuple and not is_int_tuple and all(isinstance(ele, float) for ele in arg) + + # Only support tuple of bool, int or float, for other types, handle it as a pointer. + if is_bool_tuple: + # A tuple of bool. + input_bool_tuple_positions.append(i) + input_bool_tuple_begins.append(len(input_bool_tuples)) + input_bool_tuples.extend([int(ele) for ele in arg]) + continue + elif is_int_tuple: + # A tuple of ints. + input_int_tuple_positions.append(i) + input_int_tuple_begins.append(len(input_int_tuples)) + input_int_tuples.extend(list(arg)) + continue + elif is_float_tuple: + # A tuple of floats. + input_float_tuple_positions.append(i) + input_float_tuple_begins.append(len(input_float_tuples)) + input_float_tuples.extend(list(arg)) + continue + + from onnxruntime.training.utils.hooks._statistics_subscriber import _InspectActivation + + is_inspect_activation = func_full_qual_name == get_fully_qualified_class_name(_InspectActivation) + if is_inspect_activation and isinstance(arg, str): + # _InspectActivation is a special case where the first argument is a string + # that is used to determine the activation name to be inspected. + debug_comment += arg + + # All other inputs are accessed via "pointers". + input_pointer_scalar_positions.append(i) + input_pointer_scalars.append(id(arg)) + + # For pointer (for example, ProcessGroup passed to PythonOp) needed for PythonOp execution, + # we append it into a global store to hold a reference (in case it is released after module exported). + register_miscellaneous_const_input(arg) + + attrs = { + "func_name_s": func_full_qual_name, + "input_convention_s": cconv, + "outputs": output_size, + "input_tensor_types_i": input_tensor_types, + "input_tensor_ranks_i": input_tensor_ranks, + "output_tensor_types_i": output_tensor_types, + "output_tensor_ranks_i": output_tensor_ranks, + "training_mode_i": 1 if _get_training_mode() else 0, + "comment_s": debug_comment, + } + + if len(input_bool_scalars) > 0: + attrs["input_bool_scalars_i"] = input_bool_scalars + attrs["input_bool_scalar_positions_i"] = input_bool_scalar_positions + if len(input_int_scalars) > 0: + attrs["input_int_scalars_i"] = input_int_scalars + attrs["input_int_scalar_positions_i"] = input_int_scalar_positions + if len(input_float_scalars) > 0: + attrs["input_float_scalars_f"] = input_float_scalars + attrs["input_float_scalar_positions_i"] = input_float_scalar_positions + if len(input_bool_tuples) > 0: + attrs["input_bool_tuples_i"] = input_bool_tuples + attrs["input_bool_tuple_positions_i"] = input_bool_tuple_positions + attrs["input_bool_tuple_begins_i"] = input_bool_tuple_begins + if len(input_int_tuples) > 0: + attrs["input_int_tuples_i"] = input_int_tuples + attrs["input_int_tuple_positions_i"] = input_int_tuple_positions + attrs["input_int_tuple_begins_i"] = input_int_tuple_begins + if len(input_float_tuples) > 0: + attrs["input_float_tuples_f"] = input_float_tuples + attrs["input_float_tuple_positions_i"] = input_float_tuple_positions + attrs["input_float_tuple_begins_i"] = input_float_tuple_begins + if len(input_pointer_scalars) > 0: + attrs["input_pointer_scalars_i"] = input_pointer_scalars + attrs["input_pointer_scalar_positions_i"] = input_pointer_scalar_positions + + returned_args = g.op("com.microsoft::PythonOp", *tensor_args, **attrs) + + # Register function with class names. + register_torch_autograd_function(func_full_qual_name, func_class) + + register_custom_function_schema_supplementary(func_class) + + return returned_args + + _export = wrap_custom_export_function(_export_pt_1_10) @@ -488,3 +504,142 @@ def _matmul4bit_export(g, n, *args, **kwargs): ) # flatten to 1D tensor_args = [args[0], quant_weight, absmax] return g.op("com.microsoft::MatMulBnb4", *tensor_args, **attrs) + + +class DetermisticWrapper(torch.autograd.Function): + """ + A wrapper for run autograd function in a deterministic way. This is required for PythonOp that needs + recompute support. + """ + + @staticmethod + def forward(ctx, autograd_function: torch.autograd.Function, cpu_rng_state, device_rng_state, *args): + """For normal forward run, both cpu_rng_state and device_rng_state are None. + For recompute run, cpu_rng_state and device_rng_state are provided, and from the normal forward run results. + + If device_rng_state does not exist (in pure CPU training for example), we still return cpu_rng_state + as device_rng_state to avoid the exporter to handle the case where we return None as forward outputs. + """ + original_cpu_rng_state = None + original_cuda_rng_state = None + + if cpu_rng_state is None: + assert device_rng_state is None, "device_rng_state must be None if cpu_rng_state is None" + cpu_rng_state = torch.get_rng_state() + fwd_devices = list( + {arg.device for arg in args if isinstance(arg, torch.Tensor) and arg.device.type != "cpu"} + ) + if len(fwd_devices) > 0: + assert len(fwd_devices) == 1, "Only support single device for now" + assert fwd_devices[0].type == "cuda", "Only support cuda device for now" + device_rng_state = torch.cuda.get_rng_state() + else: + # Pass CPU RNG state as device RNG state if device RNG state is not provided. + # This is to workaround the tricky case where we return None|Tensor as forward outputs. + device_rng_state = cpu_rng_state + else: + assert device_rng_state is not None, "device_rng_state must be provided if cpu_rng_state is provided" + original_cpu_rng_state = torch.get_rng_state() + torch.set_rng_state(cpu_rng_state) + + if device_rng_state.data_ptr() != cpu_rng_state.data_ptr(): + original_cuda_rng_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(device_rng_state) + + outputs = autograd_function.forward(ctx, *args) + + # Append the RNG states to the outputs in the beginning. + updated_outputs = [] + updated_outputs.append(cpu_rng_state) + updated_outputs.append(device_rng_state) + if isinstance(outputs, torch.Tensor): + updated_outputs.append(outputs) + elif isinstance(outputs, tuple): + updated_outputs.extend(outputs) + else: + raise ValueError(f"Unsupported outputs type: {type(outputs)}") + + ctx.autograd_function = autograd_function + + if original_cpu_rng_state is not None: + torch.set_rng_state(original_cpu_rng_state) + + if original_cuda_rng_state is not None: + torch.cuda.set_rng_state(original_cuda_rng_state) + + return tuple(updated_outputs) + + @staticmethod + def backward(ctx, *grad_outputs): + # Skip the first two RNG states grad, which should be None. + outputs = ctx.autograd_function.backward(ctx, *grad_outputs[2:]) + + updated_outputs = [None, None, None] + if isinstance(outputs, torch.Tensor): + updated_outputs.append(outputs) + elif isinstance(outputs, tuple): + updated_outputs.extend(outputs) + else: + raise ValueError(f"Unsupported outputs type: {type(outputs)}") + + return tuple(updated_outputs) + + +@register_high_priority_handler("flash_attn.bert_padding.IndexFirstAxis") +@register_high_priority_handler("flash_attn.bert_padding.IndexPutFirstAxis") +@register_high_priority_handler("flash_attn.flash_attn_interface.FlashAttnFunc") +@register_high_priority_handler("flash_attn.flash_attn_interface.FlashAttnVarlenFunc") +@register_high_priority_handler( + "orttraining_test_ortmodule_autograd.test_determistic_pythonop_export..TestFunction" +) +@register_high_priority_handler( + "orttraining_test_ortmodule_api.test_layerwise_recompute_determinstic..DropoutFunction" +) +def _determinstic_exporter(g, n, *args, **kwargs): + """ + Export torch.autograd.Function in ORT PythonOp with deterministic wrapper. This is required for PythonOp that needs + recompute support. + + Here, we will insert 3 inputs before the actual inputs: + 1. The first input is a constant pointer, which is not a tensor. It points to the real autograd function to execute. + 2. The second input is a tensor, which is the CPU RNG state (during export, we assign None; in memory optimizer, + the recomputed PythonOp will take the CPU RNG state output from normal forward PythonOp node). + 3. The third input is a tensor, which is the device RNG state (during export, we assign None; in memory optimizer, + the recomputed PythonOp will take the CUDA RNG state output from normal forward PythonOp node). + + """ + # The first input is a constant pointer, which is not a tensor. The second input rng_state is a tensor. + cconv = "ccc" + n.cconv() + func_class = n.pyobj().__self__ + updated_args = [func_class, None, None] + if isinstance(args, (tuple, list)): + updated_args.extend(args) + else: + updated_args.append(args) + + output_tensor_types = [] + output_tensor_ranks = [] + for arg in n.outputs(): + # Type of tensor's elements. + scalar_type = pytorch_type_to_onnx_dtype(arg.type().scalarType()) + output_tensor_types.append(scalar_type) + output_tensor_ranks.append(arg.type().dim()) + + for _ in range(2): + output_tensor_types.insert(0, pytorch_type_to_onnx_dtype(torch.uint8)) + output_tensor_ranks.insert(0, 1) + + func_full_qual_name = get_fully_qualified_class_name(func_class) + default_op_outputs = _default_export( + g, + func_full_qual_name, + DetermisticWrapper, + cconv, + n.outputsSize() + 2, + output_tensor_types, + output_tensor_ranks, + *updated_args, + **kwargs, + ) + + return default_op_outputs[2:] diff --git a/orttraining/orttraining/python/training/ortmodule/_execution_agent.py b/orttraining/orttraining/python/training/ortmodule/_execution_agent.py index 7a89aadee9950..84d7bf6410966 100644 --- a/orttraining/orttraining/python/training/ortmodule/_execution_agent.py +++ b/orttraining/orttraining/python/training/ortmodule/_execution_agent.py @@ -165,11 +165,11 @@ def run_backward(self, feeds, fetches, state): self._training_agent.run_backward(feeds, fetches, state) def get_serialized_ortmodule_memory_stat( - self, memory_optimization_config: str, recompute_probe_level: str + self, memory_optimization_config_file_path: str, recompute_probe_level: str, return_opportunity_table: bool ) -> Tuple[str, dict]: """ Get serialized memory stats for OrtModule. """ return self._training_agent.get_serialized_ortmodule_memory_stat( - memory_optimization_config, recompute_probe_level + memory_optimization_config_file_path, recompute_probe_level, return_opportunity_table ) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index ad726a82ef6ff..8e383a5545e42 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -870,24 +870,14 @@ def _add_record(tbl, columns): ], ) - if self._runtime_options.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE: - opt_config_to_display = "ALL_RECOMPUTE_FOR_EACH_LAYER" - elif ( - self._runtime_options.memory_optimization_level - == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE - ): - opt_config_to_display = "ALL_RECOMPUTE_FOR_EACH_LAYER_WITH_COMPROMISE" - else: - opt_config_to_display = self._runtime_options.memory_optimizer_config - mem_infos = "" if self._runtime_options.memory_optimizer_is_enabled(): mem_infos += ( f"Memory Optimization Level: [{_MemoryOptimizationLevel.to_string(self._runtime_options.memory_optimization_level)}], " - f"Optimization Config: [{opt_config_to_display}]" + f"Optimization Config: [{self._runtime_options.memory_optimizer_config_file_path}]" ) else: - mem_infos = "Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1/2 or ORTMODULE_MEMORY_OPT_CONFIG=,,..." + mem_infos = "Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1/2 or ORTMODULE_MEMORY_OPT_CONFIG=" mem_row = _add_record( tbl, @@ -900,7 +890,7 @@ def _add_record(tbl, columns): if self._runtime_inspector.memory_ob.is_enabled() and self._debug_options.logging.log_level < LogLevel.WARNING: mem_notes, mem_tbl = self._runtime_inspector.memory_ob.display_memory_optimization_plans( - self._runtime_options.memory_optimizer_config, + self._runtime_options.memory_optimizer_config_file_path, details=True, ) if mem_tbl is not None: diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py index fb1a26661bc46..773c506d28ef4 100644 --- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py +++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py @@ -3,6 +3,8 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +import json +import tempfile from enum import IntEnum from logging import Logger from typing import Dict, List, Optional, Tuple, Union @@ -91,7 +93,6 @@ def __init__(self, m: torch.nn.Module, logger: Logger, training: bool): self._is_enabled = True # Memory optimization related. - self.memory_optimization_opportunity_table_str = None self.cluster_id_combination_to_saving_symbolics_map: Dict[str, MemoryOptimizationSummary] = {} ## The value is a list of symbolic dim values parsed from the first batch. self.symbolic_dim_name_to_value_map: Dict = {} @@ -119,6 +120,8 @@ def __init__(self, m: torch.nn.Module, logger: Logger, training: bool): self._m = m + self._json_file_for_layerwise_recompute = None + def is_enabled(self) -> bool: """Check if memory inspector is enabled.""" return self._is_enabled @@ -149,20 +152,22 @@ def find_memory_optimization_opportunity(self, execution_agent: TrainingAgent, r """ recompute_probe_config = runtime_options.recompute_probe_config - memory_optimizer_config = runtime_options.memory_optimizer_config + memory_optimizer_config_file_path = runtime_options.memory_optimizer_config_file_path # If the memory optimization level is aggressive, we will first collect all - # recompute subgraph by passing empty memory_optimizer_config to get_serialized_ortmodule_memory_stat. + # recompute subgraph by passing empty memory_optimizer_config_file_path to get_serialized_ortmodule_memory_stat. if runtime_options.memory_optimization_level in [ _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE, _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE, ]: - memory_optimizer_config = "" + memory_optimizer_config_file_path = "" ( - self.memory_optimization_opportunity_table_str, + _, memory_optimization_saving_symbolics, - ) = execution_agent.get_serialized_ortmodule_memory_stat(memory_optimizer_config, recompute_probe_config) + ) = execution_agent.get_serialized_ortmodule_memory_stat( + memory_optimizer_config_file_path, recompute_probe_config, False + ) cluster_id_to_saving_symbol_map: Dict[str, MemoryOptimizationSummary] = {} for cluster_id, memory_saving_stat in memory_optimization_saving_symbolics.items(): @@ -191,30 +196,43 @@ def find_memory_optimization_opportunity(self, execution_agent: TrainingAgent, r for cluster_id, values in sorted_list: self.cluster_id_combination_to_saving_symbolics_map[cluster_id] = values - # For aggressive memory optimization, we update the memory_optimizer_config using all. - if runtime_options.memory_optimization_level > 0: - recompute_configs = [] - for cluster_id in self.cluster_id_combination_to_saving_symbolics_map: - config_values = cluster_id.split(":") - opt_type = int(config_values[1]) - if ( - runtime_options.memory_optimization_level - == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE - and opt_type == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE - ): - recompute_configs.append(cluster_id) - elif ( - runtime_options.memory_optimization_level - == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE - and opt_type - in [ - _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE, - _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE, - ] - ): - recompute_configs.append(cluster_id) + # For aggressive memory optimization, we update the memory_optimizer_config_file_path using all. + if runtime_options.memory_optimization_level in [ + _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE, + _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE, + ]: - runtime_options.memory_optimizer_config = ",".join(recompute_configs) + apply_config = [] + + for cluster_id in self.cluster_id_combination_to_saving_symbolics_map: + plans = cluster_id.split(",") + recompute_configs = [] + for plan in plans: + config_values = plan.split(":") + opt_type = int(config_values[1]) + if ( + runtime_options.memory_optimization_level + == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE + and opt_type == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE + ): + recompute_configs.append(plan) + elif ( + runtime_options.memory_optimization_level + == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE + and opt_type + in [ + _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE, + _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE, + ] + ): + recompute_configs.append(plan) + + apply_config.append(",".join(recompute_configs)) + + self._json_file_for_layerwise_recompute = tempfile.NamedTemporaryFile(mode="w+") + json.dump(apply_config, self._json_file_for_layerwise_recompute) + self._json_file_for_layerwise_recompute.flush() + runtime_options.memory_optimizer_config_file_path = self._json_file_for_layerwise_recompute.name def inspect_memory(self, cur_phase: Phase): """Inspect memory usage and print statistics. @@ -263,7 +281,9 @@ def inspect_memory(self, cur_phase: Phase): def _increase_step(self): self._current_step += 1 - def display_memory_optimization_plans(self, memory_optimizer_config, details=False) -> Tuple[List[str], PTable]: + def display_memory_optimization_plans( + self, memory_optimizer_config_file_path, details=False + ) -> Tuple[List[str], PTable]: mem_plan_count = len(self.cluster_id_combination_to_saving_symbolics_map) if mem_plan_count > 0: @@ -288,8 +308,11 @@ def _get_user_config_without_freq(configs: str): return configs_with_out_freq user_configs_with_out_freq = [] - if memory_optimizer_config: - user_configs_with_out_freq = _get_user_config_without_freq(memory_optimizer_config) + if memory_optimizer_config_file_path: + with open(memory_optimizer_config_file_path) as conf: + data = json.load(conf) + for user_specified_plan in data: + user_configs_with_out_freq.extend(_get_user_config_without_freq(user_specified_plan)) for ( cluster_id, @@ -328,7 +351,7 @@ def _get_user_config_without_freq(configs: str): saving_recommendation = ( "Or use comma as a delimiter to selectively enable multiple memory optimization plans:\n" ) - saving_recommendation += " export ORTMODULE_MEMORY_OPT_CONFIG=,,..." + saving_recommendation += " export ORTMODULE_MEMORY_OPT_CONFIG=" notes.append(saving_recommendation) diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 476c169fd1409..f35e3f74ba60a 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -466,8 +466,9 @@ def _create_execution_agent(self): del execution_agent # Enable memory optimization if it is enabled in the session options. + session_options.add_session_config_entry( - "optimization.memory_optimizer_config", self._runtime_options.memory_optimizer_config + "optimization.memory_optimizer_config", self._runtime_options.memory_optimizer_config_file_path ) session_options.add_session_config_entry( "optimization.enable_memory_probe_recompute_config", self._runtime_options.recompute_probe_config diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index eabeacee2530c..9145fb1712e88 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -282,8 +282,10 @@ def __init__(self, logger: Logger): # Configuration for memory optimization. self.memory_optimization_level = ( _MemoryOptimizationLevel.USER_SPECIFIED - ) # 0: use `memory_optimizer_config`; 1: aggressive optimization, enable all recomputable subgraphs. - self.memory_optimizer_config = "" # This is an advanced config, please refer to onnxruntime docs for details. + ) # 0: use `memory_optimizer_config_file_path`; 1: aggressive optimization, enable all recomputable subgraphs. + self.memory_optimizer_config_file_path = ( + "" # This is an advanced config, please refer to onnxruntime docs for details. + ) # 1 is the op set level; 0 indicates whether consider the Transformer-based model's layer boundary when # detecting recompute subgraphs. self.recompute_probe_config = "1:0" @@ -351,8 +353,9 @@ def _override_from_env_vars(self): # Configuration for memory optimization. self.memory_optimization_level = int(os.getenv("ORTMODULE_MEMORY_OPT_LEVEL", self.memory_optimization_level)) - user_given_memory_optimizer_config = os.getenv("ORTMODULE_MEMORY_OPT_CONFIG", self.memory_optimizer_config) - self.memory_optimizer_config = ",".join([c for c in user_given_memory_optimizer_config.split(",") if c]) + self.memory_optimizer_config_file_path = os.getenv( + "ORTMODULE_MEMORY_OPT_CONFIG", self.memory_optimizer_config_file_path + ) if self.memory_optimization_level in [ _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE, _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE, @@ -429,7 +432,7 @@ def _override_from_env_vars(self): def memory_optimizer_is_enabled(self) -> bool: """Check whether memory optimizer is enabled.""" if self.memory_optimization_level == _MemoryOptimizationLevel.USER_SPECIFIED: - return len(self.memory_optimizer_config) > 0 + return len(self.memory_optimizer_config_file_path) > 0 elif self.memory_optimization_level in [ _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE, _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE, diff --git a/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc b/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc index c34d0be5657e6..a0959c0df8868 100644 --- a/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc +++ b/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc @@ -18,6 +18,7 @@ #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" +#include "test/optimizer/graph_transform_test_builder.h" #include "core/optimizer/utils.h" #include "core/platform/env.h" #include "core/session/inference_session.h" @@ -26,6 +27,7 @@ #include "test/capturing_sink.h" #include "test/test_environment.h" #include "test/util/include/asserts.h" +#include "test/util/include/temp_dir.h" #include "orttraining/core/optimizer/memory_optimizer/common.h" #include "orttraining/core/optimizer/memory_optimizer/memory_optimizer.h" #include "orttraining/core/optimizer/memory_optimizer/memory_insight.h" @@ -63,9 +65,17 @@ TEST(MemoryOptimizerTests, GeluRecompute) { onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; const std::string alleviation_config("Gelu+:1:-1"); + onnxruntime::test::TemporaryDirectory tmp_dir{ORT_TSTR("memory_optimizer_test_tmp_dir")}; + PathString config_path{ConcatPathComponent(tmp_dir.Path(), + ORT_TSTR("gelurecompute.json"))}; + const std::string config_path_str = ToUTF8String(config_path); + std::ofstream outfile(config_path_str); + outfile << "[\"" << alleviation_config << "\"]" << std::endl; + outfile.close(); + const std::string probe_config("1:0"); ASSERT_STATUS_OK(graph_transformation_mgr.Register( - std::make_unique(alleviation_config, probe_config), TransformerLevel::Level3)); + std::make_unique(config_path_str, probe_config), TransformerLevel::Level3)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level3, *logger)); @@ -105,9 +115,17 @@ TEST(MemoryOptimizerTests, TileRecompute) { onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; const std::string alleviation_config("Expand+Tile+:1:-1"); + onnxruntime::test::TemporaryDirectory tmp_dir{ORT_TSTR("memory_optimizer_test_tmp_dir")}; + PathString config_path{ConcatPathComponent(tmp_dir.Path(), + ORT_TSTR("tilerecompute.json"))}; + const std::string config_path_str = ToUTF8String(config_path); + std::ofstream outfile(config_path_str); + outfile << "[\"" << alleviation_config << "\"]" << std::endl; + outfile.close(); + const std::string probe_config("1:0"); ASSERT_STATUS_OK(graph_transformation_mgr.Register( - std::make_unique(alleviation_config, probe_config), TransformerLevel::Level3)); + std::make_unique(config_path_str, probe_config), TransformerLevel::Level3)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level3, *logger)); @@ -154,40 +172,67 @@ TEST(MemoryOptimizerTests, TransformerPerLayerRecompute) { // Find all optimizable subgraphs GraphViewer graph_viewer(graph); - const std::string initial_mem_config(""); + onnxruntime::test::TemporaryDirectory tmp_dir{ORT_TSTR("memory_optimizer_test_tmp_dir")}; + PathString config_path1{ConcatPathComponent(tmp_dir.Path(), + ORT_TSTR("layerrecompute_initial.json"))}; + const std::string config_path1_str = ToUTF8String(config_path1); + std::ofstream conf_stream(config_path1_str); + conf_stream << "[]" << std::endl; // Empty config. + conf_stream.close(); + const std::string probe_config("1:1"); std::map> cluster_id_combinations_to_saved_symbolic_byte_map; std::string record_str = optimizer::memory_optimizer::GetSerializedORTModuleMemoryStat(graph_viewer, - initial_mem_config, + config_path1_str, probe_config, + true, /*enable this for test converage*/ *logger, cluster_id_combinations_to_saved_symbolic_byte_map, nullptr, nullptr); InlinedHashMap cluster_id_to_config_map; + PathString config_path2{ConcatPathComponent(tmp_dir.Path(), + ORT_TSTR("layerrecompute_2.json"))}; + const std::string config_path2_str = ToUTF8String(config_path2); + std::ofstream conf_stream2(config_path2_str); + conf_stream2 << "[" << std::endl; // Empty config. + int index = 0; for (auto it = cluster_id_combinations_to_saved_symbolic_byte_map.begin(); it != cluster_id_combinations_to_saved_symbolic_byte_map.end(); ++it) { std::string cluster_id = it->first; - ORT_ENFORCE(optimizer::memory_optimizer::ParseOptimizationConfigFromString(cluster_id, cluster_id_to_config_map) - .IsOK()); + conf_stream2 << (index == 0 ? "" : ",") << "\"" << it->first << "\""; + index += 1; } + conf_stream2 << "]" << std::endl; + conf_stream2.close(); + + ORT_ENFORCE(optimizer::memory_optimizer::ParseOptimizationConfigFromString(config_path2_str, cluster_id_to_config_map) + .IsOK()); std::ostringstream oss; - int index = 0; + index = 0; + oss << "["; for (auto it = cluster_id_to_config_map.begin(); it != cluster_id_to_config_map.end(); ++it) { if (it->second.type == optimizer::memory_optimizer::OptimizationType::Recompute) { - oss << (index == 0 ? "" : ",") << it->first << ":1:-1"; + oss << (index == 0 ? "" : ",") << "\"" << it->first << ":1:-1\""; ++index; } } + oss << "]"; // Apply the transformer GraphTransformerManager graph_transformation_mgr{5}; - const std::string layer_wise_recompute_config(oss.str()); + PathString config_path{ConcatPathComponent(tmp_dir.Path(), + ORT_TSTR("layerrecompute.json"))}; + const std::string config_path_str = ToUTF8String(config_path); + std::ofstream outfile(config_path_str); + outfile << oss.str() << std::endl; + outfile.close(); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( - std::make_unique(layer_wise_recompute_config, probe_config), TransformerLevel::Level3)); + std::make_unique(config_path_str, probe_config), TransformerLevel::Level3)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level3, *logger)); diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 9e226958dd924..f35bb47f6b41d 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6743,3 +6743,185 @@ def forward(self, x): else: if "ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT" in os.environ: del os.environ["ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT"] + + +def test_layerwise_recompute_pythonop_determinstic(): + + original_val = os.environ.get("ORTMODULE_MEMORY_OPT_LEVEL", None) + + class DropoutFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + return torch.nn.functional.dropout(x, p=0.5, training=True) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + class OneLayer(torch.nn.Module): + def __init__(self, hidden_size, num_attention_heads): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = num_attention_heads * self.attention_head_size + self.query = nn.Linear(hidden_size, self.all_head_size) + self.key = nn.Linear(hidden_size, self.all_head_size) + self.value = nn.Linear(hidden_size, self.all_head_size) + self.dense = nn.Linear(hidden_size, hidden_size) + self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-05) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask): + query_layer = self.transpose_for_scores(self.query(hidden_states)) + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_scores = attention_scores + attention_mask + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + attention_probs = DropoutFunction.apply(attention_probs) + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + output = self.dense(context_layer) + output = DropoutFunction.apply(output) + output = self.LayerNorm(output + hidden_states) + return output + + # This toy model is written referring to HuggingFace bert-large-uncased model in run_glue.py: + # https://github.com/huggingface/optimum/blob/72133e595f9a054c3221ec9ea87f42e0bdaa062b/examples/onnxruntime/training/text-classification/run_glue.py + # This is just a simple version of it for convenient testing. + class ToyModel(torch.nn.Module): + def __init__(self, num_hidden_layers, vocab_size, hidden_size, num_attention_heads, pad_token_id, num_labels): + super().__init__() + self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id) + self.token_type_embeddings = nn.Embedding(1, hidden_size) + self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-05) + self.layer = nn.ModuleList([OneLayer(hidden_size, num_attention_heads) for _ in range(num_hidden_layers)]) + self.out_proj = nn.Linear(hidden_size, num_labels) + + def forward(self, input_ids, attention_mask, target): + input_shape = input_ids.size() + token_type_ids = torch.zeros(input_shape, dtype=torch.long).to(device) + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = inputs_embeds + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + hidden_states = DropoutFunction.apply(embeddings) + extended_attention_mask = attention_mask[:, None, None, :] + extended_attention_mask = extended_attention_mask.to(dtype=torch.float32) + extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(torch.float32).min + for _, layer_module in enumerate(self.layer): + hidden_states = layer_module(hidden_states, extended_attention_mask) + x = hidden_states[:, 0, :] + x = self.out_proj(x) + loss_fct = torch.nn.CrossEntropyLoss() + return loss_fct(x, target) + + def run_step(model, inputs, mask, target): + loss = model(inputs, mask, target) + loss.backward() + return loss + + # Generate one batch of inputs (shape:[batch_size, max_seq_length]) and masks (shape:[batch_size, max_seq_length]). + # Each input has random length from 1 to max_seq_length*0.8 with values from 2 to vocab_size and padded with 1 at + # [max_seq_length - length:]. Values of masks are 1 at [0:length] and 0 at [length:max_seq_length]. + def generate_inputs(batch_size, max_seq_length, vocab_size): + batched_inputs = [] + batched_masks = [] + for _ in range(batch_size): + # Generate random length from 1 to max_seq_length*0.8, to ensure sparsity > 20% + seq_len = random.randint(1, int(max_seq_length * 0.8)) + + # Generate input values and padding respectively and concatenate them + input_id = torch.randint(2, vocab_size, (seq_len,), dtype=torch.long, device=device) + padding = torch.ones((max_seq_length - seq_len,), dtype=torch.long, device=device) + batched_inputs.append(torch.cat((input_id, padding))) + + # Generate mask values and padding respectively and concatenate them + mask_ones = torch.ones((seq_len,), device=device) + mask_zeros = torch.zeros((max_seq_length - seq_len,), device=device) + batched_masks.append(torch.cat((mask_ones, mask_zeros))) + return torch.stack(batched_inputs), torch.stack(batched_masks) + + num_layers, vocab_size, hidden_size, num_attention_heads = 12, 50265, 768, 12 + batch_size, max_seq_length = 8, 128 + device = "cuda" + pt_model = ToyModel(num_layers, vocab_size, hidden_size, num_attention_heads, 1, 3).to(device) + + os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] = "0" + ort_model1 = ORTModule(copy.deepcopy(pt_model)) + + torch.backends.cudnn.determinstic = True + torch.backends.cudnn.benchmark = False + + pt_input, pt_mask = generate_inputs(batch_size, max_seq_length, vocab_size) + ort_input = copy.deepcopy(pt_input) + ort_mask = copy.deepcopy(pt_mask) + pt_target = torch.randint(3, (batch_size,), device=device) + ort_target = copy.deepcopy(pt_target) + + seed = 5033 + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + # Run one step of forward and backward for torch and ort respectively + ort_prediction1 = run_step(ort_model1, ort_input, ort_mask, ort_target) + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] = "1" + ort_model2 = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="recompute")) + ort_prediction2 = run_step(ort_model2, ort_input, ort_mask, ort_target) + + for ort_param1, ort_param2 in zip(ort_model1.parameters(), ort_model2.parameters()): + _test_helpers.assert_values_are_close(ort_param1.grad, ort_param2.grad, atol=1e-4, rtol=1e-5) + + if os.getenv("ORTMODULE_ROCM_TEST", "0") == "1": + # For ROCm EP, the difference between ORT and PyTorch is larger than CUDA EP. + _test_helpers.assert_values_are_close(ort_prediction1, ort_prediction2, atol=2e-3, rtol=2e-4) + else: + _test_helpers.assert_values_are_close(ort_prediction1, ort_prediction2, atol=1e-3, rtol=1e-4) + + execution_mgr = ort_model2._torch_module._execution_manager._training_manager + from onnxruntime.training.ortmodule._onnx_models import _get_onnx_file_name + + # Keep the logic aligned with _graph_execution_manager.py + path = os.path.join( + execution_mgr._debug_options.save_onnx_models.path, + _get_onnx_file_name( + execution_mgr._debug_options.save_onnx_models.name_prefix, "execution_model", execution_mgr._export_mode + ), + ) + + onnx_model = onnx.load(path) + onnx_nodes = onnx_model.graph.node + + recompute_nodes = 0 + for node in onnx_nodes: + if "_recompute" in node.name: + recompute_nodes += 1 + + assert recompute_nodes > 0, "No Recompute nodes are found" + + # Make sure environment variable is restored to its original value after the run is completed. + torch.cuda.synchronize() + if original_val is not None: + os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] = original_val + else: + if "ORTMODULE_MEMORY_OPT_LEVEL" in os.environ: + del os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py index bd4fce2cde144..99c15034cdafe 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py @@ -1800,3 +1800,62 @@ def _run_step(model, input): assert_values_are_close(param.grad, pt_params[name].grad, rtol=1e-04, atol=1e-3) else: assert pt_params[name].grad is None + + +def test_determistic_pythonop_export(): + + class TestFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, x): + ctx.save_for_backward(x) + return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + return ff * grad_output + + class TestModel(torch.nn.Module): + def __init__(self, output_size): + super().__init__() + self.custom_fn = TestFunction.apply + self.bias = Parameter(torch.empty(output_size, dtype=torch.float)) + + with torch.no_grad(): + self.bias.uniform_() + + def forward(self, model_input): + # model_input did not require_grad + out = self.custom_fn(model_input) + return out + self.bias + + output_size = 1024 + + ortmodule = ORTModule(TestModel(output_size)).train() + _ = ortmodule(torch.randn(output_size, dtype=torch.float)) + + onnx_nodes = ortmodule._torch_module._execution_manager._training_manager._onnx_models.exported_model.graph.node + + found_pythonop = False + for node in onnx_nodes: + if node.op_type == "PythonOp": + cconv = None + for attr in node.attribute: + if attr.name == "func_name": + func_name = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s + if ( + func_name + == "orttraining_test_ortmodule_autograd.test_determistic_pythonop_export..TestFunction" + ): + found_pythonop = True + + if attr.name == "input_convention": + cconv = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s + + if found_pythonop: + assert cconv == "cccd", f"Expected cconv to be ccdd, but actually got {cconv}" + + assert found_pythonop, "PythonOp should be found in the exported model"