diff --git a/onnxruntime/core/optimizer/gather_fusion.cc b/onnxruntime/core/optimizer/gather_fusion.cc index b994028cbca13..4903bc1d6b961 100644 --- a/onnxruntime/core/optimizer/gather_fusion.cc +++ b/onnxruntime/core/optimizer/gather_fusion.cc @@ -9,7 +9,8 @@ namespace onnxruntime { -bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const { +bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, + int64_t& indices_n_dims) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { return false; @@ -53,6 +54,22 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + InlinedVector node_args; + for (auto node_arg : graph.GetInputs()) { + if (node_arg && graph.GetConsumerNodes(node_arg->Name()).size() > 1) { + node_args.push_back(node_arg); + } + } + + for (auto entry : graph.GetAllInitializedTensors()) { + if (graph.GetConsumerNodes(entry.first).size() > 1) { + auto node_arg = graph.GetNodeArg(entry.first); + if (node_arg) { + node_args.push_back(node_arg); + } + } + } + for (auto node_index : node_topology_list) { auto* p_node = graph.GetNode(node_index); if (p_node == nullptr) continue; // we removed the node as part of an earlier fusion @@ -73,7 +90,11 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le size_t output_count = node.GetOutputEdgesCount(); if (output_count <= 1) continue; - auto shape = node.MutableOutputDefs()[0]->Shape(); + node_args.push_back(node.OutputDefs()[0]); + } + + for (const NodeArg* node_arg : node_args) { + auto shape = node_arg->Shape(); if (!shape) continue; int64_t rank = static_cast(shape->dim_size()); @@ -81,11 +102,14 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le bool first_edge = true; int64_t split_axis = 0; int64_t indices_n_dims = -1; - InlinedVector gather_outputs(output_count, nullptr); + auto consumers = graph.GetConsumerNodes(node_arg->Name()); + size_t consumer_count = consumers.size(); + InlinedVector gather_outputs(consumer_count, nullptr); InlinedVector> nodes_to_fuse; - for (auto it = node.OutputNodesBegin(); it != node.OutputNodesEnd(); ++it) { + for (auto consumer : consumers) { int64_t index, axis, dims; - if (!IsSupportedGather(graph, *it, index, axis, dims)) { + if (!consumer || consumer->InputDefs()[0] != node_arg || + !IsSupportedGather(graph, *consumer, index, axis, dims)) { can_fuse = false; break; } @@ -99,7 +123,7 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le if (axis < 0) axis += rank; if (first_edge) { auto dim = shape->dim(static_cast(axis)); - if (!utils::HasDimValue(dim) || dim.dim_value() != static_cast(output_count)) { + if (!utils::HasDimValue(dim) || dim.dim_value() != static_cast(consumer_count)) { can_fuse = false; break; } @@ -109,12 +133,12 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le can_fuse = false; break; } - if (index < 0) index += static_cast(output_count); - if (index < 0 || index >= static_cast(output_count) || gather_outputs[static_cast(index)]) { + if (index < 0) index += static_cast(consumer_count); + if (index < 0 || index >= static_cast(consumer_count) || gather_outputs[static_cast(index)]) { can_fuse = false; break; } - Node& gather_node = *graph.GetNode(it->Index()); + Node& gather_node = *graph.GetNode(consumer->Index()); nodes_to_fuse.emplace_back(gather_node); gather_outputs[static_cast(index)] = gather_node.MutableOutputDefs()[0]; } @@ -122,8 +146,8 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le if (!can_fuse) continue; ONNX_NAMESPACE::TypeProto split_output_type; - const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast( - node.MutableOutputDefs()[0]->TypeAsProto()->tensor_type().elem_type()); + const ONNX_NAMESPACE::TensorProto_DataType element_type = + static_cast(node_arg->TypeAsProto()->tensor_type().elem_type()); split_output_type.mutable_tensor_type()->set_elem_type(element_type); for (int64_t i = 0; i < rank; ++i) { if (i == split_axis) { @@ -136,16 +160,17 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le InlinedVector split_outputs; bool add_squeeze_node = indices_n_dims == 0; if (add_squeeze_node) { - for (size_t i = 0; i < output_count; ++i) { + for (size_t i = 0; i < consumer_count; ++i) { split_outputs.emplace_back( &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split" + std::to_string(i)), &split_output_type)); } } - Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes", - {node.MutableOutputDefs()[0]}, add_squeeze_node ? split_outputs : gather_outputs); + Node& split_node = + graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes", + {graph.GetNodeArg(node_arg->Name())}, add_squeeze_node ? split_outputs : gather_outputs); split_node.AddAttribute("axis", split_axis); - split_node.SetExecutionProviderType(node.GetExecutionProviderType()); + split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); // Squeeze-11, Squeee-13, Split-13, Split-18 have different schemas. int onnx_opset_version = -1; @@ -155,16 +180,16 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le if (onnx_opset_version < 13) { if (add_squeeze_node) { - for (size_t i = 0; i < output_count; ++i) { + for (size_t i = 0; i < consumer_count; ++i) { Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze", "Squeeze for Fused Gather nodes", {split_outputs[i]}, {gather_outputs[i]}); squeeze_node.AddAttribute("axes", std::vector{split_axis}); - squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType()); + squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); } } } else { if (onnx_opset_version >= 18) { - split_node.AddAttribute("num_outputs", static_cast(output_count)); + split_node.AddAttribute("num_outputs", static_cast(consumer_count)); } if (add_squeeze_node) { @@ -176,11 +201,11 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le axes_initializer_proto.set_raw_data(axes_value.data(), axes_value.size() * sizeof(int64_t)); NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto); - for (size_t i = 0; i < output_count; ++i) { + for (size_t i = 0; i < consumer_count; ++i) { Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze", "Squeeze for Fused Gather nodes", {split_outputs[i], axes_arg}, {gather_outputs[i]}); - squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType()); + squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); } } } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index b82f3345dfcd1..17b26ed7ca4ca 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -6910,7 +6910,10 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); }; - auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); }; + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + return Status::OK(); + }; // OpSet-12 { @@ -6933,8 +6936,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // OpSet-14 @@ -6962,8 +6965,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // OpSet-18 @@ -6991,8 +6994,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } } @@ -7023,7 +7026,10 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) { builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); }; - auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); }; + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + return Status::OK(); + }; // OpSet-12 { @@ -7042,8 +7048,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // OpSet-14 @@ -7063,8 +7069,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // OpSet-18 @@ -7084,13 +7090,180 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); + } +} + +TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Input) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInput({{2, 3, 3, 3}}); + auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); + auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); + auto* gather_index_3 = builder.MakeInitializer({}, {static_cast(2)}); + auto* gather_out_1 = builder.MakeIntermediate(); + auto* gather_out_2 = builder.MakeIntermediate(); + auto* gather_out_3 = builder.MakeIntermediate(); + auto* transpose_out_1 = builder.MakeOutput(); + auto* transpose_out_2 = builder.MakeOutput(); + auto* transpose_out_3 = builder.MakeOutput(); + + builder.AddNode("Gather", {data_arg, gather_index_1}, {gather_out_1}).AddAttribute("axis", static_cast(2)); + builder.AddNode("Gather", {data_arg, gather_index_2}, {gather_out_2}) + .AddAttribute("axis", static_cast(-2)); + builder.AddNode("Gather", {data_arg, gather_index_3}, {gather_out_3}).AddAttribute("axis", static_cast(2)); + builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + return Status::OK(); + }; + + // OpSet-12 + { + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); + } else if (node.OpType() == "Squeeze") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axes") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axes").ints().at(0))); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); + } + + // OpSet-14 + { + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); + } else if (node.OpType() == "Squeeze") { + const NodeArg& input_arg = *(node.InputDefs()[1]); + const ONNX_NAMESPACE::TensorProto* tensor_proto = + graph_utils::GetConstantInitializer(graph, input_arg.Name()); + TEST_RETURN_IF_NOT(tensor_proto != nullptr); + Initializer init_const{*tensor_proto, graph.ModelPath()}; + TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); + TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); + } + + // OpSet-18 + { + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); + } else if (node.OpType() == "Squeeze") { + const NodeArg& input_arg = *(node.InputDefs()[1]); + const ONNX_NAMESPACE::TensorProto* tensor_proto = + graph_utils::GetConstantInitializer(graph, input_arg.Name()); + TEST_RETURN_IF_NOT(tensor_proto != nullptr); + Initializer init_const{*tensor_proto, graph.ModelPath()}; + TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); + TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } } +TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Initializer) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInitializer({2, 3, 3, 3}, std::vector(54)); + auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); + auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); + auto* gather_index_3 = builder.MakeInitializer({}, {static_cast(2)}); + auto* gather_out_1 = builder.MakeIntermediate(); + auto* gather_out_2 = builder.MakeIntermediate(); + auto* gather_out_3 = builder.MakeIntermediate(); + auto* transpose_out_1 = builder.MakeOutput(); + auto* transpose_out_2 = builder.MakeOutput(); + auto* transpose_out_3 = builder.MakeOutput(); + + builder.AddNode("Gather", {data_arg, gather_index_1}, {gather_out_1}).AddAttribute("axis", static_cast(2)); + builder.AddNode("Gather", {data_arg, gather_index_2}, {gather_out_2}) + .AddAttribute("axis", static_cast(-2)); + builder.AddNode("Gather", {data_arg, gather_index_3}, {gather_out_3}).AddAttribute("axis", static_cast(2)); + builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); + } else if (node.OpType() == "Squeeze") { + const NodeArg& input_arg = *(node.InputDefs()[1]); + const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); + TEST_RETURN_IF_NOT(tensor_proto != nullptr); + Initializer init_const{*tensor_proto, graph.ModelPath()}; + TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); + TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); +} + TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { - auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); }; + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + return Status::OK(); + }; auto post_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0); @@ -7130,8 +7303,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // Invalid Gather indices. @@ -7166,8 +7339,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // Invalid Gather axis. @@ -7202,8 +7375,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } } @@ -7250,8 +7423,8 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // OpSet-14, Tind is int64. @@ -7289,8 +7462,8 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } }