diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 91717486b77cb..a78ff69e5c894 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -757,9 +757,11 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) { EigenVectorArrayMap output_vec_map(output, num_elements); if (is_min) { - output_vec_map = input_1_vec_map.min(static_cast(per_iter_bh.ScalarInput0())); + output_vec_map = input_1_vec_map.template min( + static_cast(per_iter_bh.ScalarInput0())); } else { - output_vec_map = input_1_vec_map.max(static_cast(per_iter_bh.ScalarInput0())); + output_vec_map = input_1_vec_map.template max( + static_cast(per_iter_bh.ScalarInput0())); } }, [](BroadcastHelper& per_iter_bh) { @@ -772,9 +774,11 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) { EigenVectorArrayMap output_vec_map(output, num_elements); if (is_min) { - output_vec_map = input_0_vec_map.min(static_cast(per_iter_bh.ScalarInput1())); + output_vec_map = input_0_vec_map.template min( + static_cast(per_iter_bh.ScalarInput1())); } else { - output_vec_map = input_0_vec_map.max(static_cast(per_iter_bh.ScalarInput1())); + output_vec_map = input_0_vec_map.template max( + static_cast(per_iter_bh.ScalarInput1())); } }, [](BroadcastHelper& per_iter_bh) { @@ -790,9 +794,9 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) { EigenVectorArrayMap output_vec_map(output, num_elements); if (is_min) { - output_vec_map = input_0_vec_map.min(input_1_vec_map); + output_vec_map = input_0_vec_map.template min(input_1_vec_map); } else { - output_vec_map = input_0_vec_map.max(input_1_vec_map); + output_vec_map = input_0_vec_map.template max(input_1_vec_map); } }}; diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index db36754319309..55935a9eae86d 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -10,13 +10,10 @@ #include #include #include +#include #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/shared_inc/cuda_call.h" -#if CUDA_VERSION >= 11000 -#include -#endif - namespace onnxruntime { namespace cuda { @@ -347,6 +344,21 @@ __device__ __inline__ double _Pow(double a, double b) { return pow(a, b); } template <> __device__ __inline__ half _Pow(half a, half b) { return half(powf((float)a, (float)b)); } +#define ISNAN_HALF(v__) static_cast(*reinterpret_cast(&v__) & ~MLFloat16::kSignMask) \ + > MLFloat16::kPositiveInfinityBits + +#define ISNAN_BFLOAT16(v__) static_cast(*reinterpret_cast(&v__) & ~BFloat16::kSignMask) \ + > BFloat16::kPositiveInfinityBits + +// CUDART_NAN_BF16 and CUDART_NAN_FP16 constants were only added in CUDA 12.2, +// so define our own equivalent constants to support older versions. +// Note that there is no consistent canonical NaN for FP16 and BF16; +// CUDA uses 0x7FFF for both, but ONNX Runtime uses 0x7E00 and 0x7FC1 +// for FP16 and BF16 respectively +// (see Float16Impl::kPositiveQNaNBits and BFloat16Impl::kPositiveQNaNBits). +#define NAN_HALF __ushort_as_half((unsigned short)0x7FFFU) +#define NAN_BFLOAT16 BFloat16::FromBits((uint16_t)0x7FFFU) + template __device__ __inline__ T _Min(T a, T b) { return a < b ? a : b; } @@ -360,6 +372,24 @@ __device__ __inline__ double _Min(double a, double b) { return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a < b ? a : b ); } +template <> +__device__ __inline__ half _Min(half a, half b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) + return (ISNAN_HALF(a) || ISNAN_HALF(b)) ? NAN_HALF : (a < b ? a : b); +#else + return __hmin_nan(a, b); +#endif +} + +template <> +__device__ __inline__ BFloat16 _Min(BFloat16 a, BFloat16 b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) + return (ISNAN_BFLOAT16(a) || ISNAN_BFLOAT16(b)) ? NAN_BFLOAT16 : (a < b ? a : b); +#else + return BFloat16(__hmin_nan((__nv_bfloat16)a, (__nv_bfloat16)b)); +#endif +} + template __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; } @@ -373,6 +403,29 @@ __device__ __inline__ double _Max(double a, double b) { return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a > b ? a : b ); } +template <> +__device__ __inline__ half _Max(half a, half b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) + return (ISNAN_HALF(a) || ISNAN_HALF(b)) ? NAN_HALF : (a > b ? a : b); +#else + return __hmax_nan(a, b); +#endif +} + +template <> +__device__ __inline__ BFloat16 _Max(BFloat16 a, BFloat16 b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) + return (ISNAN_BFLOAT16(a) || ISNAN_BFLOAT16(b)) ? NAN_BFLOAT16 : (a > b ? a : b); +#else + return BFloat16(__hmax_nan((__nv_bfloat16)a, (__nv_bfloat16)b)); +#endif +} + +#undef ISNAN_HALF +#undef ISNAN_BFLOAT16 +#undef NAN_HALF +#undef NAN_BFLOAT16 + template __device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; } diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index eb914646942fe..507ed8e91a728 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -1787,54 +1787,90 @@ TEST(MathOpTest, Min_12_MLFloat16_Scalar1) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Min_12_MLFloat16_MatrixVector) { - OpTester test("Min", 12); - test.AddInput("data_0", {3, 3}, - MakeMLFloat16({1.0f, 1.0f, 1.0f, - -0.5f, 0.0f, -2.0f, - 0.5f, 0.0f, 2.0f})); - test.AddInput("data_1", {3, 1}, - MakeMLFloat16({0.0f, -1.0f, 1.0f})); - test.AddOutput("min", {3, 3}, - MakeMLFloat16({0.0f, 0.0f, 0.0f, - -1.0f, -1.0f, -2.0f, - 0.5f, 0.0f, 1.0f})); - if (nullptr != DefaultCpuExecutionProvider()) { +void TestFloat16MinMax( + const char* op_name, + const std::vector& lhs_dim, + const std::initializer_list& lhs_values, + const std::vector& rhs_dim, + const std::initializer_list& rhs_values, + const std::vector& out_dim, + const std::initializer_list& out_values) { + { std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); + if (nullptr != DefaultCpuExecutionProvider()) { + execution_providers.push_back(DefaultCpuExecutionProvider()); + } + if (nullptr != DefaultCudaExecutionProvider()) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + OpTester test(op_name, 13); + test.AddInput("data_0", lhs_dim, MakeMLFloat16(lhs_values)); + test.AddInput("data_1", rhs_dim, MakeMLFloat16(rhs_values)); + test.AddOutput("output", out_dim, MakeMLFloat16(out_values)); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } if (nullptr != DefaultCudaExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } -} -TEST(MathOpTest, Min_12_MLFloat16_VectorMatrix) { - OpTester test("Min", 12); - test.AddInput("data_0", {3, 1}, - MakeMLFloat16({0.0f, -1.0f, 1.0f})); - test.AddInput("data_1", {3, 4}, - MakeMLFloat16({1.0f, 1.0f, 1.0f, -1.0f, - -0.5f, 0.0f, -2.0f, -1.25f, - 0.5f, 0.0f, 2.0f, 1.5f})); - test.AddOutput("min", {3, 4}, - MakeMLFloat16({0.0f, 0.0f, 0.0f, -1.0f, - -1.0f, -1.0f, -2.0f, -1.25f, - 0.5f, 0.0f, 1.0f, 1.0f})); - if (nullptr != DefaultCpuExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - if (nullptr != DefaultCudaExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); + OpTester test(op_name, 13); + test.AddInput("data_0", lhs_dim, MakeBFloat16(lhs_values)); + test.AddInput("data_1", rhs_dim, MakeBFloat16(rhs_values)); + test.AddOutput("output", out_dim, MakeBFloat16(out_values)); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } } +TEST(MathOpTest, Min_13_Float16_MatrixVector) { + TestFloat16MinMax("Min", + {3, 3}, + {1.0f, 1.0f, 1.0f, + -0.5f, 0.0f, -2.0f, + 0.5f, 0.0f, 2.0f}, + {3, 1}, {0.0f, -1.0f, 1.0f}, + {3, 3}, + {0.0f, 0.0f, 0.0f, + -1.0f, -1.0f, -2.0f, + 0.5f, 0.0f, 1.0f}); +} + +TEST(MathOpTest, Min_13_Float16_VectorMatrix) { + TestFloat16MinMax("Min", + {3, 1}, {0.0f, -1.0f, 1.0f}, + {3, 4}, + {1.0f, 1.0f, 1.0f, -1.0f, + -0.5f, 0.0f, -2.0f, -1.25f, + 0.5f, 0.0f, 2.0f, 1.5f}, + {3, 4}, + {0.0f, 0.0f, 0.0f, -1.0f, + -1.0f, -1.0f, -2.0f, -1.25f, + 0.5f, 0.0f, 1.0f, 1.0f}); +} + +TEST(MathOpTest, Min_13_Float16_Nan) { + TestFloat16MinMax("Min", + {4, 1}, {-1.0f, std::numeric_limits::quiet_NaN(), 1.0f, 0.5f}, + {4, 1}, {0.5f, 1.0f, 0.25f, std::numeric_limits::quiet_NaN()}, + {4, 1}, + {-1.0f, std::numeric_limits::quiet_NaN(), 0.25f, std::numeric_limits::quiet_NaN()}); +} + +TEST(MathOpTest, Min_13_Float16_Nan_with_scalar) { + TestFloat16MinMax("Min", + {3, 1}, {-1.0f, std::numeric_limits::quiet_NaN(), 1.0f}, + {1}, {0.25f}, + {3, 1}, {-1.0f, std::numeric_limits::quiet_NaN(), 0.25f}); +} + +TEST(MathOpTest, Min_13_Float16_with_scalar_Nan) { + TestFloat16MinMax("Min", + {3, 1}, {-0.5f, 1.0f, 1.5f}, + {1}, {std::numeric_limits::quiet_NaN()}, + {3, 1}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); +} TEST(MathOpTest, Max_6) { OpTester test("Max", 6); std::vector dims{3, 3}; @@ -2185,54 +2221,57 @@ TEST(MathOpTest, Max_12_MLFloat16_Scalar1) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Max_12_MLFloat16_MatrixVector) { - OpTester test("Max", 12); - test.AddInput("data_0", {4, 3}, - MakeMLFloat16({1.0f, 1.0f, 1.0f, - -0.5f, 0.0f, -2.0f, - 0.0f, 0.5f, 0.75f, - 0.5f, 0.0f, 2.0f})); - test.AddInput("data_1", {4, 1}, - MakeMLFloat16({0.0f, -1.0f, 0.5f, 1.0f})); - test.AddOutput("max", {4, 3}, - MakeMLFloat16({1.0f, 1.0f, 1.0f, - -0.5f, 0.0f, -1.0f, - 0.5f, 0.5f, 0.75f, - 1.0f, 1.0f, 2.0f})); - if (nullptr != DefaultCpuExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - if (nullptr != DefaultCudaExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } -} - -TEST(MathOpTest, Max_12_MLFloat16_VectorMatrix) { - OpTester test("Max", 12); - test.AddInput("data_0", {3, 1}, - MakeMLFloat16({0.0f, -1.0f, 1.0f})); - test.AddInput("data_1", {3, 3}, - MakeMLFloat16({1.0f, 1.0f, 1.0f, - -0.5f, 0.0f, -2.0f, - 0.5f, 0.0f, 2.0f})); - test.AddOutput("max", {3, 3}, - MakeMLFloat16({1.0f, 1.0f, 1.0f, - -0.5f, 0.0f, -1.0f, - 1.0f, 1.0f, 2.0f})); - if (nullptr != DefaultCpuExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - if (nullptr != DefaultCudaExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } +TEST(MathOpTest, Max_13_Float16_MatrixVector) { + TestFloat16MinMax("Max", + {4, 3}, + {1.0f, 1.0f, 1.0f, + -0.5f, 0.0f, -2.0f, + 0.0f, 0.5f, 0.75f, + 0.5f, 0.0f, 2.0f}, + {4, 1}, {0.0f, -1.0f, 0.5f, 1.0f}, + {4, 3}, + {1.0f, 1.0f, 1.0f, + -0.5f, 0.0f, -1.0f, + 0.5f, 0.5f, 0.75f, + 1.0f, 1.0f, 2.0f}); +} + +TEST(MathOpTest, Max_13_Float16_VectorMatrix) { + TestFloat16MinMax("Max", + {3, 1}, {0.0f, -1.0f, 1.0f}, + {3, 3}, + {1.0f, 1.0f, 1.0f, + -0.5f, 0.0f, -2.0f, + 0.5f, 0.0f, 2.0f}, + {3, 3}, + {1.0f, 1.0f, 1.0f, + -0.5f, 0.0f, -1.0f, + 1.0f, 1.0f, 2.0f}); +} + +TEST(MathOpTest, Max_13_Float16_Nan) { + TestFloat16MinMax("Max", + {4, 1}, {-1.0f, std::numeric_limits::quiet_NaN(), 1.0f, 0.5f}, + {4, 1}, {0.5f, 1.0f, 0.25f, std::numeric_limits::quiet_NaN()}, + {4, 1}, + {0.5f, std::numeric_limits::quiet_NaN(), 1.0f, std::numeric_limits::quiet_NaN()}); +} + +TEST(MathOpTest, Max_13_Float16_Nan_with_scalar) { + TestFloat16MinMax("Max", + {3, 1}, {-1.0f, std::numeric_limits::quiet_NaN(), 1.0f}, + {1}, {0.25f}, + {3, 1}, {0.25f, std::numeric_limits::quiet_NaN(), 1.0f}); +} + +TEST(MathOpTest, Max_13_Float16_with_scalar_Nan) { + TestFloat16MinMax("Max", + {3, 1}, {-0.5f, 1.0f, 1.5f}, + {1}, {std::numeric_limits::quiet_NaN()}, + {3, 1}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); } TEST(MathOpTest, Not) {