Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fusion patterns for conformer-transducer model #18461

Merged
merged 6 commits into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,9 @@ if (onnxruntime_BUILD_UNIT_TESTS)
file(GLOB onnxruntime_python_transformers_testdata_whisper CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/whisper/*.onnx"
)
file(GLOB onnxruntime_python_transformers_testdata_conformer CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/conformer/*.onnx"
apsonawane marked this conversation as resolved.
Show resolved Hide resolved
)
endif()

file(GLOB onnxruntime_python_tools_srcs CONFIGURE_DEPENDS
Expand Down Expand Up @@ -556,6 +559,7 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/whisper
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/eager_test
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/conformer
COMMAND ${CMAKE_COMMAND} -E copy
${ONNXRUNTIME_ROOT}/__init__.py
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/
Expand Down Expand Up @@ -711,6 +715,9 @@ if (onnxruntime_BUILD_UNIT_TESTS)
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_testdata_whisper}
$<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/whisper/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_testdata_conformer}
$<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/conformer/
)
endif()

Expand Down
8 changes: 5 additions & 3 deletions onnxruntime/python/tools/transformers/fusion_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,6 @@ def create_multihead_attention_node(
return None

graph_input_names = set([node.name for node in self.model.graph().input])
graph_output_names = set([node.name for node in self.model.graph().output])
apsonawane marked this conversation as resolved.
Show resolved Hide resolved
mha_node_name = self.model.create_node_name("Attention")

# Add initial Q/K/V inputs for MHA
Expand Down Expand Up @@ -693,12 +692,15 @@ def create_multihead_attention_node(
mha_inputs.append("")

# Add optional inputs for MHA
if past_k and past_v and past_k in graph_input_names and past_v in graph_input_names:

if past_k and past_v:
apsonawane marked this conversation as resolved.
Show resolved Hide resolved
mha_inputs.extend([key_padding_mask, add_qk, past_k, past_v])
elif key_padding_mask or add_qk:
mha_inputs.extend([key_padding_mask, add_qk])

# Add outputs for MHA
mha_outputs = [output]
if present_k and present_v and present_k in graph_output_names and present_v in graph_output_names:
if present_k and present_v:
kunal-vaishnavi marked this conversation as resolved.
Show resolved Hide resolved
mha_outputs.extend([present_k, present_v])

mha_node = helper.make_node(
Expand Down
143 changes: 143 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_conformer_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging

from fusion_attention import AttentionMask, FusionAttention
from onnx_model import OnnxModel

logger = logging.getLogger(__name__)


class FusionConformerAttention(FusionAttention):
"""
Fuse Conformer Attention subgraph into one MultiHeadAttention node.
"""

def __init__(
self,
model: OnnxModel,
hidden_size: int,
num_heads: int,
attention_mask: AttentionMask,
):
super().__init__(model, hidden_size, num_heads, attention_mask)

def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
qkv_nodes = self.model.match_parent_path(
normalize_node,
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
[1, 1, 0, 0, 0],
)
if qkv_nodes is not None:
(
_,
_,
reshape_qkv,
transpose_qkv,
matmul_qkv,
) = qkv_nodes
else:
logger.debug("fuse_conformer_attention: failed to match qkv path")
return

v_nodes = self.model.match_parent_path(
matmul_qkv,
["Concat", "Transpose", "Reshape", "Add", "MatMul"],
[1, 1, 0, 0, 1],
)

add_v = None
if v_nodes is not None:
(concat_v, _, _, add_v, matmul_v) = v_nodes
concat_parent = self.model.get_parent(concat_v, 0, None)
present_v = concat_v.output[0]
past_v = concat_parent.output[0]
else:
logger.debug("fuse_conformer_attention: failed to match v path")
return

qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])

if qk_nodes is not None:
_, add_qk, matmul_qk = qk_nodes
else:
logger.debug("fuse_conformer_attention: failed to match qk path")
return

q_nodes = self.model.match_parent_path(
matmul_qk,
["Div", "Transpose", "Reshape", "Add", "MatMul"],
[0, 0, 0, 0, 1],
)
if q_nodes is not None:
_, _, reshape_q, add_q, matmul_q = q_nodes
else:
logger.debug("fuse_conformer_attention: failed to match q path")
return

k_nodes = self.model.match_parent_path(
matmul_qk,
["Transpose", "Concat", "Transpose", "Reshape", "Add", "MatMul"],
[1, 0, 1, 0, 0, 1],
)

matmul_k = None
if k_nodes is not None:
_, concat_k, _, _, add_k, matmul_k = k_nodes
concat_parent = self.model.get_parent(concat_k, 0, None)
past_k = concat_parent.output[0]
present_k = concat_k.output[0]
else:
logger.debug("fuse_conformer_attention: failed to match k path")
return

attention_last_node = reshape_qkv
num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)

if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
logger.debug("fuse_conformer_attention: failed to detect num_heads or hidden_size")
return

new_node = self.create_multihead_attention_node(
matmul_q,
matmul_k,
matmul_v,
add_q,
add_k,
add_v,
num_heads,
hidden_size,
attention_last_node.output[0],
add_qk=add_qk.input[1],
past_k=past_k,
past_v=past_v,
present_k=present_k,
present_v=present_v,
)

if new_node is None:
logger.debug("fuse_conformer_attention: MultiHeadAttention node creation failed")
return

self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name

self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
self.nodes_to_remove.extend(qk_nodes)

# When using multihead attention, keep MatMul nodes in original graph
if q_nodes[-1].op_type == "MatMul":
q_nodes.pop()
if k_nodes[-1].op_type == "MatMul":
k_nodes.pop()
if v_nodes[-1].op_type == "MatMul":
v_nodes.pop()

self.nodes_to_remove.extend(k_nodes)
self.nodes_to_remove.extend(v_nodes)

# Use prune graph to remove mask nodes since they are shared by all attention nodes.
self.prune_graph = True
33 changes: 33 additions & 0 deletions onnxruntime/python/tools/transformers/onnx_model_conformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from typing import Optional

from fusion_attention import AttentionMask
from fusion_conformer_attention import FusionConformerAttention
from fusion_options import FusionOptions
from onnx_model_bert import BertOnnxModel

logger = logging.getLogger(__name__)


class ConformerOnnxModel(BertOnnxModel):
def __init__(self, model, num_heads, hidden_size):
super().__init__(model, num_heads, hidden_size)
self.attention_mask = AttentionMask(self)

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute attention_mask, which was previously defined in superclass
BertOnnxModel
.
self.attention_fusion = FusionConformerAttention(self, self.hidden_size, self.num_heads, self.attention_mask)

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute attention_fusion, which was previously defined in superclass
BertOnnxModel
.

def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False):
self.attention_fusion.use_multi_head_attention = False if options is None else options.use_multi_head_attention
self.attention_fusion.disable_multi_head_attention_bias = (
False if options is None else options.disable_multi_head_attention_bias
)
super().optimize(options, add_dynamic_axes)

def fuse_attention(self):
self.attention_fusion.apply()

def preprocess(self):
self.adjust_reshape_and_expand()
2 changes: 2 additions & 0 deletions onnxruntime/python/tools/transformers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from onnx_model_bert_keras import BertOnnxModelKeras
from onnx_model_bert_tf import BertOnnxModelTF
from onnx_model_clip import ClipOnnxModel
from onnx_model_conformer import ConformerOnnxModel
from onnx_model_gpt2 import Gpt2OnnxModel
from onnx_model_t5 import T5OnnxModel
from onnx_model_tnlr import TnlrOnnxModel
Expand All @@ -56,6 +57,7 @@
"unet": (UnetOnnxModel, "pytorch", 1), # UNet in Stable Diffusion
"vae": (VaeOnnxModel, "pytorch", 1), # UAE in Stable Diffusion
"vit": (BertOnnxModel, "pytorch", 1),
"conformer": (ConformerOnnxModel, "pytorch", 1),
}


Expand Down
Loading
Loading