Skip to content

Commit

Permalink
add gpu support for bf16
Browse files Browse the repository at this point in the history
  • Loading branch information
richagadgil committed Nov 19, 2024
1 parent 043e322 commit d1acec9
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/driver/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ struct compiler
bool to_fp8 = false;
bool to_int8 = false;
bool to_int4 = false;
bool to_bf16 = false;

std::vector<std::string> fill0;
std::vector<std::string> fill1;
Expand All @@ -509,6 +510,7 @@ struct compiler
ap(to_int8, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(true));
ap(to_fp8, {"--fp8"}, ap.help("Quantize for fp8"), ap.set_value(true));
ap(to_int4, {"--int4-weights"}, ap.help("Quantize weights for int4"), ap.set_value(true));
ap(to_bf16, {"--bf16"}, ap.help("Quantize for fp16"), ap.set_value(true));
}

auto params(const program& p)
Expand Down Expand Up @@ -567,6 +569,10 @@ struct compiler
{
quantize_int4_weights(p);
}
if(to_bf16)
{
quantize_bf16(p);
}
p.compile(t, co);
l.save(p);
return p;
Expand Down
3 changes: 3 additions & 0 deletions src/include/migraphx/quantization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ quantize_fp8(program& prog, const target& t, const std::vector<parameter_map>& c

MIGRAPHX_EXPORT void quantize_int4_weights(program& prog);

MIGRAPHX_EXPORT void quantize_bf16(program& prog,
const std::vector<std::string>& ins_names = {"all"});

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

Expand Down
4 changes: 4 additions & 0 deletions src/py/migraphx_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,10 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
},
"Auto-convert FP8 parameters and return values to Float for MIGraphX Program",
py::arg("prog"));
m.def("quantize_bf16",
&migraphx::quantize_bf16,
py::arg("prog"),
py::arg("ins_names") = std::vector<std::string>{"all"});

#ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
Expand Down
10 changes: 10 additions & 0 deletions src/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
quant_tracer());
}

void quantize_bf16(program& prog, const std::vector<std::string>& ins_names)
{
run_passes(prog,
{normalize_ops{},
optimize_module{{"quantizelinear", "dequantizelinear"}},
truncate_float_pass{ins_names, shape::bf16_type},
optimize_module{{"quantizelinear", "dequantizelinear"}}},
quant_tracer());
}

void quantize_8bits(program& prog,
const target& t,
shape::type_t precision,
Expand Down
1 change: 1 addition & 0 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types.erase(shape::type_t::uint8_type);
unsupported_types.erase(shape::type_t::int32_type);
unsupported_types.erase(shape::type_t::tuple_type);
unsupported_types.erase(shape::type_t::bf16_type);

// whiltelist supported Ops for the FP8 types
// different between fp8e4m3fnuz and OCP types because rocBLAS only has
Expand Down

0 comments on commit d1acec9

Please sign in to comment.