-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CUDA] cuDNN Flash Attention (#21629)
### Description - [x] Add cuDNN flash attention using cudnn frontend, and enable it in MultiHeadAttention operator. - [x] Support attention mask. - [x] Support attention bias. - [x] Update tests and benchmark script. The cuDNN SDPA is disabled by default. To enable it, need the following: (1) Requires cuDNN 9.3 or newer version installed. (2) Set an environment variable `ORT_ENABLE_CUDNN_FLASH_ATTENTION=1` or set `sdpa_kernel=8` cuda provider option to enable it. (3) Only works on devices with compute capability >= 8.0. Note that some combinations of parameters might be rejected due to limited support of head dimension or sequence lengths. Future Works: (1) FP8 and BF16 APIs. Currently, only API for FP16 are exposed. (2) Add API to support ragged batching (padding removed in inputs). (3) Support other input formats (like QKV_BS3NH). (4) Currently, q are converted to BSNH, k/v are converted to either BSNH or BNSH format. May do some experiment to see whether converting q to BNSH could be better in some case. ### Example Benchmark Results on H100 The following tests are on FP16 MultiHeadAttention operator without attention mask and attention bias. #### Test Setting 1 batch_size | sequence_length | past_sequence_length | num_heads | head_size -- | -- | -- | -- | -- 16 | 256 | 0 | 32 | 128 format | average_latency | tflops | kernel -- | -- | -- | -- Q,K,V (BNSH) | 0.000075 | 229.5 | torch:flash Q,K,V (BNSH) | 0.000119 | 144.8 | torch:efficient Q,K,V (BNSH) | 0.000224 | 76.5 | torch:math Q,K,V (BSNH) | 0.000075 | 227.8 | ort:cudnn Q,K,V (BSNH) | 0.000094 | 182.8 | ort:flash Q,K,V (BSNH) | 0.000138 | 124.7 | ort:efficient Q,K,V (BSNH) | 0.000438 | 39.3 | ort:math Q,KV | 0.000129 | 133.0 | ort:cudnn Q,KV | 0.000151 | 114.1 | ort:flash Q,KV | 0.000194 | 88.5 | ort:efficient QKV | 0.000154 | 111.8 | ort:cudnn QKV | 0.000175 | 98.0 | ort:flash QKV | 0.000217 | 79.0 | ort:efficient #### Test Setting 2 batch_size | sequence_length | past_sequence_length | num_heads | head_size -- | -- | -- | -- | -- 16 | 512 | 0 | 16 | 64 format | average_latency | tflops | kernel -- | -- | -- | -- Q,K,V (BNSH) | 0.000069 | 249.2 | torch:flash Q,K,V (BNSH) | 0.000141 | 121.7 | torch:efficient Q,K,V (BNSH) | 0.000294 | 58.5 | torch:math Q,K,V (BSNH) | 0.000077 | 221.7 | ort:cudnn Q,K,V (BSNH) | 0.000087 | 196.6 | ort:flash Q,K,V (BSNH) | 0.000163 | 105.6 | ort:efficient Q,K,V (BSNH) | 0.000651 | 26.4 | ort:math Q,KV | 0.000103 | 167.1 | ort:cudnn Q,KV | 0.000117 | 146.3 | ort:flash Q,KV | 0.000192 | 89.6 | ort:efficient QKV | 0.000113 | 151.5 | ort:cudnn QKV | 0.000128 | 134.7 | ort:flash QKV | 0.000201 | 85.3 | ort:efficient
- Loading branch information
Showing
19 changed files
with
681 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -107,5 +107,3 @@ elseif(CUDNN_MAJOR_VERSION EQUAL 9) | |
CUDNN::cudnn_heuristic | ||
) | ||
endif() | ||
|
||
mark_as_advanced(CUDNN_INCLUDE_DIR) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.