From 056448b2f2014cba39a1d67ae7cfef19de0d534b Mon Sep 17 00:00:00 2001 From: pengwa Date: Fri, 1 Mar 2024 10:57:14 +0800 Subject: [PATCH] Follow up fix for Gelu impl (#19693) ### Follow up fix for Gelu impl There are two minor comments in https://github.com/microsoft/onnxruntime/pull/19560. Fix them in this pull request. ### Motivation and Context --- docs/ORTModule_Training_Guidelines.md | 2 +- onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc | 8 +++----- onnxruntime/contrib_ops/cuda/bert/fast_gelu.h | 4 +++- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index 91057d3dfb120..f50b18b736936 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -293,7 +293,7 @@ A classical usage of disabling the deep copy: when the deep copy before module e export ORTMODULE_MEMORY_OPT_LEVEL=0 ``` -### ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT +#### ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT - **Feature Area**: *ORTMODULE/Optimizations* - **Description**: By default, the memory-efficient gradient management is turned off. The gradient after it is computed in ONNX Runtime, will trigger the corresponding parameter's backward function through `PythonOpGrad` operator. This would help release the gradient buffer managed in ONNX Runtime, which originally is released once all backward computation finishes. diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc index e8974a29476b6..8b8e4e267f895 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc @@ -8,8 +8,7 @@ #include "contrib_ops/cpu/bert/bias_gelu_helper.h" #ifdef USE_ROCM #include "contrib_ops/rocm/bert/elementwise.h" -#endif -#ifdef USE_CUDA +#else #include "contrib_ops/cuda/bert/transformer_common.h" #endif @@ -36,7 +35,7 @@ using namespace ONNX_NAMESPACE; template FastGelu::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) { -#ifdef USE_CUDA +#ifndef USE_ROCM const TransformerOptions* options = TransformerOptions::GetInstance(); use_half2_ = !options->DisableHalf2(); #endif @@ -63,8 +62,7 @@ Status FastGelu::ComputeInternal(OpKernelContext* context) const { reinterpret_cast(input->Data()), static_cast(input_length), (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, static_cast(bias_length), reinterpret_cast(output->MutableData())); -#endif -#ifdef USE_CUDA +#else return LaunchFastGeluKernel(GetDeviceProp(), Stream(context), static_cast(input_length), diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h index d563556593e6e..26f3bd5a03928 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h @@ -18,7 +18,9 @@ class FastGelu final : public CudaKernel { Status ComputeInternal(OpKernelContext* ctx) const override; private: - bool use_half2_; // Only applicable to CUDA kernel (not ROCM). +#ifndef USE_ROCM + bool use_half2_; +#endif }; } // namespace cuda