Skip to content

Commit

Permalink
Flash attention recompute (#20603)
Browse files Browse the repository at this point in the history
### Flash attn recompute

1. Allow PythonOp(FlashAttn) can be recomputed correctly.
45879ff
2. Use JSON to pass the selected-to-recompute subgraphs.
3c374da

#### Better Memory Efficiency 

Customer model can run both PyTorch SPDA and Flash Attn, this PR make it
possible to let the Flash Attn path work with ORTModule layerwise
recompute. The peak drop from 45.xGB to 32.xGB if we only compare the
layers (not including other pieces, BTW there are few more optimization
targeting other pieces as well later).

#### Better Perf

Using Flash ATTN bring additionally 16% end to end time reduction, with
highly aligned loss curve.


![image](https://github.com/microsoft/onnxruntime/assets/10530022/bb63894a-f281-49bc-a8e6-ff818439be38)

#### Use JSON File to pass Recompute Plans

To overcome the limitation of max length of the strings defined in
session options.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
pengwa authored May 21, 2024
1 parent 8acf60f commit 8a98874
Show file tree
Hide file tree
Showing 25 changed files with 1,002 additions and 381 deletions.
1 change: 1 addition & 0 deletions cmake/onnxruntime_optimizer.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ onnxruntime_add_include_to_target(onnxruntime_optimizer onnxruntime_common onnxr
target_include_directories(onnxruntime_optimizer PRIVATE ${ONNXRUNTIME_ROOT})
if (onnxruntime_ENABLE_TRAINING)
target_include_directories(onnxruntime_optimizer PRIVATE ${ORTTRAINING_ROOT})
onnxruntime_add_include_to_target(onnxruntime_optimizer nlohmann_json::nlohmann_json)
if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
onnxruntime_add_include_to_target(onnxruntime_optimizer Python::Module)
endif()
Expand Down
36 changes: 25 additions & 11 deletions docs/Memory_Optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,22 @@ Integrate models using `ORTModule`.

There are two modes to enable the memory optimizations:
- Transformer layerwise recompute, e.g. aggressively recompute all supported nodes within each transformer layer (usually including attention and mlp sublayers), enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=1`. In this mode, `ORTMODULE_MEMORY_OPT_CONFIG` env values passed by users are not respected.
- Manual selected subgraph recompute, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=0` and `export ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,...`. This is an advanced usage, that allows users to find the most suitable graphs to recompute, at the cost of overhead to look for the best plans.
- Manual selected subgraph recompute, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=0` and `export ORTMODULE_MEMORY_OPT_CONFIG=<config file path>`. This is an advanced usage, that allows users to find the most suitable graphs to recompute, at the cost of overhead to look for the best plans. The format for its content is:
```
[
"<plan1 config>",
"<plan2 config>",
...
]
```

### Mode 1 - Simple Usage (Transformer Layerwise Recompute)


1. Set memory optimization level to be TRANSFORMER_LAYERWISE_RECOMPUTE, by `export ORTMODULE_MEMORY_OPT_LEVEL=1`
2. Run the training as usual; check the logs, you could find something like this if the current log level <= LogLevel.INFO:
```
Memory Optimizer : ON : Memory Optimization Level: [TRANSFORMER_LAYERWISE_RECOMPUTE], Optimization Config: [Reshape+Where+:1:-1,BiasSoftmax+:1:-1,Cast+:1:-1,BiasGelu+:1:-1,FusedMatMul+:1:-1,Add+:1:-1,Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1]
Memory Optimizer : ON : Memory Optimization Level: [TRANSFORMER_LAYERWISE_RECOMPUTE], Optimization Config: mem_opt.json
Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes)
- Plan 1 : ON : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- Plan 2 : ON : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
Expand All @@ -59,7 +66,7 @@ There are two modes to enable the memory optimizations:
1. Be noted `ORTMODULE_MEMORY_OPT_LEVEL` is by default be 0. Run the training as usual; then stop it after training a few steps.
2. Check the logs, you could find something like this if the current log level <= LogLevel.INFO::
```
Memory Optimizer : OFF : Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1 or ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,...
Memory Optimizer : OFF : Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1 or ORTMODULE_MEMORY_OPT_CONFIG=<config file path>
Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes)
- Plan 1 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- Plan 2 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
Expand All @@ -73,8 +80,15 @@ There are two modes to enable the memory optimizations:
3. As shown above, `Config` is a string representative for a re-computable subgraph. All are disabled for recompute in this case.
4. Set environment variable `ORTMODULE_MEMORY_OPT_CONFIG` to enable some of the subgraphs to do recompute.
```bash
# Use comma as a separator for enabling more than one subgraphs.
export ORTMODULE_MEMORY_OPT_CONFIG="BiasGelu+:1:1"
export ORTMODULE_MEMORY_OPT_CONFIG="mem_opt.json"

# Content of mem_opt.json:
[
"BiasGelu+:1:1",
"Dropout+:1:-1"
]
# Use comma as a separator for enabling more than one subgraphs in the json file.

# Explanation:
# > BiasGelu+ is the subgraph string representative;
# > 1 in the middle indicates 'Recompute' is enabled (0, on the contrary indicates it's disabled)
Expand All @@ -83,7 +97,7 @@ There are two modes to enable the memory optimizations:
```
5. Then run the training again, and you will see logs like this:
```
Memory Optimizer : ON : Memory Optimization Level: [USER_SPECIFIED], Optimization Config: [BiasGelu+:1:-1]
Memory Optimizer : ON : Memory Optimization Level: [USER_SPECIFIED], Optimization Config: mem_opt.json
Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes)
- Plan 1 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- Plan 2 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
Expand Down Expand Up @@ -127,34 +141,34 @@ MemoryInsight Summary - User config: not provided
|6 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph FusedMatMul+Add+Reshape+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+Add+Reshape+:1:-1 |
| | Status : Disabled. |
| | Stashed Activations: |
| | - ReuseFreq : Output 0(6), |
| | - Output 0 : [((inputs_input_ids_dim0)*(inputs_input_ids_dim1)*(32)*(240))], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _|
|5 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph FusedMatMul+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+:1:-1 |
| | Status : Disabled. |
| | Stashed Activations: |
| | - Output 0 : [((inputs_input_ids_dim0)*(inputs_input_ids_dim1)*(10240))], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _|
|5 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Cast+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:1:-1 |
| | Status : Disabled. |
| | Stashed Activations: |
| | - Output 0 : [((inputs_input_ids_dim0)*(32)*(inputs_input_ids_dim1)*(inputs_input_ids_dim1))], byte/elem: 2, 100% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _|
|1 |For each row options are mutually exclusive, only one of them can be enabled. |
| | |
| |>>Option 1 : Recompute subgraph Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 |
| | Status : Disabled. |
| | Stashed Activations: |
| | - Output 0 : [((inputs_input_ids_dim0)*(1)*(1)*(inputs_input_ids_dim1))], byte/elem: 4, 100% saved |
| | |
| |>>Option 2 : RecomputeWithCompromise subgraph Cast+ |
| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:2:-1 |
| | Status : Disabled. |
| | Stashed Activations: |
| | - Output 0 : [((inputs_input_ids_dim0)*(1)*(1)*(inputs_input_ids_dim1))], byte/elem: 4, 50% saved |
|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,20 @@ static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimizati
static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining = "session.disable_aot_function_inlining";

#ifdef ENABLE_TRAINING
// Specifies a list of op types for memory footprint reduction.
// The value should be a ","-delimited list of pair of
// <subgraph string: optimization strategy: number of subgraph to apply>.
// For example, "Gelu+Cast+:1:0,Dropout+:1:1".
// A valid "subgraph string" should be one subgraph representation output by ORT graph transformations.
// "optimization strategy" currently has valid values: 0 - disabled, 1 - recompute.
// "number of subgraph to apply" is used to control how many subgraphs to apply optimization, to avoid "oversaving"
// the memory.
static const char* const kOrtSessionOptionsMemoryOptimizerEnabler = "optimization.memory_optimizer_config";
// Specifies a path of the file containing a list of memory optimization configurations.
// The value should be a string indicating the file path of the config file.
// The content of the config file is a JSON struct like this:
// [
// "Gelu+Cast+:1:0",
// "Dropout+:1:1"
// ]
// Taking the example of "Gelu+Cast+:1:0",
// > "Gelu+Cast+" is the subgraph string, a valid "subgraph string" should be one subgraph representation
// output by ORT graph transformations.
// > "1" is "optimization strategy", valid values: 0 - disabled, 1 - recompute.
// > "0" is "number of subgraph to apply" which is used to control how many subgraphs to apply optimization,
// to avoid "oversaving" the memory.
static const char* const kOrtSessionOptionsMemoryOptimizerApplyConfig = "optimization.memory_optimizer_config";

// Specifies the config for detecting subgraphs for memory footprint reduction.
// The value should be a string contains int separated using commas. The default value is "0:0".
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1266,15 +1266,15 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool
}

#ifdef ENABLE_TRAINING
// Enable memory optimizations (mainly insert recomputation nodes with priority).
// Enable memory optimizations.
// Only applicable for training scenarios.
{
const std::string memory_optimizer_config =
session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsMemoryOptimizerEnabler, "");
const std::string memory_optimizer_config_file =
session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsMemoryOptimizerApplyConfig, "");
const std::string probe_config =
session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsMemoryOptimizerProbeConfig, "0:0");

MemoryOptimizer mem_transformer{memory_optimizer_config, probe_config};
MemoryOptimizer mem_transformer{memory_optimizer_config_file, probe_config};
ORT_RETURN_IF_ERROR_SESSIONID_(apply_transformer_once(mem_transformer, *session_logger_, graph));
}
#endif
Expand Down
2 changes: 2 additions & 0 deletions orttraining/orttraining/core/agent/training_agent.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ void TrainingAgent::CreateAndInitializeFeedsFetchesManager(const SessionState& s

std::string TrainingAgent::GetSerializedORTModuleMemoryStat(std::string_view memory_optimization_config,
std::string_view recompute_probe_level,
const bool return_opportunity_table,
std::map<std::string, std::pair<std::string, int>>&
cluster_id_combinations_to_saved_symbolic_byte_map)
const {
Expand All @@ -120,6 +121,7 @@ std::string TrainingAgent::GetSerializedORTModuleMemoryStat(std::string_view mem
session_state.GetGraphViewer(),
memory_optimization_config,
recompute_probe_level,
return_opportunity_table,
*inference_session_.GetLogger(),
cluster_id_combinations_to_saved_symbolic_byte_map,
&ortvalue_name_to_idx_map,
Expand Down
1 change: 1 addition & 0 deletions orttraining/orttraining/core/agent/training_agent.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class TrainingAgent {

std::string GetSerializedORTModuleMemoryStat(std::string_view memory_optimization_config,
std::string_view recompute_probe_level,
const bool return_opportunity_table,
std::map<std::string, std::pair<std::string, int>>&
cluster_id_combinations_to_saved_symbolic_byte_map) const;

Expand Down
76 changes: 48 additions & 28 deletions orttraining/orttraining/core/optimizer/memory_optimizer/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,23 @@
// Licensed under the MIT License.

#include <charconv>
#include <fstream>
#include <vector>
#include <utility>

#include "orttraining/core/optimizer/memory_optimizer/common.h"
#include "core/common/string_utils.h"
#include "core/framework/tensorprotoutils.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/utils.h"
#include "core/graph/graph_viewer.h"
#include "core/framework/tensorprotoutils.h"
#include "core/optimizer/utils.h"

#include "core/common/string_utils.h"
#include "nlohmann/json.hpp"

namespace onnxruntime::optimizer::memory_optimizer {

using json = nlohmann::json;

namespace {

constexpr const char empty_dim_param_placeholder[] = "empty_dim_param";
Expand Down Expand Up @@ -114,32 +118,48 @@ int ParseIntValueFromString(std::string_view str) {
return int_value;
}

Status ParseOptimizationConfigFromString(std::string_view memory_optimization_config,
void from_json(const json& j, UserConfig& mo) {
j.at("type").get_to(mo.type);
j.at("requested_count").get_to(mo.requested_count);
}

Status ParseOptimizationConfigFromString(std::string_view memory_optimization_config_file_path,
InlinedHashMap<std::string, UserConfig>& cluster_id_to_config_map) {
if (!memory_optimization_config.empty()) {
const auto user_config_strs = utils::SplitString(memory_optimization_config, ",");
for (const auto& user_config_str : user_config_strs) {
const auto user_config = utils::SplitString(user_config_str, ":");
ORT_RETURN_IF_NOT(user_config.size() == 3,
"User config should be in the format of SubgraphStr:OptimizationType:RequestApplyCount.");

const std::string subgraph_string_representation(user_config[0]);
int optimization_type_int = ParseIntValueFromString(user_config[1]);
int requested_apply_count = ParseIntValueFromString(user_config[2]);
ORT_RETURN_IF_NOT(optimization_type_int <
static_cast<int>(OptimizationType::TypeMax) &&
optimization_type_int >= 0,
"Invalid optimization type specified for subgraph: ",
subgraph_string_representation);

ORT_RETURN_IF_NOT(requested_apply_count == -1 || requested_apply_count >= 0,
"Invalid requested_apply_count specified for subgraph: ", requested_apply_count);

// At this point, subgraph_string_representation is a pattern graph string representation.
// If a duplicated subgraph_string_representation is found in user config, the last one will be used.
cluster_id_to_config_map[subgraph_string_representation] = UserConfig{
static_cast<OptimizationType>(optimization_type_int),
requested_apply_count};
if (!memory_optimization_config_file_path.empty()) {
InlinedVector<std::string> configs_by_cluster_id; // Each cluster_id might contains multiple plans.
try {
std::ifstream in{std::string(memory_optimization_config_file_path).c_str()};
const json j = json::parse(in);
j.get_to<InlinedVector<std::string>>(configs_by_cluster_id);
} catch (const std::exception& ex) {
ORT_THROW("Fail to parse from json file: ", ex.what());
}

for (const auto& config_for_cur_cluster : configs_by_cluster_id) {
const auto configs_by_plan_id = utils::SplitString(config_for_cur_cluster, ",");
for (const auto& config_for_cur_plan : configs_by_plan_id) {
const auto user_config = utils::SplitString(config_for_cur_plan, ":");
ORT_RETURN_IF_NOT(user_config.size() == 3,
"User config should be in the format of SubgraphStr:OptimizationType:RequestApplyCount.");

const std::string subgraph_string_representation(user_config[0]);
int optimization_type_int = ParseIntValueFromString(user_config[1]);
int requested_apply_count = ParseIntValueFromString(user_config[2]);
ORT_RETURN_IF_NOT(optimization_type_int <
static_cast<int>(OptimizationType::TypeMax) &&
optimization_type_int >= 0,
"Invalid optimization type specified for subgraph: ",
subgraph_string_representation);

ORT_RETURN_IF_NOT(requested_apply_count == -1 || requested_apply_count >= 0,
"Invalid requested_apply_count specified for subgraph: ", requested_apply_count);

// At this point, subgraph_string_representation is a pattern graph string representation.
// If a duplicated subgraph_string_representation is found in user config, the last one will be used.
cluster_id_to_config_map[subgraph_string_representation] = UserConfig{
static_cast<OptimizationType>(optimization_type_int),
requested_apply_count};
}
}
}

Expand Down
Loading

0 comments on commit 8a98874

Please sign in to comment.