Skip to content

Commit

Permalink
Fix shape conv fuse opt (microsoft#20282)
Browse files Browse the repository at this point in the history
FIx:
- Multiples Convs into an Add+Relu will fuse the op although intermediates are needed

![image](https://github.com/microsoft/onnxruntime/assets/44298237/0c85a30c-5f41-4e62-ae2e-f41eada6c2c3)
- Also fixes an issue with Shape Initializers Merge as input, that
occurs when the input initializer is the same across multiple nodes but
not all nodes are Shape nodes.
  • Loading branch information
gedoensmax authored Apr 23, 2024
1 parent 8f53957 commit b4e5075
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
15 changes: 14 additions & 1 deletion onnxruntime/core/optimizer/conv_activation_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,20 @@ const Node* GetLoneConsumerNode(const GraphViewer& graph_viewer, const Node& nod
if (!optimizer_utils::CheckOutputEdges(graph_viewer.GetGraph(), node, 1)) {
return nullptr;
}
return &*node.OutputNodesBegin();
const Node* next_node = &*node.OutputNodesBegin();
// ensure that the target node also has only one input that is not an initializer
const size_t input_edges_total = next_node->GetInputEdgesCount();
int non_const_edges = 0;
for (size_t edge_idx = 0; edge_idx < input_edges_total; ++edge_idx) {
if (!graph_utils::NodeArgIsConstant(graph_viewer.GetGraph(), *next_node->InputDefs()[edge_idx])) {
++non_const_edges;
}
}
if (non_const_edges > 1) {
return nullptr;
} else {
return next_node;
}
}

bool HasElementDataType(const NodeArg& node_arg, int32_t data_type) {
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/optimizer/shape_input_merge.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ Status ShapeInputMerge::ApplyImpl(Graph& graph, bool& modified, int graph_level,
for (size_t i = 1; i < kv.second.size(); ++i) {
Node* p_node = kv.second[i];
const NodeArg* input_arg = p_node->InputDefs()[0];
if (p_node->InputDefs()[0]->Name() == first_input_arg->Name()) continue;
if (!graph.IsInputsIncludingInitializers(input_arg)) {
if (input_arg->Name() == first_input_arg->Name()) continue;
if (!graph.IsInputsIncludingInitializers(input_arg) && p_node->GetInputEdgesCount()) {
const Node::EdgeEnd& input_edge = *p_node->InputEdgesBegin();
graph.RemoveEdge(input_edge.GetNode().Index(), p_node->Index(), input_edge.GetSrcArgIndex(), 0);
}
graph_utils::ReplaceNodeInput(*p_node, 0, *first_input_arg);
if (!is_first_input_arg_graph_input) {
if (!is_first_input_arg_graph_input && kv.second[0]->GetInputEdgesCount()) {
const Node::EdgeEnd& first_input_edge = *kv.second[0]->InputEdgesBegin();
graph.AddEdge(first_input_edge.GetNode().Index(), p_node->Index(), first_input_edge.GetSrcArgIndex(), 0);
}
Expand Down

0 comments on commit b4e5075

Please sign in to comment.