From ac38c44bfb23895ee8cb2db918523a2f29bc0ea6 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 9 Aug 2024 01:31:00 -0700 Subject: [PATCH] [CUDA] Fix MHA mask (#21655) ### Description Fix a check of mask type introduced by me in a recent commit. Add tests. --- .../cuda/bert/multihead_attention.cc | 4 +- .../test/python/transformers/benchmark_mha.py | 99 ++++++++- .../test/python/transformers/test_mha.py | 204 ++++++++++++------ 3 files changed, 233 insertions(+), 74 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index c36abc8e1d624..2835192abd298 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -182,6 +182,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif + bool is_mask_none_or_1d_k_len = parameters.mask_type == AttentionMaskType::MASK_NONE || + parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN; bool use_fused_cross_attention = !use_flash_attention && !disable_fused_cross_attention_ && @@ -213,7 +215,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { nullptr == relative_position_bias && (parameters.qkv_format == Q_K_V_BSNH || parameters.qkv_format == QKV_BSN3H) && nullptr == past_key && nullptr == present_key && - (nullptr == key_padding_mask || AttentionMaskType::MASK_1D_KEY_SEQ_LEN) && + is_mask_none_or_1d_k_len && parameters.hidden_size == parameters.v_hidden_size && parameters.sequence_length == parameters.kv_sequence_length && // self attention only for fused runner FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length, diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index ec350874af32c..0c52ee690af82 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -71,6 +71,13 @@ class SdpaKernel(IntEnum): TRT_CAUSAL_ATTENTION = 128 +# Since we support attention bias, so we only need support up to 2D mask. +class AttentionMaskFormat(IntEnum): + Mask_None = 0 # No attention mask. + Mask_1D_Key_SeqLen = 1 # Shape (batch_size), actual sequence lengths (excluding padding on the right side). + Mask_2D_Key_PaddingMask = 2 # Shape (batch_size, total_sequence_length), key padding mask mask. + + class MultiHeadAttentionConfig: def __init__( self, @@ -93,6 +100,7 @@ def __init__( input_format: int = InputFormats.Q_K_V_BSNH_BSNH_BSNH, verbose: bool = False, has_bias: bool = False, + mask_format: int = AttentionMaskFormat.Mask_None, ): self.operator = "MultiHeadAttention" self.batch_size = batch_size @@ -144,6 +152,19 @@ def __init__( self.verbose = verbose self.has_bias = has_bias + assert mask_format in [ + AttentionMaskFormat.Mask_None, + AttentionMaskFormat.Mask_1D_Key_SeqLen, + AttentionMaskFormat.Mask_2D_Key_PaddingMask, + ] + self.mask_format = mask_format + + # mask_index_q and mask_index_kv will be updated in random_inputs() if mask_format is not Mask_None. + self.mask_index_kv = torch.ones(self.batch_size, dtype=torch.int32, device=self.device) * self.sequence_length + self.mask_index_q = ( + torch.ones(self.batch_size, dtype=torch.int32, device=self.device) * self.total_sequence_length + ) + def __repr__(self): return ( f"MultiHeadAttentionConfig(batch_size={self.batch_size}, sequence_length={self.sequence_length}, " @@ -154,7 +175,7 @@ def __repr__(self): f"share_past_present_buffer={self.share_past_present_buffer}, " f"provider={self.provider}, device={self.device}, enable_cuda_graph={self.enable_cuda_graph}, " f"dtype={self.dtype}, input_format={InputFormats.input_format_str(self.input_format)}, " - f"has_bias={self.has_bias}" + f"has_bias={self.has_bias}, mask_format={self.mask_format}" ) def shape_dict(self, input_format=None): @@ -207,6 +228,13 @@ def shape_dict(self, input_format=None): if self.has_bias: shapes["bias"] = (3 * self.num_heads * self.head_size,) + if self.mask_format == AttentionMaskFormat.Mask_1D_Key_SeqLen: + shapes["mask"] = (self.batch_size,) + elif self.mask_format == AttentionMaskFormat.Mask_2D_Key_PaddingMask: + shapes["mask"] = (self.batch_size, self.total_sequence_length) + else: + assert self.mask_format == AttentionMaskFormat.Mask_None + return shapes def symbolic_shape_dict(self, input_format=None): @@ -259,8 +287,35 @@ def symbolic_shape_dict(self, input_format=None): if self.has_bias: shapes["bias"] = (3 * self.num_heads * self.head_size,) + if self.mask_format == AttentionMaskFormat.Mask_1D_Key_SeqLen: + shapes["mask"] = (self.batch_size,) + elif self.mask_format == AttentionMaskFormat.Mask_2D_Key_PaddingMask: + shapes["mask"] = (self.batch_size, "total_sequence_length") + else: + assert self.mask_format == AttentionMaskFormat.Mask_None + return shapes + def right_side_padding_masks(self): + q_mask = torch.ones(self.batch_size, 1, self.sequence_length, 1, dtype=torch.bool, device=self.device) + k_mask = torch.ones(self.batch_size, 1, self.total_sequence_length, 1, dtype=torch.bool, device=self.device) + mask = torch.ones( + self.batch_size, + self.num_heads, + self.sequence_length, + self.total_sequence_length, + dtype=torch.bool, + device=self.device, + ) + + if self.mask_format != AttentionMaskFormat.Mask_None: + for i, (m, n) in enumerate(zip(self.mask_index_q, self.mask_index_kv)): + q_mask[i, :, m:, :] = False + k_mask[i, :, n:, :] = False + mask[i, :, m:, :] = False + mask[i, :, :, n:] = False + return q_mask, k_mask, mask + def random_inputs(self, seed: int = 123, no_bias_k_v: bool = False): device = self.device dtype = self.dtype @@ -325,13 +380,38 @@ def random_inputs(self, seed: int = 123, no_bias_k_v: bool = False): if self.has_bias: feeds["bias"] = torch.concat([bias_q, bias_k, bias_v], dim=0).reshape(shape_dict["bias"]).contiguous() + # Generate padding mask + if self.mask_format != AttentionMaskFormat.Mask_None: + self.mask_index_kv = torch.randint( + 1, self.total_sequence_length + 1, (self.batch_size,), dtype=torch.int32, device=self.device + ) + if self.past_sequence_length > 0: + self.mask_index_q = ( + torch.ones(self.batch_size, dtype=torch.int32, device=self.device) * self.sequence_length + ) + else: # prompt case + self.mask_index_q = self.mask_index_kv.clone() + + mask = None + if self.mask_format == AttentionMaskFormat.Mask_1D_Key_SeqLen: + mask = self.mask_index_kv.clone() + elif self.mask_format == AttentionMaskFormat.Mask_2D_Key_PaddingMask: + k_mask = torch.ones(self.batch_size, 1, self.total_sequence_length, 1, dtype=torch.bool, device=self.device) + for i, n in enumerate(self.mask_index_kv): + k_mask[i, :, n:, :] = False + mask = k_mask.reshape(self.batch_size, self.total_sequence_length) + else: + assert self.mask_format == AttentionMaskFormat.Mask_None + + if mask is not None: + feeds = {**feeds, "mask": mask.to(dtype=torch.int32)} # mask is int32 (not bool) for MultiHeadAttention op. + return feeds def get_input_output_names(self): if self.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: - return ["query", "key", "value"], ["output"] - - if self.input_format == InputFormats.QKV_BSN3H: + inputs, outputs = ["query", "key", "value"], ["output"] + elif self.input_format == InputFormats.QKV_BSN3H: inputs, outputs = ["query"], ["output"] elif self.input_format == InputFormats.Q_KV_BSNH_BSN2H: inputs, outputs = ["query", "key"], ["output"] @@ -339,8 +419,12 @@ def get_input_output_names(self): inputs, outputs = ["query", "key", "value"], ["output"] if self.has_bias: + assert self.input_format != InputFormats.Q_KV_BSNH_BSN2H inputs = [*inputs, "bias"] + if self.mask_format != AttentionMaskFormat.Mask_None: + inputs = [*inputs, "mask"] + if self.has_past_input: inputs = [*inputs, "past_key", "past_value"] @@ -351,7 +435,7 @@ def get_input_output_names(self): def fill_optional_mha_inputs(input_names): - inputs = ["query", "key", "value", "bias", "key_padding_mask", "relative_position_bias", "past_key", "past_value"] + inputs = ["query", "key", "value", "bias", "mask", "relative_position_bias", "past_key", "past_value"] # Remove optional inputs that are not in input_names with empty string inputs_with_optional = [input if input in input_names else "" for input in inputs] @@ -376,13 +460,16 @@ def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use num_heads=config.num_heads, unidirectional=int(config.causal), scale=config.softmax_scale, + mask_filter_value=float("-inf"), domain="com.microsoft", ), ] shape_dict = config.symbolic_shape_dict() if use_symbolic_shape else config.shape_dict() inputs = [ - helper.make_tensor_value_info(input_name, float_type, list(shape_dict[input_name])) + helper.make_tensor_value_info( + input_name, TensorProto.INT32 if input_name == "mask" else float_type, list(shape_dict[input_name]) + ) for input_name in input_names if input_name ] diff --git a/onnxruntime/test/python/transformers/test_mha.py b/onnxruntime/test/python/transformers/test_mha.py index a35d02b0b9d52..5948f8b1ccfc1 100644 --- a/onnxruntime/test/python/transformers/test_mha.py +++ b/onnxruntime/test/python/transformers/test_mha.py @@ -14,9 +14,15 @@ import numpy import torch -from benchmark_mha import InputFormats, MultiHeadAttentionConfig, OrtMultiHeadAttention, SdpaKernel, create_ort_session +from benchmark_mha import ( + AttentionMaskFormat, + InputFormats, + MultiHeadAttentionConfig, + OrtMultiHeadAttention, + SdpaKernel, + create_ort_session, +) from einops import rearrange -from parameterized import parameterized import onnxruntime @@ -67,11 +73,11 @@ def attention_reference( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - scale: Optional[float] = None, mask: Optional[torch.Tensor] = None, + scale: Optional[float] = None, verbose: bool = False, ) -> torch.Tensor: - """Reference implementation of Dot Product Attention + """Reference implementation of SDPA Args: head_size (int): dimension per head @@ -82,7 +88,7 @@ def attention_reference( mask (Optional[torch.Tensor], optional): attention mask. Defaults to None. Returns: - torch.Tensor: result of dot product attention + torch.Tensor: result of SDPA """ if scale is None: scale = 1.0 / (head_size**0.5) @@ -93,6 +99,7 @@ def attention_reference( assert value.dim() == 4 if verbose: + torch.set_printoptions(precision=6, linewidth=200, sci_mode=False) print("query(SDPA)", query) print("key(SDPA)", key) print("value(SDPA)", value) @@ -101,11 +108,14 @@ def attention_reference( # Apply multi-head attention. attn = torch.einsum("bhmd,bhnd->bhmn", query, key).float() * scale - if mask is not None: - attn = attn.masked_fill((1 - mask.int()).bool(), float("-inf")) if verbose: print("QK(SDPA)", attn) + if mask is not None: + attn = attn.masked_fill((1 - mask.int()).bool(), float("-inf")) + if verbose: + print("masked QK(SDPA)", attn) + attn = attn.softmax(-1) if verbose: print("Softmax(SDPA)", attn) @@ -170,6 +180,12 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): heads = [1, 3, 4, 16] head_sizes = [8, 16, 32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + mask_formats = [ + AttentionMaskFormat.Mask_None, + AttentionMaskFormat.Mask_1D_Key_SeqLen, + AttentionMaskFormat.Mask_2D_Key_PaddingMask, + ] + device, dtype, formats = get_provider_support_info(provider, False) if comprehensive: sequence_lengths = [*sequence_lengths, 2048] # Large sequence length is slow and need a lot of memory @@ -179,25 +195,27 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): for head_size in head_sizes: for format in formats: for causal in [True, False]: - for has_bias in get_bias_support(format): - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - past_sequence_length=0, - kv_sequence_length=sequence_length, - max_cache_sequence_length=None, - provider=provider, - device=device, - dtype=dtype, - use_kv_cache=False, - share_past_present_buffer=False, - input_format=format, - has_bias=has_bias, - ) - yield config + for mask_format in mask_formats: + for has_bias in get_bias_support(format): + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=0, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=False, + share_past_present_buffer=False, + input_format=format, + has_bias=has_bias, + mask_format=mask_format, + ) + yield config else: test_cases = max(len(batch_sizes), len(sequence_lengths), len(heads), len(head_sizes)) for i in range(test_cases): @@ -205,6 +223,7 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): sequence_length = sequence_lengths[i % len(sequence_lengths)] num_heads = heads[i % len(heads)] head_size = head_sizes[i % len(head_sizes)] + mask_format = mask_formats[i % len(mask_formats)] for causal in [True, False]: for format in formats: for has_bias in get_bias_support(format): @@ -224,6 +243,7 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): share_past_present_buffer=False, input_format=format, has_bias=has_bias, + mask_format=mask_format, ) yield config @@ -238,6 +258,11 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): heads = [1, 3, 4, 16] head_sizes = [8, 16, 32, 40, 64, 80, 96, 128, 160, 192, 224, 256] device, dtype, formats = get_provider_support_info(provider, True) + mask_formats = [ + AttentionMaskFormat.Mask_None, + AttentionMaskFormat.Mask_1D_Key_SeqLen, + AttentionMaskFormat.Mask_2D_Key_PaddingMask, + ] if comprehensive: sequence_lengths = [*sequence_lengths, 2048] # Large sequence length is slow and need a lot of memory @@ -248,28 +273,30 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): for format in formats: for causal in [True, False]: for has_past_input in [True, False]: - for has_bias in get_bias_support(format): - sequence_length = 1 if has_past_input else past_sequence_length - past_seq_len = past_sequence_length if has_past_input else 0 - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - past_sequence_length=past_seq_len, - kv_sequence_length=sequence_length, - max_cache_sequence_length=None, - provider=provider, - device=device, - dtype=dtype, - use_kv_cache=True, - has_past_input=has_past_input, - share_past_present_buffer=False, - input_format=format, - has_bias=has_bias, - ) - yield config + for mask_format in mask_formats: + for has_bias in get_bias_support(format): + sequence_length = 1 if has_past_input else past_sequence_length + past_seq_len = past_sequence_length if has_past_input else 0 + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=past_seq_len, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=True, + has_past_input=has_past_input, + share_past_present_buffer=False, + input_format=format, + has_bias=has_bias, + mask_format=mask_format, + ) + yield config else: test_cases = max(len(batch_sizes), len(sequence_lengths), len(heads), len(head_sizes)) for i in range(test_cases): @@ -277,6 +304,7 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): past_sequence_length = sequence_lengths[i % len(sequence_lengths)] num_heads = heads[i % len(heads)] head_size = head_sizes[i % len(head_sizes)] + mask_format = mask_formats[i % len(mask_formats)] for causal in [True, False]: for format in formats: for has_past_input in [True, False]: @@ -300,6 +328,7 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): share_past_present_buffer=False, input_format=format, has_bias=has_bias, + mask_format=mask_format, ) yield config @@ -392,6 +421,23 @@ def causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=No return col_idx <= row_idx + sk - sq +def merge_padding_and_causal_masks(config): + + q_mask, k_mask, mask = config.right_side_padding_masks() + if config.causal: + query_padding_mask = q_mask.reshape(config.batch_size, config.sequence_length) + key_padding_mask = k_mask.reshape(config.batch_size, config.total_sequence_length) + mask = causal_mask( + config.sequence_length, + config.total_sequence_length, + query_padding_mask, + key_padding_mask, + device=config.device, + ) + + return mask + + def parity_check_mha( config: MultiHeadAttentionConfig, rtol=1e-3, @@ -406,6 +452,7 @@ def parity_check_mha( out = ort_outputs["output"] out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + ort_input_format = config.input_format no_bias_k_v = config.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH config.input_format = InputFormats.Q_K_V_BSNH_BSNH_BSNH ref_inputs = config.random_inputs(no_bias_k_v=no_bias_k_v) @@ -427,10 +474,7 @@ def parity_check_mha( k = k.transpose(1, 2) v = v.transpose(1, 2) - mask = None - if config.causal: - mask = causal_mask(config.sequence_length, config.total_sequence_length, device=config.device) - + mask = merge_padding_and_causal_masks(config) k_cache = None v_cache = None if config.use_kv_cache: @@ -440,6 +484,26 @@ def parity_check_mha( else: out_ref = attention_reference(config.head_size, q, k, v, mask=mask) + # Fill zeros for the padded kens for comparison. + if config.mask_index_q is not None: + for i, m in enumerate(config.mask_index_q): + out[i, m:, :, :] = 0 + out_ref[i, m:, :, :] = 0 + + if config.mask_index_kv is not None and config.use_kv_cache: + assert k_cache is not None + assert v_cache is not None + present_key = ort_outputs["present_key"] + present_value = ort_outputs["present_value"] + for i, n in enumerate(config.mask_index_kv): + k_cache[i, :, n:, :] = 0 + present_key[i, :, n:, :] = 0 + v_cache[i, :, n:, :] = 0 + present_value[i, :, n:, :] = 0 + + # Restore the input format so that it shows up in the error message correctly. + config.input_format = ort_input_format + numpy.testing.assert_allclose( out.detach().cpu().numpy(), out_ref.detach().cpu().numpy(), @@ -540,10 +604,7 @@ def check_parity_with_config(i: int): .transpose(1, 2) ) - mask = None - if config.causal: - mask = causal_mask(config.sequence_length, config.total_sequence_length, device=config.device) - + mask = merge_padding_and_causal_masks(config) k_cache = None v_cache = None if config.use_kv_cache: @@ -622,13 +683,13 @@ def multi_thread_test_cases(provider: str, comprehensive: bool): class TestMultiHeadAttention(unittest.TestCase): - @parameterized.expand(mha_test_cases("CUDAExecutionProvider", comprehensive_mode), skip_on_empty=True) - def test_mha_cuda(self, config): - parity_check_mha(config, rtol=5e-3, atol=5e-3) + def run_mha_cuda(self): + for config in mha_test_cases("CUDAExecutionProvider", comprehensive_mode): + parity_check_mha(config, rtol=5e-3, atol=5e-3) - @parameterized.expand(mha_test_cases("CPUExecutionProvider", comprehensive_mode), skip_on_empty=True) - def test_mha_cpu(self, config): - parity_check_mha(config, rtol=5e-3, atol=5e-3) + def run_mha_cpu(self): + for config in mha_test_cases("CPUExecutionProvider", comprehensive_mode): + parity_check_mha(config, rtol=5e-3, atol=5e-3) def run_mha_cuda_multi_threading(self, attention_kernel): for configs in multi_thread_test_cases("CUDAExecutionProvider", comprehensive_mode): @@ -646,21 +707,21 @@ def run_mha_cuda_multi_threading(self, attention_kernel): exception = parity_check_mha_multi_threading( test_inputs, attention_kernel=attention_kernel, max_threads=len(configs) ) - assert exception is None, f"{attention_kernel=}, {vars(configs[0])}, {exception}" + assert exception is None, f"Multi-threading failed: {attention_kernel=}, {vars(configs[0])}, {exception}" - def test_mha_cuda_multi_threading(self): + def run_mha_cuda_multi_threading_default(self): if get_compute_capability() >= 60: self.run_mha_cuda_multi_threading(SdpaKernel.DEFAULT) - def test_mha_cuda_multi_threading_efficient(self): + def run_mha_cuda_multi_threading_efficient(self): if comprehensive_mode and get_compute_capability() >= 60: self.run_mha_cuda_multi_threading(SdpaKernel.EFFICIENT_ATTENTION) - def test_mha_cuda_multi_threading_math(self): + def run_mha_cuda_multi_threading_math(self): if comprehensive_mode and get_compute_capability() >= 60: self.run_mha_cuda_multi_threading(SdpaKernel.MATH) - def test_mha_cuda_multi_threading_trt(self): + def run_mha_cuda_multi_threading_trt(self): if get_compute_capability() in [75, 80, 86, 89]: self.run_mha_cuda_multi_threading( SdpaKernel.TRT_FUSED_ATTENTION @@ -669,6 +730,15 @@ def test_mha_cuda_multi_threading_trt(self): | SdpaKernel.TRT_CROSS_ATTENTION ) + def test_all(self): + # Run tests sequentially to avoid out of memory issue. + self.run_mha_cpu() + self.run_mha_cuda() + self.run_mha_cuda_multi_threading_default() + self.run_mha_cuda_multi_threading_efficient() + self.run_mha_cuda_multi_threading_math() + self.run_mha_cuda_multi_threading_trt() + if __name__ == "__main__": with torch.no_grad():