Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoych030 committed Jan 26, 2024
1 parent d61119c commit 92c048e
Showing 1 changed file with 20 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,28 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
logger.debug("fuse_conformer_attention: failed to match v path")
return

qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
qk_nodes = self.model.match_parent_path(
matmul_qkv,
["Softmax", "Add", "Add", "MatMul"],
[0, 0, 0, 0])

if qk_nodes is not None:
_, add_qk, matmul_qk = qk_nodes
_, add_mask_qk, add_embd_qk, matmul_qk = qk_nodes
else:
logger.debug("fuse_conformer_attention: failed to match qk path")
return

mask_nodes = self.model.match_parent_path(
add_mask_qk,
["Cast", "Reshape", "Where", "Equal", "Cast", "Cast"],
[1, 0, 0, 0, 0, 0],
)
if mask_nodes is not None:
_, _, where_mask, _, _, cast_mask = mask_nodes
else:
logger.debug("fuse_conformer_attention: failed to match mask path")
return

q_nodes = self.model.match_parent_path(
matmul_qk,
["Mul", "Transpose", "Reshape", "Add", "MatMul"],
Expand All @@ -78,6 +92,8 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
logger.debug("fuse_conformer_attention: failed to match q path")
return



k_nodes = self.model.match_parent_path(
matmul_qk,
["Transpose", "Concat", "Transpose", "Reshape", "Add", "MatMul"],
Expand Down Expand Up @@ -111,7 +127,8 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
num_heads,
hidden_size,
attention_last_node.output[0],
add_qk=add_qk.input[1],
key_padding_mask=cast_mask.output[0],
add_qk=add_embd_qk.input[1],
past_k=past_k,
past_v=past_v,
present_k=present_k,
Expand Down

0 comments on commit 92c048e

Please sign in to comment.