Skip to content

Commit

Permalink
Add fusion patterns for conformer-transducer model (#18461)
Browse files Browse the repository at this point in the history
### Description
Add conformer-transducer model type to optimizer. This PR adds pattern
matches for attention shown below:
Unfused attention:

![ct_unfused](https://github.com/microsoft/onnxruntime/assets/111780983/46c71ed8-67e0-4607-85b1-bcadba5a2956)

Fused attention:

![ct_fused](https://github.com/microsoft/onnxruntime/assets/111780983/fbb91c96-0d4b-4f0b-8674-1ae3b9b9a92e)
  • Loading branch information
apsonawane authored Nov 19, 2023
1 parent 53917a3 commit 97cc40d
Show file tree
Hide file tree
Showing 8 changed files with 802 additions and 3 deletions.
7 changes: 7 additions & 0 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,9 @@ if (onnxruntime_BUILD_UNIT_TESTS)
file(GLOB onnxruntime_python_transformers_testdata_whisper CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/whisper/*.onnx"
)
file(GLOB onnxruntime_python_transformers_testdata_conformer CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/conformer/*.onnx"
)
endif()

file(GLOB onnxruntime_python_tools_srcs CONFIGURE_DEPENDS
Expand Down Expand Up @@ -549,6 +552,7 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/whisper
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/eager_test
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/conformer
COMMAND ${CMAKE_COMMAND} -E copy
${ONNXRUNTIME_ROOT}/__init__.py
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/
Expand Down Expand Up @@ -701,6 +705,9 @@ if (onnxruntime_BUILD_UNIT_TESTS)
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_testdata_whisper}
$<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/whisper/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_testdata_conformer}
$<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/conformer/
)
endif()

Expand Down
8 changes: 5 additions & 3 deletions onnxruntime/python/tools/transformers/fusion_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,6 @@ def create_multihead_attention_node(
return None

graph_input_names = set([node.name for node in self.model.graph().input])
graph_output_names = set([node.name for node in self.model.graph().output])
mha_node_name = self.model.create_node_name("Attention")

# Add initial Q/K/V inputs for MHA
Expand Down Expand Up @@ -693,12 +692,15 @@ def create_multihead_attention_node(
mha_inputs.append("")

# Add optional inputs for MHA
if past_k and past_v and past_k in graph_input_names and past_v in graph_input_names:

if past_k and past_v:
mha_inputs.extend([key_padding_mask, add_qk, past_k, past_v])
elif key_padding_mask or add_qk:
mha_inputs.extend([key_padding_mask, add_qk])

# Add outputs for MHA
mha_outputs = [output]
if present_k and present_v and present_k in graph_output_names and present_v in graph_output_names:
if present_k and present_v:
mha_outputs.extend([present_k, present_v])

mha_node = helper.make_node(
Expand Down
143 changes: 143 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_conformer_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging

from fusion_attention import AttentionMask, FusionAttention
from onnx_model import OnnxModel

logger = logging.getLogger(__name__)


class FusionConformerAttention(FusionAttention):
"""
Fuse Conformer Attention subgraph into one MultiHeadAttention node.
"""

def __init__(
self,
model: OnnxModel,
hidden_size: int,
num_heads: int,
attention_mask: AttentionMask,
):
super().__init__(model, hidden_size, num_heads, attention_mask)

def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
qkv_nodes = self.model.match_parent_path(
normalize_node,
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
[1, 1, 0, 0, 0],
)
if qkv_nodes is not None:
(
_,
_,
reshape_qkv,
transpose_qkv,
matmul_qkv,
) = qkv_nodes
else:
logger.debug("fuse_conformer_attention: failed to match qkv path")
return

v_nodes = self.model.match_parent_path(
matmul_qkv,
["Concat", "Transpose", "Reshape", "Add", "MatMul"],
[1, 1, 0, 0, 1],
)

add_v = None
if v_nodes is not None:
(concat_v, _, _, add_v, matmul_v) = v_nodes
concat_parent = self.model.get_parent(concat_v, 0, None)
present_v = concat_v.output[0]
past_v = concat_parent.output[0]
else:
logger.debug("fuse_conformer_attention: failed to match v path")
return

qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])

if qk_nodes is not None:
_, add_qk, matmul_qk = qk_nodes
else:
logger.debug("fuse_conformer_attention: failed to match qk path")
return

q_nodes = self.model.match_parent_path(
matmul_qk,
["Div", "Transpose", "Reshape", "Add", "MatMul"],
[0, 0, 0, 0, 1],
)
if q_nodes is not None:
_, _, reshape_q, add_q, matmul_q = q_nodes
else:
logger.debug("fuse_conformer_attention: failed to match q path")
return

k_nodes = self.model.match_parent_path(
matmul_qk,
["Transpose", "Concat", "Transpose", "Reshape", "Add", "MatMul"],
[1, 0, 1, 0, 0, 1],
)

matmul_k = None
if k_nodes is not None:
_, concat_k, _, _, add_k, matmul_k = k_nodes
concat_parent = self.model.get_parent(concat_k, 0, None)
past_k = concat_parent.output[0]
present_k = concat_k.output[0]
else:
logger.debug("fuse_conformer_attention: failed to match k path")
return

attention_last_node = reshape_qkv
num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)

if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
logger.debug("fuse_conformer_attention: failed to detect num_heads or hidden_size")
return

new_node = self.create_multihead_attention_node(
matmul_q,
matmul_k,
matmul_v,
add_q,
add_k,
add_v,
num_heads,
hidden_size,
attention_last_node.output[0],
add_qk=add_qk.input[1],
past_k=past_k,
past_v=past_v,
present_k=present_k,
present_v=present_v,
)

if new_node is None:
logger.debug("fuse_conformer_attention: MultiHeadAttention node creation failed")
return

self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name

self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
self.nodes_to_remove.extend(qk_nodes)

# When using multihead attention, keep MatMul nodes in original graph
if q_nodes[-1].op_type == "MatMul":
q_nodes.pop()
if k_nodes[-1].op_type == "MatMul":
k_nodes.pop()
if v_nodes[-1].op_type == "MatMul":
v_nodes.pop()

self.nodes_to_remove.extend(k_nodes)
self.nodes_to_remove.extend(v_nodes)

# Use prune graph to remove mask nodes since they are shared by all attention nodes.
self.prune_graph = True
33 changes: 33 additions & 0 deletions onnxruntime/python/tools/transformers/onnx_model_conformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from typing import Optional

from fusion_attention import AttentionMask
from fusion_conformer_attention import FusionConformerAttention
from fusion_options import FusionOptions
from onnx_model_bert import BertOnnxModel

logger = logging.getLogger(__name__)


class ConformerOnnxModel(BertOnnxModel):
def __init__(self, model, num_heads, hidden_size):
super().__init__(model, num_heads, hidden_size)
self.attention_mask = AttentionMask(self)
self.attention_fusion = FusionConformerAttention(self, self.hidden_size, self.num_heads, self.attention_mask)

def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False):
self.attention_fusion.use_multi_head_attention = False if options is None else options.use_multi_head_attention
self.attention_fusion.disable_multi_head_attention_bias = (
False if options is None else options.disable_multi_head_attention_bias
)
super().optimize(options, add_dynamic_axes)

def fuse_attention(self):
self.attention_fusion.apply()

def preprocess(self):
self.adjust_reshape_and_expand()
2 changes: 2 additions & 0 deletions onnxruntime/python/tools/transformers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from onnx_model_bert_keras import BertOnnxModelKeras
from onnx_model_bert_tf import BertOnnxModelTF
from onnx_model_clip import ClipOnnxModel
from onnx_model_conformer import ConformerOnnxModel
from onnx_model_gpt2 import Gpt2OnnxModel
from onnx_model_t5 import T5OnnxModel
from onnx_model_tnlr import TnlrOnnxModel
Expand All @@ -56,6 +57,7 @@
"unet": (UnetOnnxModel, "pytorch", 1), # UNet in Stable Diffusion
"vae": (VaeOnnxModel, "pytorch", 1), # UAE in Stable Diffusion
"vit": (BertOnnxModel, "pytorch", 1),
"conformer": (ConformerOnnxModel, "pytorch", 1),
}


Expand Down
Loading

0 comments on commit 97cc40d

Please sign in to comment.