Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bf16 gpu support #3630

Draft
wants to merge 80 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
c51c1ce
first pass at integrating generic float
richagadgil Oct 10, 2024
134b408
fix namespaces
richagadgil Oct 10, 2024
d4fa6eb
fix mantissa
richagadgil Oct 10, 2024
0b60841
refactor
richagadgil Oct 11, 2024
7a646f1
refactor
richagadgil Oct 11, 2024
ebe819b
add fp
richagadgil Oct 11, 2024
379a77a
fixed generic float class
richagadgil Oct 14, 2024
174384c
add fp32 test
richagadgil Oct 14, 2024
787b651
remove import
richagadgil Oct 14, 2024
1d1fa1c
update tests
richagadgil Oct 15, 2024
1791092
fp16 tests that work
richagadgil Oct 17, 2024
a2eb005
update tests
richagadgil Oct 18, 2024
ff8ffc7
updated fp16 and fp32 tests
richagadgil Oct 18, 2024
e36fd65
half tests
richagadgil Oct 22, 2024
9ac4e2a
underflow and overflow tests
richagadgil Oct 22, 2024
f05fd31
generate map
richagadgil Oct 22, 2024
cb4d92d
add more tests
richagadgil Oct 22, 2024
0cc1946
fix names
richagadgil Oct 22, 2024
85a761b
update tests
richagadgil Oct 23, 2024
65cf9ae
remove and
richagadgil Oct 24, 2024
fbabf54
disable warning
richagadgil Oct 24, 2024
549f5e6
fix tidy warning
richagadgil Oct 24, 2024
d302e5d
migraphx py fix
richagadgil Oct 25, 2024
8d475e3
add increments
richagadgil Oct 25, 2024
a0fd055
fix warnings
richagadgil Oct 25, 2024
41379fe
disable duplicate branch warning
richagadgil Oct 25, 2024
0c29c7b
add countzero_std
richagadgil Oct 28, 2024
4b012a8
ci error
richagadgil Oct 28, 2024
dbaa3a8
simplify countl
richagadgil Oct 28, 2024
b2bd2a0
fix ci
richagadgil Oct 28, 2024
6f328f0
src
richagadgil Oct 29, 2024
e6d9763
remove flag
richagadgil Oct 29, 2024
6538050
hide abi warning
richagadgil Oct 29, 2024
4e96d4d
revert changes
richagadgil Oct 29, 2024
ef11f1f
Merge branch 'develop' into generic_float
richagadgil Oct 29, 2024
e4a25bd
change half in tests
richagadgil Oct 29, 2024
3354c6e
Update generic_float.hpp
richagadgil Oct 29, 2024
6de079b
format
richagadgil Oct 29, 2024
7750874
Merge branch 'develop' into generic_float
richagadgil Oct 29, 2024
801f485
Merge branch 'develop' into generic_float
causten Oct 30, 2024
33e2c8d
fix bug
richagadgil Oct 30, 2024
9bb7198
Merge branch 'generic_float' of github.com:ROCm/AMDMIGraphX into gene…
richagadgil Oct 30, 2024
b3c345d
fix err
richagadgil Oct 30, 2024
03df6f9
edits
richagadgil Oct 31, 2024
ad817b2
tidy and format
richagadgil Oct 31, 2024
898417b
tidy etc
richagadgil Oct 31, 2024
aa5b9c9
gf
richagadgil Oct 31, 2024
6f72370
fix tidy errs
richagadgil Nov 1, 2024
0aab1a0
bf16 changes
richagadgil Nov 4, 2024
7b965c0
add flag to trace quantization passes (#3571)
shivadbhavsar Oct 30, 2024
5f5f13d
bf16
richagadgil Oct 30, 2024
d64b124
Update bf16.cpp
richagadgil Nov 1, 2024
a064eaa
Update bf16.hpp
richagadgil Nov 2, 2024
befbd9e
Update bf16.hpp
richagadgil Nov 2, 2024
08b9511
update files with working version
richagadgil Nov 4, 2024
b9d204e
Update bf16.cpp
richagadgil Nov 4, 2024
fb6df2d
Update generic_float.hpp
richagadgil Nov 4, 2024
bb78138
Merge branch 'develop' into bf16
richagadgil Nov 8, 2024
8e1f99e
add extra common type
richagadgil Nov 8, 2024
6192970
tidy
richagadgil Nov 8, 2024
c0d6bc4
Update bf16.hpp
richagadgil Nov 11, 2024
7bfc407
Update generic_float.hpp
richagadgil Nov 11, 2024
4cb96ad
Merge branch 'develop' into bf16
richagadgil Nov 11, 2024
ffd4ba2
remove imports
richagadgil Nov 12, 2024
8a10da3
Merge branch 'develop' into bf16
richagadgil Nov 12, 2024
1565a0e
ref tests
richagadgil Nov 13, 2024
e6d1155
migraphx_py fix
richagadgil Nov 13, 2024
867e960
fix test cae by index
richagadgil Nov 13, 2024
9852da5
add rocblas type
richagadgil Nov 13, 2024
bf50653
fix tgts err
richagadgil Nov 13, 2024
0ebd220
address changes
richagadgil Nov 18, 2024
043e322
Merge branch 'develop' into bf16
richagadgil Nov 18, 2024
a3ca184
bf16 gpu support
richagadgil Nov 19, 2024
490d326
add vector types
richagadgil Nov 19, 2024
a63ac1e
rocblas
richagadgil Nov 19, 2024
94990bb
bf16 gpu testing
shivadbhavsar Nov 19, 2024
8aaae90
mlir bf16
shivadbhavsar Nov 19, 2024
208232e
fix type
richagadgil Nov 19, 2024
d4866d5
fix type
richagadgil Nov 19, 2024
59eec66
add type
richagadgil Nov 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/api/include/migraphx/migraphx.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@
m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz) \
m(fp8e4m3fn_type, migraphx::fp8::fp8e4m3fn) \
m(fp8e5m2_type, migraphx::fp8::fp8e5m2)
m(fp8e5m2_type, migraphx::fp8::fp8e5m2) \
m(bf16_type, bf16)
// clang-format on

