From 9d0b66e5b5e4ee619cb55fd0d4ddffd29bc1fe41 Mon Sep 17 00:00:00 2001 From: ruiren Date: Mon, 22 Jan 2024 07:33:01 +0000 Subject: [PATCH 01/11] add GatherSliceToSplitFusion and Unittest --- .../core/optimizer/gather_slice_fusion.cc | 353 ++++++++++++++++++ .../core/optimizer/gather_slice_fusion.h | 34 ++ .../core/optimizer/graph_transformer_utils.cc | 2 + .../test/optimizer/graph_transform_test.cc | 79 ++++ .../optimizer/graph_transform_test_builder.cc | 2 +- .../core/optimizer/graph_transformer_utils.cc | 2 + 6 files changed, 471 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/core/optimizer/gather_slice_fusion.cc create mode 100644 onnxruntime/core/optimizer/gather_slice_fusion.h diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.cc b/onnxruntime/core/optimizer/gather_slice_fusion.cc new file mode 100644 index 0000000000000..0a94ec2b18229 --- /dev/null +++ b/onnxruntime/core/optimizer/gather_slice_fusion.cc @@ -0,0 +1,353 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + + +#include "core/optimizer/gather_slice_fusion.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" + + +namespace onnxruntime { + +bool GatherSliceToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, + int64_t& axis, int64_t& indices_n_dims) const { + + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { + return false; + } + + const NodeArg& input_arg = *(node.InputDefs()[1]); + + if (!optimizer_utils::IsScalar(input_arg)) return false; + + const ONNX_NAMESPACE::TensorProto* indices_init = graph_utils::GetConstantInitializer(graph, input_arg.Name()); + + if (!indices_init) return false; + + if (indices_init->data_type() != ONNX_NAMESPACE::TensorProto::INT64) return false; + + // get the index value + Initializer init_const(*indices_init, graph.ModelPath()); + index = *(init_const.data()); + + // get attributes value + axis = 0; + auto& attrs = node.GetAttributes(); + if (attrs.find("axis") != attrs.end()) { + auto& axis_attr = attrs.at("axis"); + if (utils::HasInt(axis_attr)) axis = axis_attr.i(); + } + + indices_n_dims = indices_init->dims_size(); + return true; +} + + +bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& node, + InlinedVector& starts, + InlinedVector& ends, + InlinedVector& axes, + InlinedVector& steps) const { + // check the version of Slice ops + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11, 13}) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { + return false; + } + + // get the opset version + int onnx_opset_version = -1; + if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { + onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); + }; + + // If Slice op of opset version 1 + if (onnx_opset_version == 1) { + if (!graph_utils::GetRepeatedNodeAttributeValues(node, "starts", starts) || + !graph_utils::GetRepeatedNodeAttributeValues(node, "ends", ends) || + starts.size() != ends.size()) { + return false; + } + + if (graph_utils::GetRepeatedNodeAttributeValues(node, "axes", axes) && (axes.size() != starts.size())) { + return false; + } + }; + + // If Slice op of opset version >= 10 + if (onnx_opset_version >= 10) { + // node inputs include: starts - ends - axes - steps + + // return a pointer to the corresponding NodeArg if input of the node at the index exists + auto get_input_if_exists = [&node](size_t input_index) -> const NodeArg* { + const auto& input_defs = node.InputDefs(); + const NodeArg* input = (input_defs.size() > input_index) ? input_defs[input_index] : nullptr; + return (input == nullptr || !input->Exists()) ? nullptr : input; + }; + + // return a pointer to the initializer if it is constant; otherwise, a nullptr + auto get_initializer_if_constant = + [&graph, get_input_if_exists](size_t input_index) -> const ONNX_NAMESPACE::TensorProto* { + const NodeArg* input = get_input_if_exists(input_index); + return input ? graph_utils::GetConstantInitializer(graph, input->Name()) : nullptr; + }; + + // return the initialization data if it is constant + auto get_initializer_data = + [&graph](const ONNX_NAMESPACE::TensorProto* slice_initializer) -> InlinedVector { + Initializer init(*slice_initializer, graph.ModelPath()); + if (slice_initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT32) { + int32_t* init_data = init.data(); + return InlinedVector(init_data, init_data + init.size()); + } + + if (slice_initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT64) { + int64_t* init_data = init.data(); + return InlinedVector(init_data, init_data + init.size()); + } + return {}; + }; + + // starts and ends inputs have to exist, be constants and be of the same size. + const ONNX_NAMESPACE::TensorProto* starts_init = get_initializer_if_constant(1); + const ONNX_NAMESPACE::TensorProto* ends_init = get_initializer_if_constant(2); + const ONNX_NAMESPACE::TensorProto* axes_init = get_initializer_if_constant(3); + const ONNX_NAMESPACE::TensorProto* steps_init = get_initializer_if_constant(4); + + if (!starts_init || !ends_init || !axes_init || !steps_init) { + return false; + } + + starts = get_initializer_data(starts_init); + ends = get_initializer_data(ends_init); + axes = get_initializer_data(axes_init); + steps = get_initializer_data(steps_init); + + if (starts.size() == 0 || ends.size() == 0 || starts.size() != ends.size()) { + return false; + } + + // TODO: what does this mean ? + if (axes_init->dims_size() != 1 || static_cast(axes_init->dims().Get(0)) != starts.size()) { + return false; + } + + // if steps exists, it should be constant and all value should be 1 + if (steps.size() != starts.size()) { + return false; + } + + for (int64_t step : steps) { + if (step != 1) { + return false; + } + } + } + + return true; +} + +/* +GatherToSplitFusion is to fuse: + Node + |-> Gather(index=0, axis=axis) + |-> Gather(index=1, axis=axis) + |-> Slice(index=2, axis=axis) +To + Node + |-> Split(index=0) +So that we can use one kernel to finish the job. +*/ + +Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + InlinedVector output_args; + + // Iterate the topological order and get Reshape ops + for (auto node_index : node_topology_list) { + auto* p_node = graph.GetNode(node_index); + + if (p_node == nullptr) continue; + + Node& node = *p_node; + + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + // Currently only catch after Reshape ops, optimize in the future + if (node.OpType() != "Reshape") continue; + + size_t output_count = node.GetOutputEdgesCount(); + + // We only catch 1 scenario for Multi Query Attention for now. + // |---> Gather + // Reshape |---> Gather + // |---> Slice + if (output_count != 3) continue; + + // Get the output into node args + output_args.push_back(node.OutputDefs()[0]); + } + + // iterate the children of Reshape node + for (const NodeArg* node_arg : output_args) { + auto shape = node_arg->Shape(); + if (!shape) continue; + + // ??? What is the consumers here ??? --> Reshape + auto consumers = graph.GetConsumerNodes(node_arg->Name()); + size_t consumer_count = consumers.size(); + + // get the tensor rank + int64_t rank = static_cast(shape->dim_size()); + + bool can_fuse = true; + bool first_edge = true; + int64_t split_axis = 0; + int64_t indices_n_dims = -1; + + // TODO: How to catch up the Slice output value + // 2 Gather, and 1 slice... + InlinedVector reshape_outputs; + + InlinedVector> nodes_to_fuse; + int64_t gather_node_count = 0, slice_node_count = 0; + + // find the nodes to be merged + for (auto consumer : consumers) { + int64_t index, axis, dims; + InlinedVector starts, ends, axes, steps; + + bool IsSupportedGatherOps = IsSupportedGather(graph, *consumer, index, axis, dims); + bool IsSupportedSliceOps = IsSupportedSlice(graph, *consumer, starts, ends, axes, steps); + + if ((!consumer || consumer->InputDefs()[0] != node_arg) || + (!IsSupportedGatherOps && !IsSupportedSliceOps)) { + can_fuse = false; + break; + } + + if (IsSupportedGatherOps) { + if (indices_n_dims == -1) { + indices_n_dims = dims; + } else if (indices_n_dims != dims) { + // Not the same number of dimensions (0 or 1) for all scalar indices. + can_fuse = false; + break; + } + + if (axis <0) axis += rank; + + if (first_edge) { + auto dim = shape->dim(static_cast(axis)); + // dim.dim_value() = 73 + if (!utils::HasDimValue(dim)) { + can_fuse = false; + break; + } + split_axis = axis; + first_edge = false; + } else if (axis != split_axis) { + can_fuse = false; + break; + } + + if (index < 0) index += static_cast(consumer_count); + if (index < 0 || index >= static_cast(consumer_count)) { + can_fuse = false; + break; + } + + Node& gather_node = *graph.GetNode(consumer->Index()); + nodes_to_fuse.push_back(gather_node); + NodeArg* gather_output_args = gather_node.MutableOutputDefs()[0]; + reshape_outputs.push_back(gather_output_args); + gather_node_count++; + } + + // check the Slice Ops + if (IsSupportedSliceOps) { + if (axes[0] != axis && !first_edge) { + can_fuse = false; + break; + } + + Node& slice_node = *graph.GetNode(consumer->Index()); + NodeArg* slice_output_args = slice_node.MutableOutputDefs()[0]; + nodes_to_fuse.push_back(slice_node); + reshape_outputs.push_back(slice_output_args); + slice_node_count++; + } + } + + // condition check + if (!can_fuse || gather_node_count != 2 || slice_node_count != 1) continue; + + // generate the split node and merge the kernel + ONNX_NAMESPACE::TypeProto split_output_type; + const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast( + node_arg->TypeAsProto()->tensor_type().elem_type()); + + split_output_type.mutable_tensor_type()->set_elem_type(element_type); + + for (int64_t i = 0; i < rank; i++) { + if (i == split_axis) + split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL); + else { + *(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast(i)); + } + }; + + InlinedVector split_outputs; + + for (size_t i = 0; i < consumer_count; ++i) { + split_outputs.push_back( + &graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("fused_split_" + std::to_string(i)), &split_output_type + ) + ); + } + + // how to have multiple output node + // do we need to add the Split [71, 1, 1] information here. + ONNX_NAMESPACE::TensorProto split_initializer_proto; + split_initializer_proto.set_name(graph.GenerateNodeName("fused_Split")); + split_initializer_proto.add_dims(static_cast(1)); + split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + + auto dim_value = shape->dim(static_cast(split_axis))->dim_value(); + InlinedVector split_value{{dim_value - gather_node_count, 1, 1}}; + split_initializer_proto.set_raw_data(split_value.data(), split_value.size() * sizeof(int64_t)); + NodeArg* split_arg = &graph_utils::AddInitializer(graph, split_initializer_proto); + + Node& split_node = + graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for fused Gather-Slice fusion", + {graph.GetNodeArg(node_arg->Name()), split_arg}, reshape_outputs); + + split_node.AddAttribute("axis", split_axis); + // to do here + split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); + + int onnx_opset_version = -1; + if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { + onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); + } + + if (onnx_opset_version >= 18) { + split_node.AddAttribute("num_outputs", static_cast(consumer_count)); + } + + for (Node& node_to_fuse : nodes_to_fuse) { + graph_utils::RemoveNodeOutputEdges(graph, node_to_fuse); + graph.RemoveNode(node_to_fuse.Index()); + } + modified = true; + } + + return Status::OK(); +} +} diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.h b/onnxruntime/core/optimizer/gather_slice_fusion.h new file mode 100644 index 0000000000000..63416c5274f0b --- /dev/null +++ b/onnxruntime/core/optimizer/gather_slice_fusion.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** +@class GatherSliceToSplitFusion +Fuse (2 Gather nodes + 1 Slice) to 1 split node. +*/ + +class GatherSliceToSplitFusion : public GraphTransformer { +private: + bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, + int64_t& indices_n_dims) const; + + bool IsSupportedSlice(const Graph& graph, const Node& node, + InlinedVector& starts, + InlinedVector& ends, + InlinedVector& axes, + InlinedVector& steps) const; + +public: + GatherSliceToSplitFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("GatherSliceToSplitFusion", compatible_execution_providers) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; + +}; +} diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index cd3c49be15aa4..4e939fe3c7b6b 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -37,6 +37,7 @@ #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/free_dim_override_transformer.h" #include "core/optimizer/gather_fusion.h" +#include "core/optimizer/gather_slice_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -308,6 +309,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index bf02c1741725f..4aedc354cfed0 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -42,6 +42,7 @@ #include "core/optimizer/expand_elimination.h" #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/gather_fusion.h" +#include "core/optimizer/gather_slice_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -7642,5 +7643,83 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) { } } +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion) { + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInput({{54}}); + auto* shape_arg = builder.MakeInput({{4}}); + auto* reshape_out = builder.MakeIntermediate({{2, 512, 73, 64}}); + builder.AddNode("Reshape", {data_arg, shape_arg}, {reshape_out}); + + // Create Gather-1 Ops + auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(2)}); + auto* gather_out_1 = builder.MakeIntermediate(); + builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) + .AddAttribute("axis", static_cast(2)); + + // Create Transpose 1-Ops + auto* transpose_out_1 = builder.MakeOutput(); + builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) + .AddAttribute("perm", std::vector{0, 2, 1}); + + // Create Gather-2 Ops + auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); + auto* gather_out_2 = builder.MakeIntermediate(); + builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2}) + .AddAttribute("axis", static_cast(2)); + + // Create Transpose-2 Ops + auto* transpose_out_2 = builder.MakeOutput(); + builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}) + .AddAttribute("perm", std::vector{0, 2, 1}); + + // Create Slice Ops + auto* slice_output = builder.MakeIntermediate(); + auto* starts = builder.MakeInitializer({1}, {0}); + auto* ends = builder.MakeInitializer({1}, {-2}); + auto* axes = builder.MakeInitializer({1}, {2}); + auto* steps = builder.MakeInitializer({1}, {1}); + builder.AddNode("Slice", {reshape_out, starts, ends, axes, steps}, {slice_output}); + + // Create Shape-1 Ops + auto* shape_output_1 = builder.MakeOutput(); + builder.AddNode("Shape", {slice_output}, {shape_output_1}); + + // Create Shape-2 Ops + auto* shape_output_2 = builder.MakeOutput(); + builder.AddNode("Shape", {slice_output}, {shape_output_2}); + + // Create Transpose-3 Ops + auto* transpose_out_3 = builder.MakeOutput(); + builder.AddNode("Transpose", {slice_output}, {transpose_out_3}) + .AddAttribute("perm", std::vector{0, 2, 1}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(static_cast(attrs.at("axis").i()) == 2); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); + }; +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.cc b/onnxruntime/test/optimizer/graph_transform_test_builder.cc index a5024f510b3cd..1ad45d9d27aad 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.cc @@ -15,7 +15,7 @@ #include "test/util/include/inference_session_wrapper.h" // enable to dump model for debugging -#define SAVE_TEST_GRAPH 0 +#define SAVE_TEST_GRAPH 1 namespace onnxruntime { namespace test { diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 894fe3b052fb2..0b68dc65e41cd 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -24,6 +24,7 @@ #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/free_dim_override_transformer.h" #include "core/optimizer/gather_fusion.h" +#include "core/optimizer/gather_slice_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -140,6 +141,7 @@ std::vector> GeneratePreTrainingTransformers( transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); + transformers.emplace_back(std::make_unique(compatible_eps)); // If a model with Q, DQ nodes is being used for the purpose of training, it must be for // Quantization Aware Training. So, replace QDQ nodes with FakeQuant. transformers.emplace_back(std::make_unique(compatible_eps)); From 2edb4b31d8984a4da00397a37dcb75f2ab5c2cba Mon Sep 17 00:00:00 2001 From: ruiren Date: Mon, 22 Jan 2024 08:16:17 +0000 Subject: [PATCH 02/11] lint --- .../core/optimizer/gather_slice_fusion.cc | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.cc b/onnxruntime/core/optimizer/gather_slice_fusion.cc index 0a94ec2b18229..309d714ffbe79 100644 --- a/onnxruntime/core/optimizer/gather_slice_fusion.cc +++ b/onnxruntime/core/optimizer/gather_slice_fusion.cc @@ -12,7 +12,6 @@ namespace onnxruntime { bool GatherSliceToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const { - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { return false; @@ -60,7 +59,7 @@ bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& int onnx_opset_version = -1; if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); - }; + } // If Slice op of opset version 1 if (onnx_opset_version == 1) { @@ -73,7 +72,7 @@ bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& if (graph_utils::GetRepeatedNodeAttributeValues(node, "axes", axes) && (axes.size() != starts.size())) { return false; } - }; + } // If Slice op of opset version >= 10 if (onnx_opset_version >= 10) { @@ -128,7 +127,6 @@ bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& return false; } - // TODO: what does this mean ? if (axes_init->dims_size() != 1 || static_cast(axes_init->dims().Get(0)) != starts.size()) { return false; } @@ -210,7 +208,6 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra int64_t split_axis = 0; int64_t indices_n_dims = -1; - // TODO: How to catch up the Slice output value // 2 Gather, and 1 slice... InlinedVector reshape_outputs; @@ -298,9 +295,10 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra if (i == split_axis) split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL); else { - *(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast(i)); + *(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) + = shape->dim(static_cast(i)); } - }; + } InlinedVector split_outputs; @@ -319,8 +317,9 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra split_initializer_proto.add_dims(static_cast(1)); split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - auto dim_value = shape->dim(static_cast(split_axis))->dim_value(); - InlinedVector split_value{{dim_value - gather_node_count, 1, 1}}; + auto dim_value = shape->dim(static_cast(split_axis)).dim_value(); + int64_t slice_dim = static_cast(dim_value - gather_node_count); + InlinedVector split_value{{slice_dim, 1, 1}}; split_initializer_proto.set_raw_data(split_value.data(), split_value.size() * sizeof(int64_t)); NodeArg* split_arg = &graph_utils::AddInitializer(graph, split_initializer_proto); From f5e7c81352912efd05a47e1ea117436ab3757287 Mon Sep 17 00:00:00 2001 From: ruiren Date: Wed, 24 Jan 2024 04:17:18 +0000 Subject: [PATCH 03/11] update the unitest --- .../core/optimizer/gather_slice_fusion.cc | 22 +++++++++---------- .../test/optimizer/graph_transform_test.cc | 22 +++++++++++-------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.cc b/onnxruntime/core/optimizer/gather_slice_fusion.cc index 309d714ffbe79..b783f953827d7 100644 --- a/onnxruntime/core/optimizer/gather_slice_fusion.cc +++ b/onnxruntime/core/optimizer/gather_slice_fusion.cc @@ -185,9 +185,11 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra // |---> Gather // Reshape |---> Gather // |---> Slice - if (output_count != 3) continue; + // |... or (other ops) // Get the output into node args + if (output_count < 3) continue; + output_args.push_back(node.OutputDefs()[0]); } @@ -196,7 +198,6 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra auto shape = node_arg->Shape(); if (!shape) continue; - // ??? What is the consumers here ??? --> Reshape auto consumers = graph.GetConsumerNodes(node_arg->Name()); size_t consumer_count = consumers.size(); @@ -208,8 +209,9 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra int64_t split_axis = 0; int64_t indices_n_dims = -1; - // 2 Gather, and 1 slice... - InlinedVector reshape_outputs; + // Fuse 2 Gathers and 1 slice to Split + // Get those outputs as Split outputs + InlinedVector split_outputs; InlinedVector> nodes_to_fuse; int64_t gather_node_count = 0, slice_node_count = 0; @@ -224,7 +226,6 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra if ((!consumer || consumer->InputDefs()[0] != node_arg) || (!IsSupportedGatherOps && !IsSupportedSliceOps)) { - can_fuse = false; break; } @@ -262,7 +263,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra Node& gather_node = *graph.GetNode(consumer->Index()); nodes_to_fuse.push_back(gather_node); NodeArg* gather_output_args = gather_node.MutableOutputDefs()[0]; - reshape_outputs.push_back(gather_output_args); + split_outputs.push_back(gather_output_args); gather_node_count++; } @@ -276,7 +277,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra Node& slice_node = *graph.GetNode(consumer->Index()); NodeArg* slice_output_args = slice_node.MutableOutputDefs()[0]; nodes_to_fuse.push_back(slice_node); - reshape_outputs.push_back(slice_output_args); + split_outputs.push_back(slice_output_args); slice_node_count++; } } @@ -300,7 +301,6 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra } } - InlinedVector split_outputs; for (size_t i = 0; i < consumer_count; ++i) { split_outputs.push_back( @@ -310,8 +310,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra ); } - // how to have multiple output node - // do we need to add the Split [71, 1, 1] information here. + // Generate the Split Node ONNX_NAMESPACE::TensorProto split_initializer_proto; split_initializer_proto.set_name(graph.GenerateNodeName("fused_Split")); split_initializer_proto.add_dims(static_cast(1)); @@ -323,9 +322,10 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra split_initializer_proto.set_raw_data(split_value.data(), split_value.size() * sizeof(int64_t)); NodeArg* split_arg = &graph_utils::AddInitializer(graph, split_initializer_proto); + Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for fused Gather-Slice fusion", - {graph.GetNodeArg(node_arg->Name()), split_arg}, reshape_outputs); + split_inputs, split_outputs); split_node.AddAttribute("axis", split_axis); // to do here diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 4aedc354cfed0..fbfa43e3d3ee7 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -7647,31 +7647,35 @@ TEST_F(GraphTransformationTests, GatherSliceToSplitFusion) { { auto build_test_case = [&](ModelTestBuilder& builder) { auto* data_arg = builder.MakeInput({{54}}); - auto* shape_arg = builder.MakeInput({{4}}); + auto* reshape_arg = builder.MakeInput({{4}}); auto* reshape_out = builder.MakeIntermediate({{2, 512, 73, 64}}); - builder.AddNode("Reshape", {data_arg, shape_arg}, {reshape_out}); + builder.AddNode("Reshape", {data_arg, reshape_arg}, {reshape_out}); + + // Create Shape-0 Ops + auto* shape_output_0 = builder.MakeOutput(); + builder.AddNode("Shape", {reshape_out}, {shape_output_0}); // Create Gather-1 Ops - auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(2)}); - auto* gather_out_1 = builder.MakeIntermediate(); + auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(-2)}); + auto* gather_out_1 = builder.MakeIntermediate({{2, 512, 1, 64}}); builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) .AddAttribute("axis", static_cast(2)); // Create Transpose 1-Ops auto* transpose_out_1 = builder.MakeOutput(); builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) - .AddAttribute("perm", std::vector{0, 2, 1}); + .AddAttribute("perm", std::vector{0, 2, 1, 3}); // Create Gather-2 Ops - auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); - auto* gather_out_2 = builder.MakeIntermediate(); + auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(-1)}); + auto* gather_out_2 = builder.MakeIntermediate({{2, 512, 1, 64}}); builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2}) .AddAttribute("axis", static_cast(2)); // Create Transpose-2 Ops auto* transpose_out_2 = builder.MakeOutput(); builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}) - .AddAttribute("perm", std::vector{0, 2, 1}); + .AddAttribute("perm", std::vector{0, 2, 1, 3}); // Create Slice Ops auto* slice_output = builder.MakeIntermediate(); @@ -7692,7 +7696,7 @@ TEST_F(GraphTransformationTests, GatherSliceToSplitFusion) { // Create Transpose-3 Ops auto* transpose_out_3 = builder.MakeOutput(); builder.AddNode("Transpose", {slice_output}, {transpose_out_3}) - .AddAttribute("perm", std::vector{0, 2, 1}); + .AddAttribute("perm", std::vector{0, 2, 1, 3}); }; auto pre_graph_checker = [&](Graph& graph) { From fe3304963f8ec82f4d968d1ad1f6e45cb2eeeea1 Mon Sep 17 00:00:00 2001 From: ruiren Date: Wed, 24 Jan 2024 16:29:07 +0000 Subject: [PATCH 04/11] fix type and test on build --- onnxruntime/core/optimizer/gather_slice_fusion.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.cc b/onnxruntime/core/optimizer/gather_slice_fusion.cc index b783f953827d7..10b69761c2047 100644 --- a/onnxruntime/core/optimizer/gather_slice_fusion.cc +++ b/onnxruntime/core/optimizer/gather_slice_fusion.cc @@ -322,10 +322,9 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra split_initializer_proto.set_raw_data(split_value.data(), split_value.size() * sizeof(int64_t)); NodeArg* split_arg = &graph_utils::AddInitializer(graph, split_initializer_proto); - Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for fused Gather-Slice fusion", - split_inputs, split_outputs); + {graph.GetNodeArg(node_arg->Name()), split_arg}, split_outputs); split_node.AddAttribute("axis", split_axis); // to do here From 79c37ba25a2de8aeb9dc3bb4f100167d755b91b8 Mon Sep 17 00:00:00 2001 From: ruiren Date: Wed, 24 Jan 2024 17:43:42 +0000 Subject: [PATCH 05/11] typo --- onnxruntime/core/optimizer/gather_slice_fusion.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.cc b/onnxruntime/core/optimizer/gather_slice_fusion.cc index 10b69761c2047..3dfe8b8f4d455 100644 --- a/onnxruntime/core/optimizer/gather_slice_fusion.cc +++ b/onnxruntime/core/optimizer/gather_slice_fusion.cc @@ -301,9 +301,10 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra } } + InlinedVector split_output_types; for (size_t i = 0; i < consumer_count; ++i) { - split_outputs.push_back( + split_output_types.push_back( &graph.GetOrCreateNodeArg( graph.GenerateNodeArgName("fused_split_" + std::to_string(i)), &split_output_type ) From f47d034b60ec1529c04d2f106f3a7ca1782dc592 Mon Sep 17 00:00:00 2001 From: ruiren Date: Mon, 29 Jan 2024 06:06:01 +0000 Subject: [PATCH 06/11] fix buffer size issue and adapt topological sort --- onnxruntime/core/optimizer/gather_slice_fusion.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.cc b/onnxruntime/core/optimizer/gather_slice_fusion.cc index 3dfe8b8f4d455..fd7690737b77d 100644 --- a/onnxruntime/core/optimizer/gather_slice_fusion.cc +++ b/onnxruntime/core/optimizer/gather_slice_fusion.cc @@ -211,7 +211,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra // Fuse 2 Gathers and 1 slice to Split // Get those outputs as Split outputs - InlinedVector split_outputs; + InlinedVector split_outputs(3); InlinedVector> nodes_to_fuse; int64_t gather_node_count = 0, slice_node_count = 0; @@ -263,8 +263,9 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra Node& gather_node = *graph.GetNode(consumer->Index()); nodes_to_fuse.push_back(gather_node); NodeArg* gather_output_args = gather_node.MutableOutputDefs()[0]; - split_outputs.push_back(gather_output_args); gather_node_count++; + split_outputs[gather_node_count] = gather_output_args; + } // check the Slice Ops @@ -277,7 +278,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra Node& slice_node = *graph.GetNode(consumer->Index()); NodeArg* slice_output_args = slice_node.MutableOutputDefs()[0]; nodes_to_fuse.push_back(slice_node); - split_outputs.push_back(slice_output_args); + split_outputs[slice_node_count] = slice_output_args; slice_node_count++; } } @@ -314,7 +315,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra // Generate the Split Node ONNX_NAMESPACE::TensorProto split_initializer_proto; split_initializer_proto.set_name(graph.GenerateNodeName("fused_Split")); - split_initializer_proto.add_dims(static_cast(1)); + split_initializer_proto.add_dims(static_cast(3)); split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); auto dim_value = shape->dim(static_cast(split_axis)).dim_value(); From f550f7aa407ed533ba0b5a64754375ccfdbe5f66 Mon Sep 17 00:00:00 2001 From: ruiren Date: Mon, 29 Jan 2024 06:26:42 +0000 Subject: [PATCH 07/11] update unitest --- onnxruntime/test/optimizer/graph_transform_test.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index fbfa43e3d3ee7..bd125b7318335 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -7651,10 +7651,6 @@ TEST_F(GraphTransformationTests, GatherSliceToSplitFusion) { auto* reshape_out = builder.MakeIntermediate({{2, 512, 73, 64}}); builder.AddNode("Reshape", {data_arg, reshape_arg}, {reshape_out}); - // Create Shape-0 Ops - auto* shape_output_0 = builder.MakeOutput(); - builder.AddNode("Shape", {reshape_out}, {shape_output_0}); - // Create Gather-1 Ops auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(-2)}); auto* gather_out_1 = builder.MakeIntermediate({{2, 512, 1, 64}}); From abf2994b9a57db85cabbc4d2ef7e38d5f5a6de07 Mon Sep 17 00:00:00 2001 From: ruiren Date: Wed, 7 Feb 2024 05:46:27 +0000 Subject: [PATCH 08/11] update the test file and lint format --- .../core/optimizer/gather_slice_fusion.cc | 539 +++++++++--------- .../core/optimizer/gather_slice_fusion.h | 28 +- .../test/optimizer/graph_transform_test.cc | 78 ++- .../optimizer/graph_transform_test_builder.cc | 2 +- 4 files changed, 349 insertions(+), 298 deletions(-) diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.cc b/onnxruntime/core/optimizer/gather_slice_fusion.cc index fd7690737b77d..51b0ab8e937d4 100644 --- a/onnxruntime/core/optimizer/gather_slice_fusion.cc +++ b/onnxruntime/core/optimizer/gather_slice_fusion.cc @@ -1,149 +1,146 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - #include "core/optimizer/gather_slice_fusion.h" #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/utils.h" - namespace onnxruntime { bool GatherSliceToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, - int64_t& axis, int64_t& indices_n_dims) const { - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) || - !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { - return false; - } + int64_t& axis, int64_t& indices_n_dims) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { + return false; + } - const NodeArg& input_arg = *(node.InputDefs()[1]); + const NodeArg& input_arg = *(node.InputDefs()[1]); - if (!optimizer_utils::IsScalar(input_arg)) return false; + if (!optimizer_utils::IsScalar(input_arg)) return false; - const ONNX_NAMESPACE::TensorProto* indices_init = graph_utils::GetConstantInitializer(graph, input_arg.Name()); + const ONNX_NAMESPACE::TensorProto* indices_init = graph_utils::GetConstantInitializer(graph, input_arg.Name()); - if (!indices_init) return false; + if (!indices_init) return false; - if (indices_init->data_type() != ONNX_NAMESPACE::TensorProto::INT64) return false; + if (indices_init->data_type() != ONNX_NAMESPACE::TensorProto::INT64) return false; - // get the index value - Initializer init_const(*indices_init, graph.ModelPath()); - index = *(init_const.data()); + // get the index value + Initializer init_const(*indices_init, graph.ModelPath()); + index = *(init_const.data()); - // get attributes value - axis = 0; - auto& attrs = node.GetAttributes(); - if (attrs.find("axis") != attrs.end()) { - auto& axis_attr = attrs.at("axis"); - if (utils::HasInt(axis_attr)) axis = axis_attr.i(); - } + // get attributes value + axis = 0; + auto& attrs = node.GetAttributes(); + if (attrs.find("axis") != attrs.end()) { + auto& axis_attr = attrs.at("axis"); + if (utils::HasInt(axis_attr)) axis = axis_attr.i(); + } - indices_n_dims = indices_init->dims_size(); - return true; + indices_n_dims = indices_init->dims_size(); + return true; } - bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& node, InlinedVector& starts, InlinedVector& ends, InlinedVector& axes, InlinedVector& steps) const { - // check the version of Slice ops - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11, 13}) || - !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { - return false; + // check the version of Slice ops + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11, 13}) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { + return false; + } + + // get the opset version + int onnx_opset_version = -1; + if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { + onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); + } + + // If Slice op of opset version 1 + if (onnx_opset_version == 1) { + if (!graph_utils::GetRepeatedNodeAttributeValues(node, "starts", starts) || + !graph_utils::GetRepeatedNodeAttributeValues(node, "ends", ends) || + starts.size() != ends.size()) { + return false; } - // get the opset version - int onnx_opset_version = -1; - if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { - onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); + if (graph_utils::GetRepeatedNodeAttributeValues(node, "axes", axes) && (axes.size() != starts.size())) { + return false; } - - // If Slice op of opset version 1 - if (onnx_opset_version == 1) { - if (!graph_utils::GetRepeatedNodeAttributeValues(node, "starts", starts) || - !graph_utils::GetRepeatedNodeAttributeValues(node, "ends", ends) || - starts.size() != ends.size()) { - return false; - } - - if (graph_utils::GetRepeatedNodeAttributeValues(node, "axes", axes) && (axes.size() != starts.size())) { - return false; - } + } + + // If Slice op of opset version >= 10 + if (onnx_opset_version >= 10) { + // node inputs include: starts - ends - axes - steps + + // return a pointer to the corresponding NodeArg if input of the node at the index exists + auto get_input_if_exists = [&node](size_t input_index) -> const NodeArg* { + const auto& input_defs = node.InputDefs(); + const NodeArg* input = (input_defs.size() > input_index) ? input_defs[input_index] : nullptr; + return (input == nullptr || !input->Exists()) ? nullptr : input; + }; + + // return a pointer to the initializer if it is constant; otherwise, a nullptr + auto get_initializer_if_constant = + [&graph, get_input_if_exists](size_t input_index) -> const ONNX_NAMESPACE::TensorProto* { + const NodeArg* input = get_input_if_exists(input_index); + return input ? graph_utils::GetConstantInitializer(graph, input->Name()) : nullptr; + }; + + // return the initialization data if it is constant + auto get_initializer_data = + [&graph](const ONNX_NAMESPACE::TensorProto* slice_initializer) -> InlinedVector { + Initializer init(*slice_initializer, graph.ModelPath()); + if (slice_initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT32) { + int32_t* init_data = init.data(); + return InlinedVector(init_data, init_data + init.size()); + } + + if (slice_initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT64) { + int64_t* init_data = init.data(); + return InlinedVector(init_data, init_data + init.size()); + } + return {}; + }; + + // starts and ends inputs have to exist, be constants and be of the same size. + const ONNX_NAMESPACE::TensorProto* starts_init = get_initializer_if_constant(1); + const ONNX_NAMESPACE::TensorProto* ends_init = get_initializer_if_constant(2); + const ONNX_NAMESPACE::TensorProto* axes_init = get_initializer_if_constant(3); + const ONNX_NAMESPACE::TensorProto* steps_init = get_initializer_if_constant(4); + + if (!starts_init || !ends_init || !axes_init || !steps_init) { + return false; } - // If Slice op of opset version >= 10 - if (onnx_opset_version >= 10) { - // node inputs include: starts - ends - axes - steps - - // return a pointer to the corresponding NodeArg if input of the node at the index exists - auto get_input_if_exists = [&node](size_t input_index) -> const NodeArg* { - const auto& input_defs = node.InputDefs(); - const NodeArg* input = (input_defs.size() > input_index) ? input_defs[input_index] : nullptr; - return (input == nullptr || !input->Exists()) ? nullptr : input; - }; - - // return a pointer to the initializer if it is constant; otherwise, a nullptr - auto get_initializer_if_constant = - [&graph, get_input_if_exists](size_t input_index) -> const ONNX_NAMESPACE::TensorProto* { - const NodeArg* input = get_input_if_exists(input_index); - return input ? graph_utils::GetConstantInitializer(graph, input->Name()) : nullptr; - }; - - // return the initialization data if it is constant - auto get_initializer_data = - [&graph](const ONNX_NAMESPACE::TensorProto* slice_initializer) -> InlinedVector { - Initializer init(*slice_initializer, graph.ModelPath()); - if (slice_initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT32) { - int32_t* init_data = init.data(); - return InlinedVector(init_data, init_data + init.size()); - } - - if (slice_initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT64) { - int64_t* init_data = init.data(); - return InlinedVector(init_data, init_data + init.size()); - } - return {}; - }; - - // starts and ends inputs have to exist, be constants and be of the same size. - const ONNX_NAMESPACE::TensorProto* starts_init = get_initializer_if_constant(1); - const ONNX_NAMESPACE::TensorProto* ends_init = get_initializer_if_constant(2); - const ONNX_NAMESPACE::TensorProto* axes_init = get_initializer_if_constant(3); - const ONNX_NAMESPACE::TensorProto* steps_init = get_initializer_if_constant(4); - - if (!starts_init || !ends_init || !axes_init || !steps_init) { - return false; - } - - starts = get_initializer_data(starts_init); - ends = get_initializer_data(ends_init); - axes = get_initializer_data(axes_init); - steps = get_initializer_data(steps_init); + starts = get_initializer_data(starts_init); + ends = get_initializer_data(ends_init); + axes = get_initializer_data(axes_init); + steps = get_initializer_data(steps_init); - if (starts.size() == 0 || ends.size() == 0 || starts.size() != ends.size()) { - return false; - } + if (starts.size() == 0 || ends.size() == 0 || starts.size() != ends.size()) { + return false; + } - if (axes_init->dims_size() != 1 || static_cast(axes_init->dims().Get(0)) != starts.size()) { - return false; - } + if (axes_init->dims_size() != 1 || static_cast(axes_init->dims().Get(0)) != starts.size()) { + return false; + } - // if steps exists, it should be constant and all value should be 1 - if (steps.size() != starts.size()) { - return false; - } + // if steps exists, it should be constant and all value should be 1 + if (steps.size() != starts.size()) { + return false; + } - for (int64_t step : steps) { - if (step != 1) { - return false; - } - } + for (int64_t step : steps) { + if (step != 1) { + return false; + } } + } - return true; + return true; } /* @@ -159,195 +156,191 @@ So that we can use one kernel to finish the job. */ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, - const logging::Logger& logger) const { - GraphViewer graph_viewer(graph); + const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); - const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); - InlinedVector output_args; + InlinedVector output_args; - // Iterate the topological order and get Reshape ops - for (auto node_index : node_topology_list) { - auto* p_node = graph.GetNode(node_index); + // Iterate the topological order and get Reshape ops + for (auto node_index : node_topology_list) { + auto* p_node = graph.GetNode(node_index); - if (p_node == nullptr) continue; + if (p_node == nullptr) continue; - Node& node = *p_node; + Node& node = *p_node; - ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); - // Currently only catch after Reshape ops, optimize in the future - if (node.OpType() != "Reshape") continue; + // Currently only catch after Reshape ops, optimize in the future + if (node.OpType() != "Reshape") continue; - size_t output_count = node.GetOutputEdgesCount(); + size_t output_count = node.GetOutputEdgesCount(); - // We only catch 1 scenario for Multi Query Attention for now. - // |---> Gather - // Reshape |---> Gather - // |---> Slice - // |... or (other ops) + // We only catch 1 scenario for Multi Query Attention for now. + // |---> Gather + // Reshape |---> Gather + // |---> Slice + // |... or (other ops) - // Get the output into node args - if (output_count < 3) continue; + // Get the output into node args + if (output_count < 3) continue; - output_args.push_back(node.OutputDefs()[0]); - } + output_args.push_back(node.OutputDefs()[0]); + } - // iterate the children of Reshape node - for (const NodeArg* node_arg : output_args) { - auto shape = node_arg->Shape(); - if (!shape) continue; - - auto consumers = graph.GetConsumerNodes(node_arg->Name()); - size_t consumer_count = consumers.size(); - - // get the tensor rank - int64_t rank = static_cast(shape->dim_size()); - - bool can_fuse = true; - bool first_edge = true; - int64_t split_axis = 0; - int64_t indices_n_dims = -1; - - // Fuse 2 Gathers and 1 slice to Split - // Get those outputs as Split outputs - InlinedVector split_outputs(3); - - InlinedVector> nodes_to_fuse; - int64_t gather_node_count = 0, slice_node_count = 0; - - // find the nodes to be merged - for (auto consumer : consumers) { - int64_t index, axis, dims; - InlinedVector starts, ends, axes, steps; - - bool IsSupportedGatherOps = IsSupportedGather(graph, *consumer, index, axis, dims); - bool IsSupportedSliceOps = IsSupportedSlice(graph, *consumer, starts, ends, axes, steps); - - if ((!consumer || consumer->InputDefs()[0] != node_arg) || - (!IsSupportedGatherOps && !IsSupportedSliceOps)) { - break; - } - - if (IsSupportedGatherOps) { - if (indices_n_dims == -1) { - indices_n_dims = dims; - } else if (indices_n_dims != dims) { - // Not the same number of dimensions (0 or 1) for all scalar indices. - can_fuse = false; - break; - } - - if (axis <0) axis += rank; - - if (first_edge) { - auto dim = shape->dim(static_cast(axis)); - // dim.dim_value() = 73 - if (!utils::HasDimValue(dim)) { - can_fuse = false; - break; - } - split_axis = axis; - first_edge = false; - } else if (axis != split_axis) { - can_fuse = false; - break; - } - - if (index < 0) index += static_cast(consumer_count); - if (index < 0 || index >= static_cast(consumer_count)) { - can_fuse = false; - break; - } - - Node& gather_node = *graph.GetNode(consumer->Index()); - nodes_to_fuse.push_back(gather_node); - NodeArg* gather_output_args = gather_node.MutableOutputDefs()[0]; - gather_node_count++; - split_outputs[gather_node_count] = gather_output_args; - - } - - // check the Slice Ops - if (IsSupportedSliceOps) { - if (axes[0] != axis && !first_edge) { - can_fuse = false; - break; - } - - Node& slice_node = *graph.GetNode(consumer->Index()); - NodeArg* slice_output_args = slice_node.MutableOutputDefs()[0]; - nodes_to_fuse.push_back(slice_node); - split_outputs[slice_node_count] = slice_output_args; - slice_node_count++; - } - } + // iterate the children of Reshape node + for (const NodeArg* node_arg : output_args) { + auto shape = node_arg->Shape(); + if (!shape) continue; - // condition check - if (!can_fuse || gather_node_count != 2 || slice_node_count != 1) continue; + auto consumers = graph.GetConsumerNodes(node_arg->Name()); + size_t consumer_count = consumers.size(); - // generate the split node and merge the kernel - ONNX_NAMESPACE::TypeProto split_output_type; - const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast( - node_arg->TypeAsProto()->tensor_type().elem_type()); + // get the tensor rank + int64_t rank = static_cast(shape->dim_size()); - split_output_type.mutable_tensor_type()->set_elem_type(element_type); + bool can_fuse = true; + bool first_edge = true; + int64_t split_axis = 0; + int64_t indices_n_dims = -1; - for (int64_t i = 0; i < rank; i++) { - if (i == split_axis) - split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL); - else { - *(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) - = shape->dim(static_cast(i)); - } - } + // Fuse 2 Gathers and 1 slice to Split + // Get those outputs as Split outputs + InlinedVector split_outputs(3); + + InlinedVector> nodes_to_fuse; + int64_t gather_node_count = 0, slice_node_count = 0; - InlinedVector split_output_types; + // find the nodes to be merged + for (auto consumer : consumers) { + int64_t index, axis, dims; + InlinedVector starts, ends, axes, steps; - for (size_t i = 0; i < consumer_count; ++i) { - split_output_types.push_back( - &graph.GetOrCreateNodeArg( - graph.GenerateNodeArgName("fused_split_" + std::to_string(i)), &split_output_type - ) - ); + bool IsSupportedGatherOps = IsSupportedGather(graph, *consumer, index, axis, dims); + bool IsSupportedSliceOps = IsSupportedSlice(graph, *consumer, starts, ends, axes, steps); + + if ((!consumer || consumer->InputDefs()[0] != node_arg) || + (!IsSupportedGatherOps && !IsSupportedSliceOps)) { + break; + } + + if (IsSupportedGatherOps) { + if (indices_n_dims == -1) { + indices_n_dims = dims; + } else if (indices_n_dims != dims) { + // Not the same number of dimensions (0 or 1) for all scalar indices. + can_fuse = false; + break; } - // Generate the Split Node - ONNX_NAMESPACE::TensorProto split_initializer_proto; - split_initializer_proto.set_name(graph.GenerateNodeName("fused_Split")); - split_initializer_proto.add_dims(static_cast(3)); - split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - - auto dim_value = shape->dim(static_cast(split_axis)).dim_value(); - int64_t slice_dim = static_cast(dim_value - gather_node_count); - InlinedVector split_value{{slice_dim, 1, 1}}; - split_initializer_proto.set_raw_data(split_value.data(), split_value.size() * sizeof(int64_t)); - NodeArg* split_arg = &graph_utils::AddInitializer(graph, split_initializer_proto); - - Node& split_node = - graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for fused Gather-Slice fusion", - {graph.GetNodeArg(node_arg->Name()), split_arg}, split_outputs); - - split_node.AddAttribute("axis", split_axis); - // to do here - split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); - - int onnx_opset_version = -1; - if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { - onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); + if (axis < 0) axis += rank; + + if (first_edge) { + auto dim = shape->dim(static_cast(axis)); + // dim.dim_value() = 73 + if (!utils::HasDimValue(dim)) { + can_fuse = false; + break; + } + split_axis = axis; + first_edge = false; + } else if (axis != split_axis) { + can_fuse = false; + break; } - if (onnx_opset_version >= 18) { - split_node.AddAttribute("num_outputs", static_cast(consumer_count)); + if (index < 0) index += static_cast(consumer_count); + if (index < 0 || index >= static_cast(consumer_count)) { + can_fuse = false; + break; } - for (Node& node_to_fuse : nodes_to_fuse) { - graph_utils::RemoveNodeOutputEdges(graph, node_to_fuse); - graph.RemoveNode(node_to_fuse.Index()); + Node& gather_node = *graph.GetNode(consumer->Index()); + nodes_to_fuse.push_back(gather_node); + NodeArg* gather_output_args = gather_node.MutableOutputDefs()[0]; + gather_node_count++; + split_outputs[gather_node_count] = gather_output_args; + } + + // check the Slice Ops + if (IsSupportedSliceOps) { + if (axes[0] != axis && !first_edge) { + can_fuse = false; + break; } - modified = true; + + Node& slice_node = *graph.GetNode(consumer->Index()); + NodeArg* slice_output_args = slice_node.MutableOutputDefs()[0]; + nodes_to_fuse.push_back(slice_node); + split_outputs[slice_node_count] = slice_output_args; + slice_node_count++; + } } - return Status::OK(); -} + // condition check + if (!can_fuse || gather_node_count != 2 || slice_node_count != 1) continue; + + // generate the split node and merge the kernel + ONNX_NAMESPACE::TypeProto split_output_type; + const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast( + node_arg->TypeAsProto()->tensor_type().elem_type()); + + split_output_type.mutable_tensor_type()->set_elem_type(element_type); + + for (int64_t i = 0; i < rank; i++) { + if (i == split_axis) + split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL); + else { + *(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast(i)); + } + } + + InlinedVector split_output_types; + + for (size_t i = 0; i < consumer_count; ++i) { + split_output_types.push_back( + &graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("fused_split_" + std::to_string(i)), &split_output_type)); + } + + // Generate the Split Node + ONNX_NAMESPACE::TensorProto split_initializer_proto; + split_initializer_proto.set_name(graph.GenerateNodeName("fused_Split")); + split_initializer_proto.add_dims(static_cast(3)); + split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + + auto dim_value = shape->dim(static_cast(split_axis)).dim_value(); + int64_t slice_dim = static_cast(dim_value - gather_node_count); + InlinedVector split_value{{slice_dim, 1, 1}}; + split_initializer_proto.set_raw_data(split_value.data(), split_value.size() * sizeof(int64_t)); + NodeArg* split_arg = &graph_utils::AddInitializer(graph, split_initializer_proto); + + Node& split_node = + graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for fused Gather-Slice fusion", + {graph.GetNodeArg(node_arg->Name()), split_arg}, split_outputs); + + split_node.AddAttribute("axis", split_axis); + + split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); + + int onnx_opset_version = -1; + if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { + onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); + } + + if (onnx_opset_version >= 18) { + split_node.AddAttribute("num_outputs", static_cast(consumer_count)); + } + + for (Node& node_to_fuse : nodes_to_fuse) { + graph_utils::RemoveNodeOutputEdges(graph, node_to_fuse); + graph.RemoveNode(node_to_fuse.Index()); + } + modified = true; + } + + return Status::OK(); } +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.h b/onnxruntime/core/optimizer/gather_slice_fusion.h index 63416c5274f0b..1c5c307efed7f 100644 --- a/onnxruntime/core/optimizer/gather_slice_fusion.h +++ b/onnxruntime/core/optimizer/gather_slice_fusion.h @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - #pragma once #include "core/optimizer/graph_transformer.h" @@ -14,21 +13,20 @@ Fuse (2 Gather nodes + 1 Slice) to 1 split node. */ class GatherSliceToSplitFusion : public GraphTransformer { -private: - bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, - int64_t& indices_n_dims) const; - - bool IsSupportedSlice(const Graph& graph, const Node& node, - InlinedVector& starts, - InlinedVector& ends, - InlinedVector& axes, - InlinedVector& steps) const; + private: + bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, + int64_t& indices_n_dims) const; -public: - GatherSliceToSplitFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept - : GraphTransformer("GatherSliceToSplitFusion", compatible_execution_providers) {} + bool IsSupportedSlice(const Graph& graph, const Node& node, + InlinedVector& starts, + InlinedVector& ends, + InlinedVector& axes, + InlinedVector& steps) const; - Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; + public: + GatherSliceToSplitFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("GatherSliceToSplitFusion", compatible_execution_providers) {} + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; }; -} +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index bd125b7318335..6b7af531f391e 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -7655,30 +7655,30 @@ TEST_F(GraphTransformationTests, GatherSliceToSplitFusion) { auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(-2)}); auto* gather_out_1 = builder.MakeIntermediate({{2, 512, 1, 64}}); builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) - .AddAttribute("axis", static_cast(2)); + .AddAttribute("axis", static_cast(2)); // Create Transpose 1-Ops auto* transpose_out_1 = builder.MakeOutput(); builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) - .AddAttribute("perm", std::vector{0, 2, 1, 3}); + .AddAttribute("perm", std::vector{0, 2, 1, 3}); // Create Gather-2 Ops auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(-1)}); auto* gather_out_2 = builder.MakeIntermediate({{2, 512, 1, 64}}); builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2}) - .AddAttribute("axis", static_cast(2)); + .AddAttribute("axis", static_cast(2)); // Create Transpose-2 Ops auto* transpose_out_2 = builder.MakeOutput(); builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}) - .AddAttribute("perm", std::vector{0, 2, 1, 3}); + .AddAttribute("perm", std::vector{0, 2, 1, 3}); // Create Slice Ops auto* slice_output = builder.MakeIntermediate(); auto* starts = builder.MakeInitializer({1}, {0}); - auto* ends = builder.MakeInitializer({1}, {-2}); - auto* axes = builder.MakeInitializer({1}, {2}); - auto* steps = builder.MakeInitializer({1}, {1}); + auto* ends = builder.MakeInitializer({1}, {-2}); + auto* axes = builder.MakeInitializer({1}, {2}); + auto* steps = builder.MakeInitializer({1}, {1}); builder.AddNode("Slice", {reshape_out, starts, ends, axes, steps}, {slice_output}); // Create Shape-1 Ops @@ -7692,7 +7692,7 @@ TEST_F(GraphTransformationTests, GatherSliceToSplitFusion) { // Create Transpose-3 Ops auto* transpose_out_3 = builder.MakeOutput(); builder.AddNode("Transpose", {slice_output}, {transpose_out_3}) - .AddAttribute("perm", std::vector{0, 2, 1, 3}); + .AddAttribute("perm", std::vector{0, 2, 1, 3}); }; auto pre_graph_checker = [&](Graph& graph) { @@ -7717,7 +7717,67 @@ TEST_F(GraphTransformationTests, GatherSliceToSplitFusion) { std::unique_ptr transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); + }; +} + +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Invalid) { + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInput({{54}}); + auto* reshape_arg = builder.MakeInput({{4}}); + auto* reshape_out = builder.MakeIntermediate({{2, 512, 73, 64}}); + builder.AddNode("Reshape", {data_arg, reshape_arg}, {reshape_out}); + + // Create Gather-1 Ops + auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(-2)}); + auto* gather_out_1 = builder.MakeIntermediate({{2, 512, 1, 64}}); + builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) + .AddAttribute("axis", static_cast(2)); + + // Create Transpose 1-Ops + auto* transpose_out_1 = builder.MakeOutput(); + builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + + // Create Slice Ops + auto* slice_output = builder.MakeIntermediate(); + auto* starts = builder.MakeInitializer({1}, {0}); + auto* ends = builder.MakeInitializer({1}, {-2}); + auto* axes = builder.MakeInitializer({1}, {2}); + auto* steps = builder.MakeInitializer({1}, {1}); + builder.AddNode("Slice", {reshape_out, starts, ends, axes, steps}, {slice_output}); + + // Create Shape-1 Ops + auto* shape_output_1 = builder.MakeOutput(); + builder.AddNode("Shape", {slice_output}, {shape_output_1}); + + // Create Shape-2 Ops + auto* shape_output_2 = builder.MakeOutput(); + builder.AddNode("Shape", {slice_output}, {shape_output_2}); + + // Create Transpose-3 Ops + auto* transpose_out_3 = builder.MakeOutput(); + builder.AddNode("Transpose", {slice_output}, {transpose_out_3}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0); + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); }; } diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.cc b/onnxruntime/test/optimizer/graph_transform_test_builder.cc index 1ad45d9d27aad..a5024f510b3cd 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.cc @@ -15,7 +15,7 @@ #include "test/util/include/inference_session_wrapper.h" // enable to dump model for debugging -#define SAVE_TEST_GRAPH 1 +#define SAVE_TEST_GRAPH 0 namespace onnxruntime { namespace test { From 7870789d80c6b67d59f90f681869a93ccb464f35 Mon Sep 17 00:00:00 2001 From: ruiren Date: Wed, 7 Feb 2024 06:44:18 +0000 Subject: [PATCH 09/11] update test and format --- onnxruntime/core/optimizer/gather_slice_fusion.cc | 3 +-- onnxruntime/test/optimizer/graph_transform_test.cc | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.cc b/onnxruntime/core/optimizer/gather_slice_fusion.cc index 51b0ab8e937d4..dc76c4a1f1a5b 100644 --- a/onnxruntime/core/optimizer/gather_slice_fusion.cc +++ b/onnxruntime/core/optimizer/gather_slice_fusion.cc @@ -292,9 +292,8 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra for (int64_t i = 0; i < rank; i++) { if (i == split_axis) split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL); - else { + else *(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast(i)); - } } InlinedVector split_output_types; diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 6b7af531f391e..e1fcf835c6043 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -7718,7 +7718,7 @@ TEST_F(GraphTransformationTests, GatherSliceToSplitFusion) { std::unique_ptr transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - }; + } } TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Invalid) { @@ -7763,7 +7763,7 @@ TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Invalid) { }; auto pre_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 1); TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); return Status::OK(); }; @@ -7778,7 +7778,7 @@ TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Invalid) { std::unique_ptr transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - }; + } } } // namespace test From 5a7e0c058bd4872651a9f3a8bc8b3a8ca73f69eb Mon Sep 17 00:00:00 2001 From: ruiren Date: Wed, 7 Feb 2024 07:38:05 +0000 Subject: [PATCH 10/11] fix the warning --- onnxruntime/core/optimizer/gather_slice_fusion.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.cc b/onnxruntime/core/optimizer/gather_slice_fusion.cc index dc76c4a1f1a5b..0b7b2d53e0728 100644 --- a/onnxruntime/core/optimizer/gather_slice_fusion.cc +++ b/onnxruntime/core/optimizer/gather_slice_fusion.cc @@ -211,7 +211,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra InlinedVector split_outputs(3); InlinedVector> nodes_to_fuse; - int64_t gather_node_count = 0, slice_node_count = 0; + size_t gather_node_count = 0, slice_node_count = 0; // find the nodes to be merged for (auto consumer : consumers) { @@ -260,8 +260,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra Node& gather_node = *graph.GetNode(consumer->Index()); nodes_to_fuse.push_back(gather_node); NodeArg* gather_output_args = gather_node.MutableOutputDefs()[0]; - gather_node_count++; - split_outputs[gather_node_count] = gather_output_args; + split_outputs[++gather_node_count] = gather_output_args; } // check the Slice Ops From 721a5c772c013c50179ebebf80421f2845577c64 Mon Sep 17 00:00:00 2001 From: rui-ren Date: Wed, 14 Feb 2024 19:13:56 +0000 Subject: [PATCH 11/11] small fix --- onnxruntime/core/optimizer/gather_slice_fusion.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.cc b/onnxruntime/core/optimizer/gather_slice_fusion.cc index 0b7b2d53e0728..21266d356a020 100644 --- a/onnxruntime/core/optimizer/gather_slice_fusion.cc +++ b/onnxruntime/core/optimizer/gather_slice_fusion.cc @@ -211,7 +211,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra InlinedVector split_outputs(3); InlinedVector> nodes_to_fuse; - size_t gather_node_count = 0, slice_node_count = 0; + size_t gather_node_count = 2, slice_node_count = 0; // find the nodes to be merged for (auto consumer : consumers) { @@ -260,7 +260,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra Node& gather_node = *graph.GetNode(consumer->Index()); nodes_to_fuse.push_back(gather_node); NodeArg* gather_output_args = gather_node.MutableOutputDefs()[0]; - split_outputs[++gather_node_count] = gather_output_args; + split_outputs[gather_node_count--] = gather_output_args; } // check the Slice Ops @@ -273,13 +273,12 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra Node& slice_node = *graph.GetNode(consumer->Index()); NodeArg* slice_output_args = slice_node.MutableOutputDefs()[0]; nodes_to_fuse.push_back(slice_node); - split_outputs[slice_node_count] = slice_output_args; - slice_node_count++; + split_outputs[slice_node_count++] = slice_output_args; } } // condition check - if (!can_fuse || gather_node_count != 2 || slice_node_count != 1) continue; + if (!can_fuse || gather_node_count != 0 || slice_node_count != 1) continue; // generate the split node and merge the kernel ONNX_NAMESPACE::TypeProto split_output_type; @@ -310,7 +309,8 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); auto dim_value = shape->dim(static_cast(split_axis)).dim_value(); - int64_t slice_dim = static_cast(dim_value - gather_node_count); + // Optimize 2 Gather Nodes, so Slice_dim = dim_value - 2 + int64_t slice_dim = static_cast(dim_value - 2); InlinedVector split_value{{slice_dim, 1, 1}}; split_initializer_proto.set_raw_data(split_value.data(), split_value.size() * sizeof(int64_t)); NodeArg* split_arg = &graph_utils::AddInitializer(graph, split_initializer_proto);