Skip to content

Commit

Permalink
Create edges with arg positons correctly accounting for non-existing …
Browse files Browse the repository at this point in the history
…args (#18462)

### Description
Truncate traling non-existing arguments.
  Make sure we do not skip on the non-existing arguments in the middle,
  because shape inferece relies on their proper position.
This also affects the argument position in the Edges that must be
properly rebuilt
  each time If node branch is inlined.
Make sure that when we rename Defs in subgraphs, new renamed defs are
created in those subgraphs
  instead of pointing to outer scope defs.
  Add unit test.

### Motivation and Context
This is a follow up for
#18105
Currently, the non-trailing arguments are simply ignored and the edges
are created
with potentially incorrect positions.
  • Loading branch information
yuslepukhin authored Nov 20, 2023
1 parent 247ce21 commit cc54202
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 33 deletions.
1 change: 0 additions & 1 deletion cmake/external/abseil-cpp.natvis
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
<Intrinsic Name="_capacity" Expression="_commonfields().capacity_"/>
<Intrinsic Name="_control" Expression="_commonfields().control_"/>
<Intrinsic Name="_slots" Expression="(slot_type*)(_commonfields().slots_)"/>
<DisplayString Condition="_size() == 0">empty</DisplayString>
<DisplayString IncludeView="noparens">size={ _size() }</DisplayString>
<DisplayString ExcludeView="noparens">size=({_size()})</DisplayString>
<Expand>
Expand Down
93 changes: 61 additions & 32 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4062,7 +4062,9 @@ static void ReassignSubgraphDependentNodeArgs(const InlinedHashMap<std::string,
if (input_def->Exists()) {
auto hit = name_to_nodearg.find(input_def->Name());
if (hit != name_to_nodearg.cend()) {
input_def = hit->second;
// Make sure we create a local to this subgraph definition
const auto* new_name_arg = hit->second;
input_def = &graph.GetOrCreateNodeArg(new_name_arg->Name(), input_def->TypeAsProto());
}
}
}
Expand All @@ -4088,7 +4090,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin

Graph& graph_to_inline = *sub_graph;

std::string unique_id{if_node.Name()};
std::string unique_id{"_if_"};
if (condition_value) {
unique_id.append(then_branch);
} else {
Expand All @@ -4107,7 +4109,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin
// Reason: there are no explicit inputs to the subgraphs, and the subgraph's
// implicit inputs must be covered by the implicit inputs of the If node.
InlinedHashMap<std::string_view, NodeArg*> outer_scope_values;
const auto if_implicit_inputs = if_node.MutableImplicitInputDefs();
const auto& if_implicit_inputs = if_node.MutableImplicitInputDefs();
outer_scope_values.reserve(if_implicit_inputs.size());

for (auto* input : if_implicit_inputs) {
Expand All @@ -4121,8 +4123,8 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin

// We are going to map the outputs of the graph to inline to the outputs of the If node.
// They are assumed to be in the same order.
const auto node_output_defs = if_node.MutableOutputDefs();
const auto graph_output_defs = graph_to_inline.GetOutputs();
const auto& node_output_defs = if_node.MutableOutputDefs();
const auto& graph_output_defs = graph_to_inline.GetOutputs();
for (size_t i = 0; i < graph_output_defs.size(); ++i) {
name_to_nodearg.emplace(graph_output_defs[i]->Name(), node_output_defs[i]);
}
Expand Down Expand Up @@ -4206,6 +4208,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin
}
}

auto* non_existing_arg = &GetOrCreateNodeArg(std::string(), nullptr);
// We want to make sure we get nodes in topological order
// because Constant folding may cause the nodes appear in
// a different order.
Expand All @@ -4216,68 +4219,94 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin
auto* node = graph_to_inline.GetNode(node_idx);
assert(node->OpType() != kConstant);

InlinedVector<NodeArg*> new_node_input_defs;
for (const auto* input_def : node->InputDefs()) {
// Inputs
// Chop off trailing non-existing defs, but preserve non-existing in the middle
auto& input_defs = node->MutableInputDefs();
auto last_existing = std::find_if(input_defs.rbegin(), input_defs.rend(),
[](const NodeArg* node_arg) { return node_arg->Exists(); });
input_defs.resize(std::distance(input_defs.begin(), last_existing.base()));

InlinedVector<NodeArg*> new_input_defs;
for (auto* input_def : node->InputDefs()) {
if (input_def->Exists()) {
// Check if this is one of the implicit graph inputs
// then leave the name as is and re-use the NodeArg
// then re-assign the def to the outer scope value.
const auto& input_name = input_def->Name();
auto outer_hit = outer_scope_values.find(input_name);
if (outer_hit != outer_scope_values.cend()) {
new_node_input_defs.push_back(outer_hit->second);
// get/create local definition
NodeArg* outer_arg = outer_hit->second;
auto& this_scope_arg = GetOrCreateNodeArg(outer_arg->Name(), input_def->TypeAsProto());
new_input_defs.push_back(&this_scope_arg);
} else {
auto hit = name_to_nodearg.find(input_name);
if (hit != name_to_nodearg.cend()) {
// This is other node output, constant node or initializer that was renamed.
new_node_input_defs.push_back(hit->second);
// This is other node output in the dest graph,
// constant node or initializer that was renamed.
new_input_defs.push_back(hit->second);
} else {
ORT_THROW("Node's: ", node->Name(), " input: ", input_name,
" is not If node's input or previous node output in this subgraph");
}
}
} else {
new_input_defs.push_back(non_existing_arg);
}
}

InlinedVector<NodeArg*> new_node_output_defs;
for (const auto* output_def : node->OutputDefs()) {
const auto& output_name = output_def->Name();
auto hit = name_to_nodearg.find(output_name);
if (hit != name_to_nodearg.cend()) {
// This is one of the graph outputs, we rename it to
// If node output.
new_node_output_defs.push_back(hit->second);
// Outputs
// Chop off trailing non-existing defs
auto& output_defs = node->MutableOutputDefs();
last_existing = std::find_if(output_defs.rbegin(), output_defs.rend(),
[](const NodeArg* node_arg) { return node_arg->Exists(); });
output_defs.resize(std::distance(output_defs.begin(), last_existing.base()));

InlinedVector<NodeArg*> new_output_defs;
for (auto* output_def : node->OutputDefs()) {
if (output_def->Exists()) {
const auto& output_name = output_def->Name();
auto hit = name_to_nodearg.find(output_name);
if (hit != name_to_nodearg.cend()) {
// This is one of the If node outputs, simply reassign the def.
// If node defs are already in the destination graph
new_output_defs.push_back(hit->second);
} else {
// We generate an output to downstream nodes.
auto new_name = GenerateNodeArgName(make_unique(output_name));
NodeArg& new_arg = GetOrCreateNodeArg(new_name, output_def->TypeAsProto());
new_output_defs.push_back(&new_arg);
ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(output_name, &new_arg));
}
} else {
// We generate an output to downstream nodes.
auto new_name = GenerateNodeArgName(make_unique(output_name));
NodeArg& new_arg = GetOrCreateNodeArg(new_name, output_def->TypeAsProto());
new_node_output_defs.push_back(&new_arg);
ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(output_name, &new_arg));
new_output_defs.push_back(non_existing_arg);
}
}

const auto new_node_name = GenerateNodeName(make_unique(node->OpType()));
Node& new_node = AddNode(new_node_name, node->OpType(), node->Description(),
new_node_input_defs,
new_node_output_defs,
new_input_defs,
new_output_defs,
nullptr,
node->Domain());

new_node.SetSinceVersion(node->SinceVersion());
new_node.op_ = node->op_;

if (!is_this_main_graph) {
map_defs(new_node, input_args, true);
map_defs(new_node, output_args, false);
new_nodes.push_back(&new_node);
}

new_node.SetSinceVersion(node->SinceVersion());
new_node.op_ = node->op_;

if (node->ContainsSubgraph()) {
auto& subgraphs = node->MutableSubgraphs();

// Check if any of this node implicit inputs of this graph is in the renaming map
// that would mean they come from the destination graph, not from the parent
// of the destination graph.
int renames_subgraph_names = 0;
auto& new_implicit_defs = node->MutableImplicitInputDefs();
for (auto& input_def : new_implicit_defs) {
auto& implicit_defs = node->MutableImplicitInputDefs();
for (auto& input_def : implicit_defs) {
auto hit = name_to_nodearg.find(input_def->Name());
if (hit != name_to_nodearg.cend()) {
input_def = hit->second;
Expand All @@ -4298,7 +4327,7 @@ Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const loggin

new_node.MutableSubgraphs() = std::move(subgraphs);
new_node.GetMutableMapOfAttributeNameToSubgraph() = std::move(node->GetMutableMapOfAttributeNameToSubgraph());
new_node.MutableImplicitInputDefs() = std::move(new_implicit_defs);
new_node.MutableImplicitInputDefs() = std::move(implicit_defs);
}

new_node.GetMutableAttributes() = std::move(node->GetMutableAttributes());
Expand Down
156 changes: 156 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,162 @@ TEST_F(GraphTransformationTests, ConstantFoldingIfConstantInliningRebuildEdges)
ASSERT_EQ(op_to_count["Cast"], 2);
}

TEST_F(GraphTransformationTests, ConstantFoldingIfConstantInliningEdgesWithMiddleArgNonExisting) {
// This model has a Resize() call with a middle argument non-existing.
// We want to make sure that the input edges for that Resize() node
// are properly rebuilt with a middle argument non-existing
// during If constant folding
// This test is only valid if Resize() node resides in the nested subgraph which gets inlined
// however, the destination graph must not be the main graph. Then we test that the edges are rebuild
// properly. Also Resize() should not be the first node in the resulting subgraph, so it has edges
const char* code = R"(
<
ir_version: 8,
opset_import: [ "" : 16, "local" : 1 ]
>
agraph (float[128] x, float[128] x1) => (float[N] y)
{
y = local.aten_gather <dim: int = 1, sparse_grad: int = 0> (x, x1)
}
<
opset_import: [ "" : 16, "local" : 1],
domain: "local"
>
aten_gather <dim>(self, index) => (result_16)
{
resize_scales = Constant <value_floats: floats = [1.5]> ()
tmp_0 = Size (index)
int64_0 = Constant <value: tensor = int64 int64_0 {0}> ()
int64_0_cast = CastLike (int64_0, tmp_0)
cond = Equal (tmp_0, int64_0_cast)
result_16 = If (cond) <then_branch: graph = thenGraph_10 () => ( result) {
result = Identity (self)
}, else_branch: graph = elseGraph_10 () => ( result_15) {
tmp_1 = Shape (self)
tmp_2 = Size (tmp_1)
int64_0_3 = Constant <value: tensor = int64 int64_0_3 {0}> ()
int64_0_3_cast = CastLike (int64_0_3, tmp_2)
cond_4 = Equal (tmp_2, int64_0_3_cast)
self_8 = If (cond_4) <then_branch: graph = thenGraph_13 () => ( self_6) {
tmp_5 = Constant <value_ints: ints = [-1]> ()
self_6 = Reshape (self, tmp_5)
}, else_branch: graph = elseGraph_13 () => ( self_7) {
self_71 = Mul(self, self)
float_size = CastLike (tmp_0, resize_scales)
non_constant_resize_scales = Mul(float_size, resize_scales)
self_7 = Resize(self_71,, non_constant_resize_scales)
}>
tmp_9 = Size (index)
int64_0_10 = Constant <value: tensor = int64 int64_0_10 {0}> ()
int64_0_10_cast = CastLike (int64_0_10, tmp_9)
cond_11 = Equal (tmp_9, int64_0_10_cast)
result_15 = If (cond_11) <then_branch: graph = thenGraph_15 () => ( result_12) {
result_12 = CastLike (index, self_8)
}, else_branch: graph = elseGraph_15 () => ( result_14) {
index_13 = Cast <to: int = 7> (index)
result_14 = GatherElements <axis: int = @dim> (self_8, index_13)
}>
}>
}
)";

/** Optimized model graph
<
ir_version: 8,
opset_import: ["" : 16,
"local" : 1,
"com.microsoft.nchwc" : 1,
"ai.onnx.ml" : 4,
"ai.onnx.training" : 1,
"ai.onnx.preview.training" : 1,
"com.microsoft" : 1,
"com.microsoft.experimental" : 1, "org.pytorch.aten" : 1]
>
agraph (float[128] x, float[128] x1) => (float[128] y)
<float[1] _inlfunc_aten_gather_resize_scales = {1.5}, int64 ortshared_7_0_1_0_token_8 = {0}>
{
_inlfunc_aten_gather_tmp_0 = Size (x1)
_inlfunc_aten_gather_cond = Equal (_inlfunc_aten_gather_tmp_0, ortshared_7_0_1_0_token_8)
y = If (_inlfunc_aten_gather_cond) <then_branch: graph = thenGraph_10 () =>
(float[128] _inlfunc_aten_gather_result) {
_inlfunc_aten_gather_result = Identity (x)
}, else_branch: graph = elseGraph_10 () => (float[128] _inlfunc_aten_gather_result_15)
<int64 _inlfunc_aten_gather_int64_0_10 = {0}>
{
_if_else_branch__inlfunc_aten_gather_self_71 = Mul (x, x)
_if_else_branch__inlfunc_aten_gather_float_size = Cast <to: int = 1> (_inlfunc_aten_gather_tmp_0)
_if_else_branch__inlfunc_aten_gather_non_constant_resize_scales = Mul (
_if_else_branch__inlfunc_aten_gather_float_size, _inlfunc_aten_gather_resize_scales)
_inlfunc_aten_gather_self_8 = Resize <exclude_outside: int = 0, coordinate_transformation_mode:
string = "half_pixel", cubic_coeff_a: float = -0.75, extrapolation_value: float = 0, mode:
string = "nearest", nearest_mode: string = "round_prefer_floor"> (
_if_else_branch__inlfunc_aten_gather_self_71, ,
_if_else_branch__inlfunc_aten_gather_non_constant_resize_scales)
_inlfunc_aten_gather_tmp_9 = Size (x1)
_inlfunc_aten_gather_cond_11 = Equal (_inlfunc_aten_gather_tmp_9, _inlfunc_aten_gather_int64_0_10)
_inlfunc_aten_gather_result_15 = If (_inlfunc_aten_gather_cond_11) <then_branch: graph = thenGraph_15 () =>
(float[128] _inlfunc_aten_gather_result_12) {
_inlfunc_aten_gather_result_12 = Cast <to: int = 1> (x1)
}, else_branch: graph = elseGraph_15 () => (float[128] _inlfunc_aten_gather_result_14) {
_inlfunc_aten_gather_index_13 = Cast <to: int = 7> (x1)
_inlfunc_aten_gather_result_14 = GatherElements <axis: int = 1> (
_inlfunc_aten_gather_self_8, _inlfunc_aten_gather_index_13)
}>
}>
}
*/

ONNX_NAMESPACE::OnnxParser parser(code);
ONNX_NAMESPACE::ModelProto model_proto;
auto parse_status = parser.Parse(model_proto);
ASSERT_TRUE(parse_status.IsOK()) << parse_status.ErrorMessage();
ASSERT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected.";

std::string serialized_model;
const bool serialization_status = model_proto.SerializeToString(&serialized_model);
ASSERT_TRUE(serialization_status) << "Failed to serialize proto to string";

// AOT inlining is necessary in this case, so the If nodes within the function
// are brought out to the outer scope. So we load this into a session object.
SessionOptions session_options;
InferenceSessionWrapper session_object{session_options, GetEnvironment()};
std::stringstream sstr(serialized_model);
ASSERT_STATUS_OK(session_object.Load(sstr));
ASSERT_STATUS_OK(session_object.Initialize());

// Let's verify the correctness of the rebuild edges in the Resize node that still
// resides within an if else subgraph.
auto& graph = session_object.GetModel().MainGraph();
auto op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["If"], 2);
ASSERT_EQ(op_to_count["Resize"], 1);

auto if_node = std::find_if(graph.Nodes().begin(), graph.Nodes().end(),
[](const auto& node) { return node.OpType() == "If"; });
ASSERT_NE(graph.Nodes().cend(), if_node);
// Resize is in the else branch
auto subgraph_map = if_node->GetAttributeNameToSubgraphMap();
auto branch = subgraph_map.find("else_branch");
ASSERT_NE(subgraph_map.cend(), branch);

auto resize_node = std::find_if(branch->second->Nodes().begin(), branch->second->Nodes().end(),
[](const auto& node) { return node.OpType() == "Resize"; });
ASSERT_NE(branch->second->Nodes().cend(), resize_node);

// Check the edges
ASSERT_EQ(2U, resize_node->GetInputEdgesCount());
// Should have input edges with arg_pos 0 and 2
// With 1 is missing
InlinedHashSet<size_t> dest_edges;
auto zero_edge = resize_node->InputEdgesBegin();
dest_edges.insert(zero_edge->GetDstArgIndex());
++zero_edge;
dest_edges.insert(zero_edge->GetDstArgIndex());
ASSERT_TRUE(dest_edges.find(0) != dest_edges.end());
ASSERT_TRUE(dest_edges.find(2) != dest_edges.end());
}

// Check transformations in the case of a subgraph with constant inputs.
TEST_F(GraphTransformationTests, SubgraphWithConstantInputs) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "constant-subgraph.onnx";
Expand Down

0 comments on commit cc54202

Please sign in to comment.