Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
pengwa committed Nov 23, 2023
1 parent 7164c2f commit 7217ad3
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 35 deletions.
74 changes: 39 additions & 35 deletions docs/Memory_Optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,42 +17,41 @@ Classical scenarios include:

Not all models and recipes need this optimizer technique. Imagine if your training recipe uses a batch size 6 (GPU compute and memory are fully saturated), and you don't need bump it to 8 to maintain a fixed global batch size. Enabling recompute maybe not bring better throughput on batch size 8 than the original batch size 6.

## Usage

## Simple Usage (Aggressively Recompute All)

1. Make sure ONNX Runtime training wheel is installed and correctly configured.
2. Integrate models using `ORTModule`.
Make sure ONNX Runtime training wheel is installed and correctly configured.
Integrate models using `ORTModule`.
> ort_model = ORTModule(pt_model)
3. Set memory optimization level to be AGGRESSIVE_FULL_RECOMPUTE, by `export ORTMODULE_MEMORY_OPT_LEVEL=1`
4. Run the training as usual; check the logs, you could find something like this:

There are two modes to enable the memory optimizations:
- Aggressively Recompute All, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=1`, will recompute all detected subgraphs. The advantage of using this mode is, it is easy to enable, while be noted this recompute plan may NOT be the best one.
- User Specified Subgraph Recompute, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=0` and `export ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,...`. This is an advanced usage, allow users to find most suitable graphs to recompute, at the cost of overhead to look for the best plans.

### Mode 1 - Simple Usage (Aggressively Recompute All)


1. Set memory optimization level to be AGGRESSIVE_FULL_RECOMPUTE, by `export ORTMODULE_MEMORY_OPT_LEVEL=1`
2. Run the training as usual; check the logs, you could find something like this:
```
Memory Optimizer : ON : Memory Optimization Level: [AGGRESSIVE_FULL_RECOMPUTE], Optimization Config: [Reshape+Where+:1:-1,BiasSoftmax+:1:-1,Cast+:1:-1,BiasGelu+:1:-1,FusedMatMul+:1:-1,Add+:1:-1,Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1], Probe Level: [1]
Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes)
- Plan 1 : ON : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- Plan 2 : ON : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
- Plan 3 : ON : Cast+:1:-1 1 67,043,328 64.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
- Plan 4 : ON : BiasGelu+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
- Plan 5 : ON : FusedMatMul+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
- Plan 6 : ON : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
- Plan 7 : ON : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1
- Plan 8 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1


Note 1: use comma as delimiter to enable multiple memory optimization plans at the same time:
export ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,...
Note 2: memory saving is calculated based on the 1st batch symbolic dim values:
inputs_input_ids_dim0=1, inputs_input_ids_dim1=1024, inputs_attention_mask_dim0=1, inputs_attention_mask_dim1=1024, inputs_labels_dim0=1, inputs_labels_dim1=1024,
Memory Optimizer : ON : Memory Optimization Level: [AGGRESSIVE_FULL_RECOMPUTE], Optimization Config: [Reshape+Where+:1:-1,BiasSoftmax+:1:-1,Cast+:1:-1,BiasGelu+:1:-1,FusedMatMul+:1:-1,Add+:1:-1,Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1], Probe Level: [1]
Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes)
- Plan 1 : ON : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- Plan 2 : ON : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
- Plan 3 : ON : Cast+:1:-1 1 67,043,328 64.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
- Plan 4 : ON : BiasGelu+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
- Plan 5 : ON : FusedMatMul+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
- Plan 6 : ON : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
- Plan 7 : ON : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1
- Plan 8 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1
```
5. As shown above, `Config` is a string representative for a re-computable subgraph. All are enabled for recompute in this case.
3. As shown above, `Config` is a string representative for a re-computable subgraph. All are enabled for recompute in this case.


## Advanced Usage (User Selected Subgraph Recompute)
### Mode 2 - Advanced Usage (User Selected Subgraph Recompute)

