-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add GatherSliceToSplitFusion and Unittest (#19218)
### Multi Query Attention Optimization in multi-query attention ``` batch_size, seq_length, three_times_hidden_size = fused_qkv.shape fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim) return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :] ``` which can be optimized to ``` batch_size, seq_length, three_times_hidden_size = fused_qkv.shape fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim) (query, key, value) = fused_qkv.split([self.num_heads, 1, 1], dim=2) return query, key, value ``` this optimization can be validated from nsight profiling and perf benchmarking. <img width="545" alt="image" src="https://github.com/microsoft/onnxruntime/assets/15321482/cefcd061-4a01-4aaf-a008-8e265f7f63e9"> As such, This PR is to Optimize the `Gather/Gather/Slice` Ops to `Split` Kernel. ### Optimization Target <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> As 2 `Gather` and 1 `Slice` Kernels are time consuming for backward prop, it would be efficient to use 1 `Split` Kernel ### Example - Before Fusion <img width="419" alt="image" src="https://github.com/microsoft/onnxruntime/assets/15321482/17410319-57ea-4176-afd4-1efdcd3fdbae"> - After Fusion <img width="424" alt="image" src="https://github.com/microsoft/onnxruntime/assets/15321482/f1ee1582-96d4-45f4-8778-49d1f3fd370a"> ### Perf Gain After the optimization, there will have **~7%** perf gain. > The `Transpose` Kernel can be fused too, will update it in next PR. However, after testing Transponse Ops fusion on Falcon model, there is no perf gain. Will not create a new PR. --------- Co-authored-by: ruiren <[email protected]>
- Loading branch information
Showing
5 changed files
with
519 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,344 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "core/optimizer/gather_slice_fusion.h" | ||
#include "core/graph/graph_utils.h" | ||
#include "core/optimizer/initializer.h" | ||
#include "core/optimizer/utils.h" | ||
|
||
namespace onnxruntime { | ||
|
||
bool GatherSliceToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, | ||
int64_t& axis, int64_t& indices_n_dims) const { | ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) || | ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { | ||
return false; | ||
} | ||
|
||
const NodeArg& input_arg = *(node.InputDefs()[1]); | ||
|
||
if (!optimizer_utils::IsScalar(input_arg)) return false; | ||
|
||
const ONNX_NAMESPACE::TensorProto* indices_init = graph_utils::GetConstantInitializer(graph, input_arg.Name()); | ||
|
||
if (!indices_init) return false; | ||
|
||
if (indices_init->data_type() != ONNX_NAMESPACE::TensorProto::INT64) return false; | ||
|
||
// get the index value | ||
Initializer init_const(*indices_init, graph.ModelPath()); | ||
index = *(init_const.data<int64_t>()); | ||
|
||
// get attributes value | ||
axis = 0; | ||
auto& attrs = node.GetAttributes(); | ||
if (attrs.find("axis") != attrs.end()) { | ||
auto& axis_attr = attrs.at("axis"); | ||
if (utils::HasInt(axis_attr)) axis = axis_attr.i(); | ||
} | ||
|
||
indices_n_dims = indices_init->dims_size(); | ||
return true; | ||
} | ||
|
||
bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& node, | ||
InlinedVector<int64_t>& starts, | ||
InlinedVector<int64_t>& ends, | ||
InlinedVector<int64_t>& axes, | ||
InlinedVector<int64_t>& steps) const { | ||
// check the version of Slice ops | ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11, 13}) || | ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { | ||
return false; | ||
} | ||
|
||
// get the opset version | ||
int onnx_opset_version = -1; | ||
if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { | ||
onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); | ||
} | ||
|
||
// If Slice op of opset version 1 | ||
if (onnx_opset_version == 1) { | ||
if (!graph_utils::GetRepeatedNodeAttributeValues(node, "starts", starts) || | ||
!graph_utils::GetRepeatedNodeAttributeValues(node, "ends", ends) || | ||
starts.size() != ends.size()) { | ||
return false; | ||
} | ||
|
||
if (graph_utils::GetRepeatedNodeAttributeValues(node, "axes", axes) && (axes.size() != starts.size())) { | ||
return false; | ||
} | ||
} | ||
|
||
// If Slice op of opset version >= 10 | ||
if (onnx_opset_version >= 10) { | ||
// node inputs include: starts - ends - axes - steps | ||
|
||
// return a pointer to the corresponding NodeArg if input of the node at the index exists | ||
auto get_input_if_exists = [&node](size_t input_index) -> const NodeArg* { | ||
const auto& input_defs = node.InputDefs(); | ||
const NodeArg* input = (input_defs.size() > input_index) ? input_defs[input_index] : nullptr; | ||
return (input == nullptr || !input->Exists()) ? nullptr : input; | ||
}; | ||
|
||
// return a pointer to the initializer if it is constant; otherwise, a nullptr | ||
auto get_initializer_if_constant = | ||
[&graph, get_input_if_exists](size_t input_index) -> const ONNX_NAMESPACE::TensorProto* { | ||
const NodeArg* input = get_input_if_exists(input_index); | ||
return input ? graph_utils::GetConstantInitializer(graph, input->Name()) : nullptr; | ||
}; | ||
|
||
// return the initialization data if it is constant | ||
auto get_initializer_data = | ||
[&graph](const ONNX_NAMESPACE::TensorProto* slice_initializer) -> InlinedVector<int64_t> { | ||
Initializer init(*slice_initializer, graph.ModelPath()); | ||
if (slice_initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT32) { | ||
int32_t* init_data = init.data<int32_t>(); | ||
return InlinedVector<int64_t>(init_data, init_data + init.size()); | ||
} | ||
|
||
if (slice_initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT64) { | ||
int64_t* init_data = init.data<int64_t>(); | ||
return InlinedVector<int64_t>(init_data, init_data + init.size()); | ||
} | ||
return {}; | ||
}; | ||
|
||
// starts and ends inputs have to exist, be constants and be of the same size. | ||
const ONNX_NAMESPACE::TensorProto* starts_init = get_initializer_if_constant(1); | ||
const ONNX_NAMESPACE::TensorProto* ends_init = get_initializer_if_constant(2); | ||
const ONNX_NAMESPACE::TensorProto* axes_init = get_initializer_if_constant(3); | ||
const ONNX_NAMESPACE::TensorProto* steps_init = get_initializer_if_constant(4); | ||
|
||
if (!starts_init || !ends_init || !axes_init || !steps_init) { | ||
return false; | ||
} | ||
|
||
starts = get_initializer_data(starts_init); | ||
ends = get_initializer_data(ends_init); | ||
axes = get_initializer_data(axes_init); | ||
steps = get_initializer_data(steps_init); | ||
|
||
if (starts.size() == 0 || ends.size() == 0 || starts.size() != ends.size()) { | ||
return false; | ||
} | ||
|
||
if (axes_init->dims_size() != 1 || static_cast<size_t>(axes_init->dims().Get(0)) != starts.size()) { | ||
return false; | ||
} | ||
|
||
// if steps exists, it should be constant and all value should be 1 | ||
if (steps.size() != starts.size()) { | ||
return false; | ||
} | ||
|
||
for (int64_t step : steps) { | ||
if (step != 1) { | ||
return false; | ||
} | ||
} | ||
} | ||
|
||
return true; | ||
} | ||
|
||
/* | ||
GatherToSplitFusion is to fuse: | ||
Node | ||
|-> Gather(index=0, axis=axis) | ||
|-> Gather(index=1, axis=axis) | ||
|-> Slice(index=2, axis=axis) | ||
To | ||
Node | ||
|-> Split(index=0) | ||
So that we can use one kernel to finish the job. | ||
*/ | ||
|
||
Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, | ||
const logging::Logger& logger) const { | ||
GraphViewer graph_viewer(graph); | ||
|
||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); | ||
|
||
InlinedVector<const NodeArg*> output_args; | ||
|
||
// Iterate the topological order and get Reshape ops | ||
for (auto node_index : node_topology_list) { | ||
auto* p_node = graph.GetNode(node_index); | ||
|
||
if (p_node == nullptr) continue; | ||
|
||
Node& node = *p_node; | ||
|
||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); | ||
|
||
// Currently only catch after Reshape ops, optimize in the future | ||
if (node.OpType() != "Reshape") continue; | ||
|
||
size_t output_count = node.GetOutputEdgesCount(); | ||
|
||
// We only catch 1 scenario for Multi Query Attention for now. | ||
// |---> Gather | ||
// Reshape |---> Gather | ||
// |---> Slice | ||
// |... or (other ops) | ||
|
||
// Get the output into node args | ||
if (output_count < 3) continue; | ||
|
||
output_args.push_back(node.OutputDefs()[0]); | ||
} | ||
|
||
// iterate the children of Reshape node | ||
for (const NodeArg* node_arg : output_args) { | ||
auto shape = node_arg->Shape(); | ||
if (!shape) continue; | ||
|
||
auto consumers = graph.GetConsumerNodes(node_arg->Name()); | ||
size_t consumer_count = consumers.size(); | ||
|
||
// get the tensor rank | ||
int64_t rank = static_cast<int64_t>(shape->dim_size()); | ||
|
||
bool can_fuse = true; | ||
bool first_edge = true; | ||
int64_t split_axis = 0; | ||
int64_t indices_n_dims = -1; | ||
|
||
// Fuse 2 Gathers and 1 slice to Split | ||
// Get those outputs as Split outputs | ||
InlinedVector<NodeArg*> split_outputs(3); | ||
|
||
InlinedVector<std::reference_wrapper<Node>> nodes_to_fuse; | ||
size_t gather_node_count = 2, slice_node_count = 0; | ||
|
||
// find the nodes to be merged | ||
for (auto consumer : consumers) { | ||
int64_t index, axis, dims; | ||
InlinedVector<int64_t> starts, ends, axes, steps; | ||
|
||
bool IsSupportedGatherOps = IsSupportedGather(graph, *consumer, index, axis, dims); | ||
bool IsSupportedSliceOps = IsSupportedSlice(graph, *consumer, starts, ends, axes, steps); | ||
|
||
if ((!consumer || consumer->InputDefs()[0] != node_arg) || | ||
(!IsSupportedGatherOps && !IsSupportedSliceOps)) { | ||
break; | ||
} | ||
|
||
if (IsSupportedGatherOps) { | ||
if (indices_n_dims == -1) { | ||
indices_n_dims = dims; | ||
} else if (indices_n_dims != dims) { | ||
// Not the same number of dimensions (0 or 1) for all scalar indices. | ||
can_fuse = false; | ||
break; | ||
} | ||
|
||
if (axis < 0) axis += rank; | ||
|
||
if (first_edge) { | ||
auto dim = shape->dim(static_cast<int>(axis)); | ||
// dim.dim_value() = 73 | ||
if (!utils::HasDimValue(dim)) { | ||
can_fuse = false; | ||
break; | ||
} | ||
split_axis = axis; | ||
first_edge = false; | ||
} else if (axis != split_axis) { | ||
can_fuse = false; | ||
break; | ||
} | ||
|
||
if (index < 0) index += static_cast<int64_t>(consumer_count); | ||
if (index < 0 || index >= static_cast<int64_t>(consumer_count)) { | ||
can_fuse = false; | ||
break; | ||
} | ||
|
||
Node& gather_node = *graph.GetNode(consumer->Index()); | ||
nodes_to_fuse.push_back(gather_node); | ||
NodeArg* gather_output_args = gather_node.MutableOutputDefs()[0]; | ||
split_outputs[gather_node_count--] = gather_output_args; | ||
} | ||
|
||
// check the Slice Ops | ||
if (IsSupportedSliceOps) { | ||
if (axes[0] != axis && !first_edge) { | ||
can_fuse = false; | ||
break; | ||
} | ||
|
||
Node& slice_node = *graph.GetNode(consumer->Index()); | ||
NodeArg* slice_output_args = slice_node.MutableOutputDefs()[0]; | ||
nodes_to_fuse.push_back(slice_node); | ||
split_outputs[slice_node_count++] = slice_output_args; | ||
} | ||
} | ||
|
||
// condition check | ||
if (!can_fuse || gather_node_count != 0 || slice_node_count != 1) continue; | ||
|
||
// generate the split node and merge the kernel | ||
ONNX_NAMESPACE::TypeProto split_output_type; | ||
const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast<ONNX_NAMESPACE::TensorProto_DataType>( | ||
node_arg->TypeAsProto()->tensor_type().elem_type()); | ||
|
||
split_output_type.mutable_tensor_type()->set_elem_type(element_type); | ||
|
||
for (int64_t i = 0; i < rank; i++) { | ||
if (i == split_axis) | ||
split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL); | ||
else | ||
*(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast<int>(i)); | ||
} | ||
|
||
InlinedVector<NodeArg*> split_output_types; | ||
|
||
for (size_t i = 0; i < consumer_count; ++i) { | ||
split_output_types.push_back( | ||
&graph.GetOrCreateNodeArg( | ||
graph.GenerateNodeArgName("fused_split_" + std::to_string(i)), &split_output_type)); | ||
} | ||
|
||
// Generate the Split Node | ||
ONNX_NAMESPACE::TensorProto split_initializer_proto; | ||
split_initializer_proto.set_name(graph.GenerateNodeName("fused_Split")); | ||
split_initializer_proto.add_dims(static_cast<int64_t>(3)); | ||
split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); | ||
|
||
auto dim_value = shape->dim(static_cast<int>(split_axis)).dim_value(); | ||
// Optimize 2 Gather Nodes, so Slice_dim = dim_value - 2 | ||
int64_t slice_dim = static_cast<int64_t>(dim_value - 2); | ||
InlinedVector<int64_t> split_value{{slice_dim, 1, 1}}; | ||
split_initializer_proto.set_raw_data(split_value.data(), split_value.size() * sizeof(int64_t)); | ||
NodeArg* split_arg = &graph_utils::AddInitializer(graph, split_initializer_proto); | ||
|
||
Node& split_node = | ||
graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for fused Gather-Slice fusion", | ||
{graph.GetNodeArg(node_arg->Name()), split_arg}, split_outputs); | ||
|
||
split_node.AddAttribute("axis", split_axis); | ||
|
||
split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); | ||
|
||
int onnx_opset_version = -1; | ||
if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { | ||
onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); | ||
} | ||
|
||
if (onnx_opset_version >= 18) { | ||
split_node.AddAttribute("num_outputs", static_cast<int64_t>(consumer_count)); | ||
} | ||
|
||
for (Node& node_to_fuse : nodes_to_fuse) { | ||
graph_utils::RemoveNodeOutputEdges(graph, node_to_fuse); | ||
graph.RemoveNode(node_to_fuse.Index()); | ||
} | ||
modified = true; | ||
} | ||
|
||
return Status::OK(); | ||
} | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include "core/optimizer/graph_transformer.h" | ||
|
||
namespace onnxruntime { | ||
|
||
/** | ||
@class GatherSliceToSplitFusion | ||
Fuse (2 Gather nodes + 1 Slice) to 1 split node. | ||
*/ | ||
|
||
class GatherSliceToSplitFusion : public GraphTransformer { | ||
private: | ||
bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, | ||
int64_t& indices_n_dims) const; | ||
|
||
bool IsSupportedSlice(const Graph& graph, const Node& node, | ||
InlinedVector<int64_t>& starts, | ||
InlinedVector<int64_t>& ends, | ||
InlinedVector<int64_t>& axes, | ||
InlinedVector<int64_t>& steps) const; | ||
|
||
public: | ||
GatherSliceToSplitFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept | ||
: GraphTransformer("GatherSliceToSplitFusion", compatible_execution_providers) {} | ||
|
||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; | ||
}; | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.