Skip to content

Commit

Permalink
fix rocm build
Browse files Browse the repository at this point in the history
  • Loading branch information
pengwa committed Feb 22, 2024
1 parent b8cf88d commit a4c01a5
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
5 changes: 4 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
#include "fast_gelu.h"
#include "core/providers/cuda/tensor/gelu_impl.h"
#include "contrib_ops/cpu/bert/bias_gelu_helper.h"
#include "transformer_common.h"
#ifdef USE_ROCM
#include "contrib_ops/rocm/bert/elementwise.h"
#else
#include "contrib_ops/cuda/bert/transformer_common.h"
#endif

namespace onnxruntime {
Expand All @@ -34,8 +35,10 @@ using namespace ONNX_NAMESPACE;

template <typename T>
FastGelu<T>::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) {
#ifdef USE_CUDA
const TransformerOptions* options = TransformerOptions::GetInstance();
use_half2_ = !options->DisableHalf2();
#endif
}

template <typename T>
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/fast_gelu.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class FastGelu final : public CudaKernel {
Status ComputeInternal(OpKernelContext* ctx) const override;

private:
bool use_half2_;
bool use_half2_; // Only applicable to CUDA kernel (not ROCM).
};

} // namespace cuda
Expand Down

0 comments on commit a4c01a5

Please sign in to comment.