Skip to content

Commit

Permalink
near-zero negative values must convert to 0 not NAN (#18473)
Browse files Browse the repository at this point in the history
for the Float8 types with unsigned zero, we must clear the sign bit when
rounding to zero;
otherwise we end up with 0x80 which is the encoding for NAN.

### Description
Handle all zero and near-zero values the same way, rounding to positive
zero.
Note that I removed one "if" level but did not re-indent the code in
this PR, to make it
easier to see what the actual changes are.

### Motivation and Context
For the two new 8-bit floating point types Float8E4M3FNUZ and
Float8E5M2FNUZ,
converting from a near-zero negative value would end up with the sign
bit set only;
this bit pattern is not negative zero but instead means NAN.
  • Loading branch information
arnej27959 authored Sep 6, 2024
1 parent 605a84f commit 493159b
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 73 deletions.
147 changes: 74 additions & 73 deletions include/onnxruntime/core/framework/float8.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,48 +219,49 @@ struct Float8E4M3FNUZ {
} else {
uint8_t e = static_cast<uint8_t>((b & 0x7F800000) >> 23); // exponent
uint32_t m = static_cast<uint32_t>(b & 0x007FFFFF); // mantissa
if (e != 0) {
if (e < 116) {
} else if (e < 120) {
// denormalized number
auto d = 119 - e;
if (d < 3) {
val |= 1 << (2 - d);
val |= m >> (21 + d);
} else if (m > 0) {
val |= 1;
}
auto mask = 1 << (20 + d);
if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {

if (e < 116) {
// all near-zero numbers round to positive zero:
val = 0;
} else if (e < 120) {
// denormalized number
auto d = 119 - e;
if (d < 3) {
val |= 1 << (2 - d);
val |= m >> (21 + d);
} else if (m > 0) {
val |= 1;
} else {
// round to positive zero:
val = 0;
}
auto mask = 1 << (20 + d);
if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
// rounding
val += 1;
}
} else if (e < 135) {
// normalized number
auto ex = e - 119;
if (ex == 0) {
val |= 0x4;
val |= m >> 21;
} else {
val |= ex << 3;
val |= m >> 20;
}
if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) {
if ((val & 0x7F) < 0x7F) {
// rounding
val += 1;
} else if (!saturate) {
val = 0x80;
}
} else if (e < 135) {
// normalized number
auto ex = e - 119;
if (ex == 0) {
val |= 0x4;
val |= m >> 21;
} else {
val |= ex << 3;
val |= m >> 20;
}
if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) {
if ((val & 0x7F) < 0x7F) {
// rounding
val += 1;
} else if (!saturate) {
val = 0x80;
}
}
} else if (saturate) {
val |= 0x7F;
} else {
val = 0x80;
}
} else if (m == 0) {
// -0
val = 0;
} else if (saturate) {
val |= 0x7F;
} else {
val = 0x80;
}
}
}
Expand Down Expand Up @@ -531,45 +532,45 @@ struct Float8E5M2FNUZ {
uint32_t e = (b & 0x7F800000) >> 23; // exponent
uint32_t m = b & 0x007FFFFF; // mantissa

if (e != 0) {
if (e < 109) {
} else if (e < 112) {
// denormalized number
auto d = 111 - e;
if (d < 2) {
val |= 1 << (1 - d);
val |= m >> (22 + d);
} else if (m > 0) {
val |= 1;
}
auto mask = 1 << (21 + d);
if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
if (e < 109) {
// all near-zero numbers round to positive zero:
val = 0;
} else if (e < 112) {
// denormalized number
auto d = 111 - e;
if (d < 2) {
val |= 1 << (1 - d);
val |= m >> (22 + d);
} else if (m > 0) {
val |= 1;
} else {
// round to positive zero:
val = 0;
}
auto mask = 1 << (21 + d);
if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
// rounding
val += 1;
}
} else if (e < 143) {
// normalized number
auto ex = e - 111;
val |= ex << 2;
val |= m >> 21;
if ((m & 0x100000) && ((m & 0xFFFFF) || (m & 0x200000))) {
if ((val & 0x7F) < 0x7F) {
// rounding
val += 1;
} else if (!saturate) {
val = 0x80;
}
} else if (e < 143) {
// normalized number
auto ex = e - 111;
val |= ex << 2;
val |= m >> 21;
if ((m & 0x100000) && ((m & 0xFFFFF) || (m & 0x200000))) {
if ((val & 0x7F) < 0x7F) {
// rounding
val += 1;
} else if (!saturate) {
val = 0x80;
}
}
} else if ((e == 255) && (m == 0)) {
val = 0x80;
} else if (saturate) {
val |= 0x7F;
} else {
val = 0x80;
}
} else if (m == 0) {
// -0
val = 0;
} else if ((e == 255) && (m == 0)) {
val = 0x80;
} else if (saturate) {
val |= 0x7F;
} else {
val = 0x80;
}
}
}
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/test/framework/float_8_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ TEST(Float8_Tests, NanE4M3FN) {
// 0x7FC00000 is the value used by numpy.
EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0x7FC00000}).val).val, static_cast<uint8_t>(0x7F));
EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0xFFC00000}).val).val, static_cast<uint8_t>(0xFF));
// small negative values should round to negative zero
EXPECT_EQ(onnxruntime::Float8E4M3FN(-0.00000001f).ToFloat(), -0.0f);
EXPECT_EQ(onnxruntime::Float8E4M3FN(-0.00000001f).val, static_cast<uint8_t>(0x80));
}

