-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Performance] Accuracy far off for bge when enabling attention fusion - cuda ep #18945
Comments
Tested with BERT as well, met the same accuracy issue when enabling attention fusion:
|
There is assumption for Bert like Attention fusion that it is right-side padding, and attention mask is 0 for padding. Example:
If attention mask meets such requirement, attention fusion has good accuracy and also better performance. Here is example of comparing optimization level 1 and 2 (Note that Level 3 enabled Gelu approximation and Level 4 enabled mixed precision, both leads to larger difference: 0.3 max difference looks reasonable for mixed precision model):
In optimum, the dummy inputs for validation has random mask like
So, we can see accuracy is far off for such case. If the attention mask is indeed like the above. A walkaround is like the following (The optimized model can handle random mask, but it will get worse performance than the model optimized for right-side padding):
|
It's super clear, thanks @tianleiwu ! With Another do you know what other architectures are designed for right-padding, and could lead to a similar issue? I would like to add the information to the documentation of Optimum. Thx! |
For BERT-like models (BERT, RoBERTa, VIT), right padding is used in most cases. For text generation models, like GPT, left padding is used in most cases. Since past key/value cache is used in text generation, it still need some extra work to enable attention fusion for those models for Optimum. Currently this issue will not impact those models in Optimum. |
Thanks again for helping @tianleiwu! I opened a PR in Optimum to better create dummy attn mask tensors and add some tips for attention fusion in the Optimum documentation. |
Describe the issue
When setting
options.enable_attention=True
ofFusionOptions
for bge models with Cuda ep, the outputs are far off from PyTorch on CPU. Since bge is sharing the same modeling as BERT, I wonder what could lead to the accuracy issue.To reproduce
Each test the optimization script in ORT with model
BAAI/bge-small-en-v1.5
, or using the optimum cli:Urgency
bge is very popular model, would be nice to have it taking full advantage of ORT optimization.
Platform
Linux
OS Version
ubuntu 20.04
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.16.3
ONNX Runtime API
Python
Architecture
X64
Execution Provider
CUDA
Execution Provider Library Version
CUDA 11.8
Model File
https://huggingface.co/BAAI/bge-small-en-v1.5
Is this a quantized model?
No
The text was updated successfully, but these errors were encountered: