From ae4310e54b0a30cc712933a9766c4de27e8088cc Mon Sep 17 00:00:00 2001 From: pengwa Date: Wed, 6 Mar 2024 21:54:16 +0800 Subject: [PATCH] Adapt memory optimizer to fit PHI2 (#19757) ### Adapt memory optimizer to fit PHI2 Few improvements and bug fixes: 1. Fix bug related to transformer layer detection. 2. Use default reversed typo order to create recompute node, to avoid the leaf nodes are handled too late, then having lowest priority for execution. 3. Add early stop when activation's element count is constant and total element count < 1M. This can avoid overhead to search subgraphs. Using export ORTMODULE_MEMORY_OPT_LEVEL=1 to enable layerwise recompute, on given recipe, memory consumption dropped from ~22GB to ~13GB . --- .../memory_optimizer/memory_insight.cc | 3 +- .../memory_optimizer/memory_optimizer.cc | 37 +++++++++++++++- .../memory_optimizer/recompute_analysis.cc | 18 +++++++- .../memory_optimizer/transformer_specific.cc | 42 +++++++++++++++++-- .../memory_optimizer/transformer_specific.h | 3 ++ 5 files changed, 95 insertions(+), 8 deletions(-) diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc index 08c402bf669c8..54c49db0597c7 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc @@ -258,7 +258,8 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, logger)); InlinedHashSet layer_boundary_ln_nodes; - FindLayerBoundaryLayerNormNodes(graph_viewer, logger, layer_boundary_ln_nodes); + FindLayerBoundaryLayerNormNodes(graph_viewer, logger, node_index_to_its_order_in_topological_sort_map, + yield_op_order_in_topological_sort, layer_boundary_ln_nodes); // The first pass - find the candidate subgraphs. for (int i = static_cast(node_ids.size()) - 1; i >= 0; --i) { diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc index 525e3b4b8de35..40fa2fc5cc737 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc @@ -190,11 +190,44 @@ Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve .IsOK()); // The second pass - apply the transformation. - // Iterate through the nodes in reversed topological order and find the subgraph that can be alleviated. + // Note 1: Iterate through the nodes in reversed topological order and find the subgraph that can be alleviated. // The reason we do reversed topological order is that we want the later layers' recompute nodes can be appended // earlier than the earlier layers, in this way, the execution order of later layers will be in front of the earlier // layers. - const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + // + // Note 2: Here we use default typo order (which tries to BFS from the outputs, + // so the nearest node to graph output will be visited last). So in reversed default typo order, + // the neareast node to graph output will be visited first. + // Imagine there is a such subgraph + // input1 input2 input3 + // \ | / + // multiple layers + // | + // node M + // labels-------|----- + // \ | | + // node1 | | + // \ | | + // node2 / | + // \ / | + // node loss / + // | / + // YieldOp node1_recompute + // | / + // \ node2 recompute + // \ / + // node loss_grad + // | + // critical grad path + // + // In PriorityBased order, node1 will be visited first, so it's recompute node node1_recompute will be added + // at last because we do this following reversed topological order. Then node1_recompute node will have lowest + // priority to execute, as a result, if at this time, the queue to visit contains only recompute nodes, then + // node1_recompute will be run at last, affecting the backward critical path, which is not what we want. + // Current workaround is to use default order, which will execute node1_recompute earlier than other recompute nodes + // in this case. + + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::DEFAULT); for (int i = static_cast(node_ids.size()) - 1; i >= 0; --i) { Node* p_node = graph.GetNode(node_ids[i]); if (p_node == nullptr) { diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index 12c83591c0036..76b3325f36116 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -19,7 +19,7 @@ namespace onnxruntime::optimizer::memory_optimizer { namespace { -constexpr int32_t MAXIMUM_RECOMPUTE_NODE_COUNT = 15; +constexpr int32_t MAXIMUM_RECOMPUTE_NODE_COUNT = 50; static size_t GetElementSize(const ONNX_NAMESPACE::DataType& tensor_type) { const ONNX_NAMESPACE::TypeProto& type_proto = ONNX_NAMESPACE::Utils::DataTypeUtils::ToTypeProto(tensor_type); @@ -291,6 +291,22 @@ Status SelectRecomputeSubgraph(const Node& entry_node, const auto current_node_input_index = input_edge.GetDstArgIndex(); if (std::find(input_arg_indices.begin(), input_arg_indices.end(), current_node_input_index) != input_arg_indices.end()) { + // If the tensor size is constant and very small (Now < 1M), we stop adding the input edge into queue. + auto output_shape = parent_node.OutputDefs()[parent_node_output_index]->Shape(); + if (output_shape) { + bool all_constant_dim = true; + int64_t num_elem = 1; + for (int k = 0, dim_size = output_shape->dim_size(); k < dim_size; ++k) { + if (!output_shape->dim(k).has_dim_value()) { + all_constant_dim = false; + num_elem *= output_shape->dim(k).dim_value(); + } + } + if (all_constant_dim && num_elem < 1 * 1024 * 1024) { + // Skip this input index. + continue; + } + } NodeOutputPort next_p = std::make_pair(&parent_node, parent_node_output_index); MO_LOG_DEBUG_INFO(logger, "Node " + parent_node.Name() + "(" + parent_node.OpType() + ")'s " + diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc index 04f2679ac774f..c88a0f05d36b8 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc @@ -19,6 +19,9 @@ namespace onnxruntime::optimizer::memory_optimizer { void FindLayerBoundaryLayerNormNodes( const GraphViewer& graph_viewer, const logging::Logger&, + const InlinedHashMap& + node_index_to_its_order_in_topological_sort_map, + const ptrdiff_t& yield_op_order_in_topological_sort, InlinedHashSet& layer_boundary_ln_nodes) { // Loop all nodes to find LayerNormalization nodes. // For each LayerNormalization node, keep checking its output nodes, @@ -40,9 +43,16 @@ void FindLayerBoundaryLayerNormNodes( std::deque nodes_to_check; std::set visited_nodes; for (auto node_it = node.OutputNodesBegin(); node_it != node.OutputNodesEnd(); ++node_it) { - nodes_to_check.push_back(&(*node_it)); + // Ignore those nodes after YieldOp. + if (node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) < yield_op_order_in_topological_sort) { + nodes_to_check.push_back(&(*node_it)); + } } + bool unexpected_failure = false; + bool found_softmax = false; + bool found_layernorm = false; + ptrdiff_t next_layernorm_execution_oder = -1; while (!nodes_to_check.empty()) { const Node* next_node = nodes_to_check.front(); nodes_to_check.pop_front(); @@ -53,16 +63,40 @@ void FindLayerBoundaryLayerNormNodes( visited_nodes.insert(next_node); if (softmax_ops.find(next_node->OpType()) != softmax_ops.end()) { - layer_boundary_ln_nodes.insert(&node); - break; + found_softmax = true; } else if (layernorm_ops.find(next_node->OpType()) != layernorm_ops.end()) { - break; + if (found_layernorm) { + // If we found another LayerNormalization node, we would report as warning, and do nothing for layer boundary detection. + unexpected_failure = true; + break; + } + found_layernorm = true; // don't trace further + next_layernorm_execution_oder = node_index_to_its_order_in_topological_sort_map.at(next_node->Index()); + continue; } else { for (auto node_it = next_node->OutputNodesBegin(); node_it != next_node->OutputNodesEnd(); ++node_it) { + // Stop if the node is after next Layernorm node in execution order. + if (found_layernorm && + node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) >= next_layernorm_execution_oder) { + continue; + } nodes_to_check.push_back(&(*node_it)); } } } + + if (unexpected_failure) { + layer_boundary_ln_nodes.clear(); + break; + } + + if (found_softmax) { + layer_boundary_ln_nodes.insert(&node); + } else if (!found_layernorm) { + // If no Softmax found, and no other LayerNormalization found, this should be the last LayerNormalization node, + // we also consider it as boundary node. + layer_boundary_ln_nodes.insert(&node); + } } } diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h index f2cfd640b0840..b58d822124f43 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h @@ -20,6 +20,9 @@ namespace onnxruntime::optimizer::memory_optimizer { void FindLayerBoundaryLayerNormNodes(const GraphViewer& graph_viewer, const logging::Logger& logger, + const InlinedHashMap& + node_index_to_its_order_in_topological_sort_map, + const ptrdiff_t& yield_op_order_in_topological_sort, InlinedHashSet& layer_boundary_ln_nodes); } // namespace onnxruntime::optimizer::memory_optimizer