Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functions Ahead Of Time inlininng #17764

Merged
merged 22 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -1385,7 +1385,7 @@ class Graph {
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
const logging::Logger& logger,
bool strict_shape_type_inference,
std::function<void()> model_functions_remover);
std::function<void(const InlinedHashSet<std::string>&)> model_functions_remover);

// internal use by the Graph class only
Graph(const Model& owning_model,
Expand All @@ -1397,13 +1397,13 @@ class Graph {
const Node* parent_node,
const logging::Logger& logger,
bool strict_shape_type_inference,
std::function<void()> model_functions_remover);
std::function<void(const InlinedHashSet<std::string>&)> model_functions_remover);
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved

ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Graph);

void RemoveLocalFunctions() {
void RemoveLocalFunctions(const InlinedHashSet<std::string>& retained) {
pranavsharma marked this conversation as resolved.
Show resolved Hide resolved
if (model_functions_remover_) {
model_functions_remover_();
model_functions_remover_(retained);
}
}

Expand Down Expand Up @@ -1626,7 +1626,7 @@ class Graph {
#if !defined(ORT_MINIMAL_BUILD)
// A function to call into the model to remove local functions which
// are inlined.
std::function<void()> model_functions_remover_;
std::function<void(const InlinedHashSet<std::string>&)> model_functions_remover_;
#endif

// Graph nodes.
Expand Down
17 changes: 13 additions & 4 deletions onnxruntime/core/framework/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "core/framework/kernel_registry_manager.h"
#include "core/framework/kernel_registry.h"
#include "core/graph/function.h"
#include "core/graph/function_utils.h"
#include "core/graph/graph_viewer.h"

// uncomment this line to count non-CUDA ops in ONNX domain
Expand Down Expand Up @@ -253,7 +254,7 @@
kernel_registries_for_ep,
kernel_registry_mgr.GetKernelTypeStrResolver()};

// TODO: Provide EP with a capability to look inside the functions.

Check warning on line 257 in onnxruntime/core/framework/graph_partitioner.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/framework/graph_partitioner.cc#L257

Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
Raw output
onnxruntime/core/framework/graph_partitioner.cc:257:  Missing username in TODO; it should look like "// TODO(my_username): Stuff."  [readability/todo] [2]
capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup);
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved

return Status::OK();
Expand Down Expand Up @@ -556,6 +557,7 @@
static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_providers,
const KernelRegistryManager& kernel_registry_mgr,
Graph& graph,
InlinedHashSet<std::string>& not_inlined,
size_t& inlined_count) {
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
// doing it here saves all providers checking for this in GetCapability
Expand All @@ -570,6 +572,7 @@
ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers,
kernel_registry_mgr,
*subgraph,
not_inlined,
inlined_count));
}
}
Expand Down Expand Up @@ -609,16 +612,20 @@
}
}

// TODO: Insert version check. We need to collect all the versions

Check warning on line 615 in onnxruntime/core/framework/graph_partitioner.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/framework/graph_partitioner.cc#L615

Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
Raw output
onnxruntime/core/framework/graph_partitioner.cc:615:  Missing username in TODO; it should look like "// TODO(my_username): Stuff."  [readability/todo] [2]
// that imported by the model. If the version is not supported by
// the model, we can not inline it.

