Skip to content

Commit

Permalink
test attention bias
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Aug 9, 2024
1 parent ccbbce8 commit 93d8708
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 67 deletions.
1 change: 1 addition & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ find_package(Python3 COMPONENTS Interpreter REQUIRED)

# GLOB pattern of file to be excluded
set(contrib_ops_excluded_files
"bert/cudnn_fmha/*"
"bert/cutlass_fmha/*"
"bert/fastertransformer_decoder_attention/*"
"bert/flash_attention/*"
Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/test/python/transformers/benchmark_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,11 @@ def run_tflops_test(
input_dict = config.random_inputs()

# warm up session
_ = measure_latency(session, input_dict)
try:
_ = measure_latency(session, input_dict)
except Exception as e:
print(f"Failed to run {kernel=} for {config=}. Exception: {e}")
continue

latency_list = []
for _ in range(repeats):
Expand Down
153 changes: 87 additions & 66 deletions onnxruntime/test/python/transformers/test_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,6 @@ def attention_reference(
if verbose:
print("masked QK(ref)", attn)

if mask is not None:
attn = attn.masked_fill((1 - mask.int()).bool(), float("-inf"))
if verbose:
print("masked QK(SDPA)", attn)

attn = attn.softmax(-1)
if verbose:
print("Softmax(ref)", attn)
Expand Down Expand Up @@ -193,6 +188,15 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool):
heads = [1, 3, 4, 16]
head_sizes = [8, 16, 32, 40, 64, 80, 96, 128, 160, 192, 224, 256]

atten_bias_options = [
# (has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1)
(False, False, False),
(True, False, False), # [b, n, s_q, s_kv]
# (True, True, False), # [1, n, s_q, s_kv]
# (True, False, True), # [b, 1, s_q, s_kv]
# (True, True, True), # [1, 1, s_q, s_kv]
]

mask_formats = [
AttentionMaskFormat.Mask_None,
AttentionMaskFormat.Mask_1D_Key_SeqLen,
Expand All @@ -210,25 +214,33 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool):
for causal in [True, False]:
for mask_format in mask_formats:
for has_bias in get_bias_support(format):
config = MultiHeadAttentionConfig(
batch_size=batch_size,
sequence_length=sequence_length,
num_heads=num_heads,
head_size=head_size,
causal=causal,
past_sequence_length=0,
kv_sequence_length=sequence_length,
max_cache_sequence_length=None,
provider=provider,
device=device,
dtype=dtype,
use_kv_cache=False,
share_past_present_buffer=False,
input_format=format,
has_bias=has_bias,
mask_format=mask_format,
)
yield config
for (
has_attn_bias,
broadcast_attn_bias_dim_0,
broadcast_attn_bias_dim_1,
) in atten_bias_options:
config = MultiHeadAttentionConfig(
batch_size=batch_size,
sequence_length=sequence_length,
num_heads=num_heads,
head_size=head_size,
causal=causal,
past_sequence_length=0,
kv_sequence_length=sequence_length,
max_cache_sequence_length=None,
provider=provider,
device=device,
dtype=dtype,
use_kv_cache=False,
share_past_present_buffer=False,
input_format=format,
has_bias=has_bias,
mask_format=mask_format,
has_attn_bias=has_attn_bias,
broadcast_attn_bias_dim_0=broadcast_attn_bias_dim_0,
broadcast_attn_bias_dim_1=broadcast_attn_bias_dim_1,
)
yield config
else:
test_cases = max(len(batch_sizes), len(sequence_lengths), len(heads), len(head_sizes))
for i in range(test_cases):
Expand All @@ -237,6 +249,9 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool):
num_heads = heads[i % len(heads)]
head_size = head_sizes[i % len(head_sizes)]
mask_format = mask_formats[i % len(mask_formats)]
has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1 = atten_bias_options[
i % len(atten_bias_options)
]
for causal in [True, False]:
for format in formats:
for has_bias in get_bias_support(format):
Expand All @@ -257,6 +272,9 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool):
input_format=format,
has_bias=has_bias,
mask_format=mask_format,
has_attn_bias=has_attn_bias,
broadcast_attn_bias_dim_0=broadcast_attn_bias_dim_0,
broadcast_attn_bias_dim_1=broadcast_attn_bias_dim_1,
)
yield config

Expand All @@ -277,6 +295,15 @@ def kv_cache_test_cases(provider: str, comprehensive: bool):
AttentionMaskFormat.Mask_2D_Key_PaddingMask,
]

