diff --git a/onnxruntime/python/tools/transformers/fusion_attention_clip.py b/onnxruntime/python/tools/transformers/fusion_attention_clip.py index d400e248d6cca..b027957fcc725 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_clip.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_clip.py @@ -97,7 +97,33 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): else: # Deal with the first attention after the embedding layer. for i in [0, 1]: - node_before_layer_norm = self.model.match_parent(normalize_node, "Add", i) + node_before_layer_norm = None + + node_before_layer_norm_1 = self.model.match_parent(normalize_node, "Add", i) + node_before_layer_norm_2 = self.model.match_parent(normalize_node, "LayerNormalization", i) + if node_before_layer_norm_1 is not None: + # Add -----------+ + # | | + # LayerNorm | + # | | + # LayerNorm | + # | | + # Attention subgraph | + # | | + # SkipLayerNorm ------+ + node_before_layer_norm = node_before_layer_norm_1 + elif node_before_layer_norm_2 is not None: + # Add + # | + # LayerNorm --------+ + # | | + # LayerNorm | + # | | + # Attention subgraph | + # | | + # SkipLayerNorm ------+ + node_before_layer_norm = node_before_layer_norm_2 + if node_before_layer_norm is None: continue child = self.model.find_first_child_by_type( @@ -130,20 +156,32 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): return (_, _, reshape_v, add_v, matmul_v) = v_nodes + add_mask = None add_mask_indices = [] - qk_nodes = self.model.match_parent_path( + qk_nodes = None + qk_nodes_1 = self.model.match_parent_path( matmul_qkv, ["Softmax", "Reshape", "Add", "Reshape", "MatMul"], [0, 0, 0, None, 0], return_indice=add_mask_indices, ) - if qk_nodes is None: + qk_nodes_2 = self.model.match_parent_path( + matmul_qkv, + ["Softmax", "MatMul"], + [0, 0], + ) + if qk_nodes_1 is not None: + qk_nodes = qk_nodes_1 + assert len(add_mask_indices) == 1 + causal_mask_input_index = 1 - add_mask_indices[0] + + (_softmax_qk, _, add_mask, _, matmul_qk) = qk_nodes + elif qk_nodes_2 is not None: + qk_nodes = qk_nodes_2 + (_softmax_qk, matmul_qk) = qk_nodes + else: logger.debug("fuse_attention: failed to match qk path") return - assert len(add_mask_indices) == 1 - causal_mask_input_index = 1 - add_mask_indices[0] - - (_softmax_qk, _, add_mask, _, matmul_qk) = qk_nodes q_nodes = self.model.match_parent_path( matmul_qk, ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], [0, 0, 0, 0, None, None] @@ -172,23 +210,24 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): attention_last_node = reshape_qkv - # Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path - # of computing causal mask. - causal_mask_nodes = self.model.match_parent_path( - add_mask, - ["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], - [causal_mask_input_index, 0, 0, 0, 0, 0], - ) - if causal_mask_nodes is None: - # If the model is exported with batch_size == 1, there is no Concat node + if add_mask is not None: + # Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path + # of computing causal mask. causal_mask_nodes = self.model.match_parent_path( add_mask, - ["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], - [causal_mask_input_index, 0, 0, 0, 0], + ["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], + [causal_mask_input_index, 0, 0, 0, 0, 0], ) if causal_mask_nodes is None: - logger.debug("fuse_attention: failed to match causal mask subgraph") - return + # If the model is exported with batch_size == 1, there is no Concat node + causal_mask_nodes = self.model.match_parent_path( + add_mask, + ["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], + [causal_mask_input_index, 0, 0, 0, 0], + ) + if causal_mask_nodes is None: + logger.debug("fuse_attention: failed to match causal mask subgraph") + return new_node = self.create_attention_node( mask_index=None, @@ -204,7 +243,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): output=attention_last_node.output[0], add_qk_str=None, scale=None, - causal=True, + causal=(add_mask is not None), ) if new_node is None: return diff --git a/onnxruntime/python/tools/transformers/fusion_layernorm.py b/onnxruntime/python/tools/transformers/fusion_layernorm.py index 68d26fc46fa23..678d8c42bad67 100644 --- a/onnxruntime/python/tools/transformers/fusion_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_layernorm.py @@ -38,6 +38,7 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): | | +----------------------+ """ + subgraph_nodes = [] children = self.model.get_children(node, input_name_to_nodes) if len(children) == 0 or len(children) > 2: return @@ -53,9 +54,16 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): div_node = None for child in children: - div_node = self.model.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False) - if div_node is not None: - break + # Check if Sub --> Div exists + div_node_1 = self.model.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False) + + # Check if Sub --> Cast --> Div + div_node_2 = self.model.match_child_path(child, ["Cast", "Div"], exclude=[]) + + if div_node_1 is not None: + div_node = div_node_1 + elif div_node_2 is not None: + div_node = div_node_2[-1] if div_node is None: return @@ -63,10 +71,7 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): div_node, [ (["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]), - ( - ["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"], - [1, 0, 0, 0, 0, 0], - ), + (["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"], [1, 0, 0, 0, 0, 0]), ], output_name_to_node, ) @@ -87,7 +92,14 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): if self.model.find_constant_input(pow_node, 2.0) != 1: return - mul_node = input_name_to_nodes[div_node.output[0]][0] + temp_node = input_name_to_nodes[div_node.output[0]][0] + if temp_node.op_type == "Cast": + # Div --> Cast --> Mul + subgraph_nodes.append(temp_node) # add Cast node to list of subgraph nodes + mul_node = input_name_to_nodes[temp_node.output[0]][0] + else: + # Div --> Mul + mul_node = temp_node if mul_node.op_type != "Mul": return @@ -95,7 +107,7 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): if last_add_node.op_type != "Add": return - subgraph_nodes = [node] + subgraph_nodes.append(node) subgraph_nodes.extend(children) subgraph_nodes.extend(parent_nodes[:-1]) @@ -109,7 +121,8 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): logger.debug("It is not safe to fuse LayerNormalization node. Skip") return - weight_input = mul_node.input[1 - self.model.input_index(div_node.output[0], mul_node)] + node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node + weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)] if not self.model.is_constant_with_specified_dimension(weight_input, 1, "layernorm weight"): return diff --git a/onnxruntime/python/tools/transformers/fusion_quickgelu.py b/onnxruntime/python/tools/transformers/fusion_quickgelu.py new file mode 100644 index 0000000000000..87154a1c421a8 --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_quickgelu.py @@ -0,0 +1,74 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import logging + +from fusion_base import Fusion +from onnx import helper +from onnx_model import OnnxModel + +logger = logging.getLogger(__name__) + + +class FusionQuickGelu(Fusion): + def __init__(self, model: OnnxModel): + super().__init__(model, "QuickGelu", ["Mul"]) + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + # Fuse the following subgraph to `QuickGelu` + # + # root_input + # / \ + # | Mul ----+ + # | (B = ~1.702) | + # \ | | + # \ Sigmoid |---- `QuickGelu` + # \ / | + # \ / | + # Mul ----+ + # | + # root_output + + if node.op_type != "Mul": + logger.debug("fuse_quickgelu: failed to match second Mul node") + return + + second_mul_node = node + root_input = second_mul_node.input[0] + + sigmoid_node = self.model.match_parent_path(second_mul_node, ["Sigmoid"], [1]) + if sigmoid_node is None: + logger.debug("fuse_quickgelu: failed to match Sigmoid node") + return + sigmoid_node = sigmoid_node[0] + + first_mul_node = self.model.match_parent_path(sigmoid_node, ["Mul"], [0]) + if first_mul_node is None: + logger.debug("fuse_quickgelu: failed to match first Mul node") + return + first_mul_node = first_mul_node[0] + + approximation_value = self.model.get_constant_value(first_mul_node.input[1]).item() + if abs(approximation_value - 1.7021484375) >= 1e-3: + logger.debug("fuse_quickgelu: failed to match approximation value") + return + + if first_mul_node.input[0] != root_input: + logger.debug("fuse_quickgelu: failed to match root input with first Mul node's input") + return + + new_node = helper.make_node( + "QuickGelu", + inputs=[root_input], + outputs=[second_mul_node.output[0]], + name=self.model.create_node_name("QuickGelu"), + ) + new_node.domain = "com.microsoft" + new_node.attribute.extend([helper.make_attribute("alpha", approximation_value)]) + + self.nodes_to_remove.extend([first_mul_node, sigmoid_node, second_mul_node]) + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + self.increase_counter("QuickGelu") diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index 431e64509e3cc..ad51c1cce0ec4 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -21,6 +21,7 @@ from fusion_qordered_gelu import FusionQOrderedGelu from fusion_qordered_layernorm import FusionQOrderedLayerNormalization from fusion_qordered_matmul import FusionQOrderedMatMul +from fusion_quickgelu import FusionQuickGelu from fusion_reshape import FusionReshape from fusion_rotary_attention import FusionRotaryEmbeddings from fusion_shape import FusionShape @@ -65,6 +66,8 @@ def fuse_gelu(self): fusion.apply() fusion = FusionFastGelu(self) fusion.apply() + fusion = FusionQuickGelu(self) + fusion.apply() # Only relevant in models with Q-DQ nodes fusion = FusionQOrderedGelu(self) fusion.apply() diff --git a/onnxruntime/python/tools/transformers/onnx_model_clip.py b/onnxruntime/python/tools/transformers/onnx_model_clip.py index 9b4ca03a47a5b..32bddc3ca16a0 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_clip.py +++ b/onnxruntime/python/tools/transformers/onnx_model_clip.py @@ -25,6 +25,7 @@ def get_fused_operator_statistics(self): ops = [ "Attention", "LayerNormalization", + "QuickGelu", "SkipLayerNormalization", ] for op in ops: diff --git a/onnxruntime/test/python/transformers/test_gelu_fusions.py b/onnxruntime/test/python/transformers/test_gelu_fusions.py index 77a6491d4bd3c..f6c6348ae8c17 100644 --- a/onnxruntime/test/python/transformers/test_gelu_fusions.py +++ b/onnxruntime/test/python/transformers/test_gelu_fusions.py @@ -21,6 +21,11 @@ def forward(self, x): return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) +class HuggingfaceQuickGelu(torch.nn.Module): + def forward(self, x): + return x * torch.sigmoid(1.702 * x) + + class MegatronGelu(torch.nn.Module): def forward(self, x): # The original implementation using ones_like, which might cause problem for input with dynamic axes in onnx. @@ -36,6 +41,7 @@ def forward(self, x): test_cases = [ ("huggingface", "Gelu", HuggingfaceGelu), ("huggingface", "FastGelu", HuggingfaceFastGelu), + ("huggingface", "QuickGelu", HuggingfaceQuickGelu), ("megatron", "Gelu", MegatronGelu), ("megatron", "FastGelu", MegatronFastGelu), ]