Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recognize NaN operands in Min and Max ops #19984

Merged
merged 5 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions onnxruntime/core/providers/cuda/cu_inc/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
#include <stdint.h>
#include <vector>
#include <mutex>
#include <limits>
#include <assert.h>
#include <math.h>

Check warning on line 10 in onnxruntime/core/providers/cuda/cu_inc/common.cuh

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Found C system header after C++ system header. Should be: common.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/providers/cuda/cu_inc/common.cuh:10: Found C system header after C++ system header. Should be: common.h, c system, c++ system, other. [build/include_order] [4]
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "core/providers/cuda/cuda_common.h"
Expand Down Expand Up @@ -345,9 +347,29 @@
template <typename T>
__device__ __inline__ T _Min(T a, T b) { return a < b ? a : b; }

template <>
__device__ __inline__ float _Min(float a, float b) {
return (isnan(a) || isnan(b)) ? std::numeric_limits<float>::quiet_NaN() : ( a < b ? a : b );
}

template <>
__device__ __inline__ double _Min(double a, double b) {
return (isnan(a) || isnan(b)) ? std::numeric_limits<double>::quiet_NaN() : ( a < b ? a : b );
}

template <typename T>
__device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; }

template <>
__device__ __inline__ float _Max(float a, float b) {
return (isnan(a) || isnan(b)) ? std::numeric_limits<float>::quiet_NaN() : ( a > b ? a : b );
}

template <>
__device__ __inline__ double _Max(double a, double b) {
return (isnan(a) || isnan(b)) ? std::numeric_limits<double>::quiet_NaN() : ( a > b ? a : b );
}

template <typename T>
__device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; }

Expand Down
114 changes: 114 additions & 0 deletions onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "test/common/cuda_op_test_utils.h"
#include "core/util/math.h"
#include <algorithm>
#include <limits>

Check warning on line 11 in onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: element_wise_ops_test.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc:11: Found C++ system header after other header. Should be: element_wise_ops_test.h, c system, c++ system, other. [build/include_order] [4]
#include <math.h>

namespace onnxruntime {
Expand Down Expand Up @@ -1506,6 +1507,34 @@
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); // TensorRT: Input batch size is inconsistent
}

TEST(MathOpTest, Min_12_Float_Nan) {
OpTester test("Min", 12);
test.AddInput<float>("data_2", {3, 3},
{std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
-0.5f, 0.0f, -2.0f,
0.5f, 0.0f, 2.0f});
test.AddInput<float>("data_1", {3, 1},
{0.0f, -1.0f, 1.0f});
test.AddOutput<float>("min", {3, 3},
{std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
-1.0f, -1.0f, -2.0f,
0.5f, 0.0f, 1.0f});
if (nullptr != DefaultCpuExecutionProvider().get()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider().get()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Min_12_Double) {
OpTester test("Min", 12);
test.AddInput<double>("data_0", {1, 3},
Expand All @@ -1523,6 +1552,34 @@
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent
}

TEST(MathOpTest, Min_12_Double_Nan) {
OpTester test("Min", 12);
test.AddInput<double>("data_2", {3, 3},
{std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN(),
-0.5, 0.0, -2.0,
0.5, 0.0, 2.0});
test.AddInput<double>("data_1", {3, 1},
{0.0, -1.0, 1.0});
test.AddOutput<double>("min", {3, 3},
{std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN(),
-1.0, -1.0, -2.0,
0.5, 0.0, 1.0});
if (nullptr != DefaultCpuExecutionProvider().get()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider().get()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Min_12_Int32) {
OpTester test("Min", 12);
test.AddInput<int32_t>("data_0", {1, 3},
Expand Down Expand Up @@ -1629,6 +1686,7 @@
MakeMLFloat16({-10.f, -10.f, -10.f}));
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent
}

Check warning on line 1689 in onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc:1689: Lines should be <= 120 characters long [whitespace/line_length] [2]
TEST(MathOpTest, Max_6) {
OpTester test("Max", 6);
std::vector<int64_t> dims{3, 3};
Expand Down Expand Up @@ -1717,6 +1775,34 @@
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); // TensorRT: Input batch size is inconsistent
}

TEST(MathOpTest, Max_12_Float_Nan) {
OpTester test("Max", 12);
test.AddInput<float>("data_2", {3, 3},
{std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
-0.5f, 0.0f, -2.0f,
0.5f, 0.0f, 2.0f});
test.AddInput<float>("data_1", {3, 1},
{0.0f, -1.0f, 1.0f});
test.AddOutput<float>("max", {3, 3},
{std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
-0.5f, 0.0f, -1.0f,
1.0f, 1.0f, 2.0f});
if (nullptr != DefaultCpuExecutionProvider().get()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider().get()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Max_12_Double) {
OpTester test("Max", 12);
test.AddInput<double>("data_0", {1, 3},
Expand All @@ -1734,6 +1820,34 @@
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent
}

TEST(MathOpTest, Max_12_Double_Nan) {
OpTester test("Max", 12);
test.AddInput<double>("data_2", {3, 3},
{std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN(),
-0.5, 0.0, -2.0,
0.5, 0.0, 2.0});
test.AddInput<double>("data_1", {3, 1},
{0.0, -1.0, 1.0});
test.AddOutput<double>("max", {3, 3},
{std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN(),
-0.5, 0.0, -1.0,
1.0, 1.0, 2.0});
if (nullptr != DefaultCpuExecutionProvider().get()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider().get()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Max_12_Int32) {
OpTester test("Max", 12);
test.AddInput<int32_t>("data_0", {1, 3},
Expand Down
Loading