Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
Browse files Browse the repository at this point in the history
…zhanyi/nettarget
  • Loading branch information
mszhanyi committed Dec 13, 2023
2 parents 87fa31a + dbe886a commit b8ee367
Show file tree
Hide file tree
Showing 103 changed files with 4,492 additions and 802 deletions.
7 changes: 7 additions & 0 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,9 @@ file(GLOB onnxruntime_python_quantization_operators_src CONFIGURE_DEPENDS
file(GLOB onnxruntime_python_quantization_cal_table_flatbuffers_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/quantization/CalTableFlatBuffers/*.py"
)
file(GLOB onnxruntime_python_quantization_fusions_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/quantization/fusions/*.py"
)
file(GLOB onnxruntime_python_quantization_ep_qnn_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/quantization/execution_providers/qnn/*.py"
)
Expand Down Expand Up @@ -550,6 +553,7 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/operators
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/CalTableFlatBuffers
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/fusions
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/execution_providers
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/execution_providers/qnn
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/quantization
Expand Down Expand Up @@ -622,6 +626,9 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_quantization_cal_table_flatbuffers_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/CalTableFlatBuffers/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_quantization_fusions_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/fusions/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_quantization_ep_qnn_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/execution_providers/qnn/
Expand Down
120 changes: 75 additions & 45 deletions docs/Memory_Optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,67 +17,97 @@ Classical scenarios include:

Not all models and recipes need this optimizer technique. Imagine if your training recipe uses a batch size 6 (GPU compute and memory are fully saturated), and you don't need bump it to 8 to maintain a fixed global batch size. Enabling recompute maybe not bring better throughput on batch size 8 than the original batch size 6.

## Quick trial
## Usage

1. Make sure ONNX Runtime training wheel is installed and correctly configured.
2. Integrate models using `ORTModule`, be noted log_level should be equal or lower than INFO.
> ort_model = ORTModule(pt_model, DebugOptions(log_level=LogLevel.INFO))
3. Run the training as usual; then stop it after training few steps.
4. Check the logs, you could find something like this:

Make sure ONNX Runtime training wheel is installed and correctly configured.
Integrate models using `ORTModule`.
```diff
model = build_model()

+ from onnxruntime.training.ortmodule import ORTModule
+ model = ORTModule(model)
```

There are two modes to enable the memory optimizations:
- Aggressively Recompute All Within Each Transformer Layer, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=1`. This will recompute all detected subgraphs within each Transformer Attention+MLP layer. It is easy to enable, but be noted this recompute plan may NOT be the best one. In this mode, `ORTMODULE_MEMORY_OPT_CONFIG` env values passed by users are not respected.
- User Specified Subgraph Recompute, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=0` and `export ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,...`. This is an advanced usage, that allows users to find the most suitable graphs to recompute, at the cost of overhead to look for the best plans.

### Mode 1 - Simple Usage (Aggressively Recompute All Within Each Transformer Layer)


1. Set memory optimization level to be TRANSFORMER_LAYERWISE_RECOMPUTE, by `export ORTMODULE_MEMORY_OPT_LEVEL=1`
2. Run the training as usual; check the logs, you could find something like this if the current log level <= LogLevel.INFO:
```
Memory Optimizer : OFF : Enable with env ORTMODULE_MEMORY_OPT_CONFIG=<config>, available configs:
Config Freq Max Saving(B) Saving Symbolic(Bytes)
- Plan 1 : OFF : Reshape+Where+BiasSoftmax+:1:-1 5 671,088,640 640.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- Plan 2 : OFF : Cast+:1:-1 6 402,587,648 inputs_input_ids_dim0*inputs_input_ids_dim1*(384.0*inputs_input_ids_dim1 - 64.0)
- Plan 3 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- Plan 4 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
- Plan 5 : OFF : BiasGelu+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0)
- Plan 6 : OFF : FusedMatMul+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0)
- Plan 7 : OFF : FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1 5 26,214,400 25600.0*inputs_input_ids_dim0*inputs_input_ids_dim1
- Plan 8 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
- Plan 9 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1
- Plan 10 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1


Note 1: use comma as delimiter to enable multiple memory optimization plans at the same time:
export ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,...
Note 2: memory saving is calculated based on the 1st batch symbolic dim values:
inputs_input_ids_dim0=1, inputs_input_ids_dim1=1024, inputs_attention_mask_dim0=1, inputs_attention_mask_dim1=1024, inputs_labels_dim0=1, inputs_labels_dim1=1024,
Memory Optimizer : ON : Memory Optimization Level: [TRANSFORMER_LAYERWISE_RECOMPUTE], Optimization Config: [Reshape+Where+:1:-1,BiasSoftmax+:1:-1,Cast+:1:-1,BiasGelu+:1:-1,FusedMatMul+:1:-1,Add+:1:-1,Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1]
Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes)
- Plan 1 : ON : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- Plan 2 : ON : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
- Plan 3 : ON : Cast+:1:-1 1 67,043,328 64.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
- Plan 4 : ON : BiasGelu+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
- Plan 5 : ON : FusedMatMul+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
- Plan 6 : ON : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
- Plan 7 : ON : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1
- Plan 8 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1
```
5. As shown above, `Config` is a string representative for a re-computable subgraph. All are disabled for recompute in this case.
6. Set environment variable `ORTMODULE_MEMORY_OPT_CONFIG` to enable some of the subgraph to do recompute. In below example, `6` `BiasGelu+` related subgraphs are allowed to recompute.
`BiasGelu+` is the subgraph string representative; `1` in the middle indicates 'Recompute' is enabled (0, on the contrary indicates it's disabled); `6` means the initial 6 subgraph occurrences will be recomputed, all others are left as it is, filling `-1` will make all occurrences be recomputed.
3. As shown above, `Config` is a string representative for a re-computable subgraph. All are enabled for recompute in this case.


### Mode 2 - Advanced Usage (User Selected Subgraph Recompute)

1. Be noted `ORTMODULE_MEMORY_OPT_LEVEL` is by default be 0. Run the training as usual; then stop it after training a few steps.
2. Check the logs, you could find something like this if the current log level <= LogLevel.INFO::
```
export ORTMODULE_MEMORY_OPT_CONFIG="BiasGelu+:1:6" # Use comma as separator for enabling more than one subgraphs.
Memory Optimizer : OFF : Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1 or ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,...
Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes)
- Plan 1 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- Plan 2 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
- Plan 3 : OFF : Cast+:1:-1 1 67,043,328 64.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
- Plan 4 : OFF : BiasGelu+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
- Plan 5 : OFF : FusedMatMul+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
- Plan 6 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
- Plan 7 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1
- Plan 8 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1
```
7. Then run the training again, and you will see logs like this:
3. As shown above, `Config` is a string representative for a re-computable subgraph. All are disabled for recompute in this case.
4. Set environment variable `ORTMODULE_MEMORY_OPT_CONFIG` to enable some of the subgraphs to do recompute.
```bash
# Use comma as a separator for enabling more than one subgraphs.
export ORTMODULE_MEMORY_OPT_CONFIG="BiasGelu+:1:1"
# Explanation:
# > BiasGelu+ is the subgraph string representative;
# > 1 in the middle indicates 'Recompute' is enabled (0, on the contrary indicates it's disabled)
# > The last 1 means the initial 1 subgraph occurrences will be recomputed, all others are left as it is, filling `-1` will make all occurrences be recomputed.

```
5. Then run the training again, and you will see logs like this:
```
Memory Optimizer : ON : User config: Reshape+Where+BiasSoftmax+:1:-1, probe level: 1, available configs:
Config Freq Max Saving(B) Saving Symbolic(Bytes)
- Plan 1 : OFF : Reshape+Where+BiasSoftmax+:1:-1 5 671,088,640 640.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- Plan 2 : OFF : Cast+:1:-1 6 402,587,648 inputs_input_ids_dim0*inputs_input_ids_dim1*(384.0*inputs_input_ids_dim1 - 64.0)
- Plan 3 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- Plan 4 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
- Plan 5 : ON : BiasGelu+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0)
- Plan 6 : OFF : FusedMatMul+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0)
- Plan 7 : OFF : FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1 5 26,214,400 25600.0*inputs_input_ids_dim0*inputs_input_ids_dim1
- Plan 8 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
- Plan 9 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1
- Plan 10 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1
Memory Optimizer : ON : Memory Optimization Level: [USER_SPECIFIED], Optimization Config: [BiasGelu+:1:-1]
Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes)
- Plan 1 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- Plan 2 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
- Plan 3 : OFF : Cast+:1:-1 1 67,043,328 64.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
- Plan 4 : ON : BiasGelu+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
- Plan 5 : OFF : FusedMatMul+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
- Plan 6 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
- Plan 7 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1
- Plan 8 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1
```
8. You may need iterate few times on step 6 and 7 until you find a good config for this model to run a bigger batch size. Or you may fail to find if memory optimization does not apply to the model well.
6. You may need iterate a few times on step 4 and 5 until you find a good config for this model to run a bigger batch size. Or you may fail to find if memory optimization does not apply to the model well.

## Optimization Configuration

The basic optimization unit is represented with a unique `cluster id`, for example `BiasGelu+` is one `cluster id`.
Following `cluster id` is the `optimization strategy`: 0 - none, 1 - recompute, 2 - recompute with compromised memory saving.
Following `optimization strategy` is the `request count` to apply the given optimization. Using `-1` to apply all. This would give user a bit more flexibility to avoid unnecessary memory saving.

## Compromised Recompute
### Compromised Recompute

If you check the above logs, there is a config `Cast+:2:-1`, `2` indicates it's a recomputation than can save part of the stashed activation size, not all. Recompute the subgraphs under it usually will save part of the activation (for example half of them), not all of them. Follow the same way to enable it.

## Memory Optimization Debug Infos
## Dev Notes

### Memory Optimization Debug Infos

Using following log level
> ort_model = ORTModule(pt_model, DebugOptions(log_level=LogLevel.DEVINFO))
Expand Down Expand Up @@ -132,4 +162,4 @@ MemoryInsight Summary - User config: not provided

## Notes

The feature is in experimental stage, we will tune and refine it according to real use cases.
The feature is in the experimental stage, we will tune and refine it according to real use cases.
14 changes: 9 additions & 5 deletions docs/ORTModule_Training_Guidelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ Check [DebugOptions implementation](../orttraining/orttraining/python/training/o
export ORTMODULE_ONNX_OPSET_VERSION=14
```


#### ORTMODULE_FALLBACK_POLICY

- **Feature Area**: *ORTMODULE/FallbackToPytorch*
Expand All @@ -155,7 +154,6 @@ Check [DebugOptions implementation](../orttraining/orttraining/python/training/o
export ORTMODULE_FALLBACK_POLICY="FALLBACK_DISABLE"
```


#### ORTMODULE_LOG_LEVEL

- **Feature Area**: *ORTMODULE/DebugOptions*
Expand All @@ -182,7 +180,6 @@ The output directory of the onnx models by default is set to the current working
> On the other hand, if the wrapped computation graph is small, it is reasonable to allow it.
> Overall users should be aware that ORT performance boost might be trivial when they explicitly allow it.


#### ORTMODULE_ENABLE_CUSTOM_AUTOGRAD

- **Feature Area**: *ORTMODULE/PythonOp (torch.autograd.Function)*
Expand All @@ -199,8 +196,6 @@ The output directory of the onnx models by default is set to the current working
enable_custom_autograd_support(False)
```



#### ORTMODULE_ENABLE_COMPUTE_OPTIMIZER

- **Feature Area**: *ORTMODULE/Optimizations*
Expand Down Expand Up @@ -289,6 +284,15 @@ A classical usage of disabling the deep copy: when the deep copy before module e
export ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT=0 # Disable
```

#### ORTMODULE_MEMORY_OPT_LEVEL

- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, the level is 0. This env var can be used for enabling recomputation for reducing memory peak requirement. Setting the level to be 0 means all detected subgraphs with each transformer-based model layer generating stashed activations will be recomputed. This is conceptually equivalent to PyTorch's gradient checkpoint. When level is not 0, check Check [Memory Optimizer for ONNX Runtime Training](Memory_Optimizer.md) for more details.
```bash
export ORTMODULE_MEMORY_OPT_LEVEL=0
```
### 2.2 Memory Optimization
Q: *Want to run a bigger batch size?*
Expand Down
3 changes: 3 additions & 0 deletions include/onnxruntime/core/graph/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,7 @@ constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider";
constexpr const char* kExecutionProviderSharedLibraryPath = "shared_lib_path";
constexpr const char* kExecutionProviderSharedLibraryEntry = "provider_factory_entry_point";

// For Priority based graph topology sorting.
constexpr const char* kBackwardNodeAttributeName = "__backwardpass";

} // namespace onnxruntime
Loading

0 comments on commit b8ee367

Please sign in to comment.