Skip to content

Commit

Permalink
Allow SkipLayerNorm and LayerNorm in rotary attention fusion (#18288)
Browse files Browse the repository at this point in the history
Although SimplifiedLayerNorm is faster than LayerNorm, DML doesn't have
an optimized implementation for the former yet and LayerNorm ends up
being faster.
  • Loading branch information
PatriceVignola authored Nov 7, 2023
1 parent fb6737e commit 276918d
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions onnxruntime/python/tools/transformers/fusion_rotary_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@ def __init__(
hidden_size,
num_heads,
use_multi_head_attention=True,
search_op_types=["SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization", "Add"],
search_op_types=[
"SimplifiedLayerNormalization",
"SkipSimplifiedLayerNormalization",
"LayerNormalization",
"SkipLayerNormalization",
"Add",
],
)

def create_mha_node(
Expand Down Expand Up @@ -318,7 +324,7 @@ def check_runtime_shape_paths_for_nodes(
return True

def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
if normalize_node.op_type != "SkipSimplifiedLayerNormalization" and normalize_node.op_type != "Add":
if normalize_node.op_type not in {"SkipSimplifiedLayerNormalization", "SkipLayerNormalization", "Add"}:
return

# qkv_nodes_1 is for LLaMA-2 Microsoft
Expand Down

0 comments on commit 276918d

Please sign in to comment.