Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fusion patterns for conformer-transducer model #18461

Merged
merged 6 commits into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions onnxruntime/python/tools/transformers/fusion_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,6 @@ def create_multihead_attention_node(
return None

graph_input_names = set([node.name for node in self.model.graph().input])
graph_output_names = set([node.name for node in self.model.graph().output])
apsonawane marked this conversation as resolved.
Show resolved Hide resolved
mha_node_name = self.model.create_node_name("Attention")

# Add initial Q/K/V inputs for MHA
Expand Down Expand Up @@ -693,12 +692,13 @@ def create_multihead_attention_node(
mha_inputs.append("")

# Add optional inputs for MHA
if past_k and past_v and past_k in graph_input_names and past_v in graph_input_names:

if past_k and past_v:
apsonawane marked this conversation as resolved.
Show resolved Hide resolved
mha_inputs.extend([key_padding_mask, add_qk, past_k, past_v])

# Add outputs for MHA
mha_outputs = [output]
if present_k and present_v and present_k in graph_output_names and present_v in graph_output_names:
if present_k and present_v:
kunal-vaishnavi marked this conversation as resolved.
Show resolved Hide resolved
mha_outputs.extend([present_k, present_v])

mha_node = helper.make_node(
Expand Down
178 changes: 178 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_conformer_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging

from fusion_attention import AttentionMask, FusionAttention
from onnx_model import OnnxModel

logger = logging.getLogger(__name__)


class FusionConformerAttention(FusionAttention):
"""
Fuse Conformer Attention subgraph into one Attention node.
apsonawane marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
model: OnnxModel,
hidden_size: int,
num_heads: int,
attention_mask: AttentionMask,
):
super().__init__(model, hidden_size, num_heads, attention_mask)

def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
qkv_nodes = self.model.match_parent_path(
normalize_node,
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
[1, 1, 0, 0, 0],
)
if qkv_nodes is not None:
(
_,
_,
reshape_qkv,
transpose_qkv,
matmul_qkv,
) = qkv_nodes
else:
return

other_inputs = []
for input in normalize_node.input:
if input not in output_name_to_node:
continue
if input == qkv_nodes[0].output[0]:
continue
other_inputs.append(input)
if len(other_inputs) != 1:
return
root_input = other_inputs[0]

# Sometimes the input name to the attention MatMul nodes does not match the input name to the end
# SkipLayerNormalization node (name saved in root_input). We find the true input name to the MatMul
# nodes by getting the initial SkipLayerNormalization node and checking how many MatMul nodes are
# children nodes for each of its output names.
"""
root_input
+---------------------------------------------------+
| |
| |
SkipLayerNormalization --> Attention --> MatMul --> SkipLayerNormalization
"""
skip_layernorm = output_name_to_node[root_input]
# For some attention blocks, the end SkipLayerNormalization node may point to an Add node whose
# child is the LayerNormalization node.
if skip_layernorm.op_type == "Add":
skip_layernorm = self.model.get_children(skip_layernorm)[0]
for output in skip_layernorm.output:
if not output:
continue
children = input_name_to_nodes[output]
children_types = [child.op_type for child in children]
if children_types.count("MatMul") >= 1:
root_input = output
Fixed Show fixed Hide fixed
apsonawane marked this conversation as resolved.
Show resolved Hide resolved
break

# graph_input_names = set([node.name for node in self.model.graph().input])
apsonawane marked this conversation as resolved.
Show resolved Hide resolved
# graph_output_names = set([node.name for node in self.model.graph().output])

v_nodes = self.model.match_parent_path(
matmul_qkv,
["Concat", "Transpose", "Reshape", "Add", "MatMul"],
[1, 1, 0, 0, 1],
)

add_v = None
if v_nodes is not None:
(concat_v, _, _, add_v, matmul_v) = v_nodes
concat_parent = self.model.get_parent(concat_v, 0, None)
present_v = concat_v.output[0]
past_v = concat_parent.output[0]
else:
logger.debug("fuse_attention: failed to match v path")
apsonawane marked this conversation as resolved.
Show resolved Hide resolved
return

qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])

if qk_nodes is not None:
_, add_qk, matmul_qk = qk_nodes
else:
return

q_nodes = self.model.match_parent_path(
matmul_qk,
["Div", "Transpose", "Reshape", "Add", "MatMul"],
[0, 0, 0, 0, 1],
)
if q_nodes is not None:
_, _, reshape_q, add_q, matmul_q = q_nodes
else:
return

k_nodes = self.model.match_parent_path(
matmul_qk,
["Transpose", "Concat", "Transpose", "Reshape", "Add", "MatMul"],
[1, 0, 1, 0, 0, 1],
)

matmul_k = None
if k_nodes is not None:
_, concat_k, _, _, add_k, matmul_k = k_nodes
concat_parent = self.model.get_parent(concat_k, 0, None)
past_k = concat_parent.output[0]
present_k = concat_k.output[0]
else:
return

attention_last_node = reshape_qkv
num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)

if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
logger.debug("fuse_attention: failed to detect num_heads or hidden_size")
apsonawane marked this conversation as resolved.
Show resolved Hide resolved
return

new_node = None
apsonawane marked this conversation as resolved.
Show resolved Hide resolved
new_node = self.create_multihead_attention_node(
matmul_q,
matmul_k,
matmul_v,
add_q,
add_k,
add_v,
num_heads,
hidden_size,
attention_last_node.output[0],
add_qk=add_qk.input[1],
past_k=past_k,
past_v=past_v,
present_k=present_k,
present_v=present_v,
)

if new_node is None:
return

self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name

self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
self.nodes_to_remove.extend(qk_nodes)

# When using multihead attention, keep MatMul nodes in original graph
if q_nodes[-1].op_type == "MatMul":
q_nodes.pop()
if k_nodes[-1].op_type == "MatMul":
k_nodes.pop()
if v_nodes[-1].op_type == "MatMul":
v_nodes.pop()

self.nodes_to_remove.extend(k_nodes)
self.nodes_to_remove.extend(v_nodes)

# Use prune graph to remove mask nodes since they are shared by all attention nodes.
self.prune_graph = True
33 changes: 33 additions & 0 deletions onnxruntime/python/tools/transformers/onnx_model_ct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from typing import Optional

from fusion_attention import AttentionMask
from fusion_conformer_attention import FusionConformerAttention
from fusion_options import FusionOptions
from onnx_model_bert import BertOnnxModel

logger = logging.getLogger(__name__)


class ConformerOnnxModel(BertOnnxModel):
def __init__(self, model, num_heads, hidden_size):
super().__init__(model, num_heads, hidden_size)
self.attention_mask = AttentionMask(self)
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
self.attention_fusion = FusionConformerAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
Fixed Show fixed Hide fixed

def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False):
self.attention_fusion.use_multi_head_attention = False if options is None else options.use_multi_head_attention
self.attention_fusion.disable_multi_head_attention_bias = (
False if options is None else options.disable_multi_head_attention_bias
)
super().optimize(options, add_dynamic_axes)

def fuse_attention(self):
self.attention_fusion.apply()

def preprocess(self):
self.adjust_reshape_and_expand()
2 changes: 2 additions & 0 deletions onnxruntime/python/tools/transformers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from onnx_model_bert_keras import BertOnnxModelKeras
from onnx_model_bert_tf import BertOnnxModelTF
from onnx_model_clip import ClipOnnxModel
from onnx_model_ct import ConformerOnnxModel
apsonawane marked this conversation as resolved.
Show resolved Hide resolved
from onnx_model_gpt2 import Gpt2OnnxModel
from onnx_model_t5 import T5OnnxModel
from onnx_model_tnlr import TnlrOnnxModel
Expand All @@ -56,6 +57,7 @@
"unet": (UnetOnnxModel, "pytorch", 1), # UNet in Stable Diffusion
"vae": (VaeOnnxModel, "pytorch", 1), # UAE in Stable Diffusion
"vit": (BertOnnxModel, "pytorch", 1),
"ct": (ConformerOnnxModel, "pytorch", 1),
apsonawane marked this conversation as resolved.
Show resolved Hide resolved
}


Expand Down
Loading
Loading