From 9323dece14a125841356f375e744458e7d2fef67 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Tue, 12 Dec 2023 13:35:42 -0800 Subject: [PATCH] Revert "Allow layer-wise recompute (#18566)" This reverts commit ccf3b2054b47c3a48001bd9305957d430ac02f0e. --- docs/Memory_Optimizer.md | 120 +++++------ docs/ORTModule_Training_Guidelines.md | 14 +- include/onnxruntime/core/graph/constants.h | 3 - .../onnxruntime_session_options_config_keys.h | 6 +- onnxruntime/core/graph/graph_viewer.cc | 11 - onnxruntime/core/session/inference_session.cc | 8 +- .../3layer_bloom_optimized_training.onnx | Bin 245088 -> 0 bytes .../3layer_bloom_optimized_training.py | 84 -------- .../memory_optimizer.cc | 37 ++-- .../{memory_optimizer => }/memory_optimizer.h | 18 +- .../core/optimizer/memory_optimizer/common.cc | 12 +- .../core/optimizer/memory_optimizer/common.h | 12 +- .../memory_optimizer/memory_insight.cc | 105 +++------- .../memory_optimizer/memory_insight.h | 14 +- .../memory_optimizer/optimization_planner.cc | 2 +- .../memory_optimizer/optimization_planner.h | 16 -- .../memory_optimizer/recompute_analysis.cc | 151 ++++---------- .../memory_optimizer/recompute_analysis.h | 29 +-- .../memory_optimizer/transformer_specific.cc | 69 ------- .../memory_optimizer/transformer_specific.h | 25 --- .../ortmodule/_graph_execution_manager.py | 49 ++--- .../python/training/ortmodule/_onnx_models.py | 2 +- .../training/ortmodule/_runtime_inspector.py | 72 +++---- .../training/ortmodule/_training_manager.py | 10 +- .../python/training/ortmodule/options.py | 35 +--- .../python/training/utils/ptable.py | 13 +- .../test/optimizer/memory_optimizer_test.cc | 190 +----------------- .../python/orttraining_test_ortmodule_api.py | 55 ----- 28 files changed, 231 insertions(+), 931 deletions(-) delete mode 100644 onnxruntime/test/testdata/transform/recompute/3layer_bloom_optimized_training.onnx delete mode 100644 onnxruntime/test/testdata/transform/recompute/3layer_bloom_optimized_training.py rename orttraining/orttraining/core/optimizer/{memory_optimizer => }/memory_optimizer.cc (91%) rename orttraining/orttraining/core/optimizer/{memory_optimizer => }/memory_optimizer.h (88%) delete mode 100644 orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc delete mode 100644 orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h diff --git a/docs/Memory_Optimizer.md b/docs/Memory_Optimizer.md index 97f7e7ff2c14b..0147a937db81d 100644 --- a/docs/Memory_Optimizer.md +++ b/docs/Memory_Optimizer.md @@ -17,83 +17,55 @@ Classical scenarios include: Not all models and recipes need this optimizer technique. Imagine if your training recipe uses a batch size 6 (GPU compute and memory are fully saturated), and you don't need bump it to 8 to maintain a fixed global batch size. Enabling recompute maybe not bring better throughput on batch size 8 than the original batch size 6. -## Usage +## Quick trial - -Make sure ONNX Runtime training wheel is installed and correctly configured. -Integrate models using `ORTModule`. -```diff - model = build_model() - -+ from onnxruntime.training.ortmodule import ORTModule -+ model = ORTModule(model) -``` - -There are two modes to enable the memory optimizations: -- Aggressively Recompute All Within Each Transformer Layer, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=1`. This will recompute all detected subgraphs within each Transformer Attention+MLP layer. It is easy to enable, but be noted this recompute plan may NOT be the best one. In this mode, `ORTMODULE_MEMORY_OPT_CONFIG` env values passed by users are not respected. -- User Specified Subgraph Recompute, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=0` and `export ORTMODULE_MEMORY_OPT_CONFIG=,,...`. This is an advanced usage, that allows users to find the most suitable graphs to recompute, at the cost of overhead to look for the best plans. - -### Mode 1 - Simple Usage (Aggressively Recompute All Within Each Transformer Layer) - - -1. Set memory optimization level to be TRANSFORMER_LAYERWISE_RECOMPUTE, by `export ORTMODULE_MEMORY_OPT_LEVEL=1` -2. Run the training as usual; check the logs, you could find something like this if the current log level <= LogLevel.INFO: +1. Make sure ONNX Runtime training wheel is installed and correctly configured. +2. Integrate models using `ORTModule`, be noted log_level should be equal or lower than INFO. + > ort_model = ORTModule(pt_model, DebugOptions(log_level=LogLevel.INFO)) +3. Run the training as usual; then stop it after training few steps. +4. Check the logs, you could find something like this: ``` - Memory Optimizer : ON : Memory Optimization Level: [TRANSFORMER_LAYERWISE_RECOMPUTE], Optimization Config: [Reshape+Where+:1:-1,BiasSoftmax+:1:-1,Cast+:1:-1,BiasGelu+:1:-1,FusedMatMul+:1:-1,Add+:1:-1,Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1] - Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes) - - Plan 1 : ON : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 - - Plan 2 : ON : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) - - Plan 3 : ON : Cast+:1:-1 1 67,043,328 64.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) - - Plan 4 : ON : BiasGelu+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) - - Plan 5 : ON : FusedMatMul+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) - - Plan 6 : ON : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) - - Plan 7 : ON : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1 - - Plan 8 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1 + Memory Optimizer : OFF : Enable with env ORTMODULE_MEMORY_OPT_CONFIG=, available configs: + Config Freq Max Saving(B) Saving Symbolic(Bytes) + - Plan 1 : OFF : Reshape+Where+BiasSoftmax+:1:-1 5 671,088,640 640.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 + - Plan 2 : OFF : Cast+:1:-1 6 402,587,648 inputs_input_ids_dim0*inputs_input_ids_dim1*(384.0*inputs_input_ids_dim1 - 64.0) + - Plan 3 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 + - Plan 4 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) + - Plan 5 : OFF : BiasGelu+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0) + - Plan 6 : OFF : FusedMatMul+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0) + - Plan 7 : OFF : FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1 5 26,214,400 25600.0*inputs_input_ids_dim0*inputs_input_ids_dim1 + - Plan 8 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) + - Plan 9 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1 + - Plan 10 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1 + + + Note 1: use comma as delimiter to enable multiple memory optimization plans at the same time: + export ORTMODULE_MEMORY_OPT_CONFIG=,,... + Note 2: memory saving is calculated based on the 1st batch symbolic dim values: + inputs_input_ids_dim0=1, inputs_input_ids_dim1=1024, inputs_attention_mask_dim0=1, inputs_attention_mask_dim1=1024, inputs_labels_dim0=1, inputs_labels_dim1=1024, ``` -3. As shown above, `Config` is a string representative for a re-computable subgraph. All are enabled for recompute in this case. - - -### Mode 2 - Advanced Usage (User Selected Subgraph Recompute) - -1. Be noted `ORTMODULE_MEMORY_OPT_LEVEL` is by default be 0. Run the training as usual; then stop it after training a few steps. -2. Check the logs, you could find something like this if the current log level <= LogLevel.INFO:: +5. As shown above, `Config` is a string representative for a re-computable subgraph. All are disabled for recompute in this case. +6. Set environment variable `ORTMODULE_MEMORY_OPT_CONFIG` to enable some of the subgraph to do recompute. In below example, `6` `BiasGelu+` related subgraphs are allowed to recompute. +`BiasGelu+` is the subgraph string representative; `1` in the middle indicates 'Recompute' is enabled (0, on the contrary indicates it's disabled); `6` means the initial 6 subgraph occurrences will be recomputed, all others are left as it is, filling `-1` will make all occurrences be recomputed. ``` - Memory Optimizer : OFF : Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1 or ORTMODULE_MEMORY_OPT_CONFIG=,,... - Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes) - - Plan 1 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 - - Plan 2 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) - - Plan 3 : OFF : Cast+:1:-1 1 67,043,328 64.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) - - Plan 4 : OFF : BiasGelu+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) - - Plan 5 : OFF : FusedMatMul+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) - - Plan 6 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) - - Plan 7 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1 - - Plan 8 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1 + export ORTMODULE_MEMORY_OPT_CONFIG="BiasGelu+:1:6" # Use comma as separator for enabling more than one subgraphs. ``` -3. As shown above, `Config` is a string representative for a re-computable subgraph. All are disabled for recompute in this case. -4. Set environment variable `ORTMODULE_MEMORY_OPT_CONFIG` to enable some of the subgraphs to do recompute. - ```bash - # Use comma as a separator for enabling more than one subgraphs. - export ORTMODULE_MEMORY_OPT_CONFIG="BiasGelu+:1:1" - # Explanation: - # > BiasGelu+ is the subgraph string representative; - # > 1 in the middle indicates 'Recompute' is enabled (0, on the contrary indicates it's disabled) - # > The last 1 means the initial 1 subgraph occurrences will be recomputed, all others are left as it is, filling `-1` will make all occurrences be recomputed. - - ``` -5. Then run the training again, and you will see logs like this: +7. Then run the training again, and you will see logs like this: ``` - Memory Optimizer : ON : Memory Optimization Level: [USER_SPECIFIED], Optimization Config: [BiasGelu+:1:-1] - Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes) - - Plan 1 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 - - Plan 2 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) - - Plan 3 : OFF : Cast+:1:-1 1 67,043,328 64.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) - - Plan 4 : ON : BiasGelu+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) - - Plan 5 : OFF : FusedMatMul+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) - - Plan 6 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) - - Plan 7 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1 - - Plan 8 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1 + Memory Optimizer : ON : User config: Reshape+Where+BiasSoftmax+:1:-1, probe level: 1, available configs: + Config Freq Max Saving(B) Saving Symbolic(Bytes) + - Plan 1 : OFF : Reshape+Where+BiasSoftmax+:1:-1 5 671,088,640 640.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 + - Plan 2 : OFF : Cast+:1:-1 6 402,587,648 inputs_input_ids_dim0*inputs_input_ids_dim1*(384.0*inputs_input_ids_dim1 - 64.0) + - Plan 3 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 + - Plan 4 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) + - Plan 5 : ON : BiasGelu+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0) + - Plan 6 : OFF : FusedMatMul+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0) + - Plan 7 : OFF : FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1 5 26,214,400 25600.0*inputs_input_ids_dim0*inputs_input_ids_dim1 + - Plan 8 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) + - Plan 9 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1 + - Plan 10 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1 ``` -6. You may need iterate a few times on step 4 and 5 until you find a good config for this model to run a bigger batch size. Or you may fail to find if memory optimization does not apply to the model well. +8. You may need iterate few times on step 6 and 7 until you find a good config for this model to run a bigger batch size. Or you may fail to find if memory optimization does not apply to the model well. ## Optimization Configuration @@ -101,13 +73,11 @@ The basic optimization unit is represented with a unique `cluster id`, for examp Following `cluster id` is the `optimization strategy`: 0 - none, 1 - recompute, 2 - recompute with compromised memory saving. Following `optimization strategy` is the `request count` to apply the given optimization. Using `-1` to apply all. This would give user a bit more flexibility to avoid unnecessary memory saving. -### Compromised Recompute +## Compromised Recompute If you check the above logs, there is a config `Cast+:2:-1`, `2` indicates it's a recomputation than can save part of the stashed activation size, not all. Recompute the subgraphs under it usually will save part of the activation (for example half of them), not all of them. Follow the same way to enable it. -## Dev Notes - -### Memory Optimization Debug Infos +## Memory Optimization Debug Infos Using following log level > ort_model = ORTModule(pt_model, DebugOptions(log_level=LogLevel.DEVINFO)) @@ -162,4 +132,4 @@ MemoryInsight Summary - User config: not provided ## Notes -The feature is in the experimental stage, we will tune and refine it according to real use cases. +The feature is in experimental stage, we will tune and refine it according to real use cases. diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index bede16204d420..a3cceb441a2a9 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -146,6 +146,7 @@ Check [DebugOptions implementation](../orttraining/orttraining/python/training/o export ORTMODULE_ONNX_OPSET_VERSION=14 ``` + #### ORTMODULE_FALLBACK_POLICY - **Feature Area**: *ORTMODULE/FallbackToPytorch* @@ -154,6 +155,7 @@ Check [DebugOptions implementation](../orttraining/orttraining/python/training/o export ORTMODULE_FALLBACK_POLICY="FALLBACK_DISABLE" ``` + #### ORTMODULE_LOG_LEVEL - **Feature Area**: *ORTMODULE/DebugOptions* @@ -180,6 +182,7 @@ The output directory of the onnx models by default is set to the current working > On the other hand, if the wrapped computation graph is small, it is reasonable to allow it. > Overall users should be aware that ORT performance boost might be trivial when they explicitly allow it. + #### ORTMODULE_ENABLE_CUSTOM_AUTOGRAD - **Feature Area**: *ORTMODULE/PythonOp (torch.autograd.Function)* @@ -196,6 +199,8 @@ The output directory of the onnx models by default is set to the current working enable_custom_autograd_support(False) ``` + + #### ORTMODULE_ENABLE_COMPUTE_OPTIMIZER - **Feature Area**: *ORTMODULE/Optimizations* @@ -284,15 +289,6 @@ A classical usage of disabling the deep copy: when the deep copy before module e export ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT=0 # Disable ``` -#### ORTMODULE_MEMORY_OPT_LEVEL - -- **Feature Area**: *ORTMODULE/Optimizations* -- **Description**: By default, the level is 0. This env var can be used for enabling recomputation for reducing memory peak requirement. Setting the level to be 0 means all detected subgraphs with each transformer-based model layer generating stashed activations will be recomputed. This is conceptually equivalent to PyTorch's gradient checkpoint. When level is not 0, check Check [Memory Optimizer for ONNX Runtime Training](Memory_Optimizer.md) for more details. - - ```bash - export ORTMODULE_MEMORY_OPT_LEVEL=0 - ``` - ### 2.2 Memory Optimization Q: *Want to run a bigger batch size?* diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index 9b26ba914c7dd..7e59aad80cc47 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -55,7 +55,4 @@ constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider"; constexpr const char* kExecutionProviderSharedLibraryPath = "shared_lib_path"; constexpr const char* kExecutionProviderSharedLibraryEntry = "provider_factory_entry_point"; -// For Priority based graph topology sorting. -constexpr const char* kBackwardNodeAttributeName = "__backwardpass"; - } // namespace onnxruntime diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index a94973b2cc5d7..4628afbb5a702 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -88,9 +88,9 @@ static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining = // the memory. static const char* const kOrtSessionOptionsMemoryOptimizerEnabler = "optimization.memory_optimizer_config"; -// Specifies the config for detecting subgraphs for memory footprint reduction. -// The value should be a string contains int separated using commas. The default value is "0:0". -static const char* const kOrtSessionOptionsMemoryOptimizerProbeConfig = "optimization.enable_memory_probe_recompute_config"; +// Specifies the level for detecting subgraphs for memory footprint reduction. +// The value should be an integer. The default value is 0. +static const char* const kOrtSessionOptionsMemoryOptimizerProbeLevel = "optimization.enable_memory_probe_recompute_level"; #endif // Enable or disable using device allocator for allocating initialized tensor memory. "1": enable; "0": disable. The default is "0". diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index cf78040ea5ac6..b1e07714cd3c8 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -35,17 +35,6 @@ struct PriorityNodeCompare { return n1->Priority() > n2->Priority(); } - // nodes of forward pass will be output first - auto n1_attrs = n1->GetAttributes(); - auto n2_attrs = n2->GetAttributes(); - int64_t n1_is_forward = static_cast(n1_attrs.find(kBackwardNodeAttributeName) == n1_attrs.cend()) || - (n1_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2; - int64_t n2_is_forward = static_cast(n2_attrs.find(kBackwardNodeAttributeName) == n2_attrs.cend()) || - (n2_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2; - if (n1_is_forward != n2_is_forward) { - return n2_is_forward > n1_is_forward; - } - // otherwise, nodes with lower index will be output first return n1->Index() > n2->Index(); } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 575529a06fb7a..2ed2ae93cb5c4 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -77,7 +77,7 @@ #ifdef ENABLE_TRAINING #include "core/framework/partial_graph_execution_state.h" #include "core/framework/stream_execution_context.h" -#include "orttraining/core/optimizer/memory_optimizer/memory_optimizer.h" +#include "orttraining/core/optimizer/memory_optimizer.h" #endif using namespace ONNX_NAMESPACE; @@ -1204,10 +1204,10 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool { const std::string memory_optimizer_config = session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsMemoryOptimizerEnabler, ""); - const std::string probe_config = - session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsMemoryOptimizerProbeConfig, "0:0"); + const std::string probe_level = + session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsMemoryOptimizerProbeLevel, "0"); - MemoryOptimizer mem_transformer{memory_optimizer_config, probe_config}; + MemoryOptimizer mem_transformer{memory_optimizer_config, probe_level}; ORT_RETURN_IF_ERROR_SESSIONID_(apply_transformer_once(mem_transformer, *session_logger_, graph)); } #endif diff --git a/onnxruntime/test/testdata/transform/recompute/3layer_bloom_optimized_training.onnx b/onnxruntime/test/testdata/transform/recompute/3layer_bloom_optimized_training.onnx deleted file mode 100644 index ade409c22b4d4f4631107f4d18073df44e970d3e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 245088 zcmd_T4Uiv6e;Q`f+RwO>5xE5{F4+?2$D>imMz(|96{O) z25HNdZJDHkwketGtjfyDtbCbO^pb|@|9h|aXASX4e>~Z{u{RnV^bf~7#|Oi;%nybKYqQBs*GGZ_-b*(ydP%u{1V##F5Zq7YCr5BaZKHil@-pF&!RU z?+<3P;b^uu9w|l93i2|3A+zL^uK0xV95?A4xqP)LK2}vOpE({*ZuVat-t50NI5-|^ z32j%NhC|pVt5@n9Bu8olThnEH9#(oNg&%AFb#(=i>)a?NRHG9wx+{qYdSUAnq0c|iKD8m zvDxOdHMw+i#6hq%9bQ}0sn6CFkj-F+WH_aVz|0HI2!hQ z#*J)}KEHA@CbaqvadJ8vOlD_g&e9bBTSHW2N3_S2*>ra>8SeBi^gI1-e_iG=uMS83 zt@Eu9;aGiQFxwqY?75z&5m>oPky^c5EDqk-o2HYk%8+=kJ=?CF?fGQ3-Og6)5gfN4 z9ZjXtAAWY&?_M;^`yidj%ITD(_>qPjbA;f_FA)Xw~Zx zRnHJ_(^RkPSG|AP5H!^@4CSldWvbR}nAP1P)36n<4wF)OUDwSlY3sVqysn0!n%30} z&?c|zy1lMHu^>3?TXxe%7BpSl4Z)zrbQ(rtSmW`?c(gs3^*2)`Yqe!zu&?ay!EkhA z7DZ~2Pwe(Cw$G|RmvzLwkt}AWzu8*3H~X86_v$4dTYi0Sw!1uhLl*y*RVDMmd$z}i zYlnN=lks$XeJ1`5+SUO6p-x`ZY<7FCOE^36!cc|>GK8@#%>^Qbl~W0^)idJs;NW2V z`e%oevGVS}(GZmxbLYL0DrdKI0qZ><`^>Q{ZX2y?BfV$x`IVFL7virpv7wsF0GwSn zUDfsK!o$kv<1*G8R*+Bf`IVFL7vk@Q0`5GEQz2wKpYHZHJ!Cs?TiE|oo{Hc$MjGz* zyM`TinIjvUSf5U}!h|nyO!$I5;a{mTVZ%=A130NYGMLU}+@78TZBRbGvKWzBJuRBE z@!8W+F7?%hsK!VOsYlO@XJ!&z+$7{@R~F=d#Q)b26^6x{BzMnaJ?&#}91TW0wktt~ zbDF@)UFs|GKNB-xD=^Q?fRP%JWWZF_R38t9s@0f*c7-SRwMrXEO|;Bk&$cU-LRG^~ zp5vg0nI;V4pn=^$NvLO-$O%=?2Se2hw>wld{6no|2vx11jTow~*AuE5#@u-`gsMjK zX`yPbYN%=$fU`?um@YX~HNvBXs)jwRe3aXh@=>Ac?>DlhC0jMEUMT_&DNV6^Wm;zJ zzLFOPqZ_!E8Sv?qlQ9wTI}K5yUk7|V4RKtZ$ktxTr2`X3(hDV-sP?z3&NH)o*$PRE zJWVU!CD(&8*~+a3-G3Qo=rqi_wI0I7zE2IO(^RY4J^Y7P79$ekYfU=$mIEU%si8rP z0!$2>69sgFQ9zd|3Xpj(vN)bZF&6)i<5p1;kf&gy*|XQx@md;iBSU0|&dPnx4_4nP zPR2b+tEZM4XHH*f$WymP@!vKC6TetTre*@@U((wz&o43MOP z>Cz+10P5pu9{ELddCuh@`nNRZ>2%>QFX*C9w*!A^XJ$yJd-Y&%o9^-BPfPWPh>%a0 zjeh3ENm!$oR`R>bJunm+Mh*X*D+(F64IJixD_ z{a&J_>E;yoXWPko{S9n^Mji5|L2QXTO?HJP{8miDn zl}7w&U8Ixs4N_ENlnyJ+Eb~~Sk!(6MJT?N*1k#iS-?XMl7zGFYl6G~TT>02a2cXyN zw(Jw5Z*NljA1p7AOte^*+YC=!ZrQM?R!xwX+wkDV@!dYla$DW8O2o;*2I`)?QKdSTs z8FsRnq;q&epRX}-73SNu=lkcH6`H*aJJ=XSIUP&ak#&Z|!<*Qg=yqg3b>enp^@zrU zRU4xbFq3E@-O95QqyHV;$-}dgmFBY(@y~?2Y2|!_QCOV?=|%k@lWwBa9Iws+c!EF*R$X1 zVXGilkFS}TBdZ|tvnvboKjI&wt04BUp(|j>KBuRT56Cl-pIupy{}KNqdM5U;*@bcJ zDl$DcZCh!ghV+*y1>2&^f>mWk&>JjjB{kN#joo1J$1|;@QQRf|E&6O-pF8%uh8t|z zChm(hAIIk2k-C^qmqAytD-rRhm8Sk4x`84)bn7vkNzQ&cIe6GuZ$Dg#%(h(H%*u)= zqJH-yli^@C+*uwhkA|-=M|UVj9Yn#q6oY z#CQEf%VCDjk6_6a1WQqozCX1+{Y*UfKUGL64Kvv#%$iwxx@iSE*Z@M^Y|?f*oNRW3 z&7xhdX3_t&DR?$E+Acamz?KD*`%hYk*rSc)%2Qq?ngDD8F5x@)6J%fuaA)2j%mQpt zevq6i`oAu3G@_t~{3vAbo^^H?M}~M)8{#eTJkF~)W1cgPJ`D+LAldpzH7*~1QC@aT z6&`7&-6==JRu-fGuAUW3V|lf5AQ9zOPN#TqZ#5PkJTso$SWBSh)})6WiQilhoU;&5 zSJ9RJ!Qkd_GLmR^$;J{NVIp^JZ!m3rtkQBB=FUZDxW*l|#B}hQ1a>$a4o0o})z^EY z*ZR}hj)dBpv|CR_vwB9(YH+ak*+B$Sw&MBO>it55 zN7KE7@#tM&`Pw&Mc&pLqwXYvfWxSM*v^+h&c4IO)+MO&u+Ip;9su5HV#t!}nyy$7su@~d9EovuK-*2(WWb6xX4YzlT+h;7{d zk<0CZ&J=sO)~DGAd3rF5qUF@ke!Qj;R_;nO;IoNZjQB?R^7VYZVMT(c=m=y07&)p!DCja)@C$C07m;Fc4KPhssrX7mBsAS+ z=vrNn>#Y|m%M26@+DljEW|RO#ot0C`SG{jcz6i1Vy-iUAC#|cMXZPX+0Q_FzBWkA+ z_b0f&CoVN5_OD#xQ$n8R)FGAXTTeCAm~u+H`qoQG`djUOtjVzHT&7_@Wt|GJ$=x*O zwx6u5ILTf+?{BC5QATw!Y(x8DANGE<@|;ZN%k(d_pW-BIm-FYONfufinAXd)AtdCY!AW2Tunt!+{c$u5ujxt?EtUD^x_4FMJg3|~X zmRZRf+XAJpwXoaSWQd1gFVz_rwyW(xa)G`n8(^}0yVUPAMI8YTLkt3CQk_v9{+qf| zyLx1n63wfNr?Iu?{+zjC#28zfla@5LmY0ai#@2s{(U==fTGcnbiLI?QtnT?OvZ45T z;gYNto)SM)o9VLJZ7#^lr#Ul~Hl`-K#%H>$MGoE&=>~6zbh&Pb{4Y%gNm z%*q_KhNKb3y>=dtJn-@?x5?wXriZp&UCK6HJCpwoPNQXPE<<5j6`Cv|a8<`PYOjtN zreU(KL)$HmEGemEnY*xjHH{1Y>KNy)GHlOrSD8V1cU1;v>ffNxzRQ{dC@Ebe>vlF6 zB0Xa#tX!YfteFMii?z*$%jtp(qI-9_(7k1!VRcJw`$Gl~mJ|&v3Gu~~hA3mvU?Z<+ zAn$_K!9GJc95-ki2iUG24l9b8Hs!bh{l91mWLvCtCB}A0(FG3c630bK?}ig@J5pc> zsIxb5+tmZ=n4rE${@olQb(Ter&__NMa5#7)O#wL$1{FzJ?^0j*d%T3aSkG{->9EhW zyUZB~*Y4WP?5^k%fcUlAs#(=beQ!_aY?dwToY#X$hU;8NhQGy8PKfQBnGm0UHZ&R9 zEp(P{h%)-;%{>2{7WI2DP}P}Fx5HS}w?YuzFe8g}x1*?UIAJa7TOkE!QD0)Rr$t@a z*{dGQwwh*`%f@CFd|SrQWZW{Q7R|V2askIBJOf`yYj!*%Q6m$Ej4de}TW3|(8+tU& zv~mmQR?~IT)87?XE>S;J-`asd*PlFVnsUfAi5=oAp@}KraAR5H^r5z z*Gpr<>QZTS`(ksMMz0K|AVt?chGkIVl&ES>`jsX&1ap8Fg0KJ#+aTrDNVR(3StL&r zoid!`#RQc!KflP(7?)*GJcrixqp`#ri_k_#5!QlPAPp+<3?Ic9b%0=+LM#@;4{U_t z2R3uy2iE5bKk(~KioGr=6zqnL?DgTn(Ybgx^`l4T`y0DHKS1zXkgn?l2tCgsqNxY? zagvV>kxh*Y0nS6xT+7Hc_9?Ar@Xk8o?#A%oxW9LJbZ{_R`}L0`3X<+m-aI zUQ0oqNNceyUprooe*54(+vCHv!@cdvcq)&+g{VeB@7Zbh@}clD6N}&h^b1`-G;p zu^G7`o1-imN>abEfJO7ONTkb$R@W{UND(AN11aj_L5ePfqlAk&QNsDTMhO&0Pux;Fs~r(uZF z@$kI~L7M*F1PZiKaEO#`Xh`fn7j)M%7|@)Zo8B5c5KA6e`$8n8%LUMO)gNk-K8&Yy zxq#skO|mK%=$^VDO}{2lx?F-o7)@%M+u5qIOt%kEx?I3`)u4@hopCfMFSb-7$x(u`SS^X)Po*v4Allj8FK8I*8oA&V}4hfja_I zfNe!#LsHCaD~z&6>qM(=XE^n#bN z*SRibQ*N?=)$!>VpNX?1#b@Ho^%#bPRNGB<03;x3J(AhW_>P7kQ6m%lc?}@y)=tsq zH35ogt&{?Q6rgEBOJ)EhI9^2nBtJnq0Fs}ckw?V`gYeDwfT2tk7~OpDBd4mq6sQ(J z4WyPZXt8q>bXOM`(D=bh07$uX?R=jCfCPwZqh;Bb>XqZ_07!t(gkNjPpd@uqa6|+6ZU1iA%eXj*f#M-Tu0g&>DTM7kYVqHoj0U+fJYE=YS z&6-&NzF6DHCmd^U1ThvixG)wd07zg3Y=6iXV_Op9i?Iz+#u%F&00~G~yDP&!Lpba} z|KK570Gn?E0HgrVX?TeTKmruA=tB6$X+Oa#07!m{X3h+ukAon@aeSyq)&YP7JRf4TmSch2TqF2^|sud8lwU00<5Ac2K-^pMX#TSDUV z&xRt zAd!G%tl?7kz_*EvWEeV|bd+eS(~O&l9TEvhTC;C#0AnH9$b=z>OA5)R;?2+qq-mzb zWR_MjO*K_rwpKB-N*u8sHzX2JE>%b~Ov5r)9j#)fBEg^`iH&B2L;_OML!{~gU143t zV+)i7e<2!7#DufeA&~&_VqMY|F3x5#6PNSp5+fuMkW!*chHbdz^x-&5vzU!Kgoi`| z>{o(B@=?1Td&=@qoTcOHkVrs4)`m4qL+hhA76}q5Z%`8-z(KW|vUZahSe~35GYELH z!3B6ifkXmw1W^WGfNY70FUvGU83SZ?NF*R;9UvR_S+g{=1z5ARIv|mNX#D1bv(RM4 zC(uGvf={ZyYWDWQEQ1oKL{+=e z-)^#8yRW`uQ@+v=o8c5MlAVb}8BXzHf=Y7VL4ZF3s*`094S&>K3&H46;g51C(%_G} z=Rz#H@JH{Q4qhAX^bd!FQF7_F+iA7c5BEl|^{2C){%eDY{#`c*+F)z0&;}F$q|#!+ zgV1zq0zmT7bxl6uT8BYT(P&MMz5i(f08$BF2~Z)DRTcmw#a$2p>3$7>l&HFF6sjEw zDW9ckN~kesJfR1{Zy0^ibXk=!OEe8S2uMiXHKQ*HRU#6St<72fjI$&(J3vWhVhuuP zy_ceLmlPzVLU)NpLh=dHo--Z^Nl~TuNW_U0NfP{5Fi#DNDx>dgZT zsc;b2dZWRRN{*14asn7quO2X@lANb|x`hTqDjXqWxzJ!pWymX|Mgc=ALp_y^ThFTq zh6KjeC7HCkfd)egk<7clOaVg*jgyMX0Yh3}cfgR=Hykje_0238()xK0Fr@W$cg-q- zA%Te}rqw;HDerNadb z1{Mn6;R0x?1Ac4mt8dtZB;Q4@LZ;zyyVhI^7!nvRk-IsI!gnPEY5E~-C}2pzAySJ5 z4T;_1f(AnZ1Dca5s~HFlc92Kbz8pybLjts2^@p0IkKrl(Env7rlT2?GpC)bkHHiX- zRLnLMlWrWK^tXWVs(~T-$kD-&eAI5kjxxPP_!Da|1+!5rgCXV7He#yBwm^HPwUE@l z0;U2%eu7{~8$qm)4KAz@3K$ZY5)6XyN7H~PloPsO325p>!aA2iSIxsj^OA}5mP86f zv;h@O-(Dcyk^p=qZo%{mq!D<^!&dnS{4X&^;D#l9{|%U?Q9f%`tH%cx*-#W-0SpPu z!ZM!5z=4b3mLPrDb0@Ofi34<0p+ei0{f&q;mtOSOXOV`f#DPTx| zxHejreL*7v0HUghU@ny6Z;6FS)?Pz1Ov993$G8+QBru^wJx!}7U*&@l45^syIqoV; zR_GfoU?SFTtqg{gN8C~<5EJWC8VL+3XHXLp<2!J2$ac*F@Wt9jKHT3IqD46J0|PR`k!%jG!d;6>F)_x{AjZ*Z}@2G?<7fb5#dJ0>o9F zH&2&zg^ROU%*5qE1f-Pcl3^QeIej?J(kxaD3<HvfP-o^W$h+2UBo##W)SLRlMCvE0)_Os(BnFuL<0zn>vQdm@-!n_O^$q^-RXx z+$;k}jm`-Vr4kK`AF*>GoB)xefVXImKc)1QgV>kP!-0!Sk0zsJSc9P zrs+C%f)aEYo0cV-+B5>VR*wx_%QlW#{?g4vX-$KZwwnqm%`;w#D*XBvO;NqeD;pm| zJ9ft!xs5Hlk?YiG71gtj5hLA@}lf4^zqrpM{aJ+MTFkH+0 zV0f@*aj|wV>R;FRuCj89{bZcFJd6`wG-L_9)~#@g((B}Q6NqWnQaRG+_if3XM; z4*R>q!Ol64b1tEI`qQop?vx)`eKd8nf1xSv*&7`l&!+u@!L{MR)HP6FCT)p4wuZu?ImK9yt3#Xnpq6Q2jPOptgNC$fssiI-ntEN-p0N+iTdlk_8#@pSsw zXf_!i-F!-Z`Q&IeoJ96f$NP-5y6qCypn$KH|DfOxRv!?j<%T%kj!-9879*7Fw;Btp zEpc};quI@)Vch=w;62;p!?nY`Z7BlTgf3ncoMKy2!+mR~efr#!vd%c$o84Sl(qD%g-LU zv-!S;$kHuS*({K4suYUR_oNf+KOq@6-*wDHd-q%(Y`^;YV6wA(G?-4s7fKZ{QG75i zTep_oSmNGsz&I#59~9fEB7w#0Q}9R8-}E~xf2_ox`L>gP<2mc^YN<@9O4tM|C_Cqg zkVHiRw@R>}7YIfeYK;&4RDdaKb`y;m$=8_uHdPsqLBNdNxfvyG^3;_S3PDiwSc%&T<;75BbAp6v98hu4NX zJA0!W(#|q3{?|ppG&%gMD6DCn2ijT|6GCrc^g1m9@<)_j+sSaesI3p6HkOr& zs*gW}|9F&-Duvn#<8xnBt{nbB;vK%h|C`nOg_zC;Q+e_eg;$Y1lKx@EGj z_^0=aQ^TX_-obeEuCILUn=eQ^<`^#5Mf~-J8fqHf-@xrnyQmRcB-P1FBNCrF_jL`_ zR*m)#-kE&$C%aQ`i@mdU;$Ky(UbWtw*Zg%lsMDMz(UcXxKYQC*=pU4n_OG54w_jl5 z4&6(w&sSZ@7os$B`0yLUX}@@*2NkOIm*>id9Z_hjm1A%SjgH!%#dh%WWIWi}miLOI z@AhPx`|#w(wBK8Km-iQrxo}PWZ!HQg|0?ceskeTIM!o6!ce({Sx-lMvtDUE(w-gKq(`wODB>D1QN@hQdyh$Z#hWTV^t1|xVCju2IXPI4h{s#-9W*N4A zhUiGWyChyDzQoR68fWQ$?rq6-=`8TO7s{OVX>;qOU>7g8zIfZj6XdB)&2p-1T9BvG zdE!4jDegd#inr(HZT%X%c5)<|7cZ6!;e+bbI?=wt?ee@?V1J#K9{vv}>-Hz&>vnDb zpPH1Go`QxAW7b(#RqOUAxRp9hMEe_TS?RAWihqBX=uY%btxLsf6EvkN&w9C-Lg@@H zj!pG=d9Y+oOfG)EQKmUb^<6oU*yd)dL`tohaz1=ZX@QjvMoNF24n ztiD5>lqXJ;nYuLKoc;cTrl`Drf^G7x66er?6;PT~qK;M|JpJo#P9vMEI`ZBL^NpS| z$ESWloKISeI*B@!IDXD0gi)riEtF{_@DH~cFJPInw*T|aykQ$;?Y9MpH}*pJ%t|YC zt~}yjMmAA4&zw%baTM6CNBr~x1r{!l|v`mUr?spGpRSgG(9;7%>#pME?# z_Zc1>Cx@pSJ+(UyLq39F3>CE3@Be z5lUW?z^p%*&7v;R@kq9e%h%IK@-~HY6F^D5QrS@RZP%u(I-y*q+g~>P9a5Tyks<#j zuIEtgIPzS!@D9DFGwFP|GpT!l0Z=146g2Jev9|TQt&eg~EhgN$%15?TW7Fmx*Vy!P zjHlgqO0UCzO)N7+s+XiVQO^avsz5#1*732E#mIi360x^0i}TD#;#sAogKB5kvWx!f zPttclFH?o2|H77X&n1$^0(F_jw>#n7JJKy#{<_%zc>+{xvj z?<4gNmJ^^Lv_b|c$Euu!!ma9i#FQ2d@U2>P!ntBT(^|bSfuoGRYn&s4q1Tf#nAb{r zg!-u?J$q&1MK?dlTu;3}xt@CBQd3@0U5KvhJI9oNy~zM&oNQiGxoh!hb#nZq6Fhdp zn0)VxoQA;B3PwYa6&b@?fc53J0tv@@!(gnyzg%FDoPUCpV>M@^>KPdW%4k>2D=un1 zzOAw^PUR3Al{@WN=F3gWJzQr+oVP6*c86Z$q`uE@sS?5%OZ7dRmdepHhW+Lj8N-VZ z&gyd8sbnQI`xra*CDJvWeC3U&$8hYH;H@~|I)t_hKfAFpiC0EqK}7#1FTnaRLwaEcxV{=<9xWT^_LCoBfF&Ccz05-A)$yvPKD9 zJNumB*-5HjqVtpB0}Z0ItSYdo^#GAZV*OSMGR_l9CNULf^lH+d!7 z5xn>bHpAgef-w>?Z?AHRCpH&JH;>13M2|*Wh!^@}k(keCOD`OAfdh{P zdEn9X9C-Y4gFylP^MqY9z2sg(x3k`QU@D*1ndw&{);E-BR<3^LcsTs*F#UOQY0W$e zD(9De<9c~n{H~sfYE3j{J_ZS;mE??rXLMk$j@4QG{yKDF#wasvd#yq1b!E8;;oIiI z9A#tJAAXUsx-hfE+%7Cx4m}4<8JyXLIdf#t<;1i*bqxwsk`vGykCj9K<_7^67CN)% zbYads%eXM@)YjF_(zbSC`sW)-=Va=&WQrPgVKM`&Cs|_jj7G~C9hgZX)b^||1G^Kh z)csVGVLi-_(NP7%hEJ_nUCt|sSTWA)FLIV>9YtewUdhaK$656*L%6`)t}9_VU35hD z^>;YhWdk#0(Bs5$8C#`9SG{G`upHNR%(u7EZ8@_niQQ_Nwets%QKef8}24*M|arPs9jpVQI*#YHiniHdFZi5fGk_dexNCVnEY z0PbO*o%ci+z)``S5UBp-Ns5IkAxHc>Gu`+xcl&6UfXU);re72FfIF<`fIDoop2TVa zk^!|_PFoCJYysXGh@ zx{ATrRHP<7r(@t(nu1MJJA9!IAR!TY1lst?=R{(s2`lf?# z3ieU<#7yYpGetKSjtyTd_Az@W2T2qhyAp!Kmk4daImluwEVs4*uTyAraExv#c2wOl-V6Ws^jJ%_&UcbBcEf0mrUzYB6Cnj$IB}lL=jf ziB3{aOd3?+@1sEBwj@FlB^do^n=)=7oUCc>J8iZwH7EC>pw5(1-44aX=WcVoCD7g7 zW%#X(I#Xs4Z@xr9ohdUICG>FQZvanpU=*(bK0^=in@tmPk|_%63>Y&o8$9YvfU++K z!S_f9h@-M93hGQ?T;HVds51fLp7@QB4LVAzPe>}G&IFjglY=)(4>p))K!ErjfE>#4 z&ez9q=q?sO*eH_Ib=O8;*8;SSIAW&tx0nMs3hE5tBREF}L$4?0;EPbQ9cMS%T*vY= zKR~tuP*7)p5R=Vos@M1?3P95c&(`WTp_t~#&LNCx2BM(O~ zu<%)dIit>$oFJtL7Em^0KpE|dxhA021KoHZmP5Sf?7L%k!htXabp}i^zQS%K$7rdl zqRs$TwWXR9>P*h~by=#c$S4O>&VF-Nm)lOw8FdD%h;Mofg%S-?l3C_9;XuZ}8nH3) zKu2()&Hyoa=YvO`DWr?P+7C#2kW$)&YP%zNEr3|^S&SK37yLRwohc(rE|cPPxiajH zT2W_!XyV=HOoA~Iaqg}#h$mqUW5)B$;LQ-mX{n!HSp~t66Bq5Frbp!WbeXgfWFq))Nm$)fhvBtjO?J_Dd}>cc_ps zof-|qy+_KG{ql2U&}Baw3+sVL(vN30W8lDp6Lkj2&_B;IcOUh2#NK(OSNiUw{`q<` zzjAjUDX24GGFlT&QIuhe`i~449hj-8D?!#-{Qf$0Va6ykY6x10o8DFx*czSEJjM0IaBtnU?dcRI61G^K>ZYiiUV42O1(dL+8 z!)r^N1YTs+Pr$817?XcdP-k++*V2j=jnR2!MaHl<-~w~Iu7u@uZ4lWP1$722N??Wz zdYl+^W2=dJb5c{`oozEDfN}BxD`bnIz3~P-ha)$#8P7 z^*FP`+ZoiEUV^4mJL(Kj94~VWrBO{cR!80LipF9A?ACxf1IEa(uD`)Fy;4wT02hTb z#iPy?j^PW@@TfC|!=O6)DX24r9QY=RN1Z7g209YTp^O8-;d{BejYgd*B)-s5nDfH@PI zLQD^hIRmMcxoMg7C*MIgEh&&QU|&YY(n&-4R+<1g6JRj*^JUrvD3CKDMI}Jagh&w} zX9A=uL(T-~daqDYAZG&f@vcb`VuzeD{AHCYr$vAQIa7*ze7@`>>kBvVhsQq3s8EUm zIpZ7DHy!+)XCGxxz=ZB_R&<#mXMD`w$q|x#7_<(-QHf^YVa_@FgNB^xutCn)%j!Us zWV7`@_PBhA^@-uZu>yFytw=LvnXa}awKzOm65adWsTpY|M?b+yy@RqJ3eJq$ z$e;`-Xw{FGgD?TN^;;_0t+Cs z;G3)i{ART@CxfEk%z!Ziv%%xc1SmU!V{B;gUC{yJs8ouAGZPrsHz_>MOn|s2dPC#P z0MhzQq%zJ7psW<=PEM>SUDsfm0RiGq{c|YCJ6|8Xq5D<0?;<%h?&;k zVvgJ>I5U8c;2aqYy`EHMoS9s%WBHkBAbZv*I5U8qWb>MmHNI^E&@{rcwVF~vbKI{( z7}Nek!I{aKSf<>-Xb7s}%m9|b`YJEXX9ebrGgEScl*(5?*^B{Yv@7OffLaf9t9@7w z@t(WWN+=YZ88Af$s)Nx|WkFLIEmcAoW2wG}fHV(8r7fkSXN;CAD>BMyl(XNQ)#bKR z3BOt}m}uooO?-zHSP|d!7z!mCmL%hl@L&-kc4K1_v))-*jQ(9cEt<3OSuUIzFs^q# zc$}F+x(r1V7Mz(Nqp{?_UU6o?tUa>iGAT}9E5qKX6=w#BCf6y`L`YKo<>j3ca=K5=c>HxHT>+bc2pT4Nv(ILsK+1={{aD`yE652CBHgL z-Cu_e%a~w>O|CVlDymFSHRpsXlQVv%5X!JW{32s@SZ0a29oC#tWx%o+Q}4`?L6`k# zE36JH@%LsH0f!YrmEm+)ZnfsUBgJD3c2`WNJlvu8J6i1B3B=oF*7G6{l`UiDs_PwyA760x(yuE#VcziIJ4WrPdpWW%KTi<)>@@Gy8B&;QBe&QpQ zCUoxM7l)(14!y+k7SGMwn47+l1EAg*r zKF37USm6$|vMSMmKaW+mv+@o8->lv*#B?^8N{ot|N5i;xc=Z9XICx`kdiIw7U(``7 z|M1lCXu5YW9=+=;U;E|@Z#5dddoB;QUwwTr*;zgsOs5atvpqgsJKWozjHl!4Gx1$b z@ko`4Co0f-mSK(KltY!1UTQNF?5>=QDX+dmoRp%R%+8*W@^XYD{`V#pNW$40YMFPu#z`C^!bzW-Z>w;!GSn zPWvi^+}oLo7r1oO2u!r?(}Rw5jsze^>uQu{1s!P6o5_WTlnn8U1UCJ$tLM;OyTA&x|KG){bt<4{YzQ z$w*)%{%KQOVX$w>L>XyyTOa1)tNn2&_TeJz4Rn5Q0bd~N#i*eQ6>B3#TH1|sM6KUy z=YF|?Q={Fr5Z%}my)fK4-X1>u#&Ft~32PrE6iM0w-AY(qOI>^GDNbr_G(~~jAD&nA{nz(qyE1D(n%(RVq-t4 z=gkH2Q1SNB1>JPx&Bm|F#NgsZfd)A%n#9A|H9&VQ71N)`5?n+q!9QLQwayZopCk1c zo1gp1!>Vs=-0>$9MeF~>@K9Pfd8;z|V)Y(zSDdn5R`%>(v3PAb8@%hpi8ntalc4n1 z53A(qK5=qzaI~BL-ZgiBtBFnKQNn0F!{9ztNmL1!W%HVqi;a-^Hq*3ok64s3+2N(e zB^lA5as~ID75#)pVf6ZmNeW7-tky=_PR!-#li`^hJ)qWJ$(=#uvU_)o%g$YjN-o=> z9BUh^%MJ{{>RPyHuZM&k`p@K|woX7+c#WWPW)fN>)?V0Wd+u zj}nmZfk1UMV-PRJpauy#5d|TrEEM6xmOrd@jWBV=KU)xW&sEf}@hyUr@-A1vsZz>n zUi$W)tue;QEyaBXYDGCkP%>Y}Lb3krK=#XqPXcK8Od{4f2? zA*&8+&Nx;m@^V#G2$F6)O>#x8Gag^*(2kFKEb;J$C;xmw@N?`9R1PY;KvyufX>vmr z^b?1`ix-S$1uxfYotcg%Fk#FuEnr&%B0Y54M%_NtS8JW76V2?^JlNcq4-?`1;ewdo z{KXu%N^6&WS&i$=7Ff2~N3zXLjgioAG{xMc)cORsBebIWm%bYS?H^hP%5Y&3^*yz*`^gKS*iT`MtwCao8z^l*XFlVS6s&y&2K1N3C1Km=8KGHluE z^D%n!WiBMh@wCj@5cHx@*pnNY$QB#^@Rr`=VKPRFf1eJv%wB0Sf zC!@q1w5Tv~1*IpW1S|Gj@jV$O2FOMigrw8ET*gmj+ylNRqr?nqJc*$6WCS=4NCe%J z5%52D+N8K`MG(G0pL0Xf**mk$%+ zP+U-y=#nwC_Nd4ya%L&?#U>0gm!d9d*S-SV8 zbYKAVcZ1xJIbtJpVC0Zy2>`Mugq2Y-It48q_zsM60^rU4DIFLAKD|L4x&tGhxHp_Y zcVLto!-^;f$!u2JZes@qn5=e|RXQ*LQT`5$QbWvXjWBTqr2_*wS9}LXi2<_F1tG=s zYP|7N8TUX50loC+6k||2Fak0UNCe%15%52DodDf|QOGhq*r0S^0Ct?t0uT5$9u=e1 zAaw*F&SWZdU;tGJlE!ynV8z25o=`e43OV)$D(DW3vcq^o7IX&&b~MW|qB}59W9ew3 z?hcHir{4tnKovDjr#^!2z(5W*_vOPxIFt?y^8E1~7^tDLQ4hKUqj0Q%x5sy26pl2f zNQlybQOJ8hO6U%ZLPuyvNBjoEOoJRJwNA&czV!F^W~+Z{sd47iW$<=#H+iC`+v#3e zYMxO4TYY3{5&rRpe0)cNjCZ^1SC$*eABX-8ji?_5>UJ((Svs+pjQEkIlW@e%ekUEV zd*xe_0n>lZBwsBg|6WS|UHukG3Oq*C(~l8tS`&Tg%<0R|<1l5MSAWiVo+i4|II;RC zmzt7``|T%@y8T{%{mL0vV19PBn6|dAEa~UWR@Y;q&LZ|C*<@Why(?$*AM0UOzv)cF zmPS_!HTu8$$kxRxi??ol?$+ubKXaF@9p^KD)$MemKf86y&+CRJcHzpBUX+jy>4CeQ z3s;skuc3ced+~%bYZ*yVE8NzfHeG~(E~m@TF$;Gh5ya|CkuE>j zIB}Wn>A&Q~yw%S|JL|!BXzp$`-dwuO@5JmlA{+ItzS#gAoV`5P#A)?c6$?!%->-*M ze|505AlFp7)Hlyw{wU6%?XH2iN+)A+`0<8%Hfr@7BWt-5d+p1uSB8^uY6)MI&v15E zx1-1*V&TlC%O7WB;evW2R5Ni5G-v5DS9vN&V_)+OJBK5Q)t@(Niid{>%WBTp>Zn7- z)J^qib>fj3p&Wd8aHLK>`uc!1XPu^C%;n2nk|5T5z#*K)`LaY?O60*YNk5?)Xkfg$l&=;a z>O)@1u^cQuv}jCsk^W5VxFa8W8z9tpoT)u!?Q2kCCOFfgF_{h)1on5G5^hmF45u>Q{Px@`yj`czHs7| z=f#=$VUTHGy$`atGi`l9JhXJGbwRwl^)B~!{hhtTj(A_Io%~`jn+->^z454jIGDZ~ z{kHqcCxrf3_P4Fe;*zBM{@jlj5r0J7lMH<@xHdePwtAu?iT-KvK&zGch2qrr)9x!T ziAPW`BREHWN#lL%esSNDXx%Lqm)`vHePUth&81hqO*~zNa{pj(b2u4|Cx?2ByTxfa z#z{r}74Z=!>QNpmIk#wkOgw?2EepkKvA_#>eo;JDnJ(mZUwjNrHQVY5-j7nJ_^mVI zUMX-93H+>*`nGtHdn|WCXJ7fWc&h3k%EQ5FEU9U{EUxg3;4;7TN?$yOCVe!X&ickc z`Fg;{>38rzz@ z%k~n1Znq{&1L;C;@s_5rhzKQ4wlsxfxf7yT8ZZdEr3sASGEcEI-LN4@)*<5t@s=iM z0LVXn%lZKNp#qUDjgP#`HpSA|bhRmfY{#yNVrhKjUA8HfrdN}tfpj6acuP}QVSy4S zTbjbL+zC-E4H$&o(ga3unWtEqUfq@^X8_1Q-qQ3e{Hhva8XtL=ZHlF_>DFLreB@oW zmtJ{7Ttr2X2c!OVZ;JPrIFBa;___8;@nIAhe}>z?F&XT5(%fqU@j?~yAZ1ho*^c-c zHm7iAU=0yT$1}hu2pN|@;q?%wBwH@=0N-PQeby1vz;356%2!<571#%r3kUR0?d-Lo zz`o)QMzS4-S{Ey!SDEm&Pl;zy+8dQRoZ^8R#ygFrYdhkjX!^lyRk?7Fid_>|`9=xM zN)-<1of^eTfsyL5Qh{N-)1X+XoJ`kVn~aL#tyIn^ftlbp zxt%^{LtU`R$yUlYhIaxKD`gYc7Q|iyYk#O5n`B6gnbl)>CqS`MFd??t&!`aIN8+ivo0&;8^b#R zij}g7YYSqp0f&|Hjp3cZ(zU0=2T&2Sm85&(?i(==U&Hsb9u=46d_O4K zt+Uto6;-OA<`{Sf*)z%;e&yS#;&iXJ@nOvcn&A1=xuRFSsoe4<8x~R z^~EQb8m*VZo4dJmCSP#*zuhP35Ma7chkmYs*Hkb{ESk+!sgyyDe*T9aIH!`;ev;IaZ32VD9?#pdb-6$C2Xl>%h#3NGFAIew7 zVq^XZ!4JsLiu~$n=~pRx5V#6Q(OR|Z*r7l_907KRvTSDIQrqi#wWa+R1p^vF>t68= zX_=O!k2n+i2^m5V6wfJ+`CK z4m`0i)WW zw+M+N5V%ept&pa6pZjeLk=PT1nGE*&TSned9CAv=Ap~~=JbGe?=9`D zeoOLW(+{3rCt3d1}QA{@-W_Br!i<*V+9 z@n;F%7D!gx6Uz|R%{Q$mVO&LW$dGYis{W&k5_A(3#bERzr`v|gvI?o z*ASPA&AS~WXytI`b-LhsUs0&=g9OF3lVfdv_zmweXRKyP%`OdmUqd{DTLv8|gAmjc?bSWe`%dj$&9dS*T8b6ggQVCvi zKi{PmO1z;@Gb#$dgjGPF`**gVGQriI@$!C4EUCr*)=pf@jm}6ZHUp%V@W6**>n!bm zM?-u7C(yX4Y?V7Am2MPN)Jmw)J|VF48P-xs*hkQlRu$899~;B4z-e~9<-x`MkBg^q zYveg>d35hJ6L#ug`frj-z0OIWvzKbGHljC_|B4y97y|WGID)};)SZ;d*h!>}YAoLCbg)%1BI<+u)Vk0_DK)Wq4ig8;;aZB2UBGru!(<1e) ze$bb@dkJS`Knh0Ip)<0V#NE@QgS}Z#35VsRu@MXC=LxOUU;=J`gW|IGCI><^4Oq{Q zroXF!4OMi7IJ4okC8~HDA~Pp+n<}tGiKjFBt=LYM+iFi@&q%^)k__k2E!%brgYR4_ zoo2;Fpq!b3MKE@c-(OJ-Ge|NGgGNzU4Ku-R0x4Q{(>e&SofO4)16gFYa@_$pGt}V6Xz{%ov$+2_HqVV(4XjP9(*{Zb0x0-=3J{F>la=CVK@7wL zapWnRb|CAM`^k0p`UhcH8HqEYTpKX`S!JM|J4TN+Qv};AJw_} zz;a4S2_b-heI7cEae~FZa-rN?gsM-_2{|W8rS6^PK_CB-HVEan}dI znFA%uUqX3LnD6?usa_UuvvM`{zf8#IfD$Ke!0a-n4Mw>));Tw?dt*{`FXHss(?k#j z_C{`89)79H`G#S~su9OM(r85nn&PoUx-b#+7v&V0a{OX|I&q=g`~vfGE2HtH_~{aD6mVPCdcRqxPO*YVSksPkn7!B%qHsfg)+~qzc4xx~T3k zK40dZi_1WNHhy(D>R;GWN_LO9Ti#|oB~BgwdtwV$j~ikQAaw9B^+*DLlz5b{kB=_FO$K)y*4Krjfj5`8Uzip8&q_!Q1S zdG4!+f_denU_?*JGl3$T+`(kCYQR#GDZWcV$uog&)Rt#b?jixTa$_TRlT~*&w7W<^ zj5^4+4!f%JwzN+u?I zk1IG|+Na_lcrAK!U*f=HXY4yJFiWGuJ76!9^Hu3W(3nVIuX~P*WJ^28E@JG!v4S>T z$VT5VRROuR9Zl`3xld`S^NLchS?vK7qjar2O-uI|3B3})$Y|?qN>oadYp+vLnvsC< z(cBx~!e@C1;U&FBP$m?6O7sr4UTx7k<&;R( zU}`-jqV%Nz-SY$*EMW$v?+`F&f-Wn#D-4B;!eDjuM=h-KdeK#SbE*Jk zu9Mpi+TEvAjKQXJ(ErtH*B&1;S|k(OTirB^``=IKrU6=z^%yM9f{--8alK@noywAW|wtKtqj5awF_gMya@-%8^=l(Ndl% zb4+h9hR*l`IxX=nO8~Lk1f2%UvM6i?Z|qR|>VXm~b+c|Z83Vfz74OgpekBfPvc9b1 z3C)X)QGkH~WbSY=TgI}E4bo9-ub>N^#XM48w1Kf|^hkuIS+%<}d7m8{4@4|l2As1y zMOxiY;+%lz=*&HVvD0COl_s?}0~_KpX#@s``v-`h5vMmmrp|;dU%am=YyXZr? z5@7d8)QK!iMJg}{Krt9<^q>S~nzsF~bDH?4C^B`&=wPF2Zt0Pv}J9<9KD@$MK zH?wTz^)>h?&N)tt(4ELmEz`s=%?Nx7y+FXaC1E~)J&1B_2e|WAYZyGk66x_1kTt}4 ztz}tiZ-OEOZa}7{w9)mZ0F8aw1dDvlF7oZ}OS&w}ByTf+ymFWUbUYiP2YZOh6+UIc zgnZ6Y`}H(_CZ^Pp%E2FqwT|h4hjU*TE5N!h{*2hg?p%;$702F+Wf*`<=UDyLt=_0sf?H;*7TFC- z-w;a4SWGXoud$AUO%T)A|Co5VSWEO{Supid%dep+t+)Spx$jvTU~~V5*eK@SqeYj* z2W5QzApMmuQgNl|ia(3mP|^lq_UE5AQ1ePlF?4;LaIXRwzTA5i=%z50L|8Gt(P7%_ z)_`BLGcAEImCk8-dZ3nKnIVpRj$C~YEJg)*ZR5r>RanAV5Xf_q0NRqB)^&Ny2_2t< z-41$)l~w0)wUiqOfJ1GrmTK^Y}4K3)9oDiBeYQkTj76hu_6 zX75*M8??C5ZG*XE#knTobQKu6OaWEfH%jBCm$Gq-%DB~C!>9C%0mhA2w=|Z?@DaK* zz^Xe{PTV6#BEsbH7)bs^XYK^qj1b0!v%azznDQqDb)+B@til?1R5^;Xy@x4i(1 zZYjZ)bMD4oylV1h|8{feTTO7*AMoDi=>wf>ShRyroam zqrz-qake1FA*77JSgXL$n*}V8UaRy)hR&{0mazKPS=z4@!E{W|V-&%hmTy2-T@i`9 zgFZ$CHpeTix;;N3;RB3^6FvfUps5LU>&!uiTa;AacX;avo~5uwIc^P2-7;EC#+~Id zm^yBS#W#)(en9P0v`lXKgn-h21ALv%`sdw$L+}P*2j_YNji}R0;F@|Dlp$wC)S)z2 z1{c8Ogo@9?JyY(Yz+a<2U2PTbq9gy)t(Qgc5V@)&dbtc>W666U9VNcV+D5NIGv|ob zMlH2K#bu2hY%B5$afwOiX(bCx@;UVqXHYY%pZ#FezwR9FrcrRkhF?Nj44h2c_w*RaV1mDR7kLH$b&3YyYUIdn%g2aZ6LdeUp;nRXCY>A14sD9Tjb z%XC~3I&OercgtkZlB$)k&Vv#}Z%Xb4hSwptb=1hQ=CY})I);qqO@|js5G*mXLR_W1 z9~P9LNm^a#T2ZF-xT;EnogT9|9!lzN*X9EE5V$Pn@;>E= z07Mla=283=W(q}dg{p0TdqX@_+=*q!@J|h=)2#QddPGbS6WG4>iiwhy0~SLrp>Qe7 zW#v+g=xdpPM{TQ35UbRUCz54l(bNLVvsj!u6Ob*zmQ_h!Xt~`8AjYak*UZ7nIi*`A zle%l>yJdP=vX25JrL&C$>HJTKZ;~~YM`5>>uxag9<70oi=C8jhV|!WNWQOD^HWbai_$lh`lMJZhvB z#m`k;8znIXTB~ic(FvuywI`;07$&Z*uKmIntEVw#>++taDEzv#PkE#ZxEcpTqHcRn z3%Lthij`QI5?18(`m(jMm`5i2R1phSCoW|pjAc@T4X{H=h;4R5TDuP(;uP%ym4whL zAkn?a5Uw4ULM`v3`-~s^%&`OzvxF4+@6)YC>2nt;t#$S>J!# zaqL!kD1hd!TYz-Jwe(B-`n0#BPkT2NZN@qcCG`b9Ala90cu<&DmfM`1ESg=RB+100 ztWv&$3-@-{fE0}}bz;(DD;MSG#hE1qq7}r<4y-}#aTMijqFeSg9=@&05|`B3#X||D z!Qwlx2}*pms{p0;m9x9DrvQZ(pj_?impCkdilv{J40g^vHyOS#+&SJJM&EU6KM0^S zO#%_$mI6?f2f$2HSoPu^*eQuiw9X23rzGWFNFeI06`s>Rl))1EWDJJE?Qoy^GtchY{{CSbz zB3OWhS#W#cFEqpx*tk*Mbb{%Sc4QUX*eD0?L0M*9Elb*>4%{i#wIG*fFDZTH*-(@V z)__IO(Z$Ij1q2z?WU@7MmSc5_hXj`Vws=V9JkV84y~Qh5yC}H|&`fPEO}~#)O^&LJ z4EWJcRRIRQjjnt=ay}ix^PZJdK4j&2cMWVCw~Ng7GYp%X`mWx5O09$!;(P1RP>uxllRBHg^T6;|(&s6N9RBHftx2amgcM@Sb<&{p<-j75njsUT1rO=>oBt6rp=ZlR#lpAlQ0hd!5xwVv)$2sd z5%dhYd>+^?EILVPBU1vazm1Gijt22wt>tKXEB|Jaqn2ysU*3PdOmjAH)3fob!%_d@ zMde@a5qHaL1E*vwyTq5nISH};~(I}gV@#|Oi;%nybKYcm5xu;w&VihS@HGt*G z@}>6)huAn{t}ctfD9>4!5f}$^mz!tqyL`X6XM23OcDT1a8BfR8XJ;Ch9}s88lN)PC zH)rF?_U_tXHXO;HzE7MT?5&MQqc_%$Cd1eEhOb9PaBs9Xx*>o2Zt>0xnYGdO?(5r% mnEsgc=ieyW**_n?aWtIl9S%pc!GZkLy+Y_u9Uh!n{Qm)qSGtA( diff --git a/onnxruntime/test/testdata/transform/recompute/3layer_bloom_optimized_training.py b/onnxruntime/test/testdata/transform/recompute/3layer_bloom_optimized_training.py deleted file mode 100644 index 01be120903ea3..0000000000000 --- a/onnxruntime/test/testdata/transform/recompute/3layer_bloom_optimized_training.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -"""This file is used to generate test data for MemoryOptimizer tests in - onnxruntime/test/optimizer/memory_optimizer_test.cc. - - The libs used to generate 3 layer bloom model. - - optimum: f6adbef5c4a6bd16a17e3b22712028ed5ae3709b - huggingface: 4.34.1 - deepspeed: 0.11.1 - PyTorch: 2.1.0.dev20230803+cu118 - - Change below line in optimum/onnxruntime/trainer.py - "model = ORTModule(self.model)" - to - "model = ORTModule(self.model, DebugOptions(save_onnx=True, log_level=LogLevel.WARNING, onnx_prefix="3layer_bloom"))" - - Add below in examples/onnxruntime/training/language-modeling/run_clm.py before the config is used to load the model. - "config.num_hidden_layers = 3" - - Run below command to generate the model, there will be 3layer_bloom_optimized_training.onnx generated. - #!/bin/bash - ds_config=`mktemp --suffix ".json"` - echo the deepspeed config is put at $ds_config - cat << EOF > $ds_config - { - "fp16": { - "enabled": true, - "loss_scale": 0, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, - "zero_optimization": { - "stage": 1, - "allgather_partitions": true, - "allgather_bucket_size": 200000000, - "overlap_comm": true, - "reduce_scatter": true, - "reduce_bucket_size": 200000000, - "contiguous_gradients": false, - "cpu_offload": false, - "memory_efficient_linear": true - }, - "zero_allow_untested_optimizer": true, - "optimizer": { - "type": "AdamW", - "params": { - "lr": "auto", - "betas": "auto", - "eps": "auto", - "weight_decay": "auto" - } - }, - "scheduler": { - "type": "WarmupLR", - "params": { - "warmup_min_lr": "auto", - "warmup_max_lr": "auto", - "warmup_num_steps": "auto" - } - }, - "steps_per_print": 2000, - "train_micro_batch_size_per_gpu": "auto" - } - EOF - - num_gpus=1 - export ORTMODULE_ENABLE_CUSTOM_AUTOGRAD=0 # GELU PythonOp will be used if this is set to 1 - torchrun --nproc_per_node $num_gpus \ - examples/onnxruntime/training/language-modeling/run_clm.py \ - --model_name_or_path bigscience/bloom-560m \ - --dataset_name wikitext \ - --dataset_config_name wikitext-2-raw-v1 \ - --per_device_train_batch_size 2 \ - --per_device_eval_batch_size 1 \ - --do_train \ - --output_dir /tmp/test-clm --overwrite_output_dir \ - --fp16 \ - --report_to none \ - --max_steps 10000 --logging_steps 1 --use_module_with_loss \ - --deepspeed $ds_config - """ diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc b/orttraining/orttraining/core/optimizer/memory_optimizer.cc similarity index 91% rename from orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc rename to orttraining/orttraining/core/optimizer/memory_optimizer.cc index 49e026ca86bd3..834e5ebb5f6f3 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer.cc @@ -13,7 +13,7 @@ #include "core/graph/graph_utils.h" #include "core/optimizer/utils.h" #include "orttraining/core/graph/recompute_graph_utils.h" -#include "orttraining/core/optimizer/memory_optimizer/memory_optimizer.h" +#include "orttraining/core/optimizer/memory_optimizer.h" #include "orttraining/core/optimizer/memory_optimizer/common.h" #include "orttraining/core/optimizer/memory_optimizer/optimization_planner.h" #include "orttraining/core/optimizer/memory_optimizer/recompute_analysis.h" @@ -30,17 +30,19 @@ constexpr bool IsForwardPassOperator(ptrdiff_t op_order_in_topological_sort, } // namespace -Status MemoryOptimizer::ParseOptimizationConfigFromString(const std::string& memory_optimizer_config, - const std::string& recompute_probe_config) { +Status MemoryOptimizer::ParseConfigFromString(const std::string& memory_optimizer_config, + const std::string& level) { optimizer_config_ = memory_optimizer_config; - ORT_RETURN_IF_ERROR(optimizer::memory_optimizer::ParseOptimizationConfigFromString( + ORT_RETURN_IF_ERROR(optimizer::memory_optimizer::ParseConfigFromString( memory_optimizer_config, pattern_subgraph_to_user_optimizer_config_map_)); - ORT_RETURN_IF_ERROR(optimizer::memory_optimizer::ParseProbeConfigFromString( - recompute_probe_config, - recompute_probe_config_)); + int probe_level = optimizer::memory_optimizer::ParseIntValueFromString(level); + ORT_RETURN_IF_NOT(probe_level < static_cast(optimizer::memory_optimizer::ProbeLevel::LevelMax) && + probe_level >= 0, + "Invalid probe level specified: ", level); + recompute_probe_level_ = static_cast(probe_level); return Status::OK(); } @@ -124,21 +126,14 @@ bool MemoryOptimizer::ModifyGraph(Graph& graph, Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& logger) const { - // Reset the backward pass attribute for all nodes. - ORT_RETURN_IF_ERROR(optimizer::memory_optimizer::ResetNodeBackwardPassAttribute(graph, modified)); - LOGS(logger, VERBOSE) << "Memory optimization config: " << optimizer_config_ << ", probe level: " - << static_cast(recompute_probe_config_.probe_level) - << ", enable_transformer_layer_as_boundary:" - << recompute_probe_config_.enable_transformer_layer_as_boundary; + << static_cast(recompute_probe_level_); if (pattern_subgraph_to_user_optimizer_config_map_.empty()) { LOGS(logger, VERBOSE) << "No optimization pattern is specified, skip memory optimization."; return Status::OK(); } - size_t recomputed_node_count = 0; - ptrdiff_t yield_op_order_in_topological_sort; InlinedHashMap> candidate_output_args_map; InlinedHashMap node_index_to_its_order_in_topological_sort_map; @@ -148,7 +143,7 @@ Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve optimizer::memory_optimizer::MemoryOptimizationPlanner memory_opt_planner; ORT_ENFORCE(optimizer::memory_optimizer::FindORTModuleMemoryOpportunity( graph_viewer, - recompute_probe_config_, + recompute_probe_level_, logger, node_index_to_its_order_in_topological_sort_map, yield_op_order_in_topological_sort, @@ -171,7 +166,7 @@ Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve // The reason we do reversed topological order is that we want the later layers' recompute nodes can be appended // earlier than the earlier layers, in this way, the execution order of later layers will be in front of the earlier // layers. - const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(); for (int i = static_cast(node_ids.size()) - 1; i >= 0; --i) { Node* p_node = graph.GetNode(node_ids[i]); if (p_node == nullptr) { @@ -188,17 +183,9 @@ Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve node_to_apply_context_map[p_node]); } - if (has_been_modified) { - recomputed_node_count += 1; - } - modified = modified || has_been_modified; } - if (recomputed_node_count > 0) { - LOGS(logger, INFO) << "Total number of recomputed nodes: " << recomputed_node_count; - } - PrintSummary(memory_opt_planner, node_to_apply_context_map, logger); return Status::OK(); diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.h b/orttraining/orttraining/core/optimizer/memory_optimizer.h similarity index 88% rename from orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.h rename to orttraining/orttraining/core/optimizer/memory_optimizer.h index b3e05fd334e48..13eb4cdb242f4 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer.h @@ -16,6 +16,8 @@ namespace onnxruntime { /** @Class MemoryOptimizer +(TODO) move to orttraining/orttraining/core/optimizer/memory_optimizer/ folder. + Find recompute subgraphs and enable them according to user configs. The way we collect subgraphs (in orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h) in brief is: 1. Find all nodes that generate stashed activations. @@ -29,10 +31,10 @@ Find recompute subgraphs and enable them according to user configs. The way we c class MemoryOptimizer : public GraphTransformer { private: public: - MemoryOptimizer(const std::string& memory_optimizer_config, const std::string& recompute_probe_config) + MemoryOptimizer(const std::string& memory_optimizer_config, const std::string& level) : GraphTransformer("MemoryOptimizer") { - // Parse user-defined configs. - ORT_ENFORCE(ParseOptimizationConfigFromString(memory_optimizer_config, recompute_probe_config).IsOK()); + // Parse user defined configs. + ORT_ENFORCE(ParseConfigFromString(memory_optimizer_config, level).IsOK()); } Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; @@ -40,7 +42,7 @@ class MemoryOptimizer : public GraphTransformer { bool ShouldOnlyApplyOnce() const override { return true; } private: - Status ParseOptimizationConfigFromString(const std::string& memory_optimizer_config, const std::string& recompute_probe_config); + Status ParseConfigFromString(const std::string& memory_optimizer_config, const std::string& level); /** * @brief Apply graph modifications based on user configs. @@ -81,7 +83,7 @@ class MemoryOptimizer : public GraphTransformer { const logging::Logger& logger) const; /************************************************** - ** Recompute-related function definition starts ** + ** Recompute related function definition starts ** *************************************************/ /** @@ -97,13 +99,13 @@ class MemoryOptimizer : public GraphTransformer { Node*& recompute_subgraph_output_node) const; /************************************************** - ** Recompute-related function definition ends ** + ** Recompute related function definition ends ** *************************************************/ - // User-enabled map of the subgraph string representation to the alleviation type. + // User enabled map of the subgraph string representation to the alleviation type. InlinedHashMap pattern_subgraph_to_user_optimizer_config_map_; std::string optimizer_config_; - optimizer::memory_optimizer::ProbeConfig recompute_probe_config_; + optimizer::memory_optimizer::ProbeLevel recompute_probe_level_; }; } // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/common.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/common.cc index d522e60125c36..2291d7e4f37a6 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/common.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/common.cc @@ -83,8 +83,8 @@ std::string GetTensorElemCountInSymbolicString(const Node* node, size_t output_i std::string shape_str = TensorShapeProtoToString(shape); - // If the output shape contains an unknown dimension, we try to get the shape from the input. - // Though the input shape might be different, its elem size and count should be the same + // If the output shape contains unknown dimension, we try to get the shape from input. + // though the input shape might be different, but its elem size and count should be the same // with the output. if (node->OpType() == "Reshape" && HasUnknowDimension(shape) && !HasUnknowDimension(node->InputDefs()[0]->Shape())) { @@ -114,14 +114,14 @@ int ParseIntValueFromString(std::string_view str) { return int_value; } -Status ParseOptimizationConfigFromString(std::string_view memory_optimization_config, - InlinedHashMap& cluster_id_to_config_map) { +Status ParseConfigFromString(std::string_view memory_optimization_config, + InlinedHashMap& cluster_id_to_config_map) { if (!memory_optimization_config.empty()) { const auto user_config_strs = utils::SplitString(memory_optimization_config, ","); for (const auto& user_config_str : user_config_strs) { const auto user_config = utils::SplitString(user_config_str, ":"); ORT_RETURN_IF_NOT(user_config.size() == 3, - "User config should be in the format of SubgraphStr:OptimizationType:RequestApplyCount."); + "User config should be in format of SubgraphStr:OptimizationType:RequestApplyCount."); const std::string subgraph_string_representation(user_config[0]); int optimization_type_int = ParseIntValueFromString(user_config[1]); @@ -136,7 +136,7 @@ Status ParseOptimizationConfigFromString(std::string_view memory_optimization_co "Invalid requested_apply_count specified for subgraph: ", requested_apply_count); // At this point, subgraph_string_representation is a pattern graph string representation. - // If a duplicated subgraph_string_representation is found in user config, the last one will be used. + // If duplicated subgraph_string_representation is found in user config, the last one will be used. cluster_id_to_config_map[subgraph_string_representation] = UserConfig{ static_cast(optimization_type_int), requested_apply_count}; diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/common.h b/orttraining/orttraining/core/optimizer/memory_optimizer/common.h index 268ed84f7a85f..85e2bf4f5d683 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/common.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/common.h @@ -24,7 +24,10 @@ namespace onnxruntime::optimizer::memory_optimizer { #ifdef MO_NEED_LOG_DEBUG_INFO #define MO_LOG_DEBUG_INFO(logger, message) LOGS(logger, WARNING) << message #else -#define MO_LOG_DEBUG_INFO(logger, message) LOGS(logger, VERBOSE) << message +#define MO_LOG_DEBUG_INFO(logger, message) \ + ORT_UNUSED_PARAMETER(logger); \ + do { \ + } while (0) #endif #endif @@ -58,9 +61,6 @@ struct UserConfig { /** * @brief Get total element count inn format of a symbolic string. - * Be noted: this function is used to generate a unique string for a tensor shape. - * For empty dim param, it is possible to have different symbolic string for the same shape, because there is - * a static index_empty_dim used to generate empty dim param as a string. * * @param node The node to get element count. * @param output_index The output index of the node. @@ -70,7 +70,7 @@ std::string GetTensorElemCountInSymbolicString(const Node* node, size_t output_i int ParseIntValueFromString(std::string_view str); -Status ParseOptimizationConfigFromString(std::string_view memory_optimization_config, - InlinedHashMap& cluster_id_to_config_map); +Status ParseConfigFromString(std::string_view memory_optimization_config, + InlinedHashMap& cluster_id_to_config_map); } // namespace onnxruntime::optimizer::memory_optimizer diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc index 9b77832abb6f1..60f62a9881ef4 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc @@ -15,7 +15,6 @@ #include "orttraining/core/optimizer/memory_optimizer/optimization_planner.h" #include "orttraining/core/optimizer/memory_optimizer/recompute_analysis.h" #include "orttraining/core/optimizer/memory_optimizer/memory_insight.h" -#include "orttraining/core/optimizer/memory_optimizer/transformer_specific.h" namespace onnxruntime::optimizer::memory_optimizer { @@ -47,7 +46,7 @@ void GetForwardOutputUsageMap(const GraphViewer& graph_viewer, ActivationUsedMap& fw_op_output_arg_used_map, InlinedHashMap& is_forward_nodes) { ORT_ENFORCE(boundary_op_order_in_topological_sort >= 0); - const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(); is_forward_nodes.clear(); is_forward_nodes.reserve(node_ids.size()); @@ -65,6 +64,7 @@ void GetForwardOutputUsageMap(const GraphViewer& graph_viewer, } const Node& node = *p_node; + bool is_forward_op = is_forward_pass_operator(static_cast(i), boundary_op_order_in_topological_sort); if (!is_forward_op) { is_forward_nodes[p_node] = false; @@ -122,11 +122,11 @@ Status GetStashedActivationCandidates(const GraphViewer& graph_viewer, InlinedHashMap& is_forward_nodes, const logging::Logger& logger) { if (boundary_op_order_in_topological_sort < 0) { - MO_LOG_DEBUG_INFO(logger, "No boundary op found. Skip memory optimization."); + LOGS(logger, VERBOSE) << "No boundary op found. Skip memory optimization."; return Status::OK(); } - const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(); InlinedHashMap node_index_to_its_order_in_topological_sort_map; for (size_t i = 0; i < node_ids.size(); ++i) { @@ -161,54 +161,8 @@ Status GetStashedActivationCandidates(const GraphViewer& graph_viewer, } candidate_output_args_map[n].push_back(k); - MO_LOG_DEBUG_INFO(logger, "Find candidate output named [" + kv.first + "] of Node " + - n->Name() + "(" + n->OpType() + ")"); - } - } - - return Status::OK(); -} - -Status ResetNodeBackwardPassAttribute(Graph& graph, bool& modified) { - // Find the YieldOp node. - Node* yield_op_node = nullptr; - for (auto& node : graph.Nodes()) { - if (node.OpType() == "YieldOp") { - yield_op_node = &node; - break; - } - } - - if (yield_op_node == nullptr) { - return Status::OK(); - } - - // Reverse BFS from YieldOp to find all "forward" nodes. - std::vector fw_nodes; - std::vector end_nodes{yield_op_node}; - graph.ReverseDFSFrom( - end_nodes, - nullptr, - [&fw_nodes](const Node* n) { - fw_nodes.push_back(n); - }, - nullptr); - - // Set the attribute to true for all backward nodes. - for (auto& node : graph.Nodes()) { - if (std::find(fw_nodes.begin(), fw_nodes.end(), &node) == fw_nodes.end()) { - auto& attrs = node.GetAttributes(); - if (attrs.count(kBackwardNodeAttributeName)) { - continue; - } - node.AddAttribute(kBackwardNodeAttributeName, static_cast(1)); - modified = true; - } else { - auto& attrs = node.GetAttributes(); - if (attrs.count(kBackwardNodeAttributeName)) { - node.ClearAttribute(kBackwardNodeAttributeName); - modified = true; - } + LOGS(logger, VERBOSE) << "Find candidate output named [" << kv.first << "] of Node " << n->Name() << "(" + << n->OpType() << ")"; } } @@ -216,7 +170,7 @@ Status ResetNodeBackwardPassAttribute(Graph& graph, bool& modified) { } Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, - const ProbeConfig& probe_config, + const ProbeLevel probe_level, const logging::Logger& logger, InlinedHashMap& node_index_to_its_order_in_topological_sort_map, @@ -224,7 +178,7 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, InlinedHashMap>& candidate_output_args_map, MemoryOptimizationPlanner& memory_opt_planner) { - const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(); // Find boundary ops between forward and backward pass, currently, it's limited to YieldOp. yield_op_order_in_topological_sort = -1; @@ -255,9 +209,6 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, is_forward_nodes, logger)); - InlinedHashSet layer_boundary_ln_nodes; - FindLayerBoundaryLayerNormNodes(graph_viewer, logger, layer_boundary_ln_nodes); - // The first pass - find the candidate subgraphs. for (int i = static_cast(node_ids.size()) - 1; i >= 0; --i) { const Node* p_node = graph_viewer.GetNode(node_ids[i]); @@ -271,13 +222,11 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, bool can_compromise_stashed_activation = false; std::unique_ptr recompute_plan = - CheckNodeForRecompute(graph_viewer, - *p_node, - probe_config, + CheckNodeForRecompute(*p_node, + probe_level, fw_op_output_arg_used_map, node_index_to_its_order_in_topological_sort_map, candidate_output_args_map, - layer_boundary_ln_nodes, logger, false, can_compromise_stashed_activation); if (recompute_plan != nullptr) { @@ -285,15 +234,14 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, } if (can_compromise_stashed_activation) { - MO_LOG_DEBUG_INFO(logger, "Searching Node " + p_node->Name() + "(" + p_node->OpType() + - ") for compromised recompute"); + LOGS(logger, VERBOSE) << "Searching Node " << p_node->Name() << "(" << p_node->OpType() + << ") for compromised recompute"; // If the subgraph recompute can save memory by comprising the assumption - recompute graphs' input must exist // during backward pass, then we can consider to recompute them. std::unique_ptr recompute_with_compromise_plan = - CheckNodeForRecompute(graph_viewer, *p_node, probe_config, fw_op_output_arg_used_map, + CheckNodeForRecompute(*p_node, probe_level, fw_op_output_arg_used_map, node_index_to_its_order_in_topological_sort_map, candidate_output_args_map, - layer_boundary_ln_nodes, logger, true, can_compromise_stashed_activation); if (recompute_with_compromise_plan != nullptr) { @@ -324,7 +272,7 @@ void GetMemoryRecordsGroupedByNodeClusterId(const MemoryOptimizationPlanner& mem // Collect more information for display. for (auto& plan : node_plans) { - // Same node cluster id, plans might still have different reuse_buffer patterns, so we need to collect all of them. + // Same node cluster id, plans might still have different reuse_buffer pattern, so we need to collect all of them. if (plan->reuse_buffers.size() > 0) { gsl::span output_indices = plan->GetActivationOutputIndices(); for (auto output_index : output_indices) { @@ -367,13 +315,13 @@ void GetMemoryRecordsGroupedByNodeClusterId(const MemoryOptimizationPlanner& mem if (plan->GetOptimizationType() == OptimizationType::RecomputeWithCompromise) { record.compromise_recomputed_outputs.emplace_back( output_index, - plan->GetActivationOutputDimParamString(output_index), + GetTensorElemCountInSymbolicString(node, output_index), byte_count_per_element, plan->GetSaveRatio()); } else if (plan->GetOptimizationType() == OptimizationType::Recompute) { record.recomputed_outputs.emplace_back(output_index, - plan->GetActivationOutputDimParamString(output_index), + GetTensorElemCountInSymbolicString(node, output_index), byte_count_per_element, plan->GetSaveRatio()); } @@ -400,7 +348,6 @@ void GetMemoryRecordsGroupedByNodeClusterId(const MemoryOptimizationPlanner& mem } // If apply context is provided, also update the actual applied count. - // Be noted, node_to_apply_contexts_map contains some or all of the nodes in node_to_optimization_plan_map. if (node_to_apply_contexts_map.size() > 0) { InlinedHashMap node_cluster_id_to_record_map; for (auto& p : generated_records) { @@ -411,10 +358,6 @@ void GetMemoryRecordsGroupedByNodeClusterId(const MemoryOptimizationPlanner& mem const auto& node = p.first; const auto& apply_context = p.second; std::string node_cluster_id = memory_opt_planner.GenerateNodeClusterId(node); - - ORT_ENFORCE(node_cluster_id_to_record_map.find(node_cluster_id) != node_cluster_id_to_record_map.end(), - "Node cluster id not found in memory record map: ", node_cluster_id); - if (apply_context->type == OptimizationType::Recompute) { node_cluster_id_to_record_map[node_cluster_id]->actual_recompute_count += 1; node_cluster_id_to_record_map[node_cluster_id]->request_recompute_count = apply_context->requested_count; @@ -755,14 +698,20 @@ std::string SerializeMemoryRecords( std::string GetSerializedORTModuleMemoryStat(const GraphViewer& graph_viewer, std::string_view memory_optimization_config, - std::string_view recompute_probe_config, + std::string_view recompute_probe_level, const logging::Logger& logger, std::map>& cluster_id_combinations_to_saved_symbolic_byte_map, const OrtValueNameIdxMap* ortvalue_name_to_idx_map, const SequentialExecutionPlan* p_seq_exec_plan) { - ProbeConfig probe_config; - ORT_ENFORCE(ParseProbeConfigFromString(recompute_probe_config, probe_config).IsOK()); + ProbeLevel probe_level = ProbeLevel::Advanced; + if (!recompute_probe_level.empty()) { + int probe_level_int = ParseIntValueFromString(recompute_probe_level); + ORT_ENFORCE(probe_level_int < static_cast(ProbeLevel::LevelMax) && + probe_level_int >= 0, + "Invalid probe level specified: ", recompute_probe_level); + probe_level = static_cast(probe_level); + } ptrdiff_t yield_op_order_in_topological_sort; InlinedHashMap> candidate_output_args_map; @@ -772,7 +721,7 @@ std::string GetSerializedORTModuleMemoryStat(const GraphViewer& graph_viewer, MemoryOptimizationPlanner memory_opt_planner; ORT_ENFORCE(FindORTModuleMemoryOpportunity( graph_viewer, - probe_config, + probe_level, logger, node_index_to_its_order_in_topological_sort_map, yield_op_order_in_topological_sort, @@ -787,7 +736,7 @@ std::string GetSerializedORTModuleMemoryStat(const GraphViewer& graph_viewer, NodeToClusterApplyContextMap node_to_apply_context_map; if (!memory_optimization_config.empty()) { - ORT_ENFORCE(ParseOptimizationConfigFromString(memory_optimization_config, cluster_id_to_config_map) + ORT_ENFORCE(ParseConfigFromString(memory_optimization_config, cluster_id_to_config_map) .IsOK()); InlinedHashMap> node_to_opt_plan_map; ORT_ENFORCE(memory_opt_planner.FinalizeNodePlansFromUserConfig(cluster_id_to_config_map, diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.h b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.h index 3f0a1a9a96f88..c4267efdbea51 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.h @@ -57,21 +57,11 @@ class MemoryRecord { int freq = 0; }; -/** - * @brief Reset `__backwardpass` attribute for all backward nodes in the graph. - * `__backwardpass` is used by Priority-Based topology sorting. - * - * @param graph To be scanned and modified. - * @param modified Whether the graph is modified. - * @return Status - */ -Status ResetNodeBackwardPassAttribute(Graph& graph, bool& modified); - /** * @brief Iterate the graph and find all possible memory optimization opportunities for related nodes. * * @param graph_viewer The graph to iterate. - * @param probe_config The config for recomputable subgraph detecting. + * @param probe_level The level to control allowed operations during recomputable subgraph detecting. * @param logger Logger. * @param node_index_to_its_order_in_topological_sort_map The mapping of node index to its order in topological sort. * @param yield_op_order_in_topological_sort The order of the boundary op in the topological sort. @@ -80,7 +70,7 @@ Status ResetNodeBackwardPassAttribute(Graph& graph, bool& modified); * @return Status */ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, - const ProbeConfig& probe_config, + const ProbeLevel probe_level, const logging::Logger& logger, InlinedHashMap& node_index_to_its_order_in_topological_sort_map, diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.cc index 64e99a4a0bca5..7e042031f66a2 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.cc @@ -34,7 +34,7 @@ std::string NodeOptimizationPlanBase::GetMemorySavingSymbolicString() const { if (!saving_str.empty()) { saving_str += " + "; } - saving_str = "(" + GetActivationOutputDimParamString(output_index) + " * " + + saving_str = "(" + GetTensorElemCountInSymbolicString(node, output_index) + " * " + std::to_string(byte_count_per_element) + " * " + std::to_string(GetSaveRatio()) + ")"; } diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.h b/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.h index c585b2810b39d..0e5e2967ec15a 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.h @@ -39,14 +39,6 @@ class NodeOptimizationPlanBase { : node(node), activation_output_indices_(activation_output_indices.begin(), activation_output_indices.end()), save_ratio_(save_ratio) { - activation_output_dim_params_.reserve(activation_output_indices_.size()); - - // Generate dim params once for all outputs to guarantee they are unique across different calls. - // because GetTensorElemCountInSymbolicString called to use a static index_empty_dim - // when generating empty dim param as a string. - for (auto output_index : activation_output_indices_) { - activation_output_dim_params_[output_index] = GetTensorElemCountInSymbolicString(node, output_index); - } } virtual ~NodeOptimizationPlanBase() = default; @@ -85,20 +77,12 @@ class NodeOptimizationPlanBase { */ std::string GetMemorySavingSymbolicString() const; - std::string GetActivationOutputDimParamString(size_t index) const { - ORT_ENFORCE(activation_output_dim_params_.find(index) != activation_output_dim_params_.end(), - "activation_output_dim_params_ does not contain index: ", index); - - return activation_output_dim_params_.at(index); - } - const Node* node; // A map: output index reusing other node's output (other_node, output index) InlinedHashMap reuse_buffers; private: InlinedVector activation_output_indices_; - InlinedHashMap activation_output_dim_params_; float save_ratio_ = 1.0f; }; diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index 52dea571a1eaf..0782cbdae2eec 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -9,11 +9,8 @@ #include #include "orttraining/core/optimizer/memory_optimizer/common.h" -#include "orttraining/core/optimizer/memory_optimizer/transformer_specific.h" #include "orttraining/core/optimizer/memory_optimizer/recompute_analysis.h" -#include "core/common/string_utils.h" #include "core/framework/data_types.h" -#include "core/optimizer/utils.h" namespace onnxruntime::optimizer::memory_optimizer { @@ -56,7 +53,7 @@ struct AllowedRecomputeNodeConfig { InlinedVector input_arg_indices; // input index to iterate further (bottom up) }; -// The supported op types are predefined. +// The op types that are supported predefined. const InlinedHashMap& GetAllowedRecomputeOps(int probe_op_level) { static InlinedHashMap> recomputable_op_table_map; @@ -79,19 +76,16 @@ const InlinedHashMap& GetAllowedRecompu /// The shape input is trivial whether it exists or not in backward. {"Reshape", AllowedRecomputeNodeConfig{{0}}}, {"Squeeze", AllowedRecomputeNodeConfig{{0}}}, - {"Transpose", AllowedRecomputeNodeConfig{{0}}}, {"Unsqueeze", AllowedRecomputeNodeConfig{{0}}}, // Unary elementwise - {"Dropout", AllowedRecomputeNodeConfig{{0}}}, - {"BiasGelu", AllowedRecomputeNodeConfig{{0, 1}}}, /// The ratio and mode input are trivial whether they exist or not in backward {"BitmaskDropout", AllowedRecomputeNodeConfig{{0}}}, /// The axis input is trivial whether it exists or not in backward {"CumSum", AllowedRecomputeNodeConfig{{0}}}, - {"Expand", AllowedRecomputeNodeConfig{{0}}}, - {"FastGelu", AllowedRecomputeNodeConfig{{0}}}, + {"Dropout", AllowedRecomputeNodeConfig{{0}}}, {"Gelu", AllowedRecomputeNodeConfig{{0}}}, + {"FastGelu", AllowedRecomputeNodeConfig{{0}}}, // Ternary elementwise {"Where", AllowedRecomputeNodeConfig{{0, 1, 2}}}, @@ -99,16 +93,11 @@ const InlinedHashMap& GetAllowedRecompu // Data copy {"Tile", AllowedRecomputeNodeConfig{{0}}}, {"Cast", AllowedRecomputeNodeConfig{{0}}}, - {"ConcatTraining", AllowedRecomputeNodeConfig{{0, 1}}}, // Input could be more than 2. But mostly 2. - {"Slice", AllowedRecomputeNodeConfig{{0}}}, - {"Split", AllowedRecomputeNodeConfig{{0}}}, - {"Gather", AllowedRecomputeNodeConfig{{0}}}, }); } if (probe_op_level >= static_cast(ProbeLevel::Advanced)) { recomputable_op_table.insert({ - {"LayerNormalization", AllowedRecomputeNodeConfig{{0, 1, 2}}}, {"MatMul", AllowedRecomputeNodeConfig{{0, 1}}}, {"FusedMatMul", AllowedRecomputeNodeConfig{{0, 1}}}, {"Softmax", AllowedRecomputeNodeConfig{{0}}}, @@ -131,8 +120,7 @@ bool IsRecomputable(const Node& node, ProbeLevel probe_level) { /** * @brief Find recomputable subgraphs (has at least one nodes, at most MAXIMUM_RECOMPUTE_NODE_COUNT nodes). * - * @param entry_node The entry node to start the subgraph matching (bottom-up), usually the last node of found subgraphs. - * @param probe_config The probe config to control recomputable subgraph detecting. + * @param node The entry node to start the subgraph matching (bottom-up), usually the last node of found subgraphs. * @param node_output_index_candidates Candidate output indices of "node", which are consumed by both fw and bw ops. * @param fw_op_output_arg_used_map The activation usage (in fw and bw) mapping. * @param node_index_to_its_order_in_topological_sort_map The mapping of node index to its order in topological sort. @@ -143,13 +131,13 @@ bool IsRecomputable(const Node& node, ProbeLevel probe_level) { * @param compromise_stashed_activation Whether to compromise stashed activation, e.g. if we cannot find a * recomputable subgraph to save a stashed activation, we can compromise to find a recomputable subgraph to reduce the * size of stashed activation. - * @param can_compromise_stashed_activation A bool return value, to indicate there are opportunities for finding a + * @param can_compromise_stashed_activation A bool return value, to indicate there is opportunaties for finding a * compromised subgraph. * @param save_ratio The ratio of memory saving if we can find a recomputable subgraph. * @return Status */ Status SelectRecomputeSubgraph(const Node& entry_node, - const ProbeConfig& probe_config, + const ProbeLevel probe_level, const InlinedVector& node_output_index_candidates, const ActivationUsedMap& fw_op_output_arg_used_map, const InlinedHashMap& @@ -159,13 +147,12 @@ Status SelectRecomputeSubgraph(const Node& entry_node, bool compromise_stashed_activation, bool& can_compromise_stashed_activation, float& save_ratio) { - const ProbeLevel probe_level = probe_config.probe_level; const auto& recomputable_op_table = GetAllowedRecomputeOps(static_cast(probe_level)); can_compromise_stashed_activation = false; - MO_LOG_DEBUG_INFO(logger, "Enter SelectRecomputeSubgraph for Node " + entry_node.Name() + - "(" + entry_node.OpType() + ")"); + LOGS(logger, VERBOSE) << "Enter SelectRecomputeSubgraph for Node " << entry_node.Name() << "(" + << entry_node.OpType() << ")"; nodes.clear(); std::deque q; @@ -220,34 +207,33 @@ Status SelectRecomputeSubgraph(const Node& entry_node, // (either of the above checks is true for entry node outputs) if (op_recompute_config_it == recomputable_op_table.end()) { early_stop = true; - MO_LOG_DEBUG_INFO(logger, "Entry Node " + curr_node->Name() + "(" + curr_node->OpType() + - ") is **NOT** in recompute op list, search terminates."); + LOGS(logger, VERBOSE) << "Entry Node " << curr_node->Name() << "(" << curr_node->OpType() << ") is **NOT** " + << "in recompute op list, search terminates."; break; } } else { if (op_recompute_config_it == recomputable_op_table.end()) { if (fw_op_output_arg_used_map.at(cur_output_arg_name).second) { - MO_LOG_DEBUG_INFO(logger, "Node " + curr_node->Name() + "(" + curr_node->OpType() + - ") is **NOT** in recompute op list, but its output [" + - cur_output_arg_name + - "] is used in backward, we don't need trace bottom-up further. Entry node: " + - entry_node.Name() + "(" + entry_node.OpType() + ")"); + LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() << ") is **NOT** in " + << "recompute op list, but its output [" << cur_output_arg_name << "] is used in " + << "backward, we don't need trace bottom-up further. Entry node: " + << entry_node.Name() << "(" << entry_node.OpType() << ")"; continue; } else { early_stop = true; - MO_LOG_DEBUG_INFO(logger, "Node " + curr_node->Name() + "(" + curr_node->OpType() + ") is **NOT** in " + - "recompute op list, and its output [" + cur_output_arg_name + - "] does not exist in backward, search terminates. Entry node: " + - entry_node.Name() + "(" + entry_node.OpType() + ")"); + LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() << ") is **NOT** in " + << "recompute op list, and its output [" << cur_output_arg_name + << "] does not exist in backward, search terminates. Entry node: " + << entry_node.Name() << "(" << entry_node.OpType() << ")"; break; } } if (fw_op_output_arg_used_map.at(cur_output_arg_name).second) { - MO_LOG_DEBUG_INFO(logger, "Node " + curr_node->Name() + "(" + curr_node->OpType() + ") " + - "is in recompute op list, while its output [" + cur_output_arg_name + - "] is used in backward, we don't need trace bottom-up further. Entry node: " + - entry_node.Name() + "(" + entry_node.OpType() + ")"); + LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() << ") " + << "is in recompute op list, while its output [" << cur_output_arg_name + << "] is used in backward, we don't need trace bottom-up further. Entry node: " + << entry_node.Name() << "(" << entry_node.OpType() << ")"; continue; } } @@ -255,8 +241,8 @@ Status SelectRecomputeSubgraph(const Node& entry_node, // Append node to the selected graph. if (std::find(nodes.begin(), nodes.end(), curr_node) == nodes.end()) { nodes.push_back(curr_node); - MO_LOG_DEBUG_INFO(logger, "Node " + curr_node->Name() + "(" + curr_node->OpType() + - ") is added in selected subgraph"); + LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() + << ") is added in selected subgraph "; } // This check is not matured now, subject to change. @@ -265,16 +251,15 @@ Status SelectRecomputeSubgraph(const Node& entry_node, float is_current_node_compromisable = (ratio < 1.f); can_compromise_stashed_activation = can_compromise_stashed_activation || is_current_node_compromisable; if (is_current_node_compromisable) { - MO_LOG_DEBUG_INFO(logger, "Node " + curr_node->Name() + "(" + curr_node->OpType() + - ") has input/output size " + std::to_string(ratio) + - " < 1.f, can compromise stashed activation"); + LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() + << ") has input/output size " << ratio << " < 1.f, can compromise stashed activation"; } if (is_current_node_compromisable && compromise_stashed_activation) { - MO_LOG_DEBUG_INFO(logger, "Node " + curr_node->Name() + "(" + curr_node->OpType() + ") is in " + - "recompute op list, and its output [" + cur_output_arg_name + - "] does not exist in backward, while it meets compromised check, we don't need trace " + - "bottom-up further."); + LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() << ") is in " + << "recompute op list, and its output [" << cur_output_arg_name + << "] does not exist in backward, while it meets compromised check, we don't need trace " + << "bottom-up further."; save_ratio = saving_ratio; continue; } @@ -290,10 +275,10 @@ Status SelectRecomputeSubgraph(const Node& entry_node, input_arg_indices.end()) { NodeOutputPort next_p = std::make_pair(&parent_node, parent_node_output_index); - MO_LOG_DEBUG_INFO(logger, "Node " + parent_node.Name() + "(" + parent_node.OpType() + ")'s " + - std::to_string(parent_node_output_index) + "th output [" + - parent_node.OutputDefs()[parent_node_output_index]->Name() + - "] is added in recompute search list"); + LOGS(logger, VERBOSE) << "Node " << parent_node.Name() << "(" << parent_node.OpType() << ")'s " + << parent_node_output_index + << "th output [" << parent_node.OutputDefs()[parent_node_output_index]->Name() + << "] is added in recompute search list "; q.push_back(next_p); } @@ -305,9 +290,8 @@ Status SelectRecomputeSubgraph(const Node& entry_node, // If input args are not found in bw, but op count exceed MAXIMUM_RECOMPUTE_NODE_COUNT, skip recompute. if (!q.empty() || early_stop) { - MO_LOG_DEBUG_INFO(logger, "Fail to find a solution for recompute: current node count is " + - std::to_string(nodes.size()) + ", queue size: " + std::to_string(q.size()) + - ", early stop: " + std::to_string(early_stop)); + LOGS(logger, VERBOSE) << "Fail to find a solution for recompute: current node count is " << nodes.size() + << ", queue size: " << q.size() << ", early stop: " << early_stop; nodes.clear(); } else { // Re-order the nodes in topological order. @@ -351,75 +335,24 @@ void NodesInTopoOrderToString(gsl::span nodes_in_topological_ } // namespace -Status ParseProbeConfigFromString(std::string_view recompute_probe_config, ProbeConfig& probe_config) { - int transformer_layer_as_boundary = 0; - if (!recompute_probe_config.empty()) { - const auto probe_configs = utils::SplitString(recompute_probe_config, ":"); - ORT_ENFORCE(probe_configs.size() >= 1, "Probe config information is not complete."); - int probe_level_int = ParseIntValueFromString(probe_configs[0]); - ORT_ENFORCE(probe_level_int < - static_cast(ProbeLevel::LevelMax) && - probe_level_int >= 0, - "Invalid probe level specified: ", probe_configs[0]); - - if (probe_configs.size() > 1) { - transformer_layer_as_boundary = ParseIntValueFromString(probe_configs[1]); - ORT_ENFORCE(transformer_layer_as_boundary == 0 || transformer_layer_as_boundary == 1, - "Invalid transformer_layer_as_boundary specified: ", probe_configs[1]); - } - - probe_config.probe_level = static_cast(probe_level_int); - } - - probe_config.enable_transformer_layer_as_boundary = transformer_layer_as_boundary == 1; - - return Status::OK(); -} - -std::unique_ptr CheckNodeForRecompute(const GraphViewer& graph_viewer, - const Node& node, - const ProbeConfig& probe_config, +std::unique_ptr CheckNodeForRecompute(const Node& node, + const ProbeLevel probe_level, const ActivationUsedMap& fw_op_output_arg_used_map, const InlinedHashMap& node_index_to_its_order_in_topological_sort_map, const InlinedHashMap>& candidate_output_args_map, - const InlinedHashSet& layer_boundary_ln_nodes, const logging::Logger& logger, bool compromise_stashed_activation, bool& can_compromise_stashed_activation) { - if (!IsRecomputable(node, probe_config.probe_level)) { + if (!IsRecomputable(node, probe_level)) { return nullptr; } - if (probe_config.enable_transformer_layer_as_boundary) { - // Check whether the node's stashed activation outputs are used by LayerNormalization's inputs. - // If yes, for Transformers, we don't need to recompute the node, because we treated - // LayerNormalization of Attention as the boundary for subgraph searching. - // Check at least one of the stashed activation output is used as the 1st input - // of LayerNormalization, e.g. will be used as input of LayerNormalizationGrad. - for (auto& output_index : candidate_output_args_map.at(&node)) { - auto output_name = node.OutputDefs()[output_index]->Name(); - auto consumers = graph_viewer.GetConsumerNodes(output_name); - for (auto& consumer : consumers) { - if (layer_boundary_ln_nodes.find(consumer) != layer_boundary_ln_nodes.end()) { - int dest_in_index = optimizer_utils::IndexOfNodeInput(*consumer, *node.OutputDefs()[output_index]); - if (dest_in_index == 0) { - LOGS(logger, INFO) << "Node " << node.Name() << "(" << node.OpType() - << ") is a Attention+MLP layer boundary node, " - << "its stashed activation outputs are used by LayerNormalization's inputs, " - << "we don't need to recompute it."; - return nullptr; - } - } - } - } - } - InlinedVector nodes_in_topological_order; float save_ratio = 1.f; ORT_ENFORCE(SelectRecomputeSubgraph(node, - probe_config, + probe_level, candidate_output_args_map.at(&node), fw_op_output_arg_used_map, node_index_to_its_order_in_topological_sort_map, @@ -436,7 +369,7 @@ std::unique_ptr CheckNodeForRecompute(const GraphViewer& grap std::string subgraph_str_representation, log_info; NodesInTopoOrderToString(nodes_in_topological_order, subgraph_str_representation, log_info); - MO_LOG_DEBUG_INFO(logger, "Node " + node.Name() + "(" + node.OpType() + ") can be recomputed" + log_info); + LOGS(logger, VERBOSE) << "Node " << node.Name() << "(" << node.OpType() << ") can be recomputed" << log_info; return std::make_unique(&node, candidate_output_args_map.at(&node), nodes_in_topological_order, @@ -455,7 +388,7 @@ std::string NodeRecomputePlan::NormalizeForNodeClusterId() const { oss << "recompute:" << node->OpType() << "-" << compromise_recompute_ << "-"; for (auto& output_index : GetActivationOutputIndices()) { - oss << output_index << ":" << GetActivationOutputDimParamString(output_index); + oss << output_index << ":" << GetTensorElemCountInSymbolicString(node, output_index); oss << ":" << node->OutputDefs()[output_index]->TypeAsProto()->tensor_type().elem_type() << "-"; } diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h index d9693835313b8..9211e5044cd86 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h @@ -22,25 +22,6 @@ enum class ProbeLevel { LevelMax = 2, }; -/** - * @brief Configuration to control recompute subgraph detection. - */ -class ProbeConfig { - public: - ProbeConfig() = default; - - ProbeConfig(ProbeLevel level, bool transformer_layer_as_boundary = false) { - probe_level = level; - enable_transformer_layer_as_boundary = transformer_layer_as_boundary; - } - - ProbeLevel probe_level{ProbeLevel::Basic}; - bool enable_transformer_layer_as_boundary{false}; -}; - -Status ParseProbeConfigFromString(std::string_view recompute_probe_config, - ProbeConfig& probe_config); - /** * @brief A child class used for Recompute/RecomputeWithCompromise optimization plan. * @@ -94,15 +75,13 @@ class NodeRecomputePlan : public NodeOptimizationPlanBase { /** * @brief For the node producing stashed activation, check whether a recomputable subgraph can be found or not. * - * @param graph_viewer The graph viewer to get node information. * @param node The entry node to start the subgraph matching (bottom-up), usually the last node of found subgraphs. - * @param probe_config The config for subgraph detecting. + * @param probe_level The level to control allowed operations during subgraph detecting. * @param fw_op_output_arg_used_map The activation usage (in fw and bw) mapping. * @param node_index_to_its_order_in_topological_sort_map The mapping of node index to its order in topological sort. * Used to re-order the collected subgraph nodes. * @param candidate_output_args_map A map from node to its candidate activations, which are consumed by both fw and * bw ops. - * @param layer_boundary_ln_nodes A set of LayerNormalization nodes, which are used as the boundary for subgraph. * @param subgraph_stores A store to maintain all found subgraphs. * @param logger Logger. * @param compromise_stashed_activation Whether to compromise stashed activation, e.g. if we cannot find a @@ -111,15 +90,13 @@ class NodeRecomputePlan : public NodeOptimizationPlanBase { * @param can_compromise_stashed_activation A bool return value, to indicate there is opportunaties for finding a * compromised subgraph. */ -std::unique_ptr CheckNodeForRecompute(const GraphViewer& graph_viewer, - const Node& node, - const ProbeConfig& probe_config, +std::unique_ptr CheckNodeForRecompute(const Node& node, + const ProbeLevel probe_level, const ActivationUsedMap& fw_op_output_arg_used_map, const InlinedHashMap& node_index_to_its_order_in_topological_sort_map, const InlinedHashMap>& candidate_output_args_map, - const InlinedHashSet& layer_boundary_ln_nodes, const logging::Logger& logger, bool compromise_stashed_activation, bool& can_compromise_stashed_activation); diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc deleted file mode 100644 index 04f2679ac774f..0000000000000 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include -#include - -#include "orttraining/core/optimizer/memory_optimizer/common.h" -#include "orttraining/core/optimizer/memory_optimizer/transformer_specific.h" -#include "core/graph/graph_utils.h" -#include "core/optimizer/utils.h" -#include "core/graph/graph_viewer.h" -#include "core/framework/tensorprotoutils.h" - -#include "core/common/string_utils.h" - -namespace onnxruntime::optimizer::memory_optimizer { - -void FindLayerBoundaryLayerNormNodes( - const GraphViewer& graph_viewer, - const logging::Logger&, - InlinedHashSet& layer_boundary_ln_nodes) { - // Loop all nodes to find LayerNormalization nodes. - // For each LayerNormalization node, keep checking its output nodes, - // until find a node that is Softmax or BiasSoftmax or another LayerNormalization. - // If the found node is Softmax or BiasSoftmax, the LayerNormalization node as ATTENTION. - // If the found node is another LayerNormalization, the LayerNormalization node as MLP. - const InlinedHashSet softmax_ops{"Softmax", "BiasSoftmax"}; - const InlinedHashSet layernorm_ops{"LayerNormalization", "SkipLayerNormalization"}; - - layer_boundary_ln_nodes.clear(); - const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); - for (auto node_index : node_topology_list) { - auto& node = *graph_viewer.GetNode(node_index); - - if (layernorm_ops.find(node.OpType()) == layernorm_ops.end()) { - continue; - } - - std::deque nodes_to_check; - std::set visited_nodes; - for (auto node_it = node.OutputNodesBegin(); node_it != node.OutputNodesEnd(); ++node_it) { - nodes_to_check.push_back(&(*node_it)); - } - - while (!nodes_to_check.empty()) { - const Node* next_node = nodes_to_check.front(); - nodes_to_check.pop_front(); - - if (visited_nodes.find(next_node) != visited_nodes.end()) { - continue; - } - - visited_nodes.insert(next_node); - if (softmax_ops.find(next_node->OpType()) != softmax_ops.end()) { - layer_boundary_ln_nodes.insert(&node); - break; - } else if (layernorm_ops.find(next_node->OpType()) != layernorm_ops.end()) { - break; - } else { - for (auto node_it = next_node->OutputNodesBegin(); node_it != next_node->OutputNodesEnd(); ++node_it) { - nodes_to_check.push_back(&(*node_it)); - } - } - } - } -} - -} // namespace onnxruntime::optimizer::memory_optimizer diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h deleted file mode 100644 index f2cfd640b0840..0000000000000 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include -#include - -#include "core/common/common.h" -#include "core/common/logging/logging.h" -#include "core/common/inlined_containers_fwd.h" -#include "core/graph/basic_types.h" -#include "core/framework/data_types.h" -#include "core/graph/graph_viewer.h" -#include "orttraining/core/optimizer/memory_optimizer/common.h" - -namespace onnxruntime::optimizer::memory_optimizer { - -void FindLayerBoundaryLayerNormNodes(const GraphViewer& graph_viewer, - const logging::Logger& logger, - InlinedHashSet& layer_boundary_ln_nodes); - -} // namespace onnxruntime::optimizer::memory_optimizer diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 76943b954837b..dd6d5a568cb18 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -37,7 +37,7 @@ from ._runtime_inspector import RuntimeInspector from ._utils import check_function_has_param, get_rank from ._zero_stage3_compatibility import stage3_export_context -from .options import DebugOptions, LogLevel, _MemoryOptimizationLevel, _RuntimeOptions +from .options import DebugOptions, LogLevel, _RuntimeOptions from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension @@ -650,7 +650,10 @@ def _log_feature_stats(self): if get_rank() != 0: return - tbl = PTable(sortable=True) + if self._runtime_inspector.memory_ob.is_enabled() and self._debug_options.log_level <= LogLevel.DEVINFO: + self._logger.info(self._runtime_inspector.memory_ob.memory_optimization_opportunity_table_str) + + tbl = PTable() def _add_record(tbl, columns): return tbl.add_row([columns[0], ":", "ON" if columns[1] else "OFF", ":", columns[2]]) @@ -675,35 +678,29 @@ def _add_record(tbl, columns): ], ) - if self._runtime_options.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE: - opt_config_to_display = "ALL_RECOMPUTE_FOR_EACH_LAYER" - else: - opt_config_to_display = self._runtime_options.memory_optimizer_config - + output_memory_optimization_details = self._debug_options.log_level <= LogLevel.INFO mem_row = _add_record( tbl, [ "Memory Optimizer", len(self._runtime_options.memory_optimizer_config) > 0, ( - f"Memory Optimization Level: [{_MemoryOptimizationLevel.to_string(self._runtime_options.memory_optimization_level)}], " - f"Optimization Config: [{opt_config_to_display}]" + f"User config: {self._runtime_options.memory_optimizer_config}, probe level: {self._runtime_options.probe_level}" if len(self._runtime_options.memory_optimizer_config) > 0 - else "Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1 or ORTMODULE_MEMORY_OPT_CONFIG=,,..." + else "Enable with env ORTMODULE_MEMORY_OPT_CONFIG=" ), ], ) - if self._runtime_inspector.memory_ob.is_enabled() and self._debug_options.logging.log_level < LogLevel.WARNING: + if self._runtime_inspector.memory_ob.is_enabled() and output_memory_optimization_details: mem_notes, mem_tbl = self._runtime_inspector.memory_ob.display_memory_optimization_plans( - self._runtime_options.memory_optimizer_config, - details=True, + self._runtime_options.memory_optimizer_config ) if mem_tbl is not None: mem_row.append_annotation_table(mem_tbl) notes.extend(mem_notes) - compute_opt_row = _add_record( + _add_record( tbl, [ "Compute Optimizer", @@ -711,12 +708,10 @@ def _add_record(tbl, columns): "Enable/Disable with env ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1/0", ], ) - - compute_opt_annotation_tbl = PTable() _add_record( - compute_opt_annotation_tbl, + tbl, [ - " - FLOP Reduction", + " - FLOPReduction", self._runtime_options.enable_compute_optimizer, "Reduce FLOPs by upstreaming shrinking-sized ops", ], @@ -725,18 +720,14 @@ def _add_record(tbl, columns): if self._runtime_options.enable_compute_optimizer: if len(self._runtime_options.label_sparsity_ratio) > 0: _add_record( - compute_opt_annotation_tbl, - [" - Label Sparsity Opt", True, f"Input density: {self._runtime_options.label_sparsity_ratio}"], + tbl, [" - LabelSparsityOpt", True, f"Input density: {self._runtime_options.label_sparsity_ratio}"] ) if len(self._runtime_options.embed_sparsity_ratio) > 0: _add_record( - compute_opt_annotation_tbl, - [" - Embed Sparsity Opt", True, f"Input density: {self._runtime_options.embed_sparsity_ratio}"], + tbl, [" - EmbedSparsityOpt", True, f"Input density: {self._runtime_options.embed_sparsity_ratio}"] ) - compute_opt_row.append_annotation_table(compute_opt_annotation_tbl) - # Add fallback _add_record( tbl, @@ -748,7 +739,7 @@ def _add_record(tbl, columns): ) # Add Triton - triton_row = _add_record( + _add_record( tbl, [ "TritonOp Enabled", @@ -757,16 +748,14 @@ def _add_record(tbl, columns): ], ) - triton_annotation_tbl = PTable() - if self._runtime_options.enable_tuning: desc = "Enable tunning Ops online" if self._runtime_options.tuning_results_path: desc += f", save tuning results to {self._runtime_options.tuning_results_path}" - _add_record(triton_annotation_tbl, ["Online Op Tuning", True, desc]) + _add_record(tbl, ["Online Op Tuning", True, desc]) elif self._runtime_options.tuning_results_path: _add_record( - triton_annotation_tbl, + tbl, [ "Offline Op Tuning", True, @@ -774,8 +763,6 @@ def _add_record(tbl, columns): ], ) - triton_row.append_annotation_table(triton_annotation_tbl) - _add_record( tbl, [ diff --git a/orttraining/orttraining/python/training/ortmodule/_onnx_models.py b/orttraining/orttraining/python/training/ortmodule/_onnx_models.py index d687bc24384ed..ac09c838af838 100644 --- a/orttraining/orttraining/python/training/ortmodule/_onnx_models.py +++ b/orttraining/orttraining/python/training/ortmodule/_onnx_models.py @@ -25,7 +25,7 @@ class ONNXModels: 1. exported_model: Model that is exported by torch.onnx.export 2. optimized_model: For eval mode it's exported_model with concrete input shapes set if needed, - for training mode, it's an optimized model after the gradients graph has been built. + for training mode, it's optimized model after gradients graph has been built. In addition, ORTModule also saves two other models, to the user-provided path: a. the pre_grad_model which is the model before the gradients graph is built. b. the execution_model which is the model that is being executed by ORT. diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py index 078ce4d27cd6f..05a5f30683824 100644 --- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py +++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py @@ -17,7 +17,6 @@ from onnxruntime.training.utils import PTable from ._execution_agent import TrainingAgent -from .options import _MemoryOptimizationLevel, _RuntimeOptions class Phase(IntEnum): @@ -530,26 +529,20 @@ def collect_symbolic_dim_values( dim_idx ] - def find_memory_optimization_opportunity(self, execution_agent: TrainingAgent, runtime_options: _RuntimeOptions): + def find_memory_optimization_opportunity( + self, execution_agent: TrainingAgent, memory_optimizer_config, probe_level + ): """Find memory optimization opportunity. Args: execution_agent: TrainingAgent. - runtime_options: Runtime options. + memory_optimizer_config: Memory optimization config. + probe_level: Memory probe level. """ - - recompute_probe_config = runtime_options.recompute_probe_config - memory_optimizer_config = runtime_options.memory_optimizer_config - - # If the memory optimization level is aggressive, we will first collect all - # recompute subgraph by passing empty memory_optimizer_config to get_serialized_ortmodule_memory_stat. - if runtime_options.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE: - memory_optimizer_config = "" - ( self.memory_optimization_opportunity_table_str, memory_optimization_saving_symbolics, - ) = execution_agent.get_serialized_ortmodule_memory_stat(memory_optimizer_config, recompute_probe_config) + ) = execution_agent.get_serialized_ortmodule_memory_stat(memory_optimizer_config, probe_level) cluster_id_to_saving_symbol_map: Dict[str, MemoryOptimizationSummary] = {} for cluster_id, memory_saving_stat in memory_optimization_saving_symbolics.items(): @@ -578,20 +571,6 @@ def find_memory_optimization_opportunity(self, execution_agent: TrainingAgent, r for cluster_id, values in sorted_list: self.cluster_id_combination_to_saving_symbolics_map[cluster_id] = values - # For aggressive memory optimization, we update the memory_optimizer_config using all. - if runtime_options.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE: - recompute_configs = [] - for cluster_id in self.cluster_id_combination_to_saving_symbolics_map: - config_values = cluster_id.split(":") - opt_type = int(config_values[1]) - # TODO(pengwa): use enum instead of 1 here. - if opt_type != 1: - continue - - recompute_configs.append(cluster_id) - - runtime_options.memory_optimizer_config = ",".join(recompute_configs) - def inspect_memory(self, cur_phase: Phase): """Inspect memory usage and print statistics. @@ -611,7 +590,7 @@ def inspect_memory(self, cur_phase: Phase): if self._rank != 0: return - if cur_phase < Phase.PRE_FORWARD or (cur_phase > Phase.POST_BACKWARD): + if cur_phase < Phase.PRE_FORWARD or (cur_phase <= self._last_phase): raise RuntimeError(f"Invalid phase detected: {cur_phase}, last_phase: {self._last_phase}") if (cur_phase - self._pre_phase) != 1: @@ -658,13 +637,12 @@ def _increase_step(self): def _normalize(self, mem_size_in_bytes: Union[float, int]) -> str: return f"{float(mem_size_in_bytes) / MemoryObserver.NORMALIZER_FACTOR:.0f}" - def display_memory_optimization_plans(self, memory_optimizer_config, details=False) -> Tuple[List[str], PTable]: + def display_memory_optimization_plans(self, memory_optimizer_config) -> Tuple[List[str], PTable]: mem_plan_count = len(self.cluster_id_combination_to_saving_symbolics_map) if mem_plan_count > 0: mem_tbl = PTable() - if details: - mem_tbl.add_row(["", "", "", "", "Configs", "Freq", "Max Saving(Bytes)", "Saving Symbolic(Bytes)"]) + mem_tbl.add_row(["", "", "", "", "Configs", "Freq", "Max Saving(Bytes)", "Saving Symbolic(Bytes)"]) index = 1 @@ -682,9 +660,7 @@ def _get_user_config_without_freq(configs: str): return configs_with_out_freq - user_configs_with_out_freq = [] - if memory_optimizer_config: - user_configs_with_out_freq = _get_user_config_without_freq(memory_optimizer_config) + user_configs_with_out_freq = _get_user_config_without_freq(memory_optimizer_config) for ( cluster_id, @@ -705,28 +681,26 @@ def _get_user_config_without_freq(configs: str): else "OFF", ":", cluster_id, - saving_symbolic.freq if details else "", - saving_bytes if details else "", - saving_symbolic.simplified_symbolic_saving_expr if details else "", + saving_symbolic.freq, + saving_bytes, + saving_symbolic.simplified_symbolic_saving_expr, ] ) index += 1 - notes = [] - if details: - notes.append( - "[Memory Optimizer] Use ORTMODULE_MEMORY_OPT_LEVEL=1 to enable all recomputable subgraphs per transformer layer." - ) - saving_recommendation = "[Memory Optimizer] Or use comma as a delimiter to selectively enable multiple memory optimization plans:\n" - saving_recommendation += " export ORTMODULE_MEMORY_OPT_CONFIG=,,..." + saving_recommendation = ( + "use comma as delimiter to enable multiple memory optimization plans at the same time:\n" + ) + saving_recommendation += " export ORTMODULE_MEMORY_OPT_CONFIG=,,..." - notes.append(saving_recommendation) + notes = [] + notes.append(saving_recommendation) - saving_recommendation = "memory saving is calculated based on the 1st batch symbolic dim values:\n" - for dim_param, dim_value in self.symbolic_dim_name_to_value_map.items(): - saving_recommendation += f" {dim_param}={dim_value}," - notes.append(saving_recommendation) + saving_recommendation = "memory saving is calculated based on the 1st batch symbolic dim values:\n" + for dim_param, dim_value in self.symbolic_dim_name_to_value_map.items(): + saving_recommendation += f" {dim_param}={dim_value}," + notes.append(saving_recommendation) return notes, mem_tbl diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 5b2c673ce94cb..96a95557bb9a1 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -18,7 +18,7 @@ from ._gradient_accumulation_manager import GradientAccumulationManager from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo from ._io import _FlattenedModule, _InputInfo, unflatten_user_output -from ._logger import ORTModuleInitPhase, TrackTime +from ._logger import LogLevel, ORTModuleInitPhase, TrackTime from ._runtime_inspector import Phase from ._utils import save_tuning_results, set_tuning_results from .graph_optimizer_registry import GraphOptimizerRegistry @@ -432,9 +432,11 @@ def _create_execution_agent(self): local_device_rank = self._device.index if device_type == "ort" else _utils.get_device_index(self._device) + # When log level is <= INFO, we would collect memory optimization opportunities. + # (TODO: consider to enable by default once memory optimization feature is stable and well improved.) # Create a training agent without enabling memory optimization here is beneficial for memory analyzing # when we have an allocation plan in place, and reuse information is available. - if self._runtime_inspector.memory_ob.is_enabled(): + if self._runtime_inspector.memory_ob.is_enabled() and self._debug_options.log_level <= LogLevel.INFO: # Create a training agent without enabling memory optimization. execution_agent = TrainingAgent( self._onnx_models.optimized_model.SerializeToString(), @@ -449,7 +451,7 @@ def _create_execution_agent(self): ) self._runtime_inspector.memory_ob.find_memory_optimization_opportunity( - execution_agent, self._runtime_options + execution_agent, self._runtime_options.memory_optimizer_config, self._runtime_options.probe_level ) # Release it as early as possible. @@ -460,7 +462,7 @@ def _create_execution_agent(self): "optimization.memory_optimizer_config", self._runtime_options.memory_optimizer_config ) session_options.add_session_config_entry( - "optimization.enable_memory_probe_recompute_config", self._runtime_options.recompute_probe_config + "optimization.enable_memory_probe_recompute_level", self._runtime_options.probe_level ) self._execution_agent = TrainingAgent( diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index a93f6413b7ab4..ffa3f4afa7b30 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -192,23 +192,6 @@ def is_disabled(self): return _SkipCheck.SKIP_CHECK_DISABLED in self -class _MemoryOptimizationLevel(IntFlag): - """Enumeration to specify memory optimization level""" - - USER_SPECIFIED = 0 # Fully respect user-specified config - TRANSFORMER_LAYERWISE_RECOMPUTE = 1 # Enable all recomputable subgraphs per layer - - @staticmethod - def to_string(memory_optimization_level): - if memory_optimization_level == _MemoryOptimizationLevel.USER_SPECIFIED: - return "USER_SPECIFIED" - - if memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE: - return "TRANSFORMER_LAYERWISE_RECOMPUTE" - - return "" - - class _RuntimeOptions: """Configurable runtime options for ORTModule.""" @@ -274,13 +257,8 @@ def __init__(self, logger: Logger): self.enable_embedding_sparse_optimizer = False # TODO(pengwa): remove once validation on more models are done. # Configuration for memory optimization. - self.memory_optimization_level = ( - _MemoryOptimizationLevel.USER_SPECIFIED - ) # 0: use `memory_optimizer_config`; 1: aggressive optimization, enable all recomputable subgraphs. - self.memory_optimizer_config = "" # This is an advanced config, please refer to onnxruntime docs for details. - # 1 is the op set level; 0 indicates whether consider the Transformer-based model's layer boundary when - # detecting recompute subgraphs. - self.recompute_probe_config = "1:0" + self.memory_optimizer_config = "" + self.probe_level = "1" # Configuration for dev tools. self.print_input_density = False @@ -338,13 +316,8 @@ def _override_from_env_vars(self): ) # Configuration for memory optimization. - self.memory_optimization_level = int(os.getenv("ORTMODULE_MEMORY_OPT_LEVEL", self.memory_optimization_level)) - user_given_memory_optimizer_config = os.getenv("ORTMODULE_MEMORY_OPT_CONFIG", self.memory_optimizer_config) - self.memory_optimizer_config = ",".join([c for c in user_given_memory_optimizer_config.split(",") if c]) - if self.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE: - # For transformer layer-wise recompute, we enable layer boundary when detecting subgraphs. - # Then all detected subgraphs will not cross different layers. - self.recompute_probe_config = "1:1" + self.memory_optimizer_config = os.getenv("ORTMODULE_MEMORY_OPT_CONFIG", self.memory_optimizer_config) + self.probe_level = os.getenv("ORTMODULE_MEMORY_OPT_PROBE_RECOMPUTE_LEVEL", self.probe_level) # Configuration for dev tools. if "ORTMODULE_PRINT_INPUT_DENSITY" in os.environ: diff --git a/orttraining/orttraining/python/training/utils/ptable.py b/orttraining/orttraining/python/training/utils/ptable.py index 5e06864800666..3b3b80d29ed92 100644 --- a/orttraining/orttraining/python/training/utils/ptable.py +++ b/orttraining/orttraining/python/training/utils/ptable.py @@ -20,10 +20,9 @@ def append_annotation_table(self, ptable) -> None: class PTable: """A table that can be printed to the console.""" - def __init__(self, sortable=False) -> None: + def __init__(self) -> None: self._rows: List[Row] = [] self._column_count = None - self._sortable = sortable # allow the rows to be sorted by the first column def add_row(self, columns: List[str]) -> Row: """Add a row to the table. The number of columns must match the number of columns in the table.""" @@ -36,9 +35,6 @@ def add_row(self, columns: List[str]) -> Row: def get_string(self, first_column_width=None, second_column_width=None) -> str: """Serialize the table to a string.""" - if len(self._rows) == 0: - return "" - # Collect the max width of each column column_widths = [] for row in self._rows: @@ -56,12 +52,7 @@ def get_string(self, first_column_width=None, second_column_width=None) -> str: column_widths[2] = max(second_column_width, column_widths[2]) serialized_table = "" - if self._sortable: - sorted_rows = sorted(self._rows, key=lambda row: row._columns[0]) - else: - sorted_rows = self._rows - - for row in sorted_rows: + for row in self._rows: for i, column in enumerate(row._columns): serialized_table += f"{str(column).ljust(column_widths[i] + 2)}" serialized_table += "\n" diff --git a/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc b/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc index 22f1da1327547..a7a246519419a 100644 --- a/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc +++ b/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc @@ -26,9 +26,7 @@ #include "test/capturing_sink.h" #include "test/test_environment.h" #include "test/util/include/asserts.h" -#include "orttraining/core/optimizer/memory_optimizer/common.h" -#include "orttraining/core/optimizer/memory_optimizer/memory_optimizer.h" -#include "orttraining/core/optimizer/memory_optimizer/memory_insight.h" +#include "orttraining/core/optimizer/memory_optimizer.h" using namespace std; using namespace ONNX_NAMESPACE; @@ -62,9 +60,9 @@ TEST(MemoryOptimizerTests, GeluRecompute) { onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; const std::string alleviation_config("Gelu+:1:-1"); - const std::string probe_config("1:0"); + const std::string alleviation_level("1"); ASSERT_STATUS_OK(graph_transformation_mgr.Register( - std::make_unique(alleviation_config, probe_config), TransformerLevel::Level3)); + std::make_unique(alleviation_config, alleviation_level), TransformerLevel::Level3)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level3, *logger)); @@ -92,7 +90,8 @@ TEST(MemoryOptimizerTests, GeluRecompute) { ASSERT_EQ(original_gelu_node->Priority(), static_cast(ExecutionPriority::DEFAULT)); } -TEST(MemoryOptimizerTests, TileRecompute) { +// Disable this UT for now. It has strong dependency on graph topological order, which is not correct logically. +TEST(MemoryOptimizerTests, DISABLED_TileRecompute) { const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); auto model_uri = MODEL_FOLDER "recompute_tile.onnx"; std::shared_ptr model; @@ -105,15 +104,15 @@ TEST(MemoryOptimizerTests, TileRecompute) { onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - const std::string alleviation_config("Expand+Tile+:1:-1"); - const std::string probe_config("1:0"); + const std::string alleviation_config("Tile+:1:-1"); + const std::string alleviation_level("1"); ASSERT_STATUS_OK(graph_transformation_mgr.Register( - std::make_unique(alleviation_config, probe_config), TransformerLevel::Level3)); + std::make_unique(alleviation_config, alleviation_level), TransformerLevel::Level3)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level3, *logger)); op_to_count = CountOpsInGraph(graph); - ASSERT_EQ(op_to_count["Tile"], 2); + ASSERT_TRUE(op_to_count["Tile"] == 2); ASSERT_TRUE(op_to_count["com.microsoft.YieldOp"] == 1); ASSERT_TRUE(op_to_count["com.microsoft.FusedMatMul"] == 3); @@ -137,180 +136,13 @@ TEST(MemoryOptimizerTests, TileRecompute) { ASSERT_TRUE(original_tile_node); ASSERT_TRUE(query_layer_grad_node); - const Node* recompute_expand_node = graph.GetProducerNode(recompute_tile_node->InputDefs()[0]->Name()); - ASSERT_TRUE(recompute_expand_node); - - const Node* original_expand_node = graph.GetProducerNode(original_tile_node->InputDefs()[0]->Name()); - ASSERT_TRUE(original_expand_node); - - ASSERT_EQ(recompute_expand_node->InputDefs()[0]->Name(), original_expand_node->InputDefs()[0]->Name()); - ASSERT_EQ(query_layer_grad_node->InputDefs()[1]->Name(), recompute_tile_node->OutputDefs()[0]->Name()); + ASSERT_EQ(recompute_tile_node->MutableInputDefs()[0]->Name(), original_tile_node->MutableInputDefs()[0]->Name()); + ASSERT_EQ(query_layer_grad_node->InputDefs()[1]->Name(), recompute_tile_node->MutableOutputDefs()[0]->Name()); ASSERT_EQ(recompute_tile_node->Priority(), static_cast(ExecutionPriority::LOCAL_LOW)); ASSERT_EQ(original_tile_node->Priority(), static_cast(ExecutionPriority::DEFAULT)); ASSERT_EQ(query_layer_grad_node->Priority(), static_cast(ExecutionPriority::DEFAULT)); } -TEST(MemoryOptimizerTests, TransformerPerLayerRecompute) { - const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); - auto model_uri = MODEL_FOLDER "3layer_bloom_optimized_training.onnx"; - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger)); - Graph& graph = model->MainGraph(); - - // Find all optimizable subgraphs - GraphViewer graph_viewer(graph); - const std::string initial_mem_config(""); - const std::string probe_config("1:1"); - std::map> - cluster_id_combinations_to_saved_symbolic_byte_map; - std::string record_str = - optimizer::memory_optimizer::GetSerializedORTModuleMemoryStat(graph_viewer, - initial_mem_config, - probe_config, - *logger, - cluster_id_combinations_to_saved_symbolic_byte_map, - nullptr, - nullptr); - - InlinedHashMap cluster_id_to_config_map; - for (auto it = cluster_id_combinations_to_saved_symbolic_byte_map.begin(); - it != cluster_id_combinations_to_saved_symbolic_byte_map.end(); ++it) { - std::string cluster_id = it->first; - ORT_ENFORCE(optimizer::memory_optimizer::ParseOptimizationConfigFromString(cluster_id, cluster_id_to_config_map) - .IsOK()); - } - std::ostringstream oss; - int index = 0; - for (auto it = cluster_id_to_config_map.begin(); it != cluster_id_to_config_map.end(); ++it) { - if (it->second.type == optimizer::memory_optimizer::OptimizationType::Recompute) { - oss << (index == 0 ? "" : ",") << it->first << ":1:-1"; - ++index; - } - } - - // Apply the transformer - GraphTransformerManager graph_transformation_mgr{5}; - const std::string layer_wise_recompute_config(oss.str()); - ASSERT_STATUS_OK(graph_transformation_mgr.Register( - std::make_unique(layer_wise_recompute_config, probe_config), TransformerLevel::Level3)); - - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level3, *logger)); - - std::vector bw_nodes_in_expected_order; - const Node* yield_op_node = nullptr; - for (auto& node : graph.Nodes()) { - if (node.OpType().compare("YieldOp") == 0) { - yield_op_node = &node; - } - } - ASSERT_TRUE(yield_op_node != nullptr); - bw_nodes_in_expected_order.push_back(yield_op_node); - - for (int layer_index = 2; layer_index >= 0; --layer_index) { - const Node* input_layer_norm_grad_node = nullptr; - { - // The input of LayerNormalization node in Attention should not be recomputed for the transformer layerwise probe. - auto consumers = graph.GetConsumerNodes("_original_module._original_model.transformer.h." + - std::to_string(layer_index) + ".input_layernorm.weight"); - // Check there are two LayerNormalization nodes, one of them is the original one, - // and the other is the recomputed one - const Node* original_ln_node = nullptr; - const Node* recompute_ln_node = nullptr; - const Node* original_ln_node_parent_add_or_ln_node = nullptr; - const Node* recompute_ln_node_parent_add_or_ln_node = nullptr; - - for (auto& consumer : consumers) { - if (consumer->OpType().compare("LayerNormalization") == 0) { - if (consumer->Name().find("_recompute") != std::string::npos) { - recompute_ln_node = consumer; - ASSERT_EQ(consumer->Priority(), static_cast(ExecutionPriority::LOCAL_LOW)); - recompute_ln_node_parent_add_or_ln_node = graph.GetProducerNode(consumer->InputDefs()[0]->Name()); - ASSERT_TRUE(recompute_ln_node_parent_add_or_ln_node != nullptr); - ASSERT_EQ(recompute_ln_node_parent_add_or_ln_node->Priority(), static_cast(ExecutionPriority::DEFAULT)); - ASSERT_TRUE(recompute_ln_node_parent_add_or_ln_node->Name().find("_recompute") == std::string::npos); - } else { - original_ln_node = consumer; - ASSERT_EQ(consumer->Priority(), static_cast(ExecutionPriority::DEFAULT)); - original_ln_node_parent_add_or_ln_node = graph.GetProducerNode(consumer->InputDefs()[0]->Name()); - ASSERT_TRUE(original_ln_node_parent_add_or_ln_node); - ASSERT_EQ(original_ln_node_parent_add_or_ln_node->Priority(), static_cast(ExecutionPriority::DEFAULT)); - ASSERT_TRUE(original_ln_node_parent_add_or_ln_node->Name().find("_recompute") == std::string::npos); - } - } else if (consumer->OpType().compare("LayerNormalizationGrad") == 0) { - input_layer_norm_grad_node = consumer; - ASSERT_EQ(consumer->Priority(), static_cast(ExecutionPriority::DEFAULT)); - } - } - - ASSERT_TRUE(recompute_ln_node); - ASSERT_TRUE(original_ln_node); - ASSERT_TRUE(input_layer_norm_grad_node); - } - - { - auto consumers = graph.GetConsumerNodes("_original_module._original_model.transformer.h." + - std::to_string(layer_index) + ".post_attention_layernorm.weight"); - // Check there are two LayerNormalization nodes, one of them is the original one, - // and the other is the recomputed one - const Node* original_ln_node = nullptr; - const Node* recompute_ln_node = nullptr; - const Node* original_ln_node_parent_add_node = nullptr; - const Node* recompute_ln_node_parent_add_node = nullptr; - const Node* ln_grad_node = nullptr; - - for (auto& consumer : consumers) { - if (consumer->OpType().compare("LayerNormalization") == 0) { - if (consumer->Name().find("_recompute") != std::string::npos) { - recompute_ln_node = consumer; - ASSERT_EQ(consumer->Priority(), static_cast(ExecutionPriority::LOCAL_LOW)); - recompute_ln_node_parent_add_node = graph.GetProducerNode(consumer->InputDefs()[0]->Name()); - ASSERT_TRUE(recompute_ln_node_parent_add_node); - ASSERT_EQ(recompute_ln_node_parent_add_node->OpType(), "Add"); - ASSERT_EQ(recompute_ln_node_parent_add_node->Priority(), static_cast(ExecutionPriority::LOCAL_LOW)); - ASSERT_TRUE(recompute_ln_node_parent_add_node->Name().find("_recompute") != std::string::npos); - } else { - original_ln_node = consumer; - ASSERT_EQ(consumer->Priority(), static_cast(ExecutionPriority::DEFAULT)); - original_ln_node_parent_add_node = graph.GetProducerNode(consumer->InputDefs()[0]->Name()); - ASSERT_TRUE(original_ln_node_parent_add_node); - } - } else if (consumer->OpType().compare("LayerNormalizationGrad") == 0) { - ln_grad_node = consumer; - ASSERT_EQ(consumer->Priority(), static_cast(ExecutionPriority::DEFAULT)); - } - } - - ASSERT_TRUE(recompute_ln_node); - ASSERT_TRUE(original_ln_node); - ASSERT_TRUE(ln_grad_node); - - bw_nodes_in_expected_order.push_back(recompute_ln_node_parent_add_node); - bw_nodes_in_expected_order.push_back(ln_grad_node); // ln gradient need the recomputed ln node's add node as input - } - bw_nodes_in_expected_order.push_back(input_layer_norm_grad_node); - } - - std::vector nodes_in_topological_order; - nodes_in_topological_order.reserve(bw_nodes_in_expected_order.size()); - const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); // ExecutionOrder::PRIORITY_BASED - - size_t j = 0; - for (auto node_index : node_topology_list) { - auto* node_ptr = graph.GetNode(node_index); - if (!node_ptr) continue; // Node was removed. - - if (std::find(bw_nodes_in_expected_order.begin(), bw_nodes_in_expected_order.end(), node_ptr) != - bw_nodes_in_expected_order.end()) { - nodes_in_topological_order.push_back(j); - j++; - } - } - - for (size_t i = 1; i < nodes_in_topological_order.size(); ++i) { - ASSERT_TRUE(nodes_in_topological_order[i - 1] < nodes_in_topological_order[i]); - } -} - } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index eb71f212a4b11..0efedf14fb3b8 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6394,58 +6394,3 @@ def run_step(model, x): if conv_algo_search is not None: del os.environ["ORTMODULE_CONV_ALGO_SEARCH"] - - -def test_bert_result_with_layerwise_recompute(): - original_val = os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] if "ORTMODULE_MEMORY_OPT_LEVEL" in os.environ else None - # Create PyTorch model with dropout disabled. - pt_model = _get_bert_for_sequence_classification_model( - "cuda", is_training=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0 - ) - ort_model = ORTModule(copy.deepcopy(pt_model)) - - os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] = "1" - ort_model_with_reompute = ORTModule( - copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="layerwise_recompute_test") - ) - - def run_step(model, x, y, z): - outputs = model(x, y, None, None, None, None, z) - loss = outputs[0] - loss.backward() - return outputs[0] - - for _ in range(10): - x, y, z = _get_bert_for_sequence_classification_sample_data_with_random_shapes("cuda") - - ort_p = run_step(ort_model, x, y, z) - ort_p_with_reompute = run_step(ort_model_with_reompute, x, y, z) - - _test_helpers.assert_values_are_close(ort_p, ort_p_with_reompute, atol=1e-02) - _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, ort_model_with_reompute) - - execution_mgr = ort_model_with_reompute._torch_module._execution_manager._training_manager - from onnxruntime.training.ortmodule._onnx_models import _get_onnx_file_name - - # Keep the logic aligned with _graph_execution_manager.py - path = os.path.join( - execution_mgr._debug_options.save_onnx_models.path, - _get_onnx_file_name( - execution_mgr._debug_options.save_onnx_models.name_prefix, "execution_model", execution_mgr._export_mode - ), - ) - - onnx_model = onnx.load(path) - onnx_nodes = onnx_model.graph.node - - recompute_nodes = 0 - for node in onnx_nodes: - if "_recompute" in node.name: - recompute_nodes += 1 - - assert recompute_nodes > 0, "No Recompute nodes are found" - - # Make sure environment variable is restored to its original value after the run is completed. - torch.cuda.synchronize() - if original_val is not None: - os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] = original_val