-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Performance] Does com.microsoft.Attention
use FlashAttention-2?
#18474
Comments
It depends on the input. Currently, there are some conditions that will trigger flash attention (For example, no attention mask, no past state, and sequence length > 512, Linux only, GPU cuda architecture SM=80~89): onnxruntime/onnxruntime/contrib_ops/cuda/bert/attention.cc Lines 124 to 137 in adb56df
|
Interesting – why? Only one of those restrictions apply to FlashAttention-v2 (the GPU architecture):
|
Some limitations can be removed later (like support padding, cache, and maybe add some provider configuration so that user can set their own threshold). |
Are there currently any plans to do this? I'd love to speed up my models, but I always do inference with a past state. |
Hi guys! Do you have any update on removing these limitations? Would be super helpful to us :) Thanks! |
If you want to higher coverage of FlashAttention-2, consider the following options: For GPT like model with past state, use MuiltiHeadAttention or GroupQueryAttention operator. Please make sure kv cache buffer is allocated to max length and share the buffer of past and present using I/O binding. For BERT like model, try PackedMultiHeadAttention operator. You can use onnxruntime transformer optimizer.py to get optimized model with MultiHeadAttention operators. Then run the following script to convert MultiHeadAttention to PackedMultiHeadAttention: |
Describe the issue
I recently replaced the self-attention subgraphs in a custom ONNX graph with the
com.microsoft.Attention
operator, which resulted in a noticeable speed-boost (20–50% with a batch size of 1, depending on sequence length, on a T4 GPU with the CUDA Execution provider). This matched my expectations for how much performance could be improved by fusing those operations.When I upgraded from
1.15.1
to1.16.2
, however, I was hoping to see further speed improvements, as the release notes for1.16.0
say:Unfortunately, the latency is essentially identical for both versions. Looking into the code, I noticed that these FlashAttention tests only test the
MultiHeadAttention
,PackedMultiHeadAttention
andGroupQueryAttention
operators. Are those the only operators with FlashAttention?Follow-up question: When running the graph with sequence lengths > 1024, I get the following exception:
I would have expected FlashAttention to have supported much longer sequence lengths. Does this just mean I'm somehow still using the old
Attention
operator even in version1.16.1
?To reproduce
I unfortunately can't provide the models. I can make an MWE model and code, but was hoping that this could be answered quickly without needing those.
Urgency
Not urgent.
Platform
Linux
OS Version
Ubuntu 20.04
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.16.2
ONNX Runtime API
Python
Architecture
X64
Execution Provider
Default CPU, CUDA
Execution Provider Library Version
CUDA 11.8
Model File
No response
Is this a quantized model?
No
The text was updated successfully, but these errors were encountered: