Skip to content

Commit

Permalink
Fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
apsonawane committed Nov 18, 2023
1 parent 3fe90d1 commit 93e9bda
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 110 deletions.
7 changes: 7 additions & 0 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,9 @@ if (onnxruntime_BUILD_UNIT_TESTS)
file(GLOB onnxruntime_python_transformers_testdata_whisper CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/whisper/*.onnx"
)
file(GLOB onnxruntime_python_transformers_testdata_whisper CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/conformer/*.onnx"
)
endif()

file(GLOB onnxruntime_python_tools_srcs CONFIGURE_DEPENDS
Expand Down Expand Up @@ -556,6 +559,7 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/whisper
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/eager_test
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/conformer
COMMAND ${CMAKE_COMMAND} -E copy
${ONNXRUNTIME_ROOT}/__init__.py
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/
Expand Down Expand Up @@ -711,6 +715,9 @@ if (onnxruntime_BUILD_UNIT_TESTS)
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_testdata_whisper}
$<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/whisper/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_testdata_whisper}
$<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/conformer/
)
endif()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class FusionConformerAttention(FusionAttention):
"""
Fuse Conformer Attention subgraph into one Attention node.
Fuse Conformer Attention subgraph into one MultiHeadAttention node.
"""

def __init__(
Expand Down Expand Up @@ -40,6 +40,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
matmul_qkv,
) = qkv_nodes
else:
logger.debug("fuse_conformer_attention: failed to match qkv path")
return

v_nodes = self.model.match_parent_path(
Expand All @@ -55,14 +56,15 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
present_v = concat_v.output[0]
past_v = concat_parent.output[0]
else:
logger.debug("fuse_attention: failed to match v path")
logger.debug("fuse_conformer_attention: failed to match v path")
return

qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])

if qk_nodes is not None:
_, add_qk, matmul_qk = qk_nodes
else:
logger.debug("fuse_conformer_attention: failed to match qk path")
return

q_nodes = self.model.match_parent_path(
Expand All @@ -73,6 +75,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
if q_nodes is not None:
_, _, reshape_q, add_q, matmul_q = q_nodes
else:
logger.debug("fuse_conformer_attention: failed to match q path")
return

k_nodes = self.model.match_parent_path(
Expand All @@ -88,16 +91,16 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
past_k = concat_parent.output[0]
present_k = concat_k.output[0]
else:
logger.debug("fuse_conformer_attention: failed to match k path")
return

attention_last_node = reshape_qkv
num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)

if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
logger.debug("fuse_attention: failed to detect num_heads or hidden_size")
logger.debug("fuse_conformer_attention: failed to detect num_heads or hidden_size")
return

new_node = None
new_node = self.create_multihead_attention_node(
matmul_q,
matmul_k,
Expand All @@ -116,6 +119,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
)

if new_node is None:
logger.debug("fuse_conformer_attention: MultiHeadAttention node creation failed")
return

self.nodes_to_add.append(new_node)
Expand Down
Loading

0 comments on commit 93e9bda

Please sign in to comment.