Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental] Add a path to fallback more nodes to CPUs #19875

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "core/common/logging/logging.h"
#include "core/common/status.h"
#include "core/framework/data_transfer.h"
#include "core/framework/session_options.h"
#include "core/framework/tensor.h"

namespace onnxruntime {
Expand Down Expand Up @@ -277,6 +278,14 @@ class IExecutionProvider {
return logger_;
}

void SetSessionOptions(const SessionOptions* session_options) {
session_options_ = session_options;
}

const SessionOptions* GetSessionOptions() const {
return session_options_;
}

virtual std::unique_ptr<profiling::EpProfiler> GetProfiler() {
return {};
}
Expand Down Expand Up @@ -330,5 +339,6 @@ class IExecutionProvider {

// It will be set when this object is registered to a session
const logging::Logger* logger_ = nullptr;
const SessionOptions* session_options_ = nullptr;
};
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,48 @@ static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed
// - "0": Gemm FastMath mode is not enabled. [DEFAULT]
// - "1": Gemm FastMath mode is enabled.
static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";

// Optionally identifying sub-graphs by traversing the graph in reverse order
// starting from all CPU consuming nodes (e.g., for Reshape-13, the traversal
// starts from its 2nd input). Traversing stops when hitting a Size or Shape operator.
// The identified sub-graphs will be assigned to CPU EP.
//
// See comments in the model defined by onnxscript in Python below for an example.
//
// @onnxscript.script(default_opset=opset18)
// def foo(x: FLOAT[12], w: FLOAT[6, 2], dim0: INT64[1], dim1: INT64[1]):
// # This should be computed by CPU but is placed
// # on CUDA (i.e., all inputs and outputs are GPU tensors)
// # when this option is not set to 1.
// dim2 = dim1 + 1
// # Same as `dim2 = dim1 + 1`. Another GPU node
// # when this option is not set to 1.
// dim3 = dim2 - 1
// # Same as `dim2 = dim1 + 1`. Another GPU node
// # when this option is not set to 1.
// new_shape = opset18.Concat(dim0, dim3, axis=0)
//
// # A memcpy node will be inserted to copy GPU output
// # `new_shape` to CPU since Reshape's 2nd input is a CPU tensor
// # per schema definition.
// #
// # To
// # 1. remove memcpy node.
// # 2. fallback all computation above this line to CPU.
// # use the following code in Python
// # import onnxruntime
// # so = onnxruntime.SessionOptions()
// # so.add_session_config_entry("session.reverse_traverse_cpu_fallback", "1")
// #
// # Note that x and new_x are still on GPU w/wo
// # setting session.reverse_traverse_cpu_fallback.
// new_x = opset18.Reshape(x, new_shape)
// # A pure GPU node.
// y = opset18.MatMul(new_x, w)
// return y
//
// Option values:
// - "0": Disable reverse-traversing CPU fallback. [DEFAULT]
// - "1": Enable reverse-traversing CPU fallback when calling GetCpuPreferredNodes(...).
// (i.e., adding nodes found by GetShapeRelatedNodes(...) to CPU node list internally).
static const char* const kOrtSessionOptionsAggressiveCpuFallback = "session.aggressive_cpu_fallback";
130 changes: 125 additions & 5 deletions onnxruntime/core/framework/fallback_cpu_capability.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
// Licensed under the MIT License.

#include "core/framework/fallback_cpu_capability.h"
#include "core/framework/tensorprotoutils.h"
#include "core/common/inlined_containers.h"

#include <cstring>

Check warning on line 8 in onnxruntime/core/framework/fallback_cpu_capability.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: fallback_cpu_capability.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/framework/fallback_cpu_capability.cc:8: Found C++ system header after other header. Should be: fallback_cpu_capability.h, c system, c++ system, other. [build/include_order] [4]
#include <cstdlib>

Check warning on line 9 in onnxruntime/core/framework/fallback_cpu_capability.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: fallback_cpu_capability.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/framework/fallback_cpu_capability.cc:9: Found C++ system header after other header. Should be: fallback_cpu_capability.h, c system, c++ system, other. [build/include_order] [4]
#include <queue>

#include "onnx/defs/data_type_utils.h"
Expand Down Expand Up @@ -39,9 +42,122 @@
}
} // namespace

