Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix transformer layer detection for recompute #20106

Merged
merged 10 commits into from
Mar 29, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,15 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer,
is_forward_nodes,
logger));

InlinedHashSet<const Node*> layer_boundary_ln_nodes;
InlinedVector<const Node*> layer_boundary_ln_nodes;
FindLayerBoundaryLayerNormNodes(graph_viewer, logger, node_index_to_its_order_in_topological_sort_map,
yield_op_order_in_topological_sort, layer_boundary_ln_nodes);

if (probe_config.enable_transformer_layer_as_boundary && layer_boundary_ln_nodes.size() == 0) {
LOGS(logger, WARNING) << "No transformer layer boundary nodes found, this might cause memory optimization "
"not working as expected. Please check the model and the configuration.";
}

// The first pass - find the candidate subgraphs.
for (int i = static_cast<int>(node_ids.size()) - 1; i >= 0; --i) {
const Node* p_node = graph_viewer.GetNode(node_ids[i]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,26 @@ const InlinedHashMap<std::string, OpsetToIgnorableIndicesMap>& GetAllowedRecompu
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("SimplifiedLayerNormalization", kOnnxDomain),
{
// Opset 1 in ONNX official does not have SimplifiedLayerNormalization,
// while our contrib op defined SimplifiedLayerNormalization in opset 1 in ONNX domain.
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("SkipLayerNormalization", kMSDomain),
{
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("SkipSimplifiedLayerNormalization", kMSDomain),
{
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("Softmax", kOnnxDomain),
{
Expand Down Expand Up @@ -691,7 +711,7 @@ std::unique_ptr<NodeRecomputePlan> CheckNodeForRecompute(const GraphViewer& grap
node_index_to_its_order_in_topological_sort_map,
const InlinedHashMap<const Node*, InlinedVector<size_t>>&
candidate_output_args_map,
const InlinedHashSet<const Node*>& layer_boundary_ln_nodes,
const InlinedVector<const Node*>& layer_boundary_ln_nodes,
const logging::Logger& logger,
bool compromise_stashed_activation,
bool& can_compromise_stashed_activation) {
Expand All @@ -709,13 +729,14 @@ std::unique_ptr<NodeRecomputePlan> CheckNodeForRecompute(const GraphViewer& grap
auto output_name = node.OutputDefs()[output_index]->Name();
auto consumers = graph_viewer.GetConsumerNodes(output_name);
for (auto& consumer : consumers) {
if (layer_boundary_ln_nodes.find(consumer) != layer_boundary_ln_nodes.end()) {
if (std::find(layer_boundary_ln_nodes.begin(), layer_boundary_ln_nodes.end(), consumer) !=
layer_boundary_ln_nodes.end()) {
int dest_in_index = optimizer_utils::IndexOfNodeInput(*consumer, *node.OutputDefs()[output_index]);
if (dest_in_index == 0) {
LOGS(logger, INFO) << "Node " << node.Name() << "(" << node.OpType()
<< ") is a Attention+MLP layer boundary node, "
<< "its stashed activation outputs are used by LayerNormalization's inputs, "
<< "we don't need to recompute it.";
MO_LOG_DEBUG_INFO(logger, "Node " + node.Name() + "(" + node.OpType() +
") is a Attention+MLP layer boundary node, " +
"its stashed activation outputs are used by LayerNormalization's inputs, " +
"we don't need to recompute it.");
return nullptr;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ std::unique_ptr<NodeRecomputePlan> CheckNodeForRecompute(const GraphViewer& grap
node_index_to_its_order_in_topological_sort_map,
const InlinedHashMap<const Node*, InlinedVector<size_t>>&
candidate_output_args_map,
const InlinedHashSet<const Node*>& layer_boundary_ln_nodes,
const InlinedVector<const Node*>& layer_boundary_ln_nodes,
const logging::Logger& logger,
bool compromise_stashed_activation,
bool& can_compromise_stashed_activation);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include <charconv>
#include <tuple>
#include <vector>
#include <utility>

Expand All @@ -16,43 +17,139 @@

namespace onnxruntime::optimizer::memory_optimizer {

namespace {

bool IsLayerNormNode(const Node& node) {
static const std::set<std::string> layer_norm_ops = {
"LayerNormalization",
"SkipLayerNormalization",
"SimplifiedLayerNormalization",
"SkipSimplifiedLayerNormalization",
};
return layer_norm_ops.find(node.OpType()) != layer_norm_ops.end();
}

bool IsSoftmaxNode(const Node& node) {
static const std::set<std::string> softmax_ops = {
"Softmax",
"BiasSoftmax",
};
return softmax_ops.find(node.OpType()) != softmax_ops.end();
}

std::tuple<bool, const Node*, const Node*> IsResidualNodeArg(const GraphViewer& graph_viewer, const NodeArg* node_arg) {
auto consumers = graph_viewer.GetConsumerNodes(node_arg->Name());
if (2 > consumers.size()) {
return std::make_tuple(false, nullptr, nullptr);
}

// Find the Add node from the consumer list.
const Node* add_node = nullptr;
const Node* other_node = nullptr;
for (const auto* consumer : consumers) {
if (consumer->OpType() == "Add") {
add_node = consumer;
} else if (IsLayerNormNode(*consumer)) {
other_node = consumer;
}
}

return std::make_tuple(add_node != nullptr && other_node != nullptr, add_node, other_node);
}
} // namespace

/*
One classical layer includes 1). input layer norm, 2). attention, 3). residual add
(input layer norm input + attention output), 4). post attention layer norm feedforward, and 5). residual add
(post attention layer norm input + feedforward out).

The pattern graph looks like below for each transformer layer (taking the example of MistralDecoderLayer):
|
Embedding
|
----------------------|
| |
| |
| SimplifiedLayerNormalization (layer boudary node)
| |
| |
| MistralAttention
| |
| |
|____________________Add
|
----------------------|
| |
| |
| SimplifiedLayerNormalization
| |
| |
| MultipleLayerPerception
| |
| |
|____________________Add
|
(new layer)
----------------------|
| |
| SimplifiedLayerNormalization
....
*/
void FindLayerBoundaryLayerNormNodes(
const GraphViewer& graph_viewer,
const logging::Logger&,
const logging::Logger& logger,
const InlinedHashMap<NodeIndex, ptrdiff_t>&
node_index_to_its_order_in_topological_sort_map,
const ptrdiff_t& yield_op_order_in_topological_sort,
InlinedHashSet<const Node*>& layer_boundary_ln_nodes) {
InlinedVector<const Node*>& layer_boundary_ln_nodes) {
// Loop all nodes to find LayerNormalization nodes.
// For each LayerNormalization node, keep checking its output nodes,
// until find a node that is Softmax or BiasSoftmax or another LayerNormalization.
// If the found node is Softmax or BiasSoftmax, the LayerNormalization node as ATTENTION.
// If the found node is another LayerNormalization, the LayerNormalization node as MLP.
const InlinedHashSet<std::string_view> softmax_ops{"Softmax", "BiasSoftmax"};
const InlinedHashSet<std::string_view> layernorm_ops{"LayerNormalization", "SkipLayerNormalization"};

layer_boundary_ln_nodes.clear();

const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED);
for (auto node_index : node_topology_list) {
auto& node = *graph_viewer.GetNode(node_index);

if (layernorm_ops.find(node.OpType()) == layernorm_ops.end()) {
if (!IsLayerNormNode(node)) {
continue;
}
const NodeArg* input_arg = node.InputDefs()[0];

// IsResidualNodeArg checks input_arg
auto [is_residual_node_arg, add_node, other_node] = IsResidualNodeArg(graph_viewer, input_arg);
if (!is_residual_node_arg) {
MO_LOG_DEBUG_INFO(logger, "Not a residual node arg " + input_arg->Name());
continue;
}

// At this point, there should not be any recompute node, so we don't need check the node existence in
// node_index_to_its_order_in_topological_sort_map.
ptrdiff_t attention_residual_add_node_order =
node_index_to_its_order_in_topological_sort_map.at(add_node->Index());
ptrdiff_t attention_residual_ln_node_order =
node_index_to_its_order_in_topological_sort_map.at(other_node->Index());
if (attention_residual_add_node_order >= yield_op_order_in_topological_sort ||
attention_residual_ln_node_order >= yield_op_order_in_topological_sort) {
MO_LOG_DEBUG_INFO(logger, "Not a valid residual node arg " + input_arg->Name());
continue;
}

// Search all forward nodes that is before `add_node` in topo order, unless we find a softmax or
// nodes_to_check is empty.
std::deque<const Node*> nodes_to_check;
std::set<const Node*> visited_nodes;
for (auto node_it = node.OutputNodesBegin(); node_it != node.OutputNodesEnd(); ++node_it) {
// Ignore those nodes after YieldOp.
if (node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) < yield_op_order_in_topological_sort) {
auto order = node_index_to_its_order_in_topological_sort_map.at(node_it->Index());
if (order < yield_op_order_in_topological_sort && order < attention_residual_add_node_order) {
nodes_to_check.push_back(&(*node_it));
}
}

bool unexpected_failure = false;
bool found_softmax = false;
bool found_layernorm = false;
ptrdiff_t next_layernorm_execution_oder = -1;
while (!nodes_to_check.empty()) {
const Node* next_node = nodes_to_check.front();
nodes_to_check.pop_front();
Expand All @@ -62,41 +159,21 @@ void FindLayerBoundaryLayerNormNodes(
}

visited_nodes.insert(next_node);
if (softmax_ops.find(next_node->OpType()) != softmax_ops.end()) {
found_softmax = true;
} else if (layernorm_ops.find(next_node->OpType()) != layernorm_ops.end()) {
if (found_layernorm) {
// If we found another LayerNormalization node, we would report as warning, and do nothing for layer boundary detection.
unexpected_failure = true;
break;
}
found_layernorm = true; // don't trace further
next_layernorm_execution_oder = node_index_to_its_order_in_topological_sort_map.at(next_node->Index());
continue;
} else {
for (auto node_it = next_node->OutputNodesBegin(); node_it != next_node->OutputNodesEnd(); ++node_it) {
// Stop if the node is after next Layernorm node in execution order.
if (found_layernorm &&
node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) >= next_layernorm_execution_oder) {
continue;
}
if (IsSoftmaxNode(*next_node)) {
MO_LOG_DEBUG_INFO(logger, "Found layer boundary node " + node.Name() + " with its input arg: " +
input_arg->Name());
layer_boundary_ln_nodes.push_back(&node);
break;
}

for (auto node_it = next_node->OutputNodesBegin(); node_it != next_node->OutputNodesEnd(); ++node_it) {
// Stop if the node is after next Layernorm node in execution order.
auto order = node_index_to_its_order_in_topological_sort_map.at(node_it->Index());
if (order < yield_op_order_in_topological_sort && order < attention_residual_add_node_order) {
nodes_to_check.push_back(&(*node_it));
}
}
}

if (unexpected_failure) {
layer_boundary_ln_nodes.clear();
break;
}

if (found_softmax) {
layer_boundary_ln_nodes.insert(&node);
} else if (!found_layernorm) {
// If no Softmax found, and no other LayerNormalization found, this should be the last LayerNormalization node,
// we also consider it as boundary node.
layer_boundary_ln_nodes.insert(&node);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ void FindLayerBoundaryLayerNormNodes(const GraphViewer& graph_viewer,
const InlinedHashMap<NodeIndex, ptrdiff_t>&
node_index_to_its_order_in_topological_sort_map,
const ptrdiff_t& yield_op_order_in_topological_sort,
InlinedHashSet<const Node*>& layer_boundary_ln_nodes);
InlinedVector<const Node*>& layer_boundary_ln_nodes);

} // namespace onnxruntime::optimizer::memory_optimizer
41 changes: 41 additions & 0 deletions orttraining/orttraining/test/optimizer/memory_optimizer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "orttraining/core/optimizer/memory_optimizer/common.h"
#include "orttraining/core/optimizer/memory_optimizer/memory_optimizer.h"
#include "orttraining/core/optimizer/memory_optimizer/memory_insight.h"
#include "orttraining/core/optimizer/memory_optimizer/transformer_specific.h"

using namespace std;
using namespace ONNX_NAMESPACE;
Expand Down Expand Up @@ -312,5 +313,45 @@ TEST(MemoryOptimizerTests, TransformerPerLayerRecompute) {
}
}

TEST(MemoryOptimizerTests, TransformerLayerDetectionTest) {
const logging::Logger* logger = &logging::LoggingManager::DefaultLogger();
auto model_uri = MODEL_FOLDER "3layer_bloom_optimized_training.onnx";
std::shared_ptr<Model> model;
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger));
Graph& graph = model->MainGraph();
GraphViewer graph_viewer(graph);

InlinedHashMap<NodeIndex, ptrdiff_t> node_index_to_its_order_in_topological_sort_map;
const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED);

// Find boundary ops between forward and backward pass, currently, it's limited to YieldOp.
ptrdiff_t yield_op_order_in_topological_sort = -1;
for (size_t i = 0; i < node_ids.size(); ++i) {
const Node* p_node = graph_viewer.GetNode(node_ids[i]);
if (p_node == nullptr) { /* skip removed nodes*/
continue;
}

if (p_node->OpType() == "YieldOp") {
// There are multiple YieldOps in the graph。
ASSERT_EQ(yield_op_order_in_topological_sort, -1);
yield_op_order_in_topological_sort = static_cast<ptrdiff_t>(i);
}

node_index_to_its_order_in_topological_sort_map[p_node->Index()] = static_cast<ptrdiff_t>(i);
}

InlinedVector<const Node*> layer_boundary_ln_node;
optimizer::memory_optimizer::FindLayerBoundaryLayerNormNodes(graph_viewer, *logger,
node_index_to_its_order_in_topological_sort_map,
yield_op_order_in_topological_sort,
layer_boundary_ln_node);

ASSERT_EQ(layer_boundary_ln_node.size(), 3);
ASSERT_EQ(layer_boundary_ln_node[0]->Name(), "LayerNormalization_token_0");
ASSERT_EQ(layer_boundary_ln_node[1]->Name(), "LayerNormalization_token_6");
ASSERT_EQ(layer_boundary_ln_node[2]->Name(), "LayerNormalization_token_12");
}

} // namespace test
} // namespace onnxruntime
Loading