diff --git a/docs/Memory_Optimizer.md b/docs/Memory_Optimizer.md index 0afd24224a549..f26262918cd2c 100644 --- a/docs/Memory_Optimizer.md +++ b/docs/Memory_Optimizer.md @@ -17,42 +17,41 @@ Classical scenarios include: Not all models and recipes need this optimizer technique. Imagine if your training recipe uses a batch size 6 (GPU compute and memory are fully saturated), and you don't need bump it to 8 to maintain a fixed global batch size. Enabling recompute maybe not bring better throughput on batch size 8 than the original batch size 6. +## Usage -## Simple Usage (Aggressively Recompute All) -1. Make sure ONNX Runtime training wheel is installed and correctly configured. -2. Integrate models using `ORTModule`. +Make sure ONNX Runtime training wheel is installed and correctly configured. +Integrate models using `ORTModule`. > ort_model = ORTModule(pt_model) -3. Set memory optimization level to be AGGRESSIVE_FULL_RECOMPUTE, by `export ORTMODULE_MEMORY_OPT_LEVEL=1` -4. Run the training as usual; check the logs, you could find something like this: + +There are two modes to enable the memory optimizations: +- Aggressively Recompute All, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=1`, will recompute all detected subgraphs. The advantage of using this mode is, it is easy to enable, while be noted this recompute plan may NOT be the best one. +- User Specified Subgraph Recompute, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=0` and `export ORTMODULE_MEMORY_OPT_CONFIG=,,...`. This is an advanced usage, allow users to find most suitable graphs to recompute, at the cost of overhead to look for the best plans. + +### Mode 1 - Simple Usage (Aggressively Recompute All) + + +1. Set memory optimization level to be AGGRESSIVE_FULL_RECOMPUTE, by `export ORTMODULE_MEMORY_OPT_LEVEL=1` +2. Run the training as usual; check the logs, you could find something like this: ``` -Memory Optimizer : ON : Memory Optimization Level: [AGGRESSIVE_FULL_RECOMPUTE], Optimization Config: [Reshape+Where+:1:-1,BiasSoftmax+:1:-1,Cast+:1:-1,BiasGelu+:1:-1,FusedMatMul+:1:-1,Add+:1:-1,Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1], Probe Level: [1] - Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes) - - Plan 1 : ON : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 - - Plan 2 : ON : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) - - Plan 3 : ON : Cast+:1:-1 1 67,043,328 64.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) - - Plan 4 : ON : BiasGelu+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) - - Plan 5 : ON : FusedMatMul+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) - - Plan 6 : ON : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) - - Plan 7 : ON : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1 - - Plan 8 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1 - - - Note 1: use comma as delimiter to enable multiple memory optimization plans at the same time: - export ORTMODULE_MEMORY_OPT_CONFIG=,,... - Note 2: memory saving is calculated based on the 1st batch symbolic dim values: - inputs_input_ids_dim0=1, inputs_input_ids_dim1=1024, inputs_attention_mask_dim0=1, inputs_attention_mask_dim1=1024, inputs_labels_dim0=1, inputs_labels_dim1=1024, + Memory Optimizer : ON : Memory Optimization Level: [AGGRESSIVE_FULL_RECOMPUTE], Optimization Config: [Reshape+Where+:1:-1,BiasSoftmax+:1:-1,Cast+:1:-1,BiasGelu+:1:-1,FusedMatMul+:1:-1,Add+:1:-1,Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1], Probe Level: [1] + Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes) + - Plan 1 : ON : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 + - Plan 2 : ON : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) + - Plan 3 : ON : Cast+:1:-1 1 67,043,328 64.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) + - Plan 4 : ON : BiasGelu+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) + - Plan 5 : ON : FusedMatMul+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) + - Plan 6 : ON : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) + - Plan 7 : ON : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1 + - Plan 8 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1 ``` -5. As shown above, `Config` is a string representative for a re-computable subgraph. All are enabled for recompute in this case. +3. As shown above, `Config` is a string representative for a re-computable subgraph. All are enabled for recompute in this case. -## Advanced Usage (User Selected Subgraph Recompute) +### Mode 2 - Advanced Usage (User Selected Subgraph Recompute) -1. Make sure ONNX Runtime training wheel is installed and correctly configured. -2. Integrate models using `ORTModule`. - > ort_model = ORTModule(pt_model) -3. Be noted `ORTMODULE_MEMORY_OPT_LEVEL` is by default be 0. Run the training as usual; then stop it after training a few steps. -4. Check the logs, you could find something like this: +1. Be noted `ORTMODULE_MEMORY_OPT_LEVEL` is by default be 0. Run the training as usual; then stop it after training a few steps. +2. Check the logs, you could find something like this: ``` Memory Optimizer : OFF : Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1 or ORTMODULE_MEMORY_OPT_CONFIG=,,... Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes) @@ -65,13 +64,18 @@ Memory Optimizer : ON : Memory Optimization Level: [AGGRESSIVE_FULL_RECO - Plan 7 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1 - Plan 8 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1 ``` -5. As shown above, `Config` is a string representative for a re-computable subgraph. All are disabled for recompute in this case. -6. Set environment variable `ORTMODULE_MEMORY_OPT_CONFIG` to enable some of the subgraph to do recompute. In below example, `1` `BiasGelu+` related subgraphs are allowed to recompute. -`BiasGelu+` is the subgraph string representative; `1` in the middle indicates 'Recompute' is enabled (0, on the contrary indicates it's disabled); `6` means the initial 6 subgraph occurrences will be recomputed, all others are left as it is, filling `-1` will make all occurrences be recomputed. - ``` - export ORTMODULE_MEMORY_OPT_CONFIG="BiasGelu+:1:1" # Use comma as a separator for enabling more than one subgraphs. +3. As shown above, `Config` is a string representative for a re-computable subgraph. All are disabled for recompute in this case. +4. Set environment variable `ORTMODULE_MEMORY_OPT_CONFIG` to enable some of the subgraphs to do recompute. + ```bash + # Use comma as a separator for enabling more than one subgraphs. + export ORTMODULE_MEMORY_OPT_CONFIG="BiasGelu+:1:1" + # 1 `BiasGelu+` related subgraphs are allowed to recompute. + # `BiasGelu+` is the subgraph string representative; + # `1` in the middle indicates 'Recompute' is enabled (0, on the contrary indicates it's disabled) + # The last `1` means the initial 1 subgraph occurrences will be recomputed, all others are left as it is, filling `-1` will make all occurrences be recomputed. + ``` -7. Then run the training again, and you will see logs like this: +5. Then run the training again, and you will see logs like this: ``` Memory Optimizer : ON : Memory Optimization Level: [USER_SPECIFIED], Optimization Config: [BiasGelu+:1:-1], Probe Level: [1] Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes) @@ -84,7 +88,7 @@ Memory Optimizer : ON : Memory Optimization Level: [AGGRESSIVE_FULL_RECO - Plan 7 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1 - Plan 8 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1 ``` -8. You may need iterate few times on step 6 and 7 until you find a good config for this model to run a bigger batch size. Or you may fail to find if memory optimization does not apply to the model well. +6. You may need iterate a few times on step 4 and 5 until you find a good config for this model to run a bigger batch size. Or you may fail to find if memory optimization does not apply to the model well. ## Optimization Configuration diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index 7fa89cca381d9..e9a39756b6e25 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -278,6 +278,15 @@ data sparsity based performance optimizations. export ORTMODULE_USE_EFFICIENT_ATTENTION=1 ``` +#### ORTMODULE_MEMORY_OPT_LEVEL + +- **Feature Area**: *ORTMODULE/Optimizations* +- **Description**: By default, the level is 0. This env var can be used for enabling recomputation for reducing memory peak requirement. Setting the level to be 0 means all detected subgraphs generating stashed activations will be recomputed. This is conceptually equivalent to PyTorch gradient checkpoint. When level is not 0, check Check [Memory Optimizer for ONNX Runtime Training](Memory_Optimizer.md) for more details. + + ```bash + export ORTMODULE_MEMORY_OPT_LEVEL=0 + ``` + ### 2.2 Memory Optimization Q: *Want to run a bigger batch size?*