#ifdef __cplusplus
Expand Down
15 changes: 15 additions & 0 deletions src/driver/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ struct compiler
compiler_target ct;
compile_options co;
bool to_fp16 = false;
bool to_bf16 = false;
bool to_fp8 = false;
bool to_int8 = false;
bool to_int4 = false;
Expand All @@ -506,9 +507,11 @@ struct compiler
ap.help("Exhastively search for best tuning parameters for kernels"),
ap.set_value(true));
ap(to_fp16, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(true));
ap(to_bf16, {"--bf16"}, ap.help("Quantize for bf16"), ap.set_value(true));
ap(to_int8, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(true));
ap(to_fp8, {"--fp8"}, ap.help("Quantize for fp8"), ap.set_value(true));
ap(to_int4, {"--int4-weights"}, ap.help("Quantize weights for int4"), ap.set_value(true));
ap(to_bf16, {"--bf16"}, ap.help("Quantize for fp16"), ap.set_value(true));
}

auto params(const program& p)
Expand Down Expand Up @@ -555,6 +558,10 @@ struct compiler
{
quantize_fp16(p);
}
if(to_bf16)
{
quantize_bf16(p);
}
if(to_int8)
{
quantize_int8(p, t, {host_params(p)});
Expand All @@ -567,6 +574,10 @@ struct compiler
{
quantize_int4_weights(p);
}
if(to_bf16)
{
quantize_bf16(p);
}
p.compile(t, co);
l.save(p);
return p;
Expand Down Expand Up @@ -639,6 +650,10 @@ struct verify : command<verify>
{
vo.quantize = precision::fp16;
}
if(c.to_bf16)
{
vo.quantize = precision::bf16;
}
if(c.to_int8)
{
vo.quantize = precision::int8;
Expand Down
1 change: 1 addition & 0 deletions src/driver/precision.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ enum class precision
{
fp32,
fp16,
bf16,
int8
};

Expand Down
4 changes: 4 additions & 0 deletions src/driver/verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ std::vector<argument> run_target(program p,
{
quantize_fp16(p);
}
if(vo.quantize == precision::bf16)
{
quantize_bf16(p);
}
p.compile(t, options);

parameter_map m;
Expand Down
39 changes: 39 additions & 0 deletions src/include/migraphx/bf16.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#ifndef MIGRAPHX_GUARD_RTGLIB_BF16_HPP
#define MIGRAPHX_GUARD_RTGLIB_BF16_HPP

#include <migraphx/generic_float.hpp>
#include <migraphx/config.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

using bf16 = migraphx::generic_float<7, 8>;

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
4 changes: 3 additions & 1 deletion src/include/migraphx/generic_float.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ struct float32_parts
unsigned int exponent : 8;
unsigned int sign : 1;

static constexpr unsigned int exponent_width() { return 8; }

static constexpr unsigned int mantissa_width() { return 23; }

static constexpr unsigned int max_exponent() { return all_ones<8>(); }
Expand Down Expand Up @@ -152,7 +154,7 @@ struct __attribute__((packed, may_alias)) generic_float
float32_parts f{};
f.sign = sign;

if(exponent == 0) // subnormal fps
if(exponent == 0 and ExponentSize != float32_parts::exponent_width()) // subnormal fps
{

if(mantissa == 0)
Expand Down
3 changes: 3 additions & 0 deletions src/include/migraphx/quantization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ quantize_fp8(program& prog, const target& t, const std::vector<parameter_map>& c

MIGRAPHX_EXPORT void quantize_int4_weights(program& prog);

MIGRAPHX_EXPORT void quantize_bf16(program& prog,
const std::vector<std::string>& ins_names = {"all"});

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

Expand Down
6 changes: 4 additions & 2 deletions src/include/migraphx/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/half.hpp>
#include <migraphx/bf16.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp>
Expand Down Expand Up @@ -64,8 +65,9 @@ struct MIGRAPHX_EXPORT shape
m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz) \
m(fp8e4m3fn_type, migraphx::fp8::fp8e4m3fn) \
m(fp8e5m2_type, migraphx::fp8::fp8e5m2)
// clang-format on
m(fp8e5m2_type, migraphx::fp8::fp8e5m2) \
m(bf16_type, bf16)
// clang-format on

#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
enum type_t
Expand Down
5 changes: 5 additions & 0 deletions src/include/migraphx/type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <type_traits>
#include <migraphx/half.hpp>
#include <migraphx/bf16.hpp>
#include <migraphx/config.hpp>
#include <migraphx/float8.hpp>

Expand All @@ -53,6 +54,10 @@ MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)

MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, bf16)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, bf16)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, bf16)

MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx::fp8::fp8e4m3fnuz)
Expand Down
15 changes: 15 additions & 0 deletions src/py/migraphx_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,17 @@ struct npy_format_descriptor<migraphx::fp8::fp8e5m2>
static constexpr auto name() { return _("fp8e5m2"); }
};

template <>
struct npy_format_descriptor<migraphx::bf16>
{
static std::string format()
{
// TODO: no standard format in numpy for bf16
return "z";
}
static constexpr auto name() { return _("bf16"); }
};

} // namespace detail
} // namespace pybind11

Expand Down Expand Up @@ -623,6 +634,10 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
},
"Auto-convert FP8 parameters and return values to Float for MIGraphX Program",
py::arg("prog"));
m.def("quantize_bf16",
&migraphx::quantize_bf16,
py::arg("prog"),
py::arg("ins_names") = std::vector<std::string>{"all"});

#ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
Expand Down
10 changes: 10 additions & 0 deletions src/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@
quant_tracer());
}

void quantize_bf16(program& prog, const std::vector<std::string>& ins_names)

Check warning on line 77 in src/quantization.cpp

View check run for this annotation

Codecov / codecov/patch

src/quantization.cpp#L77

Added line #L77 was not covered by tests
{
run_passes(prog,
{normalize_ops{},
optimize_module{{"quantizelinear", "dequantizelinear"}},
truncate_float_pass{ins_names, shape::bf16_type},
optimize_module{{"quantizelinear", "dequantizelinear"}}},

Check warning on line 83 in src/quantization.cpp

View check run for this annotation

Codecov / codecov/patch

src/quantization.cpp#L79-L83

Added lines #L79 - L83 were not covered by tests
quant_tracer());
}

Check warning on line 85 in src/quantization.cpp

View check run for this annotation

Codecov / codecov/patch

src/quantization.cpp#L85

Added line #L85 was not covered by tests

