Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement BF16 using generic_float class #3578

Merged
merged 74 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from 70 commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
c51c1ce
first pass at integrating generic float
richagadgil Oct 10, 2024
134b408
fix namespaces
richagadgil Oct 10, 2024
d4fa6eb
fix mantissa
richagadgil Oct 10, 2024
0b60841
refactor
richagadgil Oct 11, 2024
7a646f1
refactor
richagadgil Oct 11, 2024
ebe819b
add fp
richagadgil Oct 11, 2024
379a77a
fixed generic float class
richagadgil Oct 14, 2024
174384c
add fp32 test
richagadgil Oct 14, 2024
787b651
remove import
richagadgil Oct 14, 2024
1d1fa1c
update tests
richagadgil Oct 15, 2024
1791092
fp16 tests that work
richagadgil Oct 17, 2024
a2eb005
update tests
richagadgil Oct 18, 2024
ff8ffc7
updated fp16 and fp32 tests
richagadgil Oct 18, 2024
e36fd65
half tests
richagadgil Oct 22, 2024
9ac4e2a
underflow and overflow tests
richagadgil Oct 22, 2024
f05fd31
generate map
richagadgil Oct 22, 2024
cb4d92d
add more tests
richagadgil Oct 22, 2024
0cc1946
fix names
richagadgil Oct 22, 2024
85a761b
update tests
richagadgil Oct 23, 2024
65cf9ae
remove and
richagadgil Oct 24, 2024
fbabf54
disable warning
richagadgil Oct 24, 2024
549f5e6
fix tidy warning
richagadgil Oct 24, 2024
d302e5d
migraphx py fix
richagadgil Oct 25, 2024
8d475e3
add increments
richagadgil Oct 25, 2024
a0fd055
fix warnings
richagadgil Oct 25, 2024
41379fe
disable duplicate branch warning
richagadgil Oct 25, 2024
0c29c7b
add countzero_std
richagadgil Oct 28, 2024
4b012a8
ci error
richagadgil Oct 28, 2024
dbaa3a8
simplify countl
richagadgil Oct 28, 2024
b2bd2a0
fix ci
richagadgil Oct 28, 2024
6f328f0
src
richagadgil Oct 29, 2024
e6d9763
remove flag
richagadgil Oct 29, 2024
6538050
hide abi warning
richagadgil Oct 29, 2024
4e96d4d
revert changes
richagadgil Oct 29, 2024
ef11f1f
Merge branch 'develop' into generic_float
richagadgil Oct 29, 2024
e4a25bd
change half in tests
richagadgil Oct 29, 2024
3354c6e
Update generic_float.hpp
richagadgil Oct 29, 2024
6de079b
format
richagadgil Oct 29, 2024
7750874
Merge branch 'develop' into generic_float
richagadgil Oct 29, 2024
801f485
Merge branch 'develop' into generic_float
causten Oct 30, 2024
33e2c8d
fix bug
richagadgil Oct 30, 2024
9bb7198
Merge branch 'generic_float' of github.com:ROCm/AMDMIGraphX into gene…
richagadgil Oct 30, 2024
b3c345d
fix err
richagadgil Oct 30, 2024
03df6f9
edits
richagadgil Oct 31, 2024
ad817b2
tidy and format
richagadgil Oct 31, 2024
898417b
tidy etc
richagadgil Oct 31, 2024
aa5b9c9
gf
richagadgil Oct 31, 2024
6f72370
fix tidy errs
richagadgil Nov 1, 2024
0aab1a0
bf16 changes
richagadgil Nov 4, 2024
7b965c0
add flag to trace quantization passes (#3571)
shivadbhavsar Oct 30, 2024
5f5f13d
bf16
richagadgil Oct 30, 2024
d64b124
Update bf16.cpp
richagadgil Nov 1, 2024
a064eaa
Update bf16.hpp
richagadgil Nov 2, 2024
befbd9e
Update bf16.hpp
richagadgil Nov 2, 2024
08b9511
update files with working version
richagadgil Nov 4, 2024
b9d204e
Update bf16.cpp
richagadgil Nov 4, 2024
fb6df2d
Update generic_float.hpp
richagadgil Nov 4, 2024
bb78138
Merge branch 'develop' into bf16
richagadgil Nov 8, 2024
8e1f99e
add extra common type
richagadgil Nov 8, 2024
6192970
tidy
richagadgil Nov 8, 2024
c0d6bc4
Update bf16.hpp
richagadgil Nov 11, 2024
7bfc407
Update generic_float.hpp
richagadgil Nov 11, 2024
4cb96ad
Merge branch 'develop' into bf16
richagadgil Nov 11, 2024
ffd4ba2
remove imports
richagadgil Nov 12, 2024
8a10da3
Merge branch 'develop' into bf16
richagadgil Nov 12, 2024
1565a0e
ref tests
richagadgil Nov 13, 2024
e6d1155
migraphx_py fix
richagadgil Nov 13, 2024
867e960
fix test cae by index
richagadgil Nov 13, 2024
9852da5
add rocblas type
richagadgil Nov 13, 2024
bf50653
fix tgts err
richagadgil Nov 13, 2024
0ebd220
address changes
richagadgil Nov 18, 2024
043e322
Merge branch 'develop' into bf16
richagadgil Nov 18, 2024
21746a5
Merge branch 'develop' into bf16
causten Nov 19, 2024
47a1810
skip jit tests
richagadgil Nov 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/api/include/migraphx/migraphx.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
m(bool_type, bool) \
m(half_type, half) \
m(bf16_type, bf16) \
richagadgil marked this conversation as resolved.
Show resolved Hide resolved
m(float_type, float) \
m(double_type, double) \
m(uint8_type, uint8_t) \
Expand Down
39 changes: 39 additions & 0 deletions src/include/migraphx/bf16.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#ifndef MIGRAPHX_GUARD_RTGLIB_BF16_HPP
#define MIGRAPHX_GUARD_RTGLIB_BF16_HPP

#include <migraphx/generic_float.hpp>
#include <migraphx/config.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

using bf16 = migraphx::generic_float<7, 8>;

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
4 changes: 3 additions & 1 deletion src/include/migraphx/generic_float.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ struct float32_parts
unsigned int exponent : 8;
unsigned int sign : 1;

static constexpr unsigned int exponent_width() { return 8; }

static constexpr unsigned int mantissa_width() { return 23; }

static constexpr unsigned int max_exponent() { return all_ones<8>(); }
Expand Down Expand Up @@ -152,7 +154,7 @@ struct __attribute__((packed, may_alias)) generic_float
float32_parts f{};
f.sign = sign;

if(exponent == 0) // subnormal fps
if(exponent == 0 and ExponentSize != float32_parts::exponent_width()) // subnormal fps
{

if(mantissa == 0)
Expand Down
4 changes: 3 additions & 1 deletion src/include/migraphx/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/half.hpp>
#include <migraphx/bf16.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp>
Expand All @@ -52,6 +53,7 @@ struct MIGRAPHX_EXPORT shape
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
m(bool_type, bool) \
m(half_type, half) \
m(bf16_type, bf16) \
richagadgil marked this conversation as resolved.
Show resolved Hide resolved
m(float_type, float) \
m(double_type, double) \
m(uint8_type, uint8_t) \
Expand All @@ -65,7 +67,7 @@ struct MIGRAPHX_EXPORT shape
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz) \
m(fp8e4m3fn_type, migraphx::fp8::fp8e4m3fn) \
m(fp8e5m2_type, migraphx::fp8::fp8e5m2)
// clang-format on
// clang-format on

#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
enum type_t
Expand Down
5 changes: 5 additions & 0 deletions src/include/migraphx/type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <type_traits>
#include <migraphx/half.hpp>
#include <migraphx/bf16.hpp>
#include <migraphx/config.hpp>
#include <migraphx/float8.hpp>

Expand All @@ -53,6 +54,10 @@ MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)

MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, bf16)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, bf16)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, bf16)

MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx::fp8::fp8e4m3fnuz)
Expand Down
11 changes: 11 additions & 0 deletions src/py/migraphx_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,17 @@ struct npy_format_descriptor<migraphx::fp8::fp8e5m2>
static constexpr auto name() { return _("fp8e5m2"); }
};

template <>
struct npy_format_descriptor<migraphx::bf16>
{
static std::string format()
{
// TODO: no standard format in numpy for bf16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be an issue or is this tracked somewhere already?

return "z";
}
static constexpr auto name() { return _("bf16"); }
};

} // namespace detail
} // namespace pybind11

Expand Down
1 change: 1 addition & 0 deletions src/targets/gpu/gemm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ rocblas_datatype get_type(shape::type_t type)
case shape::double_type: return rocblas_datatype_f64_r;
case shape::float_type: return rocblas_datatype_f32_r;
case shape::half_type: return rocblas_datatype_f16_r;
case shape::bf16_type: return rocblas_datatype_bf16_r;
case shape::int8_type: return rocblas_datatype_i8_r;
case shape::uint8_type: return rocblas_datatype_u8_r;
case shape::int32_type: return rocblas_datatype_i32_r;
Expand Down
1 change: 1 addition & 0 deletions src/targets/gpu/hip_gemm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ hipDataType get_type_hipblas(shape::type_t type)
case shape::double_type: return HIP_R_64F;
case shape::float_type: return HIP_R_32F;
case shape::half_type: return HIP_R_16F;
case shape::bf16_type: return HIP_R_16BF;
case shape::int8_type: return HIP_R_8I;
case shape::uint8_type: return HIP_R_8U;
case shape::int32_type: return HIP_R_32I;
Expand Down
Loading
Loading