Skip to content

Commit

Permalink
allow layer-wise recompute
Browse files Browse the repository at this point in the history
  • Loading branch information
pengwa committed Nov 24, 2023
1 parent 15a0640 commit 4a88196
Show file tree
Hide file tree
Showing 11 changed files with 161 additions and 23 deletions.
8 changes: 4 additions & 4 deletions docs/Memory_Optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@ Integrate models using `ORTModule`.
```

There are two modes to enable the memory optimizations:
- Aggressively Recompute All, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=1`. This will recompute all detected subgraphs. It is easy to enable, but be noted this recompute plan may NOT be the best one. In this mode `ORTMODULE_MEMORY_OPT_CONFIG` env values passed by users are not respected.
- Aggressively Recompute All for Transformer Models, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=1`. This will recompute all detected subgraphs within each Transformer Attention or MLP layer. It is easy to enable, but be noted this recompute plan may NOT be the best one. In this mode, `ORTMODULE_MEMORY_OPT_CONFIG` env values passed by users are not respected.
- User Specified Subgraph Recompute, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=0` and `export ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,...`. This is an advanced usage, allows users to find the most suitable graphs to recompute, at the cost of overhead to look for the best plans.

### Mode 1 - Simple Usage (Aggressively Recompute All)


1. Set memory optimization level to be AGGRESSIVE_FULL_RECOMPUTE, by `export ORTMODULE_MEMORY_OPT_LEVEL=1`
1. Set memory optimization level to be TRANSFORMER_LAYERWISE_RECOMPUTE, by `export ORTMODULE_MEMORY_OPT_LEVEL=1`
2. Run the training as usual; check the logs, you could find something like this:
```
Memory Optimizer : ON : Memory Optimization Level: [AGGRESSIVE_FULL_RECOMPUTE], Optimization Config: [Reshape+Where+:1:-1,BiasSoftmax+:1:-1,Cast+:1:-1,BiasGelu+:1:-1,FusedMatMul+:1:-1,Add+:1:-1,Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1], Probe Level: [1]
Memory Optimizer : ON : Memory Optimization Level: [TRANSFORMER_LAYERWISE_RECOMPUTE], Optimization Config: [Reshape+Where+:1:-1,BiasSoftmax+:1:-1,Cast+:1:-1,BiasGelu+:1:-1,FusedMatMul+:1:-1,Add+:1:-1,Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1]
Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes)
- Plan 1 : ON : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- Plan 2 : ON : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
Expand Down Expand Up @@ -82,7 +82,7 @@ There are two modes to enable the memory optimizations:
```
5. Then run the training again, and you will see logs like this:
```
Memory Optimizer : ON : Memory Optimization Level: [USER_SPECIFIED], Optimization Config: [BiasGelu+:1:-1], Probe Level: [1]
Memory Optimizer : ON : Memory Optimization Level: [USER_SPECIFIED], Optimization Config: [BiasGelu+:1:-1]
Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes)
- Plan 1 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- Plan 2 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "orttraining/core/optimizer/memory_optimizer/optimization_planner.h"
#include "orttraining/core/optimizer/memory_optimizer/recompute_analysis.h"
#include "orttraining/core/optimizer/memory_optimizer/memory_insight.h"
#include "orttraining/core/optimizer/memory_optimizer/transformer_specific.h"

namespace onnxruntime::optimizer::memory_optimizer {

Expand Down Expand Up @@ -209,6 +210,9 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer,
is_forward_nodes,
logger));

InlinedHashSet<const Node*> layer_boundary_ln_nodes;
FindLayerBoundaryLayerNodeNodes(graph_viewer, logger, layer_boundary_ln_nodes);

// The first pass - find the candidate subgraphs.
for (int i = static_cast<int>(node_ids.size()) - 1; i >= 0; --i) {
const Node* p_node = graph_viewer.GetNode(node_ids[i]);
Expand All @@ -222,11 +226,13 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer,

bool can_compromise_stashed_activation = false;
std::unique_ptr<NodeRecomputePlan> recompute_plan =
CheckNodeForRecompute(*p_node,
CheckNodeForRecompute(graph_viewer,
*p_node,
probe_level,
fw_op_output_arg_used_map,
node_index_to_its_order_in_topological_sort_map,
candidate_output_args_map,
layer_boundary_ln_nodes,
logger, false,
can_compromise_stashed_activation);
if (recompute_plan != nullptr) {
Expand All @@ -239,9 +245,10 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer,
// If the subgraph recompute can save memory by comprising the assumption - recompute graphs' input must exist
// during backward pass, then we can consider to recompute them.
std::unique_ptr<NodeRecomputePlan> recompute_with_compromise_plan =
CheckNodeForRecompute(*p_node, probe_level, fw_op_output_arg_used_map,
CheckNodeForRecompute(graph_viewer, *p_node, probe_level, fw_op_output_arg_used_map,
node_index_to_its_order_in_topological_sort_map,
candidate_output_args_map,
layer_boundary_ln_nodes,
logger, true,
can_compromise_stashed_activation);
if (recompute_with_compromise_plan != nullptr) {
Expand Down Expand Up @@ -710,7 +717,7 @@ std::string GetSerializedORTModuleMemoryStat(const GraphViewer& graph_viewer,
ORT_ENFORCE(probe_level_int < static_cast<int>(ProbeLevel::LevelMax) &&
probe_level_int >= 0,
"Invalid probe level specified: ", recompute_probe_level);
probe_level = static_cast<ProbeLevel>(probe_level);
probe_level = static_cast<ProbeLevel>(probe_level_int);
}

ptrdiff_t yield_op_order_in_topological_sort;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ Status MemoryOptimizationPlanner::FinalizeNodePlansFromUserConfig(
const auto& node = node_to_optimization_plan.first;
const auto& node_plans = node_to_optimization_plan.second;

std::cout << "FinalizeNodePlansFromUserConfig loop node name: " << node->Name() << std::endl;

for (auto& node_plan : node_plans) {
const std::string cluster_id = node_plan->GetClusterId();
if (cluster_id_to_user_configs.find(cluster_id) == cluster_id_to_user_configs.end()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
#include <utility>

#include "orttraining/core/optimizer/memory_optimizer/common.h"
#include "orttraining/core/optimizer/memory_optimizer/transformer_specific.h"
#include "orttraining/core/optimizer/memory_optimizer/recompute_analysis.h"
#include "core/framework/data_types.h"
#include "core/optimizer/utils.h"

namespace onnxruntime::optimizer::memory_optimizer {

Expand Down Expand Up @@ -53,7 +55,7 @@ struct AllowedRecomputeNodeConfig {
InlinedVector<int> input_arg_indices; // input index to iterate further (bottom up)
};

// The op types that are supported predefined.
// The op types that are supported are predefined.

const InlinedHashMap<std::string, AllowedRecomputeNodeConfig>& GetAllowedRecomputeOps(int probe_op_level) {
static InlinedHashMap<int, InlinedHashMap<std::string, AllowedRecomputeNodeConfig>> recomputable_op_table_map;
Expand Down Expand Up @@ -131,7 +133,7 @@ bool IsRecomputable(const Node& node, ProbeLevel probe_level) {
* @param compromise_stashed_activation Whether to compromise stashed activation, e.g. if we cannot find a
* recomputable subgraph to save a stashed activation, we can compromise to find a recomputable subgraph to reduce the
* size of stashed activation.
* @param can_compromise_stashed_activation A bool return value, to indicate there is opportunaties for finding a
* @param can_compromise_stashed_activation A bool return value, to indicate there are opportunities for finding a
* compromised subgraph.
* @param save_ratio The ratio of memory saving if we can find a recomputable subgraph.
* @return Status
Expand Down Expand Up @@ -335,20 +337,48 @@ void NodesInTopoOrderToString(gsl::span<const Node* const> nodes_in_topological_

} // namespace

std::unique_ptr<NodeRecomputePlan> CheckNodeForRecompute(const Node& node,
std::unique_ptr<NodeRecomputePlan> CheckNodeForRecompute(const GraphViewer& graph_viewer,
const Node& node,
const ProbeLevel probe_level,
const ActivationUsedMap& fw_op_output_arg_used_map,
const InlinedHashMap<NodeIndex, ptrdiff_t>&
node_index_to_its_order_in_topological_sort_map,
const InlinedHashMap<const Node*, InlinedVector<size_t>>&
candidate_output_args_map,
const InlinedHashSet<const Node*>& layer_boundary_ln_nodes,
const logging::Logger& logger,
bool compromise_stashed_activation,
bool& can_compromise_stashed_activation) {
if (!IsRecomputable(node, probe_level)) {
return nullptr;
}

// Check whether the node's stashed activation outputs are used by LayerNormalization's inputs.
// If yes, for Transformers, we don't need to recompute the node, because we treated
// LayerNormalization as the boundary for subgraph searching.
// Be noted for a Transformer Layer, imagine one layer contains an attention sublayer and an mlp sublayer, but each
// sublayer has its own LayerNormalization, we treat the LayerNormalization as the boundary for subgraph searching.
if (probe_level == ProbeLevel::Transformers) {
// Check at least one of the stashed activation output is used as the 1st input
// of LayerNormalization, e.g. will be used as input of LayerNormalizationGrad.
for (auto& output_index : candidate_output_args_map.at(&node)) {
auto output_name = node.OutputDefs()[output_index]->Name();
auto consumers = graph_viewer.GetConsumerNodes(output_name);
for (auto& consumer : consumers) {
if (layer_boundary_ln_nodes.find(consumer) != layer_boundary_ln_nodes.end()) {
int dest_in_index = optimizer_utils::IndexOfNodeInput(*consumer, *node.OutputDefs()[output_index]);
if (dest_in_index == 0) {
LOGS(logger, WARNING) << "Node " << node.Name() << "(" << node.OpType()
<< ") is a Attention+MLP layer boundary node, "
<< "its stashed activation outputs are used by LayerNormalization's inputs, "
<< "we don't need to recompute it.";
return nullptr;
}
}
}
}
}

InlinedVector<const Node*> nodes_in_topological_order;
float save_ratio = 1.f;
ORT_ENFORCE(SelectRecomputeSubgraph(node,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ namespace onnxruntime::optimizer::memory_optimizer {
enum class ProbeLevel {
Basic = 0,
Advanced = 1,
LevelMax = 2,
Transformers = 2, // On top of Advanced, LayerNorm as the boundary for subgraph searching.
LevelMax = 3,
};

/**
Expand Down Expand Up @@ -75,13 +76,15 @@ class NodeRecomputePlan : public NodeOptimizationPlanBase {
/**
* @brief For the node producing stashed activation, check whether a recomputable subgraph can be found or not.
*
* @param graph_viewer The graph viewer to get node information.
* @param node The entry node to start the subgraph matching (bottom-up), usually the last node of found subgraphs.
* @param probe_level The level to control allowed operations during subgraph detecting.
* @param fw_op_output_arg_used_map The activation usage (in fw and bw) mapping.
* @param node_index_to_its_order_in_topological_sort_map The mapping of node index to its order in topological sort.
* Used to re-order the collected subgraph nodes.
* @param candidate_output_args_map A map from node to its candidate activations, which are consumed by both fw and
* bw ops.
* @param layer_boundary_ln_nodes A set of LayerNormalization nodes, which are used as the boundary for subgraph.
* @param subgraph_stores A store to maintain all found subgraphs.
* @param logger Logger.
* @param compromise_stashed_activation Whether to compromise stashed activation, e.g. if we cannot find a
Expand All @@ -90,13 +93,15 @@ class NodeRecomputePlan : public NodeOptimizationPlanBase {
* @param can_compromise_stashed_activation A bool return value, to indicate there is opportunaties for finding a
* compromised subgraph.
*/
std::unique_ptr<NodeRecomputePlan> CheckNodeForRecompute(const Node& node,
std::unique_ptr<NodeRecomputePlan> CheckNodeForRecompute(const GraphViewer& graph_viewer,
const Node& node,
const ProbeLevel probe_level,
const ActivationUsedMap& fw_op_output_arg_used_map,
const InlinedHashMap<NodeIndex, ptrdiff_t>&
node_index_to_its_order_in_topological_sort_map,
const InlinedHashMap<const Node*, InlinedVector<size_t>>&
candidate_output_args_map,
const InlinedHashSet<const Node*>& layer_boundary_ln_nodes,
const logging::Logger& logger,
bool compromise_stashed_activation,
bool& can_compromise_stashed_activation);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <charconv>

Check warning on line 4 in orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc#L4

orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc should include its header file orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h [build/include] [5]
Raw output
orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc:4:  orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc should include its header file orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h  [build/include] [5]
#include <vector>
#include <utility>

#include "orttraining/core/optimizer/memory_optimizer/common.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/utils.h"
#include "core/graph/graph_viewer.h"
#include "core/framework/tensorprotoutils.h"

#include "core/common/string_utils.h"

namespace onnxruntime::optimizer::memory_optimizer {

void FindLayerBoundaryLayerNodeNodes(
const GraphViewer& graph_viewer,
const logging::Logger&,
InlinedHashSet<const Node*>& layer_boundary_ln_nodes) {
// Loop all nodes to find LayerNormalization nodes.
// For each LayerNormalization node, keep checking its output nodes,
// until find a node that is Softmax or BiasSoftmax or another LayerNormalization.
// If the found node is Softmax or BiasSoftmax, the LayerNormalization node as ATTENTION.
// If the found node is another LayerNormalization, the LayerNormalization node as MLP.
const InlinedHashSet<std::string_view> softmax_ops{"Softmax", "BiasSoftmax"};
const InlinedHashSet<std::string_view> layernorm_ops{"LayerNormalization", "SkipLayerNormalization"};

layer_boundary_ln_nodes.clear();
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
for (auto node_index : node_topology_list) {
auto& node = *graph_viewer.GetNode(node_index);

if (layernorm_ops.find(node.OpType()) == layernorm_ops.end()) {
continue;
}

std::deque<const Node*> nodes_to_check;
std::set<const Node*> visited_nodes;
for (auto node_it = node.OutputNodesBegin(); node_it != node.OutputNodesEnd(); ++node_it) {
nodes_to_check.push_back(&(*node_it));
}

while (!nodes_to_check.empty()) {
const Node* next_node = nodes_to_check.front();
nodes_to_check.pop_front();

if (visited_nodes.find(next_node) != visited_nodes.end()) {
continue;
}

visited_nodes.insert(next_node);
if (softmax_ops.find(next_node->OpType()) != softmax_ops.end()) {
layer_boundary_ln_nodes.insert(&node);
break;
} else if (layernorm_ops.find(next_node->OpType()) != layernorm_ops.end()) {
break;
} else {
for (auto node_it = next_node->OutputNodesBegin(); node_it != next_node->OutputNodesEnd(); ++node_it) {
nodes_to_check.push_back(&(*node_it));
}
}
}
}
}

} // namespace onnxruntime::optimizer::memory_optimizer
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <memory>
#include <string>
#include <unordered_map>
#include <utility>

#include "core/common/common.h"
#include "core/common/logging/logging.h"
#include "core/common/inlined_containers_fwd.h"
#include "core/graph/basic_types.h"
#include "core/framework/data_types.h"
#include "core/graph/graph_viewer.h"
#include "orttraining/core/optimizer/memory_optimizer/common.h"

namespace onnxruntime::optimizer::memory_optimizer {

void FindLayerBoundaryLayerNodeNodes(const GraphViewer& graph_viewer,
const logging::Logger& logger,
InlinedHashSet<const Node*>& layer_boundary_ln_nodes);

} // namespace onnxruntime::optimizer::memory_optimizer
Original file line number Diff line number Diff line change
Expand Up @@ -652,17 +652,16 @@ def _add_record(tbl, columns):
)

opt_config_to_display = self._runtime_options.memory_optimizer_config
if self._runtime_options.memory_optimization_level == _MemoryOptimizationLevel.AGGRESSIVE_FULL_RECOMPUTE:
opt_config_to_display = "ALL_RECOMPUTE_CONFIGS"
if self._runtime_options.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE:
opt_config_to_display = "ALL_RECOMPUTE_FOR_EACH_LAYER"
mem_row = _add_record(
tbl,
[
"Memory Optimizer",
len(self._runtime_options.memory_optimizer_config) > 0,
(
f"Memory Optimization Level: [{_MemoryOptimizationLevel.to_string(self._runtime_options.memory_optimization_level)}], "
f"Optimization Config: [{opt_config_to_display}], "
f"Probe Level: [{self._runtime_options.probe_level}]"
f"Optimization Config: [{opt_config_to_display}]"
if len(self._runtime_options.memory_optimizer_config) > 0
else "Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1 or ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,..."
),
Expand Down
Loading

0 comments on commit 4a88196

Please sign in to comment.