Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BF16] GPU Implementation #3519

Open
richagadgil opened this issue Oct 9, 2024 · 1 comment
Open

[BF16] GPU Implementation #3519

richagadgil opened this issue Oct 9, 2024 · 1 comment
Assignees

Comments

@richagadgil
Copy link
Contributor

richagadgil commented Oct 9, 2024

Idea:

Cast FP32/FP16 to BF16.

Casting will be different based on type:

  • FP32 to BF16: truncate last 16 bits from mantissa, exponent stays the same
  • FP16 to BF16: more involved process -- export the exponent first and then truncate the mantissa.

May involve very slight loss in precision for both.

Workflow:

Follow similar workflow as FP8.

  • Add BF16 type to src/targets/gpu/target.cpp
  • Create cast_to_bf16 function to cast to BF16, by modifying:
    • src/include/migraphx/bfloat16_impl.hpp
    • src/include/migraphx/bfloat16.hpp
  • Determine operations that can support BF16 and modify accordingly (vectorization, math, etc.)
    • src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
    • src/targets/gpu/kernels/include/migraphx/kernels/*.hpp
  • Determine if BF16 can be supported by external libraries (like dnnl.cpp), then update:
    • src/targets/cpu/dnnl.cpp
    • src/targets/cpu/lowering.cpp
  • Write test cases for operations
@pfultz2
Copy link
Collaborator

pfultz2 commented Oct 10, 2024

So I started work on a generic_float class so we can specify any float type:

template<unsigned int N>
constexpr unsigned int all_ones() noexcept
{
    return (1 << N) - 1;
}

struct float32_parts 
{
    unsigned int mantissa : 23;
    unsigned int exponent : 8;
    unsigned int sign : 1;

    static constexpr unsigned int mantissa_width()
    {
        return 23;
    }

    static constexpr unsigned int max_exponent()
    {
        return all_ones<8>();
    }

    static constexpr int exponent_bias()
    {
        return all_ones<7>();
    }

    constexpr float to_float() const noexcept
    {
        return bit_cast<float>(*this);
    }
};

constexpr float32_parts get_parts(float f)
{
    return bit_cast<float32_parts>(f);
}

template<unsigned int MantissaSize, unsigned int ExponentSize, unsigned int Flags = 0>
struct generic_float
{
    unsigned int mantissa : MantissaSize;
    unsigned int exponent : ExponentSize;
    unsigned int sign : 1;

    static constexpr int exponent_bias()
    {
        return all_ones<ExponentSize - 1>();
    }

    explicit generic_float(float f = 0.0) noexcept
    {
        from_float(get_parts(f));
    }

    constexpr float to_float() const noexcept
    {
        float32_parts f{};
        f.sign = sign;
        f.mantissa = mantissa << (float32_parts::mantissa_width() - MantissaSize);
        if(exponent == all_ones<ExponentSize>())
        {
            f.exponent = float32_parts::max_exponent();
        }
        else
        {
            constexpr const auto diff = float32_parts::exponent_bias() - exponent_bias();
            f.exponent = exponent + diff;
        }
        return f.to_float();
    }

    constexpr void from_float(float32_parts f) noexcept
    {
        sign  = f.sign;
        mantissa = f.mantissa >> (float32_parts::mantissa_width() - MantissaSize);

        if(f.exponent == 0)
        {
            exponent = 0;
        }
        else if(f.exponent == float32_parts::max_exponent())
        {
            exponent = all_ones<ExponentSize>();
        }
        else
        {
            constexpr const int diff = float32_parts::exponent_bias() - exponent_bias();
            auto e = int(f.exponent) - diff;
            if(e >= all_ones<ExponentSize>())
            {
                exponent = all_ones<ExponentSize>();
                mantissa = 0;
            }
            else if(e < 0)
            {
                exponent = 0;
                mantissa = 0;
            }
            else
            {
                exponent = f.exponent - diff;
            }
        }

        exponent = std::min(f.exponent, all_ones<ExponentSize>());
    }

    constexpr bool is_normal() const noexcept
    {
        return exponent != all_ones<ExponentSize>() and exponent != 0;
    }

    constexpr bool is_inf() const noexcept
    {
        return exponent == all_ones<ExponentSize>() and mantissa == 0;
    }

    constexpr bool is_nan() const noexcept
    {
        return exponent == all_ones<ExponentSize>() and mantissa != 0;
    }

    constexpr bool is_finite() const noexcept
    {
        return exponent != all_ones<ExponentSize>();
    }

    constexpr operator float() const noexcept
    {
        return this->to_float();
    }

    static constexpr generic_float infinity()
    {
        generic_float x{};
        x.exponent = all_ones<ExponentSize>();
        return x;
    }

    static constexpr generic_float snan()
    {
        generic_float x{};
        x.exponent = all_ones<ExponentSize>();
        x.mantissa = 1 << (MantissaSize - 2);
        return x;
    }

    static constexpr generic_float qnan()
    {
        generic_float x{};
        x.exponent = all_ones<ExponentSize>();
        x.mantissa = 1 << (MantissaSize - 1);
        return x;
    }

    static constexpr generic_float min()
    {
        generic_float x{};
        x.exponent = 1;
        x.mantissa = 0;
        return x;
    }

    static constexpr generic_float denorm_min()
    {
        generic_float x{};
        x.exponent = 0;
        x.mantissa = 1;
        x.sign = 0;
        return x;
    }

    static constexpr generic_float lowest()
    {
        generic_float x{};
        x.exponent = all_ones<ExponentSize>() - 1;
        x.mantissa = all_ones<MantissaSize>();
        x.sign = 1;
        return x;
    }

    static constexpr generic_float max()
    {
        generic_float x{};
        x.exponent = all_ones<ExponentSize>() - 1;
        x.mantissa = all_ones<MantissaSize>();
        x.sign = 0;
        return x;
    }

    static constexpr generic_float epsilon()
    {
        generic_float x{1.0};
        x.mantissa++;
        return generic_float{x.to_float() - 1.0f};
    }
// NOLINTNEXTLINE
#define MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(op) \
    constexpr generic_float& operator op(const generic_float& rhs) \
    { \
        float self = *this; \
        float frhs = rhs; \
        self op frhs; \
        *this = generic_float(self); \
        return *this; \
    }
    MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(*=)
    MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(-=)
    MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(+=)
    MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(/=)
// NOLINTNEXTLINE
#define MIGRAPHX_GENERIC_FLOAT_BINARY_OP(op) \
    friend constexpr generic_float operator op(const generic_float& x, const generic_float& y) \
    { \
        return generic_float(float(x) op float(y)); \
    }
    MIGRAPHX_GENERIC_FLOAT_BINARY_OP(*)
    MIGRAPHX_GENERIC_FLOAT_BINARY_OP(-)
    MIGRAPHX_GENERIC_FLOAT_BINARY_OP(+)
    MIGRAPHX_GENERIC_FLOAT_BINARY_OP(/)
    MIGRAPHX_GENERIC_FLOAT_BINARY_OP(<)
    MIGRAPHX_GENERIC_FLOAT_BINARY_OP(<=)
    MIGRAPHX_GENERIC_FLOAT_BINARY_OP(>)
    MIGRAPHX_GENERIC_FLOAT_BINARY_OP(>=)

    friend constexpr generic_float operator==(const generic_float& x, const generic_float& y)
    {
        if (not x.is_finite() or not y.is_finite())
            return false;
        return std::tie(x.mantissa, x.exponent, x.sign) == std::tie(y.mantissa, y.exponent, y.sign);
    }

    friend constexpr generic_float operator!=(const generic_float& x, const generic_float& y)
    {
        return not(x == y);
    }
};

I maybe bias, but I do find this much more readable than the float8 code, I have done some initial testing with fp32 type:

using fp32 = generic_float<23, 8>;

#define CHECK_FLOAT(x, y) \
    CHECK(bit_equal(x, y)); \
    CHECK(bit_equal(x, y.to_float())); \
    CHECK(bit_equal(fp32{x}, y)); \
    CHECK(bit_equal(fp32{x}.to_float(), y.to_float()))


TEST_CASE(fp32_values)
{
    CHECK_FLOAT(1.0f, fp32{1.0f});
    CHECK_FLOAT(-1.0f, fp32{-1.0f});
    CHECK_FLOAT(std::numeric_limits<float>::min(), fp32::min());
    CHECK_FLOAT(std::numeric_limits<float>::lowest(), fp32::lowest());
    CHECK_FLOAT(std::numeric_limits<float>::max(), fp32::max());
    CHECK_FLOAT(std::numeric_limits<float>::epsilon(), fp32::epsilon());
    CHECK_FLOAT(std::numeric_limits<float>::infinity(), fp32::infinity());
    CHECK_FLOAT(std::numeric_limits<float>::quiet_NaN(), fp32::qnan());
    CHECK_FLOAT(std::numeric_limits<float>::signaling_NaN(), fp32::snan());
    CHECK_FLOAT(std::numeric_limits<float>::denorm_min(), fp32::denorm_min());
}

Although this doesnt test the truncation code. Specializations of std::numeric_limits need to be added. The flags parameter also need to be added at some point to handle fp8 types, but that shouldn't be a blocker for BF16.

It would be good to start by replacing our current half type with the generic_float since thats already implemented with initial tests already. We need to create a test suite similar to the fp8 test suite, but we probably cant create 64k arrays so we should probably create samples of say 1k values to test fp16 with.

Then we can easily add BF16:

  • Add typedef for bf16 and add a test suite similar to fp16/fp8.
  • Add bf16 to the shape class type and for the gpu and cpu backend added to the list foreliminate_data_type(this will convert it to float at first)
  • Remove bf16 from eliminate_data_type for GPU, add verify tests for bf16(those are just adding additional template instantiations), and then fix issues need to enable it(which requires a lot of small changes):
    • Added conversion to mlir type
    • Add conversion for jit compilation
    • There may be some more changes that might creep up

I split like this so it should allow smaller PRs that should make it easier to review and merge. So it would be 3 PRs for above and one more PR for the half., so 4 PRs in total.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants