Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

If Branch Constant Folding #18105

Merged
merged 65 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from 62 commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
0de61fa
Add InlineSubgraph
yuslepukhin Sep 16, 2023
9bfaabf
AOT Inlining
yuslepukhin Sep 26, 2023
5637ad4
Add Span() based defs back
yuslepukhin Oct 9, 2023
9da7a1e
Lint
yuslepukhin Oct 9, 2023
c3419ad
Implement If node ConstantFolding.
yuslepukhin Oct 10, 2023
3da463d
Merge branch 'main' into yuslepukhin/aot_inline
yuslepukhin Oct 12, 2023
bf3a3b2
Address some review comments
yuslepukhin Oct 12, 2023
18df323
Revert InlineFunction
yuslepukhin Oct 12, 2023
5a19dcf
Remove local functions
yuslepukhin Oct 13, 2023
7b3efdc
Common graph_viewer
yuslepukhin Oct 13, 2023
2d56e5a
Merge branch 'main' into yuslepukhin/aot_inline
yuslepukhin Oct 13, 2023
fd8be36
Merge branch 'yuslepukhin/aot_inline' into yuslepukhin/if_constant_fo…
yuslepukhin Oct 13, 2023
c5ee067
Remove spurious test, lint.
yuslepukhin Oct 13, 2023
7c1f9a6
Merge branch 'yuslepukhin/aot_inline' into yuslepukhin/if_constant_fo…
yuslepukhin Oct 13, 2023
4845104
Prevent If Constant Fold in the QDQ test
yuslepukhin Oct 13, 2023
8dc17a3
If Constant Folding non-recursive
yuslepukhin Oct 14, 2023
c81c1bf
HF Bart works
yuslepukhin Oct 16, 2023
cce536f
Bug fixes
yuslepukhin Oct 17, 2023
39439bb
Merge branch 'main' into yuslepukhin/aot_inline
yuslepukhin Oct 17, 2023
d91740c
Do not remove non-inlined function definitins.
yuslepukhin Oct 17, 2023
bcb402f
Merge branch 'yuslepukhin/aot_inline' into yuslepukhin/if_constant_fo…
yuslepukhin Oct 17, 2023
04c6fd6
Add topo order. HF Bert works.
yuslepukhin Oct 17, 2023
c3c20b4
Compute function id from domain, OpType of the node.
yuslepukhin Oct 18, 2023
1108cc1
Compute function id from domain, OpType of the node.
yuslepukhin Oct 18, 2023
c945f96
Revert back to underscore
yuslepukhin Oct 18, 2023
6a8897d
Merge branch 'yuslepukhin/aot_inline' into yuslepukhin/if_constant_fo…
yuslepukhin Oct 18, 2023
2db2e59
Add kill switch for AOT
yuslepukhin Oct 20, 2023
c86919b
Merge branch 'yuslepukhin/aot_inline' into yuslepukhin/if_constant_fo…
yuslepukhin Oct 20, 2023
d9106d1
build error
yuslepukhin Oct 20, 2023
4468fc1
Merge branch 'yuslepukhin/aot_inline' into yuslepukhin/if_constant_fo…
yuslepukhin Oct 20, 2023
3d4e7b5
Merge branch 'main' into yuslepukhin/aot_inline
yuslepukhin Oct 20, 2023
115cfd1
Build
yuslepukhin Oct 21, 2023
41cf0af
Address code review
yuslepukhin Oct 23, 2023
e80feca
Merge branch 'main' into yuslepukhin/aot_inline
yuslepukhin Oct 23, 2023
a75826b
Merge branch 'yuslepukhin/aot_inline' into yuslepukhin/if_constant_fo…
yuslepukhin Oct 23, 2023
eaad46e
Lint
yuslepukhin Oct 23, 2023
fea5405
Merge branch 'yuslepukhin/aot_inline' into yuslepukhin/if_constant_fo…
yuslepukhin Oct 23, 2023
9705e7a
Merge branch 'main' into yuslepukhin/if_constant_folding
yuslepukhin Oct 24, 2023
7076473
Test parses
yuslepukhin Oct 25, 2023
88519c0
Test and minor fixes
yuslepukhin Oct 25, 2023
ec1fae7
PR ready
yuslepukhin Oct 26, 2023
f074431
Merge branch 'main' into yuslepukhin/if_constant_folding
yuslepukhin Oct 26, 2023
9110692
Address subgraph node reserrection.
yuslepukhin Oct 27, 2023
8d900d6
Do not remove proto on the main graph
yuslepukhin Oct 28, 2023
7527a4e
Fix move bug
yuslepukhin Oct 28, 2023
18f1373
Merge branch 'main' into yuslepukhin/if_constant_folding
yuslepukhin Oct 30, 2023
7b44ae2
Regenerate node protos in subgraphs to which If nodes are inlined,
yuslepukhin Oct 30, 2023
7ac0aaf
Regenerate nodes after Constant Folding a subgraph
yuslepukhin Oct 31, 2023
135cac8
Address compiler error
yuslepukhin Nov 1, 2023
5fed483
Make Node::ToProto() available in extended mimimal builds.
yuslepukhin Nov 2, 2023
c592853
Merge branch 'main' into yuslepukhin/if_constant_folding
yuslepukhin Nov 2, 2023
7ee567b
Alleviate unused arg
yuslepukhin Nov 2, 2023
cf60bab
Address review comments
yuslepukhin Nov 7, 2023
0bf0ac7
Address more comments
yuslepukhin Nov 7, 2023
5ff8410
Make name_mapping map straight to NodeArg.
yuslepukhin Nov 8, 2023
a0b8899
Remove profiling test
yuslepukhin Nov 8, 2023
f0e84ac
Merge branch 'main' into yuslepukhin/if_constant_folding
yuslepukhin Nov 8, 2023
a17b3e0
Immediate consequetive nested Ifs inlining causes an issue
yuslepukhin Nov 10, 2023
60d99fb
Merge branch 'main' into yuslepukhin/if_constant_folding
yuslepukhin Nov 10, 2023
f0480cb
Adjust map type
yuslepukhin Nov 10, 2023
0ea3e23
Merge branch 'main' into yuslepukhin/if_constant_folding
yuslepukhin Nov 13, 2023
d955fbf
Address review comments
yuslepukhin Nov 13, 2023
94274a9
Lint
yuslepukhin Nov 13, 2023
f38afc3
lint 2
yuslepukhin Nov 13, 2023
7805d47
map_defs usage
yuslepukhin Nov 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,10 @@ class Node {
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
/** Remove the specified attribute from this Node */
bool ClearAttribute(const std::string& attr_name);

/** Gets the Node's mutable attributes. */
NodeAttributes& GetMutableAttributes() noexcept { return attributes_; }

#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

/**
Expand All @@ -406,8 +410,6 @@ class Node {
int PruneRemovableAttributes(gsl::span<const std::string> removable_attributes);

#if !defined(ORT_MINIMAL_BUILD)
/** Gets the Node's mutable attributes. */
NodeAttributes& GetMutableAttributes() noexcept { return attributes_; }

/** Gets the Graph instance that is instantiated from a GraphProto attribute during Graph::Resolve.
@param attr_name Attribute name for the GraphProto attribute.
Expand Down Expand Up @@ -441,6 +443,13 @@ class Node {
return attr_to_subgraph_map_;
}

/** Gets a map of attribute name to the mutable Graph instances for all subgraphs of the Node.
* @returns a mutable map of mutable subgraphs.
*/
std::unordered_map<std::string, gsl::not_null<Graph*>>& GetMutableMapOfAttributeNameToSubgraph() {
return attr_to_subgraph_map_;
}

/** Gets a map of attribute name to the const Graph instances for all subgraphs of the Node.
@returns Map of the attribute name that defines the subgraph to the subgraph's Graph instance.
nullptr if the Node has no subgraphs.
Expand Down Expand Up @@ -586,7 +595,7 @@ class Node {
// create a Graph instance for an attribute that contains a GraphProto
void CreateSubgraph(const std::string& attr_name);

const std::vector<std::unique_ptr<Graph>>& MutableSubgraphs() noexcept { return subgraphs_; }
std::vector<std::unique_ptr<Graph>>& MutableSubgraphs() noexcept { return subgraphs_; }
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved

// validate and update the input arg count
common::Status UpdateInputArgCount();
Expand Down Expand Up @@ -1134,6 +1143,26 @@ class Graph {
*/
Node& FuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name);

/**
Directly insert one of the If node branches into this Graph.
`If` node condition must be a constant. The function would
rename the nodes of the corresponding subgraph to make sure there is no conflict.

Explicit and implicit inputs references stay the same.

All of the outputs of the subgraph being inlined should be renamed
to the outputs of the If node.

The function will process any subgraphs in each of the nodes being inlined,
and will rename any references to the new names introduced.

@param condition_value If condition value
@param if_node - the node that contains the graph_to_inline. This node is going
to be deleted and replaced by the corresponding graph (either then or else)
@param logger
*/
Status InlineIfSubgraph(bool condition_value, Node& if_node, const logging::Logger& logger);
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved

/**
Directly insert the nodes in the function Node provided into this Graph.
The Graph needs to be Resolve()d after this call.
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/function_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ class Inliner {
// Process a node:
void transform(NodeProto& n) {
if (!n.name().empty())
n.set_name(prefix_ + n.name());
n.set_name(prefix_ + "_" + n.name());

for (auto& x : *n.mutable_input()) {
rename(x, false);
Expand Down
304 changes: 304 additions & 0 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -984,6 +984,7 @@ bool Node::ClearAttribute(const std::string& attr_name) {
graph_->SetGraphProtoSyncNeeded();
return attributes_.erase(attr_name) > 0;
}

#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

int Node::PruneRemovableAttributes(gsl::span<const std::string> removable_attributes) {
Expand Down Expand Up @@ -4047,6 +4048,309 @@ Status Graph::AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& nod
return Status::OK();
}

static void ReassignSubgraphDependentNodeArgs(const InlinedHashMap<std::string, NodeArg*>& name_to_nodearg,
Graph& graph) {
for (auto& node : graph.Nodes()) {
if (node.ContainsSubgraph()) {
for (auto& [name, subgraph] : node.GetAttributeNameToMutableSubgraphMap()) {
ReassignSubgraphDependentNodeArgs(name_to_nodearg, *subgraph);
}
}

// NodeArgs need to be updated
for (auto& input_def : node.MutableInputDefs()) {
if (input_def->Exists()) {
auto hit = name_to_nodearg.find(input_def->Name());
if (hit != name_to_nodearg.cend()) {
input_def = hit->second;
}
}
}
}
}

Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const logging::Logger& logger) {
static const std::string then_branch{"then_branch"};
static const std::string else_branch{"else_branch"};
Graph* sub_graph;
if (condition_value) {
sub_graph = if_node.GetMutableGraphAttribute(then_branch);
} else {
sub_graph = if_node.GetMutableGraphAttribute(else_branch);
}

if (sub_graph == nullptr) {
auto str = MakeString("Unable to constant fold If node: '", if_node.Name(), "' Unable to fetch: ",
(condition_value ? then_branch : else_branch));
LOGS(logger, WARNING) << str;
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
return Status::OK();
}

Graph& graph_to_inline = *sub_graph;

std::string unique_id{if_node.Name()};
if (condition_value) {
unique_id.append(then_branch);
} else {
unique_id.append(else_branch);
}

unique_id = GenerateNodeName(unique_id);

auto make_unique = [&unique_id](const std::string& name) {
return unique_id + '_' + name;
};

// Check if the name is an input or implicit input.
// These are not renamed, and we do not need to adjust subgraphs for them.
// Implicit inputs would cover both If node input and implicit inputs.
// Reason: there are no explicit inputs to the subgraphs, and the subgraph's
// implicit inputs must be covered by the implicit inputs of the If node.
InlinedHashMap<std::string_view, NodeArg*> outer_scope_values;
const auto if_implicit_inputs = if_node.MutableImplicitInputDefs();
outer_scope_values.reserve(if_implicit_inputs.size());

for (auto* input : if_implicit_inputs) {
const auto& name = input->Name();
ORT_IGNORE_RETURN_VALUE(outer_scope_values.emplace(name, input));
}

// Name mapping from the graph to inline to the graph we are inlining into
// we also use this to process any subgraphs in the graph we are inlining
InlinedHashMap<std::string, NodeArg*> name_to_nodearg;

// We are going to map the outputs of the graph to inline to the outputs of the If node.
// They are assumed to be in the same order.
const auto node_output_defs = if_node.MutableOutputDefs();
const auto graph_output_defs = graph_to_inline.GetOutputs();
for (size_t i = 0; i < graph_output_defs.size(); ++i) {
name_to_nodearg.emplace(graph_output_defs[i]->Name(), node_output_defs[i]);
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
}

// Move initializers from the subgraph to the destination graph.
for (int i = 0, limit = graph_to_inline.graph_proto_->initializer_size(); i < limit; ++i) {
auto* initializer = graph_to_inline.graph_proto_->mutable_initializer(i);
const std::string src_name = initializer->name();

#if !defined(DISABLE_SPARSE_TENSORS)
bool has_sparse_origin = false;
if (!graph_to_inline.sparse_tensor_names_.empty()) {
auto hit = graph_to_inline.sparse_tensor_names_.find(src_name);
if (hit != graph_to_inline.sparse_tensor_names_.cend()) {
has_sparse_origin = true;
// Erase the entry that will be invalidated
graph_to_inline.sparse_tensor_names_.erase(hit);
}
}
#endif

graph_to_inline.name_to_initial_tensor_.erase(src_name);
const gsl::not_null<TensorProto*> tensor{graph_proto_->add_initializer()};
*tensor = std::move(*initializer);

// Check if this is an output of the graph
auto hit = name_to_nodearg.find(src_name);
if (hit != name_to_nodearg.cend()) {
// We rename it to If node output.
tensor->set_name(hit->second->Name());
} else {
NodeArg* node_arg = graph_to_inline.GetNodeArg(src_name);
assert(node_arg != nullptr);
auto new_name = GenerateNodeArgName(make_unique(src_name));
NodeArg& new_arg = GetOrCreateNodeArg(new_name, node_arg->TypeAsProto());
ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(src_name, &new_arg));
tensor->set_name(std::move(new_name));
}

auto insert_result = name_to_initial_tensor_.emplace(tensor->name(), tensor);
ORT_ENFORCE(insert_result.second, "Initializer name: ", tensor->name(), " from graph: ",
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
graph_to_inline.Name(), " conflicts with graph initializer. Check name generation above.");

#if !defined(DISABLE_SPARSE_TENSORS)
if (has_sparse_origin) {
ORT_IGNORE_RETURN_VALUE(sparse_tensor_names_.emplace(tensor->name()));
}
#endif
}

// Look up nodes that would be providing input to our nodes (implicit and explicit)
// and any nodes that take the output of our nodes (used to be If output)
// Map of NodeArg name to pair of Node* and input index in the destination node
using NodeAndIndex = std::pair<gsl::not_null<Node*>, int>;
using ArgNameToNodeMap = InlinedHashMap<std::string_view, NodeAndIndex>;
ArgNameToNodeMap input_args;
// Map of NodeArg name to pair of Node* and output index in the source node.
ArgNameToNodeMap output_args;

auto map_defs = [](Node& node, ArgNameToNodeMap& map, bool input) {
const auto defs = (input) ? node.InputDefs() : node.OutputDefs();
map.reserve(defs.size());
int arg_pos = -1;
for (auto* node_arg : defs) {
++arg_pos;
if (node_arg->Exists()) {
map.emplace(node_arg->Name(), std::make_pair(&node, arg_pos));
}
}
};

const bool is_this_main_graph = (parent_graph_ == nullptr);
// Map the inputs and outputs of the If node to the nodes in the graph to inline.
if (!is_this_main_graph) {
for (auto& node : Nodes()) {
if (node.Index() == if_node.Index()) {
continue;
}
map_defs(node, input_args, true);
map_defs(node, output_args, false);
}
}

// We want to make sure we get nodes in topological order
// because Constant folding may cause the nodes appear in
// a different order.
InlinedVector<Node*> new_nodes;
GraphViewer graph(graph_to_inline);
for (const auto node_idx : graph.GetNodesInTopologicalOrder()) {
// GraphViewer filters out nullptrs
auto* node = graph_to_inline.GetNode(node_idx);
assert(node->OpType() != kConstant);

InlinedVector<NodeArg*> new_node_input_defs;
for (const auto* input_def : node->InputDefs()) {
if (input_def->Exists()) {
// Check if this is one of the implicit graph inputs
// then leave the name as is and re-use the NodeArg
const auto& input_name = input_def->Name();
auto outer_hit = outer_scope_values.find(input_name);
if (outer_hit != outer_scope_values.cend()) {
new_node_input_defs.push_back(outer_hit->second);
} else {
auto hit = name_to_nodearg.find(input_name);
if (hit != name_to_nodearg.cend()) {
// This is other node output, constant node or initializer that was renamed.
new_node_input_defs.push_back(hit->second);
} else {
ORT_THROW("Node's: ", node->Name(), " input: ", input_name,
" is not If node's input or previous node output in this subgraph");
}
}
}
}

InlinedVector<NodeArg*> new_node_output_defs;
for (const auto* output_def : node->OutputDefs()) {
const auto& output_name = output_def->Name();
auto hit = name_to_nodearg.find(output_name);
if (hit != name_to_nodearg.cend()) {
// This is one of the graph outputs, we rename it to
// If node output.
new_node_output_defs.push_back(hit->second);
} else {
// We generate an output to downstream nodes.
auto new_name = GenerateNodeArgName(make_unique(output_name));
NodeArg& new_arg = GetOrCreateNodeArg(new_name, output_def->TypeAsProto());
new_node_output_defs.push_back(&new_arg);
ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(output_name, &new_arg));
}
}

const auto new_node_name = GenerateNodeName(make_unique(node->OpType()));
Node& new_node = AddNode(new_node_name, node->OpType(), node->Description(),
new_node_input_defs,
new_node_output_defs,
nullptr,
node->Domain());

if (!is_this_main_graph) {
int arg_pos = -1;
for (auto* input_def : new_node_input_defs) {
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
++arg_pos;
input_args.insert_or_assign(input_def->Name(), std::make_pair(&new_node, arg_pos));
}
arg_pos = -1;
for (auto* output_def : new_node_output_defs) {
++arg_pos;
output_args.insert_or_assign(output_def->Name(), std::make_pair(&new_node, arg_pos));
}
new_nodes.push_back(&new_node);
}

new_node.SetSinceVersion(node->SinceVersion());
new_node.op_ = node->op_;

if (node->ContainsSubgraph()) {
auto& subgraphs = node->MutableSubgraphs();

// Check if any of this node implicit inputs of this graph is in the renaming map
int renames_subgraph_names = 0;
auto& new_implicit_defs = node->MutableImplicitInputDefs();
for (auto& input_def : new_implicit_defs) {
auto hit = name_to_nodearg.find(input_def->Name());
if (hit != name_to_nodearg.cend()) {
input_def = hit->second;
++renames_subgraph_names;
}
}

for (auto& subgraph : subgraphs) {
if (renames_subgraph_names > 0) {
// We need to rename the subgraph node names
// because they may refer to the implicit inputs
// that were renamed.
ReassignSubgraphDependentNodeArgs(name_to_nodearg, *subgraph);
}
subgraph->parent_node_ = &new_node;
subgraph->parent_graph_ = this;
}

new_node.MutableSubgraphs() = std::move(subgraphs);
new_node.GetMutableMapOfAttributeNameToSubgraph() = std::move(node->GetMutableMapOfAttributeNameToSubgraph());
new_node.MutableImplicitInputDefs() = std::move(new_implicit_defs);
}

new_node.GetMutableAttributes() = std::move(node->GetMutableAttributes());
}

// Let's rebuild local connections, so next time a GraphViewer is able to perform topological sort.
// We only need to do so if this graph is not the main graph, because the main graph is going to resolve
// and it is not possible to inline the same nodes again.
if (!is_this_main_graph) {
for (auto* node : new_nodes) {
int arg_pos = -1;
for (auto* input_def : node->InputDefs()) {
++arg_pos;
auto hit = output_args.find(input_def->Name());
if (hit != output_args.cend()) {
// The input to this node is an output from a previous node in this graph.
// Create relationship between this node (node), and the node providing the output (output_node).
const auto& [producer, src_idx] = hit->second;
AddEdge(producer->Index(), node->Index(), src_idx, arg_pos);
}
}

// Check if any of the outputs for inlined nodes are inputs to other nodes in the graph.
// (outputs of If node)
arg_pos = -1;
for (auto& output_def : node->OutputDefs()) {
++arg_pos;
auto hit = input_args.find(output_def->Name());
if (hit != input_args.cend()) {
// The output of this node is an input to another node in this graph.
// Create relationship between this node (node), and the node using the input (input_node).
const auto& [consumer, dst_idx] = hit->second;
AddEdge(node->Index(), consumer->Index(), arg_pos, dst_idx);
}
}
}
}

LOGS(logger, INFO) << "Constant folded (inlined) " << (condition_value ? then_branch : else_branch)
<< " for If node: " << if_node.Name();

return Status::OK();
}

Status Graph::InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_inline) {
auto to_node_arg = [this](const std::string& name) {
return &this->GetOrCreateNodeArg(name, nullptr);
Expand Down
Loading
Loading