Skip to content

Commit

Permalink
bfloat16 support for quickgelugrad
Browse files Browse the repository at this point in the history
  • Loading branch information
Prathik Rao committed Nov 7, 2023
1 parent 096307c commit 9fe4767
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,15 @@ namespace cuda {
ACTIVATION_GRAD_OP_TYPED(name, ver, domain, float) \
ACTIVATION_GRAD_OP_TYPED(name, ver, domain, double)

#define ACTIVATION_GRAD_OP_HFDB(name, ver, domain) \
ACTIVATION_GRAD_OP_HFD(name, ver, domain) \
ACTIVATION_GRAD_OP_TYPED(name, ver, domain, BFloat16)

ACTIVATION_GRAD_OP_HFD(GeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(FastGeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(ReluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(SigmoidGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(QuickGeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFDB(QuickGeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(TanhGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(LeakyReluGrad, 1, kMSDomain);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,15 @@ struct OP_LeakyReluGrad : public CtxLeakyReluGrad {
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL(name, T) \
template void Impl_##name<T>(cudaStream_t stream, const T* lhs_data, const T* rhs_data, T* output_data, const Ctx##name* func_ctx, size_t count);

#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(x) \
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFDB(x) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, half) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, BFloat16)

#define ACTIVATION_GRAD_OP_NAME(name) \
BINARY_ELEMENTWISE_IMPL(name); \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(name)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFDB(name)

ACTIVATION_GRAD_OPS()
#undef ACTIVATION_GRAD_OP_NAME
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGeluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGeluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGeluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, QuickGeluGrad);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TanhGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TanhGrad);
Expand Down Expand Up @@ -378,6 +379,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGeluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGeluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGeluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, QuickGeluGrad)>,

Check warning on line 382 in orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc#L382

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc:382:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TanhGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TanhGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad)>,
Expand Down

0 comments on commit 9fe4767

Please sign in to comment.