diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc index 670e460e5a189..fff6300d42f08 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc @@ -6,9 +6,10 @@ #include "fast_gelu.h" #include "core/providers/cuda/tensor/gelu_impl.h" #include "contrib_ops/cpu/bert/bias_gelu_helper.h" -#include "transformer_common.h" #ifdef USE_ROCM #include "contrib_ops/rocm/bert/elementwise.h" +#else +#include "contrib_ops/cuda/bert/transformer_common.h" #endif namespace onnxruntime { @@ -34,8 +35,10 @@ using namespace ONNX_NAMESPACE; template FastGelu::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) { +#ifdef USE_CUDA const TransformerOptions* options = TransformerOptions::GetInstance(); use_half2_ = !options->DisableHalf2(); +#endif } template diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h index 3e642a70afef5..d563556593e6e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h @@ -18,7 +18,7 @@ class FastGelu final : public CudaKernel { Status ComputeInternal(OpKernelContext* ctx) const override; private: - bool use_half2_; + bool use_half2_; // Only applicable to CUDA kernel (not ROCM). }; } // namespace cuda