Skip to content

Commit

Permalink
Add a new function to fallback more nodes to CPUs.
Browse files Browse the repository at this point in the history
Shape-related nodes don't only start with `Shape` or `Size`.
In dynamo-captured ONNX model, it can starts with a graph input.
A new transform is added to fallback `all` nodes which can be
reversely traversed from a `shape-like` variable. Some
`shape-like` variables are list below.
- all inputs of Range
- 2nd input of Reshape
- 2nd input of Unsqueeze
- 1st input of ConstantOfShape
- 2nd-to-last inputs of Slice.

Fix header

Remove unused variable

Versioning shape inputs
  • Loading branch information
wschin committed Mar 5, 2024
1 parent 27b1dc9 commit ffa61d7
Showing 1 changed file with 117 additions and 0 deletions.
117 changes: 117 additions & 0 deletions onnxruntime/core/framework/fallback_cpu_capability.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
// Licensed under the MIT License.

#include "core/framework/fallback_cpu_capability.h"
#include "core/framework/tensorprotoutils.h"
#include "core/common/inlined_containers.h"

#include <cstring>

Check warning on line 8 in onnxruntime/core/framework/fallback_cpu_capability.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: fallback_cpu_capability.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/framework/fallback_cpu_capability.cc:8: Found C++ system header after other header. Should be: fallback_cpu_capability.h, c system, c++ system, other. [build/include_order] [4]
#include <cstdlib>

Check warning on line 9 in onnxruntime/core/framework/fallback_cpu_capability.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: fallback_cpu_capability.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/framework/fallback_cpu_capability.cc:9: Found C++ system header after other header. Should be: fallback_cpu_capability.h, c system, c++ system, other. [build/include_order] [4]
#include <queue>
#include <unordered_map>

Check warning on line 11 in onnxruntime/core/framework/fallback_cpu_capability.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: fallback_cpu_capability.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/framework/fallback_cpu_capability.cc:11: Found C++ system header after other header. Should be: fallback_cpu_capability.h, c system, c++ system, other. [build/include_order] [4]

#include "onnx/defs/data_type_utils.h"

Expand Down Expand Up @@ -39,6 +43,115 @@ static bool IsSmallInitializer(const onnxruntime::GraphViewer& graph, const Node
}
} // namespace

