From a4c01a591ea3e8e26609fcd638c083b3b7fbe14f Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Thu, 22 Feb 2024 02:11:46 +0000 Subject: [PATCH] fix rocm build --- onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc | 5 ++++- onnxruntime/contrib_ops/cuda/bert/fast_gelu.h | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc index 670e460e5a189..fff6300d42f08 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc @@ -6,9 +6,10 @@ #include "fast_gelu.h" #include "core/providers/cuda/tensor/gelu_impl.h" #include "contrib_ops/cpu/bert/bias_gelu_helper.h" -#include "transformer_common.h" #ifdef USE_ROCM #include "contrib_ops/rocm/bert/elementwise.h" +#else +#include "contrib_ops/cuda/bert/transformer_common.h" #endif namespace onnxruntime { @@ -34,8 +35,10 @@ using namespace ONNX_NAMESPACE; template FastGelu::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) { +#ifdef USE_CUDA const TransformerOptions* options = TransformerOptions::GetInstance(); use_half2_ = !options->DisableHalf2(); +#endif } template diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h index 3e642a70afef5..d563556593e6e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h @@ -18,7 +18,7 @@ class FastGelu final : public CudaKernel { Status ComputeInternal(OpKernelContext* ctx) const override; private: - bool use_half2_; + bool use_half2_; // Only applicable to CUDA kernel (not ROCM). }; } // namespace cuda