Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functions Ahead Of Time inlininng #17764

Merged
merged 22 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#pragma once

#include <functional>
#include <limits>
#include <memory>
#include <string>
Expand Down Expand Up @@ -83,10 +84,10 @@ class Node {
gsl::span<NodeArg* const> output_args,
const NodeAttributes* attributes,
std::string_view domain) {
Init(std::string{name}, std::string{op_type}, std::string{description},
std::vector<NodeArg*>{input_args.begin(), input_args.end()},
std::vector<NodeArg*>{output_args.begin(), output_args.end()},
attributes, std::string{domain});
Init(name, op_type, description,
input_args,
output_args,
attributes, domain);
}
#endif

Expand Down Expand Up @@ -563,13 +564,13 @@ class Node {
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Node);

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
void Init(const std::string& name,
const std::string& op_type,
const std::string& description,
const std::vector<NodeArg*>& input_args,
const std::vector<NodeArg*>& output_args,
void Init(std::string_view name,
std::string_view op_type,
std::string_view description,
gsl::span<NodeArg* const> input_args,
gsl::span<NodeArg* const> output_args,
const NodeAttributes* attributes,
const std::string& domain);
std::string_view domain);
#endif

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
Expand Down Expand Up @@ -1141,8 +1142,22 @@ class Graph {
*/
Status InlineFunction(Node& node);

/**
Directly insert the nodes in the function proto provided into the graph.
The function converts Constant nodes into the initializers in the graph.
It then creates a node in the graph for each of the function nodes.
All of the names are expected to be specialized, and, therefore unique.
See function_utils::Specialize().

The Graph needs to be Resolve()d after this call.
@param func_to_inline
@returns Status indicating success or providing an error message.
*/

Status InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_inline);

/** Mark a NodeArg name as coming from the outer scope when programmatically constructing a Graph that will
be used as a GraphProto attribute in another Node..
be used as a GraphProto attribute in another Node.
e.g. when creating a Graph instance that will be used as a subgraph in a control flow operator, it is necessary to
define placeholder NodeArgs for outer scope values. This prevents these values from becoming explicit graph inputs
when the Graph is resolved.
Expand Down Expand Up @@ -1391,6 +1406,13 @@ class Graph {
Node& AddNode(const ONNX_NAMESPACE::NodeProto& node_proto,
const ArgNameToTypeMap& name_to_type);

/** Helper that converts and adds constant node proto to an initializer in the graph.
@param constant_node_proto Constant node to convert
@param new_name use the new name for the initializer.
*/
Status AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& constant_node_proto,
std::optional<std::string_view> new_name);

#endif

Version IrVersion() const noexcept {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enab
// GeluApproximation has side effects which may change the inference results. It is disabled by default due to this.
static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation";

// This setting controls whether to enable AheadOfTime function inlining.
// AOT function inlining examines the graph and attempts to inline as many locally defined functions in the model
// as possible with the help of enabled execution providers.
// This can reduce the number of function calls and improve performance because it is done before
// Level1 optimizers and constant folding. However, under some circumstances, when the EPs are not available,
// one can disable the AOT inlining, produce an optimized model and postpone AOT until run time.
// "0": enable; "1": disable.
// Its default value is "0".
static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining = "session.disable_aot_function_inlining";

#ifdef ENABLE_TRAINING
// Specifies a list of op types for memory footprint reduction.
// The value should be a ","-delimited list of pair of
Expand Down
162 changes: 146 additions & 16 deletions onnxruntime/core/framework/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
#include "core/framework/kernel_registry_manager.h"
#include "core/framework/kernel_registry.h"
#include "core/graph/function.h"
#include "core/graph/function_utils.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/model.h"

// uncomment this line to count non-CUDA ops in ONNX domain
// #define COUNT_NON_CUDA_OPS
Expand Down Expand Up @@ -129,6 +131,21 @@
std::reference_wrapper<const layout_transformation::DebugGraphFn> debug_graph_fn;
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
};

auto get_capabilities = [](const IExecutionProvider& ep,
const GraphViewer& graph_viewer,
const IExecutionProvider::IKernelLookup& kernel_lookup) {
auto capabilities = ep.GetCapability(graph_viewer, kernel_lookup);

// In theory an EP could return an empty capability. Remove those.
capabilities.erase(std::remove_if(capabilities.begin(), capabilities.end(),
[](const std::unique_ptr<ComputeCapability>& capability) {
return !capability || !capability->sub_graph;
}),
capabilities.end());

return capabilities;
};
} // namespace

