Skip to content

Commit

Permalink
Adapt memory optimizer to fit PHI2 (#19757)
Browse files Browse the repository at this point in the history
### Adapt memory optimizer to fit PHI2

Few improvements and bug fixes:
1. Fix bug related to transformer layer detection. 
2. Use default reversed typo order to create recompute node, to avoid
the leaf nodes are handled too late, then having lowest priority for
execution.
3. Add early stop when activation's element count is constant and total
element count < 1M. This can avoid overhead to search subgraphs.


Using export ORTMODULE_MEMORY_OPT_LEVEL=1 to enable layerwise recompute,
on given recipe, memory consumption dropped from ~22GB to ~13GB .
  • Loading branch information
pengwa authored Mar 6, 2024
1 parent e93a860 commit d9bf856
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer,
logger));

InlinedHashSet<const Node*> layer_boundary_ln_nodes;
FindLayerBoundaryLayerNormNodes(graph_viewer, logger, layer_boundary_ln_nodes);
FindLayerBoundaryLayerNormNodes(graph_viewer, logger, node_index_to_its_order_in_topological_sort_map,
yield_op_order_in_topological_sort, layer_boundary_ln_nodes);

// The first pass - find the candidate subgraphs.
for (int i = static_cast<int>(node_ids.size()) - 1; i >= 0; --i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,44 @@ Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve
.IsOK());

// The second pass - apply the transformation.
// Iterate through the nodes in reversed topological order and find the subgraph that can be alleviated.
// Note 1: Iterate through the nodes in reversed topological order and find the subgraph that can be alleviated.
// The reason we do reversed topological order is that we want the later layers' recompute nodes can be appended
// earlier than the earlier layers, in this way, the execution order of later layers will be in front of the earlier
// layers.
const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED);
//
// Note 2: Here we use default typo order (which tries to BFS from the outputs,
// so the nearest node to graph output will be visited last). So in reversed default typo order,
// the neareast node to graph output will be visited first.
// Imagine there is a such subgraph
// input1 input2 input3
// \ | /
// multiple layers
// |
// node M
// labels-------|-----
// \ | |
// node1 | |
// \ | |
// node2 / |
// \ / |
// node loss /
// | /
// YieldOp node1_recompute
// | /
// \ node2 recompute
// \ /
// node loss_grad
// |
// critical grad path
//
// In PriorityBased order, node1 will be visited first, so it's recompute node node1_recompute will be added
// at last because we do this following reversed topological order. Then node1_recompute node will have lowest
// priority to execute, as a result, if at this time, the queue to visit contains only recompute nodes, then
// node1_recompute will be run at last, affecting the backward critical path, which is not what we want.
// Current workaround is to use default order, which will execute node1_recompute earlier than other recompute nodes
// in this case.

