-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GQA support for ROCm #21032
Add GQA support for ROCm #21032
Conversation
…'t need to explicit unpack the packed qkv tensor
b6be9bd
to
14d1a1a
Compare
14d1a1a
to
2b0c46e
Compare
CI test revealed something like the following
and some sparse 'inf' in other tests. This however, happened to the |
…iled with nan and inf from reference values
…otaryEmbeddingKernel
LGTM except there is a build error: CMakeFiles/onnxruntime_providers_rocm.dir/onnxruntime_src/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu.o |
1384ff4
to
f4355d4
Compare
@snnn need an es approve. The some packages in CI are updated due to some nan and inf are produced from the reference impl, see my previous comment. |
The test_flash_attn_rocm.py from #21032 failed frequently. For example, I saw two failed jobs today: E Max absolute difference: 0.002167 E Max absolute difference: 0.002686 Adjust the abs threshold from 0.002 to 0.005, and use default relative tolerance rtol=0.001.
depends on