static InlinedHashSet<NodeIndex> GetShapeRelatedNodes(const onnxruntime::GraphViewer& viewer) {
// Conceptually, this function traverse from shape-consuming nodes
// to fallback all its upstream nodes to CPU. Consider a graph
//
//
// The traversal should stop when
// 1. hitting Shape, Size nodes, graph inputs, or graph initializers.
// 2. [TODO] hitting nodes with some large inputs or outputs. Before that,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some more TODO documentation suggestions- #19769 (comment)

// we need shape inference to determine the size of the inputs and outputs.
// Some graph transforms add nodes without shape information, so
// checking shapes will make the algorithm more unstable now.
LOGS_DEFAULT(VERBOSE) << "Call GetShapeRelatedNodes to identify extra CPU nodes." << std::endl;

const static InlinedHashMap<std::string_view, InlinedHashMap<int64_t, std::vector<size_t>>> shape_related_inputs_in_nodes = {

Check warning on line 58 in onnxruntime/core/framework/fallback_cpu_capability.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/framework/fallback_cpu_capability.cc:58: Lines should be <= 120 characters long [whitespace/line_length] [2]

Check warning on line 58 in onnxruntime/core/framework/fallback_cpu_capability.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Storage-class specifier (static, extern, typedef, etc) should be at the beginning of the declaration. [build/storage_class] [5] Raw Output: onnxruntime/core/framework/fallback_cpu_capability.cc:58: Storage-class specifier (static, extern, typedef, etc) should be at the beginning of the declaration. [build/storage_class] [5]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// 2nd input of Expand-13 is a shape-related input.
{"Expand", {{13 /* since version */, {1} /* shape inputs' indices */}}},
// 2nd input (indexed by 1) of Reshape-13, Reshape-14, Reshape-19, Reshape-21 is a shape-related input.
{"Reshape", {{13, {1}}, {14, {1}}, {19, {1}}, {21, {1}}}},
// 2nd input of Unsqueeze-13 and Unsqueeze-21 is a shape-related input.
{"Unsqueeze", {{13, {1}}, {21, {1}}}},
// 1st input of ConstantOfShape is a shape-related input.
{"ConstantOfShape", {{9, {0}}, {20, {0}}, {21, {0}}}},
// 2nd to 5th inputs of Slice-13 are shape-related inputs.
{"Slice", {{13, {1, 2, 3, 4}}}}};

auto& graph = viewer.GetGraph();
// Each shape-producing node produces a tensor consumed
// as shape, axis, size, and indices.
// E.g.,
// shape = onnx::Concat(s0, s1)
// reshaped = onnx::Reshape(x, shape)
// Then, the shape-producing node is Concat.
InlinedHashSet<const Node*> shape_producing_nodes;
// This loop collects all shape-producing nodes by finding
// all nodes that produce tensors specified in shape_related_inputs_in_nodes.
// E.g., for the above example, Concat is a shape-producing node because
// "Reshape" has a shape-related input at index 1.
for (auto& node : graph.Nodes()) {
LOGS_DEFAULT(VERBOSE) << "Check if node " << node.Name() << " can be sink of shape sub-graph." << std::endl;
auto op_type_it = shape_related_inputs_in_nodes.find(node.OpType());
if (op_type_it == shape_related_inputs_in_nodes.end()) {
// This node doesn't consume tensor as shape,
// so we won't find any shape-producing node from it.
continue;
}
auto op_type_version_it = op_type_it->second.find(node.SinceVersion());
if (op_type_version_it == op_type_it->second.end()) {
// This node doesn't consume tensor as shape in this version,
// so we won't find any shape-producing node from it.
continue;
}

// shape-like inputs' indices in this node.
// E.g., for Reshape, it's [1] and for Slice, it's [1, 2, 3, 4].
auto& shape_input_indices = op_type_version_it->second;
// Now, this `node` is a shape-consuming node as defined by shape_related_inputs_in_nodes.
// Let's find producers for shape-like tensors consumed by this `node`.
// Consider this graph:
// shape = onnx::Concat(s0, s1)
// reshaped = onnx::Reshape(x, shape)
// The loop below does:
// 1. checks all `Reshape`'s inputs, `x` and `shape`,
// 2. finds `shape` is a shape-related variable since Reshape's 2nd input is a shape-related input,
// 3. and then records the producer of `shape` (i.e., `Concat`).
for (auto& input_index : shape_input_indices) {
auto input = node.InputDefs().at(input_index);
auto producer_node = graph.GetProducerNode(input->Name());
if (producer_node != nullptr && producer_node->OpType() != "Shape" && producer_node->OpType() != "Size") {
// Assume shape-computing sub-graphs begins with Shape, Size, or graph inputs.
// We should not fallback those nodes's upstream nodes to CPU; otherwise,
// it may change
// GPU-tensor-x -> Mul -> GPU-tensor-y -> Shape -> CPU-tensor
// to
// CPU-tensor-x -> Mul -> CPU-tensor -> Shape -> CPU-tensor
// and slows down the computation.

// After this for-loop, we will reversely traverse all nodes from every shape-producing node
// found here until hitting Shape, Size nodes, graph inputs, or graph initializers.
// All nodes on the traversal path will be forced to run on CPU.
LOGS_DEFAULT(VERBOSE) << "Find a shape producing node (i.e., a node produces a tensor consumed as shape-like input in other nodes): " << node.Name() << std::endl;

Check warning on line 124 in onnxruntime/core/framework/fallback_cpu_capability.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/framework/fallback_cpu_capability.cc:124: Lines should be <= 120 characters long [whitespace/line_length] [2]
shape_producing_nodes.insert(producer_node);
}
}
}

InlinedHashSet<NodeIndex> shape_related_node_indices;
for (auto& node : shape_producing_nodes) {
LOGS_DEFAULT(VERBOSE) << "Begin the (topologically reverse) traversing from shape producing node: " << node->Name() << std::endl;

Check warning on line 132 in onnxruntime/core/framework/fallback_cpu_capability.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/framework/fallback_cpu_capability.cc:132: Lines should be <= 120 characters long [whitespace/line_length] [2]
std::vector<const Node*> start_nodes = {node};

auto to_stop = [](const Node* n1, const Node* n2) {
LOGS_DEFAULT(VERBOSE) << "Skip the traversal from " << n1->Name() << " to " << n2->Name() << " since " << n2->Name() << " is a Shape or Size node." << std::endl;

Check warning on line 136 in onnxruntime/core/framework/fallback_cpu_capability.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/framework/fallback_cpu_capability.cc:136: Lines should be <= 120 characters long [whitespace/line_length] [2]
return n2->OpType() == "Shape" || n2->OpType() == "Size";
};

// Reversely traverse all nodes from the shape-producing node.
// Force nodes to be run on CPU when all inputs and outputs are small.
// Stop the traversal when a "Shape" node is found.
graph.ReverseDFSFrom(
start_nodes,
[&shape_related_node_indices](const Node* n) {
LOGS_DEFAULT(VERBOSE) << "Find an upstream node in shape sub-graph (let's fallback it to CPU): " << n->Name() << std::endl;

Check warning on line 146 in onnxruntime/core/framework/fallback_cpu_capability.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/framework/fallback_cpu_capability.cc:146: Lines should be <= 120 characters long [whitespace/line_length] [2]
shape_related_node_indices.insert(n->Index());
},
nullptr,
NodeCompare(),
to_stop);
}

return shape_related_node_indices;
}

