diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 71b0def659741..4514a85531d6b 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -160,7 +160,7 @@ Do not modify directly.*
|||[1, 10]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|ImageScaler|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)|
|InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(float)|
-|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)|
+|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)|
|||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)|
|IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)|
|||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)|
@@ -631,6 +631,8 @@ Do not modify directly.*
|||[1, 10]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|ImageScaler|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)|
+|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)|
+|||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)|
|LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|LSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)|
diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h
index fbeee8a2aedc5..3a3b5cb6888f2 100644
--- a/include/onnxruntime/core/framework/data_types_internal.h
+++ b/include/onnxruntime/core/framework/data_types_internal.h
@@ -305,7 +305,7 @@ class CallableDispatchableHelper {
return 0;
}
- void CheckCalledOnce() {
+ void CheckCalledOnce() const {
ORT_ENFORCE(called_ == 1, "Unsupported data type: ", dt_type_);
}
};
diff --git a/onnxruntime/core/providers/cpu/tensor/isinf.cc b/onnxruntime/core/providers/cpu/tensor/isinf.cc
index 1b449f46927a2..9d18d1fa62288 100644
--- a/onnxruntime/core/providers/cpu/tensor/isinf.cc
+++ b/onnxruntime/core/providers/cpu/tensor/isinf.cc
@@ -23,7 +23,9 @@ ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(
using IsInfTypesOpset20 =
TypeList<
float,
- double
+ double,
+ MLFloat16,
+ BFloat16
#if !defined(DISABLE_FLOAT8_TYPES)
,
Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
@@ -76,10 +78,8 @@ ONNX_CPU_OPERATOR_KERNEL(
IsInf);
IsInf::IsInf(const OpKernelInfo& info) : OpKernel(info) {
- Status status = info.GetAttr("detect_positive", &detect_positive_);
- ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_positive");
- status = info.GetAttr("detect_negative", &detect_negative_);
- ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_negative");
+ detect_positive_ = info.GetAttrOrDefault("detect_positive", 1);
+ detect_negative_ = info.GetAttrOrDefault("detect_negative", 1);
opset_ = info.node().SinceVersion();
}
@@ -87,29 +87,67 @@ namespace isinf_internal {
template
struct ComputeDispatchTarget {
void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const {
- const auto total_items = X.Shape().Size();
+ auto input_data = X.DataAsSpan();
auto output_data = Y.MutableData();
if (detect_positive && detect_negative) {
EigenMap(Y) = EigenMap(X).array().isInf();
} else if (detect_positive) {
- auto input_data = X.Data();
- auto end_data = input_data + total_items;
std::transform(
- input_data, end_data, output_data, [](T v) {
+ input_data.begin(), input_data.end(), output_data, [](T v) {
return (v == std::numeric_limits::infinity());
});
} else if (detect_negative) {
- auto input_data = X.Data();
- auto end_data = input_data + total_items;
std::transform(
- input_data, end_data, output_data, [](T v) {
+ input_data.begin(), input_data.end(), output_data, [](T v) {
return (v == -std::numeric_limits::infinity());
});
} else {
// all false
- memset(output_data, false, onnxruntime::narrow(total_items));
+ memset(output_data, false, input_data.size());
+ }
+ }
+};
+
+template <>
+struct ComputeDispatchTarget {
+ void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const {
+ auto output_data = Y.MutableData();
+ auto input_data = X.DataAsSpan();
+ if (detect_positive && detect_negative) {
+ std::transform(input_data.begin(), input_data.end(), output_data,
+ [](MLFloat16 v) { return v.IsInfinity(); });
+ } else if (detect_positive) {
+ std::transform(input_data.begin(), input_data.end(), output_data,
+ [](MLFloat16 v) { return v.IsPositiveInfinity(); });
+ } else if (detect_negative) {
+ std::transform(input_data.begin(), input_data.end(), output_data,
+ [](MLFloat16 v) { return v.IsNegativeInfinity(); });
+ } else {
+ // all false
+ memset(output_data, false, input_data.size());
+ }
+ }
+};
+
+template <>
+struct ComputeDispatchTarget {
+ void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const {
+ auto output_data = Y.MutableData();
+ auto input_data = X.DataAsSpan();
+ if (detect_positive && detect_negative) {
+ std::transform(input_data.begin(), input_data.end(), output_data,
+ [](BFloat16 v) { return v.IsInfinity(); });
+ } else if (detect_positive) {
+ std::transform(input_data.begin(), input_data.end(), output_data,
+ [](BFloat16 v) { return v.IsPositiveInfinity(); });
+ } else if (detect_negative) {
+ std::transform(input_data.begin(), input_data.end(), output_data,
+ [](BFloat16 v) { return v.IsNegativeInfinity(); });
+ } else {
+ // all false
+ memset(output_data, false, input_data.size());
}
}
};
diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh
index 66794f88d8670..bba9178348132 100644
--- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh
+++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh
@@ -438,6 +438,100 @@ __device__ __inline__ BFloat16 _Fmod(BFloat16 a, BFloat16 b) {
return fmodf((float)a, (float)b);
}
+namespace isinf_details {
+template
+struct IsInfTyped {
+ static __device__ __inline__ bool IsInf(T a) {
+ // cast is needed because on non MS compilers,
+ // because there isinf() returns int
+ // and we want to avoid stupid warnings
+ return static_cast(isinf(a));
+ }
+ static __device__ __inline__ bool IsInfPos(T a) {
+ return a == std::numeric_limits::infinity();
+ }
+ static __device__ __inline__ bool IsInfNeg(T a) {
+ return a == -std::numeric_limits::infinity();
+ }
+};
+
+template <>
+struct IsInfTyped {
+ static __device__ __inline__ bool IsInf(half a) {
+ return MLFloat16::kPositiveInfinityBits ==
+ static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask);
+ }
+ static __device__ __inline__ bool IsInfPos(half a) {
+ return MLFloat16::kPositiveInfinityBits == *reinterpret_cast(&a);
+ }
+ static __device__ __inline__ bool IsInfNeg(half a) {
+ return MLFloat16::kNegativeInfinityBits == *reinterpret_cast(&a);
+ }
+};
+
+template <>
+struct IsInfTyped {
+ static __device__ __inline__ bool IsInf(BFloat16 a) {
+ return BFloat16::kPositiveInfinityBits ==
+ static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask);
+ }
+ static __device__ __inline__ bool IsInfPos(BFloat16 a) {
+ return BFloat16::kPositiveInfinityBits == *reinterpret_cast(&a);
+ }
+ static __device__ __inline__ bool IsInfNeg(BFloat16 a) {
+ return BFloat16::kNegativeInfinityBits == *reinterpret_cast(&a);
+ }
+};
+
+#if !defined(DISABLE_FLOAT8_TYPES)
+
+template
+struct ReturnFalse {
+ constexpr static bool __device__ __inline__ IsInf(T) { return false; }
+ constexpr static bool __device__ __inline__ IsInfPos(T) { return false; }
+ constexpr static bool __device__ __inline__ IsInfNeg(T) { return false; }
+};
+
+template <>
+struct IsInfTyped : ReturnFalse {};
+
+template <>
+struct IsInfTyped : ReturnFalse {};
+
+template <>
+struct IsInfTyped {
+ static __device__ __inline__ bool IsInf(Float8E5M2 a) {
+ return a.val == 0b01111100 || a.val == 0b11111100;
+ }
+ static __device__ __inline__ bool IsInfPos(Float8E5M2 a) {
+ return a.val == 0b01111100;
+ }
+ static __device__ __inline__ bool IsInfNeg(Float8E5M2 a) {
+ return a.val == 0b11111100;
+ }
+};
+
+template <>
+struct IsInfTyped : ReturnFalse {};
+
+#endif
+} // namespace isinf_details
+
+template
+struct _IsInf {
+ __device__ __inline__ bool operator()(T a) const {
+ if constexpr (detect_positive && detect_negative) {
+ return isinf_details::IsInfTyped::IsInf(a);
+ } else if constexpr (detect_positive) {
+ return isinf_details::IsInfTyped::IsInfPos(a);
+ } else if constexpr (detect_negative) {
+ return isinf_details::IsInfTyped::IsInfNeg(a);
+ } else {
+ return false;
+ }
+ }
+};
+
// We would like to use 64-bit integer to support large matrices. However, CUDA seems to support only 32-bit integer
// For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type.
#ifndef CUDA_LONG
diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h
index 41c999bacee13..61da125b40953 100644
--- a/onnxruntime/core/providers/cuda/cuda_common.h
+++ b/onnxruntime/core/providers/cuda/cuda_common.h
@@ -70,6 +70,15 @@ class ToCudaType {
}
};
+template <>
+class ToCudaType {
+ public:
+ typedef Float8E4M3FNUZ MappedType;
+ static MappedType FromFloat(float f) {
+ return MappedType(f);
+ }
+};
+
template <>
class ToCudaType {
public:
@@ -79,6 +88,15 @@ class ToCudaType {
}
};
+template <>
+class ToCudaType {
+ public:
+ typedef Float8E5M2FNUZ MappedType;
+ static MappedType FromFloat(float f) {
+ return MappedType(f);
+ }
+};
+
#endif
inline bool CalculateFdmStrides(gsl::span p, const std::vector& dims) {
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index 8ba282031a5d4..3c0930638a205 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -830,6 +830,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, MLFloat16, ThresholdedRelu);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, TopK);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, Mod);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 19, IsInf);
// opset 11
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Compress);
@@ -1342,6 +1343,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, S
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, float, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf);
template <>
KernelCreateInfo BuildKernelCreateInfo() {
@@ -1739,6 +1741,8 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
// opset 11
BuildKernelCreateInfo,
@@ -2250,6 +2254,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#endif
};
diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc
index fd8b69d7bd2f5..00de1b37f3302 100644
--- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc
+++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc
@@ -71,6 +71,44 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa
return Status::OK(); \
}
+ONNX_OPERATOR_VERSIONED_KERNEL_EX(
+ IsInf,
+ kOnnxDomain,
+ 10,
+ 19,
+ kCudaExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("T1", BuildKernelDefConstraints())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ IsInf);
+
+ONNX_OPERATOR_KERNEL_EX(
+ IsInf,
+ kOnnxDomain,
+ 20,
+ kCudaExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("T1", BuildKernelDefConstraints())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ IsInf);
+
+IsInf::IsInf(const OpKernelInfo& info) : UnaryElementwise(info) {
+ detect_positive_ = static_cast(info.GetAttrOrDefault("detect_positive", 1));
+ detect_negative_ = static_cast(info.GetAttrOrDefault("detect_negative", 1));
+ opset_ = info.node().SinceVersion();
+}
+
+Status IsInf::ComputeInternal(OpKernelContext* context) const {
+ UnaryElementwisePreparation p;
+ ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p));
+
+ Explicit_Impl_IsInf(Stream(context), opset_, detect_positive_, detect_negative_,
+ p.input_tensor->GetElementType(), p.input_tensor->DataRaw(),
+ p.output_tensor->MutableData(),
+ p.input_tensor->Shape().Size());
+ return Status::OK();
+}
+
#define UNARY_OP_VERSIONED_TYPED(name, startver, endver, T) \
UNARY_ELEMENTWISE_REGISTER_VERSIONED_KERNEL(name, startver, endver, T)
diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h
index 775b78c43a736..3b7d6df7221b7 100644
--- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h
+++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h
@@ -2,6 +2,7 @@
// Licensed under the MIT License.
#pragma once
+
#include "core/providers/cuda/cuda_kernel.h"
namespace onnxruntime {
@@ -119,5 +120,16 @@ class Sign final : public UnaryElementwise {
Status ComputeInternal(OpKernelContext* context) const override;
};
+class IsInf final : public UnaryElementwise {
+ public:
+ explicit IsInf(const OpKernelInfo& info);
+ Status ComputeInternal(OpKernelContext* context) const override;
+
+ private:
+ bool detect_positive_{true};
+ bool detect_negative_{true};
+ int opset_;
+};
+
} // namespace cuda
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu
index 73c5ac80756be..fd8f7929d4426 100644
--- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu
+++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu
@@ -11,6 +11,7 @@
#endif
namespace onnxruntime {
+
namespace cuda {
#define OP(name, expr) \
@@ -284,5 +285,42 @@ EXPLICIT_IMPL_CASTSAT(__nv_bfloat16, Float8E5M2)
#endif
+namespace isinf_details {
+template
+struct IsInf_DispFunc {
+ void operator()(cudaStream_t stream, const void* input_raw, bool* output_data,
+ bool detect_positive, bool detect_negative, size_t count) const {
+ using CudaType = typename ToCudaType::MappedType;
+ const auto* input_data = reinterpret_cast(input_raw);
+ if (detect_positive && detect_negative) {
+ UnaryElementWiseImpl(stream, input_data, output_data, _IsInf{}, count);
+ } else if (detect_positive) {
+ UnaryElementWiseImpl(stream, input_data, output_data, _IsInf{}, count);
+ } else if (detect_negative) {
+ UnaryElementWiseImpl(stream, input_data, output_data, _IsInf{}, count);
+ } else {
+ UnaryElementWiseImpl(stream, input_data, output_data, _IsInf{}, count);
+ }
+ }
+};
+
+} // namespace isinf_details
+
+void Explicit_Impl_IsInf(cudaStream_t stream, int op_set,
+ bool detect_positive, bool detect_negative,
+ int32_t input_data_type,
+ const void* input_raw, bool* output_data,
+ size_t count) {
+ if (op_set < 20) {
+ utils::MLTypeCallDispatcher dispatcher{input_data_type};
+ dispatcher.Invoke(stream, input_raw, output_data,
+ detect_positive, detect_negative, count);
+ } else {
+ utils::MLTypeCallDispatcher dispatcher{input_data_type};
+ dispatcher.Invoke(stream, input_raw, output_data,
+ detect_positive, detect_negative, count);
+ }
+}
+
} // namespace cuda
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h
index 608a81a24cf4f..a606d479bc79b 100644
--- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h
+++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h
@@ -137,5 +137,20 @@ void Impl_CastSat(
#endif
+// IsInf
+
+#if !defined(DISABLE_FLOAT8_TYPES)
+#define ISINF_OPSET20_ALL_FLOATS float, double, MLFloat16, BFloat16, Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, \
+ Float8E5M2FNUZ
+#else
+#define ISINF_OPSET20_ALL_FLOATS float, double, MLFloat16, BFloat16
+#endif
+
+void Explicit_Impl_IsInf(cudaStream_t stream, int op_set,
+ bool detect_positive, bool detect_negative,
+ int32_t input_data_type,
+ const void* input_raw, bool* output_data,
+ size_t count);
} // namespace cuda
+
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh
index 5f966ac746fcb..f3685606c17f5 100644
--- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh
+++ b/onnxruntime/core/providers/rocm/cu_inc/common.cuh
@@ -335,6 +335,100 @@ __device__ __inline__ BFloat16 _Fmod(BFloat16 a, BFloat16 b) {
return fmodf((float)a, (float)b);
}
+namespace isinf_details {
+template
+struct IsInfTyped {
+ static __device__ __inline__ bool IsInf(T a) {
+ // cast is needed because on non MS compilers,
+ // because there isinf() returns int
+ // and we want to avoid stupid warnings
+ return static_cast(isinf(a));
+ }
+ static __device__ __inline__ bool IsInfPos(T a) {
+ return a == std::numeric_limits::infinity();
+ }
+ static __device__ __inline__ bool IsInfNeg(T a) {
+ return a == -std::numeric_limits::infinity();
+ }
+};
+
+template <>
+struct IsInfTyped {
+ static __device__ __inline__ bool IsInf(half a) {
+ return MLFloat16::kPositiveInfinityBits ==
+ static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask);
+ }
+ static __device__ __inline__ bool IsInfPos(half a) {
+ return MLFloat16::kPositiveInfinityBits == *reinterpret_cast(&a);
+ }
+ static __device__ __inline__ bool IsInfNeg(half a) {
+ return MLFloat16::kNegativeInfinityBits == *reinterpret_cast(&a);
+ }
+};
+
+template <>
+struct IsInfTyped {
+ static __device__ __inline__ bool IsInf(BFloat16 a) {
+ return BFloat16::kPositiveInfinityBits ==
+ static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask);
+ }
+ static __device__ __inline__ bool IsInfPos(BFloat16 a) {
+ return BFloat16::kPositiveInfinityBits == *reinterpret_cast(&a);
+ }
+ static __device__ __inline__ bool IsInfNeg(BFloat16 a) {
+ return BFloat16::kNegativeInfinityBits == *reinterpret_cast(&a);
+ }
+};
+
+#if !defined(DISABLE_FLOAT8_TYPES)
+
+template
+struct ReturnFalse {
+ constexpr static bool __device__ __inline__ IsInf(T) { return false; }
+ constexpr static bool __device__ __inline__ IsInfPos(T) { return false; }
+ constexpr static bool __device__ __inline__ IsInfNeg(T) { return false; }
+};
+
+template <>
+struct IsInfTyped : ReturnFalse {};
+
+template <>
+struct IsInfTyped : ReturnFalse {};
+
+template <>
+struct IsInfTyped {
+ static __device__ __inline__ bool IsInf(Float8E5M2 a) {
+ return a.val == 0b01111100 || a.val == 0b11111100;
+ }
+ static __device__ __inline__ bool IsInfPos(Float8E5M2 a) {
+ return a.val == 0b01111100;
+ }
+ static __device__ __inline__ bool IsInfNeg(Float8E5M2 a) {
+ return a.val == 0b11111100;
+ }
+};
+
+template <>
+struct IsInfTyped : ReturnFalse {};
+
+#endif
+} // namespace isinf_details
+
+template
+struct _IsInf {
+ __device__ __inline__ bool operator()(T a) const {
+ if constexpr (detect_positive && detect_negative) {
+ return isinf_details::IsInfTyped::IsInf(a);
+ } else if constexpr (detect_positive) {
+ return isinf_details::IsInfTyped::IsInfPos(a);
+ } else if constexpr (detect_negative) {
+ return isinf_details::IsInfTyped::IsInfNeg(a);
+ } else {
+ return false;
+ }
+ }
+};
+
// We would like to use 64-bit integer to support large matrices. However, ROCM seems to support only 32-bit integer
// For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type.
#ifndef HIP_LONG
diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
index 0265c06b9a938..4a679b790ee40 100644
--- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
+++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
@@ -793,6 +793,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, MLFloat16, ThresholdedRelu);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, TopK);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, Mod);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 19, IsInf);
// opset 11
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, float, ArgMax);
@@ -1342,6 +1343,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, R
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Scan);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Shape);
+// Opset 20
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, IsInf);
+
template <>
KernelCreateInfo BuildKernelCreateInfo() {
return {};
@@ -1738,6 +1742,8 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
// opset 11
BuildKernelCreateInfo,
@@ -2294,6 +2300,9 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+
+ // opset 20
+ BuildKernelCreateInfo,
};
for (auto& function_table_entry : function_table) {
diff --git a/onnxruntime/test/providers/cpu/tensor/isinf_test.cc b/onnxruntime/test/providers/cpu/tensor/isinf_test.cc
index 2e583c5d2547b..bd97306142f18 100644
--- a/onnxruntime/test/providers/cpu/tensor/isinf_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/isinf_test.cc
@@ -99,6 +99,48 @@ TEST(IsInfTest, test_isinf_negative_double20) {
run_is_inf_test(20, 0, 1, input, output);
}
+TEST(IsInfTest, test_isinf_mlfloat16) {
+ std::initializer_list input = {MLFloat16{-1.7f}, MLFloat16::NaN, MLFloat16::Infinity, 3.6_fp16,
+ MLFloat16::NegativeInfinity, MLFloat16::Infinity};
+ std::initializer_list output = {false, false, true, false, true, true};
+ run_is_inf_test(20, 1, 1, input, output);
+}
+
+TEST(IsInfTest, test_isinf_positive_mlfloat16) {
+ std::initializer_list input = {MLFloat16{-1.7f}, MLFloat16::NaN, MLFloat16::Infinity, 3.6_fp16,
+ MLFloat16::NegativeInfinity, MLFloat16::Infinity};
+ std::initializer_list output = {false, false, true, false, false, true};
+ run_is_inf_test(20, 1, 0, input, output);
+}
+
+TEST(IsInfTest, test_isinf_negative_mlfloat16) {
+ std::initializer_list input = {MLFloat16{-1.7f}, MLFloat16::NaN, MLFloat16::Infinity, 3.6_fp16,
+ MLFloat16::NegativeInfinity, MLFloat16::Infinity};
+ std::initializer_list output = {false, false, false, false, true, false};
+ run_is_inf_test(20, 0, 1, input, output);
+}
+
+TEST(IsInfTest, test_isinf_bfloat16) {
+ std::initializer_list input = {BFloat16{-1.7f}, BFloat16::NaN, BFloat16::Infinity, 3.6_bfp16,
+ BFloat16::NegativeInfinity, BFloat16::Infinity};
+ std::initializer_list output = {false, false, true, false, true, true};
+ run_is_inf_test(20, 1, 1, input, output);
+}
+
+TEST(IsInfTest, test_isinf_positive_bfloat16) {
+ std::initializer_list input = {BFloat16{-1.7f}, BFloat16::NaN, BFloat16::Infinity, 3.6_bfp16,
+ BFloat16::NegativeInfinity, BFloat16::Infinity};
+ std::initializer_list output = {false, false, true, false, false, true};
+ run_is_inf_test(20, 1, 0, input, output);
+}
+
+TEST(IsInfTest, test_isinf_negative_bfloat16) {
+ std::initializer_list input = {BFloat16{-1.7f}, BFloat16::NaN, BFloat16::Infinity, 3.6_bfp16,
+ BFloat16::NegativeInfinity, BFloat16::Infinity};
+ std::initializer_list output = {false, false, false, false, true, false};
+ run_is_inf_test(20, 0, 1, input, output);
+}
+
#if !defined(DISABLE_FLOAT8_TYPES)
TEST(IsInfTest, test_Float8E4M3FN) {
std::initializer_list input = {