From ed79ec7a2e3fe77792bbddd8b106ee058b1f7a5a Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Mon, 4 Mar 2024 14:47:32 -0800 Subject: [PATCH] Add a new function to fallback more nodes to CPUs. Shape-related nodes don't only start with `Shape` or `Size`. In dynamo-captured ONNX model, it can starts with a graph input. A new transform is added to fallback `all` nodes which can be reversely traversed from a `shape-like` variable. Some `shape-like` variables are list below. - all inputs of Range - 2nd input of Reshape - 2nd input of Unsqueeze - 1st input of ConstantOfShape - 2nd-to-last inputs of Slice. Fix header Remove unused variable Versioning shape inputs Fix --- .../core/framework/fallback_cpu_capability.cc | 117 ++++++++++++++++++ 1 file changed, 117 insertions(+) diff --git a/onnxruntime/core/framework/fallback_cpu_capability.cc b/onnxruntime/core/framework/fallback_cpu_capability.cc index ef68b88187e08..8af05affe8aec 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.cc +++ b/onnxruntime/core/framework/fallback_cpu_capability.cc @@ -2,9 +2,13 @@ // Licensed under the MIT License. #include "core/framework/fallback_cpu_capability.h" +#include "core/framework/tensorprotoutils.h" #include "core/common/inlined_containers.h" +#include +#include #include +#include #include "onnx/defs/data_type_utils.h" @@ -39,6 +43,115 @@ static bool IsSmallInitializer(const onnxruntime::GraphViewer& graph, const Node } } // namespace +std::unordered_set GetShapeRelatedNodes(const onnxruntime::GraphViewer& viewer) { + // Conceptually, this function traverse from shape-consuming nodes + // to fallback all its upstream nodes to CPU. Consider a graph + // + // + // The traversal should stop when + // 1. hitting Shape, Size nodes, graph inputs, or graph initializers. + // 2. hitting nodes with some large inputs or outputs. + LOGS_DEFAULT(INFO) << "Call GetShapeRelatedNodes to identify extra CPU nodes." << std::endl; + + std::unordered_map>> shape_related_inputs_in_nodes = { + // 2nd input of Expand-13 is a shape-related input. + {"Expand", {{13 /* since version */, {1} /* shape inputs' indices */}}}, + // 2nd input (indexed by 1) of Reshape-13, Reshape-14, Reshape-19, Reshape-21 is a shape-related input. + {"Reshape", {{13, {1}}, {14, {1}}, {19, {1}}, {21, {1}}}}, + // 2nd input of Unsqueeze-13 and Unsqueeze-21 is a shape-related input. + {"Unsqueeze", {{13, {1}}, {21, {1}}}}, + // 1st input of ConstantOfShape is a shape-related input. + {"ConstantOfShape", {{9, {0}}, {20, {0}}, {21, {0}}}}, + // 2nd to 5th inputs of Slice-13 are shape-related inputs. + {"Slice", {{13, {1, 2, 3, 4}}}}}; + + auto& graph = viewer.GetGraph(); + // Each shape-producing node produces a tensor consumed + // as shape, axis, size, and indices. + // E.g., + // shape = onnx::Concat(s0, s1) + // reshaped = onnx::Reshape(x, shape) + // Then, the shape-producing node is Concat. + std::unordered_set shape_producing_nodes; + // This loop collects all shape-producing nodes by finding + // all nodes that produce tensors specified in shape_related_inputs_in_nodes. + // E.g., for the above example, Concat is a shape-producing node because + // "Reshape" has a shape-related input at index 1. + for (auto& node : graph.Nodes()) { + LOGS_DEFAULT(INFO) << "Check if node " << node.Name() << " can be sink of shape sub-graph." << std::endl; + auto op_type_it = shape_related_inputs_in_nodes.find(node.OpType()); + if (op_type_it == shape_related_inputs_in_nodes.end()) { + // This node doesn't consume tensor as shape, + // so we won't find any shape-producing node from it. + continue; + } + auto op_type_version_it = op_type_it->second.find(node.SinceVersion()); + if (op_type_version_it == op_type_it->second.end()) { + // This node doesn't consume tensor as shape in this version, + // so we won't find any shape-producing node from it. + continue; + } + + // shape-like inputs' indices in this node. + // E.g., for Reshape, it's [1] and for Slice, it's [1, 2, 3, 4]. + auto& shape_input_indices = op_type_version_it->second; + // Now, this `node` is a shape-consuming node as defined by shape_related_inputs_in_nodes. + // Let's find producers for shape-like tensors consumed by this `node`. + // Consider this graph: + // shape = onnx::Concat(s0, s1) + // reshaped = onnx::Reshape(x, shape) + // The loop below does: + // 1. checks all `Reshape`'s inputs, `x` and `shape`, + // 2. finds `shape` is a shape-related variable since Reshape's 2nd input is a shape-related input, + // 3. and then records the producer of `shape` (i.e., `Concat`). + for (auto& input_index : shape_input_indices) { + auto input = node.InputDefs().at(input_index); + auto producer_node = graph.GetProducerNode(input->Name()); + if (producer_node != nullptr && producer_node->OpType() != "Shape" && producer_node->OpType() != "Size") { + // Assume shape-computing sub-graphs begins with Shape, Size, or graph inputs. + // We should not fallback those nodes's upstream nodes to CPU; otherwise, + // it may change + // GPU-tensor-x -> Mul -> GPU-tensor-y -> Shape -> CPU-tensor + // to + // CPU-tensor-x -> Mul -> CPU-tensor -> Shape -> CPU-tensor + // and slows down the computation. + + // After this for-loop, we will reversely traverse all nodes from every shape-producing node + // found here until hitting Shape, Size nodes, graph inputs, or graph initializers. + // All nodes on the traversal path will be forced to run on CPU. + LOGS_DEFAULT(INFO) << "Find a shape producing node (i.e., a node produces a tensor consumed as shape-like input in other nodes): " << node.Name() << std::endl; + shape_producing_nodes.insert(producer_node); + } + } + } + + std::unordered_set shape_related_node_indices; + for (auto& node : shape_producing_nodes) { + LOGS_DEFAULT(INFO) << "Begin the (topologically reverse) traversing from shape producing node: " << node->Name() << std::endl; + std::vector start_nodes = {node}; + + auto to_stop = [](const Node* n1, const Node* n2) { + LOGS_DEFAULT(INFO) << "Skip the traversal from " << n1->Name() << " to " << n2->Name() << " since " << n2->Name() << " is a Shape or Size node." << std::endl; + return n2->OpType() == "Shape" || n2->OpType() == "Size"; + }; + + // Reversely traverse all nodes from the shape-producing node. + // Force nodes to be run on CPU when all inputs and outputs are small. + // Stop the traversal when a "Shape" node is found. + graph.ReverseDFSFrom( + start_nodes, + [&shape_related_node_indices](const Node* n) { + LOGS_DEFAULT(INFO) << "Find an upstream node in shape sub-graph (let's fallback it to CPU): " << n->Name() << std::endl; + shape_related_node_indices.insert(n->Index()); + }, + nullptr, + NodeCompare(), + to_stop); + } + + return shape_related_node_indices; +} + std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, gsl::span tentative_nodes) { @@ -171,6 +284,10 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe } } + if (std::strcmp(std::getenv("ORT_AGGRESSIVE_CPU_FALLBACK"), "1") == 0) { + auto shape_related_node_indices = GetShapeRelatedNodes(graph); + cpu_nodes.insert(shape_related_node_indices.begin(), shape_related_node_indices.end()); + } return cpu_nodes; }