Skip to content

Commit

Permalink
ONNX Gelu Op in Opset 20 (#19560)
Browse files Browse the repository at this point in the history
### ONNX Gelu Op in Opset 20

Refactor code to support MSDomain Gelu and ONNX Gelu-opset20 Op

1. Move CPU-GELU implmentation from
`onnxruntime/contrib_ops/cpu/activations.h/cc` to
`onnxruntime/core/providers/cpu/tensor/gelu.h/cc`, as the implementation
for approximate attribute to be 'none'.
2. Dumplicate some logic from
`onnxruntime/contrib_ops/cpu/bert/bias_gelu.cc` to
`onnxruntime/core/providers/cpu/tensor/gelu.h/cc`, as the implementation
for approximate attribute to be 'tanh'.
3. Register ONNX domain Gelu CPU kernel from opset 20 in
`onnxruntime/core/providers/cpu/cpu_execution_provider.cc`.
4. Move `onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.h/cu` to
`onnxruntime/core/providers/cuda/tensor/gelu_impl.h` and
`onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu`
respectively, as the implementation for approximate attribute to be
'tanh'.
5. Implement the logic for approximate attribute to be 'none' in
`onnxruntime/core/providers/cuda/tensor/gelu_impl.cu`.
6. Register ONNX domain Gelu CUDA kernel from opset 20 in
`onnxruntime/core/providers/cuda/cuda_execution_provider.cc`.
7. ROCM ep related changes. 
8. Enrich the tests for ONNX domain Gelu in
`onnxruntime/test/providers/cpu/activation/activation_op_test.cc`.
  • Loading branch information
pengwa authored Feb 23, 2024
1 parent 29b1106 commit ae92d59
Show file tree
Hide file tree
Showing 27 changed files with 395 additions and 197 deletions.
4 changes: 0 additions & 4 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ set(contrib_ops_excluded_files
"bert/fastertransformer_decoder_attention/*"
"bert/multihead_attention.cc"
"bert/multihead_attention.h"
"bert/fast_gelu_impl.cu"
"bert/fast_gelu_impl.h"
"bert/fast_gelu.cc"
"bert/fast_gelu.h"
"bert/relative_attn_bias.cc"
"bert/relative_attn_bias.h"
"bert/relative_attn_bias_impl.cu"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ private void TestCUDAProviderOptions()
private void CanRunInferenceOnAModelWithTensorRT()
{
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx");

int deviceId = 0;
string deviceIdStr = System.Environment.GetEnvironmentVariable("ONNXRUNTIME_TEST_GPU_DEVICE_ID");
if (!string.IsNullOrEmpty(deviceIdStr) && int.TryParse(deviceIdStr, out int parsedValue) && parsedValue >= 0)
Expand Down
2 changes: 2 additions & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ Do not modify directly.*
|GatherND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **indices** = tensor(int64)|
|||12|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **indices** = tensor(int64)|
|||11|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **indices** = tensor(int64)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|20+|**T** = tensor(float)|
|Gemm|*in* A:**T**<br> *in* B:**T**<br> *in* C:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float)|
|||[11, 12]|**T** = tensor(double), tensor(float)|
|||[9, 10]|**T** = tensor(double), tensor(float)|
Expand Down Expand Up @@ -606,6 +607,7 @@ Do not modify directly.*
|GatherND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)<br/> **indices** = tensor(int64)|
|||12|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)<br/> **indices** = tensor(int64)|
|||11|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)<br/> **indices** = tensor(int64)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|20+|**T** = tensor(double), tensor(float), tensor(float16)|
|Gemm|*in* A:**T**<br> *in* B:**T**<br> *in* C:**T**<br> *out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|||[9, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
Expand Down
2 changes: 1 addition & 1 deletion include/onnxruntime/core/providers/cuda/cuda_resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ enum CudaResource : int {
enable_skip_layer_norm_strict_mode_t,
prefer_nhwc_t,
use_tf32_t,
};
};
10 changes: 1 addition & 9 deletions onnxruntime/contrib_ops/cpu/activations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the MIT License.

