Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement CUDA IsInf-10,20 #19772

Merged
merged 7 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/onnxruntime/core/framework/data_types_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ class CallableDispatchableHelper {
return 0;
}

void CheckCalledOnce() {
void CheckCalledOnce() const {
ORT_ENFORCE(called_ == 1, "Unsupported data type: ", dt_type_);
}
};
Expand Down
58 changes: 49 additions & 9 deletions onnxruntime/core/providers/cpu/tensor/isinf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(
using IsInfTypesOpset20 =
TypeList<
float,
double
double,
MLFloat16,
BFloat16
#if !defined(DISABLE_FLOAT8_TYPES)
,
Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
Expand Down Expand Up @@ -87,29 +89,67 @@ namespace isinf_internal {
template <class T>
struct ComputeDispatchTarget {
void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const {
const auto total_items = X.Shape().Size();
auto input_data = X.DataAsSpan<T>();
auto output_data = Y.MutableData<bool>();

if (detect_positive && detect_negative) {
EigenMap<bool>(Y) = EigenMap<T>(X).array().isInf();
} else if (detect_positive) {
auto input_data = X.Data<T>();
auto end_data = input_data + total_items;
std::transform(
input_data, end_data, output_data, [](T v) {
input_data.begin(), input_data.end(), output_data, [](T v) {
return (v == std::numeric_limits<T>::infinity());
});

} else if (detect_negative) {
auto input_data = X.Data<T>();
auto end_data = input_data + total_items;
std::transform(
input_data, end_data, output_data, [](T v) {
input_data.begin(), input_data.end(), output_data, [](T v) {
return (v == -std::numeric_limits<T>::infinity());
});
} else {
// all false
memset(output_data, false, onnxruntime::narrow<size_t>(total_items));
memset(output_data, false, input_data.size());
}
}
};

template <>
struct ComputeDispatchTarget<MLFloat16> {
void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const {
auto output_data = Y.MutableData<bool>();
auto input_data = X.DataAsSpan<MLFloat16>();
if (detect_positive && detect_negative) {
std::transform(input_data.begin(), input_data.end(), output_data,
[](MLFloat16 v) { return v.IsInfinity(); });
} else if (detect_positive) {
std::transform(input_data.begin(), input_data.end(), output_data,
[](MLFloat16 v) { return v.IsPositiveInfinity(); });
} else if (detect_negative) {
std::transform(input_data.begin(), input_data.end(), output_data,
[](MLFloat16 v) { return v.IsNegativeInfinity(); });
} else {
// all false
memset(output_data, false, input_data.size());
}
}
};

template <>
struct ComputeDispatchTarget<BFloat16> {
void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const {
auto output_data = Y.MutableData<bool>();
auto input_data = X.DataAsSpan<BFloat16>();
if (detect_positive && detect_negative) {
std::transform(input_data.begin(), input_data.end(), output_data,
[](BFloat16 v) { return v.IsInfinity(); });
} else if (detect_positive) {
std::transform(input_data.begin(), input_data.end(), output_data,
[](BFloat16 v) { return v.IsPositiveInfinity(); });
} else if (detect_negative) {
std::transform(input_data.begin(), input_data.end(), output_data,
[](BFloat16 v) { return v.IsNegativeInfinity(); });
} else {
// all false
memset(output_data, false, input_data.size());
}
}
};
Expand Down
94 changes: 94 additions & 0 deletions onnxruntime/core/providers/cuda/cu_inc/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,100 @@
return fmodf((float)a, (float)b);
}

namespace isinf_details {
template <typename T>
struct IsInfTyped {
static __device__ __inline__ bool IsInf(T a) {
// cast is needed because on non MS compilers,
// because there isinf() returns int
// and we want to avoid stupid warnings
return static_cast<bool>(isinf(a));
};

Check warning on line 449 in onnxruntime/core/providers/cuda/cu_inc/common.cuh

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnxruntime/core/providers/cuda/cu_inc/common.cuh:449: You don't need a ; after a } [readability/braces] [4]
static __device__ __inline__ bool IsInfPos(T a) {
return a == std::numeric_limits<T>::infinity();
}
static __device__ __inline__ bool IsInfNeg(T a) {
return a == -std::numeric_limits<T>::infinity();

Check warning on line 454 in onnxruntime/core/providers/cuda/cu_inc/common.cuh

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <limits> for numeric_limits<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/cuda/cu_inc/common.cuh:454: Add #include <limits> for numeric_limits<> [build/include_what_you_use] [4]
}
};

template <>
struct IsInfTyped<half> {
static __device__ __inline__ bool IsInf(half a) {
return MLFloat16::kPositiveInfinityBits ==
static_cast<uint16_t>(*reinterpret_cast<uint16_t*>(&a) & ~MLFloat16::kSignMask);
};

Check warning on line 463 in onnxruntime/core/providers/cuda/cu_inc/common.cuh

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnxruntime/core/providers/cuda/cu_inc/common.cuh:463: You don't need a ; after a } [readability/braces] [4]
static __device__ __inline__ bool IsInfPos(half a) {
return MLFloat16::kPositiveInfinityBits == *reinterpret_cast<uint16_t*>(&a);
}
static __device__ __inline__ bool IsInfNeg(half a) {
return MLFloat16::kNegativeInfinityBits == *reinterpret_cast<uint16_t*>(&a);
}
};

template <>
struct IsInfTyped<BFloat16> {
static __device__ __inline__ bool IsInf(BFloat16 a) {
return BFloat16::kPositiveInfinityBits ==
static_cast<uint16_t>(*reinterpret_cast<uint16_t*>(&a) & ~BFloat16::kSignMask);
};

Check warning on line 477 in onnxruntime/core/providers/cuda/cu_inc/common.cuh

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnxruntime/core/providers/cuda/cu_inc/common.cuh:477: You don't need a ; after a } [readability/braces] [4]
static __device__ __inline__ bool IsInfPos(BFloat16 a) {
return BFloat16::kPositiveInfinityBits == *reinterpret_cast<uint16_t*>(&a);
}
static __device__ __inline__ bool IsInfNeg(BFloat16 a) {
return BFloat16::kNegativeInfinityBits == *reinterpret_cast<uint16_t*>(&a);
}
};

#if !defined(DISABLE_FLOAT8_TYPES)

template<typename T>
struct ReturnFalse {
constexpr static bool __device__ __inline__ IsInf(T) { return false; }
constexpr static bool __device__ __inline__ IsInfPos(T) { return false; }
constexpr static bool __device__ __inline__ IsInfNeg(T) { return false; }
};

template <>
struct IsInfTyped<Float8E4M3FN> : ReturnFalse<Float8E4M3FN> {};

template <>
struct IsInfTyped<Float8E4M3FNUZ> : ReturnFalse<Float8E4M3FNUZ> {};

template <>
struct IsInfTyped<Float8E5M2> {
static __device__ __inline__ bool IsInf(Float8E5M2 a) {
return a.val == 0b01111100 || a.val == 0b11111100;
}
static __device__ __inline__ bool IsInfPos(Float8E5M2 a) {
return a.val == 0b01111100;
}
static __device__ __inline__ bool IsInfNeg(Float8E5M2 a) {
return a.val == 0b11111100;
}
};

template <>
struct IsInfTyped<Float8E5M2FNUZ> : ReturnFalse<Float8E5M2FNUZ> {};

#endif
} // namespace isinf_details

template <typename T, bool detect_positive, bool detect_negative>
struct _IsInf {
__device__ __inline__ bool operator()(T a) const {
if constexpr (detect_positive && detect_negative) {
return isinf_details::IsInfTyped<T>::IsInf(a);
} else if constexpr (detect_positive) {
return isinf_details::IsInfTyped<T>::IsInfPos(a);
} else if constexpr (detect_negative) {
return isinf_details::IsInfTyped<T>::IsInfNeg(a);
} else {
return false;
}
}
};

// We would like to use 64-bit integer to support large matrices. However, CUDA seems to support only 32-bit integer
// For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type.
#ifndef CUDA_LONG
Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ class ToCudaType<Float8E4M3FN> {
}
};

template <>
class ToCudaType<Float8E4M3FNUZ> {
public:
typedef Float8E4M3FNUZ MappedType;
static MappedType FromFloat(float f) {
return MappedType(f);
}
};

template <>
class ToCudaType<Float8E5M2> {
public:
Expand All @@ -79,6 +88,15 @@ class ToCudaType<Float8E5M2> {
}
};

template <>
class ToCudaType<Float8E5M2FNUZ> {
public:
typedef Float8E5M2FNUZ MappedType;
static MappedType FromFloat(float f) {
return MappedType(f);
}
};

#endif

inline bool CalculateFdmStrides(gsl::span<fast_divmod> p, const std::vector<int64_t>& dims) {
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, MLFloat16, ThresholdedRelu);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, TopK);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, Mod);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 19, IsInf);

