-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support Smooth Softmax in GroupQueryAttention (#21867)
### Description Softmax (formula 1) is like the following: ```math y_{i} = \frac{exp(x_{i})}{\sum_{i} exp(x_{i})} ``` After applying softmax, each element will be in the range of $(0, 1)$, and the elements will add up to 1, so that they can be interpreted as probabilities. However, in language model, softmax has two issues: * When all elements are -inf (for example, a whole row is masked when a query token is padding), the result is not defined since exp(-inf)=0 and divided-by-zero is encountered in the above formula. * Why do we need normalize in a way that each query word are treated as equal important (each row has sum equals to1)? **Smooth Softmax** (formula 2) is a modified version that introduces a smooth factor like the following: ```math s_{i} = \frac{exp(x_{i})}{1+ \sum_{i} exp(x_{i})} ``` This formula could tackle the above two issues: * It could handle the special case that all elements are -inf: the result $s_{i}$ is 0 for every element in such case. * Sum of all elements $\sum_{i}{s_{i}} = \frac{\sum_{i}{exp(x_{i})}}{1+ \sum_{i} exp(x_{i})}$ is in the range of (0, 1), so that we can train the model to assign different importance to different query words. Since exponential is prone to overflow or underflow, to get stable result, formula 3 can be used: ```math s_{i} = \frac{exp(x_{i} + c)}{exp(c)+ \sum_{i} exp(x_{i} +c)} ``` c can be any value in theory. In practical, choice of constant c shall avoid $exp(c)$ and $exp(x_{i} +c)$ overflow (or underflow) at the same time. A reasonable choice is like formula 4: ```math c=-\max_{i} \{ x_i \} ``` or apply a constraint that c <=0 like the following formula 5: ```math c=-\max(0, \max_{i} \{ x_i \}) ``` The latter one (formula 5) ensures that $s_{i}$ will fallback to formula 2 when all elements are negative. For CPU provider, smooth softmax is implemented in MLAS. CPU implementation uses formula 5. @wangyems implemented the smooth softmax in flash attention for CUDA, which requires Ampere or newer GPU. The implementation of smooth softmax in flash attention uses formula 4. --------- Co-authored-by: Ye Wang
- Loading branch information
Showing
25 changed files
with
435 additions
and
161 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.