diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.cc b/onnxruntime/core/optimizer/gather_slice_fusion.cc index 309d714ffbe79..b783f953827d7 100644 --- a/onnxruntime/core/optimizer/gather_slice_fusion.cc +++ b/onnxruntime/core/optimizer/gather_slice_fusion.cc @@ -185,9 +185,11 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra // |---> Gather // Reshape |---> Gather // |---> Slice - if (output_count != 3) continue; + // |... or (other ops) // Get the output into node args + if (output_count < 3) continue; + output_args.push_back(node.OutputDefs()[0]); } @@ -196,7 +198,6 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra auto shape = node_arg->Shape(); if (!shape) continue; - // ??? What is the consumers here ??? --> Reshape auto consumers = graph.GetConsumerNodes(node_arg->Name()); size_t consumer_count = consumers.size(); @@ -208,8 +209,9 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra int64_t split_axis = 0; int64_t indices_n_dims = -1; - // 2 Gather, and 1 slice... - InlinedVector reshape_outputs; + // Fuse 2 Gathers and 1 slice to Split + // Get those outputs as Split outputs + InlinedVector split_outputs; InlinedVector> nodes_to_fuse; int64_t gather_node_count = 0, slice_node_count = 0; @@ -224,7 +226,6 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra if ((!consumer || consumer->InputDefs()[0] != node_arg) || (!IsSupportedGatherOps && !IsSupportedSliceOps)) { - can_fuse = false; break; } @@ -262,7 +263,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra Node& gather_node = *graph.GetNode(consumer->Index()); nodes_to_fuse.push_back(gather_node); NodeArg* gather_output_args = gather_node.MutableOutputDefs()[0]; - reshape_outputs.push_back(gather_output_args); + split_outputs.push_back(gather_output_args); gather_node_count++; } @@ -276,7 +277,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra Node& slice_node = *graph.GetNode(consumer->Index()); NodeArg* slice_output_args = slice_node.MutableOutputDefs()[0]; nodes_to_fuse.push_back(slice_node); - reshape_outputs.push_back(slice_output_args); + split_outputs.push_back(slice_output_args); slice_node_count++; } } @@ -300,7 +301,6 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra } } - InlinedVector split_outputs; for (size_t i = 0; i < consumer_count; ++i) { split_outputs.push_back( @@ -310,8 +310,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra ); } - // how to have multiple output node - // do we need to add the Split [71, 1, 1] information here. + // Generate the Split Node ONNX_NAMESPACE::TensorProto split_initializer_proto; split_initializer_proto.set_name(graph.GenerateNodeName("fused_Split")); split_initializer_proto.add_dims(static_cast(1)); @@ -323,9 +322,10 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra split_initializer_proto.set_raw_data(split_value.data(), split_value.size() * sizeof(int64_t)); NodeArg* split_arg = &graph_utils::AddInitializer(graph, split_initializer_proto); + Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for fused Gather-Slice fusion", - {graph.GetNodeArg(node_arg->Name()), split_arg}, reshape_outputs); + split_inputs, split_outputs); split_node.AddAttribute("axis", split_axis); // to do here diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 4aedc354cfed0..fbfa43e3d3ee7 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -7647,31 +7647,35 @@ TEST_F(GraphTransformationTests, GatherSliceToSplitFusion) { { auto build_test_case = [&](ModelTestBuilder& builder) { auto* data_arg = builder.MakeInput({{54}}); - auto* shape_arg = builder.MakeInput({{4}}); + auto* reshape_arg = builder.MakeInput({{4}}); auto* reshape_out = builder.MakeIntermediate({{2, 512, 73, 64}}); - builder.AddNode("Reshape", {data_arg, shape_arg}, {reshape_out}); + builder.AddNode("Reshape", {data_arg, reshape_arg}, {reshape_out}); + + // Create Shape-0 Ops + auto* shape_output_0 = builder.MakeOutput(); + builder.AddNode("Shape", {reshape_out}, {shape_output_0}); // Create Gather-1 Ops - auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(2)}); - auto* gather_out_1 = builder.MakeIntermediate(); + auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(-2)}); + auto* gather_out_1 = builder.MakeIntermediate({{2, 512, 1, 64}}); builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) .AddAttribute("axis", static_cast(2)); // Create Transpose 1-Ops auto* transpose_out_1 = builder.MakeOutput(); builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) - .AddAttribute("perm", std::vector{0, 2, 1}); + .AddAttribute("perm", std::vector{0, 2, 1, 3}); // Create Gather-2 Ops - auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); - auto* gather_out_2 = builder.MakeIntermediate(); + auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(-1)}); + auto* gather_out_2 = builder.MakeIntermediate({{2, 512, 1, 64}}); builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2}) .AddAttribute("axis", static_cast(2)); // Create Transpose-2 Ops auto* transpose_out_2 = builder.MakeOutput(); builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}) - .AddAttribute("perm", std::vector{0, 2, 1}); + .AddAttribute("perm", std::vector{0, 2, 1, 3}); // Create Slice Ops auto* slice_output = builder.MakeIntermediate(); @@ -7692,7 +7696,7 @@ TEST_F(GraphTransformationTests, GatherSliceToSplitFusion) { // Create Transpose-3 Ops auto* transpose_out_3 = builder.MakeOutput(); builder.AddNode("Transpose", {slice_output}, {transpose_out_3}) - .AddAttribute("perm", std::vector{0, 2, 1}); + .AddAttribute("perm", std::vector{0, 2, 1, 3}); }; auto pre_graph_checker = [&](Graph& graph) {