forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CUDA] Fix SkipLayerNorm strict mode when skip has broadcast (microso…
…ft#17896) In SLN strict mode, current code (microsoft#16510) does not handle skip broadcast nicely . There are two issues: (1) skip related parameters is not passed to cuda kernel in strict mode (2) Strict mode kernel also has bug in handling skip broadcasting (like cuWelfordMuSigma2 does not handle skip broadcasting). Here we remove the support of skip broadcasting in strict mode, and operator will return error message that strict mode only support same shape of input and skip. Other changes: * skip_size is misleading when there is no broadcasting. Change to correct value. * Refactor the code to be more efficient: (1) no need to check whether there is broadcasting in kernel. (2) remove one local buffer (load input to sum_v directly to save a local buffer copy). * compute input + bias + skip instead of input + skip + bias. The order is followed common pattern in transformers model (Here assume graph fusion will distinguish input and skip correctly, need double check fusion code later). * update unit test so that strict mode is triggered in each test case (unless skip broadcasting) to have higher test coverage. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> SLN strict mode does not support skip broadcast but current code will silently run (kernel might fail)
- Loading branch information
Showing
6 changed files
with
171 additions
and
148 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.