Skip to content

Commit

Permalink
Fix Attention Runtime Error for CLIP model (microsoft#17729)
Browse files Browse the repository at this point in the history
### Description
The condition check is not correct
```
if (is_unidirectional_ && enable_fused_causal_attention_) {  // GPT
}
else { // BERT
}
```

Change it to 
```
if (is_unidirectional_) {  // GPT
}
else { // BERT
}
```

Another walkaround is to enable fused causal attention by adding an
environment variable `ORT_ENABLE_FUSED_CAUSAL_ATTENTION=1` before
running stable diffusion.

### Motivation and Context

Without the fix, optimized CLIP model of stable diffusion will encounter
error in running Attention node:

2023-09-24 16:15:31.206037898 [E:onnxruntime:,
sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned
while running Attention node. Name:'Attention_0' Status Message:
/onnxruntime_src/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu:207
bool
onnxruntime::contrib::cuda::FusedMHARunnerFP16v2::mhaImpl::is_flash_attention(int)
const interface->mHasCausalMask == false was false.

Note that the bug has been there for a long time. It is just surfaced
since we recently added a fusion for CLIP, which will trigger the error.

We will add a comprehensive test for causal attention later to avoid
such corner cases.
  • Loading branch information
tianleiwu authored and kleiti committed Mar 22, 2024
1 parent 5e89c73 commit a397e1f
Showing 1 changed file with 22 additions and 20 deletions.
42 changes: 22 additions & 20 deletions onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,27 +140,29 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
#endif

if (!use_flash_attention) {
if (is_unidirectional_ && enable_fused_causal_attention_) { // GPT
// GPT fused kernels requires left side padding. mask can be:
// none (no padding), 1D sequence lengths or 2d mask.
// Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token
// where past state is empty.
bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING;
bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) &&
nullptr == relative_position_bias &&
parameters.past_sequence_length == 0 &&
parameters.hidden_size == parameters.v_hidden_size &&
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, true);
if (use_causal_fused_runner) {
// Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
if (nullptr == fused_fp16_runner_.get()) {
fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_,
enable_trt_flash_attention_, parameters.scale);
if (is_unidirectional_) { // GPT
if (enable_fused_causal_attention_) {
// GPT fused kernels requires left side padding. mask can be:
// none (no padding), 1D sequence lengths or 2d mask.
// Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token
// where past state is empty.
bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING;
bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) &&
nullptr == relative_position_bias &&
parameters.past_sequence_length == 0 &&
parameters.hidden_size == parameters.v_hidden_size &&
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, true);
if (use_causal_fused_runner) {
// Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
if (nullptr == fused_fp16_runner_.get()) {
fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_,
enable_trt_flash_attention_, parameters.scale);
}

// Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check.
fused_runner = fused_fp16_runner_.get();
}

// Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check.
fused_runner = fused_fp16_runner_.get();
}
} else { // BERT
bool use_fused_runner = !disable_fused_self_attention_ &&
Expand Down

0 comments on commit a397e1f

Please sign in to comment.