std::unordered_set<NodeIndex> GetShapeRelatedNodes(const onnxruntime::GraphViewer& viewer) {
// Conceptually, this function traverse from shape-consuming nodes
// to fallback all its upstream nodes to CPU. Consider a graph
//
//
// The traversal should stop when
// 1. hitting Shape, Size nodes, graph inputs, or graph initializers.
// 2. hitting nodes with some large inputs or outputs.
LOGS_DEFAULT(INFO) << "Call GetShapeRelatedNodes to identify extra CPU nodes." << std::endl;

std::unordered_map<std::string, std::unordered_map<int64_t, std::vector<size_t>>> shape_related_inputs_in_nodes = {

Check warning on line 56 in onnxruntime/core/framework/fallback_cpu_capability.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/framework/fallback_cpu_capability.cc:56: Add #include <string> for string [build/include_what_you_use] [4]
// 2nd input of Expand-13 is a shape-related input.
{"Expand", {13 /* since version */, {1} /* shape inputs' indices */}},
// 2nd input (indexed by 1) of Reshape-13, Reshape-14, Reshape-19, Reshape-21 is a shape-related input.
{"Reshape", {13 , {1}}, {14, {1}}, {19, {1}}, {21, {1}}},
// 2nd input of Unsqueeze-13 and Unsqueeze-21 is a shape-related input.
{"Unsqueeze", {13, {1}}, {21, {1}}},
// 1st input of ConstantOfShape is a shape-related input.
{"ConstantOfShape", {9, {0}}, {20, {0}}, {21, {0}}},
// 2nd to 5th inputs of Slice-13 are shape-related inputs.
{"Slice", {13, {1, 2, 3, 4}}}
};

auto& graph = viewer.GetGraph();
// Each shape-producing node produces a tensor consumed
// as shape, axis, size, and indices.
// E.g.,
// shape = onnx::Concat(s0, s1)
// reshaped = onnx::Reshape(x, shape)
// Then, the shape-producing node is Concat.
std::unordered_set<const Node*> shape_producing_nodes;
// This loop collects all shape-producing nodes by finding
// all nodes that produce tensors specified in shape_related_inputs_in_nodes.
// E.g., for the above example, Concat is a shape-producing node because
// "Reshape" has a shape-related input at index 1.
for (auto& node : graph.Nodes()) {
LOGS_DEFAULT(INFO) << "Check if node " << node.Name() << " can be sink of shape sub-graph." << std::endl;
auto op_type_it = shape_related_inputs_in_nodes.find(node.OpType())
if (op_type_it == shape_related_inputs_in_nodes.end()) {
// This node doesn't consume tensor as shape,
// so we won't find any shape-producing node from it.
continue;
}
if (op_type_it->find(node.SinceVersion()) == op_type_it->end()) {
// This node doesn't consume tensor as shape in this version,
// so we won't find any shape-producing node from it.
continue;
}

// shape-like inputs' indices in this node.
// E.g., for Reshape, it's [1] and for Slice, it's [1, 2, 3, 4].
auto & shape_input_indices = shape_related_inputs_in_nodes.at(node.OpType()).at(node.SinceVersion());
// Now, this `node` is a shape-consuming node as defined by shape_related_inputs_in_nodes.
// Let's find producers for shape-like tensors consumed by this `node`.
// Consider this graph:
// shape = onnx::Concat(s0, s1)
// reshaped = onnx::Reshape(x, shape)
// The loop below does:
// 1. checks all `Reshape`'s inputs, `x` and `shape`,
// 2. finds `shape` is a shape-related variable since Reshape's 2nd input is a shape-related input,
// 3. and then records the producer of `shape` (i.e., `Concat`).
for (auto& input_index : shape_input_indices) {
auto input = node.InputDefs().at(input_index);
auto producer_node = graph.GetProducerNode(input->Name());
if (producer_node != nullptr && producer_node->OpType() != "Shape" && producer_node->OpType() != "Size") {
// Assume shape-computing sub-graphs begins with Shape, Size, or graph inputs.
// We should not fallback those nodes's upstream nodes to CPU; otherwise,
// it may change
// GPU-tensor-x -> Mul -> GPU-tensor-y -> Shape -> CPU-tensor
// to
// CPU-tensor-x -> Mul -> CPU-tensor -> Shape -> CPU-tensor
// and slows down the computation.

// After this for-loop, we will reversely traverse all nodes from every shape-producing node
// found here until hitting Shape, Size nodes, graph inputs, or graph initializers.
// All nodes on the traversal path will be forced to run on CPU.
LOGS_DEFAULT(INFO) << "Find a shape producing node (i.e., a node produces a tensor consumed as shape-like input in other nodes): " << node.Name() << std::endl;

Check warning on line 122 in onnxruntime/core/framework/fallback_cpu_capability.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/framework/fallback_cpu_capability.cc:122: Lines should be <= 120 characters long [whitespace/line_length] [2]
shape_producing_nodes.insert(producer_node);
}
}
}

std::unordered_set<NodeIndex> shape_related_node_indices;
for (auto& node : shape_producing_nodes) {
LOGS_DEFAULT(INFO) << "Begin the (topologically reverse) traversing from shape producing node: " << node->Name() << std::endl;

Check warning on line 130 in onnxruntime/core/framework/fallback_cpu_capability.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/framework/fallback_cpu_capability.cc:130: Lines should be <= 120 characters long [whitespace/line_length] [2]
std::vector<const Node*> start_nodes = {node};

auto to_stop = [](const Node* n1, const Node* n2) {
LOGS_DEFAULT(INFO) << "Skip the traversal from " << n1->Name() << " to " << n2->Name() << " since " << n2->Name() << " is a Shape or Size node." << std::endl;

Check warning on line 134 in onnxruntime/core/framework/fallback_cpu_capability.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/framework/fallback_cpu_capability.cc:134: Lines should be <= 120 characters long [whitespace/line_length] [2]
return n2->OpType() == "Shape" || n2->OpType() == "Size";
};

// Reversely traverse all nodes from the shape-producing node.
// Force nodes to be run on CPU when all inputs and outputs are small.
// Stop the traversal when a "Shape" node is found.
graph.ReverseDFSFrom(
start_nodes,
[&shape_related_node_indices](const Node* n) {
LOGS_DEFAULT(INFO) << "Find an upstream node in shape sub-graph (let's fallback it to CPU): " << n->Name() << std::endl;

Check warning on line 144 in onnxruntime/core/framework/fallback_cpu_capability.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/framework/fallback_cpu_capability.cc:144: Lines should be <= 120 characters long [whitespace/line_length] [2]
shape_related_node_indices.insert(n->Index());
},
nullptr,
NodeCompare(),
to_stop);
}

return shape_related_node_indices;
}

std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph,
const IExecutionProvider::IKernelLookup& kernel_lookup,
gsl::span<const NodeIndex> tentative_nodes) {
Expand Down Expand Up @@ -171,6 +284,10 @@ std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewe
}
}

if (std::strcmp(std::getenv("ORT_AGGRESSIVE_CPU_FALLBACK"), "1") == 0) {
auto shape_related_node_indices = GetShapeRelatedNodes(graph);
cpu_nodes.insert(shape_related_node_indices.begin(), shape_related_node_indices.end());
}
return cpu_nodes;
}

Expand Down

0 comments on commit ffa61d7

Please sign in to comment.