Skip to content

Commit

Permalink
use X for bfloat16
Browse files Browse the repository at this point in the history
  • Loading branch information
Prathik Rao committed Nov 8, 2023
1 parent 9fe4767 commit f60f122
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ namespace cuda {
ACTIVATION_GRAD_OP_TYPED(name, ver, domain, float) \
ACTIVATION_GRAD_OP_TYPED(name, ver, domain, double)

#define ACTIVATION_GRAD_OP_HFDB(name, ver, domain) \
#define ACTIVATION_GRAD_OP_HFDX(name, ver, domain) \
ACTIVATION_GRAD_OP_HFD(name, ver, domain) \
ACTIVATION_GRAD_OP_TYPED(name, ver, domain, BFloat16)

ACTIVATION_GRAD_OP_HFD(GeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(FastGeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(ReluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(SigmoidGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFDB(QuickGeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFDX(QuickGeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(TanhGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(LeakyReluGrad, 1, kMSDomain);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,15 @@ struct OP_LeakyReluGrad : public CtxLeakyReluGrad {
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL(name, T) \
template void Impl_##name<T>(cudaStream_t stream, const T* lhs_data, const T* rhs_data, T* output_data, const Ctx##name* func_ctx, size_t count);

#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFDB(x) \
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFDX(x) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, half) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, BFloat16)

#define ACTIVATION_GRAD_OP_NAME(name) \
BINARY_ELEMENTWISE_IMPL(name); \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFDB(name)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFDX(name)

ACTIVATION_GRAD_OPS()
#undef ACTIVATION_GRAD_OP_NAME
Expand Down

0 comments on commit f60f122

Please sign in to comment.