From 116d36e3d149e79e80b0e357db357a55f15c3c62 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 4 Mar 2024 15:54:18 -0800 Subject: [PATCH 1/6] Implement but not registered --- .../core/framework/data_types_internal.h | 2 +- .../core/providers/cpu/tensor/isinf.cc | 58 ++++++++++-- .../core/providers/cuda/cu_inc/common.cuh | 94 +++++++++++++++++++ onnxruntime/core/providers/cuda/cuda_common.h | 18 ++++ .../cuda/math/unary_elementwise_ops.cc | 55 +++++++++++ .../cuda/math/unary_elementwise_ops.h | 13 +++ .../cuda/math/unary_elementwise_ops_impl.cu | 44 +++++++++ .../cuda/math/unary_elementwise_ops_impl.h | 8 ++ 8 files changed, 282 insertions(+), 10 deletions(-) diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h index fbeee8a2aedc5..3a3b5cb6888f2 100644 --- a/include/onnxruntime/core/framework/data_types_internal.h +++ b/include/onnxruntime/core/framework/data_types_internal.h @@ -305,7 +305,7 @@ class CallableDispatchableHelper { return 0; } - void CheckCalledOnce() { + void CheckCalledOnce() const { ORT_ENFORCE(called_ == 1, "Unsupported data type: ", dt_type_); } }; diff --git a/onnxruntime/core/providers/cpu/tensor/isinf.cc b/onnxruntime/core/providers/cpu/tensor/isinf.cc index 1b449f46927a2..e16de0feea559 100644 --- a/onnxruntime/core/providers/cpu/tensor/isinf.cc +++ b/onnxruntime/core/providers/cpu/tensor/isinf.cc @@ -23,7 +23,9 @@ ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST( using IsInfTypesOpset20 = TypeList< float, - double + double, + MLFloat16, + BFloat16 #if !defined(DISABLE_FLOAT8_TYPES) , Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ @@ -87,29 +89,67 @@ namespace isinf_internal { template struct ComputeDispatchTarget { void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const { - const auto total_items = X.Shape().Size(); + auto input_data = X.DataAsSpan(); auto output_data = Y.MutableData(); if (detect_positive && detect_negative) { EigenMap(Y) = EigenMap(X).array().isInf(); } else if (detect_positive) { - auto input_data = X.Data(); - auto end_data = input_data + total_items; std::transform( - input_data, end_data, output_data, [](T v) { + input_data.begin(), input_data.end(), output_data, [](T v) { return (v == std::numeric_limits::infinity()); }); } else if (detect_negative) { - auto input_data = X.Data(); - auto end_data = input_data + total_items; std::transform( - input_data, end_data, output_data, [](T v) { + input_data.begin(), input_data.end(), output_data, [](T v) { return (v == -std::numeric_limits::infinity()); }); } else { // all false - memset(output_data, false, onnxruntime::narrow(total_items)); + memset(output_data, false, input_data.size()); + } + } +}; + +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const { + auto output_data = Y.MutableData(); + auto input_data = X.DataAsSpan(); + if (detect_positive && detect_negative) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](MLFloat16 v) { return v.IsInfinity(); }); + } else if (detect_positive) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](MLFloat16 v) { return v.IsPositiveInfinity(); }); + } else if (detect_negative) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](MLFloat16 v) { return v.IsNegativeInfinity(); }); + } else { + // all false + memset(output_data, false, input_data.size()); + } + } +}; + +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const { + auto output_data = Y.MutableData(); + auto input_data = X.DataAsSpan(); + if (detect_positive && detect_negative) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](BFloat16 v) { return v.IsInfinity(); }); + } else if (detect_positive) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](BFloat16 v) { return v.IsPositiveInfinity(); }); + } else if (detect_negative) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](BFloat16 v) { return v.IsNegativeInfinity(); }); + } else { + // all false + memset(output_data, false, input_data.size()); } } }; diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index 66794f88d8670..e86435f1882f4 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -438,6 +438,100 @@ __device__ __inline__ BFloat16 _Fmod(BFloat16 a, BFloat16 b) { return fmodf((float)a, (float)b); } +namespace isinf_details { +template +struct IsInfTyped { + static __device__ __inline__ bool IsInf(T a) { + // cast is needed because on non MS compilers, + // because there isinf() returns int + // and we want to avoid stupid warnings + return static_cast(isinf(a)); + }; + static __device__ __inline__ bool IsInfPos(T a) { + return a == std::numeric_limits::infinity(); + } + static __device__ __inline__ bool IsInfNeg(T a) { + return a == -std::numeric_limits::infinity(); + } +}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(half a) { + return MLFloat16::kPositiveInfinityBits == + static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask); + }; + static __device__ __inline__ bool IsInfPos(half a) { + return MLFloat16::kPositiveInfinityBits == *reinterpret_cast(&a); + } + static __device__ __inline__ bool IsInfNeg(half a) { + return MLFloat16::kNegativeInfinityBits == *reinterpret_cast(&a); + } +}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(BFloat16 a) { + return BFloat16::kPositiveInfinityBits == + static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask); + }; + static __device__ __inline__ bool IsInfPos(BFloat16 a) { + return BFloat16::kPositiveInfinityBits == *reinterpret_cast(&a); + } + static __device__ __inline__ bool IsInfNeg(BFloat16 a) { + return BFloat16::kNegativeInfinityBits == *reinterpret_cast(&a); + } +}; + +#if !defined(DISABLE_FLOAT8_TYPES) + +template +struct ReturnFalse { + constexpr static bool __device__ __inline__ IsInf(T) { return false; } + constexpr static bool __device__ __inline__ IsInfPos(T) { return false; } + constexpr static bool __device__ __inline__ IsInfNeg(T) { return false; } +}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(Float8E5M2 a) { + return a.val == 0b01111100 || a.val == 0b11111100; + } + static __device__ __inline__ bool IsInfPos(Float8E5M2 a) { + return a.val == 0b01111100; + } + static __device__ __inline__ bool IsInfNeg(Float8E5M2 a) { + return a.val == 0b11111100; + } +}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +#endif +} // namespace isinf_details + +template +struct _IsInf { + __device__ __inline__ bool operator()(T a) const { + if constexpr (detect_positive && detect_negative) { + return isinf_details::IsInfTyped::IsInf(a); + } else if constexpr (detect_positive) { + return isinf_details::IsInfTyped::IsInfPos(a); + } else if constexpr (detect_negative) { + return isinf_details::IsInfTyped::IsInfNeg(a); + } else { + return false; + } + } +}; + // We would like to use 64-bit integer to support large matrices. However, CUDA seems to support only 32-bit integer // For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type. #ifndef CUDA_LONG diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index 41c999bacee13..61da125b40953 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -70,6 +70,15 @@ class ToCudaType { } }; +template <> +class ToCudaType { + public: + typedef Float8E4M3FNUZ MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + template <> class ToCudaType { public: @@ -79,6 +88,15 @@ class ToCudaType { } }; +template <> +class ToCudaType { + public: + typedef Float8E5M2FNUZ MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + #endif inline bool CalculateFdmStrides(gsl::span p, const std::vector& dims) { diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index fd8b69d7bd2f5..119aba849656f 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -71,6 +71,61 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa return Status::OK(); \ } +namespace op_kernel_type_control { +using IsInfTypesOpset10 = TypeList; +using IsInfTypesOpset20 = TypeList; +} // namespace op_kernel_type_control + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + IsInf, + kOnnxDomain, + 10, + 19, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", BuildKernelDefConstraints()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + IsInf); + +ONNX_OPERATOR_KERNEL_EX( + IsInf, + kOnnxDomain, + 20, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + IsInf); + +IsInf::IsInf(const OpKernelInfo& info) : UnaryElementwise(info) { + int64_t detect_positive = 0, detect_negative = 0; + Status status = info.GetAttr("detect_positive", &detect_positive); + ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_positive"); + status = info.GetAttr("detect_negative", &detect_negative); + ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_negative"); + + detect_positive_ = static_cast(detect_positive); + detect_negative_ = static_cast(detect_negative); + opset_ = info.node().SinceVersion(); +} + +Status IsInf::ComputeInternal(OpKernelContext* context) const { + UnaryElementwisePreparation p; + ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p)); + + Explicit_Impl_IsInf(Stream(context), opset_, detect_positive_, detect_negative_, + p.input_tensor->GetElementType(), p.input_tensor->DataRaw(), + p.output_tensor->MutableData(), + p.input_tensor->Shape().Size()); + return Status::OK(); +} + + #define UNARY_OP_VERSIONED_TYPED(name, startver, endver, T) \ UNARY_ELEMENTWISE_REGISTER_VERSIONED_KERNEL(name, startver, endver, T) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h index 775b78c43a736..928b6e4370ba2 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h @@ -2,6 +2,7 @@ // Licensed under the MIT License. #pragma once + #include "core/providers/cuda/cuda_kernel.h" namespace onnxruntime { @@ -119,5 +120,17 @@ class Sign final : public UnaryElementwise { Status ComputeInternal(OpKernelContext* context) const override; }; +class IsInf final :public UnaryElementwise { + public: + explicit IsInf(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + private: + bool detect_positive_{true}; + bool detect_negative_{true}; + int opset_; +}; + + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index 73c5ac80756be..36c1624ddc1ad 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -11,6 +11,7 @@ #endif namespace onnxruntime { + namespace cuda { #define OP(name, expr) \ @@ -284,5 +285,48 @@ EXPLICIT_IMPL_CASTSAT(__nv_bfloat16, Float8E5M2) #endif +namespace isinf_details { +template +struct IsInf_Impl { + void operator()(cudaStream_t stream, const void* input_raw, bool* output_data, + bool detect_positive, bool detect_negative, size_t count) const { + using CudaType = typename ToCudaType::MappedType; + const auto* input_data = reinterpret_cast(input_raw); + if (detect_positive && detect_negative) { + UnaryElementWiseImpl(stream, input_data, output_data, _IsInf(), count); + } else if (detect_positive) { + UnaryElementWiseImpl(stream, input_data, output_data, _IsInf(), count); + } else if (detect_negative) { + UnaryElementWiseImpl(stream, input_data, output_data, _IsInf(), count); + } else { + UnaryElementWiseImpl(stream, input_data, output_data, _IsInf(), count); + } + } +}; + +} // namespace isinf_details + +void Explicit_Impl_IsInf(cudaStream_t stream, int op_set, + bool detect_positive, bool detect_negative, + int32_t input_data_type, + const void* input_raw, bool* output_data, + size_t count) { + if (op_set < 20) { + utils::MLTypeCallDispatcher dispatcher{input_data_type}; + dispatcher.Invoke(stream, input_raw, output_data, + detect_positive, detect_negative, count); + } else { + utils::MLTypeCallDispatcher + dispatcher{input_data_type}; + dispatcher.Invoke(stream, input_raw, output_data, + detect_positive, detect_negative, count); + } +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h index 608a81a24cf4f..f09a3430338d2 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h @@ -46,6 +46,7 @@ namespace cuda { UNARY_OPS() #undef UNARY_OP_NAME_EXPR + // Cast #define DECL_IMPL_CAST(InT, OutT) \ @@ -137,5 +138,12 @@ void Impl_CastSat( #endif +// IsInf +void Explicit_Impl_IsInf(cudaStream_t stream, int op_set, + bool detect_positive, bool detect_negative, + int32_t input_data_type, + const void* input_raw, bool* output_data, + size_t count); } // namespace cuda + } // namespace onnxruntime From 56f2f72ce98355dc06b88cb64b4395a9763c5b46 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 4 Mar 2024 17:06:34 -0800 Subject: [PATCH 2/6] Add fp16 and bfp16, test pass --- .../providers/cuda/cuda_execution_provider.cc | 5 +++ .../cuda/math/unary_elementwise_ops.cc | 21 ++++------ .../cuda/math/unary_elementwise_ops.h | 3 +- .../cuda/math/unary_elementwise_ops_impl.h | 1 - .../providers/rocm/rocm_execution_provider.cc | 9 ++++ .../test/providers/cpu/tensor/isinf_test.cc | 42 +++++++++++++++++++ 6 files changed, 65 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 1ce089fd93044..982c84fb294ef 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -830,6 +830,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, MLFloat16, ThresholdedRelu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, TopK); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, Mod); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 19, IsInf); // opset 11 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Compress); @@ -1339,6 +1340,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, S class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, float, Gelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, Gelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf); template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -1736,6 +1738,8 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // opset 11 BuildKernelCreateInfo, @@ -2244,6 +2248,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index 119aba849656f..12b3c96c4b295 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -70,16 +70,12 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa \ return Status::OK(); \ } - -namespace op_kernel_type_control { -using IsInfTypesOpset10 = TypeList; -using IsInfTypesOpset20 = TypeList; -} // namespace op_kernel_type_control ONNX_OPERATOR_VERSIONED_KERNEL_EX( IsInf, @@ -88,8 +84,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 19, kCudaExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", BuildKernelDefConstraints()) - .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), IsInf); ONNX_OPERATOR_KERNEL_EX( @@ -98,8 +94,8 @@ ONNX_OPERATOR_KERNEL_EX( 20, kCudaExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()) - .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), IsInf); IsInf::IsInf(const OpKernelInfo& info) : UnaryElementwise(info) { @@ -125,7 +121,6 @@ Status IsInf::ComputeInternal(OpKernelContext* context) const { return Status::OK(); } - #define UNARY_OP_VERSIONED_TYPED(name, startver, endver, T) \ UNARY_ELEMENTWISE_REGISTER_VERSIONED_KERNEL(name, startver, endver, T) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h index 928b6e4370ba2..3b7d6df7221b7 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h @@ -120,7 +120,7 @@ class Sign final : public UnaryElementwise { Status ComputeInternal(OpKernelContext* context) const override; }; -class IsInf final :public UnaryElementwise { +class IsInf final : public UnaryElementwise { public: explicit IsInf(const OpKernelInfo& info); Status ComputeInternal(OpKernelContext* context) const override; @@ -131,6 +131,5 @@ class IsInf final :public UnaryElementwise { int opset_; }; - } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h index f09a3430338d2..32dd622db85c9 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h @@ -46,7 +46,6 @@ namespace cuda { UNARY_OPS() #undef UNARY_OP_NAME_EXPR - // Cast #define DECL_IMPL_CAST(InT, OutT) \ diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 0265c06b9a938..4a679b790ee40 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -793,6 +793,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, MLFloat16, ThresholdedRelu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, TopK); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, Mod); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 19, IsInf); // opset 11 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, float, ArgMax); @@ -1342,6 +1343,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, R class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Scan); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Shape); +// Opset 20 +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, IsInf); + template <> KernelCreateInfo BuildKernelCreateInfo() { return {}; @@ -1738,6 +1742,8 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // opset 11 BuildKernelCreateInfo, @@ -2294,6 +2300,9 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + // opset 20 + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/test/providers/cpu/tensor/isinf_test.cc b/onnxruntime/test/providers/cpu/tensor/isinf_test.cc index 2e583c5d2547b..bd97306142f18 100644 --- a/onnxruntime/test/providers/cpu/tensor/isinf_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/isinf_test.cc @@ -99,6 +99,48 @@ TEST(IsInfTest, test_isinf_negative_double20) { run_is_inf_test(20, 0, 1, input, output); } +TEST(IsInfTest, test_isinf_mlfloat16) { + std::initializer_list input = {MLFloat16{-1.7f}, MLFloat16::NaN, MLFloat16::Infinity, 3.6_fp16, + MLFloat16::NegativeInfinity, MLFloat16::Infinity}; + std::initializer_list output = {false, false, true, false, true, true}; + run_is_inf_test(20, 1, 1, input, output); +} + +TEST(IsInfTest, test_isinf_positive_mlfloat16) { + std::initializer_list input = {MLFloat16{-1.7f}, MLFloat16::NaN, MLFloat16::Infinity, 3.6_fp16, + MLFloat16::NegativeInfinity, MLFloat16::Infinity}; + std::initializer_list output = {false, false, true, false, false, true}; + run_is_inf_test(20, 1, 0, input, output); +} + +TEST(IsInfTest, test_isinf_negative_mlfloat16) { + std::initializer_list input = {MLFloat16{-1.7f}, MLFloat16::NaN, MLFloat16::Infinity, 3.6_fp16, + MLFloat16::NegativeInfinity, MLFloat16::Infinity}; + std::initializer_list output = {false, false, false, false, true, false}; + run_is_inf_test(20, 0, 1, input, output); +} + +TEST(IsInfTest, test_isinf_bfloat16) { + std::initializer_list input = {BFloat16{-1.7f}, BFloat16::NaN, BFloat16::Infinity, 3.6_bfp16, + BFloat16::NegativeInfinity, BFloat16::Infinity}; + std::initializer_list output = {false, false, true, false, true, true}; + run_is_inf_test(20, 1, 1, input, output); +} + +TEST(IsInfTest, test_isinf_positive_bfloat16) { + std::initializer_list input = {BFloat16{-1.7f}, BFloat16::NaN, BFloat16::Infinity, 3.6_bfp16, + BFloat16::NegativeInfinity, BFloat16::Infinity}; + std::initializer_list output = {false, false, true, false, false, true}; + run_is_inf_test(20, 1, 0, input, output); +} + +TEST(IsInfTest, test_isinf_negative_bfloat16) { + std::initializer_list input = {BFloat16{-1.7f}, BFloat16::NaN, BFloat16::Infinity, 3.6_bfp16, + BFloat16::NegativeInfinity, BFloat16::Infinity}; + std::initializer_list output = {false, false, false, false, true, false}; + run_is_inf_test(20, 0, 1, input, output); +} + #if !defined(DISABLE_FLOAT8_TYPES) TEST(IsInfTest, test_Float8E4M3FN) { std::initializer_list input = { From 8c19067ae9ac0a663fc756e6a5eb6d5eaf954257 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Tue, 5 Mar 2024 10:29:51 -0800 Subject: [PATCH 3/6] Address build issues --- .../core/providers/cuda/cu_inc/common.cuh | 6 ++--- .../cuda/math/unary_elementwise_ops.cc | 8 +------ .../cuda/math/unary_elementwise_ops_impl.cu | 22 +++++++------------ .../cuda/math/unary_elementwise_ops_impl.h | 8 +++++++ 4 files changed, 20 insertions(+), 24 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index e86435f1882f4..bba9178348132 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -446,7 +446,7 @@ struct IsInfTyped { // because there isinf() returns int // and we want to avoid stupid warnings return static_cast(isinf(a)); - }; + } static __device__ __inline__ bool IsInfPos(T a) { return a == std::numeric_limits::infinity(); } @@ -460,7 +460,7 @@ struct IsInfTyped { static __device__ __inline__ bool IsInf(half a) { return MLFloat16::kPositiveInfinityBits == static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask); - }; + } static __device__ __inline__ bool IsInfPos(half a) { return MLFloat16::kPositiveInfinityBits == *reinterpret_cast(&a); } @@ -474,7 +474,7 @@ struct IsInfTyped { static __device__ __inline__ bool IsInf(BFloat16 a) { return BFloat16::kPositiveInfinityBits == static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask); - }; + } static __device__ __inline__ bool IsInfPos(BFloat16 a) { return BFloat16::kPositiveInfinityBits == *reinterpret_cast(&a); } diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index 12b3c96c4b295..56ca139d18bf9 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -70,12 +70,6 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa \ return Status::OK(); \ } -#if !defined(DISABLE_FLOAT8_TYPES) -#define OPSET20_CONSTRAINTS float, double, MLFloat16, BFloat16, Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, \ - Float8E5M2FNUZ -#else -#define OPSET20_CONSTRAINTS float, double, MLFloat16, BFloat16 -#endif ONNX_OPERATOR_VERSIONED_KERNEL_EX( IsInf, @@ -94,7 +88,7 @@ ONNX_OPERATOR_KERNEL_EX( 20, kCudaExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T1", BuildKernelDefConstraints()) .TypeConstraint("T2", DataTypeImpl::GetTensorType()), IsInf); diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index 36c1624ddc1ad..9adb28da22b0b 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -287,19 +287,19 @@ EXPLICIT_IMPL_CASTSAT(__nv_bfloat16, Float8E5M2) namespace isinf_details { template -struct IsInf_Impl { +struct IsInf_DispFunc { void operator()(cudaStream_t stream, const void* input_raw, bool* output_data, bool detect_positive, bool detect_negative, size_t count) const { using CudaType = typename ToCudaType::MappedType; const auto* input_data = reinterpret_cast(input_raw); if (detect_positive && detect_negative) { - UnaryElementWiseImpl(stream, input_data, output_data, _IsInf(), count); + UnaryElementWiseImpl(stream, input_data, output_data, _IsInf{}, count); } else if (detect_positive) { - UnaryElementWiseImpl(stream, input_data, output_data, _IsInf(), count); + UnaryElementWiseImpl(stream, input_data, output_data, _IsInf{}, count); } else if (detect_negative) { - UnaryElementWiseImpl(stream, input_data, output_data, _IsInf(), count); + UnaryElementWiseImpl(stream, input_data, output_data, _IsInf{}, count); } else { - UnaryElementWiseImpl(stream, input_data, output_data, _IsInf(), count); + UnaryElementWiseImpl(stream, input_data, output_data, _IsInf{}, count); } } }; @@ -313,17 +313,11 @@ void Explicit_Impl_IsInf(cudaStream_t stream, int op_set, size_t count) { if (op_set < 20) { utils::MLTypeCallDispatcher dispatcher{input_data_type}; - dispatcher.Invoke(stream, input_raw, output_data, + dispatcher.Invoke(stream, input_raw, output_data, detect_positive, detect_negative, count); } else { - utils::MLTypeCallDispatcher - dispatcher{input_data_type}; - dispatcher.Invoke(stream, input_raw, output_data, + utils::MLTypeCallDispatcher dispatcher{input_data_type}; + dispatcher.Invoke(stream, input_raw, output_data, detect_positive, detect_negative, count); } } diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h index 32dd622db85c9..63b293f1be396 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h @@ -138,6 +138,14 @@ void Impl_CastSat( #endif // IsInf + +#if !defined(DISABLE_FLOAT8_TYPES) +#define ISINF_OPSET20_CONSTRAINTS float, double, MLFloat16, BFloat16, Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, \ + Float8E5M2FNUZ +#else +#define ISINF_OPSET20_CONSTRAINTS float, double, MLFloat16, BFloat16 +#endif + void Explicit_Impl_IsInf(cudaStream_t stream, int op_set, bool detect_positive, bool detect_negative, int32_t input_data_type, From 668ae7b0260e531be0cdc1d1c6168df5ea3c6adf Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Tue, 5 Mar 2024 10:31:09 -0800 Subject: [PATCH 4/6] Update docs --- docs/OperatorKernels.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 71b0def659741..4514a85531d6b 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -160,7 +160,7 @@ Do not modify directly.* |||[1, 10]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ImageScaler|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| |InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(float)| -|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| |||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| |IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| |||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| @@ -631,6 +631,8 @@ Do not modify directly.* |||[1, 10]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ImageScaler|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| +|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| |LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| |||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |LSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| From a5d8970b813d12f363aea3223ad994e2ee800a63 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Tue, 5 Mar 2024 10:42:02 -0800 Subject: [PATCH 5/6] Address types macro and attribute default values --- onnxruntime/core/providers/cpu/tensor/isinf.cc | 6 ++---- .../providers/cuda/math/unary_elementwise_ops.cc | 12 +++--------- .../cuda/math/unary_elementwise_ops_impl.cu | 2 +- .../providers/cuda/math/unary_elementwise_ops_impl.h | 6 +++--- 4 files changed, 9 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/isinf.cc b/onnxruntime/core/providers/cpu/tensor/isinf.cc index e16de0feea559..9d18d1fa62288 100644 --- a/onnxruntime/core/providers/cpu/tensor/isinf.cc +++ b/onnxruntime/core/providers/cpu/tensor/isinf.cc @@ -78,10 +78,8 @@ ONNX_CPU_OPERATOR_KERNEL( IsInf); IsInf::IsInf(const OpKernelInfo& info) : OpKernel(info) { - Status status = info.GetAttr("detect_positive", &detect_positive_); - ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_positive"); - status = info.GetAttr("detect_negative", &detect_negative_); - ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_negative"); + detect_positive_ = info.GetAttrOrDefault("detect_positive", 1); + detect_negative_ = info.GetAttrOrDefault("detect_negative", 1); opset_ = info.node().SinceVersion(); } diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index 56ca139d18bf9..00de1b37f3302 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -88,19 +88,13 @@ ONNX_OPERATOR_KERNEL_EX( 20, kCudaExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T1", BuildKernelDefConstraints()) .TypeConstraint("T2", DataTypeImpl::GetTensorType()), IsInf); IsInf::IsInf(const OpKernelInfo& info) : UnaryElementwise(info) { - int64_t detect_positive = 0, detect_negative = 0; - Status status = info.GetAttr("detect_positive", &detect_positive); - ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_positive"); - status = info.GetAttr("detect_negative", &detect_negative); - ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_negative"); - - detect_positive_ = static_cast(detect_positive); - detect_negative_ = static_cast(detect_negative); + detect_positive_ = static_cast(info.GetAttrOrDefault("detect_positive", 1)); + detect_negative_ = static_cast(info.GetAttrOrDefault("detect_negative", 1)); opset_ = info.node().SinceVersion(); } diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index 9adb28da22b0b..fd8f7929d4426 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -316,7 +316,7 @@ void Explicit_Impl_IsInf(cudaStream_t stream, int op_set, dispatcher.Invoke(stream, input_raw, output_data, detect_positive, detect_negative, count); } else { - utils::MLTypeCallDispatcher dispatcher{input_data_type}; + utils::MLTypeCallDispatcher dispatcher{input_data_type}; dispatcher.Invoke(stream, input_raw, output_data, detect_positive, detect_negative, count); } diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h index 63b293f1be396..a606d479bc79b 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h @@ -140,10 +140,10 @@ void Impl_CastSat( // IsInf #if !defined(DISABLE_FLOAT8_TYPES) -#define ISINF_OPSET20_CONSTRAINTS float, double, MLFloat16, BFloat16, Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, \ - Float8E5M2FNUZ +#define ISINF_OPSET20_ALL_FLOATS float, double, MLFloat16, BFloat16, Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, \ + Float8E5M2FNUZ #else -#define ISINF_OPSET20_CONSTRAINTS float, double, MLFloat16, BFloat16 +#define ISINF_OPSET20_ALL_FLOATS float, double, MLFloat16, BFloat16 #endif void Explicit_Impl_IsInf(cudaStream_t stream, int op_set, From 94c0da76a3760d2cf47dddbc7ab1d5200a7e616b Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Tue, 5 Mar 2024 11:17:11 -0800 Subject: [PATCH 6/6] Replicate IsInf CUDA functors to ROCM --- .../core/providers/rocm/cu_inc/common.cuh | 94 +++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh index 5f966ac746fcb..f3685606c17f5 100644 --- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh +++ b/onnxruntime/core/providers/rocm/cu_inc/common.cuh @@ -335,6 +335,100 @@ __device__ __inline__ BFloat16 _Fmod(BFloat16 a, BFloat16 b) { return fmodf((float)a, (float)b); } +namespace isinf_details { +template +struct IsInfTyped { + static __device__ __inline__ bool IsInf(T a) { + // cast is needed because on non MS compilers, + // because there isinf() returns int + // and we want to avoid stupid warnings + return static_cast(isinf(a)); + } + static __device__ __inline__ bool IsInfPos(T a) { + return a == std::numeric_limits::infinity(); + } + static __device__ __inline__ bool IsInfNeg(T a) { + return a == -std::numeric_limits::infinity(); + } +}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(half a) { + return MLFloat16::kPositiveInfinityBits == + static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask); + } + static __device__ __inline__ bool IsInfPos(half a) { + return MLFloat16::kPositiveInfinityBits == *reinterpret_cast(&a); + } + static __device__ __inline__ bool IsInfNeg(half a) { + return MLFloat16::kNegativeInfinityBits == *reinterpret_cast(&a); + } +}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(BFloat16 a) { + return BFloat16::kPositiveInfinityBits == + static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask); + } + static __device__ __inline__ bool IsInfPos(BFloat16 a) { + return BFloat16::kPositiveInfinityBits == *reinterpret_cast(&a); + } + static __device__ __inline__ bool IsInfNeg(BFloat16 a) { + return BFloat16::kNegativeInfinityBits == *reinterpret_cast(&a); + } +}; + +#if !defined(DISABLE_FLOAT8_TYPES) + +template +struct ReturnFalse { + constexpr static bool __device__ __inline__ IsInf(T) { return false; } + constexpr static bool __device__ __inline__ IsInfPos(T) { return false; } + constexpr static bool __device__ __inline__ IsInfNeg(T) { return false; } +}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(Float8E5M2 a) { + return a.val == 0b01111100 || a.val == 0b11111100; + } + static __device__ __inline__ bool IsInfPos(Float8E5M2 a) { + return a.val == 0b01111100; + } + static __device__ __inline__ bool IsInfNeg(Float8E5M2 a) { + return a.val == 0b11111100; + } +}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +#endif +} // namespace isinf_details + +template +struct _IsInf { + __device__ __inline__ bool operator()(T a) const { + if constexpr (detect_positive && detect_negative) { + return isinf_details::IsInfTyped::IsInf(a); + } else if constexpr (detect_positive) { + return isinf_details::IsInfTyped::IsInfPos(a); + } else if constexpr (detect_negative) { + return isinf_details::IsInfTyped::IsInfNeg(a); + } else { + return false; + } + } +}; + // We would like to use 64-bit integer to support large matrices. However, ROCM seems to support only 32-bit integer // For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type. #ifndef HIP_LONG