std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph,
const IExecutionProvider::IKernelLookup& kernel_lookup,
gsl::span<const NodeIndex> tentative_nodes) {
gsl::span<const NodeIndex> tentative_nodes,
const bool aggressive_cpu_fallback) {
wschin marked this conversation as resolved.
Show resolved Hide resolved
// automatic conversion from const std::vector&
const auto& ordered_nodes = graph.GetNodesInTopologicalOrder();
InlinedVector<size_t> node_id_to_order_map(graph.MaxNodeIndex());
Expand Down Expand Up @@ -83,7 +199,7 @@
auto consumer_nodes = graph.GetConsumerNodes(node_arg.Name());
for (auto& consumer_node : consumer_nodes) {
candidates.push(consumer_node->Index());
LOGS_DEFAULT(INFO) << "Candidate for fallback CPU execution: " << consumer_node->Name();
LOGS_DEFAULT(VERBOSE) << "Candidate for fallback CPU execution: " << consumer_node->Name();
}
}
return Status::OK();
Expand Down Expand Up @@ -159,9 +275,9 @@

if (place_in_cpu) {
cpu_nodes.insert(cur);
LOGS_DEFAULT(INFO) << "ORT optimization- Force fallback to CPU execution for node: " << node->Name()
<< " because the CPU execution path is deemed faster than overhead involved with execution on other EPs "
<< " capable of executing this node";
LOGS_DEFAULT(VERBOSE) << "ORT optimization- Force fallback to CPU execution for node: " << node->Name()
<< " because the CPU execution path is deemed faster than overhead involved with execution on other EPs "

Check warning on line 279 in onnxruntime/core/framework/fallback_cpu_capability.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/framework/fallback_cpu_capability.cc:279: Lines should be <= 120 characters long [whitespace/line_length] [2]
<< " capable of executing this node";
for (auto* output : node->OutputDefs()) {
cpu_output_args.insert(output);
}
Expand All @@ -171,6 +287,10 @@
}
}

if (aggressive_cpu_fallback) {
auto shape_related_node_indices = GetShapeRelatedNodes(graph);
cpu_nodes.insert(shape_related_node_indices.begin(), shape_related_node_indices.end());
}
return cpu_nodes;
}

Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/framework/fallback_cpu_capability.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ namespace onnxruntime {
@param graph Graph viewer
@param kernel_lookup The kernel lookup for the target execution provider
@param tentative_nodes Nodes that are tentative to be placed on on target EP
@param aggressive_cpu_fallback This is the set by kOrtSessionOptionsAggressiveCpuFallback option.
*/
std::unordered_set<NodeIndex> GetCpuPreferredNodes(const GraphViewer& graph,
const IExecutionProvider::IKernelLookup& kernel_lookup,
gsl::span<const NodeIndex> tentative_nodes);
gsl::span<const NodeIndex> tentative_nodes,
const bool aggressive_cpu_fallback);

} // namespace onnxruntime
4 changes: 4 additions & 0 deletions onnxruntime/core/framework/session_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@
// The configuration keys and value formats are defined in
// /include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
ConfigOptions config_options;

const ConfigOptions& GetConfigOptions() const {
return config_options;
};

Check warning on line 152 in onnxruntime/core/framework/session_options.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnxruntime/core/framework/session_options.h:152: You don't need a ; after a } [readability/braces] [4]
std::unordered_map<std::string, const OrtValue*> initializers_to_share_map;

// See onnxruntime_c_api.h for detailed documentation.
Expand Down
8 changes: 7 additions & 1 deletion onnxruntime/core/providers/cann/cann_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1296,7 +1296,13 @@ CANNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewe
candidates.push_back(node.Index());
}

