From a67e6925468effd1897c2f541821d32a2860a037 Mon Sep 17 00:00:00 2001 From: rui-ren Date: Wed, 14 Feb 2024 15:07:56 -0800 Subject: [PATCH] add GatherSliceToSplitFusion and Unittest (#19218) ### Multi Query Attention Optimization in multi-query attention ``` batch_size, seq_length, three_times_hidden_size = fused_qkv.shape fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim) return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :] ``` which can be optimized to ``` batch_size, seq_length, three_times_hidden_size = fused_qkv.shape fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim) (query, key, value) = fused_qkv.split([self.num_heads, 1, 1], dim=2) return query, key, value ``` this optimization can be validated from nsight profiling and perf benchmarking. image As such, This PR is to Optimize the `Gather/Gather/Slice` Ops to `Split` Kernel. ### Optimization Target As 2 `Gather` and 1 `Slice` Kernels are time consuming for backward prop, it would be efficient to use 1 `Split` Kernel ### Example - Before Fusion image - After Fusion image ### Perf Gain After the optimization, there will have **~7%** perf gain. > The `Transpose` Kernel can be fused too, will update it in next PR. However, after testing Transponse Ops fusion on Falcon model, there is no perf gain. Will not create a new PR. --------- Co-authored-by: ruiren --- .../core/optimizer/gather_slice_fusion.cc | 344 ++++++++++++++++++ .../core/optimizer/gather_slice_fusion.h | 32 ++ .../core/optimizer/graph_transformer_utils.cc | 2 + .../test/optimizer/graph_transform_test.cc | 139 +++++++ .../core/optimizer/graph_transformer_utils.cc | 2 + 5 files changed, 519 insertions(+) create mode 100644 onnxruntime/core/optimizer/gather_slice_fusion.cc create mode 100644 onnxruntime/core/optimizer/gather_slice_fusion.h diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.cc b/onnxruntime/core/optimizer/gather_slice_fusion.cc new file mode 100644 index 0000000000000..21266d356a020 --- /dev/null +++ b/onnxruntime/core/optimizer/gather_slice_fusion.cc @@ -0,0 +1,344 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/gather_slice_fusion.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" + +namespace onnxruntime { + +bool GatherSliceToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, + int64_t& axis, int64_t& indices_n_dims) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { + return false; + } + + const NodeArg& input_arg = *(node.InputDefs()[1]); + + if (!optimizer_utils::IsScalar(input_arg)) return false; + + const ONNX_NAMESPACE::TensorProto* indices_init = graph_utils::GetConstantInitializer(graph, input_arg.Name()); + + if (!indices_init) return false; + + if (indices_init->data_type() != ONNX_NAMESPACE::TensorProto::INT64) return false; + + // get the index value + Initializer init_const(*indices_init, graph.ModelPath()); + index = *(init_const.data()); + + // get attributes value + axis = 0; + auto& attrs = node.GetAttributes(); + if (attrs.find("axis") != attrs.end()) { + auto& axis_attr = attrs.at("axis"); + if (utils::HasInt(axis_attr)) axis = axis_attr.i(); + } + + indices_n_dims = indices_init->dims_size(); + return true; +} + +bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& node, + InlinedVector& starts, + InlinedVector& ends, + InlinedVector& axes, + InlinedVector& steps) const { + // check the version of Slice ops + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11, 13}) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { + return false; + } + + // get the opset version + int onnx_opset_version = -1; + if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { + onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); + } + + // If Slice op of opset version 1 + if (onnx_opset_version == 1) { + if (!graph_utils::GetRepeatedNodeAttributeValues(node, "starts", starts) || + !graph_utils::GetRepeatedNodeAttributeValues(node, "ends", ends) || + starts.size() != ends.size()) { + return false; + } + + if (graph_utils::GetRepeatedNodeAttributeValues(node, "axes", axes) && (axes.size() != starts.size())) { + return false; + } + } + + // If Slice op of opset version >= 10 + if (onnx_opset_version >= 10) { + // node inputs include: starts - ends - axes - steps + + // return a pointer to the corresponding NodeArg if input of the node at the index exists + auto get_input_if_exists = [&node](size_t input_index) -> const NodeArg* { + const auto& input_defs = node.InputDefs(); + const NodeArg* input = (input_defs.size() > input_index) ? input_defs[input_index] : nullptr; + return (input == nullptr || !input->Exists()) ? nullptr : input; + }; + + // return a pointer to the initializer if it is constant; otherwise, a nullptr + auto get_initializer_if_constant = + [&graph, get_input_if_exists](size_t input_index) -> const ONNX_NAMESPACE::TensorProto* { + const NodeArg* input = get_input_if_exists(input_index); + return input ? graph_utils::GetConstantInitializer(graph, input->Name()) : nullptr; + }; + + // return the initialization data if it is constant + auto get_initializer_data = + [&graph](const ONNX_NAMESPACE::TensorProto* slice_initializer) -> InlinedVector { + Initializer init(*slice_initializer, graph.ModelPath()); + if (slice_initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT32) { + int32_t* init_data = init.data(); + return InlinedVector(init_data, init_data + init.size()); + } + + if (slice_initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT64) { + int64_t* init_data = init.data(); + return InlinedVector(init_data, init_data + init.size()); + } + return {}; + }; + + // starts and ends inputs have to exist, be constants and be of the same size. + const ONNX_NAMESPACE::TensorProto* starts_init = get_initializer_if_constant(1); + const ONNX_NAMESPACE::TensorProto* ends_init = get_initializer_if_constant(2); + const ONNX_NAMESPACE::TensorProto* axes_init = get_initializer_if_constant(3); + const ONNX_NAMESPACE::TensorProto* steps_init = get_initializer_if_constant(4); + + if (!starts_init || !ends_init || !axes_init || !steps_init) { + return false; + } + + starts = get_initializer_data(starts_init); + ends = get_initializer_data(ends_init); + axes = get_initializer_data(axes_init); + steps = get_initializer_data(steps_init); + + if (starts.size() == 0 || ends.size() == 0 || starts.size() != ends.size()) { + return false; + } + + if (axes_init->dims_size() != 1 || static_cast(axes_init->dims().Get(0)) != starts.size()) { + return false; + } + + // if steps exists, it should be constant and all value should be 1 + if (steps.size() != starts.size()) { + return false; + } + + for (int64_t step : steps) { + if (step != 1) { + return false; + } + } + } + + return true; +} + +/* +GatherToSplitFusion is to fuse: + Node + |-> Gather(index=0, axis=axis) + |-> Gather(index=1, axis=axis) + |-> Slice(index=2, axis=axis) +To + Node + |-> Split(index=0) +So that we can use one kernel to finish the job. +*/ + +Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + InlinedVector output_args; + + // Iterate the topological order and get Reshape ops + for (auto node_index : node_topology_list) { + auto* p_node = graph.GetNode(node_index); + + if (p_node == nullptr) continue; + + Node& node = *p_node; + + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + // Currently only catch after Reshape ops, optimize in the future + if (node.OpType() != "Reshape") continue; + + size_t output_count = node.GetOutputEdgesCount(); + + // We only catch 1 scenario for Multi Query Attention for now. + // |---> Gather + // Reshape |---> Gather + // |---> Slice + // |... or (other ops) + + // Get the output into node args + if (output_count < 3) continue; + + output_args.push_back(node.OutputDefs()[0]); + } + + // iterate the children of Reshape node + for (const NodeArg* node_arg : output_args) { + auto shape = node_arg->Shape(); + if (!shape) continue; + + auto consumers = graph.GetConsumerNodes(node_arg->Name()); + size_t consumer_count = consumers.size(); + + // get the tensor rank + int64_t rank = static_cast(shape->dim_size()); + + bool can_fuse = true; + bool first_edge = true; + int64_t split_axis = 0; + int64_t indices_n_dims = -1; + + // Fuse 2 Gathers and 1 slice to Split + // Get those outputs as Split outputs + InlinedVector split_outputs(3); + + InlinedVector> nodes_to_fuse; + size_t gather_node_count = 2, slice_node_count = 0; + + // find the nodes to be merged + for (auto consumer : consumers) { + int64_t index, axis, dims; + InlinedVector starts, ends, axes, steps; + + bool IsSupportedGatherOps = IsSupportedGather(graph, *consumer, index, axis, dims); + bool IsSupportedSliceOps = IsSupportedSlice(graph, *consumer, starts, ends, axes, steps); + + if ((!consumer || consumer->InputDefs()[0] != node_arg) || + (!IsSupportedGatherOps && !IsSupportedSliceOps)) { + break; + } + + if (IsSupportedGatherOps) { + if (indices_n_dims == -1) { + indices_n_dims = dims; + } else if (indices_n_dims != dims) { + // Not the same number of dimensions (0 or 1) for all scalar indices. + can_fuse = false; + break; + } + + if (axis < 0) axis += rank; + + if (first_edge) { + auto dim = shape->dim(static_cast(axis)); + // dim.dim_value() = 73 + if (!utils::HasDimValue(dim)) { + can_fuse = false; + break; + } + split_axis = axis; + first_edge = false; + } else if (axis != split_axis) { + can_fuse = false; + break; + } + + if (index < 0) index += static_cast(consumer_count); + if (index < 0 || index >= static_cast(consumer_count)) { + can_fuse = false; + break; + } + + Node& gather_node = *graph.GetNode(consumer->Index()); + nodes_to_fuse.push_back(gather_node); + NodeArg* gather_output_args = gather_node.MutableOutputDefs()[0]; + split_outputs[gather_node_count--] = gather_output_args; + } + + // check the Slice Ops + if (IsSupportedSliceOps) { + if (axes[0] != axis && !first_edge) { + can_fuse = false; + break; + } + + Node& slice_node = *graph.GetNode(consumer->Index()); + NodeArg* slice_output_args = slice_node.MutableOutputDefs()[0]; + nodes_to_fuse.push_back(slice_node); + split_outputs[slice_node_count++] = slice_output_args; + } + } + + // condition check + if (!can_fuse || gather_node_count != 0 || slice_node_count != 1) continue; + + // generate the split node and merge the kernel + ONNX_NAMESPACE::TypeProto split_output_type; + const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast( + node_arg->TypeAsProto()->tensor_type().elem_type()); + + split_output_type.mutable_tensor_type()->set_elem_type(element_type); + + for (int64_t i = 0; i < rank; i++) { + if (i == split_axis) + split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL); + else + *(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast(i)); + } + + InlinedVector split_output_types; + + for (size_t i = 0; i < consumer_count; ++i) { + split_output_types.push_back( + &graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("fused_split_" + std::to_string(i)), &split_output_type)); + } + + // Generate the Split Node + ONNX_NAMESPACE::TensorProto split_initializer_proto; + split_initializer_proto.set_name(graph.GenerateNodeName("fused_Split")); + split_initializer_proto.add_dims(static_cast(3)); + split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + + auto dim_value = shape->dim(static_cast(split_axis)).dim_value(); + // Optimize 2 Gather Nodes, so Slice_dim = dim_value - 2 + int64_t slice_dim = static_cast(dim_value - 2); + InlinedVector split_value{{slice_dim, 1, 1}}; + split_initializer_proto.set_raw_data(split_value.data(), split_value.size() * sizeof(int64_t)); + NodeArg* split_arg = &graph_utils::AddInitializer(graph, split_initializer_proto); + + Node& split_node = + graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for fused Gather-Slice fusion", + {graph.GetNodeArg(node_arg->Name()), split_arg}, split_outputs); + + split_node.AddAttribute("axis", split_axis); + + split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); + + int onnx_opset_version = -1; + if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { + onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); + } + + if (onnx_opset_version >= 18) { + split_node.AddAttribute("num_outputs", static_cast(consumer_count)); + } + + for (Node& node_to_fuse : nodes_to_fuse) { + graph_utils::RemoveNodeOutputEdges(graph, node_to_fuse); + graph.RemoveNode(node_to_fuse.Index()); + } + modified = true; + } + + return Status::OK(); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.h b/onnxruntime/core/optimizer/gather_slice_fusion.h new file mode 100644 index 0000000000000..1c5c307efed7f --- /dev/null +++ b/onnxruntime/core/optimizer/gather_slice_fusion.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** +@class GatherSliceToSplitFusion +Fuse (2 Gather nodes + 1 Slice) to 1 split node. +*/ + +class GatherSliceToSplitFusion : public GraphTransformer { + private: + bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, + int64_t& indices_n_dims) const; + + bool IsSupportedSlice(const Graph& graph, const Node& node, + InlinedVector& starts, + InlinedVector& ends, + InlinedVector& axes, + InlinedVector& steps) const; + + public: + GatherSliceToSplitFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("GatherSliceToSplitFusion", compatible_execution_providers) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index cd3c49be15aa4..4e939fe3c7b6b 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -37,6 +37,7 @@ #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/free_dim_override_transformer.h" #include "core/optimizer/gather_fusion.h" +#include "core/optimizer/gather_slice_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -308,6 +309,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index bf02c1741725f..e1fcf835c6043 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -42,6 +42,7 @@ #include "core/optimizer/expand_elimination.h" #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/gather_fusion.h" +#include "core/optimizer/gather_slice_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -7642,5 +7643,143 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) { } } +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion) { + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInput({{54}}); + auto* reshape_arg = builder.MakeInput({{4}}); + auto* reshape_out = builder.MakeIntermediate({{2, 512, 73, 64}}); + builder.AddNode("Reshape", {data_arg, reshape_arg}, {reshape_out}); + + // Create Gather-1 Ops + auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(-2)}); + auto* gather_out_1 = builder.MakeIntermediate({{2, 512, 1, 64}}); + builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) + .AddAttribute("axis", static_cast(2)); + + // Create Transpose 1-Ops + auto* transpose_out_1 = builder.MakeOutput(); + builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + + // Create Gather-2 Ops + auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(-1)}); + auto* gather_out_2 = builder.MakeIntermediate({{2, 512, 1, 64}}); + builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2}) + .AddAttribute("axis", static_cast(2)); + + // Create Transpose-2 Ops + auto* transpose_out_2 = builder.MakeOutput(); + builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + + // Create Slice Ops + auto* slice_output = builder.MakeIntermediate(); + auto* starts = builder.MakeInitializer({1}, {0}); + auto* ends = builder.MakeInitializer({1}, {-2}); + auto* axes = builder.MakeInitializer({1}, {2}); + auto* steps = builder.MakeInitializer({1}, {1}); + builder.AddNode("Slice", {reshape_out, starts, ends, axes, steps}, {slice_output}); + + // Create Shape-1 Ops + auto* shape_output_1 = builder.MakeOutput(); + builder.AddNode("Shape", {slice_output}, {shape_output_1}); + + // Create Shape-2 Ops + auto* shape_output_2 = builder.MakeOutput(); + builder.AddNode("Shape", {slice_output}, {shape_output_2}); + + // Create Transpose-3 Ops + auto* transpose_out_3 = builder.MakeOutput(); + builder.AddNode("Transpose", {slice_output}, {transpose_out_3}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(static_cast(attrs.at("axis").i()) == 2); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); + } +} + +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Invalid) { + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInput({{54}}); + auto* reshape_arg = builder.MakeInput({{4}}); + auto* reshape_out = builder.MakeIntermediate({{2, 512, 73, 64}}); + builder.AddNode("Reshape", {data_arg, reshape_arg}, {reshape_out}); + + // Create Gather-1 Ops + auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(-2)}); + auto* gather_out_1 = builder.MakeIntermediate({{2, 512, 1, 64}}); + builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) + .AddAttribute("axis", static_cast(2)); + + // Create Transpose 1-Ops + auto* transpose_out_1 = builder.MakeOutput(); + builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + + // Create Slice Ops + auto* slice_output = builder.MakeIntermediate(); + auto* starts = builder.MakeInitializer({1}, {0}); + auto* ends = builder.MakeInitializer({1}, {-2}); + auto* axes = builder.MakeInitializer({1}, {2}); + auto* steps = builder.MakeInitializer({1}, {1}); + builder.AddNode("Slice", {reshape_out, starts, ends, axes, steps}, {slice_output}); + + // Create Shape-1 Ops + auto* shape_output_1 = builder.MakeOutput(); + builder.AddNode("Shape", {slice_output}, {shape_output_1}); + + // Create Shape-2 Ops + auto* shape_output_2 = builder.MakeOutput(); + builder.AddNode("Shape", {slice_output}, {shape_output_2}); + + // Create Transpose-3 Ops + auto* transpose_out_3 = builder.MakeOutput(); + builder.AddNode("Transpose", {slice_output}, {transpose_out_3}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0); + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); + } +} + } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 894fe3b052fb2..0b68dc65e41cd 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -24,6 +24,7 @@ #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/free_dim_override_transformer.h" #include "core/optimizer/gather_fusion.h" +#include "core/optimizer/gather_slice_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -140,6 +141,7 @@ std::vector> GeneratePreTrainingTransformers( transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); + transformers.emplace_back(std::make_unique(compatible_eps)); // If a model with Q, DQ nodes is being used for the purpose of training, it must be for // Quantization Aware Training. So, replace QDQ nodes with FakeQuant. transformers.emplace_back(std::make_unique(compatible_eps));