// opset 11
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Compress);
Expand Down Expand Up @@ -1339,6 +1340,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, S
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, float, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf);

template <>
KernelCreateInfo BuildKernelCreateInfo<void>() {
Expand Down Expand Up @@ -1736,6 +1738,8 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, int8_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, uint8_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, Mod)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10,
19, IsInf)>,

// opset 11
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMax)>,
Expand Down Expand Up @@ -2244,6 +2248,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, float, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf)>,
#endif
};

Expand Down
50 changes: 50 additions & 0 deletions onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,56 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa
\
return Status::OK(); \
}
#if !defined(DISABLE_FLOAT8_TYPES)
#define OPSET20_CONSTRAINTS float, double, MLFloat16, BFloat16, Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, \
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
Float8E5M2FNUZ
#else
#define OPSET20_CONSTRAINTS float, double, MLFloat16, BFloat16
#endif

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
IsInf,
kOnnxDomain,
10,
19,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", BuildKernelDefConstraints<float, double>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
IsInf);

ONNX_OPERATOR_KERNEL_EX(
IsInf,
kOnnxDomain,
20,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", BuildKernelDefConstraints<OPSET20_CONSTRAINTS>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
IsInf);

IsInf::IsInf(const OpKernelInfo& info) : UnaryElementwise(info) {
int64_t detect_positive = 0, detect_negative = 0;
Status status = info.GetAttr("detect_positive", &detect_positive);
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_positive");
status = info.GetAttr("detect_negative", &detect_negative);
ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_negative");

detect_positive_ = static_cast<bool>(detect_positive);
detect_negative_ = static_cast<bool>(detect_negative);
opset_ = info.node().SinceVersion();
}

Status IsInf::ComputeInternal(OpKernelContext* context) const {
UnaryElementwisePreparation p;
ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p));

Explicit_Impl_IsInf(Stream(context), opset_, detect_positive_, detect_negative_,
p.input_tensor->GetElementType(), p.input_tensor->DataRaw(),
p.output_tensor->MutableData<bool>(),
p.input_tensor->Shape().Size());
return Status::OK();
}

#define UNARY_OP_VERSIONED_TYPED(name, startver, endver, T) \
UNARY_ELEMENTWISE_REGISTER_VERSIONED_KERNEL(name, startver, endver, T)
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#pragma once

#include "core/providers/cuda/cuda_kernel.h"

namespace onnxruntime {
Expand Down Expand Up @@ -119,5 +120,16 @@ class Sign final : public UnaryElementwise {
Status ComputeInternal(OpKernelContext* context) const override;
};

class IsInf final : public UnaryElementwise {
public:
explicit IsInf(const OpKernelInfo& info);
Status ComputeInternal(OpKernelContext* context) const override;

private:
bool detect_positive_{true};
bool detect_negative_{true};
int opset_;
};

} // namespace cuda
} // namespace onnxruntime
Loading
Loading