auto cpu_nodes = GetCpuPreferredNodes(graph_viewer, kernel_lookup, candidates);
auto p_session_options = GetSessionOptions();
bool aggressive_cpu_fallback = false;
if (p_session_options) {
aggressive_cpu_fallback = p_session_options->config_options.GetConfigOrDefault(
kOrtSessionOptionsAggressiveCpuFallback, "0") == "1";
}
auto cpu_nodes = GetCpuPreferredNodes(graph_viewer, kernel_lookup, candidates, aggressive_cpu_fallback);
for (auto& node_index : candidates) {
if (cpu_nodes.count(node_index) > 0)
continue;
Expand Down
9 changes: 8 additions & 1 deletion onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "core/providers/cuda/gpu_data_transfer.h"
#include "core/providers/cuda/cuda_profiler.h"
#include "core/session/onnxruntime_run_options_config_keys.h"
#include "core/session/onnxruntime_session_options_config_keys.h"

#ifndef USE_CUDA_MINIMAL
#ifndef DISABLE_CONTRIB_OPS
Expand Down Expand Up @@ -2530,7 +2531,13 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
// For CUDA EP, exclude the subgraph that is preferred to be placed in CPU
// These are usually shape related computation subgraphs
// Following logic can be extended for other EPs
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes);
auto p_session_options = GetSessionOptions();
bool aggressive_cpu_fallback = false;
if (p_session_options) {
aggressive_cpu_fallback = p_session_options->GetConfigOptions().GetConfigEntry(
kOrtSessionOptionsAggressiveCpuFallback) == "1";
}
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, aggressive_cpu_fallback);
std::vector<std::unique_ptr<ComputeCapability>> result;
for (auto& node_index : candidates) {
if (cpu_nodes.count(node_index) > 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,13 @@ namespace Dml
}

// Get the list of nodes that should stay on the CPU
auto cpuPreferredNodes = GetCpuPreferredNodes(graph, kernel_lookup, tentativeNodes);
auto p_session_options = GetSessionOptions();
bool aggressive_cpu_fallback = false;
if (p_session_options) {
aggressive_cpu_fallback = p_session_options->config_options.GetConfigOrDefault(
kOrtSessionOptionsAggressiveCpuFallback, "0") == "1";
}
auto cpuPreferredNodes = GetCpuPreferredNodes(graph, kernel_lookup, tentativeNodes, aggressive_cpu_fallback);

for (size_t nodeIndex : toplogicalOrder)
{
Expand Down
8 changes: 7 additions & 1 deletion onnxruntime/core/providers/js/js_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,13 @@ std::vector<std::unique_ptr<ComputeCapability>> JsExecutionProvider::GetCapabili
candidates.push_back(node.Index());
tenative_candidates.push_back(node.Index());
}
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates);
auto p_session_options = GetSessionOptions();
bool aggressive_cpu_fallback = false;
if (p_session_options) {
aggressive_cpu_fallback = p_session_options->config_options.GetConfigOrDefault(
kOrtSessionOptionsAggressiveCpuFallback, "0") == "1";
}
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates, aggressive_cpu_fallback);
std::vector<std::unique_ptr<ComputeCapability>> result;
for (auto& node_index : candidates) {
if (cpu_nodes.count(node_index) > 0) {
Expand Down
8 changes: 7 additions & 1 deletion onnxruntime/core/providers/rocm/rocm_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2415,7 +2415,13 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
// For ROCM EP, exclude the subgraph that is preferred to be placed in CPU
// These are usually shape related computation subgraphs
// Following logic can be extended for other EPs
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, candidates);
auto p_session_options = GetSessionOptions();
bool aggressive_cpu_fallback = false;
if (p_session_options) {
aggressive_cpu_fallback = p_session_options->config_options.GetConfigOrDefault(
kOrtSessionOptionsAggressiveCpuFallback, "0") == "1";
}
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, candidates, aggressive_cpu_fallback);
std::vector<std::unique_ptr<ComputeCapability>> result;
for (auto& node_index : candidates) {
if (cpu_nodes.count(node_index) > 0)
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/shared_library/provider_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ std::unique_ptr<IDataTransfer> CreateGPUDataTransfer();

std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph,
const IExecutionProvider::IKernelLookup& kernel_lookup,
gsl::span<const NodeIndex> tentative_nodes);
gsl::span<const NodeIndex> tentative_nodes,
const bool aggressive_cpu_fallback);

std::string GetEnvironmentVar(const std::string& var_name);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,9 @@ std::string GetEnvironmentVar(const std::string& var_name) {

std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph,
const IExecutionProvider::IKernelLookup& kernel_lookup,
gsl::span<const NodeIndex> tentative_nodes) {
return g_host->GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes);
gsl::span<const NodeIndex> tentative_nodes,
const bool aggressive_cpu_fallback) {
return g_host->GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, aggressive_cpu_fallback);
}

namespace profiling {
Expand Down
Loading
Loading