From 699a64cf6c56e0e4f45697567a62c90a0adee1b1 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sun, 15 Dec 2024 01:16:08 +0000 Subject: [PATCH] force fuse layernorm --- .../tools/transformers/fusion_layernorm.py | 126 ++++++++++-------- .../tools/transformers/onnx_model_mmdit.py | 7 +- 2 files changed, 74 insertions(+), 59 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_layernorm.py b/onnxruntime/python/tools/transformers/fusion_layernorm.py index d1e30351564a9..277bd0799cf16 100644 --- a/onnxruntime/python/tools/transformers/fusion_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_layernorm.py @@ -13,9 +13,10 @@ class FusionLayerNormalization(Fusion): - def __init__(self, model: OnnxModel, check_constant_and_dimension: bool = True): + def __init__(self, model: OnnxModel, check_constant_and_dimension: bool = True, force: bool = False): super().__init__(model, "LayerNormalization", "ReduceMean") self.check_constant_and_dimension = check_constant_and_dimension + self.force = force def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): """ @@ -97,63 +98,74 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): if div_node.output[0] not in input_name_to_nodes: return - temp_node = input_name_to_nodes[div_node.output[0]][0] - if temp_node.op_type == "Cast": - # Div --> Cast --> Mul - subgraph_nodes.append(temp_node) # add Cast node to list of subgraph nodes - if temp_node.output[0] not in input_name_to_nodes: - return - mul_node = input_name_to_nodes[temp_node.output[0]][0] - else: - # Div --> Mul - mul_node = temp_node - if mul_node.op_type != "Mul": - return - - if mul_node.output[0] not in input_name_to_nodes: - return - last_add_node = input_name_to_nodes[mul_node.output[0]][0] - if last_add_node.op_type != "Add": - return - - subgraph_nodes.append(node) - subgraph_nodes.extend(children) - subgraph_nodes.extend(parent_nodes[:-1]) - - subgraph_nodes.extend([last_add_node, mul_node, div_node]) - if not self.model.is_safe_to_fuse_nodes( - subgraph_nodes, - last_add_node.output, - input_name_to_nodes, - output_name_to_node, - ): - logger.debug("It is not safe to fuse LayerNormalization node. Skip") - return - node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node - weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)] - if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension( - weight_input, 1, "layernorm weight" - ): - return - - bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)] - if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension( - bias_input, 1, "layernorm bias" - ): - return - - self.nodes_to_remove.extend(subgraph_nodes) - - normalize_node = helper.make_node( - "LayerNormalization", - inputs=[node.input[0], weight_input, bias_input], - outputs=[last_add_node.output[0]], - name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm"), - ) - normalize_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))]) - self.nodes_to_add.append(normalize_node) - self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name + # In MMDit model, Div might have two Mul+Add children paths. + div_children = input_name_to_nodes[div_node.output[0]] + for temp_node in div_children: + if temp_node.op_type == "Cast": + # Div --> Cast --> Mul + subgraph_nodes.append(temp_node) # add Cast node to list of subgraph nodes + if temp_node.output[0] not in input_name_to_nodes: + continue + mul_node = input_name_to_nodes[temp_node.output[0]][0] + else: + # Div --> Mul + mul_node = temp_node + if mul_node.op_type != "Mul": + continue + + if mul_node.output[0] not in input_name_to_nodes: + continue + last_add_node = input_name_to_nodes[mul_node.output[0]][0] + if last_add_node.op_type != "Add": + continue + + subgraph_nodes.append(node) + subgraph_nodes.extend(children) + subgraph_nodes.extend(parent_nodes[:-1]) + + subgraph_nodes.extend([last_add_node, mul_node, div_node]) + + node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node + weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)] + if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension( + weight_input, 1, "layernorm weight" + ): + continue + + bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)] + if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension( + bias_input, 1, "layernorm bias" + ): + continue + + layer_norm_output = last_add_node.output[0] + if not self.model.is_safe_to_fuse_nodes( + subgraph_nodes, + last_add_node.output, + input_name_to_nodes, + output_name_to_node, + ): + # If it is not safe to fuse, somce computation may be duplicated if we force to fuse it. + # It it unknown that force fusion might bring performance gain/loss. + # User need test performance impact to see whether forcing fusion can help. + if self.force: + self.prune_graph = True + else: + logger.debug("It is not safe to fuse LayerNormalization node. Skip") + continue + else: + self.nodes_to_remove.extend(subgraph_nodes) + + normalize_node = helper.make_node( + "LayerNormalization", + inputs=[node.input[0], weight_input, bias_input], + outputs=[layer_norm_output], + name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm"), + ) + normalize_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))]) + self.nodes_to_add.append(normalize_node) + self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name class FusionLayerNormalizationNCHW(Fusion): diff --git a/onnxruntime/python/tools/transformers/onnx_model_mmdit.py b/onnxruntime/python/tools/transformers/onnx_model_mmdit.py index 7593450f7dd74..9e30130d033e7 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_mmdit.py +++ b/onnxruntime/python/tools/transformers/onnx_model_mmdit.py @@ -39,7 +39,9 @@ def fuse_layer_norm(self): "The optimized model requires LayerNormalization with broadcast support. " "Please use onnxruntime-gpu>=1.21 for inference." ) - fusion = FusionLayerNormalization(self, check_constant_and_dimension=not layernorm_support_broadcast) + fusion = FusionLayerNormalization( + self, check_constant_and_dimension=not layernorm_support_broadcast, force=True + ) fusion.apply() def fuse_multi_head_attention(self): @@ -88,7 +90,8 @@ def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None): # TODO: SkipLayerNormalization does not support broadcast yet. # if (options is None) or options.enable_skip_layer_norm: - # self.fuse_skip_layer_norm() + # self.fuse_skip_simplified_layer_norm() + # self.fuse_skip_layer_norm() # if (options is None) or options.enable_bias_skip_layer_norm: # # Fuse SkipLayerNormalization and Add Bias before it. # self.fuse_add_bias_skip_layer_norm()