Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add bfloat16 support for CUDA Neg kernel #18306

Merged
merged 14 commits into from
Nov 9, 2023
Merged

Conversation

prathikr
Copy link
Contributor

@prathikr prathikr commented Nov 6, 2023

Description

Registers BFloat16 datatype as valid input type for CUDA Neg Kernel.

Motivation and Context

Enabling meta-llama/Llama-2-70b to be finetuned with ONNX Runtime training.

@prathikr prathikr changed the title add bfloat16 support add bfloat16 support for CUDA Neg kernel Nov 6, 2023
@prathikr prathikr requested a review from hanbitmyths November 8, 2023 05:04
@prathikr prathikr requested a review from hanbitmyths November 8, 2023 06:09
@prathikr prathikr merged commit 7a3da45 into main Nov 9, 2023
95 checks passed
@prathikr prathikr deleted the prathikrao/neg-bfloat16 branch November 9, 2023 02:32
kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
### Description
<!-- Describe your changes. -->

Registers BFloat16 datatype as valid input type for CUDA Neg Kernel.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Enabling `meta-llama/Llama-2-70b` to be finetuned with ONNX Runtime
training.

---------

Co-authored-by: Prathik Rao <[email protected]@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants