Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement IsNaN-9,13,20 for CUDA along with tests #19807

Merged
merged 3 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ Do not modify directly.*
|InstanceNormalization|*in* input:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *out* output:**T**|6+|**T** = tensor(float)|
|IsInf|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
|||[10, 19]|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(bool)|
|IsNaN|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
|IsNaN|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
|||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|LRN|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(float)|
Expand Down Expand Up @@ -633,6 +633,9 @@ Do not modify directly.*
|InstanceNormalization|*in* input:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *out* output:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)|
|IsInf|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
|||[10, 19]|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(bool)|
|IsNaN|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
|||[13, 19]|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|LRN|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|LSTM|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(int32)|
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, float, IsNaN);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, double, IsNaN);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, MLFloat16, IsNaN);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, BFloat16, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool, NonZero);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, NonZero);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, NonZero);
Expand Down Expand Up @@ -1023,6 +1024,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, BFloat16, IsNaN);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Gelu);
#if !defined(DISABLE_FLOAT8_TYPES)
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN);
Expand Down Expand Up @@ -2553,6 +2555,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16,
IsNaN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, BFloat16,
IsNaN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Gelu)>,
#if !defined(DISABLE_FLOAT8_TYPES)
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN,
Expand Down
19 changes: 16 additions & 3 deletions onnxruntime/core/providers/cpu/tensor/isnan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ ADD_TYPED_ISNAN_OP_9(MLFloat16);
ADD_TYPED_ISNAN_OP_13(float);
ADD_TYPED_ISNAN_OP_13(double);
ADD_TYPED_ISNAN_OP_13(MLFloat16);
ADD_TYPED_ISNAN_OP_13(BFloat16);
ADD_TYPED_ISNAN_OP(float);
ADD_TYPED_ISNAN_OP(double);
ADD_TYPED_ISNAN_OP(MLFloat16);
ADD_TYPED_ISNAN_OP(BFloat16);

#if !defined(DISABLE_FLOAT8_TYPES)
ADD_TYPED_ISNAN_OP(Float8E4M3FN);
Expand All @@ -75,9 +77,7 @@ Status IsNaN<T>::Compute(OpKernelContext* context) const {
template <>
Status IsNaN<MLFloat16>::Compute(OpKernelContext* context) const {
const auto* X_ptr = context->Input<Tensor>(0);
if (!X_ptr) {
return Status(common::ONNXRUNTIME, common::FAIL, "Null input ptr");
}

auto X_data = X_ptr->Data<MLFloat16>();
auto& dims = X_ptr->Shape();
auto shape_size = dims.Size();
Expand All @@ -91,6 +91,19 @@ Status IsNaN<MLFloat16>::Compute(OpKernelContext* context) const {
return Status::OK();
}

template <>
Status IsNaN<BFloat16>::Compute(OpKernelContext* context) const {
const auto* X_ptr = context->Input<Tensor>(0);

auto X_data = X_ptr->DataAsSpan<BFloat16>();
auto& Y = *context->Output(0, X_ptr->Shape());

std::transform(X_data.begin(), X_data.end(), Y.MutableData<bool>(),
[](BFloat16 x) { return x.IsNaN(); });

return Status::OK();
}

#if !defined(DISABLE_FLOAT8_TYPES)
template <>
Status IsNaN<Float8E4M3FN>::Compute(OpKernelContext* context) const {
Expand Down
59 changes: 58 additions & 1 deletion onnxruntime/core/providers/cuda/cu_inc/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@

#if !defined(DISABLE_FLOAT8_TYPES)

template<typename T>
template <typename T>
struct ReturnFalse {
constexpr static bool __device__ __inline__ IsInf(T) { return false; }
constexpr static bool __device__ __inline__ IsInfPos(T) { return false; }
Expand Down Expand Up @@ -532,6 +532,63 @@
}
};

// float and double
template <typename T>
struct _IsNan {
__device__ __inline__ bool operator()(T a) const {
return isnan(a);
}
};

template <>
struct _IsNan<half> {
__device__ __inline__ bool operator()(half a) const {
return static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&a) & ~MLFloat16::kSignMask)

Check warning on line 546 in onnxruntime/core/providers/cuda/cu_inc/common.cuh

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/core/providers/cuda/cu_inc/common.cuh:546: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
> MLFloat16::kPositiveInfinityBits;
}
};

template <>
struct _IsNan<BFloat16> {
__device__ __inline__ bool operator()(BFloat16 a) const {
return static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&a) & ~BFloat16::kSignMask)

