From f47d034b60ec1529c04d2f106f3a7ca1782dc592 Mon Sep 17 00:00:00 2001 From: ruiren Date: Mon, 29 Jan 2024 06:06:01 +0000 Subject: [PATCH] fix buffer size issue and adapt topological sort --- onnxruntime/core/optimizer/gather_slice_fusion.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.cc b/onnxruntime/core/optimizer/gather_slice_fusion.cc index 3dfe8b8f4d455..fd7690737b77d 100644 --- a/onnxruntime/core/optimizer/gather_slice_fusion.cc +++ b/onnxruntime/core/optimizer/gather_slice_fusion.cc @@ -211,7 +211,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra // Fuse 2 Gathers and 1 slice to Split // Get those outputs as Split outputs - InlinedVector split_outputs; + InlinedVector split_outputs(3); InlinedVector> nodes_to_fuse; int64_t gather_node_count = 0, slice_node_count = 0; @@ -263,8 +263,9 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra Node& gather_node = *graph.GetNode(consumer->Index()); nodes_to_fuse.push_back(gather_node); NodeArg* gather_output_args = gather_node.MutableOutputDefs()[0]; - split_outputs.push_back(gather_output_args); gather_node_count++; + split_outputs[gather_node_count] = gather_output_args; + } // check the Slice Ops @@ -277,7 +278,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra Node& slice_node = *graph.GetNode(consumer->Index()); NodeArg* slice_output_args = slice_node.MutableOutputDefs()[0]; nodes_to_fuse.push_back(slice_node); - split_outputs.push_back(slice_output_args); + split_outputs[slice_node_count] = slice_output_args; slice_node_count++; } } @@ -314,7 +315,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra // Generate the Split Node ONNX_NAMESPACE::TensorProto split_initializer_proto; split_initializer_proto.set_name(graph.GenerateNodeName("fused_Split")); - split_initializer_proto.add_dims(static_cast(1)); + split_initializer_proto.add_dims(static_cast(3)); split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); auto dim_value = shape->dim(static_cast(split_axis)).dim_value();