#include "core/providers/cpu/activation/activations.h"
#include "activations.h"
#include "contrib_ops/cpu/activations.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -26,14 +26,6 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
ThresholdedRelu<float>);

ONNX_OPERATOR_KERNEL_EX(
Gelu,
kMSDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Gelu<float>);

ONNX_OPERATOR_KERNEL_EX(
QuickGelu,
kMSDomain,
Expand Down
41 changes: 0 additions & 41 deletions onnxruntime/contrib_ops/cpu/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,47 +54,6 @@ namespace contrib {
DEFINE_ELE_KERNEL(ScaledTanh);
DEFINE_ELE_KERNEL(ParametricSoftplus);

template <typename T>
class Gelu : public OpKernel {
public:
Gelu(const OpKernelInfo& info) : OpKernel(info) {
}

Status Compute(OpKernelContext* context) const override {
const Tensor* input = context->Input<Tensor>(0);
const T* input_data = input->Data<T>();

Tensor* output = context->Output(0, input->Shape());
T* output_data = output->MutableData<T>();

concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
int64_t elem_count = input->Shape().Size();
constexpr int64_t length_per_task = 4096; // this number comes from FastGelu.
int64_t task_count = (elem_count + length_per_task - 1) / length_per_task;
concurrency::ThreadPool::TryBatchParallelFor(
tp, static_cast<int32_t>(task_count),
[&](ptrdiff_t task_idx) {
const auto start = task_idx * length_per_task;
const T* p_input = input_data + start;
T* p_output = output_data + start;
int64_t count = std::min(length_per_task, elem_count - start);

for (int64_t i = 0; i < count; i++) {
T value = p_input[i];
p_output[i] = value * static_cast<T>(M_SQRT1_2);
}

MlasComputeErf(p_output, p_output, narrow<size_t>(count));

for (int64_t i = 0; i < count; i++) {
p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f);
}
},
0);
return Status::OK();
}
};

// Implement a new one instead of inheriting from ElementWiseRangedTransform so that we can call
// MlasComputeLogistic instead of using Eigen for better perf.
template <typename T>
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/contrib_ops/cuda/activation/activations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ namespace cuda {
UNARY_ACTIVATION_OP_HFD(Affine, 1, kOnnxDomain);
UNARY_ACTIVATION_OP_HFD(ParametricSoftplus, 1, kOnnxDomain);
UNARY_ACTIVATION_OP_HFD(ScaledTanh, 1, kOnnxDomain);
UNARY_ACTIVATION_OP_HFD(Gelu, 1, kMSDomain);
UNARY_ACTIVATION_OP_HFD(QuickGelu, 1, kMSDomain);

REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, MLFloat16)
Expand Down
11 changes: 0 additions & 11 deletions onnxruntime/contrib_ops/cuda/activation/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,6 @@ class ScaledTanh final : public UnaryElementwise {
float beta_;
};

template <typename T>
class Gelu final : public UnaryElementwise {
public:
Gelu(const OpKernelInfo& info) : UnaryElementwise(info) {}

Status ComputeInternal(OpKernelContext* context) const override;

private:
MAKE_FUNC_CTX_NULL()
};

template <typename T>
class QuickGelu final : public UnaryElementwise {
public:
Expand Down
14 changes: 0 additions & 14 deletions onnxruntime/contrib_ops/cuda/activation/activations_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,6 @@ struct OP_ScaledTanh : public CtxScaledTanh {
}
};

template <typename T>
struct OP_Gelu : public CtxGelu {
__device__ __inline__ T operator()(const T& a) const {
return _Gelu(a);
}
};

template <>
struct OP_Gelu<half> : public CtxGelu {
__device__ __inline__ half operator()(const half& a) const {
return static_cast<half>(_Gelu(static_cast<float>(a)));
}
};

template <typename T>
struct OP_QuickGelu : public CtxQuickGelu {
__device__ __inline__ T operator()(const T& a) const {
Expand Down
2 changes: 0 additions & 2 deletions onnxruntime/contrib_ops/cuda/activation/activations_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,12 @@ namespace cuda {
typedef onnxruntime::cuda::CtxAlphaBeta CtxAffine;
typedef onnxruntime::cuda::CtxAlphaBeta CtxParametricSoftplus;
typedef onnxruntime::cuda::CtxAlphaBeta CtxScaledTanh;
typedef onnxruntime::cuda::CtxNull CtxGelu;
typedef onnxruntime::cuda::CtxAlpha CtxQuickGelu;

#define UNARY_CONTRIB_ACTIVATION_OPS() \
UNARY_ACTIVATION_OP_NAME(ScaledTanh) \
UNARY_ACTIVATION_OP_NAME(Affine) \
UNARY_ACTIVATION_OP_NAME(ParametricSoftplus) \
UNARY_ACTIVATION_OP_NAME(Gelu) \
UNARY_ACTIVATION_OP_NAME(QuickGelu)

#define UNARY_ACTIVATION_OP_NAME(name) UNARY_ACTIVATION_IMPL_DECLARATION(name);
Expand Down
20 changes: 18 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cudnn_common.h"
#include "fast_gelu.h"
#include "fast_gelu_impl.h"
#include "core/providers/cuda/tensor/gelu_impl.h"
#include "contrib_ops/cpu/bert/bias_gelu_helper.h"
#include "transformer_common.h"
#ifdef USE_ROCM
#include "contrib_ops/rocm/bert/elementwise.h"
#endif
#ifdef USE_CUDA
#include "contrib_ops/cuda/bert/transformer_common.h"
#endif

namespace onnxruntime {
namespace contrib {
Expand All @@ -31,8 +36,10 @@ using namespace ONNX_NAMESPACE;

template <typename T>
FastGelu<T>::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) {
#ifdef USE_CUDA
const TransformerOptions* options = TransformerOptions::GetInstance();
use_half2_ = !options->DisableHalf2();
#endif
}

template <typename T>
Expand All @@ -50,6 +57,14 @@ Status FastGelu<T>::ComputeInternal(OpKernelContext* context) const {
int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size();
typedef typename ToCudaType<T>::MappedType CudaT;

#ifdef USE_ROCM
return LaunchElementwiseKernel<functor::FastGeLU, CudaT>(
GetTuningContext(), context->GetComputeStream(),
reinterpret_cast<const CudaT*>(input->Data<T>()), static_cast<int>(input_length),
(nullptr != bias) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr, static_cast<int>(bias_length),
reinterpret_cast<CudaT*>(output->MutableData<T>()));
#endif
#ifdef USE_CUDA
return LaunchFastGeluKernel<CudaT>(GetDeviceProp(),
Stream(context),
static_cast<int>(input_length),
Expand All @@ -58,6 +73,7 @@ Status FastGelu<T>::ComputeInternal(OpKernelContext* context) const {
(nullptr != bias) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr,
reinterpret_cast<CudaT*>(output->MutableData<T>()),
use_half2_);
#endif
}

} // namespace cuda
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/fast_gelu.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class FastGelu final : public CudaKernel {
Status ComputeInternal(OpKernelContext* ctx) const override;

private:
bool use_half2_;
bool use_half2_; // Only applicable to CUDA kernel (not ROCM).
};

} // namespace cuda
Expand Down
59 changes: 0 additions & 59 deletions onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc

This file was deleted.

24 changes: 0 additions & 24 deletions onnxruntime/contrib_ops/rocm/bert/fast_gelu.h

This file was deleted.

2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Gelu);
#if !defined(DISABLE_FLOAT8_TYPES)
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FNUZ, IsNaN);
Expand Down Expand Up @@ -2562,6 +2563,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16,
IsNaN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Gelu)>,
#if !defined(DISABLE_FLOAT8_TYPES)
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN,
IsNaN)>,
Expand Down
Loading

0 comments on commit ae92d59

Please sign in to comment.