Skip to content

Commit

Permalink
[CUDA] enable causal in MultiHeadAttention (#21852)
Browse files Browse the repository at this point in the history
### Description
Enable causal in MultiHeadAttention cuda operator.

All formats (Q_K_V_BSNH_BSNH_BSNH, Q_K_V_BSNH_BNSH_BNSH, Q_KV_BSNH_BSN2H
and QKV_BSN3H) supports causal for now. Internally, casual will be
dispatch to flash attention, efficient attention or unfused attention
kernel.

### Motivation and Context
Currently, MultiHeadAttention has causal enabled in CPU ep, but not in
CUDA ep. It could cause issues in onnx conversion, like some model can
run in CPU but not in CUDA. Enable causal in CUDA will reduce the
difference of support matrix of CPU/CUDA.
  • Loading branch information
tianleiwu authored Aug 26, 2024
1 parent d9c57ac commit ad38212
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 28 deletions.
12 changes: 5 additions & 7 deletions onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info)

scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
ORT_ENFORCE(!is_unidirectional_,
"MHA support CUDA kernel does not Unidirectional. Consider using Attention or GQA instead.");

kernel_options_ = this->GetAttentionKernelOptions();

Expand Down Expand Up @@ -208,13 +206,13 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
bool use_fused_cross_attention =
kernel_type == AttentionKernelType::AttentionKernel_Default &&
!disable_fused_cross_attention_ &&
!is_unidirectional_ &&
nullptr == key_padding_mask &&
nullptr == attention_bias &&
nullptr == past_key && nullptr == present_key &&
(parameters.qkv_format == Q_K_V_BSNH || (parameters.qkv_format == Q_KV_BSNH_BSN2H && bias == nullptr)) &&
parameters.hidden_size == parameters.v_hidden_size &&
has_fused_cross_attention_kernel(sm, parameters.head_size,
parameters.kv_sequence_length);
has_fused_cross_attention_kernel(sm, parameters.head_size, parameters.kv_sequence_length);
if (use_fused_cross_attention) {
if (fused_fp16_cross_attention_kernel_ == nullptr) {
std::call_once(fused_cross_init_once_flag_, [&]() {
Expand All @@ -233,20 +231,20 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
bool use_fused_runner =
kernel_type == AttentionKernelType::AttentionKernel_Default &&
!disable_fused_self_attention_ &&
!is_unidirectional_ &&
nullptr == attention_bias &&
(parameters.qkv_format == Q_K_V_BSNH || parameters.qkv_format == QKV_BSN3H) &&
nullptr == past_key && nullptr == present_key &&
is_mask_none_or_1d_k_len &&
parameters.hidden_size == parameters.v_hidden_size &&
parameters.sequence_length == parameters.kv_sequence_length && // self attention only for fused runner
FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, false);
enable_trt_flash_attention_, is_unidirectional_);
if (use_fused_runner) {
// Here we assume that num_heads and head_size does not change for a MultiHeadAttention node.
if (nullptr == fused_fp16_runner_.get()) {
constexpr bool is_unidirectional = false;
std::call_once(fused_fp16_runner_created_, [&]() {
fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional,
fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_,
enable_trt_flash_attention_, parameters.scale);
});
}
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/python/tools/transformers/io_binding_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def allocate_buffers(self, shape_dict: Dict[str, Union[Tuple[int], List[int]]]):
tensor.data_ptr(),
)

def infer(self, feed_dict: Dict[str, torch.Tensor], run_options: RunOptions = None, synchronize: bool = False):
def infer(self, feed_dict: Dict[str, torch.Tensor], run_options: RunOptions = None, synchronize: bool = True):
"""Bind input tensors and run inference"""
for name, tensor in feed_dict.items():
assert isinstance(tensor, torch.Tensor) and tensor.is_contiguous()
Expand All @@ -317,7 +317,6 @@ def infer(self, feed_dict: Dict[str, torch.Tensor], run_options: RunOptions = No
else:
self.bind_input_and_buffer_sharing(name, tensor)

# Synchronization are not needed in most cases unless different streams are used or inputs/outputs are in CPU.
if synchronize:
self.io_binding.synchronize_inputs()
self.ort_session.run_with_iobinding(self.io_binding, run_options)
Expand Down
5 changes: 2 additions & 3 deletions onnxruntime/test/python/transformers/benchmark_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,8 @@ def __init__(self, config: MultiHeadAttentionConfig, session_options=None, use_t
self.ort_session = create_session(config, session_options, use_tf32=use_tf32)
self.feed_dict = config.random_inputs()

def infer(self):
return self.ort_session.infer(self.feed_dict)
def infer(self, run_options=None, synchronize=True):
return self.ort_session.infer(self.feed_dict, run_options=run_options, synchronize=synchronize)


def measure_latency(cuda_session: CudaSession, input_dict):
Expand Down Expand Up @@ -1356,7 +1356,6 @@ def _parse_arguments():
args.repeats = 10000 if args.use_gpu else 100

if args.use_gpu:
assert args.torch or not args.causal, "no causal cuda kernel in MHA op"
assert torch.cuda.is_available()
if not args.torch:
assert "CUDAExecutionProvider" in get_available_providers()
Expand Down
45 changes: 29 additions & 16 deletions onnxruntime/test/python/transformers/test_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,22 @@ def get_bias_support(format: InputFormats):
raise RuntimeError(f"Unknown format: {format}")


def get_causal_support(format: InputFormats):
if format == InputFormats.Q_K_V_BSNH_BSNH_BSNH:
return [True, False]

if format == InputFormats.Q_K_V_BSNH_BNSH_BNSH:
return [True, False]

if format == InputFormats.Q_KV_BSNH_BSN2H:
return [True, False]

if format == InputFormats.QKV_BSN3H:
return [True, False]

raise RuntimeError(f"Unknown format: {format}")


def get_atten_bias_support():
atten_bias_options = [
# (has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1)
Expand Down Expand Up @@ -215,7 +231,7 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool):
for num_heads in heads:
for head_size in head_sizes:
for format in formats:
for causal in [True, False]:
for causal in get_causal_support(format):
for mask_format in mask_formats:
for has_bias in get_bias_support(format):
for (
Expand Down Expand Up @@ -256,8 +272,8 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool):
has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1 = atten_bias_options[
i % len(atten_bias_options)
]
for causal in [True, False]:
for format in formats:
for format in formats:
for causal in get_causal_support(format):
for has_bias in get_bias_support(format):
config = MultiHeadAttentionConfig(
batch_size=batch_size,
Expand Down Expand Up @@ -308,7 +324,7 @@ def kv_cache_test_cases(provider: str, comprehensive: bool):
for num_heads in heads:
for head_size in head_sizes:
for format in formats:
for causal in [True, False]:
for causal in get_causal_support(format):
for has_past_input in [True, False]:
for mask_format in mask_formats:
for has_bias in get_bias_support(format):
Expand Down Expand Up @@ -353,8 +369,8 @@ def kv_cache_test_cases(provider: str, comprehensive: bool):
has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1 = atten_bias_options[
i % len(atten_bias_options)
]
for causal in [True, False]:
for format in formats:
for format in formats:
for causal in get_causal_support(format):
for has_past_input in [True, False]:
for has_bias in get_bias_support(format):
sequence_length = 1 if has_past_input else past_sequence_length
Expand Down Expand Up @@ -397,7 +413,7 @@ def no_kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool):
device, dtype, formats = get_provider_support_info(provider, False)

for format in formats:
for causal in [True, False]:
for causal in get_causal_support(format):
for num_heads in heads:
for head_size in head_sizes:
configs = [] # list of configurations to run in parallel
Expand Down Expand Up @@ -437,7 +453,7 @@ def kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool):
device, dtype, formats = get_provider_support_info(provider, True)

for format in formats:
for causal in [True, False]:
for causal in get_causal_support(format):
for num_heads in heads:
for head_size in head_sizes:
configs = []
Expand Down Expand Up @@ -494,12 +510,8 @@ def parity_check_mha(
rtol=1e-3,
atol=1e-3,
):
# CUDA kernel does not support causal so skip such test cases.
if config.causal and config.provider == "CUDAExecutionProvider":
return

ort_mha = OrtMultiHeadAttention(config, use_tf32=False)
ort_outputs = ort_mha.infer()
ort_outputs = ort_mha.infer(synchronize=True)
out = ort_outputs["output"]
out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size))

Expand Down Expand Up @@ -602,9 +614,6 @@ def parity_check_mha_multi_threading(
):
# Use the first config to create a session, which is shared by all configs to run in parallel.
config = test_inputs[0]["config"]
# For now, MHA CUDA kernel does not support causal so skip such test cases.
if config.causal and config.provider == "CUDAExecutionProvider":
return None

# Some kernel does not support certain input format.
if attention_kernel not in [
Expand Down Expand Up @@ -784,6 +793,10 @@ def run_mha_cpu(self):

def run_mha_cuda_multi_threading(self, attention_kernel):
for configs in multi_thread_test_cases("CUDAExecutionProvider", comprehensive_mode):
if configs and configs[0].causal and (SdpaKernel.TRT_CAUSAL_ATTENTION & attention_kernel != 0):
# TRT fused causal is disabled by default so skip the test of causal for multi-threading.
continue

test_inputs = []
for config in configs:
ort_inputs = config.random_inputs()
Expand Down

0 comments on commit ad38212

Please sign in to comment.