Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Accuracy far off for bge when enabling attention fusion - cuda ep #18945

Closed
JingyaHuang opened this issue Dec 27, 2023 · 5 comments
Labels
ep:CUDA issues related to the CUDA execution provider model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.

Comments

@JingyaHuang
Copy link

Describe the issue

When setting options.enable_attention=True of FusionOptions for bge models with Cuda ep, the outputs are far off from PyTorch on CPU. Since bge is sharing the same modeling as BERT, I wonder what could lead to the accuracy issue.

To reproduce

Each test the optimization script in ORT with model BAAI/bge-small-en-v1.5, or using the optimum cli:

optimum-cli export onnx --model BAAI/bge-small-en-v1.5 --task feature-extraction --optimize O4 --device cuda bge_Opt_04_cu118

Urgency

bge is very popular model, would be nice to have it taking full advantage of ORT optimization.

Platform

Linux

OS Version

ubuntu 20.04

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.16.3

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

CUDA 11.8

Model File

https://huggingface.co/BAAI/bge-small-en-v1.5

Is this a quantized model?

No

@github-actions github-actions bot added ep:CUDA issues related to the CUDA execution provider model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. labels Dec 27, 2023
@JingyaHuang
Copy link
Author

Tested with BERT as well, met the same accuracy issue when enabling attention fusion:

root@xxxxxxxxx:/workspace# optimum-cli export onnx --model bert-base-uncased --task feature-extraction --optimize O4 --device cuda bert_Opt_04_cu118
Framework not specified. Using pt to export to ONNX.
config.json: 100%|█████████████████████████████████████████████████| 570/570 [00:00<00:00, 83.7kB/s]
model.safetensors: 100%|██████████████████████████████████████████| 440M/440M [00:02<00:00, 214MB/s]
tokenizer_config.json: 100%|█████████████████████████████████████| 28.0/28.0 [00:00<00:00, 4.50kB/s]
vocab.txt: 100%|█████████████████████████████████████████████████| 232k/232k [00:00<00:00, 3.84MB/s]
tokenizer.json: 100%|████████████████████████████████████████████| 466k/466k [00:00<00:00, 6.90MB/s]
Using the export variant default. Available variants are:
    - default: The default ONNX variant.
Using framework PyTorch: 2.0.0+cu117
Overriding 1 configuration item(s)
        - use_cache -> False
============= Diagnostic Run torch.onnx.export version 2.0.0+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

2023-12-27 15:36:43.118577917 [W:onnxruntime:, session_state.cc:1162 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2023-12-27 15:36:43.118609489 [W:onnxruntime:, session_state.cc:1164 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.
Overridding for_gpu=False to for_gpu=True as half precision is available only on GPU.
/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/optimum/onnxruntime/configuration.py:770: FutureWarning: disable_embed_layer_norm will be deprecated soon, use disable_embed_layer_norm_fusion instead, disable_embed_layer_norm_fusion is set to True.
  warnings.warn(
Optimizing model...
2023-12-27 15:36:46.105477455 [W:onnxruntime:, session_state.cc:1162 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2023-12-27 15:36:46.105508845 [W:onnxruntime:, session_state.cc:1164 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.
symbolic shape inference disabled or failed.
symbolic shape inference disabled or failed.
Configuration saved in bert_Opt_04_cu118/ort_config.json
Optimized model saved at: bert_Opt_04_cu118 (external data format: False; saved all tensor to one file: True)
Post-processing the exported models...
Deduplicating shared (tied) weights...
Validating models in subprocesses...
Validating ONNX model bert_Opt_04_cu118/model.onnx...
2023-12-27 15:37:03.261404593 [W:onnxruntime:, session_state.cc:1162 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2023-12-27 15:37:03.261436085 [W:onnxruntime:, session_state.cc:1164 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.
        -[✓] ONNX model output names match reference model (last_hidden_state)
        - Validating ONNX Model output "last_hidden_state":
                -[✓] (2, 16, 768) matches (2, 16, 768)
                -[x] values not close enough, max diff: 2.3931784629821777 (atol: 0.0001)
The ONNX export succeeded with the warning: The maximum absolute difference between the output of the reference model and the ONNX exported model is not within the set tolerance 0.0001:
- last_hidden_state: max diff = 2.3931784629821777.
 The exported model was saved at: bert_Opt_04_cu118

@tianleiwu
Copy link
Contributor

tianleiwu commented Dec 28, 2023

There is assumption for Bert like Attention fusion that it is right-side padding, and attention mask is 0 for padding. Example:

{'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]])}

If attention mask meets such requirement, attention fusion has good accuracy and also better performance. Here is example of comparing optimization level 1 and 2 (Note that Level 3 enabled Gelu approximation and Level 4 enabled mixed precision, both leads to larger difference: 0.3 max difference looks reasonable for mixed precision model):

python -m onnxruntime.transformers.compare_bert_results --baseline_model bert_Opt_01_cu118/model.onnx  --optimized_model bert_Opt_02_cu118/model.onnx --batch_size 2 --sequence_length 16
100% passed for 100 random inputs given thresholds (rtol=0.001, atol=0.0001).
maximum absolute difference=6.222724914550781e-05

python -m onnxruntime.transformers.compare_bert_results --baseline_model bert_Opt_01_cu118/model.onnx  --optimized_model bert_Opt_04_cu118/model.onnx --batch_size 2 --sequence_length 16
WARNING: 100 out of 100 results NOT passed for thresholds (rtol=0.001, atol=0.0001).
maximum absolute difference=0.321255087852478

In optimum, the dummy inputs for validation has random mask like

{'attention_mask': tensor([[0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0], [1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0]])}

So, we can see accuracy is far off for such case.

If the attention mask is indeed like the above. A walkaround is like the following (The optimized model can handle random mask, but it will get worse performance than the model optimized for right-side padding):

optimum-cli export onnx --model bert-base-uncased --task feature-extraction --optimize O1 --device cuda bert_Opt_01

python -m onnxruntime.transformers.optimizer --input bert_Opt_01/model.onnx --output ./optimized.onnx --use_raw_attention_mask

@JingyaHuang
Copy link
Author

It's super clear, thanks @tianleiwu !

With use_raw_attention_mask, the accuracy looks good. And with right-padding + fp16 the difference looks reasonable as well(and the dummy attention mask in optimum shall probably be either right-padded or left-padded instead of being totally random, thanks for pointing that out).

Another do you know what other architectures are designed for right-padding, and could lead to a similar issue? I would like to add the information to the documentation of Optimum. Thx!

@tianleiwu
Copy link
Contributor

tianleiwu commented Jan 4, 2024

With use_raw_attention_mask, the accuracy looks good. And with right-padding + fp16 the difference looks reasonable as well(and the dummy attention mask in optimum shall probably be either right-padded or left-padded instead of being totally random, thanks for pointing that out).

Another do you know what other architectures are designed for right-padding, and could lead to a similar issue? I would like to add the information to the documentation of Optimum. Thx!

For BERT-like models (BERT, RoBERTa, VIT), right padding is used in most cases.

For text generation models, like GPT, left padding is used in most cases. Since past key/value cache is used in text generation, it still need some extra work to enable attention fusion for those models for Optimum. Currently this issue will not impact those models in Optimum.

@JingyaHuang
Copy link
Author

Thanks again for helping @tianleiwu! I opened a PR in Optimum to better create dummy attn mask tensors and add some tips for attention fusion in the Optimum documentation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.
Projects
None yet
Development

No branches or pull requests

2 participants