Skip to content

Commit

Permalink
fix buffer size issue and adapt topological sort
Browse files Browse the repository at this point in the history
  • Loading branch information
ruiren committed Jan 29, 2024
1 parent e396f00 commit f47d034
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions onnxruntime/core/optimizer/gather_slice_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra

// Fuse 2 Gathers and 1 slice to Split
// Get those outputs as Split outputs
InlinedVector<NodeArg*> split_outputs;
InlinedVector<NodeArg*> split_outputs(3);

InlinedVector<std::reference_wrapper<Node>> nodes_to_fuse;
int64_t gather_node_count = 0, slice_node_count = 0;
Expand Down Expand Up @@ -263,8 +263,9 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra
Node& gather_node = *graph.GetNode(consumer->Index());
nodes_to_fuse.push_back(gather_node);
NodeArg* gather_output_args = gather_node.MutableOutputDefs()[0];
split_outputs.push_back(gather_output_args);
gather_node_count++;
split_outputs[gather_node_count] = gather_output_args;

}

// check the Slice Ops
Expand All @@ -277,7 +278,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra
Node& slice_node = *graph.GetNode(consumer->Index());
NodeArg* slice_output_args = slice_node.MutableOutputDefs()[0];
nodes_to_fuse.push_back(slice_node);
split_outputs.push_back(slice_output_args);
split_outputs[slice_node_count] = slice_output_args;
slice_node_count++;
}
}
Expand Down Expand Up @@ -314,7 +315,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra
// Generate the Split Node
ONNX_NAMESPACE::TensorProto split_initializer_proto;
split_initializer_proto.set_name(graph.GenerateNodeName("fused_Split"));
split_initializer_proto.add_dims(static_cast<int64_t>(1));
split_initializer_proto.add_dims(static_cast<int64_t>(3));
split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);

auto dim_value = shape->dim(static_cast<int>(split_axis)).dim_value();
Expand Down

0 comments on commit f47d034

Please sign in to comment.