Skip to content

Commit

Permalink
Implement CUDA IsInf-10,20 (#19772)
Browse files Browse the repository at this point in the history
### Description
Implment IsInf-10,20 for CUDA.
Add FP16 types also on CPU.

### Motivation and Context
Certain models lag in performance due to IsInf not available on CUDA.
  • Loading branch information
yuslepukhin authored Mar 5, 2024
1 parent 06e684c commit 1e78bce
Show file tree
Hide file tree
Showing 13 changed files with 420 additions and 15 deletions.
4 changes: 3 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ Do not modify directly.*
|||[1, 10]|**B** = tensor(bool)<br/> **V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|ImageScaler|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(float)|
|InstanceNormalization|*in* input:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *out* output:**T**|6+|**T** = tensor(float)|
|IsInf|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
|IsInf|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
|||[10, 19]|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(bool)|
|IsNaN|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
|||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
Expand Down Expand Up @@ -631,6 +631,8 @@ Do not modify directly.*
|||[1, 10]|**B** = tensor(bool)<br/> **V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|ImageScaler|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|InstanceNormalization|*in* input:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *out* output:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)|
|IsInf|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
|||[10, 19]|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(bool)|
|LRN|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|LSTM|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(int32)|
Expand Down
2 changes: 1 addition & 1 deletion include/onnxruntime/core/framework/data_types_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ class CallableDispatchableHelper {
return 0;
}