Check warning on line 554 in onnxruntime/core/providers/cuda/cu_inc/common.cuh

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/core/providers/cuda/cu_inc/common.cuh:554: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
> BFloat16::kPositiveInfinityBits;
}
};

#if !defined(DISABLE_FLOAT8_TYPES)

template<>
struct _IsNan<Float8E4M3FN> {
__device__ __inline__ bool operator()(Float8E4M3FN a) const {
return (*reinterpret_cast<const uint8_t*>(&a) & 0x7f) == 0x7f;
}
};

template<>
struct _IsNan<Float8E4M3FNUZ> {
__device__ __inline__ bool operator()(Float8E4M3FNUZ a) const {
return *reinterpret_cast<const uint8_t*>(&a) == 0x80;
}
};

template<>
struct _IsNan<Float8E5M2> {
__device__ __inline__ bool operator()(Float8E5M2 a) const {
uint8_t c = *reinterpret_cast<const uint8_t*>(&a);
return ((c & 0x7c) == 0x7c) && ((c & 0x03) != 0x00);
}
};

template<>
struct _IsNan<Float8E5M2FNUZ> {
__device__ __inline__ bool operator()(Float8E5M2FNUZ a) const {
return *reinterpret_cast<const uint8_t*>(&a) == 0x80;
}
};

#endif

// We would like to use 64-bit integer to support large matrices. However, CUDA seems to support only 32-bit integer
// For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type.
#ifndef CUDA_LONG
Expand Down
7 changes: 6 additions & 1 deletion onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,7 @@
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, uint32_t, Cast);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, uint64_t, Cast);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, bool, Cast);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, IsNaN);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, 10, float, Pad);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, 10, double, Pad);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, 10, MLFloat16, Pad);
Expand Down Expand Up @@ -938,7 +939,6 @@

// OpSet 12
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Clip);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, float, MaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, double, MaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, MLFloat16, MaxPool);
Expand Down Expand Up @@ -1087,6 +1087,7 @@
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Concat);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Gather);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, GatherElements);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 19, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, MatMul);
Expand Down Expand Up @@ -1368,6 +1369,7 @@
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsNaN);

template <>
KernelCreateInfo BuildKernelCreateInfo<void>() {
Expand Down Expand Up @@ -1553,6 +1555,7 @@
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, float, Erf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, double, Erf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, Erf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, IsNaN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, bool, Not)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, BatchNormalization)>,
Expand Down Expand Up @@ -1979,6 +1982,7 @@
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint32_t, Cast)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint64_t, Cast)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, bool, Cast)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 19, IsNaN)>,

Check warning on line 1985 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:1985: Lines should be <= 120 characters long [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, Reshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Shape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Size)>,
Expand Down Expand Up @@ -2279,6 +2283,7 @@
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsNaN)>,
#endif
};

Expand Down
44 changes: 44 additions & 0 deletions onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,50 @@ Status IsInf::ComputeInternal(OpKernelContext* context) const {
return Status::OK();
}

// IsNan
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
IsNaN,
kOnnxDomain,
9,
12,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", BuildKernelDefConstraints<ISNAN_OPSET9_FLOATS>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
IsNaN);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
IsNaN,
kOnnxDomain,
13,
19,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", BuildKernelDefConstraints<ISNAN_OPSET13_FLOATS>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
IsNaN);

ONNX_OPERATOR_KERNEL_EX(
IsNaN,
kOnnxDomain,
20,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", BuildKernelDefConstraints<ISNAN_OPSET20_FLOATS>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
IsNaN);

Status IsNaN::ComputeInternal(OpKernelContext* context) const {
UnaryElementwisePreparation p;
ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p));

Explicit_Impl_IsNan(Stream(context), p.input_tensor->GetElementType(), p.input_tensor->DataRaw(),
p.output_tensor->MutableData<bool>(),
p.input_tensor->Shape().Size());

return Status::OK();
}

#define UNARY_OP_VERSIONED_TYPED(name, startver, endver, T) \
UNARY_ELEMENTWISE_REGISTER_VERSIONED_KERNEL(name, startver, endver, T)

Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,11 @@ class IsInf final : public UnaryElementwise {
int opset_;
};

class IsNaN : public UnaryElementwise {
public:
explicit IsNaN(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};

} // namespace cuda
} // namespace onnxruntime
24 changes: 22 additions & 2 deletions onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,33 @@ void Explicit_Impl_IsInf(cudaStream_t stream, int op_set,
if (op_set < 20) {
utils::MLTypeCallDispatcher<float, double> dispatcher{input_data_type};
dispatcher.Invoke<isinf_details::IsInf_DispFunc>(stream, input_raw, output_data,
detect_positive, detect_negative, count);
detect_positive, detect_negative, count);
} else {
utils::MLTypeCallDispatcher<ISINF_OPSET20_ALL_FLOATS> dispatcher{input_data_type};
dispatcher.Invoke<isinf_details::IsInf_DispFunc>(stream, input_raw, output_data,
detect_positive, detect_negative, count);
detect_positive, detect_negative, count);
}
}

