Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create edges with arg positons correctly accounting for non-existing args #18462

Merged
merged 6 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cmake/external/abseil-cpp.natvis
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
<Intrinsic Name="_capacity" Expression="_commonfields().capacity_"/>
<Intrinsic Name="_control" Expression="_commonfields().control_"/>
<Intrinsic Name="_slots" Expression="(slot_type*)(_commonfields().slots_)"/>
<DisplayString Condition="_size() == 0">empty</DisplayString>
<DisplayString IncludeView="noparens">size={ _size() }</DisplayString>
<DisplayString ExcludeView="noparens">size=({_size()})</DisplayString>
<Expand>
Expand Down
93 changes: 61 additions & 32 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4062,7 +4062,9 @@ static void ReassignSubgraphDependentNodeArgs(const InlinedHashMap<std::string,
if (input_def->Exists()) {
auto hit = name_to_nodearg.find(input_def->Name());
if (hit != name_to_nodearg.cend()) {
input_def = hit->second;
// Make sure we create a local to this subgraph definition
const auto* new_name_arg = hit->second;
input_def = &graph.GetOrCreateNodeArg(new_name_arg->Name(), input_def->TypeAsProto());
}
}
}
Expand All @@ -4088,7 +4090,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin

Graph& graph_to_inline = *sub_graph;

std::string unique_id{if_node.Name()};
std::string unique_id{"_if_"};
if (condition_value) {
unique_id.append(then_branch);
} else {
Expand All @@ -4107,7 +4109,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin
// Reason: there are no explicit inputs to the subgraphs, and the subgraph's
// implicit inputs must be covered by the implicit inputs of the If node.
InlinedHashMap<std::string_view, NodeArg*> outer_scope_values;
const auto if_implicit_inputs = if_node.MutableImplicitInputDefs();
const auto& if_implicit_inputs = if_node.MutableImplicitInputDefs();
outer_scope_values.reserve(if_implicit_inputs.size());

for (auto* input : if_implicit_inputs) {
Expand All @@ -4121,8 +4123,8 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin

// We are going to map the outputs of the graph to inline to the outputs of the If node.
// They are assumed to be in the same order.
const auto node_output_defs = if_node.MutableOutputDefs();
const auto graph_output_defs = graph_to_inline.GetOutputs();
const auto& node_output_defs = if_node.MutableOutputDefs();
const auto& graph_output_defs = graph_to_inline.GetOutputs();
for (size_t i = 0; i < graph_output_defs.size(); ++i) {
name_to_nodearg.emplace(graph_output_defs[i]->Name(), node_output_defs[i]);
}
Expand Down Expand Up @@ -4206,6 +4208,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin
}
}

auto* non_existing_arg = &GetOrCreateNodeArg(std::string(), nullptr);
// We want to make sure we get nodes in topological order
// because Constant folding may cause the nodes appear in
// a different order.
Expand All @@ -4216,68 +4219,94 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin
auto* node = graph_to_inline.GetNode(node_idx);
assert(node->OpType() != kConstant);