void CheckCalledOnce() {
void CheckCalledOnce() const {
ORT_ENFORCE(called_ == 1, "Unsupported data type: ", dt_type_);
}
};
Expand Down
64 changes: 51 additions & 13 deletions onnxruntime/core/providers/cpu/tensor/isinf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(
using IsInfTypesOpset20 =
TypeList<
float,
double
double,
MLFloat16,
BFloat16
#if !defined(DISABLE_FLOAT8_TYPES)
,
Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
Expand Down Expand Up @@ -76,40 +78,76 @@ ONNX_CPU_OPERATOR_KERNEL(
IsInf);

IsInf::IsInf(const OpKernelInfo& info) : OpKernel(info) {
Status status = info.GetAttr("detect_positive", &detect_positive_);
ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_positive");
status = info.GetAttr("detect_negative", &detect_negative_);
ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_negative");
detect_positive_ = info.GetAttrOrDefault<int64_t>("detect_positive", 1);
detect_negative_ = info.GetAttrOrDefault<int64_t>("detect_negative", 1);
opset_ = info.node().SinceVersion();
}

namespace isinf_internal {
template <class T>
struct ComputeDispatchTarget {
void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const {
const auto total_items = X.Shape().Size();
auto input_data = X.DataAsSpan<T>();
auto output_data = Y.MutableData<bool>();

if (detect_positive && detect_negative) {
EigenMap<bool>(Y) = EigenMap<T>(X).array().isInf();
} else if (detect_positive) {
auto input_data = X.Data<T>();
auto end_data = input_data + total_items;
std::transform(
input_data, end_data, output_data, [](T v) {
input_data.begin(), input_data.end(), output_data, [](T v) {
return (v == std::numeric_limits<T>::infinity());
});

} else if (detect_negative) {
auto input_data = X.Data<T>();
auto end_data = input_data + total_items;
std::transform(
input_data, end_data, output_data, [](T v) {
input_data.begin(), input_data.end(), output_data, [](T v) {
return (v == -std::numeric_limits<T>::infinity());
});
} else {
// all false
memset(output_data, false, onnxruntime::narrow<size_t>(total_items));
memset(output_data, false, input_data.size());
}
}
};

template <>
struct ComputeDispatchTarget<MLFloat16> {
void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const {
auto output_data = Y.MutableData<bool>();
auto input_data = X.DataAsSpan<MLFloat16>();
if (detect_positive && detect_negative) {
std::transform(input_data.begin(), input_data.end(), output_data,
[](MLFloat16 v) { return v.IsInfinity(); });
} else if (detect_positive) {
std::transform(input_data.begin(), input_data.end(), output_data,
[](MLFloat16 v) { return v.IsPositiveInfinity(); });
} else if (detect_negative) {
std::transform(input_data.begin(), input_data.end(), output_data,
[](MLFloat16 v) { return v.IsNegativeInfinity(); });
} else {
// all false
memset(output_data, false, input_data.size());
}
}
};

template <>
struct ComputeDispatchTarget<BFloat16> {
void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const {
auto output_data = Y.MutableData<bool>();
auto input_data = X.DataAsSpan<BFloat16>();
if (detect_positive && detect_negative) {
std::transform(input_data.begin(), input_data.end(), output_data,
[](BFloat16 v) { return v.IsInfinity(); });
} else if (detect_positive) {
std::transform(input_data.begin(), input_data.end(), output_data,
[](BFloat16 v) { return v.IsPositiveInfinity(); });
} else if (detect_negative) {
std::transform(input_data.begin(), input_data.end(), output_data,
[](BFloat16 v) { return v.IsNegativeInfinity(); });
} else {
// all false
memset(output_data, false, input_data.size());
}
}
};
Expand Down
94 changes: 94 additions & 0 deletions onnxruntime/core/providers/cuda/cu_inc/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,100 @@ __device__ __inline__ BFloat16 _Fmod(BFloat16 a, BFloat16 b) {
return fmodf((float)a, (float)b);
}

namespace isinf_details {
template <typename T>
struct IsInfTyped {
static __device__ __inline__ bool IsInf(T a) {
// cast is needed because on non MS compilers,
// because there isinf() returns int
// and we want to avoid stupid warnings
return static_cast<bool>(isinf(a));
}
static __device__ __inline__ bool IsInfPos(T a) {
return a == std::numeric_limits<T>::infinity();
}
static __device__ __inline__ bool IsInfNeg(T a) {
return a == -std::numeric_limits<T>::infinity();
}
};

template <>
struct IsInfTyped<half> {
static __device__ __inline__ bool IsInf(half a) {
return MLFloat16::kPositiveInfinityBits ==
static_cast<uint16_t>(*reinterpret_cast<uint16_t*>(&a) & ~MLFloat16::kSignMask);
}
static __device__ __inline__ bool IsInfPos(half a) {
return MLFloat16::kPositiveInfinityBits == *reinterpret_cast<uint16_t*>(&a);
}
static __device__ __inline__ bool IsInfNeg(half a) {
return MLFloat16::kNegativeInfinityBits == *reinterpret_cast<uint16_t*>(&a);
}
};

template <>
struct IsInfTyped<BFloat16> {
static __device__ __inline__ bool IsInf(BFloat16 a) {
return BFloat16::kPositiveInfinityBits ==
static_cast<uint16_t>(*reinterpret_cast<uint16_t*>(&a) & ~BFloat16::kSignMask);
}
static __device__ __inline__ bool IsInfPos(BFloat16 a) {
return BFloat16::kPositiveInfinityBits == *reinterpret_cast<uint16_t*>(&a);
}
static __device__ __inline__ bool IsInfNeg(BFloat16 a) {
return BFloat16::kNegativeInfinityBits == *reinterpret_cast<uint16_t*>(&a);
}
};

#if !defined(DISABLE_FLOAT8_TYPES)

template<typename T>
struct ReturnFalse {
constexpr static bool __device__ __inline__ IsInf(T) { return false; }
constexpr static bool __device__ __inline__ IsInfPos(T) { return false; }
constexpr static bool __device__ __inline__ IsInfNeg(T) { return false; }
};

template <>
struct IsInfTyped<Float8E4M3FN> : ReturnFalse<Float8E4M3FN> {};

template <>
struct IsInfTyped<Float8E4M3FNUZ> : ReturnFalse<Float8E4M3FNUZ> {};

template <>
struct IsInfTyped<Float8E5M2> {
static __device__ __inline__ bool IsInf(Float8E5M2 a) {
return a.val == 0b01111100 || a.val == 0b11111100;
}
static __device__ __inline__ bool IsInfPos(Float8E5M2 a) {
return a.val == 0b01111100;
}
static __device__ __inline__ bool IsInfNeg(Float8E5M2 a) {
return a.val == 0b11111100;
}
};

template <>
struct IsInfTyped<Float8E5M2FNUZ> : ReturnFalse<Float8E5M2FNUZ> {};

#endif
} // namespace isinf_details

template <typename T, bool detect_positive, bool detect_negative>
struct _IsInf {
__device__ __inline__ bool operator()(T a) const {
if constexpr (detect_positive && detect_negative) {
return isinf_details::IsInfTyped<T>::IsInf(a);
} else if constexpr (detect_positive) {
return isinf_details::IsInfTyped<T>::IsInfPos(a);
} else if constexpr (detect_negative) {
return isinf_details::IsInfTyped<T>::IsInfNeg(a);
} else {
return false;
}
}
};

// We would like to use 64-bit integer to support large matrices. However, CUDA seems to support only 32-bit integer
// For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type.
#ifndef CUDA_LONG
Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ class ToCudaType<Float8E4M3FN> {
}
};

template <>
class ToCudaType<Float8E4M3FNUZ> {
public:
typedef Float8E4M3FNUZ MappedType;
static MappedType FromFloat(float f) {
return MappedType(f);
}
};

template <>
class ToCudaType<Float8E5M2> {
public:
Expand All @@ -79,6 +88,15 @@ class ToCudaType<Float8E5M2> {
}
};

template <>
class ToCudaType<Float8E5M2FNUZ> {
public:
typedef Float8E5M2FNUZ MappedType;
static MappedType FromFloat(float f) {
return MappedType(f);
}
};

#endif

inline bool CalculateFdmStrides(gsl::span<fast_divmod> p, const std::vector<int64_t>& dims) {
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, MLFloat16, ThresholdedRelu);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, TopK);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, Mod);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 19, IsInf);

