From 9500cb2c3a42e6404e8e9fff0c4ffd7137603016 Mon Sep 17 00:00:00 2001 From: Yi DING Date: Thu, 25 Jul 2024 12:09:10 +0800 Subject: [PATCH] XeTLA INT4 With BF16 Support (#311) * int4 with bf16 support * rename quantmode --- include/common/core/common_types.hpp | 2 +- include/common/core/explicit_conv.hpp | 13 ++++++ .../group/gemm/impl/int4_dequantize_xe.hpp | 18 +++----- .../gemm/impl/int4_dequantize_kslicing_xe.hpp | 14 +++--- .../subgroup/tile/impl/tile_op_functor.hpp | 4 +- .../gemm/int4_dequantization/main.cpp | 5 ++- .../int4_dequantization_bias/main_client.cpp | 4 +- .../gemm/int4_dequantization_bias/main_xe.cpp | 2 +- tests/integration/gemv/int4/main.cpp | 45 ++++++++++++------- 9 files changed, 65 insertions(+), 42 deletions(-) diff --git a/include/common/core/common_types.hpp b/include/common/core/common_types.hpp index cbd174462..30c3cc04d 100644 --- a/include/common/core/common_types.hpp +++ b/include/common/core/common_types.hpp @@ -27,7 +27,7 @@ enum class grf_mode : uint8_t { normal = 0, double_grf = 1 }; enum class mem_layout : uint8_t { row_major = 0, col_major = 1 }; -enum class quant_mode : uint8_t { S4_ASYM = 0, S4_FULLRANGE_NO_ZP = 1 }; +enum class quant_mode : uint8_t { I4_ASYM = 0, I4_SYM = 1 }; struct quant_info { quant_mode quant_mode; diff --git a/include/common/core/explicit_conv.hpp b/include/common/core/explicit_conv.hpp index 0c61f12bc..ba553ad2d 100644 --- a/include/common/core/explicit_conv.hpp +++ b/include/common/core/explicit_conv.hpp @@ -62,6 +62,19 @@ xetla_cvt(xetla_vector src) { return dst; } +/// @brief xetla explicit data conversion, bf16->fp16. +/// @tparam T_dst is the float16 data type. +/// @tparam T_src is the bfloat16 data type. +/// @tparam N is the element number in xetla_vector. +template +__XETLA_API typename std::enable_if_t< + std::is_same::value && std::is_same::value, + xetla_vector> +xetla_cvt(xetla_vector src) { + xetla_vector dst = src; + return dst; +} + /// @brief xetla explicit data conversion, bf16->fp32. /// @tparam T_dst is the bfloat16 data type. /// @tparam T_src is the float32 data type. diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index cfb09fe22..71168e2f6 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -520,8 +520,7 @@ class gemm_t< // TODO 1D prefetch need pack to U32/U64 subgroup::tile_prefetch( scale_prefetch_payload); - if constexpr ( - compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { + if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) { // TODO 1D prefetch need pack to U32/U64 subgroup::tile_prefetch( zero_pt_prefetch_payload); @@ -534,8 +533,7 @@ class gemm_t< if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) { scale_prefetch_payload.template update_tdesc( scale_t::tile_size_y); - if constexpr ( - compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { + if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) { zero_pt_prefetch_payload .template update_tdesc( zero_pt_t::tile_size_y); @@ -564,8 +562,7 @@ class gemm_t< // matB, matB_payload); subgroup::tile_load( scale, scale_payload); - if constexpr ( - compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { + if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) { subgroup::tile_load( zero_pt, zero_pt_payload); } @@ -579,8 +576,7 @@ class gemm_t< // TODO 1D prefetch need pack to U32/U64 subgroup::tile_prefetch( scale_prefetch_payload); - if constexpr ( - compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { + if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) { // TODO 1D prefetch need pack to U32/U64 subgroup::tile_prefetch( zero_pt_prefetch_payload); @@ -593,8 +589,7 @@ class gemm_t< if (tile_k_idx % scale_addr_update_freq == 0) { scale_payload.template update_tdesc(scale_t::tile_size_y); } - if constexpr ( - compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { + if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) { if (tile_k_idx % zero_pt_addr_update_freq == 0) { zero_pt_payload.template update_tdesc( zero_pt_t::tile_size_y); @@ -608,8 +603,7 @@ class gemm_t< if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) { scale_prefetch_payload.template update_tdesc( scale_t::tile_size_y); - if constexpr ( - compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { + if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) { zero_pt_prefetch_payload .template update_tdesc( zero_pt_t::tile_size_y); diff --git a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp index 79d37d517..e9ddff750 100644 --- a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp +++ b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp @@ -159,7 +159,7 @@ class gemm_universal_t< /// @brief GEMM arguments. /// This is the interface for users to pass the application-related runtime /// variables. - template + template struct arguments_t { /// @brief Is the size of the m dimension of the matrix multiplication (m x /// k x n). @@ -295,7 +295,7 @@ class gemm_universal_t< } }; template <> - struct arguments_t { + struct arguments_t { /// @brief Is the size of the m dimension of the matrix multiplication (m x /// k x n). uint32_t matrix_m; @@ -526,6 +526,10 @@ class gemm_universal_t< template static bool can_implement(arguments_t& args) { bool implementable = true; + if (arch_tag == gpu_arch::XeLpg) { + implementable &= !std::is_same_v; // XeLpg arch dosen't + // have bf16 related isa. + } if (gemm_t::msg_type_a != msg_type::unaligned_2d) { if (gemm_t::msg_type_a == msg_type::block_2d) { implementable &= kernel::block_2d::check_tensor( @@ -566,8 +570,7 @@ class gemm_universal_t< // check for int4x2 implementable &= ((args.matB_ld % pack_ratio == 0) && (args.matrix_n % pack_ratio == 0)); - if constexpr ( - gemm_t::compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { + if constexpr (gemm_t::compute_policy::quant_mode != quant_mode::I4_SYM) { implementable &= (args.zero_pt_ld % pack_ratio == 0); } @@ -664,8 +667,7 @@ class gemm_universal_t< uint32_t inner_loop_start = (start_k + k_stride - 1) / k_stride; uint32_t inner_loop_count = (wg_tile_k + k_stride - 1) / k_stride; gemm_args_t gemm_args; - if constexpr ( - gemm_t::compute_policy::quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) { + if constexpr (gemm_t::compute_policy::quant_mode == quant_mode::I4_SYM) { gemm_args = gemm_args_t( mem_desc_a, mem_desc_b, diff --git a/include/subgroup/tile/impl/tile_op_functor.hpp b/include/subgroup/tile/impl/tile_op_functor.hpp index 060302448..f5a41e931 100644 --- a/include/subgroup/tile/impl/tile_op_functor.hpp +++ b/include/subgroup/tile/impl/tile_op_functor.hpp @@ -130,7 +130,7 @@ struct dequant_int4_weight_t { (offset_y_in_tile) / dequant_s * scale_t::block_size_x + offset_x_in_tile; - if constexpr (quant_mode == quant_mode::S4_ASYM) { + if constexpr (quant_mode == quant_mode::I4_ASYM) { uint32_t zero_pt_idx = offset_y_in_tile / dequant_s * zero_pt_t::block_size_x + offset_x_in_tile / pack_ratio; @@ -149,7 +149,7 @@ struct dequant_int4_weight_t { cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) = cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) - zero_pt_i8; - } else if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) { + } else if constexpr (quant_mode == quant_mode::I4_SYM) { cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) = cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) - int8_t(8); diff --git a/tests/integration/gemm/int4_dequantization/main.cpp b/tests/integration/gemm/int4_dequantization/main.cpp index 18e40ded5..88c21250e 100644 --- a/tests/integration/gemm/int4_dequantization/main.cpp +++ b/tests/integration/gemm/int4_dequantization/main.cpp @@ -229,8 +229,9 @@ void dequantize_gemm_run(uint32_t iter) { compute_attr_t; using perf_tuning_knob = xetla::group:: perf_tuning_knob_t; - - static constexpr quant_info quant_info{quant_mode::S4_ASYM, Test::dequant_s, layout_b}; + + static constexpr quant_info quant_info{ + quant_mode::I4_ASYM, Test::dequant_s, layout_b}; using compute_policy = xetla::group::compute_policy_int4_dequantize< compute_attr, diff --git a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp b/tests/integration/gemm/int4_dequantization_bias/main_client.cpp index 69fdfc1fe..0597af758 100644 --- a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp +++ b/tests/integration/gemm/int4_dequantization_bias/main_client.cpp @@ -622,7 +622,7 @@ void dequantize_gemm_run(int iter) { perf_tuning_knob_t; static constexpr quant_info quant_info{ - quant_mode::S4_FULLRANGE_NO_ZP, Test::dequant_s, layout_b}; + quant_mode::I4_SYM, Test::dequant_s, layout_b}; using compute_policy = xetla::group::compute_policy_int4_dequantize< compute_attr, @@ -1043,4 +1043,4 @@ REGISTER_TYPED_TEST_SUITE_P(dequantize_gemm_act_shuf_test, esimd); INSTANTIATE_TYPED_TEST_SUITE_P( dequantize_gemm_act_shuf_test_suite, dequantize_gemm_act_shuf_test, - tests); \ No newline at end of file + tests); diff --git a/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp b/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp index 1c42454df..0cc9d8d6f 100644 --- a/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp +++ b/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp @@ -388,7 +388,7 @@ void dequantize_gemm_run(int iter) { using perf_tuning_knob = xetla::group:: perf_tuning_knob_t; static constexpr quant_info quant_info{ - quant_mode::S4_FULLRANGE_NO_ZP, Test::dequant_s, layout_b}; + quant_mode::I4_SYM, Test::dequant_s, layout_b}; using compute_policy = xetla::group::compute_policy_int4_dequantize< compute_attr, diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index 2cda91eba..5edf0a0d6 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -27,6 +27,7 @@ constexpr int ITER = 200; #endif constexpr size_t UNDEFINED_DATA_SIZE = 1024; +template class test_col_major_1 { public: // Extract the parameters required by different test cases @@ -39,8 +40,8 @@ class test_col_major_1 { static constexpr size_t sg_n = 1; static constexpr size_t sg_k = 512 / sg_m; static constexpr size_t dequant_s = 128; - // static constexpr quant_mode quant_mode = quant_mode::S4_ASYM; - static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP; + // static constexpr quant_mode quant_mode = quant_mode::I4_ASYM; + static constexpr quant_mode quant_mode = quant_mode::I4_SYM; static constexpr size_t local_kslicing = 1; static constexpr size_t global_kslicing = 1; @@ -48,9 +49,9 @@ class test_col_major_1 { static constexpr mem_layout layout_b = mem_layout::col_major; static constexpr mma_engine mma_eng = mma_engine::fpu; static constexpr gpu_arch arch = gpu_arch::XeLpg; - using data_type_a = fp16; + using data_type_a = scalar_t; using data_type_b = int4x8; - using data_type_c = fp16; + using data_type_c = scalar_t; }; class test_col_major_2 { public: @@ -120,7 +121,7 @@ int gemm_result_validate( } template < - quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP, + quant_mode quant_mode = quant_mode::I4_SYM, typename data_type_acc_in = fp16, typename data_type_b, typename data_type_scale, @@ -134,7 +135,7 @@ std::vector convert_int4( int8_t zero_pt_i8 = zero_pt & 0xf; for (uint32_t i = 0; i < dequant_fp16.size(); i++) { int8_t dequant_8bit = data_b & 0xf; - if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) { + if constexpr (quant_mode == quant_mode::I4_SYM) { dequant_fp16[i] = scale * (dequant_8bit - 8); } else { dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8); @@ -147,7 +148,7 @@ std::vector convert_int4( template < size_t dequant_s, mem_layout layout_b = mem_layout::col_major, - quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP, + quant_mode quant_mode = quant_mode::I4_SYM, typename data_type_acc_in = fp16, typename data_type_b, typename data_type_scale, @@ -173,11 +174,11 @@ std::vector dequantize_weight( (j / step) * (matrix_n / pack_radio) + i / pack_radio; int start_out = layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio; + data_type_zero_pt zp_value = zero_pt[start_zero_pt_in]; + zp_value = zp_value >> (4 * (i % pack_radio)); for (uint32_t jj = 0; jj < step; jj++) { std::vector dequant_fp16 = convert_int4( - b[start_b_in + jj], - scale[start_scale_in], - zero_pt[start_zero_pt_in] >> (4 * (i % pack_radio))); + b[start_b_in + jj], scale[start_scale_in], zp_value); for (uint32_t jjj = 0; jjj < dequant_fp16.size(); jjj++) { b_out[start_out + pack_radio * jj + jjj] = dequant_fp16[jjj]; } @@ -474,7 +475,7 @@ void dequantize_gemv_run(int iter) { // It accepts the base pointer to matrix D, and its dimensions {bias_d, bias_add_shape}}); typename gemm_op_t::template arguments_t gemm_arg; - if constexpr (compute_policy::quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) { + if constexpr (compute_policy::quant_mode == quant_mode::I4_SYM) { gemm_arg = typename gemm_op_t::template arguments_t( matrix_m, @@ -491,7 +492,7 @@ void dequantize_gemv_run(int iter) { Acc_d, Cnt_d, epilogue_args); - } else if constexpr (compute_policy::quant_mode == quant_mode::S4_ASYM) { + } else if constexpr (compute_policy::quant_mode == quant_mode::I4_ASYM) { gemm_arg = typename gemm_op_t::template arguments_t( matrix_m, @@ -551,9 +552,11 @@ void dequantize_gemv_run(int iter) { // performance prof.print_profiling_result(profiling_selector::GPU); // check result - std::vector dequantize_b = - dequantize_weight( - matrix_k, matrix_n, B_h, scale_h, zero_pt_h); + std::vector dequantize_b = dequantize_weight< + dequant_s, + layout_b, + compute_policy::quant_mode, + data_type_c>(matrix_k, matrix_n, B_h, scale_h, zero_pt_h); queue.memcpy((void*)C_h, (void*)C_d, size_c * sizeof(data_type_c)).wait(); ASSERT_EQ( @@ -585,6 +588,12 @@ void dequantize_gemv_run(int iter) { free(Cnt_d, context); } +// Placeholder for void test param +template <> +void dequantize_gemv_run(int) { + GTEST_SKIP(); +} + template class dequantize_gemv_test : public ::testing::Test {}; TYPED_TEST_SUITE_P(dequantize_gemv_test); @@ -594,7 +603,11 @@ TYPED_TEST_P(dequantize_gemv_test, esimd) { } REGISTER_TYPED_TEST_SUITE_P(dequantize_gemv_test, esimd); -using tests = ::testing::Types; +using tests = ::testing::Types< // + test_col_major_1, + test_col_major_1, + // test_col_major_2, + void>; INSTANTIATE_TYPED_TEST_SUITE_P( dequantize_gemv_test_suite,