for (auto node_index : inline_candidates) {
if (claimed_by_ep.count(node_index) == 0) {
auto* node = graph.GetNode(node_index);
if (node != nullptr) {
auto* node = graph.GetNode(node_index);
if (node != nullptr) {
if (claimed_by_ep.count(node_index) == 0) {
ORT_RETURN_IF_ERROR(graph.InlineFunction(*node));
++inlined_count;
} else {
// OpType is the same as function name.
auto function_id = function_utils::GetFunctionIdentifier(node->Domain(), node->OpType());
ORT_IGNORE_RETURN_VALUE(not_inlined.insert(std::move(function_id)));
}
}
}
Expand Down Expand Up @@ -791,11 +798,13 @@
Status GraphPartitioner::InlineFunctionsAOT(Graph& graph,
const ExecutionProviders& execution_providers,
const KernelRegistryManager& kernel_registry_manager) const {
InlinedHashSet<std::string> not_inlined;

Check warning on line 801 in onnxruntime/core/framework/graph_partitioner.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/framework/graph_partitioner.cc#L801

Add #include <string> for string [build/include_what_you_use] [4]
Raw output
onnxruntime/core/framework/graph_partitioner.cc:801:  Add #include <string> for string  [build/include_what_you_use] [4]
do {
size_t inlined_count = 0;
ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers,
kernel_registry_manager,
graph,
not_inlined,
inlined_count));

if (inlined_count == 0) {
Expand All @@ -805,7 +814,7 @@
ORT_RETURN_IF_ERROR(graph.Resolve());
} while (true);

graph.RemoveLocalFunctions();
graph.RemoveLocalFunctions(not_inlined);

return Status::OK();
}
Expand Down
9 changes: 5 additions & 4 deletions onnxruntime/core/graph/function_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,8 @@ class Inliner {
// Replace given name with a unique version of the name, and cache the
// renaming-binding in current scope.
void make_unique(std::string& name) {
auto new_name = prefix_ + name;
auto new_name{prefix_};
new_name.append("/").append(name);
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
auto& current_scope = rename_scopes_.back();
current_scope[name] = new_name;
name = std::move(new_name);
Expand Down Expand Up @@ -410,7 +411,7 @@ class Inliner {
std::string rename_as = actuals.Get(i);
if constexpr (isOutput) {
if (rename_as.empty())
rename_as.assign(prefix_).append(formal);
rename_as.assign(prefix_).append("/").append(formal);
}
current_scope[formal] = rename_as;
if (!rename_as.empty())
Expand All @@ -420,7 +421,7 @@ class Inliner {
std::string& formal = *formals.Mutable(i);
std::string rename_as;
if constexpr (isOutput) {
rename_as.assign(prefix_).append(formal);
rename_as.assign(prefix_).append("/").append(formal);
}
current_scope[formal] = rename_as;
if (!rename_as.empty())
Expand All @@ -431,7 +432,7 @@ class Inliner {
// Process a node:
void transform(NodeProto& n) {
if (!n.name().empty())
n.set_name(prefix_ + n.name());
n.set_name(prefix_ + "/" + n.name());

for (auto& x : *n.mutable_input()) {
rename(x, false);
Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,7 @@ Graph::Graph(const Model& owning_model,
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
const logging::Logger& logger,
bool strict_shape_type_inference,
std::function<void()> model_functions_remover)
std::function<void(const InlinedHashSet<std::string>&)> model_functions_remover)
: Graph(owning_model, graph_proto, domain_to_version, ir_version,
schema_registry, nullptr, nullptr, logger,
strict_shape_type_inference, std::move(model_functions_remover)) {}
Expand All @@ -1155,7 +1155,7 @@ Graph::Graph(const Model& owning_model,
IOnnxRuntimeOpSchemaCollectionPtr schema_registry, Graph* parent_graph, const Node* parent_node,
const logging::Logger& logger,
bool strict_shape_type_inference,
std::function<void()> model_functions_remover)
std::function<void(const InlinedHashSet<std::string>&)> model_functions_remover)
: owning_model_(owning_model),
graph_proto_(graph_proto),
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
Expand Down Expand Up @@ -1329,7 +1329,7 @@ Graph::Graph(Graph& parent_graph, const Node& parent_node, ONNX_NAMESPACE::Graph
&parent_node,
parent_graph.logger_,
parent_graph.strict_shape_type_inference_,
std::function<void()>()) {
std::function<void(const InlinedHashSet<std::string>&)>()) {
}

Graph::Graph(const Model& owning_model,
Expand All @@ -1347,7 +1347,7 @@ Graph::Graph(const Model& owning_model,
nullptr,
logger,
strict_shape_type_inference,
std::function<void()>()) {
std::function<void(const InlinedHashSet<std::string>&)>()) {
}

void Graph::InitializeStateFromModelFileGraphProto() {
Expand Down Expand Up @@ -4104,7 +4104,7 @@ Status Graph::InlineFunction(Node& callnode) {

// create a uniq_identifier to append to every node name and intermediate input\outputs
// to make sure there are no unintended duplicates
std::string base_uniq_identifier{"_inline_"};
std::string base_uniq_identifier{"_inlfunc/"};
base_uniq_identifier.append(callnode.OpType());
const auto uniq_identifier = GenerateNodeName(base_uniq_identifier);

Expand Down
62 changes: 43 additions & 19 deletions onnxruntime/core/graph/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,36 @@

#if !defined(ORT_MINIMAL_BUILD)

void Model::RemoveLocalFunctionsProtos(const InlinedHashSet<std::string>& retained) {
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
auto* local_functions = model_proto_.mutable_functions();
if (retained.empty()) {
model_local_function_templates_maps_.clear();
model_local_functions_.clear();
local_functions->erase(local_functions->begin(), local_functions->end());
} else {
const auto retained_end = retained.cend();
for (auto it = model_local_functions_.begin();
it != model_local_functions_.end();) {
if (retained.find(it->first) == retained_end) {
model_local_function_templates_maps_.erase(it->first);
// Post increment for compatibility between abseil and STL
model_local_functions_.erase(it++);
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
} else {
++it;
}
}

for (auto it = local_functions->begin(); it != local_functions->end();) {
const auto function_id = function_utils::GetFunctionIdentifier(it->domain(), it->name());
if (retained.find(function_id) == retained_end) {
it = local_functions->erase(it);
} else {
++it;
}
}
}
}

static constexpr int DEFAULT_PROTOBUF_BLOCK_SIZE = 4 * 1024 * 1024;

Model::Model(const std::string& graph_name,
Expand Down Expand Up @@ -95,7 +125,8 @@
for (auto& func : model_local_functions) {
auto func_ptr = model_proto_.add_functions();
func_ptr->CopyFrom(func);
model_local_functions_.insert_or_assign(function_utils::GetFunctionIdentifier(func_ptr->domain(), func_ptr->name()), func_ptr);
model_local_functions_.insert_or_assign(function_utils::GetFunctionIdentifier(func_ptr->domain(), func_ptr->name()),
func_ptr);
}

model_local_function_templates_maps_.reserve(model_proto_.functions().size());
Expand All @@ -110,12 +141,14 @@
auto func_template_ptr = std::make_unique<FunctionTemplate>();
func_template_ptr->op_schema_ = std::move(func_schema_ptr);
func_template_ptr->onnx_func_proto_ = &func;
model_local_function_templates_maps_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name()), std::move(func_template_ptr));
model_local_function_templates_maps_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name()),

Check warning on line 144 in onnxruntime/core/graph/model.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/graph/model.cc#L144

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/graph/model.cc:144:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
std::move(func_template_ptr));
}

std::function<void()> model_functions_remover = [this]() {
RemoveLocalFunctionsProtos();
};
std::function<void(const InlinedHashSet<std::string>&)> model_functions_remover =
[this](const InlinedHashSet<std::string>& retained) {
RemoveLocalFunctionsProtos(retained);
};
// need to call private ctor so can't use make_shared
GSL_SUPPRESS(r.11)
graph_.reset(new Graph(*this, model_proto_.mutable_graph(), *p_domain_to_version, IrVersion(), schema_registry,
Expand Down Expand Up @@ -233,19 +266,20 @@
auto func_template_ptr = std::make_unique<FunctionTemplate>();
func_template_ptr->op_schema_ = std::move(func_schema_ptr);
func_template_ptr->onnx_func_proto_ = &func;
model_local_function_templates_maps_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name()), std::move(func_template_ptr));

Check warning on line 269 in onnxruntime/core/graph/model.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/graph/model.cc#L269

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/graph/model.cc:269:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}

std::function<void()> model_functions_remover = [this]() {
RemoveLocalFunctionsProtos();
};
std::function<void(const InlinedHashSet<std::string>& retained)> model_functions_remover =
[this](const InlinedHashSet<std::string>& retained) {
RemoveLocalFunctionsProtos(retained);
};
// create instance. need to call private ctor so can't use make_unique
GSL_SUPPRESS(r.11)
graph_.reset(new Graph(*this, model_proto_.mutable_graph(), domain_to_version, IrVersion(), schema_registry,
logger, options.strict_shape_type_inference, std::move(model_functions_remover)));
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
}

const InlinedHashMap<std::string, std::unique_ptr<FunctionTemplate>>& Model::GetModelLocalFunctionTemplates() const {
const NodeHashMap<std::string, std::unique_ptr<FunctionTemplate>>& Model::GetModelLocalFunctionTemplates() const {
return model_local_function_templates_maps_;
}

Expand Down Expand Up @@ -803,16 +837,6 @@
return Status::OK();
}

void Model::RemoveLocalFunctionsProtos() {
model_local_function_templates_maps_.clear();
model_local_functions_.clear();

auto* local_functions = model_proto_.mutable_functions();
for (auto it = local_functions->begin(); it != local_functions->end();) {
it = local_functions->erase(it);
}
}

#endif // !defined(ORT_MINIMAL_BUILD)

Model::Model() : model_path_{} {
Expand Down
15 changes: 11 additions & 4 deletions onnxruntime/core/graph/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class Model {
// Returns empty string if not specified.
const std::string GraphDocString() const;

const InlinedHashMap<std::string, std::unique_ptr<FunctionTemplate>>& GetModelLocalFunctionTemplates() const;
const NodeHashMap<std::string, std::unique_ptr<FunctionTemplate>>& GetModelLocalFunctionTemplates() const;

#else
// Get model's IR version.
Expand Down Expand Up @@ -314,9 +314,16 @@ class Model {
std::unordered_map<std::string, const ONNX_NAMESPACE::FunctionProto*> model_local_functions_;
// this is the map from function id to the local function template.
// this map will be used by graph to instantiate the function body.
InlinedHashMap<std::string, std::unique_ptr<FunctionTemplate>> model_local_function_templates_maps_;

void RemoveLocalFunctionsProtos();
// Defined as a node based map so the memory is released when not all of the functions
// are inlined and removed.
NodeHashMap<std::string, std::unique_ptr<FunctionTemplate>> model_local_function_templates_maps_;

/// <summary>
/// The functions cleans local function definitions in the model excluding
/// those that are contained within the retained
/// </summary>
/// <param name="retained">contains function IDs that should not be removed.</param>
void RemoveLocalFunctionsProtos(const InlinedHashSet<std::string>& retained);

#else
// properties that would normally come from ModelProto
Expand Down
Loading