Skip to content

Commit

Permalink
update FastGelu and RMSNorm fusions
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Dec 5, 2024
1 parent 9b2dcc0 commit 7f925ce
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 111 deletions.
122 changes: 122 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_fastgelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def fuse(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
if self.fuse_3(tanh_node, input_name_to_nodes, output_name_to_node):
return

if self.fuse_4(tanh_node, input_name_to_nodes, output_name_to_node):
return

def fuse_1(self, tanh_node, input_name_to_nodes, output_name_to_node) -> Optional[bool]:
"""
Fuse Gelu with tanh into one node:
Expand Down Expand Up @@ -358,3 +361,122 @@ def fuse_3(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict
self.nodes_to_add.append(fused_node)
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
return True

def fuse_4(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
"""
This pattern is from stable diffusion 3.5 model.
Fuse Gelu with tanh into one node:
+-----------------+------------------+
| | |
| v v
[root] ==> Mul --> Mul --> Mul -----> Add --> Mul --> Tanh --> Add -----> Mul --> Mul -->
| (A=0.0447) (A=0.7978) (A=1) ^ (A=0.5)
| |
+-------------------------------------------------------------------------+
Note that constant input for Add and Mul could be first or second input.
"""
if tanh_node.output[0] not in input_name_to_nodes:
return

children = input_name_to_nodes[tanh_node.output[0]]
if len(children) != 1 or children[0].op_type != "Add":
return
add_after_tanh = children[0]

if not self.model.has_constant_input(add_after_tanh, 1.0):
return

if add_after_tanh.output[0] not in input_name_to_nodes:
return
children = input_name_to_nodes[add_after_tanh.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return
mul_after_tanh = children[0]

if mul_after_tanh.output[0] not in input_name_to_nodes:
return
children = input_name_to_nodes[mul_after_tanh.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return
mul_half = children[0]
if not self.model.has_constant_input(mul_half, 0.5):
return

root_input = mul_after_tanh.input[0 if mul_after_tanh.input[1] == add_after_tanh.output[0] else 1]

mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
if mul_before_tanh is None:
return

i = self.model.find_constant_input(mul_before_tanh, 0.7978, delta=0.0001)
if i < 0:
return

add_before_tanh = self.model.match_parent(mul_before_tanh, "Add", 0 if i == 1 else 1, output_name_to_node)
if add_before_tanh is None:
return

if add_before_tanh.input[0] == root_input:
another = 1
elif add_before_tanh.input[1] == root_input:
another = 0
else:
return

mul_after_pow = self.model.match_parent(add_before_tanh, "Mul", another, output_name_to_node)
if mul_after_pow is None:
return

i = self.model.find_constant_input(mul_after_pow, 0.0447, delta=0.0001)
if i < 0:
return

mul = self.model.match_parent(mul_after_pow, "Mul", 0 if i == 1 else 1, output_name_to_node)
if mul is None:
return

if mul.input[0] == root_input:
another = 1
elif mul.input[1] == root_input:
another = 0
else:
return

mul2 = self.model.match_parent(mul, "Mul", another, output_name_to_node)
if mul2 is None:
return

if mul2.input[0] != root_input or mul2.input[1] != root_input:
return

subgraph_nodes = [
mul2,
mul,
mul_after_pow,
add_before_tanh,
mul_before_tanh,
tanh_node,
add_after_tanh,
mul_after_tanh,
mul_half,
]

if not self.model.is_safe_to_fuse_nodes(
subgraph_nodes,
[mul_half.output[0]],
input_name_to_nodes,
output_name_to_node,
):
return

self.nodes_to_remove.extend(subgraph_nodes)
fused_node = helper.make_node(
"FastGelu",
inputs=[root_input],
outputs=mul_half.output,
name=self.model.create_node_name("FastGelu"),
)
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
return True
24 changes: 16 additions & 8 deletions onnxruntime/python/tools/transformers/fusion_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,26 +56,28 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
for child in children:
# Check if Sub --> Div exists
div_node_1 = self.model.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False)

# Check if Sub --> Cast --> Div
div_node_2 = self.model.match_child_path(child, ["Cast", "Div"], exclude=[])

if div_node_1 is not None:
div_node = div_node_1
elif div_node_2 is not None:
div_node = div_node_2[-1]
break
else:
# Check if Sub --> Cast --> Div
div_node_2 = self.model.match_child_path(child, ["Cast", "Div"], exclude=[])
if div_node_2 is not None:
div_node = div_node_2[-1]
break

if div_node is None:
return

path_id, parent_nodes, _ = self.model.match_parent_paths(
_path_id, parent_nodes, _ = self.model.match_parent_paths(
div_node,
[
(["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]),
(["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"], [1, 0, 0, 0, 0, 0]),
],
output_name_to_node,
)
if path_id < 0:
if parent_nodes is None:
return

sub_node = parent_nodes[-1]
Expand All @@ -92,17 +94,23 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
if self.model.find_constant_input(pow_node, 2.0) != 1:
return

if div_node.output[0] not in input_name_to_nodes:
return
temp_node = input_name_to_nodes[div_node.output[0]][0]
if temp_node.op_type == "Cast":
# Div --> Cast --> Mul
subgraph_nodes.append(temp_node) # add Cast node to list of subgraph nodes
if temp_node.output[0] not in input_name_to_nodes:
return
mul_node = input_name_to_nodes[temp_node.output[0]][0]
else:
# Div --> Mul
mul_node = temp_node
if mul_node.op_type != "Mul":
return

if mul_node.output[0] not in input_name_to_nodes:
return
last_add_node = input_name_to_nodes[mul_node.output[0]][0]
if last_add_node.op_type != "Add":
return
Expand Down
152 changes: 54 additions & 98 deletions onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,134 +18,90 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
return

sim_ln_nodes = None
# SimplifiedLayerNorm calculation (notation from https://onnx.ai/onnx/operators/onnx__LayerNormalization.html#summary):
# DD = Pow(D, 2)
# RMSNorm calculation (notation from https://onnx.ai/onnx/operators/onnx__LayerNormalization.html#summary):
# DD = Pow(D, 2) or DD = Mul(D, D)
# Var = ReduceMean(DD)
# VarEps = Add(Var, epsilon)
# StdDev = Sqrt(VarEps)
# InvStdDev = Div(1, StdDev)
# Normalized = Mul(D, InvStdDev)
# NormalizedScaled = Mul(Normalized, Scale)

# SimplifiedLayerNorm
# +-------------------------------------------------------+
# | |
# Add --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul
# |
# node
sim_ln_nodes_1 = self.model.match_parent_path(
node,
["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"],
[1, 1, 1, 0, 0, 0, 0],
)
# SimplifiedLayerNorm
# +-------------------------------------------------------+
# | |
# Gather --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul
# |
# node
sim_ln_nodes_2 = self.model.match_parent_path(
node,
["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Gather"],
[1, 1, 1, 0, 0, 0, 0],
)

# For LLaMA from Microsoft custom export:
# sim_ln_nodes_3 uses a different start parent index than sim_ln_nodes_1
#
# SimplifiedLayerNorm
# +-------------------------------------------------------+
# | |
# Add --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul
# |
# node
sim_ln_nodes_3 = self.model.match_parent_path(
node,
["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"],
[0, 1, 1, 0, 0, 0, 0],
)

# sim_ln_nodes_4 starts with a graph input instead of an Add node like sim_ln_nodes_3
# (root_input) ---------------------------------------+
# | |
# v v
# Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul (node)
# (B=2) (A/B=eps) (A=1) (A/B=scale)
#
# SimplifiedLayerNorm
# +-----------------------------------------------+
# | |
# graph_input --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul
# |
# node
sim_ln_nodes_4 = self.model.match_parent_path(
# (root_input) ---------------------------------------+
# | | |
# v v v
# Mul --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul (node)
# (B=2) (A/B=eps) (A=1) (A/B=scale)

return_indice = []
sim_ln_nodes = self.model.match_parent_path(
node,
["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow"],
[0, 1, 1, 0, 0, 0],
["Mul", "Div", "Sqrt", "Add", "ReduceMean"],
[None, 1, 1, 0, None],
output_name_to_node=output_name_to_node,
return_indice=return_indice,
)

# For Gemma from Microsoft custom export, which has a Multiply after the Gather:
#
# SimplifiedLayerNorm
# +-------------------------------------------------------+
# | |
# Mul --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul
# |
# node
sim_ln_nodes_5 = self.model.match_parent_path(
node,
["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Mul"],
[1, 1, 1, 0, 0, 0, 0],
)
if sim_ln_nodes is None:
return

add_node, pow_node = None, None
if sim_ln_nodes_1 is not None:
sim_ln_nodes = sim_ln_nodes_1
add_node = sim_ln_nodes[3]
pow_node = sim_ln_nodes[-2]
elif sim_ln_nodes_2 is not None:
sim_ln_nodes = sim_ln_nodes_2
add_node = sim_ln_nodes[3]
pow_node = sim_ln_nodes[-2]
elif sim_ln_nodes_3 is not None:
sim_ln_nodes = sim_ln_nodes_3
add_node = sim_ln_nodes[3]
pow_node = sim_ln_nodes[-2]
elif sim_ln_nodes_4 is not None:
sim_ln_nodes = sim_ln_nodes_4
add_node = sim_ln_nodes[3]
pow_node = sim_ln_nodes[-1]
# Verify that parent input to Pow node is graph_input
if pow_node.input[0] not in self.model.get_graphs_input_names():
mul_node, div_node, _sqrt_node, add_node, reduce_mean_node = sim_ln_nodes

pow_or_mul_node = self.model.get_parent(reduce_mean_node, 0, output_name_to_node)
if pow_or_mul_node is None or pow_or_mul_node.op_type not in ["Pow", "Mul"]:
return

if pow_or_mul_node.op_type == "Pow":
if self.model.find_constant_input(pow_or_mul_node, 2.0) != 1:
return
elif sim_ln_nodes_5 is not None:
sim_ln_nodes = sim_ln_nodes_5
add_node = sim_ln_nodes[3]
pow_node = sim_ln_nodes[-2]
else:
assert pow_or_mul_node.op_type == "Mul"
if pow_or_mul_node[0] != pow_or_mul_node[1]:
return

root_input = pow_or_mul_node.input[0]
if root_input != mul_node.input[0]:
return

layernorm_weight_index = 1 if sim_ln_nodes in (sim_ln_nodes_3, sim_ln_nodes_4) else 0
starts_with_graph_input = sim_ln_nodes == sim_ln_nodes_4
if not self.model.has_constant_input(div_node, 1.0):
return

if self.model.find_constant_input(pow_node, 2.0) != 1:
_i, epsilon = self.model.get_constant_input(add_node)
if epsilon is None or epsilon <= 0 or epsilon > 1.0e-4:
logger.warning(f"epsilon value is not expected: {epsilon}")
return

root_input = pow_node.input[0]
if root_input != sim_ln_nodes[0].input[0]:
# ReduceMean must have keepdims == 1
keepdims = self.model.get_node_attribute(reduce_mean_node, "keepdims")
if not keepdims:
return

i, add_weight = self.model.get_constant_input(add_node)
if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4:
logger.warning(f"epsilon value is not expected: {add_weight}")
# ReduceMean axes must refer only to the last dimension.
# Axes became an input in opset 18. Before then, axes was an attribute.
axes = self.model.get_node_attribute(reduce_mean_node, "axes")
if (not axes) and len(reduce_mean_node.input) > 1:
axes = self.model.get_constant_value(reduce_mean_node.input[1])
# Make sure only one axis as required by SimplifiedLayerNormalization spec.
if not axes or len(axes) != 1:
return

self.nodes_to_remove.extend(sim_ln_nodes[:-1] if not starts_with_graph_input else sim_ln_nodes)
self.nodes_to_remove.extend(sim_ln_nodes)
self.nodes_to_remove.append(node)

normalize_node = helper.make_node(
"SimplifiedLayerNormalization",
inputs=[root_input, node.input[layernorm_weight_index]],
inputs=[root_input, node.input[1 - return_indice[0]]],
outputs=[node.output[0]],
name=self.model.create_node_name("SimplifiedLayerNormalization", name_prefix="LayerNorm"),
)
normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))])
normalize_node.attribute.extend([helper.make_attribute("axis", -1)])
normalize_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))])
normalize_node.attribute.extend([helper.make_attribute("axis", axes[0])])
normalize_node.attribute.extend([helper.make_attribute("stash_type", 1)])
self.nodes_to_add.append(normalize_node)
self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
import coloredlogs
import onnx
from fusion_options import FusionOptions
from onnx_model_mmdit import MmditOnnxModel
from onnx_model_clip import ClipOnnxModel
from onnx_model_mmdit import MmditOnnxModel
from onnx_model_unet import UnetOnnxModel
from onnx_model_vae import VaeOnnxModel
from optimizer import optimize_by_onnxruntime, optimize_model
Expand Down
Loading

0 comments on commit 7f925ce

Please sign in to comment.