InlinedVector<NodeArg*> new_node_input_defs;
for (const auto* input_def : node->InputDefs()) {
// Inputs
// Chop off trailing non-existing defs, but preserve non-existing in the middle
auto& input_defs = node->MutableInputDefs();
auto last_existing = std::find_if(input_defs.rbegin(), input_defs.rend(),
[](const NodeArg* node_arg) { return node_arg->Exists(); });
input_defs.resize(std::distance(input_defs.begin(), last_existing.base()));

InlinedVector<NodeArg*> new_input_defs;
for (auto* input_def : node->InputDefs()) {
if (input_def->Exists()) {
// Check if this is one of the implicit graph inputs
// then leave the name as is and re-use the NodeArg
// then re-assign the def to the outer scope value.
const auto& input_name = input_def->Name();
auto outer_hit = outer_scope_values.find(input_name);
if (outer_hit != outer_scope_values.cend()) {
new_node_input_defs.push_back(outer_hit->second);
// get/create local definition
NodeArg* outer_arg = outer_hit->second;
auto& this_scope_arg = GetOrCreateNodeArg(outer_arg->Name(), input_def->TypeAsProto());
new_input_defs.push_back(&this_scope_arg);
} else {
auto hit = name_to_nodearg.find(input_name);
if (hit != name_to_nodearg.cend()) {
// This is other node output, constant node or initializer that was renamed.
new_node_input_defs.push_back(hit->second);
// This is other node output in the dest graph,
// constant node or initializer that was renamed.
new_input_defs.push_back(hit->second);
} else {
ORT_THROW("Node's: ", node->Name(), " input: ", input_name,
" is not If node's input or previous node output in this subgraph");
}
}
} else {
new_input_defs.push_back(non_existing_arg);
}
}

InlinedVector<NodeArg*> new_node_output_defs;
for (const auto* output_def : node->OutputDefs()) {
const auto& output_name = output_def->Name();
auto hit = name_to_nodearg.find(output_name);
if (hit != name_to_nodearg.cend()) {
// This is one of the graph outputs, we rename it to
// If node output.
new_node_output_defs.push_back(hit->second);
// Outputs
// Chop off trailing non-existing defs
auto& output_defs = node->MutableOutputDefs();
last_existing = std::find_if(output_defs.rbegin(), output_defs.rend(),
[](const NodeArg* node_arg) { return node_arg->Exists(); });
output_defs.resize(std::distance(output_defs.begin(), last_existing.base()));

InlinedVector<NodeArg*> new_output_defs;
for (auto* output_def : node->OutputDefs()) {
if (output_def->Exists()) {
const auto& output_name = output_def->Name();
auto hit = name_to_nodearg.find(output_name);
if (hit != name_to_nodearg.cend()) {
// This is one of the If node outputs, simply reassign the def.
// If node defs are already in the destination graph
new_output_defs.push_back(hit->second);
} else {
// We generate an output to downstream nodes.
auto new_name = GenerateNodeArgName(make_unique(output_name));
NodeArg& new_arg = GetOrCreateNodeArg(new_name, output_def->TypeAsProto());
new_output_defs.push_back(&new_arg);
ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(output_name, &new_arg));
}
} else {
// We generate an output to downstream nodes.
auto new_name = GenerateNodeArgName(make_unique(output_name));
NodeArg& new_arg = GetOrCreateNodeArg(new_name, output_def->TypeAsProto());
new_node_output_defs.push_back(&new_arg);
ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(output_name, &new_arg));
new_output_defs.push_back(non_existing_arg);
}
}

const auto new_node_name = GenerateNodeName(make_unique(node->OpType()));
Node& new_node = AddNode(new_node_name, node->OpType(), node->Description(),
new_node_input_defs,
new_node_output_defs,
new_input_defs,
new_output_defs,
nullptr,
node->Domain());

new_node.SetSinceVersion(node->SinceVersion());
new_node.op_ = node->op_;

if (!is_this_main_graph) {
map_defs(new_node, input_args, true);
map_defs(new_node, output_args, false);
new_nodes.push_back(&new_node);
}

new_node.SetSinceVersion(node->SinceVersion());
new_node.op_ = node->op_;

if (node->ContainsSubgraph()) {
auto& subgraphs = node->MutableSubgraphs();

// Check if any of this node implicit inputs of this graph is in the renaming map
// that would mean they come from the destination graph, not from the parent
// of the destination graph.
int renames_subgraph_names = 0;
auto& new_implicit_defs = node->MutableImplicitInputDefs();
for (auto& input_def : new_implicit_defs) {
auto& implicit_defs = node->MutableImplicitInputDefs();
for (auto& input_def : implicit_defs) {
auto hit = name_to_nodearg.find(input_def->Name());
if (hit != name_to_nodearg.cend()) {
input_def = hit->second;
Expand All @@ -4298,7 +4327,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin

new_node.MutableSubgraphs() = std::move(subgraphs);
new_node.GetMutableMapOfAttributeNameToSubgraph() = std::move(node->GetMutableMapOfAttributeNameToSubgraph());
new_node.MutableImplicitInputDefs() = std::move(new_implicit_defs);
new_node.MutableImplicitInputDefs() = std::move(implicit_defs);
}

new_node.GetMutableAttributes() = std::move(node->GetMutableAttributes());
Expand Down
156 changes: 156 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,162 @@ TEST_F(GraphTransformationTests, ConstantFoldingIfConstantInliningRebuildEdges)
ASSERT_EQ(op_to_count["Cast"], 2);
}

TEST_F(GraphTransformationTests, ConstantFoldingIfConstantInliningEdgesWithMiddleArgNonExisting) {
// This model has a Resize() call with a middle argument non-existing.
// We want to make sure that the input edges for that Resize() node
// are properly rebuilt with a middle argument non-existing
// during If constant folding
// This test is only valid if Resize() node resides in the nested subgraph which gets inlined
// however, the destination graph must not be the main graph. Then we test that the edges are rebuild
// properly. Also Resize() should not be the first node in the resulting subgraph, so it has edges
const char* code = R"(
<
ir_version: 8,
opset_import: [ "" : 16, "local" : 1 ]
>
agraph (float[128] x, float[128] x1) => (float[N] y)
{
y = local.aten_gather <dim: int = 1, sparse_grad: int = 0> (x, x1)
}
<
opset_import: [ "" : 16, "local" : 1],
domain: "local"
>
aten_gather <dim>(self, index) => (result_16)
{
resize_scales = Constant <value_floats: floats = [1.5]> ()
tmp_0 = Size (index)
int64_0 = Constant <value: tensor = int64 int64_0 {0}> ()
int64_0_cast = CastLike (int64_0, tmp_0)
cond = Equal (tmp_0, int64_0_cast)
result_16 = If (cond) <then_branch: graph = thenGraph_10 () => ( result) {
result = Identity (self)
}, else_branch: graph = elseGraph_10 () => ( result_15) {
tmp_1 = Shape (self)
tmp_2 = Size (tmp_1)
int64_0_3 = Constant <value: tensor = int64 int64_0_3 {0}> ()
int64_0_3_cast = CastLike (int64_0_3, tmp_2)
cond_4 = Equal (tmp_2, int64_0_3_cast)
self_8 = If (cond_4) <then_branch: graph = thenGraph_13 () => ( self_6) {
tmp_5 = Constant <value_ints: ints = [-1]> ()
self_6 = Reshape (self, tmp_5)
}, else_branch: graph = elseGraph_13 () => ( self_7) {
self_71 = Mul(self, self)
float_size = CastLike (tmp_0, resize_scales)
non_constant_resize_scales = Mul(float_size, resize_scales)
self_7 = Resize(self_71,, non_constant_resize_scales)
}>
tmp_9 = Size (index)
int64_0_10 = Constant <value: tensor = int64 int64_0_10 {0}> ()
int64_0_10_cast = CastLike (int64_0_10, tmp_9)
cond_11 = Equal (tmp_9, int64_0_10_cast)
result_15 = If (cond_11) <then_branch: graph = thenGraph_15 () => ( result_12) {
result_12 = CastLike (index, self_8)
}, else_branch: graph = elseGraph_15 () => ( result_14) {
index_13 = Cast <to: int = 7> (index)
result_14 = GatherElements <axis: int = @dim> (self_8, index_13)
}>
}>
}
)";

/** Optimized model graph
<
ir_version: 8,
opset_import: ["" : 16,
"local" : 1,
"com.microsoft.nchwc" : 1,
"ai.onnx.ml" : 4,
"ai.onnx.training" : 1,
"ai.onnx.preview.training" : 1,
"com.microsoft" : 1,
"com.microsoft.experimental" : 1, "org.pytorch.aten" : 1]
>
agraph (float[128] x, float[128] x1) => (float[128] y)
<float[1] _inlfunc_aten_gather_resize_scales = {1.5}, int64 ortshared_7_0_1_0_token_8 = {0}>
{
_inlfunc_aten_gather_tmp_0 = Size (x1)
_inlfunc_aten_gather_cond = Equal (_inlfunc_aten_gather_tmp_0, ortshared_7_0_1_0_token_8)
y = If (_inlfunc_aten_gather_cond) <then_branch: graph = thenGraph_10 () =>
(float[128] _inlfunc_aten_gather_result) {
_inlfunc_aten_gather_result = Identity (x)
}, else_branch: graph = elseGraph_10 () => (float[128] _inlfunc_aten_gather_result_15)
<int64 _inlfunc_aten_gather_int64_0_10 = {0}>
{
_if_else_branch__inlfunc_aten_gather_self_71 = Mul (x, x)
_if_else_branch__inlfunc_aten_gather_float_size = Cast <to: int = 1> (_inlfunc_aten_gather_tmp_0)
_if_else_branch__inlfunc_aten_gather_non_constant_resize_scales = Mul (
_if_else_branch__inlfunc_aten_gather_float_size, _inlfunc_aten_gather_resize_scales)
_inlfunc_aten_gather_self_8 = Resize <exclude_outside: int = 0, coordinate_transformation_mode:
string = "half_pixel", cubic_coeff_a: float = -0.75, extrapolation_value: float = 0, mode:
string = "nearest", nearest_mode: string = "round_prefer_floor"> (
_if_else_branch__inlfunc_aten_gather_self_71, ,
_if_else_branch__inlfunc_aten_gather_non_constant_resize_scales)
_inlfunc_aten_gather_tmp_9 = Size (x1)
_inlfunc_aten_gather_cond_11 = Equal (_inlfunc_aten_gather_tmp_9, _inlfunc_aten_gather_int64_0_10)
_inlfunc_aten_gather_result_15 = If (_inlfunc_aten_gather_cond_11) <then_branch: graph = thenGraph_15 () =>
(float[128] _inlfunc_aten_gather_result_12) {
_inlfunc_aten_gather_result_12 = Cast <to: int = 1> (x1)
}, else_branch: graph = elseGraph_15 () => (float[128] _inlfunc_aten_gather_result_14) {
_inlfunc_aten_gather_index_13 = Cast <to: int = 7> (x1)
_inlfunc_aten_gather_result_14 = GatherElements <axis: int = 1> (
_inlfunc_aten_gather_self_8, _inlfunc_aten_gather_index_13)
}>
}>
}

*/

ONNX_NAMESPACE::OnnxParser parser(code);
ONNX_NAMESPACE::ModelProto model_proto;
auto parse_status = parser.Parse(model_proto);
ASSERT_TRUE(parse_status.IsOK()) << parse_status.ErrorMessage();
ASSERT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected.";

std::string serialized_model;
const bool serialization_status = model_proto.SerializeToString(&serialized_model);
ASSERT_TRUE(serialization_status) << "Failed to serialize proto to string";

// AOT inlining is necessary in this case, so the If nodes within the function
// are brought out to the outer scope. So we load this into a session object.
SessionOptions session_options;
InferenceSessionWrapper session_object{session_options, GetEnvironment()};
std::stringstream sstr(serialized_model);
ASSERT_STATUS_OK(session_object.Load(sstr));
ASSERT_STATUS_OK(session_object.Initialize());

// Let's verify the correctness of the rebuild edges in the Resize node that still
// resides within an if else subgraph.
auto& graph = session_object.GetModel().MainGraph();
auto op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["If"], 2);
ASSERT_EQ(op_to_count["Resize"], 1);

auto if_node = std::find_if(graph.Nodes().begin(), graph.Nodes().end(),
[](const auto& node) { return node.OpType() == "If"; });
ASSERT_NE(graph.Nodes().cend(), if_node);
// Resize is in the else branch
auto subgraph_map = if_node->GetAttributeNameToSubgraphMap();
auto branch = subgraph_map.find("else_branch");
ASSERT_NE(subgraph_map.cend(), branch);

auto resize_node = std::find_if(branch->second->Nodes().begin(), branch->second->Nodes().end(),
[](const auto& node) { return node.OpType() == "Resize"; });
ASSERT_NE(branch->second->Nodes().cend(), resize_node);

// Check the edges
ASSERT_EQ(2U, resize_node->GetInputEdgesCount());
// Should have input edges with arg_pos 0 and 2
// With 1 is missing
InlinedHashSet<size_t> dest_edges;
auto zero_edge = resize_node->InputEdgesBegin();
dest_edges.insert(zero_edge->GetDstArgIndex());
++zero_edge;
dest_edges.insert(zero_edge->GetDstArgIndex());
ASSERT_TRUE(dest_edges.find(0) != dest_edges.end());
ASSERT_TRUE(dest_edges.find(2) != dest_edges.end());
}

// Check transformations in the case of a subgraph with constant inputs.
TEST_F(GraphTransformationTests, SubgraphWithConstantInputs) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "constant-subgraph.onnx";
Expand Down
Loading