const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::DEFAULT);
for (int i = static_cast<int>(node_ids.size()) - 1; i >= 0; --i) {
Node* p_node = graph.GetNode(node_ids[i]);
if (p_node == nullptr) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace onnxruntime::optimizer::memory_optimizer {

namespace {

constexpr int32_t MAXIMUM_RECOMPUTE_NODE_COUNT = 15;
constexpr int32_t MAXIMUM_RECOMPUTE_NODE_COUNT = 50;

static size_t GetElementSize(const ONNX_NAMESPACE::DataType& tensor_type) {
const ONNX_NAMESPACE::TypeProto& type_proto = ONNX_NAMESPACE::Utils::DataTypeUtils::ToTypeProto(tensor_type);
Expand Down Expand Up @@ -291,6 +291,22 @@ Status SelectRecomputeSubgraph(const Node& entry_node,
const auto current_node_input_index = input_edge.GetDstArgIndex();
if (std::find(input_arg_indices.begin(), input_arg_indices.end(), current_node_input_index) !=
input_arg_indices.end()) {
// If the tensor size is constant and very small (Now < 1M), we stop adding the input edge into queue.
auto output_shape = parent_node.OutputDefs()[parent_node_output_index]->Shape();
if (output_shape) {
bool all_constant_dim = true;
int64_t num_elem = 1;
for (int k = 0, dim_size = output_shape->dim_size(); k < dim_size; ++k) {
if (!output_shape->dim(k).has_dim_value()) {
all_constant_dim = false;
num_elem *= output_shape->dim(k).dim_value();
}
}
if (all_constant_dim && num_elem < 1 * 1024 * 1024) {
// Skip this input index.
continue;
}
}
NodeOutputPort next_p = std::make_pair(&parent_node, parent_node_output_index);

MO_LOG_DEBUG_INFO(logger, "Node " + parent_node.Name() + "(" + parent_node.OpType() + ")'s " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ namespace onnxruntime::optimizer::memory_optimizer {
void FindLayerBoundaryLayerNormNodes(
const GraphViewer& graph_viewer,
const logging::Logger&,
const InlinedHashMap<NodeIndex, ptrdiff_t>&
node_index_to_its_order_in_topological_sort_map,
const ptrdiff_t& yield_op_order_in_topological_sort,
InlinedHashSet<const Node*>& layer_boundary_ln_nodes) {
// Loop all nodes to find LayerNormalization nodes.
// For each LayerNormalization node, keep checking its output nodes,
Expand All @@ -40,9 +43,16 @@ void FindLayerBoundaryLayerNormNodes(
std::deque<const Node*> nodes_to_check;
std::set<const Node*> visited_nodes;
for (auto node_it = node.OutputNodesBegin(); node_it != node.OutputNodesEnd(); ++node_it) {
nodes_to_check.push_back(&(*node_it));
// Ignore those nodes after YieldOp.
if (node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) < yield_op_order_in_topological_sort) {
nodes_to_check.push_back(&(*node_it));
}
}

bool unexpected_failure = false;
bool found_softmax = false;
bool found_layernorm = false;
ptrdiff_t next_layernorm_execution_oder = -1;
while (!nodes_to_check.empty()) {
const Node* next_node = nodes_to_check.front();
nodes_to_check.pop_front();
Expand All @@ -53,16 +63,40 @@ void FindLayerBoundaryLayerNormNodes(

visited_nodes.insert(next_node);
if (softmax_ops.find(next_node->OpType()) != softmax_ops.end()) {
layer_boundary_ln_nodes.insert(&node);
break;
found_softmax = true;
} else if (layernorm_ops.find(next_node->OpType()) != layernorm_ops.end()) {
break;
if (found_layernorm) {
// If we found another LayerNormalization node, we would report as warning, and do nothing for layer boundary detection.
unexpected_failure = true;
break;
}
found_layernorm = true; // don't trace further
next_layernorm_execution_oder = node_index_to_its_order_in_topological_sort_map.at(next_node->Index());
continue;
} else {
for (auto node_it = next_node->OutputNodesBegin(); node_it != next_node->OutputNodesEnd(); ++node_it) {
// Stop if the node is after next Layernorm node in execution order.
if (found_layernorm &&
node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) >= next_layernorm_execution_oder) {
continue;
}
nodes_to_check.push_back(&(*node_it));
}
}
}

if (unexpected_failure) {
layer_boundary_ln_nodes.clear();
break;
}

if (found_softmax) {
layer_boundary_ln_nodes.insert(&node);
} else if (!found_layernorm) {
// If no Softmax found, and no other LayerNormalization found, this should be the last LayerNormalization node,
// we also consider it as boundary node.
layer_boundary_ln_nodes.insert(&node);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ namespace onnxruntime::optimizer::memory_optimizer {

void FindLayerBoundaryLayerNormNodes(const GraphViewer& graph_viewer,
const logging::Logger& logger,
const InlinedHashMap<NodeIndex, ptrdiff_t>&
node_index_to_its_order_in_topological_sort_map,
const ptrdiff_t& yield_op_order_in_topological_sort,
InlinedHashSet<const Node*>& layer_boundary_ln_nodes);

} // namespace onnxruntime::optimizer::memory_optimizer

0 comments on commit d9bf856

Please sign in to comment.