From b4e50758c0773d07368c7692027793e42089c882 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= <44298237+gedoensmax@users.noreply.github.com> Date: Wed, 24 Apr 2024 01:19:57 +0200 Subject: [PATCH] Fix shape conv fuse opt (#20282) FIx: - Multiples Convs into an Add+Relu will fuse the op although intermediates are needed ![image](https://github.com/microsoft/onnxruntime/assets/44298237/0c85a30c-5f41-4e62-ae2e-f41eada6c2c3) - Also fixes an issue with Shape Initializers Merge as input, that occurs when the input initializer is the same across multiple nodes but not all nodes are Shape nodes. --- .../core/optimizer/conv_activation_fusion.cc | 15 ++++++++++++++- onnxruntime/core/optimizer/shape_input_merge.cc | 6 +++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index b7cb3ba488c62..12746ad53123a 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -23,7 +23,20 @@ const Node* GetLoneConsumerNode(const GraphViewer& graph_viewer, const Node& nod if (!optimizer_utils::CheckOutputEdges(graph_viewer.GetGraph(), node, 1)) { return nullptr; } - return &*node.OutputNodesBegin(); + const Node* next_node = &*node.OutputNodesBegin(); + // ensure that the target node also has only one input that is not an initializer + const size_t input_edges_total = next_node->GetInputEdgesCount(); + int non_const_edges = 0; + for (size_t edge_idx = 0; edge_idx < input_edges_total; ++edge_idx) { + if (!graph_utils::NodeArgIsConstant(graph_viewer.GetGraph(), *next_node->InputDefs()[edge_idx])) { + ++non_const_edges; + } + } + if (non_const_edges > 1) { + return nullptr; + } else { + return next_node; + } } bool HasElementDataType(const NodeArg& node_arg, int32_t data_type) { diff --git a/onnxruntime/core/optimizer/shape_input_merge.cc b/onnxruntime/core/optimizer/shape_input_merge.cc index 9f20520e3e3f4..dec1382319f16 100644 --- a/onnxruntime/core/optimizer/shape_input_merge.cc +++ b/onnxruntime/core/optimizer/shape_input_merge.cc @@ -58,13 +58,13 @@ Status ShapeInputMerge::ApplyImpl(Graph& graph, bool& modified, int graph_level, for (size_t i = 1; i < kv.second.size(); ++i) { Node* p_node = kv.second[i]; const NodeArg* input_arg = p_node->InputDefs()[0]; - if (p_node->InputDefs()[0]->Name() == first_input_arg->Name()) continue; - if (!graph.IsInputsIncludingInitializers(input_arg)) { + if (input_arg->Name() == first_input_arg->Name()) continue; + if (!graph.IsInputsIncludingInitializers(input_arg) && p_node->GetInputEdgesCount()) { const Node::EdgeEnd& input_edge = *p_node->InputEdgesBegin(); graph.RemoveEdge(input_edge.GetNode().Index(), p_node->Index(), input_edge.GetSrcArgIndex(), 0); } graph_utils::ReplaceNodeInput(*p_node, 0, *first_input_arg); - if (!is_first_input_arg_graph_input) { + if (!is_first_input_arg_graph_input && kv.second[0]->GetInputEdgesCount()) { const Node::EdgeEnd& first_input_edge = *kv.second[0]->InputEdgesBegin(); graph.AddEdge(first_input_edge.GetNode().Index(), p_node->Index(), first_input_edge.GetSrcArgIndex(), 0); }