From d2e6dd25ea8bd528f614250ba0165a535734305e Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Thu, 29 Feb 2024 13:45:58 +0800 Subject: [PATCH] Merge GatherToSplitFusion and #19218 to a General Fusion (#19600) #19218 tried to fuse Gather/Slice to Split, but the logic has problem. Scalar value or 1-dim value of indices in Gather node will produce different result, scalar value will produce a result tensor by removing the axis dim, will 1-dim indices value will keep that dim, even when the dim value is 1. For example, Node |-> Gather(indices=[0], axis=axis) |-> Gather(indices=[1], axis=axis) |-> Slice(index=2, axis=axis) is same as Node |-> Split(axis=axis) But Node |-> Gather(indices=0, axis=axis) |-> Gather(indices=1, axis=axis) |-> Slice(index=2, axis=axis) is same as Node |-> Split(axis=axis) ||-> Squeeze(axis=axis) ||-> Squeeze(axis=axis) ||-> Previous PR doesn't take such case related to Squeeze/Unsqueeze into account. This PR merges #19218 and GatherToSplitFusion to a general fusion, which relaxes the limit the number of Gather and Slice node number, check all Gather and Slice consumers, if the indices of Gather and start/end of Slice can cover the specific dim of the input tensor, then we can fuse them to a Split, and adding Squeeze if necessary according to the dim count of the indices tensor in Gather. @rui-ren, please check if the fix can still be applied to your model. --- onnxruntime/core/optimizer/gather_fusion.cc | 318 ++++++---- onnxruntime/core/optimizer/gather_fusion.h | 16 +- .../core/optimizer/gather_slice_fusion.cc | 344 ----------- .../core/optimizer/gather_slice_fusion.h | 32 - .../core/optimizer/graph_transformer_utils.cc | 4 +- .../test/optimizer/graph_transform_test.cc | 550 +++++------------- .../core/optimizer/graph_transformer_utils.cc | 4 +- 7 files changed, 352 insertions(+), 916 deletions(-) delete mode 100644 onnxruntime/core/optimizer/gather_slice_fusion.cc delete mode 100644 onnxruntime/core/optimizer/gather_slice_fusion.h diff --git a/onnxruntime/core/optimizer/gather_fusion.cc b/onnxruntime/core/optimizer/gather_fusion.cc index 4903bc1d6b961..90cabff88122c 100644 --- a/onnxruntime/core/optimizer/gather_fusion.cc +++ b/onnxruntime/core/optimizer/gather_fusion.cc @@ -9,55 +9,144 @@ namespace onnxruntime { -bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, - int64_t& indices_n_dims) const { - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) || +namespace { +static int64_t GetGatherAxis(const Node& node, int64_t rank) { + int64_t axis = 0; + auto& attrs = node.GetAttributes(); + if (attrs.find("axis") != attrs.end()) { + auto& axis_attr = attrs.at("axis"); + if (utils::HasInt(axis_attr)) { + axis = axis_attr.i(); + if (axis < 0) axis += rank; + } + } + return axis; +} + +static bool GetScalarInt64Initializer(const Graph& graph, const NodeArg& node_arg, int64_t& value, int64_t& rank) { + if (!optimizer_utils::IsScalar(node_arg)) return false; + const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node_arg.Name()); + if (!tensor_proto || tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto::INT64) return false; + Initializer init_const{*tensor_proto, graph.ModelPath()}; + value = *(init_const.data()); + rank = tensor_proto->dims_size(); + return true; +} + +static bool GetSliceAxis(const Graph& graph, const Node& node, int64_t rank, int64_t& axis) { + if (node.InputDefs().size() < 4) return false; + int64_t unused = 0; + if (!GetScalarInt64Initializer(graph, *node.InputDefs()[3], axis, unused)) return false; + if (axis < 0) axis += rank; + return true; +} + +static bool GetAxis(const Graph& graph, const Node& node, int64_t rank, int64_t& axis) { + if (node.OpType() == "Gather") { + axis = GetGatherAxis(node, rank); + return true; + } + if (node.OpType() == "Slice") { + return GetSliceAxis(graph, node, rank, axis); + } + return false; +} + +} // namespace + +bool GatherSliceToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t rank, + int64_t target_axis, int64_t dim_size, InlinedVector& consumed, + int64_t& start, bool& need_squeeze) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {13}) || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { return false; } - const NodeArg& input_arg = *(node.InputDefs()[1]); - if (!optimizer_utils::IsScalar(input_arg)) return false; - const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); - if (!tensor_proto) return false; - if (tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT64) return false; - Initializer init_const{*tensor_proto, graph.ModelPath()}; - index = *(init_const.data()); - axis = 0; // Default value. - auto& attrs = node.GetAttributes(); - if (attrs.find("axis") != attrs.end()) { - auto& axis_attr = attrs.at("axis"); - if (utils::HasInt(axis_attr)) axis = axis_attr.i(); + if (GetGatherAxis(node, rank) != target_axis) return false; + // Require the indices input to be a scalar tensor for now. Normally if not, the exporter will choose Slice. + // We can relax this later if needed. + int64_t indices_n_dims = 0; + if (!GetScalarInt64Initializer(graph, *(node.InputDefs()[1]), start, indices_n_dims)) return false; + if (start < 0) start += dim_size; + if (start < 0 || start >= dim_size || consumed[static_cast(start)]) return false; + consumed[static_cast(start)] = true; + need_squeeze = indices_n_dims == 0; + return true; +} + +bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& node, int64_t rank, int64_t target_axis, + int64_t dim_size, InlinedVector& consumed, int64_t& start, + int64_t& end) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {13}) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { + return false; + } + + int64_t axis = 0; + if (!GetSliceAxis(graph, node, rank, axis) || axis != target_axis) return false; + int64_t unused = 0; + if (!GetScalarInt64Initializer(graph, *node.InputDefs()[1], start, unused) || + !GetScalarInt64Initializer(graph, *node.InputDefs()[2], end, unused)) { + return false; + } + // Handling start and end according to schema definition. + if (start < 0) start += dim_size; + if (end < 0) end += dim_size; + if (start < 0) + start = 0; + else if (start > dim_size) + start = dim_size; + if (end < 0) + end = 0; + else if (end > dim_size) + end = dim_size; + if (start >= end) return false; + if (node.InputDefs().size() >= 5) { + int64_t step = 0; + if (!GetScalarInt64Initializer(graph, *node.InputDefs()[4], step, unused) || step != 1) return false; + } + for (int64_t i = start; i < end; ++i) { + if (consumed[static_cast(i)]) return false; + consumed[static_cast(i)] = true; } - indices_n_dims = tensor_proto->dims_size(); return true; } /* -GatherToSplitFusion is to fuse: -Node -> Gather(index=0, axis=axis) - |-> Gather(index=1, axis=axis) - |-> Gather(index=2, axis=axis) +GatherSliceToSplitFusion is to fuse: +Node -> Gather(indices=0, axis=axis) + |-> Gather(indices=[1], axis=axis) + |-> Slice(start=2, end=3, axes=[axis]) |... To Node -> Split -> Squeeze(axis=axis) - |-> Squeeze(axis=axis) - |-> Squeeze(axis=axis) + |-> + |-> |... So that we can use one kernel to finish the job. +The fusion requires that the indices of Gather nodes and start/end of Slice nodes are not overlapping and cover +all the elements in the target axis. Step of Slice node should be 1. */ -Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, - const logging::Logger& logger) const { +Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { + // Squeeze, Gather, Slice and Split have different schemas before and after OpSet 13. + // To make code simple, support OpSet >= 13 only. + int onnx_opset_version = -1; + if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { + onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); + } + if (onnx_opset_version < 13) return Status::OK(); + GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); - InlinedVector node_args; + InlinedVector candidate_args; for (auto node_arg : graph.GetInputs()) { if (node_arg && graph.GetConsumerNodes(node_arg->Name()).size() > 1) { - node_args.push_back(node_arg); + candidate_args.push_back(node_arg); } } @@ -65,7 +154,7 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le if (graph.GetConsumerNodes(entry.first).size() > 1) { auto node_arg = graph.GetNodeArg(entry.first); if (node_arg) { - node_args.push_back(node_arg); + candidate_args.push_back(node_arg); } } } @@ -90,129 +179,108 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le size_t output_count = node.GetOutputEdgesCount(); if (output_count <= 1) continue; - node_args.push_back(node.OutputDefs()[0]); + candidate_args.push_back(node.OutputDefs()[0]); } - for (const NodeArg* node_arg : node_args) { + for (const NodeArg* node_arg : candidate_args) { auto shape = node_arg->Shape(); if (!shape) continue; int64_t rank = static_cast(shape->dim_size()); - - bool can_fuse = true; - bool first_edge = true; - int64_t split_axis = 0; - int64_t indices_n_dims = -1; auto consumers = graph.GetConsumerNodes(node_arg->Name()); - size_t consumer_count = consumers.size(); - InlinedVector gather_outputs(consumer_count, nullptr); - InlinedVector> nodes_to_fuse; + InlinedVector condidate_consumers; for (auto consumer : consumers) { - int64_t index, axis, dims; - if (!consumer || consumer->InputDefs()[0] != node_arg || - !IsSupportedGather(graph, *consumer, index, axis, dims)) { - can_fuse = false; - break; - } - if (indices_n_dims == -1) { - indices_n_dims = dims; - } else if (indices_n_dims != dims) { - // Not the same number of dimensions (0 or 1) for all scalar indices. - can_fuse = false; - break; + if (consumer && consumer->InputDefs()[0] == node_arg && + (consumer->OpType() == "Gather" || consumer->OpType() == "Slice")) { + condidate_consumers.emplace_back(consumer); } - if (axis < 0) axis += rank; - if (first_edge) { - auto dim = shape->dim(static_cast(axis)); - if (!utils::HasDimValue(dim) || dim.dim_value() != static_cast(consumer_count)) { - can_fuse = false; - break; - } - split_axis = axis; - first_edge = false; - } else if (axis != split_axis) { + } + if (condidate_consumers.size() < 2) continue; + int64_t axis = 0; + if (!GetAxis(graph, *condidate_consumers[0], rank, axis)) continue; + auto dim = shape->dim(static_cast(axis)); + if (!utils::HasDimValue(dim)) continue; + int64_t dim_size = dim.dim_value(); + InlinedVector consumed(static_cast(dim_size), false); + bool can_fuse = true; + InlinedVector> nodes_to_fuse; + InlinedVector starts; + InlinedHashMap> output_info_map; + for (auto consumer : condidate_consumers) { + if (!consumer || consumer->InputDefs()[0] != node_arg) { can_fuse = false; break; } - if (index < 0) index += static_cast(consumer_count); - if (index < 0 || index >= static_cast(consumer_count) || gather_outputs[static_cast(index)]) { + int64_t start = 0, end = 0; + bool need_squeeze = false; + if (IsSupportedGather(graph, *consumer, rank, axis, dim_size, consumed, start, need_squeeze)) { + Node& gather_node = *graph.GetNode(consumer->Index()); + nodes_to_fuse.emplace_back(gather_node); + starts.emplace_back(start); + output_info_map[start] = std::make_tuple(gather_node.MutableOutputDefs()[0], 1, need_squeeze); + } else if (IsSupportedSlice(graph, *consumer, rank, axis, dim_size, consumed, start, end)) { + Node& slice_node = *graph.GetNode(consumer->Index()); + nodes_to_fuse.emplace_back(slice_node); + starts.emplace_back(start); + output_info_map[start] = std::make_tuple(slice_node.MutableOutputDefs()[0], end - start, false); + } else { can_fuse = false; break; } - Node& gather_node = *graph.GetNode(consumer->Index()); - nodes_to_fuse.emplace_back(gather_node); - gather_outputs[static_cast(index)] = gather_node.MutableOutputDefs()[0]; - } - - if (!can_fuse) continue; - - ONNX_NAMESPACE::TypeProto split_output_type; - const ONNX_NAMESPACE::TensorProto_DataType element_type = - static_cast(node_arg->TypeAsProto()->tensor_type().elem_type()); - split_output_type.mutable_tensor_type()->set_elem_type(element_type); - for (int64_t i = 0; i < rank; ++i) { - if (i == split_axis) { - split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL); - } else { - *(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast(i)); - } } + if (!can_fuse || std::find(consumed.begin(), consumed.end(), false) != consumed.end()) continue; + std::sort(starts.begin(), starts.end()); InlinedVector split_outputs; - bool add_squeeze_node = indices_n_dims == 0; - if (add_squeeze_node) { - for (size_t i = 0; i < consumer_count; ++i) { - split_outputs.emplace_back( - &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split" + std::to_string(i)), &split_output_type)); - } - } - - Node& split_node = - graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes", - {graph.GetNodeArg(node_arg->Name())}, add_squeeze_node ? split_outputs : gather_outputs); - split_node.AddAttribute("axis", split_axis); - split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); - - // Squeeze-11, Squeee-13, Split-13, Split-18 have different schemas. - int onnx_opset_version = -1; - if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { - onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); - } - - if (onnx_opset_version < 13) { - if (add_squeeze_node) { - for (size_t i = 0; i < consumer_count; ++i) { - Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze", - "Squeeze for Fused Gather nodes", {split_outputs[i]}, {gather_outputs[i]}); - squeeze_node.AddAttribute("axes", std::vector{split_axis}); - squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); + InlinedVector split_values; + for (int64_t start : starts) { + auto& output_info = output_info_map[start]; + NodeArg* original_output_arg = std::get<0>(output_info); + int64_t split_value = std::get<1>(output_info); + split_values.emplace_back(split_value); + if (std::get<2>(output_info)) { + ONNX_NAMESPACE::TypeProto split_output_type; + const ONNX_NAMESPACE::TensorProto_DataType element_type = + static_cast(node_arg->TypeAsProto()->tensor_type().elem_type()); + split_output_type.mutable_tensor_type()->set_elem_type(element_type); + for (int64_t i = 0; i < rank; ++i) { + if (i == axis) { + split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(split_value); + } else { + *(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast(i)); + } } - } - } else { - if (onnx_opset_version >= 18) { - split_node.AddAttribute("num_outputs", static_cast(consumer_count)); - } - - if (add_squeeze_node) { + NodeArg* split_output_arg = + &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split_output"), &split_output_type); ONNX_NAMESPACE::TensorProto axes_initializer_proto; - axes_initializer_proto.set_name(graph.GenerateNodeName("SqueezeAxesInitializer")); + axes_initializer_proto.set_name(graph.GenerateNodeName("squeeze_axes")); axes_initializer_proto.add_dims(static_cast(1)); axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - InlinedVector axes_value{split_axis}; - axes_initializer_proto.set_raw_data(axes_value.data(), axes_value.size() * sizeof(int64_t)); + axes_initializer_proto.add_int64_data(axis); NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto); - - for (size_t i = 0; i < consumer_count; ++i) { - Node& squeeze_node = - graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze", - "Squeeze for Fused Gather nodes", {split_outputs[i], axes_arg}, {gather_outputs[i]}); - squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); - } + Node& squeeze_node = + graph.AddNode(graph.GenerateNodeName("Squeeze"), "Squeeze", "Squeeze for Fused Gather nodes", + {split_output_arg, axes_arg}, {original_output_arg}); + squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); + split_outputs.emplace_back(split_output_arg); + } else { + split_outputs.emplace_back(original_output_arg); } } - for (Node& n : nodes_to_fuse) { - graph_utils::RemoveNodeOutputEdges(graph, n); - graph.RemoveNode(n.Index()); + ONNX_NAMESPACE::TensorProto split_initializer_proto; + split_initializer_proto.set_name(graph.GenerateNodeName("splits")); + split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + split_initializer_proto.add_dims(static_cast(split_values.size())); + split_initializer_proto.mutable_int64_data()->Add(split_values.begin(), split_values.end()); + NodeArg* split_initializer_arg = &graph_utils::AddInitializer(graph, split_initializer_proto); + Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes", + {graph.GetNodeArg(node_arg->Name()), split_initializer_arg}, split_outputs); + split_node.AddAttribute("axis", axis); + split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); + + for (Node& node : nodes_to_fuse) { + graph_utils::RemoveNodeOutputEdges(graph, node); + graph.RemoveNode(node.Index()); } modified = true; diff --git a/onnxruntime/core/optimizer/gather_fusion.h b/onnxruntime/core/optimizer/gather_fusion.h index 44c235915b6cc..098278a77dafe 100644 --- a/onnxruntime/core/optimizer/gather_fusion.h +++ b/onnxruntime/core/optimizer/gather_fusion.h @@ -8,19 +8,23 @@ namespace onnxruntime { /** -@Class GatherToSplitFusion +@Class GatherSliceToSplitFusion -Fuse multiple Gather nodes that comsuming one output to one Split node. +Fuse multiple Gather/Slice nodes that comsuming one output to one Split node. */ -class GatherToSplitFusion : public GraphTransformer { +class GatherSliceToSplitFusion : public GraphTransformer { public: - GatherToSplitFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept - : GraphTransformer("GatherToSplitFusion", compatible_execution_providers) {} + GatherSliceToSplitFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("GatherSliceToSplitFusion", compatible_execution_providers) {} Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; private: - bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const; + bool IsSupportedGather(const Graph& graph, const Node& node, int64_t rank, int64_t target_axis, int64_t dim_size, + InlinedVector& consumed, int64_t& start, bool& need_squeeze) const; + + bool IsSupportedSlice(const Graph& graph, const Node& node, int64_t rank, int64_t target_axis, int64_t dim_size, + InlinedVector& consumed, int64_t& start, int64_t& end) const; }; /** diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.cc b/onnxruntime/core/optimizer/gather_slice_fusion.cc deleted file mode 100644 index 21266d356a020..0000000000000 --- a/onnxruntime/core/optimizer/gather_slice_fusion.cc +++ /dev/null @@ -1,344 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/optimizer/gather_slice_fusion.h" -#include "core/graph/graph_utils.h" -#include "core/optimizer/initializer.h" -#include "core/optimizer/utils.h" - -namespace onnxruntime { - -bool GatherSliceToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, - int64_t& axis, int64_t& indices_n_dims) const { - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) || - !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { - return false; - } - - const NodeArg& input_arg = *(node.InputDefs()[1]); - - if (!optimizer_utils::IsScalar(input_arg)) return false; - - const ONNX_NAMESPACE::TensorProto* indices_init = graph_utils::GetConstantInitializer(graph, input_arg.Name()); - - if (!indices_init) return false; - - if (indices_init->data_type() != ONNX_NAMESPACE::TensorProto::INT64) return false; - - // get the index value - Initializer init_const(*indices_init, graph.ModelPath()); - index = *(init_const.data()); - - // get attributes value - axis = 0; - auto& attrs = node.GetAttributes(); - if (attrs.find("axis") != attrs.end()) { - auto& axis_attr = attrs.at("axis"); - if (utils::HasInt(axis_attr)) axis = axis_attr.i(); - } - - indices_n_dims = indices_init->dims_size(); - return true; -} - -bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& node, - InlinedVector& starts, - InlinedVector& ends, - InlinedVector& axes, - InlinedVector& steps) const { - // check the version of Slice ops - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11, 13}) || - !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { - return false; - } - - // get the opset version - int onnx_opset_version = -1; - if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { - onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); - } - - // If Slice op of opset version 1 - if (onnx_opset_version == 1) { - if (!graph_utils::GetRepeatedNodeAttributeValues(node, "starts", starts) || - !graph_utils::GetRepeatedNodeAttributeValues(node, "ends", ends) || - starts.size() != ends.size()) { - return false; - } - - if (graph_utils::GetRepeatedNodeAttributeValues(node, "axes", axes) && (axes.size() != starts.size())) { - return false; - } - } - - // If Slice op of opset version >= 10 - if (onnx_opset_version >= 10) { - // node inputs include: starts - ends - axes - steps - - // return a pointer to the corresponding NodeArg if input of the node at the index exists - auto get_input_if_exists = [&node](size_t input_index) -> const NodeArg* { - const auto& input_defs = node.InputDefs(); - const NodeArg* input = (input_defs.size() > input_index) ? input_defs[input_index] : nullptr; - return (input == nullptr || !input->Exists()) ? nullptr : input; - }; - - // return a pointer to the initializer if it is constant; otherwise, a nullptr - auto get_initializer_if_constant = - [&graph, get_input_if_exists](size_t input_index) -> const ONNX_NAMESPACE::TensorProto* { - const NodeArg* input = get_input_if_exists(input_index); - return input ? graph_utils::GetConstantInitializer(graph, input->Name()) : nullptr; - }; - - // return the initialization data if it is constant - auto get_initializer_data = - [&graph](const ONNX_NAMESPACE::TensorProto* slice_initializer) -> InlinedVector { - Initializer init(*slice_initializer, graph.ModelPath()); - if (slice_initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT32) { - int32_t* init_data = init.data(); - return InlinedVector(init_data, init_data + init.size()); - } - - if (slice_initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT64) { - int64_t* init_data = init.data(); - return InlinedVector(init_data, init_data + init.size()); - } - return {}; - }; - - // starts and ends inputs have to exist, be constants and be of the same size. - const ONNX_NAMESPACE::TensorProto* starts_init = get_initializer_if_constant(1); - const ONNX_NAMESPACE::TensorProto* ends_init = get_initializer_if_constant(2); - const ONNX_NAMESPACE::TensorProto* axes_init = get_initializer_if_constant(3); - const ONNX_NAMESPACE::TensorProto* steps_init = get_initializer_if_constant(4); - - if (!starts_init || !ends_init || !axes_init || !steps_init) { - return false; - } - - starts = get_initializer_data(starts_init); - ends = get_initializer_data(ends_init); - axes = get_initializer_data(axes_init); - steps = get_initializer_data(steps_init); - - if (starts.size() == 0 || ends.size() == 0 || starts.size() != ends.size()) { - return false; - } - - if (axes_init->dims_size() != 1 || static_cast(axes_init->dims().Get(0)) != starts.size()) { - return false; - } - - // if steps exists, it should be constant and all value should be 1 - if (steps.size() != starts.size()) { - return false; - } - - for (int64_t step : steps) { - if (step != 1) { - return false; - } - } - } - - return true; -} - -/* -GatherToSplitFusion is to fuse: - Node - |-> Gather(index=0, axis=axis) - |-> Gather(index=1, axis=axis) - |-> Slice(index=2, axis=axis) -To - Node - |-> Split(index=0) -So that we can use one kernel to finish the job. -*/ - -Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, - const logging::Logger& logger) const { - GraphViewer graph_viewer(graph); - - const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); - - InlinedVector output_args; - - // Iterate the topological order and get Reshape ops - for (auto node_index : node_topology_list) { - auto* p_node = graph.GetNode(node_index); - - if (p_node == nullptr) continue; - - Node& node = *p_node; - - ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); - - // Currently only catch after Reshape ops, optimize in the future - if (node.OpType() != "Reshape") continue; - - size_t output_count = node.GetOutputEdgesCount(); - - // We only catch 1 scenario for Multi Query Attention for now. - // |---> Gather - // Reshape |---> Gather - // |---> Slice - // |... or (other ops) - - // Get the output into node args - if (output_count < 3) continue; - - output_args.push_back(node.OutputDefs()[0]); - } - - // iterate the children of Reshape node - for (const NodeArg* node_arg : output_args) { - auto shape = node_arg->Shape(); - if (!shape) continue; - - auto consumers = graph.GetConsumerNodes(node_arg->Name()); - size_t consumer_count = consumers.size(); - - // get the tensor rank - int64_t rank = static_cast(shape->dim_size()); - - bool can_fuse = true; - bool first_edge = true; - int64_t split_axis = 0; - int64_t indices_n_dims = -1; - - // Fuse 2 Gathers and 1 slice to Split - // Get those outputs as Split outputs - InlinedVector split_outputs(3); - - InlinedVector> nodes_to_fuse; - size_t gather_node_count = 2, slice_node_count = 0; - - // find the nodes to be merged - for (auto consumer : consumers) { - int64_t index, axis, dims; - InlinedVector starts, ends, axes, steps; - - bool IsSupportedGatherOps = IsSupportedGather(graph, *consumer, index, axis, dims); - bool IsSupportedSliceOps = IsSupportedSlice(graph, *consumer, starts, ends, axes, steps); - - if ((!consumer || consumer->InputDefs()[0] != node_arg) || - (!IsSupportedGatherOps && !IsSupportedSliceOps)) { - break; - } - - if (IsSupportedGatherOps) { - if (indices_n_dims == -1) { - indices_n_dims = dims; - } else if (indices_n_dims != dims) { - // Not the same number of dimensions (0 or 1) for all scalar indices. - can_fuse = false; - break; - } - - if (axis < 0) axis += rank; - - if (first_edge) { - auto dim = shape->dim(static_cast(axis)); - // dim.dim_value() = 73 - if (!utils::HasDimValue(dim)) { - can_fuse = false; - break; - } - split_axis = axis; - first_edge = false; - } else if (axis != split_axis) { - can_fuse = false; - break; - } - - if (index < 0) index += static_cast(consumer_count); - if (index < 0 || index >= static_cast(consumer_count)) { - can_fuse = false; - break; - } - - Node& gather_node = *graph.GetNode(consumer->Index()); - nodes_to_fuse.push_back(gather_node); - NodeArg* gather_output_args = gather_node.MutableOutputDefs()[0]; - split_outputs[gather_node_count--] = gather_output_args; - } - - // check the Slice Ops - if (IsSupportedSliceOps) { - if (axes[0] != axis && !first_edge) { - can_fuse = false; - break; - } - - Node& slice_node = *graph.GetNode(consumer->Index()); - NodeArg* slice_output_args = slice_node.MutableOutputDefs()[0]; - nodes_to_fuse.push_back(slice_node); - split_outputs[slice_node_count++] = slice_output_args; - } - } - - // condition check - if (!can_fuse || gather_node_count != 0 || slice_node_count != 1) continue; - - // generate the split node and merge the kernel - ONNX_NAMESPACE::TypeProto split_output_type; - const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast( - node_arg->TypeAsProto()->tensor_type().elem_type()); - - split_output_type.mutable_tensor_type()->set_elem_type(element_type); - - for (int64_t i = 0; i < rank; i++) { - if (i == split_axis) - split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL); - else - *(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast(i)); - } - - InlinedVector split_output_types; - - for (size_t i = 0; i < consumer_count; ++i) { - split_output_types.push_back( - &graph.GetOrCreateNodeArg( - graph.GenerateNodeArgName("fused_split_" + std::to_string(i)), &split_output_type)); - } - - // Generate the Split Node - ONNX_NAMESPACE::TensorProto split_initializer_proto; - split_initializer_proto.set_name(graph.GenerateNodeName("fused_Split")); - split_initializer_proto.add_dims(static_cast(3)); - split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - - auto dim_value = shape->dim(static_cast(split_axis)).dim_value(); - // Optimize 2 Gather Nodes, so Slice_dim = dim_value - 2 - int64_t slice_dim = static_cast(dim_value - 2); - InlinedVector split_value{{slice_dim, 1, 1}}; - split_initializer_proto.set_raw_data(split_value.data(), split_value.size() * sizeof(int64_t)); - NodeArg* split_arg = &graph_utils::AddInitializer(graph, split_initializer_proto); - - Node& split_node = - graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for fused Gather-Slice fusion", - {graph.GetNodeArg(node_arg->Name()), split_arg}, split_outputs); - - split_node.AddAttribute("axis", split_axis); - - split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); - - int onnx_opset_version = -1; - if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { - onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); - } - - if (onnx_opset_version >= 18) { - split_node.AddAttribute("num_outputs", static_cast(consumer_count)); - } - - for (Node& node_to_fuse : nodes_to_fuse) { - graph_utils::RemoveNodeOutputEdges(graph, node_to_fuse); - graph.RemoveNode(node_to_fuse.Index()); - } - modified = true; - } - - return Status::OK(); -} -} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.h b/onnxruntime/core/optimizer/gather_slice_fusion.h deleted file mode 100644 index 1c5c307efed7f..0000000000000 --- a/onnxruntime/core/optimizer/gather_slice_fusion.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/optimizer/graph_transformer.h" - -namespace onnxruntime { - -/** -@class GatherSliceToSplitFusion -Fuse (2 Gather nodes + 1 Slice) to 1 split node. -*/ - -class GatherSliceToSplitFusion : public GraphTransformer { - private: - bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, - int64_t& indices_n_dims) const; - - bool IsSupportedSlice(const Graph& graph, const Node& node, - InlinedVector& starts, - InlinedVector& ends, - InlinedVector& axes, - InlinedVector& steps) const; - - public: - GatherSliceToSplitFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept - : GraphTransformer("GatherSliceToSplitFusion", compatible_execution_providers) {} - - Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; -}; -} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 4e939fe3c7b6b..8376b87aee6b2 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -37,7 +37,6 @@ #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/free_dim_override_transformer.h" #include "core/optimizer/gather_fusion.h" -#include "core/optimizer/gather_slice_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -307,9 +306,8 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index e1fcf835c6043..16f38bac62713 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -42,7 +42,6 @@ #include "core/optimizer/expand_elimination.h" #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/gather_fusion.h" -#include "core/optimizer/gather_slice_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -7059,13 +7058,13 @@ TEST_F(GraphTransformationTests, ConstantSharing_ShouldNotShareForGraphOutput) { } } -TEST_F(GraphTransformationTests, GatherToSplitFusion) { +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_AllGather) { auto build_test_case = [&](ModelTestBuilder& builder) { auto* data_arg = builder.MakeInput({{54}}); auto* shape_arg = builder.MakeInput({{4}}); auto* reshape_out = builder.MakeIntermediate({{2, 3, 3, 3}}); auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); - auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); + auto* gather_index_2 = builder.MakeInitializer({1}, {static_cast(1)}); auto* gather_index_3 = builder.MakeInitializer({}, {static_cast(2)}); auto* gather_out_1 = builder.MakeIntermediate(); auto* gather_out_2 = builder.MakeIntermediate(); @@ -7082,7 +7081,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3}) .AddAttribute("axis", static_cast(2)); builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector{0, 2, 1}); - builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); }; @@ -7091,27 +7091,16 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { return Status::OK(); }; - // OpSet-12 + // OpSet-12, not support { auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); - } else if (node.OpType() == "Squeeze") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axes") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axes").ints().at(0))); - } - } + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0); return Status::OK(); }; - std::unique_ptr transformer = std::make_unique(); + std::unique_ptr transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } @@ -7121,7 +7110,7 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { auto post_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 2); for (auto& node : graph.Nodes()) { if (node.OpType() == "Split") { auto& attrs = node.GetAttributes(); @@ -7140,249 +7129,140 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { return Status::OK(); }; - std::unique_ptr transformer = std::make_unique(); + std::unique_ptr transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } - - // OpSet-18 - { - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); - } else if (node.OpType() == "Squeeze") { - const NodeArg& input_arg = *(node.InputDefs()[1]); - const ONNX_NAMESPACE::TensorProto* tensor_proto = - graph_utils::GetConstantInitializer(graph, input_arg.Name()); - TEST_RETURN_IF_NOT(tensor_proto != nullptr); - Initializer init_const{*tensor_proto, graph.ModelPath()}; - TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); - TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); - } - } - return Status::OK(); - }; - - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } } -TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) { +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_AllSlice_GraphInput) { auto build_test_case = [&](ModelTestBuilder& builder) { - auto* data_arg = builder.MakeInput({{54}}); - auto* shape_arg = builder.MakeInput({{4}}); - auto* reshape_out = builder.MakeIntermediate({{2, 3, 3, 3}}); - auto* gather_index_1 = builder.MakeInitializer({1}, {static_cast(0)}); - auto* gather_index_2 = builder.MakeInitializer({1}, {static_cast(1)}); - auto* gather_index_3 = builder.MakeInitializer({1}, {static_cast(2)}); - auto* gather_out_1 = builder.MakeIntermediate(); - auto* gather_out_2 = builder.MakeIntermediate(); - auto* gather_out_3 = builder.MakeIntermediate(); + auto* data_arg = builder.MakeInput({{2, 3, 8, 3}}); + auto* starts_1 = builder.MakeInitializer({1}, {0}); + auto* ends_1 = builder.MakeInitializer({1}, {2}); + auto* axes_1 = builder.MakeInitializer({1}, {2}); + auto* steps_1 = builder.MakeInitializer({1}, {1}); + auto* starts_2 = builder.MakeInitializer({1}, {2}); + auto* ends_2 = builder.MakeInitializer({1}, {-2}); + auto* axes_2 = builder.MakeInitializer({1}, {-2}); + auto* steps_2 = builder.MakeInitializer({1}, {1}); + auto* starts_3 = builder.MakeInitializer({1}, {-2}); + auto* ends_3 = builder.MakeInitializer({1}, {16}); + auto* axes_3 = builder.MakeInitializer({1}, {2}); + auto* slice_out_1 = builder.MakeIntermediate(); + auto* slice_out_2 = builder.MakeIntermediate(); + auto* slice_out_3 = builder.MakeIntermediate(); auto* transpose_out_1 = builder.MakeOutput(); auto* transpose_out_2 = builder.MakeOutput(); auto* transpose_out_3 = builder.MakeOutput(); - builder.AddNode("Reshape", {data_arg, shape_arg}, {reshape_out}); - builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) - .AddAttribute("axis", static_cast(2)); - builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2}) - .AddAttribute("axis", static_cast(-2)); - builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3}) - .AddAttribute("axis", static_cast(2)); - builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector{0, 2, 1}); - builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector{0, 2, 1}); - builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Slice", {data_arg, starts_1, ends_1, axes_1, steps_1}, {slice_out_1}); + builder.AddNode("Slice", {data_arg, starts_2, ends_2, axes_2, steps_2}, {slice_out_2}); + builder.AddNode("Slice", {data_arg, starts_3, ends_3, axes_3}, {slice_out_3}); + builder.AddNode("Transpose", {slice_out_1}, {transpose_out_1}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + builder.AddNode("Transpose", {slice_out_2}, {transpose_out_2}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + builder.AddNode("Transpose", {slice_out_3}, {transpose_out_3}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); }; auto pre_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 3); return Status::OK(); }; - // OpSet-12 - { - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0); - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); - } - } - return Status::OK(); - }; - - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } - - // OpSet-14 - { - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0); - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); - } - } - return Status::OK(); - }; - - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } - - // OpSet-18 - { - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0); - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); - } + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); } - return Status::OK(); - }; + } + return Status::OK(); + }; - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); } -TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Input) { +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Combined) { auto build_test_case = [&](ModelTestBuilder& builder) { - auto* data_arg = builder.MakeInput({{2, 3, 3, 3}}); - auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); - auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); - auto* gather_index_3 = builder.MakeInitializer({}, {static_cast(2)}); + auto* data_arg = builder.MakeInput({{144}}); + auto* shape_arg = builder.MakeInput({{4}}); + auto* reshape_out = builder.MakeIntermediate({{2, 8, 3, 3}}); + auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(5)}); + auto* starts_2 = builder.MakeInitializer({1}, {6}); + auto* ends_2 = builder.MakeInitializer({1}, {8}); + auto* axes_2 = builder.MakeInitializer({1}, {-3}); + auto* steps_2 = builder.MakeInitializer({1}, {1}); + auto* gather_index_3 = builder.MakeInitializer({1}, {static_cast(4)}); + auto* starts_4 = builder.MakeInitializer({1}, {-16}); + auto* ends_4 = builder.MakeInitializer({1}, {4}); + auto* axes_4 = builder.MakeInitializer({1}, {1}); auto* gather_out_1 = builder.MakeIntermediate(); - auto* gather_out_2 = builder.MakeIntermediate(); + auto* slice_out_2 = builder.MakeIntermediate(); auto* gather_out_3 = builder.MakeIntermediate(); + auto* slice_out_4 = builder.MakeIntermediate(); auto* transpose_out_1 = builder.MakeOutput(); auto* transpose_out_2 = builder.MakeOutput(); auto* transpose_out_3 = builder.MakeOutput(); + auto* transpose_out_4 = builder.MakeOutput(); - builder.AddNode("Gather", {data_arg, gather_index_1}, {gather_out_1}).AddAttribute("axis", static_cast(2)); - builder.AddNode("Gather", {data_arg, gather_index_2}, {gather_out_2}) - .AddAttribute("axis", static_cast(-2)); - builder.AddNode("Gather", {data_arg, gather_index_3}, {gather_out_3}).AddAttribute("axis", static_cast(2)); + builder.AddNode("Reshape", {data_arg, shape_arg}, {reshape_out}); + builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) + .AddAttribute("axis", static_cast(1)); + builder.AddNode("Slice", {reshape_out, starts_2, ends_2, axes_2, steps_2}, {slice_out_2}); + builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3}) + .AddAttribute("axis", static_cast(-3)); + builder.AddNode("Slice", {reshape_out, starts_4, ends_4, axes_4}, {slice_out_4}); builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector{0, 2, 1}); - builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector{0, 2, 1}); - builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Transpose", {slice_out_2}, {transpose_out_2}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + builder.AddNode("Transpose", {slice_out_4}, {transpose_out_4}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); }; auto pre_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 2); return Status::OK(); }; - // OpSet-12 - { - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); - } else if (node.OpType() == "Squeeze") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axes") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axes").ints().at(0))); - } - } - return Status::OK(); - }; - - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } - - // OpSet-14 - { - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); - } else if (node.OpType() == "Squeeze") { - const NodeArg& input_arg = *(node.InputDefs()[1]); - const ONNX_NAMESPACE::TensorProto* tensor_proto = - graph_utils::GetConstantInitializer(graph, input_arg.Name()); - TEST_RETURN_IF_NOT(tensor_proto != nullptr); - Initializer init_const{*tensor_proto, graph.ModelPath()}; - TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); - TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); - } - } - return Status::OK(); - }; - - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } - - // OpSet-18 - { - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); - } else if (node.OpType() == "Squeeze") { - const NodeArg& input_arg = *(node.InputDefs()[1]); - const ONNX_NAMESPACE::TensorProto* tensor_proto = - graph_utils::GetConstantInitializer(graph, input_arg.Name()); - TEST_RETURN_IF_NOT(tensor_proto != nullptr); - Initializer init_const{*tensor_proto, graph.ModelPath()}; - TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); - TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); - } + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 1); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(1 == static_cast(attrs.at("axis").i())); + } else if (node.OpType() == "Squeeze") { + const NodeArg& input_arg = *(node.InputDefs()[1]); + const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); + TEST_RETURN_IF_NOT(tensor_proto != nullptr); + Initializer init_const{*tensor_proto, graph.ModelPath()}; + TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); + TEST_RETURN_IF_NOT(1 == static_cast(*(init_const.data()))); } - return Status::OK(); - }; + } + return Status::OK(); + }; - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); } -TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Initializer) { +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Consume_Initializer) { auto build_test_case = [&](ModelTestBuilder& builder) { auto* data_arg = builder.MakeInitializer({2, 3, 3, 3}, std::vector(54)); auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); @@ -7430,31 +7310,31 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Initializer) { return Status::OK(); }; - std::unique_ptr transformer = std::make_unique(); + std::unique_ptr transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } -TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Invalid) { auto pre_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] > 0 || CountOpsInGraph(graph)["Slice"] > 0); return Status::OK(); }; auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] > 0 || CountOpsInGraph(graph)["Slice"] > 0); TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0); TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0); return Status::OK(); }; - // Invalid shape. + // Not cover all elements of specific dimension. { auto build_test_case = [&](ModelTestBuilder& builder) { auto* data_arg = builder.MakeInput({{72}}); - auto* shape_arg = builder.MakeInput({{1}}); + auto* shape_arg = builder.MakeInput({{4}}); auto* reshape_out = builder.MakeIntermediate({{2, 3, 4, 3}}); auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); - auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); + auto* gather_index_2 = builder.MakeInitializer({1}, {static_cast(1)}); auto* gather_index_3 = builder.MakeInitializer({}, {static_cast(2)}); auto* gather_out_1 = builder.MakeIntermediate(); auto* gather_out_2 = builder.MakeIntermediate(); @@ -7467,63 +7347,65 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) .AddAttribute("axis", static_cast(2)); builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2}) - .AddAttribute("axis", static_cast(2)); + .AddAttribute("axis", static_cast(-2)); builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3}) .AddAttribute("axis", static_cast(2)); builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) .AddAttribute("perm", std::vector{0, 2, 1}); builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}) - .AddAttribute("perm", std::vector{0, 2, 1}); + .AddAttribute("perm", std::vector{0, 2, 1, 3}); builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}) .AddAttribute("perm", std::vector{0, 2, 1}); }; - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } - // Invalid Gather indices. + // Has overlap. { auto build_test_case = [&](ModelTestBuilder& builder) { - auto* data_arg = builder.MakeInput({{54}}); - auto* shape_arg = builder.MakeInput({{1}}); - auto* reshape_out = builder.MakeIntermediate({{2, 3, 3, 3}}); - auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); - auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); - auto* gather_index_3 = builder.MakeInitializer({}, {static_cast(1)}); - auto* gather_out_1 = builder.MakeIntermediate(); - auto* gather_out_2 = builder.MakeIntermediate(); - auto* gather_out_3 = builder.MakeIntermediate(); + auto* data_arg = builder.MakeInput({{2, 3, 8, 3}}); + auto* starts_1 = builder.MakeInitializer({1}, {0}); + auto* ends_1 = builder.MakeInitializer({1}, {3}); + auto* axes_1 = builder.MakeInitializer({1}, {2}); + auto* steps_1 = builder.MakeInitializer({1}, {1}); + auto* starts_2 = builder.MakeInitializer({1}, {2}); + auto* ends_2 = builder.MakeInitializer({1}, {-2}); + auto* axes_2 = builder.MakeInitializer({1}, {-2}); + auto* steps_2 = builder.MakeInitializer({1}, {1}); + auto* starts_3 = builder.MakeInitializer({1}, {-2}); + auto* ends_3 = builder.MakeInitializer({1}, {16}); + auto* axes_3 = builder.MakeInitializer({1}, {2}); + auto* slice_out_1 = builder.MakeIntermediate(); + auto* slice_out_2 = builder.MakeIntermediate(); + auto* slice_out_3 = builder.MakeIntermediate(); auto* transpose_out_1 = builder.MakeOutput(); auto* transpose_out_2 = builder.MakeOutput(); auto* transpose_out_3 = builder.MakeOutput(); - builder.AddNode("Reshape", {data_arg, shape_arg}, {reshape_out}); - builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) - .AddAttribute("axis", static_cast(2)); - builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2}) - .AddAttribute("axis", static_cast(2)); - builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3}) - .AddAttribute("axis", static_cast(2)); - builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) - .AddAttribute("perm", std::vector{0, 2, 1}); - builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}) - .AddAttribute("perm", std::vector{0, 2, 1}); - builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}) - .AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Slice", {data_arg, starts_1, ends_1, axes_1, steps_1}, {slice_out_1}); + builder.AddNode("Slice", {data_arg, starts_2, ends_2, axes_2, steps_2}, {slice_out_2}); + builder.AddNode("Slice", {data_arg, starts_3, ends_3, axes_3}, {slice_out_3}); + builder.AddNode("Transpose", {slice_out_1}, {transpose_out_1}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + builder.AddNode("Transpose", {slice_out_2}, {transpose_out_2}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + builder.AddNode("Transpose", {slice_out_3}, {transpose_out_3}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); }; - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } - // Invalid Gather axis. + // Invalid axis. { auto build_test_case = [&](ModelTestBuilder& builder) { auto* data_arg = builder.MakeInput({{54}}); - auto* shape_arg = builder.MakeInput({{1}}); + auto* shape_arg = builder.MakeInput({{4}}); auto* reshape_out = builder.MakeIntermediate({{2, 3, 3, 3}}); auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); @@ -7550,7 +7432,7 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { .AddAttribute("perm", std::vector{0, 2, 1}); }; - std::unique_ptr transformer = std::make_unique(); + std::unique_ptr transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } @@ -7643,143 +7525,5 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) { } } -TEST_F(GraphTransformationTests, GatherSliceToSplitFusion) { - { - auto build_test_case = [&](ModelTestBuilder& builder) { - auto* data_arg = builder.MakeInput({{54}}); - auto* reshape_arg = builder.MakeInput({{4}}); - auto* reshape_out = builder.MakeIntermediate({{2, 512, 73, 64}}); - builder.AddNode("Reshape", {data_arg, reshape_arg}, {reshape_out}); - - // Create Gather-1 Ops - auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(-2)}); - auto* gather_out_1 = builder.MakeIntermediate({{2, 512, 1, 64}}); - builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) - .AddAttribute("axis", static_cast(2)); - - // Create Transpose 1-Ops - auto* transpose_out_1 = builder.MakeOutput(); - builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) - .AddAttribute("perm", std::vector{0, 2, 1, 3}); - - // Create Gather-2 Ops - auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(-1)}); - auto* gather_out_2 = builder.MakeIntermediate({{2, 512, 1, 64}}); - builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2}) - .AddAttribute("axis", static_cast(2)); - - // Create Transpose-2 Ops - auto* transpose_out_2 = builder.MakeOutput(); - builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}) - .AddAttribute("perm", std::vector{0, 2, 1, 3}); - - // Create Slice Ops - auto* slice_output = builder.MakeIntermediate(); - auto* starts = builder.MakeInitializer({1}, {0}); - auto* ends = builder.MakeInitializer({1}, {-2}); - auto* axes = builder.MakeInitializer({1}, {2}); - auto* steps = builder.MakeInitializer({1}, {1}); - builder.AddNode("Slice", {reshape_out, starts, ends, axes, steps}, {slice_output}); - - // Create Shape-1 Ops - auto* shape_output_1 = builder.MakeOutput(); - builder.AddNode("Shape", {slice_output}, {shape_output_1}); - - // Create Shape-2 Ops - auto* shape_output_2 = builder.MakeOutput(); - builder.AddNode("Shape", {slice_output}, {shape_output_2}); - - // Create Transpose-3 Ops - auto* transpose_out_3 = builder.MakeOutput(); - builder.AddNode("Transpose", {slice_output}, {transpose_out_3}) - .AddAttribute("perm", std::vector{0, 2, 1, 3}); - }; - - auto pre_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 2); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); - return Status::OK(); - }; - - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(static_cast(attrs.at("axis").i()) == 2); - } - } - return Status::OK(); - }; - - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } -} - -TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Invalid) { - { - auto build_test_case = [&](ModelTestBuilder& builder) { - auto* data_arg = builder.MakeInput({{54}}); - auto* reshape_arg = builder.MakeInput({{4}}); - auto* reshape_out = builder.MakeIntermediate({{2, 512, 73, 64}}); - builder.AddNode("Reshape", {data_arg, reshape_arg}, {reshape_out}); - - // Create Gather-1 Ops - auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(-2)}); - auto* gather_out_1 = builder.MakeIntermediate({{2, 512, 1, 64}}); - builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) - .AddAttribute("axis", static_cast(2)); - - // Create Transpose 1-Ops - auto* transpose_out_1 = builder.MakeOutput(); - builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) - .AddAttribute("perm", std::vector{0, 2, 1, 3}); - - // Create Slice Ops - auto* slice_output = builder.MakeIntermediate(); - auto* starts = builder.MakeInitializer({1}, {0}); - auto* ends = builder.MakeInitializer({1}, {-2}); - auto* axes = builder.MakeInitializer({1}, {2}); - auto* steps = builder.MakeInitializer({1}, {1}); - builder.AddNode("Slice", {reshape_out, starts, ends, axes, steps}, {slice_output}); - - // Create Shape-1 Ops - auto* shape_output_1 = builder.MakeOutput(); - builder.AddNode("Shape", {slice_output}, {shape_output_1}); - - // Create Shape-2 Ops - auto* shape_output_2 = builder.MakeOutput(); - builder.AddNode("Shape", {slice_output}, {shape_output_2}); - - // Create Transpose-3 Ops - auto* transpose_out_3 = builder.MakeOutput(); - builder.AddNode("Transpose", {slice_output}, {transpose_out_3}) - .AddAttribute("perm", std::vector{0, 2, 1, 3}); - }; - - auto pre_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); - return Status::OK(); - }; - - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0); - return Status::OK(); - }; - - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } -} - } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 0b68dc65e41cd..5d527369a1b75 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -24,7 +24,6 @@ #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/free_dim_override_transformer.h" #include "core/optimizer/gather_fusion.h" -#include "core/optimizer/gather_slice_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -139,9 +138,8 @@ std::vector> GeneratePreTrainingTransformers( transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); - transformers.emplace_back(std::make_unique(compatible_eps)); - transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); + transformers.emplace_back(std::make_unique(compatible_eps)); // If a model with Q, DQ nodes is being used for the purpose of training, it must be for // Quantization Aware Training. So, replace QDQ nodes with FakeQuant. transformers.emplace_back(std::make_unique(compatible_eps));