Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix NaN propagation for float16 min and max operators #22161

Merged
merged 7 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions onnxruntime/core/providers/cpu/math/element_wise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -757,9 +757,11 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) {
EigenVectorArrayMap<Eigen::half> output_vec_map(output, num_elements);

if (is_min) {
output_vec_map = input_1_vec_map.min(static_cast<Eigen::half>(per_iter_bh.ScalarInput0<MLFloat16>()));
output_vec_map = input_1_vec_map.template min<Eigen::PropagateNaN>(
static_cast<Eigen::half>(per_iter_bh.ScalarInput0<MLFloat16>()));
} else {
output_vec_map = input_1_vec_map.max(static_cast<Eigen::half>(per_iter_bh.ScalarInput0<MLFloat16>()));
output_vec_map = input_1_vec_map.template max<Eigen::PropagateNaN>(
static_cast<Eigen::half>(per_iter_bh.ScalarInput0<MLFloat16>()));
}
},
[](BroadcastHelper& per_iter_bh) {
Expand All @@ -772,9 +774,11 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) {
EigenVectorArrayMap<Eigen::half> output_vec_map(output, num_elements);

if (is_min) {
output_vec_map = input_0_vec_map.min(static_cast<Eigen::half>(per_iter_bh.ScalarInput1<MLFloat16>()));
output_vec_map = input_0_vec_map.template min<Eigen::PropagateNaN>(
static_cast<Eigen::half>(per_iter_bh.ScalarInput1<MLFloat16>()));
} else {
output_vec_map = input_0_vec_map.max(static_cast<Eigen::half>(per_iter_bh.ScalarInput1<MLFloat16>()));
output_vec_map = input_0_vec_map.template max<Eigen::PropagateNaN>(
static_cast<Eigen::half>(per_iter_bh.ScalarInput1<MLFloat16>()));
}
},
[](BroadcastHelper& per_iter_bh) {
Expand All @@ -790,9 +794,9 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) {
EigenVectorArrayMap<Eigen::half> output_vec_map(output, num_elements);

if (is_min) {
output_vec_map = input_0_vec_map.min(input_1_vec_map);
output_vec_map = input_0_vec_map.template min<Eigen::PropagateNaN>(input_1_vec_map);
} else {
output_vec_map = input_0_vec_map.max(input_1_vec_map);
output_vec_map = input_0_vec_map.template max<Eigen::PropagateNaN>(input_1_vec_map);
}
}};

Expand Down
29 changes: 29 additions & 0 deletions onnxruntime/core/providers/cuda/cu_inc/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,12 @@ __device__ __inline__ double _Pow(double a, double b) { return pow(a, b); }
template <>
__device__ __inline__ half _Pow(half a, half b) { return half(powf((float)a, (float)b)); }

#define ISNAN_HALF(v__) static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&v__) & ~MLFloat16::kSignMask) \
> MLFloat16::kPositiveInfinityBits

#define ISNAN_BFLOAT16(v__) static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&v__) & ~BFloat16::kSignMask) \
> BFloat16::kPositiveInfinityBits

template <typename T>
__device__ __inline__ T _Min(T a, T b) { return a < b ? a : b; }

Expand All @@ -360,6 +366,16 @@ __device__ __inline__ double _Min(double a, double b) {
return (isnan(a) || isnan(b)) ? std::numeric_limits<double>::quiet_NaN() : ( a < b ? a : b );
}

template <>
__device__ __inline__ half _Min(half a, half b) {
return ISNAN_HALF(a) ? a : (ISNAN_HALF(b) ? b : (a < b ? a : b));
adamreeve marked this conversation as resolved.
Show resolved Hide resolved
}

template <>
__device__ __inline__ BFloat16 _Min(BFloat16 a, BFloat16 b) {
return ISNAN_BFLOAT16(a) ? a : (ISNAN_BFLOAT16(b) ? b : (a < b ? a : b));
}

template <typename T>
__device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; }

Expand All @@ -373,6 +389,19 @@ __device__ __inline__ double _Max(double a, double b) {
return (isnan(a) || isnan(b)) ? std::numeric_limits<double>::quiet_NaN() : ( a > b ? a : b );
}

template <>
__device__ __inline__ half _Max(half a, half b) {
return ISNAN_HALF(a) ? a : (ISNAN_HALF(b) ? b : (a > b ? a : b));
}

