diff --git a/include/onnxruntime/core/framework/float8.h b/include/onnxruntime/core/framework/float8.h index dd607cbbc6952..5e39849186756 100644 --- a/include/onnxruntime/core/framework/float8.h +++ b/include/onnxruntime/core/framework/float8.h @@ -219,48 +219,49 @@ struct Float8E4M3FNUZ { } else { uint8_t e = static_cast((b & 0x7F800000) >> 23); // exponent uint32_t m = static_cast(b & 0x007FFFFF); // mantissa - if (e != 0) { - if (e < 116) { - } else if (e < 120) { - // denormalized number - auto d = 119 - e; - if (d < 3) { - val |= 1 << (2 - d); - val |= m >> (21 + d); - } else if (m > 0) { - val |= 1; - } - auto mask = 1 << (20 + d); - if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) { + + if (e < 116) { + // all near-zero numbers round to positive zero: + val = 0; + } else if (e < 120) { + // denormalized number + auto d = 119 - e; + if (d < 3) { + val |= 1 << (2 - d); + val |= m >> (21 + d); + } else if (m > 0) { + val |= 1; + } else { + // round to positive zero: + val = 0; + } + auto mask = 1 << (20 + d); + if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) { + // rounding + val += 1; + } + } else if (e < 135) { + // normalized number + auto ex = e - 119; + if (ex == 0) { + val |= 0x4; + val |= m >> 21; + } else { + val |= ex << 3; + val |= m >> 20; + } + if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) { + if ((val & 0x7F) < 0x7F) { // rounding val += 1; + } else if (!saturate) { + val = 0x80; } - } else if (e < 135) { - // normalized number - auto ex = e - 119; - if (ex == 0) { - val |= 0x4; - val |= m >> 21; - } else { - val |= ex << 3; - val |= m >> 20; - } - if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) { - if ((val & 0x7F) < 0x7F) { - // rounding - val += 1; - } else if (!saturate) { - val = 0x80; - } - } - } else if (saturate) { - val |= 0x7F; - } else { - val = 0x80; } - } else if (m == 0) { - // -0 - val = 0; + } else if (saturate) { + val |= 0x7F; + } else { + val = 0x80; } } } @@ -531,45 +532,45 @@ struct Float8E5M2FNUZ { uint32_t e = (b & 0x7F800000) >> 23; // exponent uint32_t m = b & 0x007FFFFF; // mantissa - if (e != 0) { - if (e < 109) { - } else if (e < 112) { - // denormalized number - auto d = 111 - e; - if (d < 2) { - val |= 1 << (1 - d); - val |= m >> (22 + d); - } else if (m > 0) { - val |= 1; - } - auto mask = 1 << (21 + d); - if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) { + if (e < 109) { + // all near-zero numbers round to positive zero: + val = 0; + } else if (e < 112) { + // denormalized number + auto d = 111 - e; + if (d < 2) { + val |= 1 << (1 - d); + val |= m >> (22 + d); + } else if (m > 0) { + val |= 1; + } else { + // round to positive zero: + val = 0; + } + auto mask = 1 << (21 + d); + if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) { + // rounding + val += 1; + } + } else if (e < 143) { + // normalized number + auto ex = e - 111; + val |= ex << 2; + val |= m >> 21; + if ((m & 0x100000) && ((m & 0xFFFFF) || (m & 0x200000))) { + if ((val & 0x7F) < 0x7F) { // rounding val += 1; + } else if (!saturate) { + val = 0x80; } - } else if (e < 143) { - // normalized number - auto ex = e - 111; - val |= ex << 2; - val |= m >> 21; - if ((m & 0x100000) && ((m & 0xFFFFF) || (m & 0x200000))) { - if ((val & 0x7F) < 0x7F) { - // rounding - val += 1; - } else if (!saturate) { - val = 0x80; - } - } - } else if ((e == 255) && (m == 0)) { - val = 0x80; - } else if (saturate) { - val |= 0x7F; - } else { - val = 0x80; } - } else if (m == 0) { - // -0 - val = 0; + } else if ((e == 255) && (m == 0)) { + val = 0x80; + } else if (saturate) { + val |= 0x7F; + } else { + val = 0x80; } } } diff --git a/onnxruntime/test/framework/float_8_test.cc b/onnxruntime/test/framework/float_8_test.cc index 948e0e05a9141..62a82e50d4c8a 100644 --- a/onnxruntime/test/framework/float_8_test.cc +++ b/onnxruntime/test/framework/float_8_test.cc @@ -52,6 +52,9 @@ TEST(Float8_Tests, NanE4M3FN) { // 0x7FC00000 is the value used by numpy. EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0x7FC00000}).val).val, static_cast(0x7F)); EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0xFFC00000}).val).val, static_cast(0xFF)); + // small negative values should round to negative zero + EXPECT_EQ(onnxruntime::Float8E4M3FN(-0.00000001f).ToFloat(), -0.0f); + EXPECT_EQ(onnxruntime::Float8E4M3FN(-0.00000001f).val, static_cast(0x80)); } TEST(Float8_Tests, NanE4M3FNUZ) { @@ -64,6 +67,11 @@ TEST(Float8_Tests, NanE4M3FNUZ) { // 0x7FC00000 is the value used by numpy. EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0x7FC00000}).val).val, static_cast(0x80)); EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0xFFC00000}).val).val, static_cast(0x80)); + // small negative values should round to zero + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ(-0.00000001f).ToFloat(), 0.0f); + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ(-0.00000001f).val, static_cast(0x00)); + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ(-0x1.0p-11f).ToFloat(), 0.0f); + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ(-0x1.0p-11f).val, static_cast(0x00)); } TEST(Float8_Tests, NanE5M2) { @@ -76,6 +84,9 @@ TEST(Float8_Tests, NanE5M2) { // 0x7FC00000 is the value used by numpy. EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0x7FC00000}).val).val, static_cast(0x7F)); EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0xFFC00000}).val).val, static_cast(0xFF)); + // small negative values should round to negative zero + EXPECT_EQ(onnxruntime::Float8E5M2(-0.00000001f).ToFloat(), 0.0f); + EXPECT_EQ(onnxruntime::Float8E5M2(-0.00000001f).val, static_cast(0x80)); } TEST(Float8_Tests, NanE5M2FNUZ) { @@ -88,6 +99,10 @@ TEST(Float8_Tests, NanE5M2FNUZ) { // 0x7FC00000 is the value used by numpy. EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0x7FC00000}).val).val, static_cast(0x80)); EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0xFFC00000}).val).val, static_cast(0x80)); + // small negative values should round to zero + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ(-0.00000001f).ToFloat(), 0.0f); + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ(-0.00000001f).val, static_cast(0x00)); + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ(-0x1.0p-18f).ToFloat(), 0.0f); } } // namespace test