atten_bias_options = [
# (has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1)
(False, False, False),
(True, False, False), # [b, n, s_q, s_kv]
# (True, True, False), # [1, n, s_q, s_kv]
# (True, False, True), # [b, 1, s_q, s_kv]
# (True, True, True), # [1, 1, s_q, s_kv]
]

if comprehensive:
sequence_lengths = [*sequence_lengths, 2048] # Large sequence length is slow and need a lot of memory
for batch_size in batch_sizes:
Expand All @@ -288,28 +315,36 @@ def kv_cache_test_cases(provider: str, comprehensive: bool):
for has_past_input in [True, False]:
for mask_format in mask_formats:
for has_bias in get_bias_support(format):
sequence_length = 1 if has_past_input else past_sequence_length
past_seq_len = past_sequence_length if has_past_input else 0
config = MultiHeadAttentionConfig(
batch_size=batch_size,
sequence_length=sequence_length,
num_heads=num_heads,
head_size=head_size,
causal=causal,
past_sequence_length=past_seq_len,
kv_sequence_length=sequence_length,
max_cache_sequence_length=None,
provider=provider,
device=device,
dtype=dtype,
use_kv_cache=True,
has_past_input=has_past_input,
share_past_present_buffer=False,
input_format=format,
has_bias=has_bias,
mask_format=mask_format,
)
yield config
for (
has_attn_bias,
broadcast_attn_bias_dim_0,
broadcast_attn_bias_dim_1,
) in atten_bias_options:
sequence_length = 1 if has_past_input else past_sequence_length
past_seq_len = past_sequence_length if has_past_input else 0
config = MultiHeadAttentionConfig(
batch_size=batch_size,
sequence_length=sequence_length,
num_heads=num_heads,
head_size=head_size,
causal=causal,
past_sequence_length=past_seq_len,
kv_sequence_length=sequence_length,
max_cache_sequence_length=None,
provider=provider,
device=device,
dtype=dtype,
use_kv_cache=True,
has_past_input=has_past_input,
share_past_present_buffer=False,
input_format=format,
has_bias=has_bias,
mask_format=mask_format,
has_attn_bias=has_attn_bias,
broadcast_attn_bias_dim_0=broadcast_attn_bias_dim_0,
broadcast_attn_bias_dim_1=broadcast_attn_bias_dim_1,
)
yield config
else:
test_cases = max(len(batch_sizes), len(sequence_lengths), len(heads), len(head_sizes))
for i in range(test_cases):
Expand All @@ -318,6 +353,9 @@ def kv_cache_test_cases(provider: str, comprehensive: bool):
num_heads = heads[i % len(heads)]
head_size = head_sizes[i % len(head_sizes)]
mask_format = mask_formats[i % len(mask_formats)]
has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1 = atten_bias_options[
i % len(atten_bias_options)
]
for causal in [True, False]:
for format in formats:
for has_past_input in [True, False]:
Expand All @@ -342,6 +380,9 @@ def kv_cache_test_cases(provider: str, comprehensive: bool):
input_format=format,
has_bias=has_bias,
mask_format=mask_format,
has_attn_bias=has_attn_bias,
broadcast_attn_bias_dim_0=broadcast_attn_bias_dim_0,
broadcast_attn_bias_dim_1=broadcast_attn_bias_dim_1,
)
yield config

Expand Down Expand Up @@ -523,26 +564,6 @@ def parity_check_mha(
# Restore the input format so that it shows up in the error message correctly.
config.input_format = ort_input_format

# Fill zeros for the padded kens for comparison.
if config.mask_index_q is not None:
for i, m in enumerate(config.mask_index_q):
out[i, m:, :, :] = 0
out_ref[i, m:, :, :] = 0

if config.mask_index_kv is not None and config.use_kv_cache:
assert k_cache is not None
assert v_cache is not None
present_key = ort_outputs["present_key"]
present_value = ort_outputs["present_value"]
for i, n in enumerate(config.mask_index_kv):
k_cache[i, :, n:, :] = 0
present_key[i, :, n:, :] = 0
v_cache[i, :, n:, :] = 0
present_value[i, :, n:, :] = 0

# Restore the input format so that it shows up in the error message correctly.
config.input_format = ort_input_format

numpy.testing.assert_allclose(
out.detach().cpu().numpy(),
out_ref.detach().cpu().numpy(),
Expand Down

0 comments on commit 93d8708

Please sign in to comment.