// IsNan

namespace isnan_details {
template <typename T>
struct IsNan_Disp {
void operator()(cudaStream_t stream, const void* input_raw, bool* output_data, size_t count) const {
using CudaType = typename ToCudaType<T>::MappedType;
const auto* input_data = reinterpret_cast<const CudaType*>(input_raw);
UnaryElementWiseImpl(stream, input_data, output_data, _IsNan<CudaType>{}, count);
}
};
} // namespace isnan_details

void Explicit_Impl_IsNan(cudaStream_t stream, int32_t input_data_type,
const void* input_raw, bool* output_data, size_t count) {
// KernelDef constraints would ensure only subset of datatypes is used.
utils::MLTypeCallDispatcher<ISNAN_OPSET20_FLOATS> dispatcher{input_data_type};
dispatcher.Invoke<isnan_details::IsNan_Disp>(stream, input_raw, output_data, count);
}

} // namespace cuda
} // namespace onnxruntime
14 changes: 14 additions & 0 deletions onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,20 @@ void Explicit_Impl_IsInf(cudaStream_t stream, int op_set,
int32_t input_data_type,
const void* input_raw, bool* output_data,
size_t count);

// IsNan
#define ISNAN_OPSET9_FLOATS float, double, MLFloat16
#define ISNAN_OPSET13_FLOATS float, double, MLFloat16, BFloat16
#if !defined(DISABLE_FLOAT8_TYPES)
#define ISNAN_OPSET20_FLOATS float, double, MLFloat16, BFloat16, Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, \
Float8E5M2FNUZ
#else
#define ISNAN_OPSET20_FLOATS ISNAN_OPSET13_FLOATS
#endif

void Explicit_Impl_IsNan(cudaStream_t stream, int32_t input_data_type,
const void* input_raw, bool* output_data, size_t count);

} // namespace cuda

} // namespace onnxruntime
57 changes: 57 additions & 0 deletions onnxruntime/core/providers/rocm/cu_inc/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,63 @@
}
};

// float and double
template <typename T>
struct _IsNan {
__device__ __inline__ bool operator()(T a) const {
return isnan(a);
}
};

template <>
struct _IsNan<half> {
__device__ __inline__ bool operator()(half a) const {
return static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&a) & ~MLFloat16::kSignMask)

Check warning on line 443 in onnxruntime/core/providers/rocm/cu_inc/common.cuh

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/core/providers/rocm/cu_inc/common.cuh:443: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
> MLFloat16::kPositiveInfinityBits;
}
};

template <>
struct _IsNan<BFloat16> {
__device__ __inline__ bool operator()(BFloat16 a) const {
return static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&a) & ~BFloat16::kSignMask)

Check warning on line 451 in onnxruntime/core/providers/rocm/cu_inc/common.cuh

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/core/providers/rocm/cu_inc/common.cuh:451: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
> BFloat16::kPositiveInfinityBits;
}
};

#if !defined(DISABLE_FLOAT8_TYPES)

template <>
struct _IsNan<Float8E4M3FN> {
__device__ __inline__ bool operator()(Float8E4M3FN a) const {
return (*reinterpret_cast<const uint8_t*>(&a) & 0x7f) == 0x7f;
}
};

template <>
struct _IsNan<Float8E4M3FNUZ> {
__device__ __inline__ bool operator()(Float8E4M3FNUZ a) const {
return *reinterpret_cast<const uint8_t*>(&a) == 0x80;
}
};

template <>
struct _IsNan<Float8E5M2> {
__device__ __inline__ bool operator()(Float8E5M2 a) const {
uint8_t c = *reinterpret_cast<const uint8_t*>(&a);
return ((c & 0x7c) == 0x7c) && ((c & 0x03) != 0x00);
}
};

template <>
struct _IsNan<Float8E5M2FNUZ> {
__device__ __inline__ bool operator()(Float8E5M2FNUZ a) const {
return *reinterpret_cast<const uint8_t*>(&a) == 0x80;
}
};

#endif

// We would like to use 64-bit integer to support large matrices. However, ROCM seems to support only 32-bit integer
// For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type.
#ifndef HIP_LONG
Expand Down
Loading
Loading