// opset 11
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Compress);
Expand Down Expand Up @@ -1342,6 +1343,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, S
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, float, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf);

template <>
KernelCreateInfo BuildKernelCreateInfo<void>() {
Expand Down Expand Up @@ -1739,6 +1741,8 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, int8_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, uint8_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, Mod)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10,
19, IsInf)>,

// opset 11
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMax)>,
Expand Down Expand Up @@ -2250,6 +2254,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, float, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf)>,
#endif
};

Expand Down
38 changes: 38 additions & 0 deletions onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,44 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa
return Status::OK(); \
}

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
IsInf,
kOnnxDomain,
10,
19,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", BuildKernelDefConstraints<float, double>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
IsInf);

ONNX_OPERATOR_KERNEL_EX(
IsInf,
kOnnxDomain,
20,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", BuildKernelDefConstraints<ISINF_OPSET20_ALL_FLOATS>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
IsInf);

IsInf::IsInf(const OpKernelInfo& info) : UnaryElementwise(info) {
detect_positive_ = static_cast<bool>(info.GetAttrOrDefault<int64_t>("detect_positive", 1));
detect_negative_ = static_cast<bool>(info.GetAttrOrDefault<int64_t>("detect_negative", 1));
opset_ = info.node().SinceVersion();
}

Status IsInf::ComputeInternal(OpKernelContext* context) const {
UnaryElementwisePreparation p;
ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p));

Explicit_Impl_IsInf(Stream(context), opset_, detect_positive_, detect_negative_,
p.input_tensor->GetElementType(), p.input_tensor->DataRaw(),
p.output_tensor->MutableData<bool>(),
p.input_tensor->Shape().Size());
return Status::OK();
}

#define UNARY_OP_VERSIONED_TYPED(name, startver, endver, T) \
UNARY_ELEMENTWISE_REGISTER_VERSIONED_KERNEL(name, startver, endver, T)

Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#pragma once

#include "core/providers/cuda/cuda_kernel.h"

namespace onnxruntime {
Expand Down Expand Up @@ -119,5 +120,16 @@ class Sign final : public UnaryElementwise {
Status ComputeInternal(OpKernelContext* context) const override;
};

class IsInf final : public UnaryElementwise {
public:
explicit IsInf(const OpKernelInfo& info);
Status ComputeInternal(OpKernelContext* context) const override;

private:
bool detect_positive_{true};
bool detect_negative_{true};
int opset_;
};

} // namespace cuda
} // namespace onnxruntime
Loading

0 comments on commit 1e78bce

Please sign in to comment.