diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index f153e88909b8d..462d410e13769 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include #include @@ -83,10 +84,10 @@ class Node { gsl::span output_args, const NodeAttributes* attributes, std::string_view domain) { - Init(std::string{name}, std::string{op_type}, std::string{description}, - std::vector{input_args.begin(), input_args.end()}, - std::vector{output_args.begin(), output_args.end()}, - attributes, std::string{domain}); + Init(name, op_type, description, + input_args, + output_args, + attributes, domain); } #endif @@ -563,13 +564,13 @@ class Node { ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Node); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) - void Init(const std::string& name, - const std::string& op_type, - const std::string& description, - const std::vector& input_args, - const std::vector& output_args, + void Init(std::string_view name, + std::string_view op_type, + std::string_view description, + gsl::span input_args, + gsl::span output_args, const NodeAttributes* attributes, - const std::string& domain); + std::string_view domain); #endif #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -1141,8 +1142,22 @@ class Graph { */ Status InlineFunction(Node& node); + /** + Directly insert the nodes in the function proto provided into the graph. + The function converts Constant nodes into the initializers in the graph. + It then creates a node in the graph for each of the function nodes. + All of the names are expected to be specialized, and, therefore unique. + See function_utils::Specialize(). + + The Graph needs to be Resolve()d after this call. + @param func_to_inline + @returns Status indicating success or providing an error message. + */ + + Status InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_inline); + /** Mark a NodeArg name as coming from the outer scope when programmatically constructing a Graph that will - be used as a GraphProto attribute in another Node.. + be used as a GraphProto attribute in another Node. e.g. when creating a Graph instance that will be used as a subgraph in a control flow operator, it is necessary to define placeholder NodeArgs for outer scope values. This prevents these values from becoming explicit graph inputs when the Graph is resolved. @@ -1391,6 +1406,13 @@ class Graph { Node& AddNode(const ONNX_NAMESPACE::NodeProto& node_proto, const ArgNameToTypeMap& name_to_type); + /** Helper that converts and adds constant node proto to an initializer in the graph. + @param constant_node_proto Constant node to convert + @param new_name use the new name for the initializer. + */ + Status AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& constant_node_proto, + std::optional new_name); + #endif Version IrVersion() const noexcept { diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 37545f41b43dd..831def24e4f5e 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -67,6 +67,16 @@ static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enab // GeluApproximation has side effects which may change the inference results. It is disabled by default due to this. static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation"; +// This setting controls whether to enable AheadOfTime function inlining. +// AOT function inlining examines the graph and attempts to inline as many locally defined functions in the model +// as possible with the help of enabled execution providers. +// This can reduce the number of function calls and improve performance because it is done before +// Level1 optimizers and constant folding. However, under some circumstances, when the EPs are not available, +// one can disable the AOT inlining, produce an optimized model and postpone AOT until run time. +// "0": enable; "1": disable. +// Its default value is "0". +static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining = "session.disable_aot_function_inlining"; + #ifdef ENABLE_TRAINING // Specifies a list of op types for memory footprint reduction. // The value should be a ","-delimited list of pair of diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 1b492a3561396..b028596fe4e6d 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -13,7 +13,9 @@ #include "core/framework/kernel_registry_manager.h" #include "core/framework/kernel_registry.h" #include "core/graph/function.h" +#include "core/graph/function_utils.h" #include "core/graph/graph_viewer.h" +#include "core/graph/model.h" // uncomment this line to count non-CUDA ops in ONNX domain // #define COUNT_NON_CUDA_OPS @@ -129,6 +131,21 @@ struct GetCapabilityForEPParams { std::reference_wrapper debug_graph_fn; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) }; + +auto get_capabilities = [](const IExecutionProvider& ep, + const GraphViewer& graph_viewer, + const IExecutionProvider::IKernelLookup& kernel_lookup) { + auto capabilities = ep.GetCapability(graph_viewer, kernel_lookup); + + // In theory an EP could return an empty capability. Remove those. + capabilities.erase(std::remove_if(capabilities.begin(), capabilities.end(), + [](const std::unique_ptr& capability) { + return !capability || !capability->sub_graph; + }), + capabilities.end()); + + return capabilities; +}; } // namespace static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { @@ -143,21 +160,6 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - auto get_capabilities = [](const IExecutionProvider& ep, - const GraphViewer& graph_viewer, - const IExecutionProvider::IKernelLookup& kernel_lookup) { - auto capabilities = ep.GetCapability(graph_viewer, kernel_lookup); - - // In theory an EP could return an empty capability. Remove those. - capabilities.erase(std::remove_if(capabilities.begin(), capabilities.end(), - [](const std::unique_ptr& capability) { - return !capability || !capability->sub_graph; - }), - capabilities.end()); - - return capabilities; - }; - const auto& kernel_registry_mgr = params.kernel_registry_mgr.get(); const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type); const KernelLookup kernel_lookup{ep_type, @@ -239,6 +241,26 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { } #if !defined(ORT_MINIMAL_BUILD) + +// This function queries the capabilities for a given EP, but it does not assign the nodes. +// It also does not perform layout transformation. This will be done during normal partitioning. +static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer, + const KernelRegistryManager& kernel_registry_mgr, + const IExecutionProvider& current_ep, + std::vector>& capabilities) { + const auto& ep_type = current_ep.Type(); + + const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type); + const KernelLookup kernel_lookup{ep_type, + kernel_registries_for_ep, + kernel_registry_mgr.GetKernelTypeStrResolver()}; + + // TODO: Provide EP with a capability to look inside the functions. + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup); + + return Status::OK(); +} + /** * Check if a node can be placed on a specific provider. * Do nothing if the node is already assigned @@ -518,7 +540,7 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) { // successfully inlined, we re-run the partitioner on the modified graph. // NOTE: Inlining the function will change the nodes in the Graph instance, so we can't do that while iterating // using graph.Nodes(). - std::vector nodes_to_inline; + InlinedVector nodes_to_inline; for (auto& node : graph.Nodes()) { if (node.GetExecutionProviderType().empty() && node.CanBeInlined()) { nodes_to_inline.push_back(&node); @@ -533,6 +555,85 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) { return Status::OK(); } +static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_providers, + const KernelRegistryManager& kernel_registry_mgr, + Graph& graph, + InlinedHashSet& not_inlined, + size_t& inlined_count) { + // handle testing edge case where optimizers or constant lifting results in graph with no nodes. + // doing it here saves all providers checking for this in GetCapability + if (graph.NumberOfNodes() == 0) { + return Status::OK(); + } + + for (auto& node : graph.Nodes()) { + for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { + Graph* subgraph = entry.second; + // we pass through the FuncManager from the top level graph + ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, + kernel_registry_mgr, + *subgraph, + not_inlined, + inlined_count)); + } + } + + // Gather the candidates + InlinedVector inline_candidates; + for (auto& node : graph.Nodes()) { + if (node.CanBeInlined()) { + inline_candidates.push_back(node.Index()); + } + } + + if (inline_candidates.empty()) { + return Status::OK(); + } + + // Find out all the nodes that are already taken + const GraphViewer graph_viewer(graph); + + InlinedHashSet claimed_by_ep; + for (const auto& ep : execution_providers) { + std::vector> capabilities; + ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, capabilities)); + for (auto& capability : capabilities) { + const auto& nodes = capability->sub_graph->nodes; + if (nodes.size() == 1) { + // Single node capability. + ORT_IGNORE_RETURN_VALUE(claimed_by_ep.insert(nodes[0])); + } else { + // Make sure none is claimed by other EPs mirroring the logic in PartitionOnnxFormatModelImpl. + if (std::all_of(nodes.cbegin(), nodes.cend(), [&claimed_by_ep](NodeIndex node_index) { + return claimed_by_ep.count(node_index) == 0; + })) { + claimed_by_ep.insert(nodes.cbegin(), nodes.cend()); + } + } + } + } + + // TODO: Insert version check. We need to collect all the versions + // that imported by the model. If the version is not supported by + // the model, we can not inline it. + + for (auto node_index : inline_candidates) { + auto* node = graph.GetNode(node_index); + if (node != nullptr) { + if (claimed_by_ep.count(node_index) == 0) { + ORT_RETURN_IF_ERROR(graph.InlineFunction(*node)); + ++inlined_count; + } else { + // OpType is the same as function name. + auto function_id = function_utils::GetFunctionIdentifier(node->Domain(), node->OpType()); + ORT_IGNORE_RETURN_VALUE(not_inlined.insert(std::move(function_id))); + } + } + } + + return Status::OK(); +} + static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode, const ExecutionProviders& execution_providers, KernelRegistryManager& kernel_registry_manager) { @@ -693,6 +794,35 @@ static Status PartitionOrtFormatModel(const PartitionParams& partition_params, return Status::OK(); } +#ifndef ORT_MINIMAL_BUILD + +Status GraphPartitioner::InlineFunctionsAOT(Model& model, + const ExecutionProviders& execution_providers, + const KernelRegistryManager& kernel_registry_manager) const { + auto& graph = model.MainGraph(); + InlinedHashSet not_inlined; + do { + size_t inlined_count = 0; + ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, + kernel_registry_manager, + graph, + not_inlined, + inlined_count)); + + if (inlined_count == 0) { + break; + } + + ORT_RETURN_IF_ERROR(graph.Resolve()); + } while (true); + + model.RemoveLocalFunctionsProtos(not_inlined); + + return Status::OK(); +} + +#endif + Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, const layout_transformation::TransformLayoutFunction& transform_layout_function, Mode mode, diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index 36a27e906c651..c1fa46de9145d 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -12,6 +12,7 @@ namespace onnxruntime { class ExecutionProviders; class KernelRegistryManager; +class Model; class GraphPartitioner { public: @@ -33,6 +34,26 @@ class GraphPartitioner { Mode mode = Mode::kNormal, const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const; +#ifndef ORT_MINIMAL_BUILD + /// + // Ahead of Time Function inlining. The main purpose of the function is to inline as many + // functions as possible and delete locally defined functions to reduce the size of the model. + // This would make other optimizations to be more effective. + // + // This function performs GetCapability on the graph and its subgraphs bottom up + // and inlines any functions that are not claimed by any of the execution providers. + // This function does not attempt to run layout transformation, and it does not assign EPs. + // The latter will be done by graph partitioning after Level1 optimizations are done. + /// + /// model instance + /// execution providers considered + /// registry manager + /// + Status InlineFunctionsAOT(Model& model, + const ExecutionProviders& execution_providers, + const KernelRegistryManager& kernel_registry_manager) const; +#endif + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphPartitioner); diff --git a/onnxruntime/core/graph/function_utils.cc b/onnxruntime/core/graph/function_utils.cc index 7477f48088a15..7b0a834a7ffc0 100644 --- a/onnxruntime/core/graph/function_utils.cc +++ b/onnxruntime/core/graph/function_utils.cc @@ -373,7 +373,8 @@ class Inliner { // Replace given name with a unique version of the name, and cache the // renaming-binding in current scope. void make_unique(std::string& name) { - auto new_name = prefix_ + name; + auto new_name{prefix_}; + new_name.append("_").append(name); auto& current_scope = rename_scopes_.back(); current_scope[name] = new_name; name = std::move(new_name); @@ -410,7 +411,7 @@ class Inliner { std::string rename_as = actuals.Get(i); if constexpr (isOutput) { if (rename_as.empty()) - rename_as.assign(prefix_).append(formal); + rename_as.assign(prefix_).append("_").append(formal); } current_scope[formal] = rename_as; if (!rename_as.empty()) @@ -420,7 +421,7 @@ class Inliner { std::string& formal = *formals.Mutable(i); std::string rename_as; if constexpr (isOutput) { - rename_as.assign(prefix_).append(formal); + rename_as.assign(prefix_).append("_").append(formal); } current_scope[formal] = rename_as; if (!rename_as.empty()) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 383c1d689d3c3..cab9467501f55 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -860,18 +860,18 @@ Status Node::LoadEdgesFromOrtFormat(const onnxruntime::fbs::NodeEdge& fbs_node_e } #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) -void Node::Init(const std::string& name, - const std::string& op_type, - const std::string& description, - const std::vector& input_args, - const std::vector& output_args, +void Node::Init(std::string_view name, + std::string_view op_type, + std::string_view description, + gsl::span input_args, + gsl::span output_args, const NodeAttributes* attributes, - const std::string& domain) { + std::string_view domain) { name_ = name; op_type_ = op_type; description_ = description; - definitions_.input_defs = input_args; - definitions_.output_defs = output_args; + definitions_.input_defs.assign(input_args.begin(), input_args.end()); + definitions_.output_defs.assign(output_args.begin(), output_args.end()); domain_ = domain; can_be_saved_ = true; priority_ = 0; @@ -1145,7 +1145,8 @@ Graph::Graph(const Model& owning_model, IOnnxRuntimeOpSchemaCollectionPtr schema_registry, const logging::Logger& logger, bool strict_shape_type_inference) - : Graph(owning_model, graph_proto, domain_to_version, ir_version, schema_registry, nullptr, nullptr, logger, strict_shape_type_inference) {} + : Graph(owning_model, graph_proto, domain_to_version, ir_version, + schema_registry, nullptr, nullptr, logger, strict_shape_type_inference) {} Graph::Graph(const Model& owning_model, GraphProto* graph_proto, const std::unordered_map& domain_to_version, Version ir_version, @@ -3261,8 +3262,8 @@ Node& Graph::AddNode(const std::string& name, gsl::span output_args, const NodeAttributes* attributes, const std::string& domain) { - std::vector inputs; - std::vector outputs; + InlinedVector inputs; + InlinedVector outputs; inputs.resize(input_args.size()); outputs.resize(output_args.size()); int i = 0; @@ -4019,69 +4020,100 @@ Node& Graph::FuseSubGraph(const IndexedSubGraph& sub_graph, return fused_node; } +Status Graph::AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& node_proto, + std::optional new_name) { + const gsl::not_null tensor{graph_proto_->add_initializer()}; + ORT_RETURN_IF_ERROR(utils::ConstantNodeProtoToTensorProto(node_proto, ModelPath(), *tensor, node_proto.output(0))); + + if (new_name.has_value()) { + tensor->set_name(std::string(new_name.value())); + } + + auto insert_result = name_to_initial_tensor_.emplace(tensor->name(), tensor); + ORT_ENFORCE(insert_result.second, "Constant node name: ", tensor->name(), + " conflicts with graph initializer. Check that the node names have been made unique."); + if (GetNodeArg(tensor->name()) == nullptr) { + TypeProto t{TypeProtoFromTensorProto(*tensor)}; + ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(tensor->name(), &t)); + } + +#if !defined(DISABLE_SPARSE_TENSORS) + if (node_proto.attribute(0).type() == AttributeProto_AttributeType_SPARSE_TENSOR) { + ORT_IGNORE_RETURN_VALUE(sparse_tensor_names_.emplace(tensor->name())); + } +#endif + + return Status::OK(); +} + +Status Graph::InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_inline) { + auto to_node_arg = [this](const std::string& name) { + return &this->GetOrCreateNodeArg(name, nullptr); + }; + + // Process constant nodes first and create NodeArg for these as they become initializers + // It is important for the initializers to have NodeArg created, first they are needed + // if the initializer is unused and removed, second if the node depends on the initializer, + // we can have Type attached to it. + InlinedVector non_constant_nodes; + non_constant_nodes.reserve(func_to_inline.node_size()); + for (const auto& inlined_node : func_to_inline.node()) { + if (inlined_node.op_type() == kConstant) { + // Copy constant nodes _value to name_to_initial_tensor_ + ORT_RETURN_IF_ERROR(AddConstantProtoAsInitializer(inlined_node, std::nullopt)); + } else { + non_constant_nodes.push_back(&inlined_node); + } + } + + for (const auto* inlined_node : non_constant_nodes) { + InlinedVector inputs; + InlinedVector outputs; + + for (const auto& tensor_name : inlined_node->input()) + inputs.push_back(to_node_arg(tensor_name)); + + for (const auto& tensor_name : inlined_node->output()) + outputs.push_back(to_node_arg(tensor_name)); + + onnxruntime::NodeAttributes new_attr_map; + new_attr_map.reserve(inlined_node->attribute_size()); + for (const auto& node_attr : inlined_node->attribute()) { + new_attr_map.insert_or_assign(node_attr.name(), node_attr); + } + ORT_IGNORE_RETURN_VALUE(AddNode(inlined_node->name(), inlined_node->op_type(), + inlined_node->doc_string(), inputs, outputs, + &new_attr_map, inlined_node->domain())); + } + + return Status::OK(); +} + Status Graph::InlineFunction(Node& callnode) { - const auto& model_path = ModelPath(); - auto output_edges = callnode.GetRelationships().output_edges; + // Remove output edges. Requirement for RemoveNode() below. + auto output_edges = callnode.GetRelationships().output_edges; // copy so RemoveEdge doesn't invalidate iterator for (const auto& output_edge : output_edges) { RemoveEdge(callnode.Index(), output_edge.GetNode().Index(), output_edge.GetSrcArgIndex(), output_edge.GetDstArgIndex()); } // create a uniq_identifier to append to every node name and intermediate input\outputs // to make sure there are no unintended duplicates - std::stringstream ss; - ss << "_inline_" << callnode.OpType(); - auto uniq_identifier = GenerateNodeName(ss.str()); + std::string base_uniq_identifier{"_inlfunc_"}; + base_uniq_identifier.append(callnode.OpType()); + const auto uniq_identifier = GenerateNodeName(base_uniq_identifier); + // Replace a (function-call) node by an inlined graph. if (!callnode.GetFunctionBody()) { // This is the normal use-case: inlining a FunctionProto (representing // a model-local function or a schema-defined function). - FunctionProto inlined_fp; + ONNX_NAMESPACE::FunctionProto inlined_fp; ORT_ENFORCE(callnode.TryGetFunctionProto(inlined_fp), "Node has no function body and cannot be inlined."); - function_utils::Specialize(inlined_fp, callnode, uniq_identifier); - auto to_node_arg = [this](const std::string& name) { - return &this->GetOrCreateNodeArg(name, nullptr); - }; - - // Process constant nodes first and create NodeArg for these as they become initializers - // It is important for the initializers to have NodeArg created, first they are needed - // if the initializer is unused and removed, second if the node depends on the initializer, - // we can have Type attached to it. - for (const auto& inlined_node : inlined_fp.node()) { - if (inlined_node.op_type() == kConstant) { - // Copy constant nodes _value to name_to_initial_tensor_ - const gsl::not_null tensor{graph_proto_->add_initializer()}; - ORT_RETURN_IF_ERROR(utils::ConstantNodeProtoToTensorProto(inlined_node, model_path, *tensor, inlined_node.output(0))); - auto insert_result = name_to_initial_tensor_.emplace(tensor->name(), tensor); - ORT_ENFORCE(insert_result.second, "Constant node name: ", tensor->name(), " in inlined function: ", - inlined_fp.name(), " conflicts with graph initializer. Check Specializing code."); - TypeProto t{TypeProtoFromTensorProto(*tensor)}; - ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(tensor->name(), &t)); - } - } - - for (const auto& inlined_node : inlined_fp.node()) { - if (inlined_node.op_type() != kConstant) { - InlinedVector inputs; - InlinedVector outputs; - - for (const auto& tensor_name : inlined_node.input()) - inputs.push_back(to_node_arg(tensor_name)); - - for (const auto& tensor_name : inlined_node.output()) - outputs.push_back(to_node_arg(tensor_name)); - - onnxruntime::NodeAttributes new_attr_map; - new_attr_map.reserve(inlined_node.attribute_size()); - for (const auto& node_attr : inlined_node.attribute()) { - onnx::AttributeProto attr_copy = node_attr; - new_attr_map[node_attr.name()] = std::move(attr_copy); - } - AddNode(inlined_node.name(), inlined_node.op_type(), - inlined_node.doc_string(), inputs, outputs, &new_attr_map, inlined_node.domain()); - } - } + // Make all the names unique and resolve nested graphs inputs to the outer scope. + function_utils::Specialize(inlined_fp, callnode, uniq_identifier); + // In this case, global Resolve() will take care of everything. + ORT_RETURN_IF_ERROR(InlineFunctionProto(inlined_fp)); } else { // Uncommon scenario. Inlining a node representing a fused sub-graph. // TODO: Unclear that this feature is needed. Can this be removed? @@ -4115,15 +4147,7 @@ Status Graph::InlineFunction(Node& callnode) { // Copy constant nodes _value to name_to_initial_tensor_ ONNX_NAMESPACE::NodeProto subgraph_node_proto{}; subgraph_node.ToProto(subgraph_node_proto); - const gsl::not_null tensor{graph_proto_->add_initializer()}; - ORT_RETURN_IF_ERROR(utils::ConstantNodeProtoToTensorProto(subgraph_node_proto, model_path, *tensor, subgraph_node_proto.output(0))); - auto insert_result = name_to_initial_tensor_.emplace(tensor->name(), tensor); - ORT_ENFORCE(insert_result.second, "Constant node name: ", tensor->name(), " in inlined subgraph: ", - subgraph.Name(), " conflicts with graph initializer. Check Specializing code."); - if (GetNodeArg(tensor->name()) == nullptr) { - TypeProto t{TypeProtoFromTensorProto(*tensor)}; - ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(tensor->name(), &t)); - } + ORT_RETURN_IF_ERROR(AddConstantProtoAsInitializer(subgraph_node_proto, std::nullopt)); } } diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index 05747a7e5124d..076332a65c8f2 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -41,6 +41,35 @@ namespace onnxruntime { #if !defined(ORT_MINIMAL_BUILD) +void Model::RemoveLocalFunctionsProtos(const InlinedHashSet& retained) { + auto* local_functions = model_proto_.mutable_functions(); + if (retained.empty()) { + model_local_function_templates_maps_.clear(); + model_local_functions_.clear(); + local_functions->erase(local_functions->begin(), local_functions->end()); + } else { + const auto retained_end = retained.cend(); + for (auto it = model_local_functions_.begin(); + it != model_local_functions_.end();) { + if (retained.find(it->first) == retained_end) { + model_local_function_templates_maps_.erase(it->first); + it = model_local_functions_.erase(it); + } else { + ++it; + } + } + + for (auto it = local_functions->begin(); it != local_functions->end();) { + const auto function_id = function_utils::GetFunctionIdentifier(it->domain(), it->name()); + if (retained.find(function_id) == retained_end) { + it = local_functions->erase(it); + } else { + ++it; + } + } + } +} + static constexpr int DEFAULT_PROTOBUF_BLOCK_SIZE = 4 * 1024 * 1024; Model::Model(const std::string& graph_name, @@ -95,10 +124,10 @@ Model::Model(const std::string& graph_name, for (auto& func : model_local_functions) { auto func_ptr = model_proto_.add_functions(); func_ptr->CopyFrom(func); - model_local_functions_.insert_or_assign(function_utils::GetFunctionIdentifier(func_ptr->domain(), func_ptr->name()), func_ptr); + model_local_functions_.insert_or_assign(function_utils::GetFunctionIdentifier(func_ptr->domain(), func_ptr->name()), + func_ptr); } - model_local_function_templates_.reserve(model_proto_.functions().size()); model_local_function_templates_maps_.reserve(model_proto_.functions().size()); for (auto& func : model_proto_.functions()) { auto func_schema_ptr = function_utils::CreateSchema(func.domain(), @@ -111,8 +140,8 @@ Model::Model(const std::string& graph_name, auto func_template_ptr = std::make_unique(); func_template_ptr->op_schema_ = std::move(func_schema_ptr); func_template_ptr->onnx_func_proto_ = &func; - model_local_function_templates_.push_back(std::move(func_template_ptr)); - model_local_function_templates_maps_[function_utils::GetFunctionIdentifier(func.domain(), func.name())] = model_local_function_templates_.back().get(); + model_local_function_templates_maps_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name()), + std::move(func_template_ptr)); } // need to call private ctor so can't use make_shared @@ -220,7 +249,6 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path, model_local_functions_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name()), &func); } - model_local_function_templates_.reserve(model_proto_.functions().size()); model_local_function_templates_maps_.reserve(model_proto_.functions().size()); for (auto& func : model_proto_.functions()) { auto func_schema_ptr = function_utils::CreateSchema(func.domain(), @@ -233,9 +261,7 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path, auto func_template_ptr = std::make_unique(); func_template_ptr->op_schema_ = std::move(func_schema_ptr); func_template_ptr->onnx_func_proto_ = &func; - model_local_function_templates_.push_back(std::move(func_template_ptr)); - model_local_function_templates_maps_[function_utils::GetFunctionIdentifier(func.domain(), func.name())] = - model_local_function_templates_.back().get(); + model_local_function_templates_maps_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name()), std::move(func_template_ptr)); } // create instance. need to call private ctor so can't use make_unique @@ -244,7 +270,7 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path, logger, options.strict_shape_type_inference)); } -const InlinedHashMap& Model::GetModelLocalFunctionTemplates() const { +const NodeHashMap>& Model::GetModelLocalFunctionTemplates() const { return model_local_function_templates_maps_; } @@ -332,7 +358,7 @@ const Graph& Model::MainGraph() const noexcept { } #if !defined(ORT_MINIMAL_BUILD) -ModelProto Model::ToProto() { +ModelProto Model::ToProto() const { // We want to return back the original proto // To that end invoke const overload of ToGraphProto() // that returns by value and, therefore, allows us to filter @@ -346,7 +372,7 @@ ModelProto Model::ToProto() { ModelProto Model::ToGraphProtoWithExternalInitializers(const std::string& external_file_name, const PathString& file_path, - size_t initializer_size_threshold) { + size_t initializer_size_threshold) const { ModelProto result(model_proto_); const auto& graph = *graph_; *(result.mutable_graph()) = graph.ToGraphProtoWithExternalInitializers(external_file_name, diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 6bdb68dd734f0..4ce6660b794bc 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -139,7 +139,7 @@ class Model { // Returns empty string if not specified. const std::string GraphDocString() const; - const InlinedHashMap& GetModelLocalFunctionTemplates() const; + const NodeHashMap>& GetModelLocalFunctionTemplates() const; #else // Get model's IR version. @@ -182,14 +182,14 @@ class Model { #if !defined(ORT_MINIMAL_BUILD) // Get model's serialization proto data. - ONNX_NAMESPACE::ModelProto ToProto(); + ONNX_NAMESPACE::ModelProto ToProto() const; // Get model's serialization proto data. // Save initializer larger than the given threshold (in bytes) into an external binary file // with the given name. This function is useful to avoid hitting the size limit of protobuf files. ONNX_NAMESPACE::ModelProto ToGraphProtoWithExternalInitializers(const std::string& external_file_name, const PathString& file_path, - size_t initializer_size_threshold); + size_t initializer_size_threshold) const; #ifdef _WIN32 static common::Status Save(Model& model, const std::wstring& file_path); @@ -291,6 +291,13 @@ class Model { common::Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder, flatbuffers::Offset& model) const; + /// + /// Frees local function definitions in the model, excluding those in the `retained` set. + /// Called from GraphPartitioner::InlineFunctionsAOT. + /// + /// contains function IDs that should not be removed. + void RemoveLocalFunctionsProtos(const InlinedHashSet& retained); + #endif // !defined(ORT_MINIMAL_BUILD) static common::Status LoadFromOrtFormat(const onnxruntime::fbs::Model& fbs_model, @@ -312,14 +319,12 @@ class Model { // this map will be used for the local functions' schema's type/shape inference. // This container is used by ONNX code and must be an std::unordered_map. std::unordered_map model_local_functions_; - // this is the container that host the generated schemas for model local functions. - // the generated schemare will be used for graph resolving and type/shape inference. - // those schemas' type/shape inference will reference to the model_local_functions_ as context, - // so need to keep them with same lifetime. - InlinedVector> model_local_function_templates_; // this is the map from function id to the local function template. // this map will be used by graph to instantiate the function body. - InlinedHashMap model_local_function_templates_maps_; + // Defined as a node based map so the memory is released when not all of the functions + // are inlined and removed. + NodeHashMap> model_local_function_templates_maps_; + #else // properties that would normally come from ModelProto std::string producer_version_; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index b4d47652942b7..cad55afdf73ac 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -984,14 +984,25 @@ common::Status InferenceSession::Load() { common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool saving_model_in_ort_format) { // The transformer order: - // 1. ensure potential QDQ node units have unique DQ nodes (required transformer). + // 1. Ensure we inline as many functions as possible. We refer to it as Ahead Of Time (AOT) function inlining. + // 2. ensure potential QDQ node units have unique DQ nodes (required transformer). // - This is a required transformer as the ORT code has a hard requirement there are no overlapping QDQ node units. // - We run it here in case optimizers are disabled. - // 2. run level 1 optimizations. these only use ONNX operators. - // 3. partition nodes based on EP capabilities. EPs may fuse nodes during this process. - // 4. run level 2+ optimizations. level 2 and 3 optimizations use contrib ops. - // 5. insert cast nodes (required transformer). - // 6. insert copy nodes (required transformer). + // 3. run level 1 optimizations. these only use ONNX operators. + // 4. partition nodes based on EP capabilities. EPs may fuse nodes during this process. + // 5. run level 2+ optimizations. level 2 and 3 optimizations use contrib ops. + // 6. insert cast nodes (required transformer). + // 7. insert copy nodes (required transformer). + + // Run Ahead Of time function inlining + GraphPartitioner partitioner(kernel_registry_manager_, execution_providers_); + if (const bool disable_aot_function_inlining = + session_options_.config_options.GetConfigOrDefault( + kOrtSessionOptionsDisableAheadOfTimeFunctionInlining, "0") == "1"; + !disable_aot_function_inlining) { + ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.InlineFunctionsAOT(*model_, + execution_providers_, kernel_registry_manager_)); + } auto apply_transformer_once = [](const GraphTransformer& transformer, const logging::Logger& logger, Graph& graph) { @@ -1075,7 +1086,6 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool } // Do partitioning based on execution providers' capabilities. - GraphPartitioner partitioner(kernel_registry_manager_, execution_providers_); ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state_->GetMutableFuncMgr(), transform_layout_fn, mode, debug_graph_fn)); diff --git a/onnxruntime/test/framework/function_test.cc b/onnxruntime/test/framework/function_test.cc index 6e745776ab6b0..41274ee0dedfa 100644 --- a/onnxruntime/test/framework/function_test.cc +++ b/onnxruntime/test/framework/function_test.cc @@ -6,36 +6,47 @@ #include "onnx/defs/parser.h" #include "core/common/span_utils.h" -#include "core/framework/float8.h" +#include "core/framework/customregistry.h" +#include "core/framework/op_kernel.h" #include "core/graph/model.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "core/session/inference_session.h" #include "test/test_environment.h" #include "test/framework/test_utils.h" +#include "inference_session_wrapper.h" #include "test/common/tensor_op_test_utils.h" #include "test/util/include/asserts.h" +#include "test/providers/internal_testing/internal_testing_execution_provider.h" + // Unit tests to check the implementation of functions, model-local functions, // function-inlining etc. namespace onnxruntime { namespace test { -static void Check(const char* source, - const char* input_name, std::vector input_values, - const char* output_name, std::vector output_values) { - // Convert source-representation of model to ModelProto: +// Convert source-representation of model to ModelProto: +static void ParseOnnxSource(const char* source, std::string& result) { ONNX_NAMESPACE::OnnxParser parser(source); ONNX_NAMESPACE::ModelProto model; auto parse_status = parser.Parse(model); ASSERT_TRUE(parse_status.IsOK()) << parse_status.ErrorMessage(); ASSERT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected."; - // Serialize and then load model: + // Serialize std::string serialized_model; const bool serialization_status = model.SerializeToString(&serialized_model); ASSERT_TRUE(serialization_status) << "Failed to serialize proto to string"; + result = std::move(serialized_model); +} + +static void Check(const char* source, + const char* input_name, std::vector input_values, + const char* output_name, std::vector output_values) { + // Serialize and then load model: + std::string serialized_model; + ParseOnnxSource(source, serialized_model); SessionOptions session_options; InferenceSession session_object{session_options, GetEnvironment()}; @@ -76,8 +87,8 @@ static void Check(const char* source, } } -TEST(FunctionTest, Basic) { - const char* code = R"( +namespace { +const char* basic_code = R"( < ir_version: 8, opset_import: [ "" : 16, "local" : 1 ] @@ -96,8 +107,10 @@ TEST(FunctionTest, Basic) { ly = Mul (lx, two) } )"; +} - Check(code, "x", {1.0, 2.0, 3.0}, "y", {2.0, 4.0, 6.0}); +TEST(FunctionTest, Basic) { + Check(basic_code, "x", {1.0, 2.0, 3.0}, "y", {2.0, 4.0, 6.0}); } // Check that variables are renamed to avoid conflicts when multiple @@ -521,5 +534,56 @@ TEST(FunctionTest, ConstantFoldingInSubGraph) { Check(code, "X", {1.0, 2.0, 3.0}, "Y", {3.0, 4.0, 5.0, 3.0, 4.0, 5.0, 3.0, 4.0, 5.0}); } +TEST(FunctionTest, TestInlinedLocalFunctionRemoved) { + std::string serialized_model; + ParseOnnxSource(basic_code, serialized_model); + + // Default is to do AOT Function inlining + SessionOptions session_options; + InferenceSessionWrapper session_object{session_options, GetEnvironment()}; + + std::stringstream sstr(serialized_model); + auto status = session_object.Load(sstr); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + auto model_proto = session_object.GetModel().ToProto(); + ASSERT_EQ(1, model_proto.functions_size()); + + status = session_object.Initialize(); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + // All functions removed + model_proto = session_object.GetModel().ToProto(); + ASSERT_EQ(0, model_proto.functions_size()); +} + +TEST(FunctionTest, TestInlinedLocalFunctionNotRemoved) { + std::string serialized_model; + ParseOnnxSource(basic_code, serialized_model); + + // Default is to do AOT Function inlining + SessionOptions session_options; + InferenceSessionWrapper session_object{session_options, GetEnvironment()}; + + using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider; + const std::unordered_set empty_set; + auto internal_testing_ep = std::make_unique(empty_set, empty_set, DataLayout::NCHW); + internal_testing_ep->EnableStaticKernels().TakeAllNodes(); + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(std::move(internal_testing_ep))); + + std::stringstream sstr(serialized_model); + ASSERT_STATUS_OK(session_object.Load(sstr)); + + auto model_proto = session_object.GetModel().ToProto(); + ASSERT_EQ(1, model_proto.functions_size()); + + ASSERT_STATUS_OK(session_object.Initialize()); + + // myfun is not removed because it was claimed by InternalTestingEP + model_proto = session_object.GetModel().ToProto(); + ASSERT_EQ(1, model_proto.functions_size()); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/util/include/inference_session_wrapper.h b/onnxruntime/test/util/include/inference_session_wrapper.h index eab83c26b681f..757caf7987d35 100644 --- a/onnxruntime/test/util/include/inference_session_wrapper.h +++ b/onnxruntime/test/util/include/inference_session_wrapper.h @@ -12,9 +12,8 @@ namespace test { // InferenceSession wrapper class for use in tests where we need access to the Graph and SessionState class InferenceSessionWrapper : public InferenceSession { public: - explicit InferenceSessionWrapper(const SessionOptions& session_options, - const Environment& env) : InferenceSession(session_options, env) { - } + // Expose the constructors from InferenceSession + using InferenceSession::InferenceSession; const Graph& GetGraph() const { return model_->MainGraph();