static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) {
Expand All @@ -143,21 +160,6 @@
}
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

auto get_capabilities = [](const IExecutionProvider& ep,
const GraphViewer& graph_viewer,
const IExecutionProvider::IKernelLookup& kernel_lookup) {
auto capabilities = ep.GetCapability(graph_viewer, kernel_lookup);

// In theory an EP could return an empty capability. Remove those.
capabilities.erase(std::remove_if(capabilities.begin(), capabilities.end(),
[](const std::unique_ptr<ComputeCapability>& capability) {
return !capability || !capability->sub_graph;
}),
capabilities.end());

return capabilities;
};

const auto& kernel_registry_mgr = params.kernel_registry_mgr.get();
const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type);
const KernelLookup kernel_lookup{ep_type,
Expand Down Expand Up @@ -239,6 +241,26 @@
}

#if !defined(ORT_MINIMAL_BUILD)

// This function queries the capabilities for a given EP, but it does not assign the nodes.
// It also does not perform layout transformation. This will be done during normal partitioning.
static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer,
const KernelRegistryManager& kernel_registry_mgr,
const IExecutionProvider& current_ep,
std::vector<std::unique_ptr<ComputeCapability>>& capabilities) {
const auto& ep_type = current_ep.Type();

const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type);
const KernelLookup kernel_lookup{ep_type,
kernel_registries_for_ep,
kernel_registry_mgr.GetKernelTypeStrResolver()};

// TODO: Provide EP with a capability to look inside the functions.

Check warning on line 258 in onnxruntime/core/framework/graph_partitioner.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/framework/graph_partitioner.cc#L258

Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
Raw output
onnxruntime/core/framework/graph_partitioner.cc:258:  Missing username in TODO; it should look like "// TODO(my_username): Stuff."  [readability/todo] [2]
capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup);
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved

return Status::OK();
}

/**
* Check if a node can be placed on a specific provider.
* Do nothing if the node is already assigned
Expand Down Expand Up @@ -518,7 +540,7 @@
// successfully inlined, we re-run the partitioner on the modified graph.
// NOTE: Inlining the function will change the nodes in the Graph instance, so we can't do that while iterating
// using graph.Nodes().
std::vector<Node*> nodes_to_inline;
InlinedVector<Node*> nodes_to_inline;
for (auto& node : graph.Nodes()) {
if (node.GetExecutionProviderType().empty() && node.CanBeInlined()) {
nodes_to_inline.push_back(&node);
Expand All @@ -533,6 +555,85 @@
return Status::OK();
}

static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_providers,
const KernelRegistryManager& kernel_registry_mgr,
Graph& graph,
InlinedHashSet<std::string>& not_inlined,
size_t& inlined_count) {
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
// doing it here saves all providers checking for this in GetCapability
if (graph.NumberOfNodes() == 0) {
return Status::OK();
}

for (auto& node : graph.Nodes()) {
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
Graph* subgraph = entry.second;
// we pass through the FuncManager from the top level graph
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers,
kernel_registry_mgr,
*subgraph,
not_inlined,
inlined_count));
}
}

// Gather the candidates
InlinedVector<NodeIndex> inline_candidates;
for (auto& node : graph.Nodes()) {
if (node.CanBeInlined()) {
inline_candidates.push_back(node.Index());
}
}

if (inline_candidates.empty()) {
return Status::OK();
}

// Find out all the nodes that are already taken
const GraphViewer graph_viewer(graph);

InlinedHashSet<NodeIndex> claimed_by_ep;
for (const auto& ep : execution_providers) {
std::vector<std::unique_ptr<ComputeCapability>> capabilities;
ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, capabilities));
for (auto& capability : capabilities) {
const auto& nodes = capability->sub_graph->nodes;
if (nodes.size() == 1) {
// Single node capability.
ORT_IGNORE_RETURN_VALUE(claimed_by_ep.insert(nodes[0]));
} else {
// Make sure none is claimed by other EPs mirroring the logic in PartitionOnnxFormatModelImpl.
if (std::all_of(nodes.cbegin(), nodes.cend(), [&claimed_by_ep](NodeIndex node_index) {
return claimed_by_ep.count(node_index) == 0;
})) {
claimed_by_ep.insert(nodes.cbegin(), nodes.cend());
}
}
}
}

// TODO: Insert version check. We need to collect all the versions

Check warning on line 616 in onnxruntime/core/framework/graph_partitioner.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/framework/graph_partitioner.cc#L616

Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
Raw output
onnxruntime/core/framework/graph_partitioner.cc:616:  Missing username in TODO; it should look like "// TODO(my_username): Stuff."  [readability/todo] [2]
// that imported by the model. If the version is not supported by
// the model, we can not inline it.

for (auto node_index : inline_candidates) {
auto* node = graph.GetNode(node_index);
if (node != nullptr) {
if (claimed_by_ep.count(node_index) == 0) {
ORT_RETURN_IF_ERROR(graph.InlineFunction(*node));
++inlined_count;
} else {
// OpType is the same as function name.
auto function_id = function_utils::GetFunctionIdentifier(node->Domain(), node->OpType());
ORT_IGNORE_RETURN_VALUE(not_inlined.insert(std::move(function_id)));
}
}
}

return Status::OK();
}

static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode,
const ExecutionProviders& execution_providers,
KernelRegistryManager& kernel_registry_manager) {
Expand Down Expand Up @@ -693,6 +794,35 @@
return Status::OK();
}

#ifndef ORT_MINIMAL_BUILD

Status GraphPartitioner::InlineFunctionsAOT(Model& model,
const ExecutionProviders& execution_providers,
const KernelRegistryManager& kernel_registry_manager) const {
auto& graph = model.MainGraph();
InlinedHashSet<std::string> not_inlined;

Check warning on line 803 in onnxruntime/core/framework/graph_partitioner.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/framework/graph_partitioner.cc#L803

Add #include <string> for string [build/include_what_you_use] [4]
Raw output
onnxruntime/core/framework/graph_partitioner.cc:803:  Add #include <string> for string  [build/include_what_you_use] [4]
do {
size_t inlined_count = 0;
ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers,
kernel_registry_manager,
graph,
not_inlined,
inlined_count));

if (inlined_count == 0) {
break;
}

ORT_RETURN_IF_ERROR(graph.Resolve());
} while (true);

model.RemoveLocalFunctionsProtos(not_inlined);

return Status::OK();
}

#endif

Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
const layout_transformation::TransformLayoutFunction& transform_layout_function,
Mode mode,
Expand Down
21 changes: 21 additions & 0 deletions onnxruntime/core/framework/graph_partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace onnxruntime {

class ExecutionProviders;
class KernelRegistryManager;
class Model;

class GraphPartitioner {
public:
Expand All @@ -33,6 +34,26 @@ class GraphPartitioner {
Mode mode = Mode::kNormal,
const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const;

#ifndef ORT_MINIMAL_BUILD
/// <summary>
// Ahead of Time Function inlining. The main purpose of the function is to inline as many
// functions as possible and delete locally defined functions to reduce the size of the model.
// This would make other optimizations to be more effective.
//
// This function performs GetCapability on the graph and its subgraphs bottom up
// and inlines any functions that are not claimed by any of the execution providers.
// This function does not attempt to run layout transformation, and it does not assign EPs.
// The latter will be done by graph partitioning after Level1 optimizations are done.
/// </summary>
/// <param name="model">model instance</param>
/// <param name="execution_providers">execution providers considered</param>
/// <param name="kernel_registry_manager">registry manager</param>
/// <returns></returns>
Status InlineFunctionsAOT(Model& model,
const ExecutionProviders& execution_providers,
const KernelRegistryManager& kernel_registry_manager) const;
#endif

private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphPartitioner);

Expand Down
7 changes: 4 additions & 3 deletions onnxruntime/core/graph/function_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,8 @@ class Inliner {
// Replace given name with a unique version of the name, and cache the
// renaming-binding in current scope.
void make_unique(std::string& name) {
auto new_name = prefix_ + name;
auto new_name{prefix_};
new_name.append("_").append(name);
auto& current_scope = rename_scopes_.back();
current_scope[name] = new_name;
name = std::move(new_name);
Expand Down Expand Up @@ -410,7 +411,7 @@ class Inliner {
std::string rename_as = actuals.Get(i);
if constexpr (isOutput) {
if (rename_as.empty())
rename_as.assign(prefix_).append(formal);
rename_as.assign(prefix_).append("_").append(formal);
}
current_scope[formal] = rename_as;
if (!rename_as.empty())
Expand All @@ -420,7 +421,7 @@ class Inliner {
std::string& formal = *formals.Mutable(i);
std::string rename_as;
if constexpr (isOutput) {
rename_as.assign(prefix_).append(formal);
rename_as.assign(prefix_).append("_").append(formal);
}
current_scope[formal] = rename_as;
if (!rename_as.empty())
Expand Down
Loading
Loading