Skip to content

Commit

Permalink
Do not run AOT function inlining when the model does not define any l…
Browse files Browse the repository at this point in the history
…ocal functions (#18302)

### Description
Check if the model defines any local functions.
if not, skip AOT inlining including any schema based functions.
The latter would be inlined during partitioning.

### Motivation and Context
This prevents calls GetCapability() to EPs and enhahces  compatibility.
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Pranav Sharma <[email protected]>
  • Loading branch information
yuslepukhin and pranavsharma authored Nov 7, 2023
1 parent 606356d commit 096307c
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
17 changes: 16 additions & 1 deletion onnxruntime/core/framework/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,16 @@ static Status PartitionOrtFormatModel(const PartitionParams& partition_params,

Status GraphPartitioner::InlineFunctionsAOT(Model& model,
const ExecutionProviders& execution_providers,
const KernelRegistryManager& kernel_registry_manager) const {
const KernelRegistryManager& kernel_registry_manager,
const logging::Logger& logger) const {
const auto local_functions_num = model.GetModelLocalFunctionTemplates().size();
const bool is_there_local_functions = local_functions_num > 0;

if (!is_there_local_functions) {
LOGS(logger, INFO) << "This model does not have any local functions defined. AOT Inlining is not performed";
return Status::OK();
}

auto& graph = model.MainGraph();
InlinedHashSet<std::string> not_inlined;
do {
Expand All @@ -818,6 +827,12 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model,

model.RemoveLocalFunctionsProtos(not_inlined);

LOGS(logger, INFO)
<< "AOT inlining completed. (" << (local_functions_num - model.GetModelLocalFunctionTemplates().size())
<< ") functions of ("
<< local_functions_num
<< ") pruned.";

return Status::OK();
}

Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/framework/graph_partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ class GraphPartitioner {
/// <param name="model">model instance</param>
/// <param name="execution_providers">execution providers considered</param>
/// <param name="kernel_registry_manager">registry manager</param>
/// <param name="logger">session logger</param>
/// <returns></returns>
Status InlineFunctionsAOT(Model& model,
const ExecutionProviders& execution_providers,
const KernelRegistryManager& kernel_registry_manager) const;
const KernelRegistryManager& kernel_registry_manager,
const logging::Logger& logger) const;
#endif

private:
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,9 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool
kOrtSessionOptionsDisableAheadOfTimeFunctionInlining, "0") == "1";
!disable_aot_function_inlining) {
ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.InlineFunctionsAOT(*model_,
execution_providers_, kernel_registry_manager_));
execution_providers_,
kernel_registry_manager_,
*session_logger_));
}

auto apply_transformer_once = [](const GraphTransformer& transformer, const logging::Logger& logger,
Expand Down

0 comments on commit 096307c

Please sign in to comment.