From 92c048e6f409b85a1505f35f6572656b545ebd33 Mon Sep 17 00:00:00 2001 From: Xiaoyang Chen Date: Fri, 26 Jan 2024 12:24:13 +0000 Subject: [PATCH] update --- .../fusion_conformer_attention.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_conformer_attention.py b/onnxruntime/python/tools/transformers/fusion_conformer_attention.py index e30c35bf156ba..5f098b33ce39c 100644 --- a/onnxruntime/python/tools/transformers/fusion_conformer_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_conformer_attention.py @@ -59,14 +59,28 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): logger.debug("fuse_conformer_attention: failed to match v path") return - qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0]) + qk_nodes = self.model.match_parent_path( + matmul_qkv, + ["Softmax", "Add", "Add", "MatMul"], + [0, 0, 0, 0]) if qk_nodes is not None: - _, add_qk, matmul_qk = qk_nodes + _, add_mask_qk, add_embd_qk, matmul_qk = qk_nodes else: logger.debug("fuse_conformer_attention: failed to match qk path") return + mask_nodes = self.model.match_parent_path( + add_mask_qk, + ["Cast", "Reshape", "Where", "Equal", "Cast", "Cast"], + [1, 0, 0, 0, 0, 0], + ) + if mask_nodes is not None: + _, _, where_mask, _, _, cast_mask = mask_nodes + else: + logger.debug("fuse_conformer_attention: failed to match mask path") + return + q_nodes = self.model.match_parent_path( matmul_qk, ["Mul", "Transpose", "Reshape", "Add", "MatMul"], @@ -78,6 +92,8 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): logger.debug("fuse_conformer_attention: failed to match q path") return + + k_nodes = self.model.match_parent_path( matmul_qk, ["Transpose", "Concat", "Transpose", "Reshape", "Add", "MatMul"], @@ -111,7 +127,8 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): num_heads, hidden_size, attention_last_node.output[0], - add_qk=add_qk.input[1], + key_padding_mask=cast_mask.output[0], + add_qk=add_embd_qk.input[1], past_k=past_k, past_v=past_v, present_k=present_k,