Skip to content

Commit

Permalink
add bfloat16 support for CUDA Neg kernel (microsoft#18306)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

Registers BFloat16 datatype as valid input type for CUDA Neg Kernel.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Enabling `meta-llama/Llama-2-70b` to be finetuned with ONNX Runtime
training.

---------

Co-authored-by: Prathik Rao <[email protected]@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
  • Loading branch information
prathikr and Prathik Rao authored Nov 9, 2023
1 parent 4dc6369 commit 7a3da45
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ Do not modify directly.*
|Mul|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|14+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|||13|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Neg|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)|
|Neg|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)|
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)|
|NonZero|*in* X:**T**<br> *out* Y:**tensor(int64)**|13+|**T** = tensor(bool), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)|
|||[9, 12]|**T** = tensor(bool), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)|
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Neg);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Neg);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Neg);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Neg);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Floor);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Floor);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Floor);
Expand Down Expand Up @@ -1855,6 +1856,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Neg)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Neg)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Neg)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Neg)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Floor)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Floor)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Floor)>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa
// F: float
// D: double
// O: bool
// X: BFloat16

#define UNARY_OP_VERSIONED_HFD(name, startver, endver) \
UNARY_OP_VERSIONED_TYPED(name, startver, endver, MLFloat16) \
Expand All @@ -124,12 +125,18 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa
UNARY_OP_TYPED(name, ver, float) \
UNARY_OP_TYPED(name, ver, double)

#define UNARY_OP_HFDX(name, ver) \
UNARY_OP_TYPED(name, ver, MLFloat16) \
UNARY_OP_TYPED(name, ver, BFloat16) \
UNARY_OP_TYPED(name, ver, float) \
UNARY_OP_TYPED(name, ver, double)

#define UNARY_OP_CSILHFD(name, ver) \
UNARY_OP_TYPED(name, ver, int8_t) \
UNARY_OP_TYPED(name, ver, int16_t) \
UNARY_OP_TYPED(name, ver, int32_t) \
UNARY_OP_TYPED(name, ver, int64_t) \
UNARY_OP_HFD(name, ver)
UNARY_OP_HFDX(name, ver)

#define UNARY_OP_BWUZCSILHFD(name, ver) \
UNARY_OP_TYPED(name, ver, uint8_t) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,14 @@ UNARY_OPS()
// F: float
// D: double
// O: bool
// X: BFloat16

#define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(name) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, half) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, float) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, double)

#define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDB(name) \
#define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(name) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(name) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, BFloat16)

Expand All @@ -68,7 +69,7 @@ UNARY_OPS()
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int16_t) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int32_t) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int64_t) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(name)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(name)

#define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_BWUZCSILHFD(name) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, uint8_t) \
Expand All @@ -83,8 +84,8 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Floor)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Ceil)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Reciprocal)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sqrt)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDB(Log)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDB(Exp)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Log)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Exp)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Erf)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Round)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sin)
Expand Down

0 comments on commit 7a3da45

Please sign in to comment.