Skip to content

Commit

Permalink
Add kill switch for AOT
Browse files Browse the repository at this point in the history
  Add tests
  Rename the function.
  Remove functions directly from the partitioner.
  • Loading branch information
yuslepukhin committed Oct 20, 2023
1 parent c945f96 commit 2db2e59
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 69 deletions.
20 changes: 3 additions & 17 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -1154,7 +1154,7 @@ class Graph {
@returns Status indicating success or providing an error message.
*/

Status FunctionToGraph(const ONNX_NAMESPACE::FunctionProto& func_to_inline);
Status InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_inline);

/** Mark a NodeArg name as coming from the outer scope when programmatically constructing a Graph that will
be used as a GraphProto attribute in another Node.
Expand Down Expand Up @@ -1384,8 +1384,7 @@ class Graph {
Version ir_version,
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
const logging::Logger& logger,
bool strict_shape_type_inference,
std::function<void(const InlinedHashSet<std::string>&)> model_functions_remover);
bool strict_shape_type_inference);

// internal use by the Graph class only
Graph(const Model& owning_model,
Expand All @@ -1396,17 +1395,10 @@ class Graph {
Graph* parent_graph,
const Node* parent_node,
const logging::Logger& logger,
bool strict_shape_type_inference,
std::function<void(const InlinedHashSet<std::string>&)> model_functions_remover);
bool strict_shape_type_inference);

ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Graph);

void RemoveLocalFunctions(const InlinedHashSet<std::string>& retained) {
if (model_functions_remover_) {
model_functions_remover_(retained);
}
}

private:
void InitializeStateFromModelFileGraphProto();

Expand Down Expand Up @@ -1623,12 +1615,6 @@ class Graph {
InlinedHashMap<std::string, std::reference_wrapper<ONNX_NAMESPACE::OpSchema>> reusable_fused_schema_map_;
#endif // !defined(ORT_MINIMAL_BUILD)

#if !defined(ORT_MINIMAL_BUILD)
// A function to call into the model to remove local functions which
// are inlined.
std::function<void(const InlinedHashSet<std::string>&)> model_functions_remover_;
#endif

// Graph nodes.
// Element in <nodes_> may be nullptr due to graph optimization.
std::vector<std::unique_ptr<Node>> nodes_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enab
// GeluApproximation has side effects which may change the inference results. It is disabled by default due to this.
static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation";

// This setting controls whether to enable AheadOfTime function inlining.
// AOT function inlining examines the graph and attempts to inline as many locally defined functions in the model
// as possible with the help of enabled execution providers.
// This can reduce the number of function calls and improve performance because it is done before
// Level1 optimizers and constant folding. However, under some circumstances, when the EPs are not available,
// one can disable the AOT inlining, produce an optimized model and postpone AOT until run time.
// "0": enable; "1": disable.
// Its default value is "0".
static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining = "session.disable_aot_function_inlining";

#ifdef ENABLE_TRAINING
// Specifies a list of op types for memory footprint reduction.
// The value should be a ","-delimited list of pair of
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/framework/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "core/graph/function.h"
#include "core/graph/function_utils.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/model.h"

// uncomment this line to count non-CUDA ops in ONNX domain
// #define COUNT_NON_CUDA_OPS
Expand Down Expand Up @@ -795,7 +796,8 @@ static Status PartitionOrtFormatModel(const PartitionParams& partition_params,

#ifndef ORT_MINIMAL_BUILD

Status GraphPartitioner::InlineFunctionsAOT(Graph& graph,
Status GraphPartitioner::InlineFunctionsAOT(Model& model,
Graph& graph,
const ExecutionProviders& execution_providers,
const KernelRegistryManager& kernel_registry_manager) const {
InlinedHashSet<std::string> not_inlined;
Expand All @@ -814,7 +816,7 @@ Status GraphPartitioner::InlineFunctionsAOT(Graph& graph,
ORT_RETURN_IF_ERROR(graph.Resolve());
} while (true);

graph.RemoveLocalFunctions(not_inlined);
model.RemoveLocalFunctionsProtos(not_inlined);

return Status::OK();
}
Expand Down
15 changes: 12 additions & 3 deletions onnxruntime/core/framework/graph_partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace onnxruntime {

class ExecutionProviders;
class KernelRegistryManager;
class Model;

class GraphPartitioner {
public:
Expand All @@ -34,15 +35,23 @@ class GraphPartitioner {
const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const;

#ifndef ORT_MINIMAL_BUILD
/// <summary>
// Ahead of Time Function inlining. The main purpose of the function is to inline as many
// functions as possible and delete locally defined functions to reduce the size of the model.
// This would enable other optimizations to be more effective.
// This would make other optimizations to be more effective.
//
// This function performs GetCapability on the graph and its subgraphs bottom up
// and inlines any functions that are not claimed by any of the execution providers.
// This function does not attempt to run layout transformation, and it does not assign EPs.
// The latter will be done by graph partitioning.
Status InlineFunctionsAOT(Graph& graph,
// The latter will be done by graph partitioning after Level1 optimizations are done.
/// </summary>
/// <param name="model">model instance</param>
/// <param name="graph">main graph</param>
/// <param name="execution_providers">execution providers considered</param>
/// <param name="kernel_registry_manager">registry manager</param>
/// <returns></returns>
Status InlineFunctionsAOT(Model& model,
Graph& graph,
const ExecutionProviders& execution_providers,
const KernelRegistryManager& kernel_registry_manager) const;
#endif
Expand Down
20 changes: 7 additions & 13 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1144,26 +1144,22 @@ Graph::Graph(const Model& owning_model,
Version ir_version,
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
const logging::Logger& logger,
bool strict_shape_type_inference,
std::function<void(const InlinedHashSet<std::string>&)> model_functions_remover)
bool strict_shape_type_inference)
: Graph(owning_model, graph_proto, domain_to_version, ir_version,
schema_registry, nullptr, nullptr, logger,
strict_shape_type_inference, std::move(model_functions_remover)) {}
schema_registry, nullptr, nullptr, logger, strict_shape_type_inference) {}

Graph::Graph(const Model& owning_model,
GraphProto* graph_proto, const std::unordered_map<std::string, int>& domain_to_version, Version ir_version,
IOnnxRuntimeOpSchemaCollectionPtr schema_registry, Graph* parent_graph, const Node* parent_node,
const logging::Logger& logger,
bool strict_shape_type_inference,
std::function<void(const InlinedHashSet<std::string>&)> model_functions_remover)
bool strict_shape_type_inference)
: owning_model_(owning_model),
graph_proto_(graph_proto),
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
runtime_optimizations_ptr_(std::make_unique<RuntimeOptimizationRecordContainer>()),
runtime_optimizations_(*runtime_optimizations_ptr_),
#endif
schema_registry_(schema_registry),
model_functions_remover_(std::move(model_functions_remover)),
graph_resolve_needed_(true),
domain_to_version_(domain_to_version),
ir_version_(ir_version),
Expand Down Expand Up @@ -1328,8 +1324,7 @@ Graph::Graph(Graph& parent_graph, const Node& parent_node, ONNX_NAMESPACE::Graph
&parent_graph,
&parent_node,
parent_graph.logger_,
parent_graph.strict_shape_type_inference_,
std::function<void(const InlinedHashSet<std::string>&)>()) {
parent_graph.strict_shape_type_inference_) {
}

Graph::Graph(const Model& owning_model,
Expand All @@ -1346,8 +1341,7 @@ Graph::Graph(const Model& owning_model,
nullptr,
nullptr,
logger,
strict_shape_type_inference,
std::function<void(const InlinedHashSet<std::string>&)>()) {
strict_shape_type_inference) {
}

void Graph::InitializeStateFromModelFileGraphProto() {
Expand Down Expand Up @@ -4052,7 +4046,7 @@ Status Graph::AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& nod
return Status::OK();
}

Status Graph::FunctionToGraph(const ONNX_NAMESPACE::FunctionProto& func_to_inline) {
Status Graph::InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_inline) {
auto to_node_arg = [this](const std::string& name) {
return &this->GetOrCreateNodeArg(name, nullptr);
};
Expand Down Expand Up @@ -4119,7 +4113,7 @@ Status Graph::InlineFunction(Node& callnode) {
function_utils::Specialize(inlined_fp, callnode, uniq_identifier);

// In this case, global Resolve() will take care of everything.
ORT_RETURN_IF_ERROR(FunctionToGraph(inlined_fp));
ORT_RETURN_IF_ERROR(InlineFunctionProto(inlined_fp));
} else {
// Uncommon scenario. Inlining a node representing a fused sub-graph.
// TODO: Unclear that this feature is needed. Can this be removed?
Expand Down
16 changes: 4 additions & 12 deletions onnxruntime/core/graph/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,10 @@ Model::Model(const std::string& graph_name,
std::move(func_template_ptr));
}

std::function<void(const InlinedHashSet<std::string>&)> model_functions_remover =
[this](const InlinedHashSet<std::string>& retained) {
RemoveLocalFunctionsProtos(retained);
};
// need to call private ctor so can't use make_shared
GSL_SUPPRESS(r.11)
graph_.reset(new Graph(*this, model_proto_.mutable_graph(), *p_domain_to_version, IrVersion(), schema_registry,
logger, options.strict_shape_type_inference, std::move(model_functions_remover)));
logger, options.strict_shape_type_inference));
}

Model::Model(const ModelProto& model_proto, const PathString& model_path,
Expand Down Expand Up @@ -269,14 +265,10 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path,
model_local_function_templates_maps_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name()), std::move(func_template_ptr));
}

std::function<void(const InlinedHashSet<std::string>& retained)> model_functions_remover =
[this](const InlinedHashSet<std::string>& retained) {
RemoveLocalFunctionsProtos(retained);
};
// create instance. need to call private ctor so can't use make_unique
GSL_SUPPRESS(r.11)
graph_.reset(new Graph(*this, model_proto_.mutable_graph(), domain_to_version, IrVersion(), schema_registry,
logger, options.strict_shape_type_inference, std::move(model_functions_remover)));
logger, options.strict_shape_type_inference));
}

const NodeHashMap<std::string, std::unique_ptr<FunctionTemplate>>& Model::GetModelLocalFunctionTemplates() const {
Expand Down Expand Up @@ -367,7 +359,7 @@ const Graph& Model::MainGraph() const noexcept {
}

#if !defined(ORT_MINIMAL_BUILD)
ModelProto Model::ToProto() {
ModelProto Model::ToProto() const {
// We want to return back the original proto
// To that end invoke const overload of ToGraphProto()
// that returns by value and, therefore, allows us to filter
Expand All @@ -381,7 +373,7 @@ ModelProto Model::ToProto() {

ModelProto Model::ToGraphProtoWithExternalInitializers(const std::string& external_file_name,
const PathString& file_path,
size_t initializer_size_threshold) {
size_t initializer_size_threshold) const {
ModelProto result(model_proto_);
const auto& graph = *graph_;
*(result.mutable_graph()) = graph.ToGraphProtoWithExternalInitializers(external_file_name,
Expand Down
19 changes: 10 additions & 9 deletions onnxruntime/core/graph/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,14 @@ class Model {

#if !defined(ORT_MINIMAL_BUILD)
// Get model's serialization proto data.
ONNX_NAMESPACE::ModelProto ToProto();
ONNX_NAMESPACE::ModelProto ToProto() const;

// Get model's serialization proto data.
// Save initializer larger than the given threshold (in bytes) into an external binary file
// with the given name. This function is useful to avoid hitting the size limit of protobuf files.
ONNX_NAMESPACE::ModelProto ToGraphProtoWithExternalInitializers(const std::string& external_file_name,
const PathString& file_path,
size_t initializer_size_threshold);
size_t initializer_size_threshold) const;

#ifdef _WIN32
static common::Status Save(Model& model, const std::wstring& file_path);
Expand Down Expand Up @@ -291,6 +291,14 @@ class Model {
common::Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder,
flatbuffers::Offset<onnxruntime::fbs::Model>& model) const;

/// <summary>
/// The functions cleans local function definitions in the model excluding
/// those that are contained within the retained.
/// This is called from GraphParitioner::InlineFunctionsAOT.
/// </summary>
/// <param name="retained">contains function IDs that should not be removed.</param>
void RemoveLocalFunctionsProtos(const InlinedHashSet<std::string>& retained);

#endif // !defined(ORT_MINIMAL_BUILD)

static common::Status LoadFromOrtFormat(const onnxruntime::fbs::Model& fbs_model,
Expand Down Expand Up @@ -318,13 +326,6 @@ class Model {
// are inlined and removed.
NodeHashMap<std::string, std::unique_ptr<FunctionTemplate>> model_local_function_templates_maps_;

/// <summary>
/// The functions cleans local function definitions in the model excluding
/// those that are contained within the retained
/// </summary>
/// <param name="retained">contains function IDs that should not be removed.</param>
void RemoveLocalFunctionsProtos(const InlinedHashSet<std::string>& retained);

#else
// properties that would normally come from ModelProto
std::string producer_version_;
Expand Down
8 changes: 7 additions & 1 deletion onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,13 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool

// Run Ahead Of time function inlining
GraphPartitioner partitioner(kernel_registry_manager_, execution_providers_);
ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.InlineFunctionsAOT(graph, execution_providers_, kernel_registry_manager_));
if (const bool disable_aot_function_inlining =
session_options_.config_options.GetConfigOrDefault(
kOrtSessionOptionsDisableAheadOfTimeFunctionInlining, "0") == "1";
!disable_aot_function_inlining) {
ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.InlineFunctionsAOT(*model_, graph,
execution_providers_, kernel_registry_manager_));
}

auto apply_transformer_once = [](const GraphTransformer& transformer, const logging::Logger& logger,
Graph& graph) {
Expand Down
Loading

0 comments on commit 2db2e59

Please sign in to comment.