Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] cuDNN Flash Attention #21629

Merged
merged 3 commits into from
Aug 20, 2024
Merged

[CUDA] cuDNN Flash Attention #21629

merged 3 commits into from
Aug 20, 2024

Conversation

tianleiwu
Copy link
Contributor

@tianleiwu tianleiwu commented Aug 5, 2024

Description

  • Add cuDNN flash attention using cudnn frontend, and enable it in MultiHeadAttention operator.
  • Support attention mask.
  • Support attention bias.
  • Update tests and benchmark script.

The cuDNN SDPA is disabled by default. To enable it, need the following:
(1) Requires cuDNN 9.3 or newer version installed.
(2) Set an environment variable ORT_ENABLE_CUDNN_FLASH_ATTENTION=1 or set sdpa_kernel=8 cuda provider option to enable it.
(3) Only works on devices with compute capability >= 8.0.

Note that some combinations of parameters might be rejected due to limited support of head dimension or sequence lengths.

Future Works:
(1) FP8 APIs. Currently, only API for FP16 and BF16 are exposed in cudnn_flash_attention.h.
(2) Add API to support ragged batching (padding removed in inputs).
(3) Support other input formats (like QKV_BS3NH).
(4) Currently, q are converted to BSNH, k/v are converted to either BSNH or BNSH format. May do some experiment to see whether converting q to BNSH could be better in some case.

Example Benchmark Results on H100

The following tests are on FP16 MultiHeadAttention operator without attention mask and attention bias.

Test Setting 1

python benchmark_mha.py -b 16 -s 256 -n 32 -d 128 --use_gpu -r 10000
python benchmark_mha.py -b 16 -s 256 -n 32 -d 128 --use_gpu -r 10000 --torch
batch_size sequence_length past_sequence_length num_heads head_size
16 256 0 32 128
format average_latency tflops kernel
Q,K,V (BNSH) 0.000075 229.5 torch:flash
Q,K,V (BNSH) 0.000119 144.8 torch:efficient
Q,K,V (BNSH) 0.000224 76.5 torch:math
Q,K,V (BSNH) 0.000075 227.8 ort:cudnn
Q,K,V (BSNH) 0.000094 182.8 ort:flash
Q,K,V (BSNH) 0.000138 124.7 ort:efficient
Q,K,V (BSNH) 0.000438 39.3 ort:math
Q,KV (BSNH_BSN2H) 0.000129 133.0 ort:cudnn
Q,KV (BSNH_BSN2H) 0.000151 114.1 ort:flash
Q,KV (BSNH_BSN2H) 0.000194 88.5 ort:efficient
QKV (BSN3H) 0.000154 111.8 ort:cudnn
QKV (BSN3H) 0.000175 98.0 ort:flash
QKV (BSN3H) 0.000217 79.0 ort:efficient

Test Setting 2

batch_size sequence_length past_sequence_length num_heads head_size
16 512 0 16 64
format average_latency tflops kernel
Q,K,V (BNSH) 0.000069 249.2 torch:flash
Q,K,V (BNSH) 0.000141 121.7 torch:efficient
Q,K,V (BNSH) 0.000294 58.5 torch:math
Q,K,V (BSNH) 0.000077 221.7 ort:cudnn
Q,K,V (BSNH) 0.000087 196.6 ort:flash
Q,K,V (BSNH) 0.000163 105.6 ort:efficient
Q,K,V (BSNH) 0.000651 26.4 ort:math
Q,KV (BSNH_BSN2H) 0.000103 167.1 ort:cudnn
Q,KV (BSNH_BSN2H) 0.000117 146.3 ort:flash
Q,KV (BSNH_BSN2H) 0.000192 89.6 ort:efficient
QKV (BSN3H) 0.000113 151.5 ort:cudnn
QKV (BSN3H) 0.000128 134.7 ort:flash
QKV (BSN3H) 0.000201 85.3 ort:efficient

@tianleiwu tianleiwu marked this pull request as draft August 5, 2024 23:54
@tianleiwu tianleiwu force-pushed the tlwu/cudnn_flash_att branch from a3608a5 to 1ac4cf8 Compare August 6, 2024 07:00
@tianleiwu tianleiwu force-pushed the tlwu/cudnn_flash_att branch from 93d8708 to 9c78a6d Compare August 19, 2024 06:10
@tianleiwu tianleiwu marked this pull request as ready for review August 19, 2024 06:17
@tianleiwu tianleiwu requested review from a team as code owners August 19, 2024 16:14
@tianleiwu tianleiwu force-pushed the tlwu/cudnn_flash_att branch from f6f82ef to 388dabe Compare August 19, 2024 16:21
@tianleiwu tianleiwu merged commit fbc3927 into main Aug 20, 2024
93 of 97 checks passed
@tianleiwu tianleiwu deleted the tlwu/cudnn_flash_att branch August 20, 2024 15:50
tianleiwu added a commit that referenced this pull request Aug 22, 2024
… kernel (#21804)

Use debug info to identify sdpa kernel actually used, and show it in the
output of benchmark_mha.py. This updated benchmark script was used to
get the benchmark results in
#21629.
(1) Change the output format of debug info to output like SdpaKernel=*
(2) Add a step to capture stdout from onnxruntime session, and use
regular expression to parse SdpaKernel=* from the captured text.

Other minor changes:
(1) Set different default repeats during benchmark: 100 for CPU; and
10000 for CUDA.
(2) Fix PrintTensorByDims used in console dumper: if it is not enabled,
do not dump tensor.
(3) Update some comments

### Motivation and Context

Sometime, we will use fallback for a sdpa_kernel. It could confuse user
unless we can tell exact kernel is used in benchmark.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants