Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add bfloat16 support for CUDA Neg kernel #18306

Merged
merged 14 commits into from
Nov 9, 2023
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ Do not modify directly.*
|Mul|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|14+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|||13|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Neg|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)|
|Neg|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)|
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)|
|NonZero|*in* X:**T**<br> *out* Y:**tensor(int64)**|13+|**T** = tensor(bool), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)|
|||[9, 12]|**T** = tensor(bool), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)|
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,7 @@
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Neg);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Neg);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Neg);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Neg);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Floor);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Floor);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Floor);
Expand Down Expand Up @@ -1855,6 +1856,7 @@
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Neg)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Neg)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Neg)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Neg)>,

Check warning on line 1859 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/cuda_execution_provider.cc#L1859

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/cuda/cuda_execution_provider.cc:1859:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Floor)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Floor)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Floor)>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa
// F: float
// D: double
// O: bool
// X: BFloat16

#define UNARY_OP_VERSIONED_HFD(name, startver, endver) \
UNARY_OP_VERSIONED_TYPED(name, startver, endver, MLFloat16) \
Expand All @@ -124,12 +125,18 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa
UNARY_OP_TYPED(name, ver, float) \
UNARY_OP_TYPED(name, ver, double)

#define UNARY_OP_HFDX(name, ver) \
UNARY_OP_TYPED(name, ver, MLFloat16) \
UNARY_OP_TYPED(name, ver, BFloat16) \
UNARY_OP_TYPED(name, ver, float) \
UNARY_OP_TYPED(name, ver, double)

#define UNARY_OP_CSILHFD(name, ver) \
UNARY_OP_TYPED(name, ver, int8_t) \
UNARY_OP_TYPED(name, ver, int16_t) \
UNARY_OP_TYPED(name, ver, int32_t) \
UNARY_OP_TYPED(name, ver, int64_t) \
UNARY_OP_HFD(name, ver)
UNARY_OP_HFDX(name, ver)
prathikr marked this conversation as resolved.
Show resolved Hide resolved

#define UNARY_OP_BWUZCSILHFD(name, ver) \
UNARY_OP_TYPED(name, ver, uint8_t) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,14 @@ UNARY_OPS()
// F: float
// D: double
// O: bool
// X: BFloat16

#define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(name) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, half) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, float) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, double)

#define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDB(name) \
#define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(name) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(name) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, BFloat16)

Expand All @@ -68,7 +69,7 @@ UNARY_OPS()
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int16_t) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int32_t) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int64_t) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(name)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(name)
prathikr marked this conversation as resolved.
Show resolved Hide resolved

#define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_BWUZCSILHFD(name) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, uint8_t) \
Expand All @@ -83,8 +84,8 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Floor)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Ceil)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Reciprocal)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sqrt)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDB(Log)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDB(Exp)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Log)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Exp)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Erf)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Round)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sin)
Expand Down
Loading