Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Fix SkipLayerNorm strict mode when skip has broadcast #17896

Merged
merged 3 commits into from
Oct 13, 2023

Conversation

tianleiwu
Copy link
Contributor

@tianleiwu tianleiwu commented Oct 12, 2023

Description

In SLN strict mode, current code (#16510) does not handle skip broadcast nicely . There are two issues:
(1) skip related parameters is not passed to cuda kernel in strict mode
(2) Strict mode kernel also has bug in handling skip broadcasting (like cuWelfordMuSigma2 does not handle skip broadcasting).

Here we remove the support of skip broadcasting in strict mode, and operator will return error message that strict mode only support same shape of input and skip.

Other changes:

  • skip_size is misleading when there is no broadcasting. Change to correct value.
  • Refactor the code to be more efficient: (1) no need to check whether there is broadcasting in kernel. (2) remove one local buffer (load input to sum_v directly to save a local buffer copy).
  • compute input + bias + skip instead of input + skip + bias. The order is followed common pattern in transformers model (Here assume graph fusion will distinguish input and skip correctly, need double check fusion code later).
  • update unit test so that strict mode is triggered in each test case (unless skip broadcasting) to have higher test coverage.

Motivation and Context

SLN strict mode does not support skip broadcast but current code will silently run (kernel might fail)

@tianleiwu tianleiwu changed the title Fix SkipLayerNorm strict mode when skip has broadcast [CUDA] Fix SkipLayerNorm strict mode when skip has broadcast Oct 12, 2023
@tianleiwu tianleiwu marked this pull request as draft October 12, 2023 16:19
@tianleiwu tianleiwu marked this pull request as ready for review October 12, 2023 16:21
@tianleiwu tianleiwu merged commit 67d7eb3 into main Oct 13, 2023
89 of 91 checks passed
@tianleiwu tianleiwu deleted the tlwu/sln_update branch October 13, 2023 14:51
jchen351 pushed a commit that referenced this pull request Oct 18, 2023
In SLN strict mode, current code (#16510) does not handle skip broadcast
nicely . There are two issues:
(1) skip related parameters is not passed to cuda kernel in strict mode
(2) Strict mode kernel also has bug in handling skip broadcasting (like
cuWelfordMuSigma2 does not handle skip broadcasting).

Here we remove the support of skip broadcasting in strict mode, and
operator will return error message that strict mode only support same
shape of input and skip.

Other changes:
* skip_size is misleading when there is no broadcasting. Change to
correct value.
* Refactor the code to be more efficient: (1) no need to check whether
there is broadcasting in kernel. (2) remove one local buffer (load input
to sum_v directly to save a local buffer copy).
* compute input + bias + skip instead of input + skip + bias. The order
is followed common pattern in transformers model (Here assume graph
fusion will distinguish input and skip correctly, need double check
fusion code later).
* update unit test so that strict mode is triggered in each test case
(unless skip broadcasting) to have higher test coverage.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

SLN strict mode does not support skip broadcast but current code will
silently run (kernel might fail)
@faxu faxu added triage:approved Approved for cherrypicks for release sdxl_llama labels Oct 25, 2023
tianleiwu added a commit that referenced this pull request Oct 31, 2023
In SLN strict mode, current code (#16510) does not handle skip broadcast
nicely . There are two issues:
(1) skip related parameters is not passed to cuda kernel in strict mode
(2) Strict mode kernel also has bug in handling skip broadcasting (like
cuWelfordMuSigma2 does not handle skip broadcasting).

Here we remove the support of skip broadcasting in strict mode, and
operator will return error message that strict mode only support same
shape of input and skip.

Other changes:
* skip_size is misleading when there is no broadcasting. Change to
correct value.
* Refactor the code to be more efficient: (1) no need to check whether
there is broadcasting in kernel. (2) remove one local buffer (load input
to sum_v directly to save a local buffer copy).
* compute input + bias + skip instead of input + skip + bias. The order
is followed common pattern in transformers model (Here assume graph
fusion will distinguish input and skip correctly, need double check
fusion code later).
* update unit test so that strict mode is triggered in each test case
(unless skip broadcasting) to have higher test coverage.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

SLN strict mode does not support skip broadcast but current code will
silently run (kernel might fail)
@tianleiwu tianleiwu removed triage:approved Approved for cherrypicks for release release:1.16.2 labels Nov 1, 2023
kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
…ft#17896)

In SLN strict mode, current code (microsoft#16510) does not handle skip broadcast
nicely . There are two issues:
(1) skip related parameters is not passed to cuda kernel in strict mode
(2) Strict mode kernel also has bug in handling skip broadcasting (like
cuWelfordMuSigma2 does not handle skip broadcasting).

Here we remove the support of skip broadcasting in strict mode, and
operator will return error message that strict mode only support same
shape of input and skip.

Other changes:
* skip_size is misleading when there is no broadcasting. Change to
correct value.
* Refactor the code to be more efficient: (1) no need to check whether
there is broadcasting in kernel. (2) remove one local buffer (load input
to sum_v directly to save a local buffer copy).
* compute input + bias + skip instead of input + skip + bias. The order
is followed common pattern in transformers model (Here assume graph
fusion will distinguish input and skip correctly, need double check
fusion code later).
* update unit test so that strict mode is triggered in each test case
(unless skip broadcasting) to have higher test coverage.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

SLN strict mode does not support skip broadcast but current code will
silently run (kernel might fail)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants