diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.cc b/onnxruntime/core/optimizer/gather_slice_fusion.cc index 0b7b2d53e0728..21266d356a020 100644 --- a/onnxruntime/core/optimizer/gather_slice_fusion.cc +++ b/onnxruntime/core/optimizer/gather_slice_fusion.cc @@ -211,7 +211,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra InlinedVector split_outputs(3); InlinedVector> nodes_to_fuse; - size_t gather_node_count = 0, slice_node_count = 0; + size_t gather_node_count = 2, slice_node_count = 0; // find the nodes to be merged for (auto consumer : consumers) { @@ -260,7 +260,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra Node& gather_node = *graph.GetNode(consumer->Index()); nodes_to_fuse.push_back(gather_node); NodeArg* gather_output_args = gather_node.MutableOutputDefs()[0]; - split_outputs[++gather_node_count] = gather_output_args; + split_outputs[gather_node_count--] = gather_output_args; } // check the Slice Ops @@ -273,13 +273,12 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra Node& slice_node = *graph.GetNode(consumer->Index()); NodeArg* slice_output_args = slice_node.MutableOutputDefs()[0]; nodes_to_fuse.push_back(slice_node); - split_outputs[slice_node_count] = slice_output_args; - slice_node_count++; + split_outputs[slice_node_count++] = slice_output_args; } } // condition check - if (!can_fuse || gather_node_count != 2 || slice_node_count != 1) continue; + if (!can_fuse || gather_node_count != 0 || slice_node_count != 1) continue; // generate the split node and merge the kernel ONNX_NAMESPACE::TypeProto split_output_type; @@ -310,7 +309,8 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); auto dim_value = shape->dim(static_cast(split_axis)).dim_value(); - int64_t slice_dim = static_cast(dim_value - gather_node_count); + // Optimize 2 Gather Nodes, so Slice_dim = dim_value - 2 + int64_t slice_dim = static_cast(dim_value - 2); InlinedVector split_value{{slice_dim, 1, 1}}; split_initializer_proto.set_raw_data(split_value.data(), split_value.size() * sizeof(int64_t)); NodeArg* split_arg = &graph_utils::AddInitializer(graph, split_initializer_proto);