-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
161 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
68 changes: 68 additions & 0 deletions
68
orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include <charconv> | ||
Check warning on line 4 in orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc GitHub Actions / cpplint[cpplint] orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc#L4
Raw output
|
||
#include <vector> | ||
#include <utility> | ||
|
||
#include "orttraining/core/optimizer/memory_optimizer/common.h" | ||
#include "core/graph/graph_utils.h" | ||
#include "core/optimizer/utils.h" | ||
#include "core/graph/graph_viewer.h" | ||
#include "core/framework/tensorprotoutils.h" | ||
|
||
#include "core/common/string_utils.h" | ||
|
||
namespace onnxruntime::optimizer::memory_optimizer { | ||
|
||
void FindLayerBoundaryLayerNodeNodes( | ||
const GraphViewer& graph_viewer, | ||
const logging::Logger&, | ||
InlinedHashSet<const Node*>& layer_boundary_ln_nodes) { | ||
// Loop all nodes to find LayerNormalization nodes. | ||
// For each LayerNormalization node, keep checking its output nodes, | ||
// until find a node that is Softmax or BiasSoftmax or another LayerNormalization. | ||
// If the found node is Softmax or BiasSoftmax, the LayerNormalization node as ATTENTION. | ||
// If the found node is another LayerNormalization, the LayerNormalization node as MLP. | ||
const InlinedHashSet<std::string_view> softmax_ops{"Softmax", "BiasSoftmax"}; | ||
const InlinedHashSet<std::string_view> layernorm_ops{"LayerNormalization", "SkipLayerNormalization"}; | ||
|
||
layer_boundary_ln_nodes.clear(); | ||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); | ||
for (auto node_index : node_topology_list) { | ||
auto& node = *graph_viewer.GetNode(node_index); | ||
|
||
if (layernorm_ops.find(node.OpType()) == layernorm_ops.end()) { | ||
continue; | ||
} | ||
|
||
std::deque<const Node*> nodes_to_check; | ||
std::set<const Node*> visited_nodes; | ||
for (auto node_it = node.OutputNodesBegin(); node_it != node.OutputNodesEnd(); ++node_it) { | ||
nodes_to_check.push_back(&(*node_it)); | ||
} | ||
|
||
while (!nodes_to_check.empty()) { | ||
const Node* next_node = nodes_to_check.front(); | ||
nodes_to_check.pop_front(); | ||
|
||
if (visited_nodes.find(next_node) != visited_nodes.end()) { | ||
continue; | ||
} | ||
|
||
visited_nodes.insert(next_node); | ||
if (softmax_ops.find(next_node->OpType()) != softmax_ops.end()) { | ||
layer_boundary_ln_nodes.insert(&node); | ||
break; | ||
} else if (layernorm_ops.find(next_node->OpType()) != layernorm_ops.end()) { | ||
break; | ||
} else { | ||
for (auto node_it = next_node->OutputNodesBegin(); node_it != next_node->OutputNodesEnd(); ++node_it) { | ||
nodes_to_check.push_back(&(*node_it)); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
} // namespace onnxruntime::optimizer::memory_optimizer |
25 changes: 25 additions & 0 deletions
25
orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include <memory> | ||
#include <string> | ||
#include <unordered_map> | ||
#include <utility> | ||
|
||
#include "core/common/common.h" | ||
#include "core/common/logging/logging.h" | ||
#include "core/common/inlined_containers_fwd.h" | ||
#include "core/graph/basic_types.h" | ||
#include "core/framework/data_types.h" | ||
#include "core/graph/graph_viewer.h" | ||
#include "orttraining/core/optimizer/memory_optimizer/common.h" | ||
|
||
namespace onnxruntime::optimizer::memory_optimizer { | ||
|
||
void FindLayerBoundaryLayerNodeNodes(const GraphViewer& graph_viewer, | ||
const logging::Logger& logger, | ||
InlinedHashSet<const Node*>& layer_boundary_ln_nodes); | ||
|
||
} // namespace onnxruntime::optimizer::memory_optimizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.