template <>
__device__ __inline__ BFloat16 _Max(BFloat16 a, BFloat16 b) {
return ISNAN_BFLOAT16(a) ? a : (ISNAN_BFLOAT16(b) ? b : (a > b ? a : b));
}

#undef ISNAN_HALF
#undef ISNAN_BFLOAT16

template <typename T>
__device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; }

Expand Down
125 changes: 125 additions & 0 deletions onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1835,6 +1835,68 @@ TEST(MathOpTest, Min_12_MLFloat16_VectorMatrix) {
}
}

TEST(MathOpTest, Min_12_MLFloat16_Nan) {
OpTester test("Min", 12);
test.AddInput<MLFloat16>("data_0", {4, 1},
MakeMLFloat16({-1.0f, std::numeric_limits<float>::quiet_NaN(), 1.0f, 0.5f}));
test.AddInput<MLFloat16>("data_1", {4, 1},
MakeMLFloat16({0.5f, 1.0f, 0.25f, std::numeric_limits<float>::quiet_NaN()}));
test.AddOutput<MLFloat16>("min", {4, 1},
MakeMLFloat16({-1.0f,
std::numeric_limits<float>::quiet_NaN(),
0.25f,
std::numeric_limits<float>::quiet_NaN()}));
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Min_12_MLFloat16_Nan_with_scalar) {
OpTester test("Min", 12);
test.AddInput<MLFloat16>("data_0", {3, 1},
MakeMLFloat16({-1.0f, std::numeric_limits<float>::quiet_NaN(), 1.0f}));
test.AddInput<MLFloat16>("data_1", {1}, MakeMLFloat16({0.25f}));
test.AddOutput<MLFloat16>("min", {3, 1},
MakeMLFloat16({-1.0f, std::numeric_limits<float>::quiet_NaN(), 0.25f}));
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Min_12_MLFloat16_with_scalar_Nan) {
OpTester test("Min", 12);
test.AddInput<MLFloat16>("data_0", {3, 1},
MakeMLFloat16({-0.5f, 1.0f, 1.5f}));
test.AddInput<MLFloat16>("data_1", {1}, MakeMLFloat16({std::numeric_limits<float>::quiet_NaN()}));
test.AddOutput<MLFloat16>("min", {3, 1},
MakeMLFloat16({std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN()}));
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}
TEST(MathOpTest, Max_6) {
OpTester test("Max", 6);
std::vector<int64_t> dims{3, 3};
Expand Down Expand Up @@ -2235,6 +2297,69 @@ TEST(MathOpTest, Max_12_MLFloat16_VectorMatrix) {
}
}

TEST(MathOpTest, Max_12_MLFloat16_Nan) {
OpTester test("Max", 12);
test.AddInput<MLFloat16>("data_0", {4, 1},
MakeMLFloat16({-1.0f, std::numeric_limits<float>::quiet_NaN(), 1.0f, 0.5f}));
test.AddInput<MLFloat16>("data_1", {4, 1},
MakeMLFloat16({0.5f, 1.0f, 0.25f, std::numeric_limits<float>::quiet_NaN()}));
test.AddOutput<MLFloat16>("max", {4, 1},
MakeMLFloat16({0.5f,
std::numeric_limits<float>::quiet_NaN(),
1.0f,
std::numeric_limits<float>::quiet_NaN()}));
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Max_12_MLFloat16_Nan_with_scalar) {
OpTester test("Max", 12);
test.AddInput<MLFloat16>("data_0", {3, 1},
MakeMLFloat16({-1.0f, std::numeric_limits<float>::quiet_NaN(), 1.0f}));
test.AddInput<MLFloat16>("data_1", {1}, MakeMLFloat16({0.25f}));
test.AddOutput<MLFloat16>("max", {3, 1},
MakeMLFloat16({0.25f, std::numeric_limits<float>::quiet_NaN(), 1.0f}));
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Max_12_MLFloat16_with_scalar_Nan) {
OpTester test("Max", 12);
test.AddInput<MLFloat16>("data_0", {3, 1},
MakeMLFloat16({-0.5f, 1.0f, 1.5f}));
test.AddInput<MLFloat16>("data_1", {1}, MakeMLFloat16({std::numeric_limits<float>::quiet_NaN()}));
test.AddOutput<MLFloat16>("max", {3, 1},
MakeMLFloat16({std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN()}));
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Not) {
OpTester test("Not");
std::vector<int64_t> dims{2};
Expand Down
Loading