From 7a3da4526f98c9cfc6387a5faa1edeec7d88ef17 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Wed, 8 Nov 2023 18:32:12 -0800 Subject: [PATCH] add bfloat16 support for CUDA Neg kernel (#18306) ### Description Registers BFloat16 datatype as valid input type for CUDA Neg Kernel. ### Motivation and Context Enabling `meta-llama/Llama-2-70b` to be finetuned with ONNX Runtime training. --------- Co-authored-by: Prathik Rao --- docs/OperatorKernels.md | 2 +- .../core/providers/cuda/cuda_execution_provider.cc | 2 ++ .../core/providers/cuda/math/unary_elementwise_ops.cc | 9 ++++++++- .../providers/cuda/math/unary_elementwise_ops_impl.cu | 9 +++++---- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 38783ac044c22..8e546b30aa4cb 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -665,7 +665,7 @@ Do not modify directly.* |Mul|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||13|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|Neg|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)| +|Neg|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)| |NonZero|*in* X:**T**
*out* Y:**tensor(int64)**|13+|**T** = tensor(bool), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)| |||[9, 12]|**T** = tensor(bool), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)| diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 2d242d7d6fb12..d8a0792209b0f 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -971,6 +971,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Neg); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Neg); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Neg); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Neg); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Floor); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Floor); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Floor); @@ -1855,6 +1856,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index 9ede1f8d90ecc..655877f425054 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -99,6 +99,7 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa // F: float // D: double // O: bool +// X: BFloat16 #define UNARY_OP_VERSIONED_HFD(name, startver, endver) \ UNARY_OP_VERSIONED_TYPED(name, startver, endver, MLFloat16) \ @@ -124,12 +125,18 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa UNARY_OP_TYPED(name, ver, float) \ UNARY_OP_TYPED(name, ver, double) +#define UNARY_OP_HFDX(name, ver) \ + UNARY_OP_TYPED(name, ver, MLFloat16) \ + UNARY_OP_TYPED(name, ver, BFloat16) \ + UNARY_OP_TYPED(name, ver, float) \ + UNARY_OP_TYPED(name, ver, double) + #define UNARY_OP_CSILHFD(name, ver) \ UNARY_OP_TYPED(name, ver, int8_t) \ UNARY_OP_TYPED(name, ver, int16_t) \ UNARY_OP_TYPED(name, ver, int32_t) \ UNARY_OP_TYPED(name, ver, int64_t) \ - UNARY_OP_HFD(name, ver) + UNARY_OP_HFDX(name, ver) #define UNARY_OP_BWUZCSILHFD(name, ver) \ UNARY_OP_TYPED(name, ver, uint8_t) \ diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index 1298d53338337..5c3db4a499972 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -53,13 +53,14 @@ UNARY_OPS() // F: float // D: double // O: bool +// X: BFloat16 #define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(name) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, half) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, float) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, double) -#define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDB(name) \ +#define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(name) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(name) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, BFloat16) @@ -68,7 +69,7 @@ UNARY_OPS() SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int16_t) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int32_t) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int64_t) \ - SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(name) + SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(name) #define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_BWUZCSILHFD(name) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, uint8_t) \ @@ -83,8 +84,8 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Floor) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Ceil) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Reciprocal) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sqrt) -SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDB(Log) -SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDB(Exp) +SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Log) +SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Exp) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Erf) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Round) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sin)