Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Does com.microsoft.Attention use FlashAttention-2? #18474

Open
Numeri opened this issue Nov 16, 2023 · 6 comments
Open

[Performance] Does com.microsoft.Attention use FlashAttention-2? #18474

Numeri opened this issue Nov 16, 2023 · 6 comments
Labels
ep:CUDA issues related to the CUDA execution provider

Comments

@Numeri
Copy link

Numeri commented Nov 16, 2023

Describe the issue

I recently replaced the self-attention subgraphs in a custom ONNX graph with the com.microsoft.Attention operator, which resulted in a noticeable speed-boost (20–50% with a batch size of 1, depending on sequence length, on a T4 GPU with the CUDA Execution provider). This matched my expectations for how much performance could be improved by fusing those operations.

When I upgraded from 1.15.1 to 1.16.2, however, I was hoping to see further speed improvements, as the release notes for 1.16.0 say:

Added FlashAttention v2 support for Attention, MultiHeadAttention and PackedMultiHeadAttention ops

Unfortunately, the latency is essentially identical for both versions. Looking into the code, I noticed that these FlashAttention tests only test the MultiHeadAttention, PackedMultiHeadAttention and GroupQueryAttention operators. Are those the only operators with FlashAttention?


Follow-up question: When running the graph with sequence lengths > 1024, I get the following exception:

Status Message: Attention CUDA operator does not support total sequence length > 1024.

I would have expected FlashAttention to have supported much longer sequence lengths. Does this just mean I'm somehow still using the old Attention operator even in version 1.16.1?

To reproduce

I unfortunately can't provide the models. I can make an MWE model and code, but was hoping that this could be answered quickly without needing those.

Urgency

Not urgent.

Platform

Linux

OS Version

Ubuntu 20.04

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.16.2

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU, CUDA

Execution Provider Library Version

CUDA 11.8

Model File

No response

Is this a quantized model?

No

@github-actions github-actions bot added the ep:CUDA issues related to the CUDA execution provider label Nov 16, 2023
@tianleiwu
Copy link
Contributor

It depends on the input. Currently, there are some conditions that will trigger flash attention (For example, no attention mask, no past state, and sequence length > 512, Linux only, GPU cuda architecture SM=80~89):

bool use_flash_attention = !disable_flash_attention_ &&
(nullptr == relative_position_bias) &&
nullptr == past &&
nullptr == present &&
parameters.hidden_size == parameters.v_hidden_size &&
nullptr == mask_index &&
onnxruntime::flash::is_supported(device_prop,
parameters.head_size,
parameters.num_heads,
parameters.num_heads);
// When input is packed QKV format, TensorRT kernel might be faster when sequence length <= 512.
if (use_flash_attention && parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) {
use_flash_attention = false;
}

@Numeri
Copy link
Author

Numeri commented Nov 17, 2023

Interesting – why? Only one of those restrictions apply to FlashAttention-v2 (the GPU architecture):

  • The flash_attn_with_kvcache method's cache_seqlen flag supports right padding, just like Attention's mask_index
  • It supports using a past state cache (although only in-place/with a buffer) with the flash_attn_with_kvcache method
  • I'm guessing the restriction to token lengths more than 512 is based on profiling, but it really surprises me that 512 is the sweet spot (I'm sure this depends heavily on GPU/batch size, etc. but here they report about 150 to 200 tokens)
  • It does only support Ampere GPUs at the moment, but support for Turing GPUs is being worked on

@tianleiwu
Copy link
Contributor

Interesting – why? Only one of those restrictions apply to FlashAttention-v2 (the GPU architecture):

  • The flash_attn_with_kvcache method's cache_seqlen flag supports right padding, just like Attention's mask_index
  • It supports using a past state cache (although only in-place/with a buffer) with the flash_attn_with_kvcache method
  • I'm guessing the restriction to token lengths more than 512 is based on profiling, but it really surprises me that 512 is the sweet spot (I'm sure this depends heavily on GPU/batch size, etc. but here they report about 150 to 200 tokens)
  • It does only support Ampere GPUs at the moment, but support for Turing GPUs is being worked on

Some limitations can be removed later (like support padding, cache, and maybe add some provider configuration so that user can set their own threshold).

@Numeri
Copy link
Author

Numeri commented Dec 4, 2023

Are there currently any plans to do this? I'd love to speed up my models, but I always do inference with a past state.

@joewue
Copy link

joewue commented Apr 30, 2024

Hi guys! Do you have any update on removing these limitations? Would be super helpful to us :)

Thanks!

@tianleiwu
Copy link
Contributor

If you want to higher coverage of FlashAttention-2, consider the following options:

For GPT like model with past state, use MuiltiHeadAttention or GroupQueryAttention operator. Please make sure kv cache buffer is allocated to max length and share the buffer of past and present using I/O binding.

For BERT like model, try PackedMultiHeadAttention operator. You can use onnxruntime transformer optimizer.py to get optimized model with MultiHeadAttention operators. Then run the following script to convert MultiHeadAttention to PackedMultiHeadAttention:
https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/convert_to_packing_mode.py
Basically, padding will be removed in most part of the model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider
Projects
None yet
Development

No branches or pull requests

3 participants