Skip to content

Commit

Permalink
Follow up fix for Gelu impl (#19693)
Browse files Browse the repository at this point in the history
### Follow up fix for Gelu impl

There are two minor comments in
#19560.

Fix them in this pull request. 


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
pengwa authored Mar 1, 2024
1 parent 2a857d9 commit acbfc29
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion docs/ORTModule_Training_Guidelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ A classical usage of disabling the deep copy: when the deep copy before module e
export ORTMODULE_MEMORY_OPT_LEVEL=0
```
### ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT
#### ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT
- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, the memory-efficient gradient management is turned off. The gradient after it is computed in ONNX Runtime, will trigger the corresponding parameter's backward function through `PythonOpGrad` operator. This would help release the gradient buffer managed in ONNX Runtime, which originally is released once all backward computation finishes.
Expand Down
8 changes: 3 additions & 5 deletions onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
#include "contrib_ops/cpu/bert/bias_gelu_helper.h"
#ifdef USE_ROCM
#include "contrib_ops/rocm/bert/elementwise.h"
#endif
#ifdef USE_CUDA
#else
#include "contrib_ops/cuda/bert/transformer_common.h"
#endif

Expand All @@ -36,7 +35,7 @@ using namespace ONNX_NAMESPACE;

template <typename T>
FastGelu<T>::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) {
#ifdef USE_CUDA
#ifndef USE_ROCM
const TransformerOptions* options = TransformerOptions::GetInstance();
use_half2_ = !options->DisableHalf2();
#endif
Expand All @@ -63,8 +62,7 @@ Status FastGelu<T>::ComputeInternal(OpKernelContext* context) const {
reinterpret_cast<const CudaT*>(input->Data<T>()), static_cast<int>(input_length),
(nullptr != bias) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr, static_cast<int>(bias_length),
reinterpret_cast<CudaT*>(output->MutableData<T>()));
#endif
#ifdef USE_CUDA
#else
return LaunchFastGeluKernel<CudaT>(GetDeviceProp(),
Stream(context),
static_cast<int>(input_length),
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/fast_gelu.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ class FastGelu final : public CudaKernel {
Status ComputeInternal(OpKernelContext* ctx) const override;

private:
bool use_half2_; // Only applicable to CUDA kernel (not ROCM).
#ifndef USE_ROCM
bool use_half2_;
#endif
};

} // namespace cuda
Expand Down

0 comments on commit acbfc29

Please sign in to comment.