1. Make sure ONNX Runtime training wheel is installed and correctly configured.
2. Integrate models using `ORTModule`.
> ort_model = ORTModule(pt_model)
3. Be noted `ORTMODULE_MEMORY_OPT_LEVEL` is by default be 0. Run the training as usual; then stop it after training a few steps.
4. Check the logs, you could find something like this:
1. Be noted `ORTMODULE_MEMORY_OPT_LEVEL` is by default be 0. Run the training as usual; then stop it after training a few steps.
2. Check the logs, you could find something like this:
```
Memory Optimizer : OFF : Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1 or ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,...
Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes)
Expand All @@ -65,13 +64,18 @@ Memory Optimizer : ON : Memory Optimization Level: [AGGRESSIVE_FULL_RECO
- Plan 7 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1
- Plan 8 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1
```
5. As shown above, `Config` is a string representative for a re-computable subgraph. All are disabled for recompute in this case.
6. Set environment variable `ORTMODULE_MEMORY_OPT_CONFIG` to enable some of the subgraph to do recompute. In below example, `1` `BiasGelu+` related subgraphs are allowed to recompute.
`BiasGelu+` is the subgraph string representative; `1` in the middle indicates 'Recompute' is enabled (0, on the contrary indicates it's disabled); `6` means the initial 6 subgraph occurrences will be recomputed, all others are left as it is, filling `-1` will make all occurrences be recomputed.
```
export ORTMODULE_MEMORY_OPT_CONFIG="BiasGelu+:1:1" # Use comma as a separator for enabling more than one subgraphs.
3. As shown above, `Config` is a string representative for a re-computable subgraph. All are disabled for recompute in this case.
4. Set environment variable `ORTMODULE_MEMORY_OPT_CONFIG` to enable some of the subgraphs to do recompute.
```bash
# Use comma as a separator for enabling more than one subgraphs.
export ORTMODULE_MEMORY_OPT_CONFIG="BiasGelu+:1:1"
# 1 `BiasGelu+` related subgraphs are allowed to recompute.
# `BiasGelu+` is the subgraph string representative;
# `1` in the middle indicates 'Recompute' is enabled (0, on the contrary indicates it's disabled)
# The last `1` means the initial 1 subgraph occurrences will be recomputed, all others are left as it is, filling `-1` will make all occurrences be recomputed.

```
7. Then run the training again, and you will see logs like this:
5. Then run the training again, and you will see logs like this:
```
Memory Optimizer : ON : Memory Optimization Level: [USER_SPECIFIED], Optimization Config: [BiasGelu+:1:-1], Probe Level: [1]
Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes)
Expand All @@ -84,7 +88,7 @@ Memory Optimizer : ON : Memory Optimization Level: [AGGRESSIVE_FULL_RECO
- Plan 7 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1
- Plan 8 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1
```
8. You may need iterate few times on step 6 and 7 until you find a good config for this model to run a bigger batch size. Or you may fail to find if memory optimization does not apply to the model well.
6. You may need iterate a few times on step 4 and 5 until you find a good config for this model to run a bigger batch size. Or you may fail to find if memory optimization does not apply to the model well.

## Optimization Configuration

Expand Down
9 changes: 9 additions & 0 deletions docs/ORTModule_Training_Guidelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,15 @@ data sparsity based performance optimizations.
export ORTMODULE_USE_EFFICIENT_ATTENTION=1
```

#### ORTMODULE_MEMORY_OPT_LEVEL

- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, the level is 0. This env var can be used for enabling recomputation for reducing memory peak requirement. Setting the level to be 0 means all detected subgraphs generating stashed activations will be recomputed. This is conceptually equivalent to PyTorch gradient checkpoint. When level is not 0, check Check [Memory Optimizer for ONNX Runtime Training](Memory_Optimizer.md) for more details.

```bash
export ORTMODULE_MEMORY_OPT_LEVEL=0
```

### 2.2 Memory Optimization

Q: *Want to run a bigger batch size?*
Expand Down

0 comments on commit 7217ad3

Please sign in to comment.