void quantize_8bits(program& prog,
const target& t,
shape::type_t precision,
Expand Down
7 changes: 5 additions & 2 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
const auto& name = i.name();
const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::bf16_type,
type_t::half_type,
type_t::fp8e4m3fnuz_type,
type_t::fp8e4m3fn_type,
Expand Down Expand Up @@ -407,6 +408,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
};
std::set<shape::type_t> float_types = {type_t::float_type,
type_t::half_type,
type_t::bf16_type,
type_t::fp8e4m3fnuz_type,
type_t::fp8e4m3fn_type,
type_t::fp8e5m2_type};
Expand All @@ -426,7 +428,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
return false;
} // else
return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type());
return contains({type_t::float_type, type_t::half_type, type_t::bf16_type}, arg->get_shape().type());
});
}
return false;
Expand All @@ -438,7 +440,7 @@ bool is_reduce_op_supported_by_mlir(const instruction& i)
const auto& name = i.name();
const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {
type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type};
type_t::float_type, type_t::half_type, type_t::bf16_type, type_t::fp8e4m3fnuz_type};
// Preliminary type check.
if(not contains(allowed_types, result_type))
{
Expand Down Expand Up @@ -695,6 +697,7 @@ struct find_mlir_standalone_op
if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) {
return not contains({shape::type_t::float_type,
shape::type_t::half_type,
shape::type_t::bf16_type,
shape::type_t::int8_type,
shape::type_t::fp8e4m3fnuz_type,
shape::type_t::fp8e4m3fn_type,
Expand Down
3 changes: 2 additions & 1 deletion src/targets/gpu/gemm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ rocblas_datatype get_type(shape::type_t type)
case shape::int16_type:
case shape::int64_type:
case shape::uint64_type: MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!");
case shape::bf16_type: return rocblas_datatype_bf16_r;
}

MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!");
Expand Down Expand Up @@ -221,7 +222,7 @@ struct gemm_impl
compute_type = rb_compute_type{output_type};
if(compute_fp32)
{
if(arg_type == rocblas_datatype_f16_r)
if(arg_type == rocblas_datatype_f16_r or arg_type == rocblas_datatype_bf16_r)
compute_type = rocblas_datatype_f32_r;
}
if(arg_type == rocblas_datatype_f8_r)
Expand Down
1 change: 1 addition & 0 deletions src/targets/gpu/hip_gemm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ hipDataType get_type_hipblas(shape::type_t type)
case shape::int16_type:
case shape::int64_type:
case shape::uint64_type: MIGRAPHX_THROW("HIPBLAS_GEMM: data type not supported!");
case shape::bf16_type: return HIP_R_16BF;
}

MIGRAPHX_THROW("HIPBLAS_GEMM: data type not supported!");
Expand Down
2 changes: 2 additions & 0 deletions src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
#ifndef MIGRAPHX_USE_HIPRTC
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <hip/math_functions.h>
#include <hip/hip_bf16.h>

Check warning on line 32 in src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp

View workflow job for this annotation

GitHub Actions / tidy

duplicate include [readability-duplicate-include,-warnings-as-errors]

Check warning on line 32 in src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp

View workflow job for this annotation

GitHub Actions / tidy

duplicate include [readability-duplicate-include,-warnings-as-errors]

Check warning on line 32 in src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp

View workflow job for this annotation

GitHub Actions / tidy

duplicate include [readability-duplicate-include,-warnings-as-errors]

Check warning on line 32 in src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp

View workflow job for this annotation

GitHub Actions / tidy

duplicate include [readability-duplicate-include,-warnings-as-errors]

Check warning on line 32 in src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp

View workflow job for this annotation

GitHub Actions / tidy

duplicate include [readability-duplicate-include,-warnings-as-errors]

Check warning on line 32 in src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp

View workflow job for this annotation

GitHub Actions / tidy

duplicate include [readability-duplicate-include,-warnings-as-errors]

Check warning on line 32 in src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp

View workflow job for this annotation

GitHub Actions / tidy

duplicate include [readability-duplicate-include,-warnings-as-errors]

Check warning on line 32 in src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp

View workflow job for this annotation

GitHub Actions / tidy

duplicate include [readability-duplicate-include,-warnings-as-errors]
#endif

#endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ using vec = T __attribute__((ext_vector_type(N)));

using half = _Float16;
using half2 = migraphx::vec<half, 2>;
using bf16 = __bf16;

} // namespace migraphx

Expand Down
2 changes: 2 additions & 0 deletions src/targets/gpu/mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ struct mlir_program
result = mlirF32TypeGet(ctx.get());
else if(as.type_enum() == shape::half_type)
result = mlirF16TypeGet(ctx.get());
else if(as.type_enum() == shape::bf16_type)
result = mlirBF16TypeGet(ctx.get());
else if(as.type_enum() == shape::fp8e4m3fnuz_type)
result = mlirFloat8E4M3FNUZTypeGet(ctx.get());
else if(as.type_enum() == shape::fp8e4m3fn_type)
Expand Down
1 change: 1 addition & 0 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types.erase(shape::type_t::uint8_type);
unsupported_types.erase(shape::type_t::int32_type);
unsupported_types.erase(shape::type_t::tuple_type);
unsupported_types.erase(shape::type_t::bf16_type);

// whiltelist supported Ops for the FP8 types
// different between fp8e4m3fnuz and OCP types because rocBLAS only has
Expand Down
Loading
Loading