-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BF16] GPU Implementation #3519
Comments
So I started work on a template<unsigned int N>
constexpr unsigned int all_ones() noexcept
{
return (1 << N) - 1;
}
struct float32_parts
{
unsigned int mantissa : 23;
unsigned int exponent : 8;
unsigned int sign : 1;
static constexpr unsigned int mantissa_width()
{
return 23;
}
static constexpr unsigned int max_exponent()
{
return all_ones<8>();
}
static constexpr int exponent_bias()
{
return all_ones<7>();
}
constexpr float to_float() const noexcept
{
return bit_cast<float>(*this);
}
};
constexpr float32_parts get_parts(float f)
{
return bit_cast<float32_parts>(f);
}
template<unsigned int MantissaSize, unsigned int ExponentSize, unsigned int Flags = 0>
struct generic_float
{
unsigned int mantissa : MantissaSize;
unsigned int exponent : ExponentSize;
unsigned int sign : 1;
static constexpr int exponent_bias()
{
return all_ones<ExponentSize - 1>();
}
explicit generic_float(float f = 0.0) noexcept
{
from_float(get_parts(f));
}
constexpr float to_float() const noexcept
{
float32_parts f{};
f.sign = sign;
f.mantissa = mantissa << (float32_parts::mantissa_width() - MantissaSize);
if(exponent == all_ones<ExponentSize>())
{
f.exponent = float32_parts::max_exponent();
}
else
{
constexpr const auto diff = float32_parts::exponent_bias() - exponent_bias();
f.exponent = exponent + diff;
}
return f.to_float();
}
constexpr void from_float(float32_parts f) noexcept
{
sign = f.sign;
mantissa = f.mantissa >> (float32_parts::mantissa_width() - MantissaSize);
if(f.exponent == 0)
{
exponent = 0;
}
else if(f.exponent == float32_parts::max_exponent())
{
exponent = all_ones<ExponentSize>();
}
else
{
constexpr const int diff = float32_parts::exponent_bias() - exponent_bias();
auto e = int(f.exponent) - diff;
if(e >= all_ones<ExponentSize>())
{
exponent = all_ones<ExponentSize>();
mantissa = 0;
}
else if(e < 0)
{
exponent = 0;
mantissa = 0;
}
else
{
exponent = f.exponent - diff;
}
}
exponent = std::min(f.exponent, all_ones<ExponentSize>());
}
constexpr bool is_normal() const noexcept
{
return exponent != all_ones<ExponentSize>() and exponent != 0;
}
constexpr bool is_inf() const noexcept
{
return exponent == all_ones<ExponentSize>() and mantissa == 0;
}
constexpr bool is_nan() const noexcept
{
return exponent == all_ones<ExponentSize>() and mantissa != 0;
}
constexpr bool is_finite() const noexcept
{
return exponent != all_ones<ExponentSize>();
}
constexpr operator float() const noexcept
{
return this->to_float();
}
static constexpr generic_float infinity()
{
generic_float x{};
x.exponent = all_ones<ExponentSize>();
return x;
}
static constexpr generic_float snan()
{
generic_float x{};
x.exponent = all_ones<ExponentSize>();
x.mantissa = 1 << (MantissaSize - 2);
return x;
}
static constexpr generic_float qnan()
{
generic_float x{};
x.exponent = all_ones<ExponentSize>();
x.mantissa = 1 << (MantissaSize - 1);
return x;
}
static constexpr generic_float min()
{
generic_float x{};
x.exponent = 1;
x.mantissa = 0;
return x;
}
static constexpr generic_float denorm_min()
{
generic_float x{};
x.exponent = 0;
x.mantissa = 1;
x.sign = 0;
return x;
}
static constexpr generic_float lowest()
{
generic_float x{};
x.exponent = all_ones<ExponentSize>() - 1;
x.mantissa = all_ones<MantissaSize>();
x.sign = 1;
return x;
}
static constexpr generic_float max()
{
generic_float x{};
x.exponent = all_ones<ExponentSize>() - 1;
x.mantissa = all_ones<MantissaSize>();
x.sign = 0;
return x;
}
static constexpr generic_float epsilon()
{
generic_float x{1.0};
x.mantissa++;
return generic_float{x.to_float() - 1.0f};
}
// NOLINTNEXTLINE
#define MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(op) \
constexpr generic_float& operator op(const generic_float& rhs) \
{ \
float self = *this; \
float frhs = rhs; \
self op frhs; \
*this = generic_float(self); \
return *this; \
}
MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(*=)
MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(-=)
MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(+=)
MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(/=)
// NOLINTNEXTLINE
#define MIGRAPHX_GENERIC_FLOAT_BINARY_OP(op) \
friend constexpr generic_float operator op(const generic_float& x, const generic_float& y) \
{ \
return generic_float(float(x) op float(y)); \
}
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(*)
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(-)
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(+)
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(/)
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(<)
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(<=)
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(>)
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(>=)
friend constexpr generic_float operator==(const generic_float& x, const generic_float& y)
{
if (not x.is_finite() or not y.is_finite())
return false;
return std::tie(x.mantissa, x.exponent, x.sign) == std::tie(y.mantissa, y.exponent, y.sign);
}
friend constexpr generic_float operator!=(const generic_float& x, const generic_float& y)
{
return not(x == y);
}
}; I maybe bias, but I do find this much more readable than the float8 code, I have done some initial testing with fp32 type: using fp32 = generic_float<23, 8>;
#define CHECK_FLOAT(x, y) \
CHECK(bit_equal(x, y)); \
CHECK(bit_equal(x, y.to_float())); \
CHECK(bit_equal(fp32{x}, y)); \
CHECK(bit_equal(fp32{x}.to_float(), y.to_float()))
TEST_CASE(fp32_values)
{
CHECK_FLOAT(1.0f, fp32{1.0f});
CHECK_FLOAT(-1.0f, fp32{-1.0f});
CHECK_FLOAT(std::numeric_limits<float>::min(), fp32::min());
CHECK_FLOAT(std::numeric_limits<float>::lowest(), fp32::lowest());
CHECK_FLOAT(std::numeric_limits<float>::max(), fp32::max());
CHECK_FLOAT(std::numeric_limits<float>::epsilon(), fp32::epsilon());
CHECK_FLOAT(std::numeric_limits<float>::infinity(), fp32::infinity());
CHECK_FLOAT(std::numeric_limits<float>::quiet_NaN(), fp32::qnan());
CHECK_FLOAT(std::numeric_limits<float>::signaling_NaN(), fp32::snan());
CHECK_FLOAT(std::numeric_limits<float>::denorm_min(), fp32::denorm_min());
} Although this doesnt test the truncation code. Specializations of It would be good to start by replacing our current half type with the generic_float since thats already implemented with initial tests already. We need to create a test suite similar to the fp8 test suite, but we probably cant create 64k arrays so we should probably create samples of say 1k values to test fp16 with. Then we can easily add BF16:
I split like this so it should allow smaller PRs that should make it easier to review and merge. So it would be 3 PRs for above and one more PR for the half., so 4 PRs in total. |
Idea:
Cast FP32/FP16 to BF16.
Casting will be different based on type:
May involve very slight loss in precision for both.
Workflow:
Follow similar workflow as FP8.
cast_to_bf16
function to cast to BF16, by modifying:The text was updated successfully, but these errors were encountered: