diff --git a/onnxruntime/core/optimizer/attention_fusion.cc b/onnxruntime/core/optimizer/attention_fusion.cc index 64a38214caff0..ff8943de79679 100644 --- a/onnxruntime/core/optimizer/attention_fusion.cc +++ b/onnxruntime/core/optimizer/attention_fusion.cc @@ -50,7 +50,8 @@ void MergeWeights(const T* q, const T* k, const T* v, std::vector& result, in // Merge 2-D weights (q, k and v) by concatenating them row by row. template -void MergeMatMulWeights(const T* q_weight, const T* k_weight, const T* v_weight, std::vector& result, int64_t hidden_size) { +void MergeMatMulWeights(const T* q_weight, const T* k_weight, const T* v_weight, + std::vector& result, int64_t hidden_size) { const T* q = q_weight; const T* k = k_weight; const T* v = v_weight; @@ -144,7 +145,8 @@ static NodeArg& MergeQkvWeights(Graph& graph, int64_t hidden_size, return graph_utils::AddInitializer(graph, initializer); } -static NodeArg* ConvertMaskToInt32(Graph& graph, NodeArg* mask_input, ProviderType provider_type, const logging::Logger& logger) { +static NodeArg* ConvertMaskToInt32(Graph& graph, NodeArg* mask_input, ProviderType provider_type, + const logging::Logger& logger) { // Validate mask input shape (batch_size, sequence_length) and data type. // Note that batch_size and sequence_length could be symbolic. const TensorShapeProto* mask_shape = mask_input->Shape(); @@ -208,9 +210,11 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, Node& node = *p_node; ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); - if ((node.GetOutputEdgesCount() >= 2 && node.GetOutputEdgesCount() <= 6) && // Add node.GetOutputEdgesCount() == 5/6 for distilbert + // Add node.GetOutputEdgesCount() == 5/6 for distilbert + if ((node.GetOutputEdgesCount() >= 2 && node.GetOutputEdgesCount() <= 6) && graph_utils::IsSupportedOptypeVersionAndDomain(node, "LayerNormalization", {1, 17}, kOnnxDomain) && - graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) && node.InputDefs().size() > 2) { + graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) && + node.InputDefs().size() > 2 && node.InputDefs()[2]->Exists()) { // Bias is an optional input for LayerNorm // Get hidden size from layer norm bias tensor shape. const NodeArg& layer_norm_bias = *(node.InputDefs()[2]); if (!optimizer_utils::IsShapeKnownOnAllDims(layer_norm_bias, 1)) { @@ -242,8 +246,10 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, fused_count++; modified = true; } - } else if (reshape_count == 1 && (shape_count == 1 || shape_count == 3) && (static_cast(reshape_count) + shape_count) == node.GetOutputEdgesCount()) { // GPT - if (AttentionFusionHelper::FuseGptAttention(node, graph, hidden_size, mask_int32_map, shape_count == 1, logger)) { + } else if (reshape_count == 1 && (shape_count == 1 || shape_count == 3) && + (static_cast(reshape_count) + shape_count) == node.GetOutputEdgesCount()) { // GPT + if (AttentionFusionHelper::FuseGptAttention(node, graph, hidden_size, mask_int32_map, shape_count == 1, + logger)) { fused_count++; modified = true; } @@ -301,7 +307,8 @@ static bool FuseSubGraphQKImpl(Node& layer_norm, return false; } - if (!AttentionFusionHelper::CheckNodesInPathQ(graph, pivot_nodes[1].get(), q_reshape, q_transpose, num_heads, head_size, logger)) { + if (!AttentionFusionHelper::CheckNodesInPathQ(graph, pivot_nodes[1].get(), + q_reshape, q_transpose, num_heads, head_size, logger)) { DEBUG_LOG("CheckNodesInPathQ returns false"); return false; } @@ -365,7 +372,8 @@ static bool FuseSubGraphQKImpl(Node& layer_norm, } // Now everything is ready, we will start fusing subgraph. - NodeArg* mask_int32 = ConvertMaskToInt32(graph, mask_input, mask_int32_map, layer_norm.GetExecutionProviderType(), logger); + NodeArg* mask_int32 = ConvertMaskToInt32(graph, mask_input, mask_int32_map, layer_norm.GetExecutionProviderType(), + logger); if (nullptr == mask_int32) { DEBUG_LOG("Failed to convert mask to int32"); return false; @@ -438,7 +446,8 @@ static bool FuseSubGraphQK(Node& layer_norm, } std::vector nodes_to_remove; - if (!FuseSubGraphQKImpl(layer_norm, graph, parent_path_nodes, mask_input, mask_int32_map, edges, nodes_to_remove, hidden_size, + if (!FuseSubGraphQKImpl(layer_norm, graph, parent_path_nodes, + mask_input, mask_int32_map, edges, nodes_to_remove, hidden_size, num_heads, head_size, mask_nodes.mask_filter_value, logger)) { return false; } @@ -529,7 +538,8 @@ static bool FuseSubGraphQKDistilBert(Node& layer_norm, } std::vector nodes_to_remove; - if (!FuseSubGraphQKImpl(layer_norm, graph, parent_path_nodes, mask_input, mask_int32_map, edges, nodes_to_remove, hidden_size, + if (!FuseSubGraphQKImpl(layer_norm, graph, parent_path_nodes, + mask_input, mask_int32_map, edges, nodes_to_remove, hidden_size, num_heads, head_size, mask_nodes.mask_filter_value, logger)) { return false; } @@ -615,7 +625,12 @@ After Fusion: | | Add */ -bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer_norm, Graph& graph, int64_t hidden_size, std::map& mask_int32_map, const logging::Logger& logger) { +bool AttentionFusion::FuseSubGraph(Node& layer_norm, + const Node& add_after_layer_norm, + Graph& graph, + int64_t hidden_size, + std::map& mask_int32_map, + const logging::Logger& logger) { std::vector parent_path{ {0, 0, "Add", {7, 13}, kOnnxDomain}, {0, 0, "MatMul", {1, 9, 13}, kOnnxDomain}, @@ -657,7 +672,9 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer int64_t num_heads = 0; // will be updated in CheckNodesInPathV int64_t head_size = 0; // will be updated in CheckNodesInPathV NodeIndex record_node_idx = 0; // will be updated in CheckNodesInPathV if it's distilbert model - if (!AttentionFusionHelper::CheckNodesInPathV(graph, reshape, transpose, qkv_matmul, v_transpose, v_reshape, num_heads, head_size, hidden_size, record_node_idx, logger)) { + if (!AttentionFusionHelper::CheckNodesInPathV(graph, reshape, transpose, + qkv_matmul, v_transpose, v_reshape, num_heads, + head_size, hidden_size, record_node_idx, logger)) { DEBUG_LOG("CheckNodesInPathV return false"); return false; } @@ -672,7 +689,8 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer } // store parent path - std::vector> parent_path_nodes{reshape, transpose, qkv_matmul, v_transpose, v_reshape, v_add, v_matmul}; + std::vector> parent_path_nodes{ + reshape, transpose, qkv_matmul, v_transpose, v_reshape, v_add, v_matmul}; // Find mask nodes: Unsqueeze -> Unsqueeze -> (Cast) -> Sub -> Mul -> Add -> Softmax --> [MatMul] // The "Cast" node in parentheses is optional. @@ -681,10 +699,13 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer if (AttentionFusionHelper::MatchInputMaskSubgraph(graph, qkv_matmul, mask_nodes, logger, false)) { NodeArg* mask_input = graph.GetNode(mask_nodes.unsqueeze_1->Index())->MutableInputDefs()[0]; - return FuseSubGraphQK(layer_norm, graph, mask_nodes, mask_input, parent_path_nodes, hidden_size, num_heads, head_size, mask_int32_map, logger); - } else if (AttentionFusionHelper::MatchInputMaskSubgraph(graph, layer_norm, qkv_matmul, mask_nodes_distilbert, record_node_idx, logger)) { + return FuseSubGraphQK(layer_norm, graph, mask_nodes, mask_input, + parent_path_nodes, hidden_size, num_heads, head_size, mask_int32_map, logger); + } else if (AttentionFusionHelper::MatchInputMaskSubgraph(graph, layer_norm, qkv_matmul, + mask_nodes_distilbert, record_node_idx, logger)) { NodeArg* mask_input = graph.GetNode(mask_nodes_distilbert.equal->Index())->MutableInputDefs()[0]; - return FuseSubGraphQKDistilBert(layer_norm, graph, mask_nodes_distilbert, mask_input, parent_path_nodes, hidden_size, num_heads, head_size, mask_int32_map, logger); + return FuseSubGraphQKDistilBert(layer_norm, graph, mask_nodes_distilbert, mask_input, + parent_path_nodes, hidden_size, num_heads, head_size, mask_int32_map, logger); } else { DEBUG_LOG("Failed in match input mask subgraph"); return false;