Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Memory optimization refactor and refinement (#17481)
### Memory optimization refactor and refinement Currently memory optimizer runs graph transformations and print recompute opportunities in INFO level, while ORT backend has many many INFO level logs making users hard to find those information. So we are looking for a Python binding API to retrieve the memory optimization opportunities instead of depending on the MemoryOptimizer's default logging. Then we can print ORTModule feature statistics using this information. Also, with such an API, we can create an ORT session created, where allocation plan is done, the analysis will consider buffer reuse as well. This can void giving some recomputation subgraphs that are reusing other subgraphs' output buffers. Check https://github.com/microsoft/onnxruntime/blob/pengwa/add_devinfo_level/docs/Memory_Optimizer.md for the new flow using `MemoryOptimizer`. This pull requests made following refactoring: 1. Print the log in ORTModule Python script, along with ORTModule feature enabling stats. This is implemented by exposing an API `get_serialized_ortmodule_memory_stat` to retrieve the memory optimization opportunities. 2. We are analyzing memory optimization opportunities considering ORT memory planning. This is done by firstly creating the execution graph without enabling MemoryOptimizer, then we call `execution_agent.get_serialized_ortmodule_memory_stat` which internally will consider the session memory allocation planner when analyzing memory optimization opportunity. As a direct result, the memory optimization opportunities can show those stashed activations that are reusing other buffers. 3. Move recompute analysis logic from memory_optimizer.h/cc to recompute_analysis.h/cc. 4. Abstract optimization strategies for their own implementation. This will make introducing new strategies (for example compression and decompression ) easier. New logging matrix (INFO Level), in WARNING level, the details will NOT show. ``` 2023-09-13 13:25:09,249 orttraining.rank-0 [WARNING] - ***** ONNX Runtime Training (ORTModule) is accelerating your model ***** ORTModule is enabled with following features ON/OFF for [training] mode: ATen Executor : ON : Dispatch ATen operators to ORT's ATen executor Cast Propagation : ON : Level 1 enabled Custom Function : ON : Support custom torch.autograd.Function export and execution Memory Optimizer : ON : RecomputeConfig: Reshape+Where+BiasSoftmax+:1:-1,Cast+:1:-1, ProbeLevel: 1, available configs: Config Freq Saving(B) Saving Symbolic(Bytes) - Plan 1 : ON : Reshape+Where+BiasSoftmax+:1:-1 5 671,088,640 640.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 - Plan 2 : ON : Cast+:1:-1 6 402,587,648 inputs_input_ids_dim0*inputs_input_ids_dim1*(384.0*inputs_input_ids_dim1 - 64.0) - Plan 3 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 - Plan 4 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) - Plan 5 : OFF : BiasGelu+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0) - Plan 6 : OFF : FusedMatMul+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0) - Plan 7 : OFF : FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1 5 26,214,400 25600.0*inputs_input_ids_dim0*inputs_input_ids_dim1 - Plan 8 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) - Plan 9 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1 - Plan 10 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1 Compute Optimizer : ON : Enable/Disable with env ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1/0 - FLOPReduction : ON : Reduce FLOPs by upstreaming shrinking-sized ops Auto Fallback : ON : Fallback to PyTorch when encountering unsupported ops TritonOp Enabled : OFF : ORT will switch to Triton for executing some ops to further accelerate training. ZeRO Stage3 Support : OFF : Enable/Disable with env ORTMODULE_ENABLE_ZERO_STAGE3=1/0 Total ORT initialization overhead is 10.73s where export takes 8.39s. Other overhead details: graph builder init takes 0.06s, runtime detection takes 0.01s, graph building takes 0.31s, session creation takes 1.96s Versions: ONNX Runtime - 1.16.0+cu118, ONNX - 1.11.0 Note 1: use comma to enable multiple plans at the same time. export ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,... Note 2: saving is calculated based on the 1st batch symbolic dim values: inputs_input_ids_dim0=1, inputs_input_ids_dim1=1024, inputs_attention_mask_dim0=1, inputs_attention_mask_dim1=1024, inputs_labels_dim0=1, inputs_labels_dim1=1024, ************************************************************************ ``` If DEVINFO level is enabled, then more details about the memory optimizations are printed. ``` MemoryInsight Summary - User config: BiasGelu+:1:-1,Cast+:2:-1 ========================================================================================================================================== |Freq | Memory Optimization Opportunities (Clustered by node-level activation patterns) | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |3 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph FusedMatMul+Add+Reshape+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+Add+Reshape+:1:-1 | | | Stashed Activations: | | | - ReuseFreq : Output 0(3), | | | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 32 x 240 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |2 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph Reshape+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+:1:-1 | | | Stashed Activations: | | | - ReuseFreq : Output 0(2), | | | - Output 0 : [ x 2560 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |2 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph FusedMatMul+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 10240 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |2 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph Cast+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 x inputs_input_ids_dim1 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |2 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph Reshape+Where+BiasSoftmax+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Where+BiasSoftmax+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |2 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph BiasGelu+ | | | Status : Enabled, requested count=-1, actual applied count=2 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 10240 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |2 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph FusedMatMul+Add+FusedMatMul+Add+Add+Add+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 2560 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |1 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph Reshape+Where+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Where+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |1 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph FusedMatMul+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) x 10240 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |1 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph Cast+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 - 1 x inputs_input_ids_dim1 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |1 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x 1 x 1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved | | | | | |>>Option 2 : RecomputeWithCompromise subgraph Cast+ | | | Status : Enabled, requested count=-1, actual applied count=1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x 1 x 1 x inputs_input_ids_dim1 x ], byte/elem: 4, 50% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |1 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph BiasSoftmax+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=BiasSoftmax+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 - 1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |1 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph BiasGelu+ | | | Status : Enabled, requested count=-1, actual applied count=1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) x 10240 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |1 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph Add+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Add+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) x 2560 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | ========================================================================================================================================== Note: use comma as a separator for enabling more than one subgraphs. ************************************************************************ ``` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
- Loading branch information