-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CUDA] FusedMHARunnerFP16v2 thread-safe (#21420)
### Description - [x] Rewrite FusedMHARunnerFP16v2 to make it thread-safe. - [x] Add multi-threading tests Previously, the kernel parameters params is stored as a member of mha runner, which means that different threads might change the params at the same time and impacts the other threads. For example, if batch_size and seq_len was changed by another thread to larger values in setup(...), buffer overrun might happen in run(...) because a kernel could read/write memory out of range of allocated buffers. In new implementation, I change the api and remove mutable member variables to make it thread safe. Below is summary of change: Before: ``` class FusedMHARunnerFP16v2::mhaImpl { void setup(int seq_len, int batch_size) { // change scalar params } void run(input, output) { // change params for input and output pointers // launch kernel using params } Fused_multihead_attention_params_v2 params; // mutable, not thread-safe } ``` After: ``` class FusedMHARunnerFP16v2::FmhaImpl { void setup(int seq_len, int batch_size, Fused_multihead_attention_params_v2& params) { // change params } void run(params, input, output) { // change params with input and output pointers // launch kernel using params } } ``` ### Motivation and Context #18854 #21413
- Loading branch information
Showing
10 changed files
with
534 additions
and
234 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.