TEST(Float8_Tests, NanE4M3FNUZ) {
Expand All @@ -64,6 +67,11 @@ TEST(Float8_Tests, NanE4M3FNUZ) {
// 0x7FC00000 is the value used by numpy.
EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0x7FC00000}).val).val, static_cast<uint8_t>(0x80));
EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0xFFC00000}).val).val, static_cast<uint8_t>(0x80));
// small negative values should round to zero
EXPECT_EQ(onnxruntime::Float8E4M3FNUZ(-0.00000001f).ToFloat(), 0.0f);
EXPECT_EQ(onnxruntime::Float8E4M3FNUZ(-0.00000001f).val, static_cast<uint8_t>(0x00));
EXPECT_EQ(onnxruntime::Float8E4M3FNUZ(-0x1.0p-11f).ToFloat(), 0.0f);
EXPECT_EQ(onnxruntime::Float8E4M3FNUZ(-0x1.0p-11f).val, static_cast<uint8_t>(0x00));
}

TEST(Float8_Tests, NanE5M2) {
Expand All @@ -76,6 +84,9 @@ TEST(Float8_Tests, NanE5M2) {
// 0x7FC00000 is the value used by numpy.
EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0x7FC00000}).val).val, static_cast<uint8_t>(0x7F));
EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0xFFC00000}).val).val, static_cast<uint8_t>(0xFF));
// small negative values should round to negative zero
EXPECT_EQ(onnxruntime::Float8E5M2(-0.00000001f).ToFloat(), 0.0f);
EXPECT_EQ(onnxruntime::Float8E5M2(-0.00000001f).val, static_cast<uint8_t>(0x80));
}

TEST(Float8_Tests, NanE5M2FNUZ) {
Expand All @@ -88,6 +99,10 @@ TEST(Float8_Tests, NanE5M2FNUZ) {
// 0x7FC00000 is the value used by numpy.
EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0x7FC00000}).val).val, static_cast<uint8_t>(0x80));
EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0xFFC00000}).val).val, static_cast<uint8_t>(0x80));
// small negative values should round to zero
EXPECT_EQ(onnxruntime::Float8E5M2FNUZ(-0.00000001f).ToFloat(), 0.0f);
EXPECT_EQ(onnxruntime::Float8E5M2FNUZ(-0.00000001f).val, static_cast<uint8_t>(0x00));
EXPECT_EQ(onnxruntime::Float8E5M2FNUZ(-0x1.0p-18f).ToFloat(), 0.0f);
}

} // namespace test
Expand Down

0 comments on commit 493159b

Please sign in to comment.