Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](6fd7aba) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> #18854
- Loading branch information