diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 0960a9efe7699..52bfe61608f62 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -46,8 +46,6 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) scale_ = info.GetAttrOrDefault("scale", 0.0f); is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; - ORT_ENFORCE(!is_unidirectional_, - "MHA support CUDA kernel does not Unidirectional. Consider using Attention or GQA instead."); kernel_options_ = this->GetAttentionKernelOptions(); @@ -208,13 +206,13 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { bool use_fused_cross_attention = kernel_type == AttentionKernelType::AttentionKernel_Default && !disable_fused_cross_attention_ && + !is_unidirectional_ && nullptr == key_padding_mask && nullptr == attention_bias && nullptr == past_key && nullptr == present_key && (parameters.qkv_format == Q_K_V_BSNH || (parameters.qkv_format == Q_KV_BSNH_BSN2H && bias == nullptr)) && parameters.hidden_size == parameters.v_hidden_size && - has_fused_cross_attention_kernel(sm, parameters.head_size, - parameters.kv_sequence_length); + has_fused_cross_attention_kernel(sm, parameters.head_size, parameters.kv_sequence_length); if (use_fused_cross_attention) { if (fused_fp16_cross_attention_kernel_ == nullptr) { std::call_once(fused_cross_init_once_flag_, [&]() { @@ -233,6 +231,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { bool use_fused_runner = kernel_type == AttentionKernelType::AttentionKernel_Default && !disable_fused_self_attention_ && + !is_unidirectional_ && nullptr == attention_bias && (parameters.qkv_format == Q_K_V_BSNH || parameters.qkv_format == QKV_BSN3H) && nullptr == past_key && nullptr == present_key && @@ -240,13 +239,12 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.hidden_size == parameters.v_hidden_size && parameters.sequence_length == parameters.kv_sequence_length && // self attention only for fused runner FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length, - enable_trt_flash_attention_, false); + enable_trt_flash_attention_, is_unidirectional_); if (use_fused_runner) { // Here we assume that num_heads and head_size does not change for a MultiHeadAttention node. if (nullptr == fused_fp16_runner_.get()) { - constexpr bool is_unidirectional = false; std::call_once(fused_fp16_runner_created_, [&]() { - fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional, + fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, enable_trt_flash_attention_, parameters.scale); }); } diff --git a/onnxruntime/python/tools/transformers/io_binding_helper.py b/onnxruntime/python/tools/transformers/io_binding_helper.py index 4f46242a4f402..2375104ac96f5 100644 --- a/onnxruntime/python/tools/transformers/io_binding_helper.py +++ b/onnxruntime/python/tools/transformers/io_binding_helper.py @@ -304,7 +304,7 @@ def allocate_buffers(self, shape_dict: Dict[str, Union[Tuple[int], List[int]]]): tensor.data_ptr(), ) - def infer(self, feed_dict: Dict[str, torch.Tensor], run_options: RunOptions = None, synchronize: bool = False): + def infer(self, feed_dict: Dict[str, torch.Tensor], run_options: RunOptions = None, synchronize: bool = True): """Bind input tensors and run inference""" for name, tensor in feed_dict.items(): assert isinstance(tensor, torch.Tensor) and tensor.is_contiguous() @@ -317,7 +317,6 @@ def infer(self, feed_dict: Dict[str, torch.Tensor], run_options: RunOptions = No else: self.bind_input_and_buffer_sharing(name, tensor) - # Synchronization are not needed in most cases unless different streams are used or inputs/outputs are in CPU. if synchronize: self.io_binding.synchronize_inputs() self.ort_session.run_with_iobinding(self.io_binding, run_options) diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index 2a3541db4c9b5..d8acb66158ed2 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -587,8 +587,8 @@ def __init__(self, config: MultiHeadAttentionConfig, session_options=None, use_t self.ort_session = create_session(config, session_options, use_tf32=use_tf32) self.feed_dict = config.random_inputs() - def infer(self): - return self.ort_session.infer(self.feed_dict) + def infer(self, run_options=None, synchronize=True): + return self.ort_session.infer(self.feed_dict, run_options=run_options, synchronize=synchronize) def measure_latency(cuda_session: CudaSession, input_dict): @@ -1356,7 +1356,6 @@ def _parse_arguments(): args.repeats = 10000 if args.use_gpu else 100 if args.use_gpu: - assert args.torch or not args.causal, "no causal cuda kernel in MHA op" assert torch.cuda.is_available() if not args.torch: assert "CUDAExecutionProvider" in get_available_providers() diff --git a/onnxruntime/test/python/transformers/test_mha.py b/onnxruntime/test/python/transformers/test_mha.py index 92653ffb053ce..69f0035ef8a17 100644 --- a/onnxruntime/test/python/transformers/test_mha.py +++ b/onnxruntime/test/python/transformers/test_mha.py @@ -68,6 +68,22 @@ def get_bias_support(format: InputFormats): raise RuntimeError(f"Unknown format: {format}") +def get_causal_support(format: InputFormats): + if format == InputFormats.Q_K_V_BSNH_BSNH_BSNH: + return [True, False] + + if format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: + return [True, False] + + if format == InputFormats.Q_KV_BSNH_BSN2H: + return [True, False] + + if format == InputFormats.QKV_BSN3H: + return [True, False] + + raise RuntimeError(f"Unknown format: {format}") + + def get_atten_bias_support(): atten_bias_options = [ # (has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1) @@ -215,7 +231,7 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): for num_heads in heads: for head_size in head_sizes: for format in formats: - for causal in [True, False]: + for causal in get_causal_support(format): for mask_format in mask_formats: for has_bias in get_bias_support(format): for ( @@ -256,8 +272,8 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1 = atten_bias_options[ i % len(atten_bias_options) ] - for causal in [True, False]: - for format in formats: + for format in formats: + for causal in get_causal_support(format): for has_bias in get_bias_support(format): config = MultiHeadAttentionConfig( batch_size=batch_size, @@ -308,7 +324,7 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): for num_heads in heads: for head_size in head_sizes: for format in formats: - for causal in [True, False]: + for causal in get_causal_support(format): for has_past_input in [True, False]: for mask_format in mask_formats: for has_bias in get_bias_support(format): @@ -353,8 +369,8 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1 = atten_bias_options[ i % len(atten_bias_options) ] - for causal in [True, False]: - for format in formats: + for format in formats: + for causal in get_causal_support(format): for has_past_input in [True, False]: for has_bias in get_bias_support(format): sequence_length = 1 if has_past_input else past_sequence_length @@ -397,7 +413,7 @@ def no_kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): device, dtype, formats = get_provider_support_info(provider, False) for format in formats: - for causal in [True, False]: + for causal in get_causal_support(format): for num_heads in heads: for head_size in head_sizes: configs = [] # list of configurations to run in parallel @@ -437,7 +453,7 @@ def kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): device, dtype, formats = get_provider_support_info(provider, True) for format in formats: - for causal in [True, False]: + for causal in get_causal_support(format): for num_heads in heads: for head_size in head_sizes: configs = [] @@ -494,12 +510,8 @@ def parity_check_mha( rtol=1e-3, atol=1e-3, ): - # CUDA kernel does not support causal so skip such test cases. - if config.causal and config.provider == "CUDAExecutionProvider": - return - ort_mha = OrtMultiHeadAttention(config, use_tf32=False) - ort_outputs = ort_mha.infer() + ort_outputs = ort_mha.infer(synchronize=True) out = ort_outputs["output"] out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) @@ -602,9 +614,6 @@ def parity_check_mha_multi_threading( ): # Use the first config to create a session, which is shared by all configs to run in parallel. config = test_inputs[0]["config"] - # For now, MHA CUDA kernel does not support causal so skip such test cases. - if config.causal and config.provider == "CUDAExecutionProvider": - return None # Some kernel does not support certain input format. if attention_kernel not in [ @@ -784,6 +793,10 @@ def run_mha_cpu(self): def run_mha_cuda_multi_threading(self, attention_kernel): for configs in multi_thread_test_cases("CUDAExecutionProvider", comprehensive_mode): + if configs and configs[0].causal and (SdpaKernel.TRT_CAUSAL_ATTENTION & attention_kernel != 0): + # TRT fused causal is disabled by default so skip the test of causal for multi-threading. + continue + test_inputs = [] for config in configs: ort_inputs = config.random_inputs()