-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add fusion patterns for conformer-transducer model #18461
Conversation
c58247f
to
98dea4d
Compare
98dea4d
to
84600a0
Compare
Pease add a test case for the attention fusion. Otherwise, it is not able to prevent regression in the future. |
8fc8a31
to
fd8d044
Compare
fd8d044
to
7a730c2
Compare
3fe5d00
to
dda939e
Compare
onnxruntime/python/tools/transformers/fusion_conformer_attention.py
Outdated
Show resolved
Hide resolved
onnxruntime/python/tools/transformers/fusion_conformer_attention.py
Outdated
Show resolved
Hide resolved
from typing import List | ||
|
||
import numpy as np | ||
import onnx |
Check notice
Code scanning / CodeQL
Module is imported with 'import' and 'import from' Note test
Module 'onnxruntime.test.onnx' is imported with both 'import' and 'import from'.
class ConformerOnnxModel(BertOnnxModel): | ||
def __init__(self, model, num_heads, hidden_size): | ||
super().__init__(model, num_heads, hidden_size) | ||
self.attention_mask = AttentionMask(self) |
Check warning
Code scanning / CodeQL
Overwriting attribute in super-class or sub-class Warning
BertOnnxModel
def __init__(self, model, num_heads, hidden_size): | ||
super().__init__(model, num_heads, hidden_size) | ||
self.attention_mask = AttentionMask(self) | ||
self.attention_fusion = FusionConformerAttention(self, self.hidden_size, self.num_heads, self.attention_mask) |
Check warning
Code scanning / CodeQL
Overwriting attribute in super-class or sub-class Warning
BertOnnxModel
onnxruntime/python/tools/transformers/fusion_conformer_attention.py
Outdated
Show resolved
Hide resolved
onnxruntime/python/tools/transformers/fusion_conformer_attention.py
Outdated
Show resolved
Hide resolved
onnxruntime/python/tools/transformers/fusion_conformer_attention.py
Outdated
Show resolved
Hide resolved
onnxruntime/python/tools/transformers/fusion_conformer_attention.py
Outdated
Show resolved
Hide resolved
d2ea8e5
to
93e9bda
Compare
93e9bda
to
df11324
Compare
### Description Add conformer-transducer model type to optimizer. This PR adds pattern matches for attention shown below: Unfused attention: ![ct_unfused](https://github.com/microsoft/onnxruntime/assets/111780983/46c71ed8-67e0-4607-85b1-bcadba5a2956) Fused attention: ![ct_fused](https://github.com/microsoft/onnxruntime/assets/111780983/fbb91c96-0d4b-4f0b-8674-1ae3b9b9a92e)
Description
Add conformer-transducer model type to optimizer. This PR adds pattern matches for attention shown below:
Unfused attention:
Fused attention: