-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Experimental] Add a path to fallback more nodes to CPUs #19875
base: main
Are you sure you want to change the base?
Changes from 7 commits
039e489
0dac902
520332f
05c9a41
8f8c8fb
526b166
c087069
9f0ca0f
6db6a4d
0f9de47
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,8 +2,11 @@ | |
// Licensed under the MIT License. | ||
|
||
#include "core/framework/fallback_cpu_capability.h" | ||
#include "core/framework/tensorprotoutils.h" | ||
#include "core/common/inlined_containers.h" | ||
|
||
#include <cstring> | ||
Check warning on line 8 in onnxruntime/core/framework/fallback_cpu_capability.cc GitHub Actions / Lint C++
|
||
#include <cstdlib> | ||
Check warning on line 9 in onnxruntime/core/framework/fallback_cpu_capability.cc GitHub Actions / Lint C++
|
||
#include <queue> | ||
|
||
#include "onnx/defs/data_type_utils.h" | ||
|
@@ -39,9 +42,122 @@ | |
} | ||
} // namespace | ||
|
||
static InlinedHashSet<NodeIndex> GetShapeRelatedNodes(const onnxruntime::GraphViewer& viewer) { | ||
// Conceptually, this function traverse from shape-consuming nodes | ||
// to fallback all its upstream nodes to CPU. Consider a graph | ||
// | ||
// | ||
// The traversal should stop when | ||
// 1. hitting Shape, Size nodes, graph inputs, or graph initializers. | ||
// 2. [TODO] hitting nodes with some large inputs or outputs. Before that, | ||
// we need shape inference to determine the size of the inputs and outputs. | ||
// Some graph transforms add nodes without shape information, so | ||
// checking shapes will make the algorithm more unstable now. | ||
LOGS_DEFAULT(VERBOSE) << "Call GetShapeRelatedNodes to identify extra CPU nodes." << std::endl; | ||
|
||
const static InlinedHashMap<std::string_view, InlinedHashMap<int64_t, std::vector<size_t>>> shape_related_inputs_in_nodes = { | ||
Check warning on line 58 in onnxruntime/core/framework/fallback_cpu_capability.cc GitHub Actions / Lint C++
Check warning on line 58 in onnxruntime/core/framework/fallback_cpu_capability.cc GitHub Actions / Lint C++
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious to hear your thoughts on this - https://github.com/microsoft/onnxruntime/pull/19769/files#r1515337499 |
||
// 2nd input of Expand-13 is a shape-related input. | ||
{"Expand", {{13 /* since version */, {1} /* shape inputs' indices */}}}, | ||
// 2nd input (indexed by 1) of Reshape-13, Reshape-14, Reshape-19, Reshape-21 is a shape-related input. | ||
{"Reshape", {{13, {1}}, {14, {1}}, {19, {1}}, {21, {1}}}}, | ||
// 2nd input of Unsqueeze-13 and Unsqueeze-21 is a shape-related input. | ||
{"Unsqueeze", {{13, {1}}, {21, {1}}}}, | ||
// 1st input of ConstantOfShape is a shape-related input. | ||
{"ConstantOfShape", {{9, {0}}, {20, {0}}, {21, {0}}}}, | ||
// 2nd to 5th inputs of Slice-13 are shape-related inputs. | ||
{"Slice", {{13, {1, 2, 3, 4}}}}}; | ||
|
||
auto& graph = viewer.GetGraph(); | ||
// Each shape-producing node produces a tensor consumed | ||
// as shape, axis, size, and indices. | ||
// E.g., | ||
// shape = onnx::Concat(s0, s1) | ||
// reshaped = onnx::Reshape(x, shape) | ||
// Then, the shape-producing node is Concat. | ||
InlinedHashSet<const Node*> shape_producing_nodes; | ||
// This loop collects all shape-producing nodes by finding | ||
// all nodes that produce tensors specified in shape_related_inputs_in_nodes. | ||
// E.g., for the above example, Concat is a shape-producing node because | ||
// "Reshape" has a shape-related input at index 1. | ||
for (auto& node : graph.Nodes()) { | ||
LOGS_DEFAULT(VERBOSE) << "Check if node " << node.Name() << " can be sink of shape sub-graph." << std::endl; | ||
auto op_type_it = shape_related_inputs_in_nodes.find(node.OpType()); | ||
if (op_type_it == shape_related_inputs_in_nodes.end()) { | ||
// This node doesn't consume tensor as shape, | ||
// so we won't find any shape-producing node from it. | ||
continue; | ||
} | ||
auto op_type_version_it = op_type_it->second.find(node.SinceVersion()); | ||
if (op_type_version_it == op_type_it->second.end()) { | ||
// This node doesn't consume tensor as shape in this version, | ||
// so we won't find any shape-producing node from it. | ||
continue; | ||
} | ||
|
||
// shape-like inputs' indices in this node. | ||
// E.g., for Reshape, it's [1] and for Slice, it's [1, 2, 3, 4]. | ||
auto& shape_input_indices = op_type_version_it->second; | ||
// Now, this `node` is a shape-consuming node as defined by shape_related_inputs_in_nodes. | ||
// Let's find producers for shape-like tensors consumed by this `node`. | ||
// Consider this graph: | ||
// shape = onnx::Concat(s0, s1) | ||
// reshaped = onnx::Reshape(x, shape) | ||
// The loop below does: | ||
// 1. checks all `Reshape`'s inputs, `x` and `shape`, | ||
// 2. finds `shape` is a shape-related variable since Reshape's 2nd input is a shape-related input, | ||
// 3. and then records the producer of `shape` (i.e., `Concat`). | ||
for (auto& input_index : shape_input_indices) { | ||
auto input = node.InputDefs().at(input_index); | ||
auto producer_node = graph.GetProducerNode(input->Name()); | ||
if (producer_node != nullptr && producer_node->OpType() != "Shape" && producer_node->OpType() != "Size") { | ||
// Assume shape-computing sub-graphs begins with Shape, Size, or graph inputs. | ||
// We should not fallback those nodes's upstream nodes to CPU; otherwise, | ||
// it may change | ||
// GPU-tensor-x -> Mul -> GPU-tensor-y -> Shape -> CPU-tensor | ||
// to | ||
// CPU-tensor-x -> Mul -> CPU-tensor -> Shape -> CPU-tensor | ||
// and slows down the computation. | ||
|
||
// After this for-loop, we will reversely traverse all nodes from every shape-producing node | ||
// found here until hitting Shape, Size nodes, graph inputs, or graph initializers. | ||
// All nodes on the traversal path will be forced to run on CPU. | ||
LOGS_DEFAULT(VERBOSE) << "Find a shape producing node (i.e., a node produces a tensor consumed as shape-like input in other nodes): " << node.Name() << std::endl; | ||
Check warning on line 124 in onnxruntime/core/framework/fallback_cpu_capability.cc GitHub Actions / Lint C++
|
||
shape_producing_nodes.insert(producer_node); | ||
} | ||
} | ||
} | ||
|
||
InlinedHashSet<NodeIndex> shape_related_node_indices; | ||
for (auto& node : shape_producing_nodes) { | ||
LOGS_DEFAULT(VERBOSE) << "Begin the (topologically reverse) traversing from shape producing node: " << node->Name() << std::endl; | ||
Check warning on line 132 in onnxruntime/core/framework/fallback_cpu_capability.cc GitHub Actions / Lint C++
|
||
std::vector<const Node*> start_nodes = {node}; | ||
|
||
auto to_stop = [](const Node* n1, const Node* n2) { | ||
LOGS_DEFAULT(VERBOSE) << "Skip the traversal from " << n1->Name() << " to " << n2->Name() << " since " << n2->Name() << " is a Shape or Size node." << std::endl; | ||
Check warning on line 136 in onnxruntime/core/framework/fallback_cpu_capability.cc GitHub Actions / Lint C++
|
||
return n2->OpType() == "Shape" || n2->OpType() == "Size"; | ||
}; | ||
|
||
// Reversely traverse all nodes from the shape-producing node. | ||
// Force nodes to be run on CPU when all inputs and outputs are small. | ||
// Stop the traversal when a "Shape" node is found. | ||
graph.ReverseDFSFrom( | ||
start_nodes, | ||
[&shape_related_node_indices](const Node* n) { | ||
LOGS_DEFAULT(VERBOSE) << "Find an upstream node in shape sub-graph (let's fallback it to CPU): " << n->Name() << std::endl; | ||
Check warning on line 146 in onnxruntime/core/framework/fallback_cpu_capability.cc GitHub Actions / Lint C++
|
||
shape_related_node_indices.insert(n->Index()); | ||
}, | ||
nullptr, | ||
NodeCompare(), | ||
to_stop); | ||
} | ||
|
||
return shape_related_node_indices; | ||
} | ||
|
||
std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, | ||
const IExecutionProvider::IKernelLookup& kernel_lookup, | ||
gsl::span<const NodeIndex> tentative_nodes) { | ||
gsl::span<const NodeIndex> tentative_nodes, | ||
const bool aggressive_cpu_fallback) { | ||
wschin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// automatic conversion from const std::vector& | ||
const auto& ordered_nodes = graph.GetNodesInTopologicalOrder(); | ||
InlinedVector<size_t> node_id_to_order_map(graph.MaxNodeIndex()); | ||
|
@@ -83,7 +199,7 @@ | |
auto consumer_nodes = graph.GetConsumerNodes(node_arg.Name()); | ||
for (auto& consumer_node : consumer_nodes) { | ||
candidates.push(consumer_node->Index()); | ||
LOGS_DEFAULT(INFO) << "Candidate for fallback CPU execution: " << consumer_node->Name(); | ||
LOGS_DEFAULT(VERBOSE) << "Candidate for fallback CPU execution: " << consumer_node->Name(); | ||
} | ||
} | ||
return Status::OK(); | ||
|
@@ -159,9 +275,9 @@ | |
|
||
if (place_in_cpu) { | ||
cpu_nodes.insert(cur); | ||
LOGS_DEFAULT(INFO) << "ORT optimization- Force fallback to CPU execution for node: " << node->Name() | ||
<< " because the CPU execution path is deemed faster than overhead involved with execution on other EPs " | ||
<< " capable of executing this node"; | ||
LOGS_DEFAULT(VERBOSE) << "ORT optimization- Force fallback to CPU execution for node: " << node->Name() | ||
<< " because the CPU execution path is deemed faster than overhead involved with execution on other EPs " | ||
Check warning on line 279 in onnxruntime/core/framework/fallback_cpu_capability.cc GitHub Actions / Lint C++
|
||
<< " capable of executing this node"; | ||
for (auto* output : node->OutputDefs()) { | ||
cpu_output_args.insert(output); | ||
} | ||
|
@@ -171,6 +287,10 @@ | |
} | ||
} | ||
|
||
if (aggressive_cpu_fallback) { | ||
auto shape_related_node_indices = GetShapeRelatedNodes(graph); | ||
cpu_nodes.insert(shape_related_node_indices.begin(), shape_related_node_indices.end()); | ||
} | ||
return cpu_nodes; | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some more TODO documentation suggestions- #19769 (comment)