Skip to content

Commit

Permalink
force fuse layernorm
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Dec 15, 2024
1 parent 2f5b9b9 commit 699a64c
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 59 deletions.
126 changes: 69 additions & 57 deletions onnxruntime/python/tools/transformers/fusion_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@


class FusionLayerNormalization(Fusion):
def __init__(self, model: OnnxModel, check_constant_and_dimension: bool = True):
def __init__(self, model: OnnxModel, check_constant_and_dimension: bool = True, force: bool = False):
super().__init__(model, "LayerNormalization", "ReduceMean")
self.check_constant_and_dimension = check_constant_and_dimension
self.force = force

def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
"""
Expand Down Expand Up @@ -97,63 +98,74 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):

if div_node.output[0] not in input_name_to_nodes:
return
temp_node = input_name_to_nodes[div_node.output[0]][0]
if temp_node.op_type == "Cast":
# Div --> Cast --> Mul
subgraph_nodes.append(temp_node) # add Cast node to list of subgraph nodes
if temp_node.output[0] not in input_name_to_nodes:
return
mul_node = input_name_to_nodes[temp_node.output[0]][0]
else:
# Div --> Mul
mul_node = temp_node
if mul_node.op_type != "Mul":
return

if mul_node.output[0] not in input_name_to_nodes:
return
last_add_node = input_name_to_nodes[mul_node.output[0]][0]
if last_add_node.op_type != "Add":
return

subgraph_nodes.append(node)
subgraph_nodes.extend(children)
subgraph_nodes.extend(parent_nodes[:-1])

subgraph_nodes.extend([last_add_node, mul_node, div_node])
if not self.model.is_safe_to_fuse_nodes(
subgraph_nodes,
last_add_node.output,
input_name_to_nodes,
output_name_to_node,
):
logger.debug("It is not safe to fuse LayerNormalization node. Skip")
return

node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node
weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)]
if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension(
weight_input, 1, "layernorm weight"
):
return

bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)]
if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension(
bias_input, 1, "layernorm bias"
):
return

self.nodes_to_remove.extend(subgraph_nodes)

normalize_node = helper.make_node(
"LayerNormalization",
inputs=[node.input[0], weight_input, bias_input],
outputs=[last_add_node.output[0]],
name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm"),
)
normalize_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))])
self.nodes_to_add.append(normalize_node)
self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
# In MMDit model, Div might have two Mul+Add children paths.
div_children = input_name_to_nodes[div_node.output[0]]
for temp_node in div_children:
if temp_node.op_type == "Cast":
# Div --> Cast --> Mul
subgraph_nodes.append(temp_node) # add Cast node to list of subgraph nodes
if temp_node.output[0] not in input_name_to_nodes:
continue
mul_node = input_name_to_nodes[temp_node.output[0]][0]
else:
# Div --> Mul
mul_node = temp_node
if mul_node.op_type != "Mul":
continue

if mul_node.output[0] not in input_name_to_nodes:
continue
last_add_node = input_name_to_nodes[mul_node.output[0]][0]
if last_add_node.op_type != "Add":
continue

subgraph_nodes.append(node)
subgraph_nodes.extend(children)
subgraph_nodes.extend(parent_nodes[:-1])

subgraph_nodes.extend([last_add_node, mul_node, div_node])

node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node
weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)]
if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension(
weight_input, 1, "layernorm weight"
):
continue

bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)]
if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension(
bias_input, 1, "layernorm bias"
):
continue

layer_norm_output = last_add_node.output[0]
if not self.model.is_safe_to_fuse_nodes(
subgraph_nodes,
last_add_node.output,
input_name_to_nodes,
output_name_to_node,
):
# If it is not safe to fuse, somce computation may be duplicated if we force to fuse it.
# It it unknown that force fusion might bring performance gain/loss.
# User need test performance impact to see whether forcing fusion can help.
if self.force:
self.prune_graph = True
else:
logger.debug("It is not safe to fuse LayerNormalization node. Skip")
continue
else:
self.nodes_to_remove.extend(subgraph_nodes)

normalize_node = helper.make_node(
"LayerNormalization",
inputs=[node.input[0], weight_input, bias_input],
outputs=[layer_norm_output],
name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm"),
)
normalize_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))])
self.nodes_to_add.append(normalize_node)
self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name


class FusionLayerNormalizationNCHW(Fusion):
Expand Down
7 changes: 5 additions & 2 deletions onnxruntime/python/tools/transformers/onnx_model_mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def fuse_layer_norm(self):
"The optimized model requires LayerNormalization with broadcast support. "
"Please use onnxruntime-gpu>=1.21 for inference."
)
fusion = FusionLayerNormalization(self, check_constant_and_dimension=not layernorm_support_broadcast)
fusion = FusionLayerNormalization(
self, check_constant_and_dimension=not layernorm_support_broadcast, force=True
)
fusion.apply()

def fuse_multi_head_attention(self):
Expand Down Expand Up @@ -88,7 +90,8 @@ def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None):

# TODO: SkipLayerNormalization does not support broadcast yet.
# if (options is None) or options.enable_skip_layer_norm:
# self.fuse_skip_layer_norm()
# self.fuse_skip_simplified_layer_norm()
# self.fuse_skip_layer_norm()
# if (options is None) or options.enable_bias_skip_layer_norm:
# # Fuse SkipLayerNormalization and Add Bias before it.
# self.fuse_add_bias_skip_layer_norm()
Expand Down

0 comments on commit 699a64c

Please sign in to comment.