diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 4514a85531d6b..9f5cd4cc842dc 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -162,7 +162,7 @@ Do not modify directly.*
|InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(float)|
|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)|
|||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)|
-|IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)|
+|IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)|
|||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)|
|||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)|
|LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float)|
@@ -633,6 +633,9 @@ Do not modify directly.*
|InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)|
|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)|
|||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)|
+|IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)|
+|||[13, 19]|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)|
+|||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)|
|LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|LSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)|
diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
index 7e0f919deb0a7..c3d5a51b636ef 100644
--- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
+++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
@@ -714,6 +714,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, float, IsNaN);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, double, IsNaN);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, MLFloat16, IsNaN);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, BFloat16, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool, NonZero);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, NonZero);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, NonZero);
@@ -1023,6 +1024,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, BFloat16, IsNaN);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Gelu);
#if !defined(DISABLE_FLOAT8_TYPES)
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN);
@@ -2553,6 +2555,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
#if !defined(DISABLE_FLOAT8_TYPES)
BuildKernelCreateInfo::Compute(OpKernelContext* context) const {
template <>
Status IsNaN::Compute(OpKernelContext* context) const {
const auto* X_ptr = context->Input(0);
- if (!X_ptr) {
- return Status(common::ONNXRUNTIME, common::FAIL, "Null input ptr");
- }
+
auto X_data = X_ptr->Data();
auto& dims = X_ptr->Shape();
auto shape_size = dims.Size();
@@ -91,6 +91,19 @@ Status IsNaN::Compute(OpKernelContext* context) const {
return Status::OK();
}
+template <>
+Status IsNaN::Compute(OpKernelContext* context) const {
+ const auto* X_ptr = context->Input(0);
+
+ auto X_data = X_ptr->DataAsSpan();
+ auto& Y = *context->Output(0, X_ptr->Shape());
+
+ std::transform(X_data.begin(), X_data.end(), Y.MutableData(),
+ [](BFloat16 x) { return x.IsNaN(); });
+
+ return Status::OK();
+}
+
#if !defined(DISABLE_FLOAT8_TYPES)
template <>
Status IsNaN::Compute(OpKernelContext* context) const {
diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh
index bba9178348132..bed2f677166d6 100644
--- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh
+++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh
@@ -485,7 +485,7 @@ struct IsInfTyped {
#if !defined(DISABLE_FLOAT8_TYPES)
-template
+template
struct ReturnFalse {
constexpr static bool __device__ __inline__ IsInf(T) { return false; }
constexpr static bool __device__ __inline__ IsInfPos(T) { return false; }
@@ -532,6 +532,63 @@ struct _IsInf {
}
};
+// float and double
+template
+struct _IsNan {
+ __device__ __inline__ bool operator()(T a) const {
+ return isnan(a);
+ }
+};
+
+template <>
+struct _IsNan {
+ __device__ __inline__ bool operator()(half a) const {
+ return static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask)
+ > MLFloat16::kPositiveInfinityBits;
+ }
+};
+
+template <>
+struct _IsNan {
+ __device__ __inline__ bool operator()(BFloat16 a) const {
+ return static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask)
+ > BFloat16::kPositiveInfinityBits;
+ }
+};
+
+#if !defined(DISABLE_FLOAT8_TYPES)
+
+template<>
+struct _IsNan {
+ __device__ __inline__ bool operator()(Float8E4M3FN a) const {
+ return (*reinterpret_cast(&a) & 0x7f) == 0x7f;
+ }
+};
+
+template<>
+struct _IsNan {
+ __device__ __inline__ bool operator()(Float8E4M3FNUZ a) const {
+ return *reinterpret_cast(&a) == 0x80;
+ }
+};
+
+template<>
+struct _IsNan {
+ __device__ __inline__ bool operator()(Float8E5M2 a) const {
+ uint8_t c = *reinterpret_cast(&a);
+ return ((c & 0x7c) == 0x7c) && ((c & 0x03) != 0x00);
+ }
+};
+
+template<>
+struct _IsNan {
+ __device__ __inline__ bool operator()(Float8E5M2FNUZ a) const {
+ return *reinterpret_cast(&a) == 0x80;
+ }
+};
+
+#endif
+
// We would like to use 64-bit integer to support large matrices. However, CUDA seems to support only 32-bit integer
// For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type.
#ifndef CUDA_LONG
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index bade2faf8f2e2..18c7334af6611 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -746,6 +746,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, uint32_t, Cast);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, uint64_t, Cast);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, bool, Cast);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, IsNaN);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, 10, float, Pad);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, 10, double, Pad);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, 10, MLFloat16, Pad);
@@ -938,7 +939,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom
// OpSet 12
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Clip);
-
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, float, MaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, double, MaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, MLFloat16, MaxPool);
@@ -1087,6 +1087,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, U
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Concat);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Gather);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, GatherElements);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 19, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, MatMul);
@@ -1368,6 +1369,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsNaN);
template <>
KernelCreateInfo BuildKernelCreateInfo() {
@@ -1553,6 +1555,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -1979,6 +1982,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -2279,6 +2283,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#endif
};
diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc
index 00de1b37f3302..24593b255371c 100644
--- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc
+++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc
@@ -109,6 +109,50 @@ Status IsInf::ComputeInternal(OpKernelContext* context) const {
return Status::OK();
}
+// IsNan
+ONNX_OPERATOR_VERSIONED_KERNEL_EX(
+ IsNaN,
+ kOnnxDomain,
+ 9,
+ 12,
+ kCudaExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("T1", BuildKernelDefConstraints())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ IsNaN);
+
+ONNX_OPERATOR_VERSIONED_KERNEL_EX(
+ IsNaN,
+ kOnnxDomain,
+ 13,
+ 19,
+ kCudaExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("T1", BuildKernelDefConstraints())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ IsNaN);
+
+ONNX_OPERATOR_KERNEL_EX(
+ IsNaN,
+ kOnnxDomain,
+ 20,
+ kCudaExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("T1", BuildKernelDefConstraints())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ IsNaN);
+
+Status IsNaN::ComputeInternal(OpKernelContext* context) const {
+ UnaryElementwisePreparation p;
+ ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p));
+
+ Explicit_Impl_IsNan(Stream(context), p.input_tensor->GetElementType(), p.input_tensor->DataRaw(),
+ p.output_tensor->MutableData(),
+ p.input_tensor->Shape().Size());
+
+ return Status::OK();
+}
+
#define UNARY_OP_VERSIONED_TYPED(name, startver, endver, T) \
UNARY_ELEMENTWISE_REGISTER_VERSIONED_KERNEL(name, startver, endver, T)
diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h
index 3b7d6df7221b7..95d68b5e1d534 100644
--- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h
+++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h
@@ -131,5 +131,11 @@ class IsInf final : public UnaryElementwise {
int opset_;
};
+class IsNaN : public UnaryElementwise {
+ public:
+ explicit IsNaN(const OpKernelInfo& info) : UnaryElementwise(info) {}
+ Status ComputeInternal(OpKernelContext* context) const override;
+};
+
} // namespace cuda
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu
index 554d5908cf854..2cdfcda5be26a 100644
--- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu
+++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu
@@ -315,13 +315,33 @@ void Explicit_Impl_IsInf(cudaStream_t stream, int op_set,
if (op_set < 20) {
utils::MLTypeCallDispatcher dispatcher{input_data_type};
dispatcher.Invoke(stream, input_raw, output_data,
- detect_positive, detect_negative, count);
+ detect_positive, detect_negative, count);
} else {
utils::MLTypeCallDispatcher dispatcher{input_data_type};
dispatcher.Invoke(stream, input_raw, output_data,
- detect_positive, detect_negative, count);
+ detect_positive, detect_negative, count);
}
}
+// IsNan
+
+namespace isnan_details {
+template
+struct IsNan_Disp {
+ void operator()(cudaStream_t stream, const void* input_raw, bool* output_data, size_t count) const {
+ using CudaType = typename ToCudaType::MappedType;
+ const auto* input_data = reinterpret_cast(input_raw);
+ UnaryElementWiseImpl(stream, input_data, output_data, _IsNan{}, count);
+ }
+};
+} // namespace isnan_details
+
+void Explicit_Impl_IsNan(cudaStream_t stream, int32_t input_data_type,
+ const void* input_raw, bool* output_data, size_t count) {
+ // KernelDef constraints would ensure only subset of datatypes is used.
+ utils::MLTypeCallDispatcher dispatcher{input_data_type};
+ dispatcher.Invoke(stream, input_raw, output_data, count);
+}
+
} // namespace cuda
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h
index a606d479bc79b..2588f56e32c12 100644
--- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h
+++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h
@@ -151,6 +151,20 @@ void Explicit_Impl_IsInf(cudaStream_t stream, int op_set,
int32_t input_data_type,
const void* input_raw, bool* output_data,
size_t count);
+
+// IsNan
+#define ISNAN_OPSET9_FLOATS float, double, MLFloat16
+#define ISNAN_OPSET13_FLOATS float, double, MLFloat16, BFloat16
+#if !defined(DISABLE_FLOAT8_TYPES)
+#define ISNAN_OPSET20_FLOATS float, double, MLFloat16, BFloat16, Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, \
+ Float8E5M2FNUZ
+#else
+#define ISNAN_OPSET20_FLOATS ISNAN_OPSET13_FLOATS
+#endif
+
+void Explicit_Impl_IsNan(cudaStream_t stream, int32_t input_data_type,
+ const void* input_raw, bool* output_data, size_t count);
+
} // namespace cuda
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh
index f3685606c17f5..1698e5ca8478c 100644
--- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh
+++ b/onnxruntime/core/providers/rocm/cu_inc/common.cuh
@@ -429,6 +429,63 @@ struct _IsInf {
}
};
+// float and double
+template
+struct _IsNan {
+ __device__ __inline__ bool operator()(T a) const {
+ return isnan(a);
+ }
+};
+
+template <>
+struct _IsNan {
+ __device__ __inline__ bool operator()(half a) const {
+ return static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask)
+ > MLFloat16::kPositiveInfinityBits;
+ }
+};
+
+template <>
+struct _IsNan {
+ __device__ __inline__ bool operator()(BFloat16 a) const {
+ return static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask)
+ > BFloat16::kPositiveInfinityBits;
+ }
+};
+
+#if !defined(DISABLE_FLOAT8_TYPES)
+
+template <>
+struct _IsNan {
+ __device__ __inline__ bool operator()(Float8E4M3FN a) const {
+ return (*reinterpret_cast(&a) & 0x7f) == 0x7f;
+ }
+};
+
+template <>
+struct _IsNan {
+ __device__ __inline__ bool operator()(Float8E4M3FNUZ a) const {
+ return *reinterpret_cast(&a) == 0x80;
+ }
+};
+
+template <>
+struct _IsNan {
+ __device__ __inline__ bool operator()(Float8E5M2 a) const {
+ uint8_t c = *reinterpret_cast(&a);
+ return ((c & 0x7c) == 0x7c) && ((c & 0x03) != 0x00);
+ }
+};
+
+template <>
+struct _IsNan {
+ __device__ __inline__ bool operator()(Float8E5M2FNUZ a) const {
+ return *reinterpret_cast(&a) == 0x80;
+ }
+};
+
+#endif
+
// We would like to use 64-bit integer to support large matrices. However, ROCM seems to support only 32-bit integer
// For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type.
#ifndef HIP_LONG
diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
index 32be74550951e..87daaeea969ac 100644
--- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
+++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
@@ -734,6 +734,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, Shrink);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, double, Shrink);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, MLFloat16, Shrink);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, IsNaN);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, Less);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, Less);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Less);
@@ -1067,6 +1068,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint32_t, Cast);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint64_t, Cast);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, bool, Cast);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 19, IsNaN);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, Reshape);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Shape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Size);
@@ -1346,6 +1348,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, S
// Opset 20
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, IsInf);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, IsNaN);
template <>
KernelCreateInfo BuildKernelCreateInfo() {
@@ -1531,6 +1534,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
// BuildKernelCreateInfo,
@@ -1941,6 +1945,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -2304,6 +2309,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
// opset 20
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
};
for (auto& function_table_entry : function_table) {
diff --git a/onnxruntime/test/providers/cpu/tensor/isnan_test.cc b/onnxruntime/test/providers/cpu/tensor/isnan_test.cc
index 0f1e5c07cdd9b..3cf99fde2cce7 100644
--- a/onnxruntime/test/providers/cpu/tensor/isnan_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/isnan_test.cc
@@ -38,9 +38,23 @@ TEST(IsNaNOpTest, IsNaNFloat16_9) {
run_is_nan_test(9, dims, input, output);
}
+TEST(IsNaNOpTest, IsNaNFloat16_13) {
+ std::vector dims{2, 2};
+ std::initializer_list input = {MLFloat16::One, MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN};
+ std::initializer_list output = {false, true, false, true};
+ run_is_nan_test(13, dims, input, output);
+}
+
TEST(IsNaNOpTest, IsNaNFloat16_20) {
std::vector dims{2, 2};
- std::initializer_list input = {MLFloat16(1.0f), MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN};
+ std::initializer_list input = {MLFloat16::One, MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN};
+ std::initializer_list output = {false, true, false, true};
+ run_is_nan_test(20, dims, input, output);
+}
+
+TEST(IsNaNOpTest, IsNaNBFloat16_20) {
+ std::vector dims{2, 2};
+ std::initializer_list input = {BFloat16::One, BFloat16::NaN, BFloat16(2.0f), BFloat16::NaN};
std::initializer_list output = {false, true, false, true};
run_is_nan_test(20, dims, input, output);
}