From 8940c0a520bcbad853a19c29166507d0483bb4ce Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 4 Dec 2023 14:04:44 -0800 Subject: [PATCH 01/31] only register q4gemm benchmarks if q4gemm is available --- onnxruntime/test/mlas/bench/bench_q4gemm.cpp | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/onnxruntime/test/mlas/bench/bench_q4gemm.cpp b/onnxruntime/test/mlas/bench/bench_q4gemm.cpp index 87e3601612761..57c6c93969840 100644 --- a/onnxruntime/test/mlas/bench/bench_q4gemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_q4gemm.cpp @@ -112,9 +112,16 @@ static void GemmSizeProducts(benchmark::internal::Benchmark* b) { ArgsProduct(b, {{1, 1024, 2048}, {4096}, {4096}, {8}}); } -BENCHMARK_CAPTURE(Q4GEMM, Q4Sym, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK_CAPTURE(Q4GEMM, Q4Zp8, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK_CAPTURE(Q4GEMM, Q4Sym128, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Sym, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Zp8, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Sym128, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime(); +[[maybe_unused]] static const bool benchmarks_registered = []() { + const bool is_q4gemm_supported = MlasQ4GemmPackBSize(BlkQ4Sym, 1, 1) > 0; + if (is_q4gemm_supported) { + BENCHMARK_CAPTURE(Q4GEMM, Q4Sym, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime(); + BENCHMARK_CAPTURE(Q4GEMM, Q4Zp8, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime(); + BENCHMARK_CAPTURE(Q4GEMM, Q4Sym128, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime(); + BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Sym, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime(); + BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Zp8, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime(); + BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Sym128, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime(); + return true; + } + return false; +}(); From a6a8ce627ad42cd280e10355ff8f49c61d57e5f0 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 4 Dec 2023 18:25:26 -0800 Subject: [PATCH 02/31] some mlas cmake updates --- cmake/onnxruntime_mlas.cmake | 45 ++++++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 04efa5c2b4f6d..f455aa26024a6 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -1,7 +1,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -set(MLAS_SRC_DIR ${ONNXRUNTIME_ROOT}/core/mlas/lib) +set(MLAS_ROOT ${ONNXRUNTIME_ROOT}/core/mlas) +set(MLAS_SRC_DIR ${MLAS_ROOT}/lib) +set(MLAS_INC_DIR ${MLAS_ROOT}/inc) # # All hardware agnostic source files here @@ -9,6 +11,7 @@ set(MLAS_SRC_DIR ${ONNXRUNTIME_ROOT}/core/mlas/lib) # multi-target build # onnxruntime_add_static_library(onnxruntime_mlas + ${MLAS_SRC_DIR}/mlasi.h ${MLAS_SRC_DIR}/platform.cpp ${MLAS_SRC_DIR}/threading.cpp ${MLAS_SRC_DIR}/sgemm.cpp @@ -33,9 +36,18 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/qpostprocessor.cpp ${MLAS_SRC_DIR}/qlgavgpool.cpp ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp + ${MLAS_SRC_DIR}/sqnbitgemm.h ${MLAS_SRC_DIR}/sqnbitgemm.cpp ) +target_sources(onnxruntime_mlas PRIVATE + ${MLAS_INC_DIR}/mlas_float16.h + ${MLAS_INC_DIR}/mlas_gemm_postprocessor.h + ${MLAS_INC_DIR}/mlas_q4.h + ${MLAS_INC_DIR}/mlas_qnbit.h + ${MLAS_INC_DIR}/mlas.h +) + if (NOT onnxruntime_ORT_MINIMAL_BUILD) target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/q4_dq.cpp @@ -134,10 +146,6 @@ function(setup_mlas_source_for_windows) target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/arm/sgemmc.cpp ) - # it should be removed after Visual Stuio is upgraded to 17.7 - if (MSVC) - add_compile_options("-d2SSAOptimizer-") - endif() elseif(onnxruntime_target_platform STREQUAL "x64") file(GLOB_RECURSE mlas_platform_srcs_avx CONFIGURE_DEPENDS @@ -290,8 +298,8 @@ else() if(APPLE) get_target_property(ONNXRUNTIME_MLAS_MACOSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) endif() - list(LENGTH ONNXRUNTIME_MLAS_MACOSX_ARCH ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGH) - if(ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGH GREATER 1) + list(LENGTH ONNXRUNTIME_MLAS_MACOSX_ARCH ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH) + if(ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH GREATER 1) set(ONNXRUNTIME_MLAS_MULTI_ARCH TRUE) endif() #If ONNXRUNTIME_MLAS_MULTI_ARCH is true, we need to go through every if branch below @@ -583,10 +591,12 @@ else() endif() foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) - target_include_directories(${mlas_target} PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${MLAS_SRC_DIR}) + target_include_directories(${mlas_target} PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET}) + + set_target_properties(${mlas_target} PROPERTIES FOLDER "ONNXRuntime") endforeach() -set_target_properties(onnxruntime_mlas PROPERTIES FOLDER "ONNXRuntime") + if (WIN32) target_compile_options(onnxruntime_mlas PRIVATE "$<$:/wd6385>" "$<$:/wd4127>") if (onnxruntime_ENABLE_STATIC_ANALYSIS) @@ -602,6 +612,21 @@ if (NOT onnxruntime_BUILD_SHARED_LIB) FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) endif() +# set up source group for MLAS source files +block() + set(source_group_srcs) + foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) + get_target_property(mlas_target_srcs ${mlas_target} SOURCES) + foreach(mlas_target_src ${mlas_target_srcs}) + cmake_path(IS_PREFIX MLAS_ROOT ${mlas_target_src} in_mlas_root) + if(in_mlas_root) + list(APPEND source_group_srcs ${mlas_target_src}) + endif() + endforeach() + endforeach() + source_group(TREE ${MLAS_ROOT} FILES ${source_group_srcs}) +endblock() + if (NOT onnxruntime_ORT_MINIMAL_BUILD) @@ -613,7 +638,7 @@ if (NOT onnxruntime_ORT_MINIMAL_BUILD) onnxruntime_add_executable(onnxruntime_mlas_q4dq ${MLAS_SRC_DIR}/q4_dq_cli.cpp ) - target_include_directories(onnxruntime_mlas_q4dq PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${MLAS_SRC_DIR}) + target_include_directories(onnxruntime_mlas_q4dq PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) set_target_properties(onnxruntime_mlas_q4dq PROPERTIES FOLDER "ONNXRuntimeTest") target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) From 53a46ca8d63d48191ba0bebf84871c16dd064079 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 11 Dec 2023 16:04:30 -0800 Subject: [PATCH 03/31] change BlkLen from template param to function param --- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 17 +-- onnxruntime/core/mlas/lib/sqnbitgemm.h | 26 ++-- .../core/mlas/lib/sqnbitgemm_kernel_neon.cpp | 120 ++++++++---------- 3 files changed, 73 insertions(+), 90 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index f964b1affec31..9a8694b0085ce 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -24,17 +24,10 @@ namespace int32_t GetDispatchQuantVariant(size_t BlkBitWidth, size_t BlkLen) { + MLAS_UNREFERENCED_PARAMETER(BlkLen); int32_t type = -1; - if (BlkBitWidth == 4 && BlkLen == 16) { - type = QuantVariant_BitWidth4_BlockSize16; - } else if (BlkBitWidth == 4 && BlkLen == 32) { - type = QuantVariant_BitWidth4_BlockSize32; - } else if (BlkBitWidth == 4 && BlkLen == 64) { - type = QuantVariant_BitWidth4_BlockSize64; - } else if (BlkBitWidth == 4 && BlkLen == 128) { - type = QuantVariant_BitWidth4_BlockSize128; - } else if (BlkBitWidth == 4 && BlkLen == 256) { - type = QuantVariant_BitWidth4_BlockSize256; + if (BlkBitWidth == 4) { + type = QuantVariant_BitWidth4; } return type; @@ -60,7 +53,7 @@ MlasSQNBitGemmBatch( if (ThreadPool == nullptr) { for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { auto Data = &DataParams[gemm_i]; - Operation(K, Data, 0, M, 0, N); + Operation(BlkLen, K, Data, 0, M, 0, N); } return; } @@ -120,7 +113,7 @@ MlasSQNBitGemmBatch( const size_t RangeStartN = ThreadIdN * StrideN; const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); - Operation(K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + Operation(BlkLen, K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN); }); } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index f8f7dcd43699f..870359b60ea8e 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -41,9 +41,9 @@ Module Name: * This kernel handles the special case where M, the number of rows of A and C, is 1. * * @tparam BlkBitWidth Bit width of each value in a block. - * @tparam BlkLen Number of values in a block. * @tparam KernelType Hardware-specific kernel type. * + * @param BlkLen Number of values in a block. * @param A Supplies the A matrix. * @param QuantBData Supplies the quantized B matrix block data. * @param QuantBScale Supplies the quantized B matrix block scale values. @@ -54,9 +54,10 @@ Module Name: * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. * @param Bias Bias vector of length N. */ -template +template MLAS_FORCEINLINE void MlasSQNBitGemmM1Kernel( + size_t BlkLen, const float* A, const uint8_t* QuantBData, const float* QuantBScale, @@ -75,9 +76,9 @@ MlasSQNBitGemmM1Kernel( * MlasSgemmCopyPackB. * * @tparam BlkBitWidth Bit width of each value in a block. - * @tparam BlkLen Number of values in a block. * @tparam KernelType Hardware-specific kernel type. * + * @param BlkLen Number of values in a block. * @param[out] FpData Supplies the output buffer for the dequantized B float data. * @param QuantBData Supplies the quantized B matrix block data. * @param QuantBScale Supplies the quantized B matrix block scale values. @@ -86,9 +87,10 @@ MlasSQNBitGemmM1Kernel( * @param CountK Number of rows of B. * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. */ -template +template MLAS_FORCEINLINE void MlasQNBitBlkDequantBForSgemm( + size_t BlkLen, float* FpData, const uint8_t* QuantBData, const float* QuantBScale, @@ -145,9 +147,10 @@ MlasAddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, si } } -template +template MLAS_FORCEINLINE void MLASCALL MlasSQNBitGemmOperation( + const size_t BlkLen, const size_t K, const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, const size_t RangeStartM, @@ -189,7 +192,8 @@ MlasSQNBitGemmOperation( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - MlasSQNBitGemmM1Kernel( + MlasSQNBitGemmM1Kernel( + BlkLen, a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias ); @@ -226,7 +230,8 @@ MlasSQNBitGemmOperation( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - MlasQNBitBlkDequantBForSgemm( + MlasQNBitBlkDequantBForSgemm( + BlkLen, dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks ); @@ -262,6 +267,7 @@ MlasSQNBitGemmOperation( // typedef void(MLASCALL MLAS_SQNBIT_GEMM_OPERATION)( + size_t BlkLen, size_t K, const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, size_t RangeStartM, @@ -271,11 +277,7 @@ typedef void(MLASCALL MLAS_SQNBIT_GEMM_OPERATION)( ); enum QuantVariant { - QuantVariant_BitWidth4_BlockSize16, - QuantVariant_BitWidth4_BlockSize32, - QuantVariant_BitWidth4_BlockSize64, - QuantVariant_BitWidth4_BlockSize128, - QuantVariant_BitWidth4_BlockSize256, + QuantVariant_BitWidth4, QuantVariantCount, // Keep this element last and ensure that its value is the number of other QuantVariant values. // Its value is used as an array size. }; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index 63afe57dd9137..08dbd8507bd0b 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -70,7 +70,7 @@ FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3) template MLAS_FORCEINLINE void -LoadData(const float* src, size_t count, float32x4_t (& dst)[Capacity / 4]) +LoadData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4]) { static_assert(Capacity % 4 == 0, "Capacity must be divisible by 4."); @@ -101,9 +101,10 @@ LoadData(const float* src, size_t count, float32x4_t (& dst)[Capacity / 4]) } } -template +template MLAS_FORCEINLINE void ComputeDotProducts( + size_t BlkLen, const float* ARowPtr, const uint8_t* QuantBDataColPtr, const float* QuantBScaleColPtr, @@ -118,6 +119,9 @@ ComputeDotProducts( { static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); + constexpr size_t SubBlkLen = 16; // number of block elements to process in a sub-block iteration + assert(BlkLen % SubBlkLen == 0); + const uint8x8_t LowMask = vdup_n_u8(0x0F); // Manual conversion to float takes place in two steps: @@ -162,8 +166,6 @@ ComputeDotProducts( }); } - constexpr size_t SubBlkLen = 16; // number of block elements to process in one iteration - for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { // load A row vector elements @@ -268,9 +270,10 @@ ComputeDotProducts( // MlasSQNBitGemmKernel and helpers. // -template +template MLAS_FORCEINLINE void MlasSQNBitGemmM1KernelNeon( + size_t BlkLen, const float* A, const uint8_t* QuantBData, const float* QuantBScale, @@ -304,7 +307,8 @@ MlasSQNBitGemmM1KernelNeon( int64_t nblk = static_cast(CountN) - NCols; while (nblk >= 0) { - ComputeDotProducts( + ComputeDotProducts( + BlkLen, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, BiasPtr @@ -327,7 +331,8 @@ MlasSQNBitGemmM1KernelNeon( // left over columns less than `NCols`? nblk += NCols; for (int64_t n = 0; n < nblk; ++n) { - ComputeDotProducts( + ComputeDotProducts( + BlkLen, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, BiasPtr @@ -346,42 +351,36 @@ MlasSQNBitGemmM1KernelNeon( } } -#define SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(BlkBitWidth, BlkLen) \ - template <> \ - MLAS_FORCEINLINE void \ - MlasSQNBitGemmM1Kernel( \ - const float* A, \ - const uint8_t* QuantBData, \ - const float* QuantBScale, \ - const uint8_t* QuantBZeroPoint, \ - float* C, \ - size_t CountN, \ - size_t CountK, \ - size_t BlockStrideQuantB, \ - const float* Bias \ - ) \ - { \ - return MlasSQNBitGemmM1KernelNeon( \ - A, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, \ - BlockStrideQuantB, Bias \ - ); \ - } - -SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 16) -SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 32) -SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 64) -SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 128) -SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 256) - -#undef SPECIALIZE_SQNBIT_GEMM_M1_KERNEL +template <> +MLAS_FORCEINLINE void +MlasSQNBitGemmM1Kernel<4, MLAS_SQNBIT_GEMM_KERNEL_NEON>( + size_t BlkLen, + const float* A, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + return MlasSQNBitGemmM1KernelNeon<4>( + BlkLen, + A, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, + BlockStrideQuantB, Bias + ); +} // // MlasQNBitBlkDequantBForSgemm and helpers. // -template +template MLAS_FORCEINLINE void MlasQNBitBlkDequantBForSgemmNeon( + size_t BlkLen, float* FpData, const uint8_t* QuantBData, const float* QuantBScale, @@ -448,31 +447,24 @@ MlasQNBitBlkDequantBForSgemmNeon( impl0_reference(); } -#define SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(BlkBitWidth, BlkLen) \ - template <> \ - MLAS_FORCEINLINE void \ - MlasQNBitBlkDequantBForSgemm( \ - float* FpData, \ - const uint8_t* QuantBData, \ - const float* QuantBScale, \ - const uint8_t* QuantBZeroPoint, \ - size_t CountN, \ - size_t CountK, \ - size_t BlockStrideQuantB \ - ) \ - { \ - MlasQNBitBlkDequantBForSgemmNeon( \ - FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB \ - ); \ - } - -SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 16) -SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 32) -SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 64) -SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 128) -SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 256) - -#undef SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM +template <> +MLAS_FORCEINLINE void +MlasQNBitBlkDequantBForSgemm<4, MLAS_SQNBIT_GEMM_KERNEL_NEON>( + size_t BlkLen, + float* FpData, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB +) +{ + MlasQNBitBlkDequantBForSgemmNeon<4>( + BlkLen, + FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB + ); +} // // Kernel dispatch structure definition. @@ -480,10 +472,6 @@ SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 256) const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { MLAS_SQNBIT_GEMM_DISPATCH d; - d.Operations[QuantVariant_BitWidth4_BlockSize16] = MlasSQNBitGemmOperation<4, 16, MLAS_SQNBIT_GEMM_KERNEL_NEON>; - d.Operations[QuantVariant_BitWidth4_BlockSize32] = MlasSQNBitGemmOperation<4, 32, MLAS_SQNBIT_GEMM_KERNEL_NEON>; - d.Operations[QuantVariant_BitWidth4_BlockSize64] = MlasSQNBitGemmOperation<4, 64, MLAS_SQNBIT_GEMM_KERNEL_NEON>; - d.Operations[QuantVariant_BitWidth4_BlockSize128] = MlasSQNBitGemmOperation<4, 128, MLAS_SQNBIT_GEMM_KERNEL_NEON>; - d.Operations[QuantVariant_BitWidth4_BlockSize256] = MlasSQNBitGemmOperation<4, 256, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + d.Operations[QuantVariant_BitWidth4] = MlasSQNBitGemmOperation<4, MLAS_SQNBIT_GEMM_KERNEL_NEON>; return d; }(); From e2a9eee8a5dc37dd5a165139bf9bf0295bc1ac29 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Wed, 13 Dec 2023 18:54:00 -0800 Subject: [PATCH 04/31] Save work --- onnxruntime/core/mlas/inc/mlas_qnbit.h | 37 +- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 414 ++++++++++++++++-- onnxruntime/core/mlas/lib/sqnbitgemm.h | 362 ++++++--------- .../core/mlas/lib/sqnbitgemm_kernel_neon.cpp | 171 +++++--- .../test/mlas/bench/bench_sqnbitgemm.cpp | 1 + .../test/mlas/unittest/test_sqnbitgemm.cpp | 246 ++++++++--- 6 files changed, 864 insertions(+), 367 deletions(-) diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 9620dd42d1da9..8de6670203a65 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -23,6 +23,12 @@ Module Name: #include "mlas.h" #include "mlas_gemm_postprocessor.h" +// TODO add documentation +enum MLAS_SQNBITGEMM_COMPUTE_TYPE { + CompFp32, // fp32 A, fp32 accumulator + CompInt8, // int8 A, int32 accumulator +}; + /** * @brief Data parameters for float/n-bit quantized int GEMM routine. */ @@ -32,11 +38,16 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { const void* QuantBData = nullptr; ///< address of quantized B (quantized n-bit int values) const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block - bool IsBPacked = false; ///< whether B values are packed in an optimized format for the computation const float* Bias = nullptr; ///< optional address of Bias, vector size N float* C = nullptr; ///< address of result matrix size_t ldc = 0; ///< leading dimension of C + /** + * Address of intermediate workspace buffer. + * Only required if MlasSQNBitGemmWorkspaceSize returns a non-zero value. + */ + void* Workspace = nullptr; + ///< optional post processing to apply to result matrix MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; }; @@ -54,6 +65,7 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { * @param[in] BlkLen number of quantized values per block * @param[inout] DataParams An array (size BatchN) of parameter blocks * @param[in] ThreadPool optional thread pool to use + * // TODO update param doc */ void MLASCALL MlasSQNBitGemmBatch( @@ -63,6 +75,7 @@ MlasSQNBitGemmBatch( size_t BatchN, size_t BlkBitWidth, size_t BlkLen, + MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType, const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool = nullptr ); @@ -71,9 +84,29 @@ MlasSQNBitGemmBatch( * @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform. * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block + * TODO update param doc */ bool MLASCALL MlasIsSQNBitGemmAvailable( + size_t M, + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType +); + +/** + * @brief Gets the size in bytes of the intermediate workspace buffer required by the float32/quantized n-bit int GEMM + * implementation. If zero, no intermediate workspace is required. + * // TODO update param doc + */ +size_t MLASCALL +MlasSQNBitGemmWorkspaceSize( + size_t M, + size_t N, + size_t K, size_t BlkBitWidth, - size_t BlkLen + size_t BlkLen, + MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType ); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 9a8694b0085ce..a4d6aedc83ca6 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -11,28 +11,387 @@ Module Name: Abstract: This module implements the float/quantized n-bit integer matrix - multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch. + multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch, + as well as some SQNBitGemm-related query functions. --*/ #include "sqnbitgemm.h" +#include + +namespace +{ + +enum SQNBitGemmVariant { + SQNBitGemmVariantInvalid = -1, + + // Valid variants + + SQNBitGemmVariant_BitWidth4_CompFp32 = 0, + SQNBitGemmVariant_BitWidth4_CompInt8, + + // End of valid variants + + // Keep this element last and ensure that its value is the number of valid SQNBitGemmVariant values. + // Its value is used as an array size. + SQNBitGemmVariantCount, +}; + +SQNBitGemmVariant +GetSQNBitGemmVariant( + size_t M, + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType +) +{ + MLAS_UNREFERENCED_PARAMETER(N); + MLAS_UNREFERENCED_PARAMETER(K); + + if (BlkBitWidth == 4 && + (BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { + if (ComputeType == CompFp32) { + return SQNBitGemmVariant_BitWidth4_CompFp32; + } else if (ComputeType == CompInt8 && M == 1) { + return SQNBitGemmVariant_BitWidth4_CompInt8; + } + } + + return SQNBitGemmVariantInvalid; +} + +} // namespace + +bool MLASCALL +MlasIsSQNBitGemmAvailable( + size_t M, + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType +) +{ + const auto* dispatch = GetMlasPlatform().SQNBitGemmDispatch; + if (dispatch == nullptr) { + return false; + } + + const auto variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); + + switch (variant) { + case SQNBitGemmVariant_BitWidth4_CompFp32: { + return dispatch->SQNBitGemmM1Kernel_BlkBitWidth4_CompFp32 != nullptr && + dispatch->QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32 != nullptr; + } + case SQNBitGemmVariant_BitWidth4_CompInt8: { + return dispatch->SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8 != nullptr && + dispatch->QuantizeA_CompInt8 != nullptr; + } + default: { + return false; + } + } +} + +size_t MLASCALL +MlasSQNBitGemmWorkspaceSize( + size_t M, + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType +) +{ + const auto variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); + + switch (variant) { + case SQNBitGemmVariant_BitWidth4_CompInt8: { + // workspace buffer is used for block quantization of A to int8 + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t RequiredBufferSize = M * BlockCountK * Q8BlkSize(BlkLen); + const size_t RequiredAlignment = Q8BlkAlignment(BlkLen); + return (RequiredBufferSize + RequiredAlignment - 1) / RequiredAlignment * RequiredAlignment; + } + default: { + return 0; + } + } +} + namespace { -// Get quantization variant based on `BlkBitWidth` and `BlkLen`. -// Return -1 if the input values are unsupported. -int32_t -GetDispatchQuantVariant(size_t BlkBitWidth, size_t BlkLen) +MLAS_FORCEINLINE void +AddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc) +{ + for (size_t m = 0; m < CountM; m++) { + const float* bias = Bias; + float* sum = C; + for (size_t n = 0; n < CountN; n += 4) { + if (CountN - n < 4) { + for (size_t nn = n; nn < CountN; nn++) { + *sum += *bias; + sum++; + bias++; + } + break; + } + + MLAS_FLOAT32X4 acc_x = MlasLoadFloat32x4(sum); + acc_x = MlasAddFloat32x4(acc_x, MlasLoadFloat32x4(bias)); + MlasStoreFloat32x4(sum, acc_x); + bias += 4; + sum += 4; + } + C += ldc; + } +} + +typedef void(SQNBitGemmFn)( + const size_t BlkLen, + const size_t K, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +); + +void +SQNBitGemm_BlkBitWidth4_CompFp32( + const size_t BlkLen, + const size_t K, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) +{ + constexpr size_t BlkBitWidth = 4; + + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; + + const size_t k_blks = MlasDivRoundup(K, BlkLen); + const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); + + const float* A = DataParams->A + RangeStartM * lda; + + const uint8_t* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; + const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; + const uint8_t* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; + + if (RangeCountM == 1) { + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const float* a_row = A; + const uint8_t* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const uint8_t* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + GetMlasPlatform().SQNBitGemmDispatch->SQNBitGemmM1Kernel_BlkBitWidth4_CompFp32( + BlkLen, + a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + } + return; + } + + constexpr size_t StrideN = 32; + size_t bufsize = k_blks * BlkLen * StrideN * sizeof(float); + MlasThreadedBufAlloc(bufsize); + auto* dequant_b = reinterpret_cast(ThreadedBufHolder.get()); + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, StrideN); + + // + // Step through each slice of matrix A along the M dimension. + // + const float* a_row = A; + const uint8_t* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const uint8_t* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + GetMlasPlatform().SQNBitGemmDispatch->QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32( + BlkLen, + dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks + ); + + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true + ); +#else + auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f); +#endif + + if (bias) { + AddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc); + } + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN, + RowsHandled, CountN, ldc + ); + } + + c_blk += ldc * RowsHandled; + a_row += lda * RowsHandled; + RowsRemaining -= RowsHandled; + } + } +} + +void +SQNBitGemm_BlkBitWidth4_CompInt8( + const size_t BlkLen, + const size_t K, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) { - MLAS_UNREFERENCED_PARAMETER(BlkLen); - int32_t type = -1; - if (BlkBitWidth == 4) { - type = QuantVariant_BitWidth4; + constexpr size_t BlkBitWidth = 4; + + const size_t k_blks = MlasDivRoundup(K, BlkLen); + + const size_t lda = k_blks * Q8BlkSize(BlkLen); + const size_t ldc = DataParams->ldc; + const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); + + const std::byte* QuantA = static_cast(DataParams->Workspace) + RangeStartM * lda; + + const uint8_t* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; + const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; + const uint8_t* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; + + if (RangeCountM == 1) { + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const std::byte* a_row = QuantA; + const uint8_t* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const uint8_t* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + GetMlasPlatform().SQNBitGemmDispatch->SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8( + BlkLen, + a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + } + return; } - return type; + assert(false && "not implemented for M > 1"); +} + +typedef void(InitializeWorkspaceFn)( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkLen, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool +); + +void +InitializeWorkspace_CompInt8( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkLen, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool +) +{ + MLAS_UNREFERENCED_PARAMETER(N); + + MLAS_UNREFERENCED_PARAMETER(ThreadPool); + + const auto QuantizeA = GetMlasPlatform().SQNBitGemmDispatch->QuantizeA_CompInt8; + + // TODO use threading + for (size_t gemm_idx = 0; gemm_idx < BatchN; ++gemm_idx) { + auto& data = DataParams[gemm_idx]; + + QuantizeA(BlkLen, data.A, M, K, K, static_cast(data.Workspace)); + } } +struct Operations { + InitializeWorkspaceFn* InitializeWorkspace = nullptr; + SQNBitGemmFn* SQNBitGemm = nullptr; +}; + +constexpr auto OperationMap = []() { + std::array ops; + + ops[SQNBitGemmVariant_BitWidth4_CompFp32].SQNBitGemm = SQNBitGemm_BlkBitWidth4_CompFp32; + + ops[SQNBitGemmVariant_BitWidth4_CompInt8].InitializeWorkspace = InitializeWorkspace_CompInt8; + ops[SQNBitGemmVariant_BitWidth4_CompInt8].SQNBitGemm = SQNBitGemm_BlkBitWidth4_CompInt8; + + return ops; +}(); + } // namespace void MLASCALL @@ -43,17 +402,25 @@ MlasSQNBitGemmBatch( const size_t BatchN, const size_t BlkBitWidth, const size_t BlkLen, + MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType, const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool ) { - const int32_t QuantVariant = GetDispatchQuantVariant(BlkBitWidth, BlkLen); - MLAS_SQNBIT_GEMM_OPERATION* const Operation = GetMlasPlatform().SQNBitGemmDispatch->Operations[QuantVariant]; + const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); + assert(Variant != SQNBitGemmVariantInvalid); + + if (const auto InitializeWorkspaceOperation = OperationMap[Variant].InitializeWorkspace; + InitializeWorkspaceOperation != nullptr) { + InitializeWorkspaceOperation(M, N, K, BatchN, BlkLen, DataParams, ThreadPool); + } + + const auto ComputeOperation = OperationMap[Variant].SQNBitGemm; if (ThreadPool == nullptr) { for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { auto Data = &DataParams[gemm_i]; - Operation(BlkLen, K, Data, 0, M, 0, N); + ComputeOperation(BlkLen, K, Data, 0, M, 0, N); } return; } @@ -113,25 +480,6 @@ MlasSQNBitGemmBatch( const size_t RangeStartN = ThreadIdN * StrideN; const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); - Operation(BlkLen, K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + ComputeOperation(BlkLen, K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN); }); } - -bool MLASCALL -MlasIsSQNBitGemmAvailable( - size_t BlkBitWidth, - size_t BlkLen -) -{ - const int32_t QuantVariant = GetDispatchQuantVariant(BlkBitWidth, BlkLen); - if (QuantVariant == -1) { - return false; - } - - if (GetMlasPlatform().SQNBitGemmDispatch == nullptr || - GetMlasPlatform().SQNBitGemmDispatch->Operations[QuantVariant] == nullptr) { - return false; - } - - return true; -} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index 870359b60ea8e..b23955183f34c 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -10,7 +10,7 @@ Module Name: Abstract: - This module includes: + This module includes: // TODO update - Declaration of the set of template functions used to implement a kernel for a matrix/matrix multiplication, A*B, where A is a float matrix and B is @@ -31,79 +31,6 @@ Module Name: #include "mlas_qnbit.h" #include "mlasi.h" -// -// Kernel implementation template declarations -// - -/** - * @brief Multiply float matrix A with quantized n-bit integer matrix B. - * B is block quantized and column major. - * This kernel handles the special case where M, the number of rows of A and C, is 1. - * - * @tparam BlkBitWidth Bit width of each value in a block. - * @tparam KernelType Hardware-specific kernel type. - * - * @param BlkLen Number of values in a block. - * @param A Supplies the A matrix. - * @param QuantBData Supplies the quantized B matrix block data. - * @param QuantBScale Supplies the quantized B matrix block scale values. - * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. - * @param[out] C Supplies the output C matrix. - * @param CountN Number of columns of B and C. - * @param CountK Number of columns of A and rows of B. - * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. - * @param Bias Bias vector of length N. - */ -template -MLAS_FORCEINLINE void -MlasSQNBitGemmM1Kernel( - size_t BlkLen, - const float* A, - const uint8_t* QuantBData, - const float* QuantBScale, - const uint8_t* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB, - const float* Bias -); - -/** - * @brief Dequantize B into the format expected by the Sgemm kernel. - * B is block quantized and column major. - * This is equivalent to dequantizing B and then running - * MlasSgemmCopyPackB. - * - * @tparam BlkBitWidth Bit width of each value in a block. - * @tparam KernelType Hardware-specific kernel type. - * - * @param BlkLen Number of values in a block. - * @param[out] FpData Supplies the output buffer for the dequantized B float data. - * @param QuantBData Supplies the quantized B matrix block data. - * @param QuantBScale Supplies the quantized B matrix block scale values. - * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. - * @param CountN Number of columns of B. - * @param CountK Number of rows of B. - * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. - */ -template -MLAS_FORCEINLINE void -MlasQNBitBlkDequantBForSgemm( - size_t BlkLen, - float* FpData, - const uint8_t* QuantBData, - const float* QuantBScale, - const uint8_t* QuantBZeroPoint, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB -); - -// -// MlasQNBitGemmOperation and helpers -// - constexpr MLAS_FORCEINLINE size_t MlasQNBitBlkDataSizeInBytes(size_t BlkBitWidth, size_t BlkLen) { @@ -121,169 +48,152 @@ MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount) } } -MLAS_FORCEINLINE void -MlasAddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc) +// +// Quantized int8 block helpers. +// + +MLAS_FORCEINLINE +const float& +Q8BlkScale(const std::byte* BlkPtr) { - for (size_t m = 0; m < CountM; m++) { - const float* bias = Bias; - float* sum = C; - for (size_t n = 0; n < CountN; n += 4) { - if (CountN - n < 4) { - for (size_t nn = n; nn < CountN; nn++) { - *sum += *bias; - sum++; - bias++; - } - break; - } - - MLAS_FLOAT32X4 acc_x = MlasLoadFloat32x4(sum); - acc_x = MlasAddFloat32x4(acc_x, MlasLoadFloat32x4(bias)); - MlasStoreFloat32x4(sum, acc_x); - bias += 4; - sum += 4; - } - C += ldc; - } + return *reinterpret_cast(BlkPtr); } -template -MLAS_FORCEINLINE void MLASCALL -MlasSQNBitGemmOperation( - const size_t BlkLen, - const size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN -) +MLAS_FORCEINLINE +float& +Q8BlkScale(std::byte* BlkPtr) { - const size_t lda = DataParams->lda; - const size_t ldc = DataParams->ldc; - - const size_t k_blks = MlasDivRoundup(K, BlkLen); - const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); - - const float* A = DataParams->A + RangeStartM * lda; - - const uint8_t* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; - const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; - const uint8_t* QuantBZeroPoint = - (DataParams->QuantBZeroPoint == nullptr) - ? nullptr - : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; - - float* C = DataParams->C + RangeStartM * ldc + RangeStartN; - - const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; - - if (RangeCountM == 1) { - size_t CountN; - for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(RangeCountN - n, size_t{128}); - - const float* a_row = A; - const uint8_t* b_col = QuantBData + n * ldb; - const float* b_col_scale = QuantBScale + n * k_blks; - const uint8_t* b_col_zp = - (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; - float* c_blk = C + n; - const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - - MlasSQNBitGemmM1Kernel( - BlkLen, - a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias - ); - - if (DataParams->PostProcessor != nullptr) { - DataParams->PostProcessor->Process( - DataParams->C, RangeStartM, RangeStartN + n, - RangeCountM, CountN, ldc - ); - } - } - return; - } + return *reinterpret_cast(BlkPtr); +} - constexpr size_t StrideN = 32; - size_t bufsize = k_blks * BlkLen * StrideN * sizeof(float); - MlasThreadedBufAlloc(bufsize); - auto* dequant_b = reinterpret_cast(ThreadedBufHolder.get()); - // - // Step through each slice of matrix B along the N dimension. - // +MLAS_FORCEINLINE +const int8_t* +Q8BlkData(const std::byte* BlkPtr) +{ + return reinterpret_cast(BlkPtr + sizeof(float)); +} - size_t CountN; - for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(RangeCountN - n, StrideN); - - // - // Step through each slice of matrix A along the M dimension. - // - const float* a_row = A; - const uint8_t* b_col = QuantBData + n * ldb; - const float* b_col_scale = QuantBScale + n * k_blks; - const uint8_t* b_col_zp = - (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; - float* c_blk = C + n; - const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - - MlasQNBitBlkDequantBForSgemm( - BlkLen, - dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks - ); - - size_t RowsRemaining = RangeCountM; - while (RowsRemaining > 0) { -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) - auto RowsHandled = GetMlasPlatform().GemmFloatKernel( - a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true - ); -#else - auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f); -#endif - - if (bias) { - MlasAddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc); - } - if (DataParams->PostProcessor != nullptr) { - DataParams->PostProcessor->Process( - DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN, - RowsHandled, CountN, ldc - ); - } - - c_blk += ldc * RowsHandled; - a_row += lda * RowsHandled; - RowsRemaining -= RowsHandled; - } - } +MLAS_FORCEINLINE +int8_t* +Q8BlkData(std::byte* BlkPtr) +{ + return reinterpret_cast(BlkPtr + sizeof(float)); +} + +MLAS_FORCEINLINE +constexpr size_t +Q8BlkSize(size_t BlkLen) +{ + const size_t BlkSize = sizeof(float) + BlkLen * sizeof(int8_t); + // Currently, the strictest alignment requirement of a block is for a float. + // Ensure contiguous blocks are suitably aligned. + // assert(BlkSize % alignof(float) == 0); // TODO needs include, put it in .cpp? + return BlkSize; +} + +MLAS_FORCEINLINE +constexpr size_t +Q8BlkAlignment(size_t BlkLen) +{ + MLAS_UNREFERENCED_PARAMETER(BlkLen); + return alignof(float); } // // Kernel dispatch structure. // -typedef void(MLASCALL MLAS_SQNBIT_GEMM_OPERATION)( - size_t BlkLen, - size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, - size_t RangeStartM, - size_t RangeCountM, - size_t RangeStartN, - size_t RangeCountN -); - -enum QuantVariant { - QuantVariant_BitWidth4, - QuantVariantCount, // Keep this element last and ensure that its value is the number of other QuantVariant values. - // Its value is used as an array size. -}; - struct MLAS_SQNBIT_GEMM_DISPATCH { - MLAS_SQNBIT_GEMM_OPERATION* Operations[QuantVariantCount] = { - // Initialized to nullptrs. Overwrite in hardware-specific kernel implementation. - }; + // + // CompFp32 kernels + // + + /** + * @brief Multiply float matrix A with quantized n-bit integer matrix B. + * B is block quantized and column major. + * This kernel handles the special case where M, the number of rows of A and C, is 1. + * + * @param BlkLen Number of values in a block. + * @param A Supplies the A matrix. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + */ + typedef void(SQNBitGemmM1Kernel_BlkBitWidth4_CompFp32_Fn)( + size_t BlkLen, + const float* A, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias + ); + + SQNBitGemmM1Kernel_BlkBitWidth4_CompFp32_Fn* SQNBitGemmM1Kernel_BlkBitWidth4_CompFp32 = nullptr; + + /** + * @brief Dequantize B into the format expected by the Sgemm kernel. + * B is block quantized and column major. + * This is equivalent to dequantizing B and then running + * MlasSgemmCopyPackB. + * + * @param BlkLen Number of values in a block. + * @param[out] FpData Supplies the output buffer for the dequantized B float data. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param CountN Number of columns of B. + * @param CountK Number of rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + */ + typedef void(QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32_Fn)( + size_t BlkLen, + float* FpData, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB + ); + + QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32_Fn* QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32 = nullptr; + + // + // CompInt8 kernels + // + + typedef void(SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8_Fn)( + size_t BlkLen, + const std::byte* QuantA, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias + ); + + SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8_Fn* SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8 = nullptr; + + typedef void(QuantizeA_CompInt8_Fn)( + size_t BlkLen, + const float* A, + size_t CountM, + size_t CountK, + size_t lda, + std::byte* QuantA + ); + + QuantizeA_CompInt8_Fn* QuantizeA_CompInt8 = nullptr; }; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index 08dbd8507bd0b..38342febbf8ff 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -15,19 +15,13 @@ Module Name: --*/ -#include "sqnbitgemm.h" - #include #include #include #include -// -// Hardware-specific kernel type. -// -struct MLAS_SQNBIT_GEMM_KERNEL_NEON { -}; +#include "sqnbitgemm.h" namespace { @@ -264,12 +258,6 @@ ComputeDotProducts( } } -} // namespace - -// -// MlasSQNBitGemmKernel and helpers. -// - template MLAS_FORCEINLINE void MlasSQNBitGemmM1KernelNeon( @@ -351,32 +339,6 @@ MlasSQNBitGemmM1KernelNeon( } } -template <> -MLAS_FORCEINLINE void -MlasSQNBitGemmM1Kernel<4, MLAS_SQNBIT_GEMM_KERNEL_NEON>( - size_t BlkLen, - const float* A, - const uint8_t* QuantBData, - const float* QuantBScale, - const uint8_t* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB, - const float* Bias -) -{ - return MlasSQNBitGemmM1KernelNeon<4>( - BlkLen, - A, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, - BlockStrideQuantB, Bias - ); -} - -// -// MlasQNBitBlkDequantBForSgemm and helpers. -// - template MLAS_FORCEINLINE void MlasQNBitBlkDequantBForSgemmNeon( @@ -447,31 +409,142 @@ MlasQNBitBlkDequantBForSgemmNeon( impl0_reference(); } -template <> -MLAS_FORCEINLINE void -MlasQNBitBlkDequantBForSgemm<4, MLAS_SQNBIT_GEMM_KERNEL_NEON>( +// +// CompInt8 kernel implementation and related helpers +// + +void MLASCALL +QuantizeA_CompInt8( size_t BlkLen, - float* FpData, + const float* A, + size_t CountM, + size_t CountK, + size_t lda, + std::byte* QuantA +) +{ + auto impl0_reference = [&]() { + const size_t BlockCountK = MlasDivRoundup(CountK, BlkLen); + + const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); + + for (size_t m = 0; m < CountM; ++m) { + const float* ADataRowPtr = A + m * lda; + std::byte* QuantARowPtr = QuantA + m * QuantAStride; + + for (size_t k = 0, k_blk = 0; k < CountK; k += BlkLen, ++k_blk) { + const size_t k_blk_len = std::min(CountK - k, BlkLen); + + const float* ADataBlkPtr = ADataRowPtr + k; + + // scan block values first to determine scale + + float amax = 0.0f; // max of absolute values of A block + + for (size_t kk = 0; kk < k_blk_len; ++kk) { + float a = ADataBlkPtr[kk]; + amax = std::max(amax, fabsf(a)); + } + + constexpr float range_max = (1 << 7) - 1; + const float scale = amax / range_max; + const float scale_reciprocal = scale != 0.0f ? 1.0f / scale : 0.0f; + + std::byte* QuantABlkPtr = QuantARowPtr + k_blk * Q8BlkSize(BlkLen); + + Q8BlkScale(QuantABlkPtr) = scale; + int8_t* QuantABlkData = Q8BlkData(QuantABlkPtr); + + for (size_t kk = 0; kk < k_blk_len; ++kk) { + const float q = ADataBlkPtr[kk] * scale_reciprocal; + QuantABlkData[kk] = static_cast( + std::clamp( + q, + static_cast(std::numeric_limits::min()), + static_cast(std::numeric_limits::max()) + ) + ); + } + } + } + }; + + // TODO neon impl + + impl0_reference(); +} + +MLAS_FORCEINLINE +void +SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8( + size_t BlkLen, + const std::byte* QuantA, const uint8_t* QuantBData, const float* QuantBScale, const uint8_t* QuantBZeroPoint, + float* C, size_t CountN, size_t CountK, - size_t BlockStrideQuantB + size_t BlockStrideQuantB, + const float* Bias ) { - MlasQNBitBlkDequantBForSgemmNeon<4>( - BlkLen, - FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB - ); + auto impl0_reference = [&]() { + const std::byte* QuantARowPtr = QuantA; + + for (size_t n = 0; n < CountN; ++n) { + float sum = Bias != nullptr ? Bias[n] : 0.0f; + + for (size_t k = 0, k_blk = 0; k < CountK; k += BlkLen, ++k_blk) { + const size_t k_blk_len = std::min(CountK - k, BlkLen); + + const std::byte* QuantABlkPtr = QuantARowPtr + k_blk * Q8BlkSize(BlkLen); + + const float a_scale = Q8BlkScale(QuantABlkPtr); + + const float b_scale = QuantBScale[n * BlockStrideQuantB + k_blk]; + + int8_t b_zp = 8; + if (QuantBZeroPoint != nullptr) { + const uint8_t b_zp_byte = QuantBZeroPoint[n * ((BlockStrideQuantB + 1) / 2) + k_blk / 2]; + b_zp = (k_blk & 1) ? static_cast(b_zp_byte >> 4) : static_cast(b_zp_byte & 0x0F); + } + + int32_t qsum = 0; + + const int8_t* QuantABlkData = Q8BlkData(QuantABlkPtr); + for (size_t kk = 0; kk < k_blk_len; ++kk) { + const int8_t qa = QuantABlkData[kk]; + const uint8_t qb_byte = QuantBData[(n * BlockStrideQuantB * BlkLen + k + kk) / 2]; + const int8_t qb = ((kk & 1) == 1 ? static_cast(qb_byte >> 4) : static_cast(qb_byte & 0x0F)) - b_zp; + qsum += qa * qb; + } + + sum += static_cast(qsum) * a_scale * b_scale; + } + + C[n] = sum; + } + }; + + // TODO neon impl + + impl0_reference(); } +} // namespace + // // Kernel dispatch structure definition. // const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { MLAS_SQNBIT_GEMM_DISPATCH d; - d.Operations[QuantVariant_BitWidth4] = MlasSQNBitGemmOperation<4, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + + d.SQNBitGemmM1Kernel_BlkBitWidth4_CompFp32 = MlasSQNBitGemmM1KernelNeon<4>; + d.QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32 = MlasQNBitBlkDequantBForSgemmNeon<4>; + d.SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8 = SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8; + d.QuantizeA_CompInt8 = QuantizeA_CompInt8; + return d; }(); diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp index 2f2635dab0512..0ca48dcea8d51 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -74,6 +74,7 @@ static void GemmSizeProducts(benchmark::internal::Benchmark* b) { ArgsProduct(b, {{1, 1024, 2048}, {4096, 11008}, {4096, 11008}, {8}}); } +// TODO disable if unavailable BENCHMARK(SQNBITGEMM<4, 16, false>)->Apply(GemmSizeProducts)->UseRealTime(); BENCHMARK(SQNBITGEMM<4, 16, true>)->Apply(GemmSizeProducts)->UseRealTime(); BENCHMARK(SQNBITGEMM<4, 32, false>)->Apply(GemmSizeProducts)->UseRealTime(); diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 6c97d60301573..c442d1441b35d 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -18,6 +18,17 @@ Module Name: #include "mlas_q4.h" #include "mlas_qnbit.h" +static constexpr const char* ComputeTypeName(MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType) { + switch (ComputeType) { + case CompFp32: + return "Fp32"; + case CompInt8: + return "Int8"; + default: + return "unknown"; + } +} + /** * @brief Test class for n-bit int block quantized GEMM * Note: only 2-D matmul supported for now @@ -26,12 +37,15 @@ template class MlasSQNBitGemmTest : public MlasTestBase { private: MatrixGuardBuffer BufferA; + MatrixGuardBuffer BufferQuantAData; + MatrixGuardBuffer BufferQuantAScale; MatrixGuardBuffer BufferB; MatrixGuardBuffer BufferQuantBData; MatrixGuardBuffer BufferQuantBZeroPoint; MatrixGuardBuffer BufferQuantBScale; MatrixGuardBuffer BufferDequantizedB; MatrixGuardBuffer BufferBias; + MatrixGuardBuffer BufferWorkspace; MatrixGuardBuffer BufferC; MatrixGuardBuffer BufferCReference; @@ -46,6 +60,8 @@ class MlasSQNBitGemmTest : public MlasTestBase { const float* Bias, float* C, size_t ldc, + void* Workspace, + MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType, MLAS_THREADPOOL* Threadpool) { MLAS_SQNBIT_GEMM_DATA_PARAMS params; params.A = A; @@ -56,20 +72,104 @@ class MlasSQNBitGemmTest : public MlasTestBase { params.QuantBData = QuantBData; params.QuantBScale = QuantBScale; params.QuantBZeroPoint = QuantBZeroPoint; + params.Workspace = Workspace; params.PostProcessor = nullptr; - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ¶ms, Threadpool); + MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Threadpool); + } + + void QuantizeA(size_t M, size_t K, const float* A, int8_t* QuantAData, float* QuantAScale) { + const size_t BlockCountK = (K + BlkLen - 1) / BlkLen; + const size_t lda = K; + for (size_t m = 0; m < M; ++m) { + for (size_t k = 0, k_blk = 0; k < K; k += BlkLen, ++k_blk) { + const size_t local_blk_len = std::min(K - k, BlkLen); + float blk_a[BlkLen]{}; + std::copy_n(A + m * lda + k, local_blk_len, blk_a); + + float amax = 0.0f; // max of absolute values of A block + for (size_t kk = 0; kk < k_blk_len; ++kk) { + float a = blk_a[kk]; + amax = std::max(amax, fabsf(a)); + } + + constexpr float range_max = (1 << 7) - 1; + const float scale = amax / range_max; + const float scale_reciprocal = scale != 0.0f ? 1.0f / scale : 0.0f; + + QuantAScale[m * BlockCountK + k_blk] = scale; + + for (size_t kk = 0; kk < BlkLen; ++kk) { + const float q = blk_a[kk] * scale_reciprocal; + QuantAData[m * BlockCountK * BlkLen + k + kk] = + static_cast( + std::clamp(q, + static_cast(std::numeric_limits::min()), + static_cast(std::numeric_limits::max()))); + } + } + } + } + + void CallReferenceGemm_CompInt8(size_t M, + size_t N, + size_t K, + const float* A, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + const float* Bias, + float* C) { + const size_t BlockCountK = (K + BlkLen - 1) / BlkLen; + + int8_t* QuantAData = BufferQuantAData.GetBuffer(M * BlockCountK * BlkLen); + float* QuantAScale = BufferQuantAScale.GetBuffer(M * BlockCountK); + QuantizeA(M, K, A, QuantAData, QuantAScale); + + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float sum = Bias == nullptr ? 0.0f : Bias[n]; + for (size_t k = 0, k_blk = 0; k < K; k += BlkLen, ++k_blk) { + const size_t k_blk_len = std::min(K - k, BlkLen); + + const float a_scale = QuantAScale[m * BlockCountK + k_blk]; + + const float b_scale = QuantBScale[n * BlockCountK + k_blk]; + + static_assert(BlkBitWidth == 4, "only implemented for 4-bit quantized B"); + + uint8_t b_zp = 8; + if (QuantBZeroPoint != nullptr) { + const uint8_t b_zp_byte = QuantBZeroPoint[n * ((BlockCountK + 1) / 2) + k_blk / 2]; + b_zp = (k_blk & 1) ? (b_zp_byte >> 4) : (b_zp_byte & 0x0F); + } + + int32_t qsum = 0; + + for (size_t kk = 0; kk < k_blk_len; ++kk) { + const int8_t qa = QuantAData[m * BlockCountK * BlkLen + k + kk]; + const uint8_t qb_byte = QuantBData[(n * BlockCountK * BlkLen + k + kk) / 2]; + const int8_t qb = ((kk & 1) == 1 ? (qb_byte >> 4) : (qb_byte & 0x0F)) - b_zp; + qsum += qa * qb; + } + + sum += static_cast(qsum) * a_scale * b_scale; + } + + C[m * N + n] = sum; + } + } } - void CallReferenceGemm(size_t M, - size_t N, - size_t K, - const float* A, - const uint8_t* QuantBData, - const float* QuantBScale, - const uint8_t* QuantBZeroPoint, - const float* Bias, - float* C) { + void CallReferenceGemm_CompFp32(size_t M, + size_t N, + size_t K, + const float* A, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + const float* Bias, + float* C) { float* DequantizedBData = BufferDequantizedB.GetBuffer(K * N); MlasDequantizeBlockwise( DequantizedBData, QuantBData, QuantBScale, QuantBZeroPoint, BlkLen, /* columnwise */ true, @@ -95,10 +195,11 @@ class MlasSQNBitGemmTest : public MlasTestBase { public: void Test(size_t M, size_t N, size_t K, + MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType, bool WithBias, bool Symmetric, bool WithThreadpool) { MLAS_THREADPOOL* Threadpool = WithThreadpool ? GetMlasThreadPool() : nullptr; - const float* A = BufferA.GetBuffer(K * M); + float* A = BufferA.GetBuffer(K * M); const float* B = BufferB.GetBuffer(N * K); @@ -126,7 +227,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { float* C = BufferC.GetBuffer(N * M, true); float* CReference = BufferCReference.GetBuffer(N * M, true); - // pack B + // quantize B uint8_t* QuantBData = nullptr; float* QuantBScale = nullptr; uint8_t* QuantBZeroPoint = nullptr; @@ -138,20 +239,35 @@ class MlasSQNBitGemmTest : public MlasTestBase { QuantBData = BufferQuantBData.GetBuffer(QuantBDataSizeInBytes); QuantBScale = BufferQuantBScale.GetBuffer(QuantBScaleSize); - if (Symmetric) { + if (!Symmetric) { QuantBZeroPoint = BufferQuantBZeroPoint.GetBuffer(QuantBZeroPointSizeInBytes); } - MlasQuantizeBlockwise(QuantBData, QuantBScale, QuantBZeroPoint, - B, BlkLen, - /* columnwise */ true, - static_cast(K), static_cast(N), - static_cast(N), - GetMlasThreadPool()); + MlasQuantizeBlockwise(QuantBData, QuantBScale, QuantBZeroPoint, + B, BlkLen, + /* columnwise */ true, + static_cast(K), static_cast(N), + static_cast(N), + GetMlasThreadPool()); + } + + void* Workspace = nullptr; + if (const auto WorkspaceSize = MlasSQNBitGemmWorkspaceSize(M, N, K, BlkBitWidth, BlkLen, ComputeType); + WorkspaceSize > 0) { + Workspace = BufferWorkspace.GetBuffer(WorkspaceSize); } - CallGemm(M, N, K, A, /* lda */ K, QuantBData, QuantBScale, QuantBZeroPoint, Bias, C, /* ldc */ N, Threadpool); - CallReferenceGemm(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); + if (ComputeType == CompFp32) { + CallReferenceGemm_CompFp32(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); + } else if (ComputeType == CompInt8) { + CallReferenceGemm_CompInt8(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); + } else { + FAIL() << "Test is not implemented for compute type " + << ComputeType << " (" << ComputeTypeName(ComputeType) << ")"; + } + + CallGemm(M, N, K, A, /* lda */ K, QuantBData, QuantBScale, QuantBZeroPoint, Bias, C, /* ldc */ N, Workspace, + ComputeType, Threadpool); size_t f = 0; for (size_t m = 0; m < M; m++) { @@ -179,74 +295,90 @@ template class SQNBitGemmShortExecuteTest : public MlasTestFixture> { public: explicit SQNBitGemmShortExecuteTest(size_t M, size_t N, size_t K, + MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType, bool WithThreadpool, bool Symmetric, bool WithBias) - : M_(M), N_(N), K_(K), WithThreadpool_(WithThreadpool), Symmetric_(Symmetric), WithBias_(WithBias) { + : M_(M), + N_(N), + K_(K), + ComputeType_(ComputeType), + WithThreadpool_(WithThreadpool), + Symmetric_(Symmetric), + WithBias_(WithBias) { } void TestBody() override { MlasTestFixture>::mlas_tester->Test( - M_, N_, K_, WithThreadpool_, Symmetric_, WithBias_); + M_, N_, K_, ComputeType_, WithThreadpool_, Symmetric_, WithBias_); } static size_t RegisterSingleTest(size_t M, size_t N, size_t K, + MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType, bool WithThreadpool, bool Symmetric, bool WithBias) { - std::stringstream ss; - ss << (WithThreadpool ? "SingleThread" : "Threaded") - << "/isSymmetric" << Symmetric - << "/M" << M << "xN" << N << "xK" << K - << "/hasBias" << WithBias; - auto test_name = ss.str(); - - testing::RegisterTest( - MlasSQNBitGemmTest::GetTestSuiteName(), - test_name.c_str(), - nullptr, - test_name.c_str(), - __FILE__, - __LINE__, - // Important to use the fixture type as the return type here. - [=]() -> MlasTestFixture>* { - return new SQNBitGemmShortExecuteTest( - M, N, K, WithThreadpool, Symmetric, WithBias); - }); - - return 1; + size_t tests_registered = 0; + + if (MlasIsSQNBitGemmAvailable(M, N, K, BlkBitWidth, BlkLen, ComputeType)) { + std::stringstream ss; + ss << (WithThreadpool ? "SingleThread" : "Threaded") + << "/isSymmetric" << Symmetric + << "/M" << M << "xN" << N << "xK" << K + << "/hasBias" << WithBias + << "/computeType" << ComputeTypeName(ComputeType); + auto test_name = ss.str(); + + testing::RegisterTest( + MlasSQNBitGemmTest::GetTestSuiteName(), + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + // Important to use the fixture type as the return type here. + [=]() -> MlasTestFixture>* { + return new SQNBitGemmShortExecuteTest( + M, N, K, ComputeType, WithThreadpool, Symmetric, WithBias); + }); + + tests_registered += 1; + } + + return tests_registered; } static size_t RegisterShortExecuteTests() { - size_t test_registered = 0; + size_t tests_registered = 0; - if (MlasIsSQNBitGemmAvailable(BlkBitWidth, BlkLen)) { + for (MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType : {CompFp32, CompInt8}) { for (bool WithThreadpool : {false, true}) { for (bool Symmetric : {false, true}) { for (size_t b = 1; b < 16; b++) { - test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, false); - test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, true); + tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, false); + tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, true); } for (size_t b = 16; b <= 256; b <<= 1) { - test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, false); - test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, true); + tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, false); + tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, true); } for (size_t b = 256; b < 320; b += 32) { - test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, true); + tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, true); } for (size_t b = 1; b < 96; b++) { - test_registered += RegisterSingleTest(1, b, 32, WithThreadpool, Symmetric, false); - test_registered += RegisterSingleTest(1, 32, b, WithThreadpool, Symmetric, true); - test_registered += RegisterSingleTest(1, b, b, WithThreadpool, Symmetric, false); + tests_registered += RegisterSingleTest(1, b, 32, ComputeType, WithThreadpool, Symmetric, false); + tests_registered += RegisterSingleTest(1, 32, b, ComputeType, WithThreadpool, Symmetric, true); + tests_registered += RegisterSingleTest(1, b, b, ComputeType, WithThreadpool, Symmetric, false); } - test_registered += RegisterSingleTest(43, 500, 401, WithThreadpool, Symmetric, true); + tests_registered += RegisterSingleTest(43, 500, 401, ComputeType, WithThreadpool, Symmetric, true); - // test_registered += RegisterSingleTest(1001, 1027, 1031, WithThreadpool, Symmetric, false); + // tests_registered += RegisterSingleTest(1001, 1027, 1031, ComputeType, WithThreadpool, Symmetric, false); } } } - return test_registered; + return tests_registered; } private: size_t M_, N_, K_; + MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType_; bool WithThreadpool_, Symmetric_, WithBias_; }; From 966a915064eb0a30cd0dbaebd2f991e27e53433b Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Thu, 14 Dec 2023 10:52:29 -0800 Subject: [PATCH 05/31] only enable benchmark if available --- onnxruntime/test/mlas/bench/bench_q4gemm.cpp | 2 +- onnxruntime/test/mlas/bench/bench_sconv.cpp | 3 +- onnxruntime/test/mlas/bench/bench_sgemm.cpp | 10 +-- .../test/mlas/bench/bench_sqnbitgemm.cpp | 68 +++++++++++-------- onnxruntime/test/mlas/bench/bench_util.cpp | 11 +-- onnxruntime/test/mlas/bench/bench_util.h | 8 ++- 6 files changed, 58 insertions(+), 44 deletions(-) diff --git a/onnxruntime/test/mlas/bench/bench_q4gemm.cpp b/onnxruntime/test/mlas/bench/bench_q4gemm.cpp index 57c6c93969840..61b3f57d8daac 100644 --- a/onnxruntime/test/mlas/bench/bench_q4gemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_q4gemm.cpp @@ -109,7 +109,7 @@ void Q8Q4GEMM(benchmark::State& state, MLAS_BLK_QUANT_TYPE qtype) { static void GemmSizeProducts(benchmark::internal::Benchmark* b) { b->ArgNames(q4gemm_bench_arg_names); - ArgsProduct(b, {{1, 1024, 2048}, {4096}, {4096}, {8}}); + b->ArgsProduct({{1, 1024, 2048}, {4096}, {4096}, {8}}); } [[maybe_unused]] static const bool benchmarks_registered = []() { diff --git a/onnxruntime/test/mlas/bench/bench_sconv.cpp b/onnxruntime/test/mlas/bench/bench_sconv.cpp index 115641f6a6efb..39d135236b89c 100644 --- a/onnxruntime/test/mlas/bench/bench_sconv.cpp +++ b/onnxruntime/test/mlas/bench/bench_sconv.cpp @@ -224,8 +224,7 @@ BENCHMARK_CAPTURE(SCONV_NCHW, TeamsModel, "")->Apply(TeamsModel)->UseRealTime(); static void General_Conv2d(benchmark::internal::Benchmark* b) { b->ArgNames(ArgNamesForConv(2)); - ArgsProduct( - b, + b->ArgsProduct( {{2}, // Rank, {1}, // N {1, 2}, // Groups diff --git a/onnxruntime/test/mlas/bench/bench_sgemm.cpp b/onnxruntime/test/mlas/bench/bench_sgemm.cpp index e6e34bc88ad59..a94d33cd77f63 100644 --- a/onnxruntime/test/mlas/bench/bench_sgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sgemm.cpp @@ -103,14 +103,14 @@ void SGEMM(benchmark::State& state, bool pack_b, bool trans_a, bool trans_b, flo static void GemmSizeWithOne(benchmark::internal::Benchmark* b) { b->ArgNames(sgemm_bench_arg_names); - ArgsProduct(b, {{1}, {63, 255, 1023}, {63, 255, 1023}}); - ArgsProduct(b, {{63, 255, 1023}, {1}, {63, 255, 1023}}); - ArgsProduct(b, {{63, 255, 1023}, {63, 255, 1023}, {1}}); + b->ArgsProduct({{1}, {63, 255, 1023}, {63, 255, 1023}}); + b->ArgsProduct({{63, 255, 1023}, {1}, {63, 255, 1023}}); + b->ArgsProduct({{63, 255, 1023}, {63, 255, 1023}, {1}}); } static void GemmSizeProducts(benchmark::internal::Benchmark* b) { b->ArgNames(sgemm_bench_arg_names); - ArgsProduct(b, {{63, 255, 1023}, {63, 255, 1023}, {63, 255, 1023}}); + b->ArgsProduct({{63, 255, 1023}, {63, 255, 1023}, {63, 255, 1023}}); } BENCHMARK_CAPTURE(SGEMM, NORMAL_NoTrans, false, false, false)->Apply(GemmSizeProducts)->UseRealTime(); @@ -128,7 +128,7 @@ BENCHMARK_CAPTURE(SGEMM, PACKB_TransA, true, true, false)->Apply(GemmSizeProduct static void GemmLLMSizeProducts(benchmark::internal::Benchmark* b) { b->ArgNames(sgemm_bench_arg_names); - ArgsProduct(b, {{1, 1024, 2048}, {4096, 11008}, {4096, 11008}}); + b->ArgsProduct({{1, 1024, 2048}, {4096, 11008}, {4096, 11008}}); } BENCHMARK_CAPTURE(SGEMM, LLM, false, false, true)->Apply(GemmLLMSizeProducts)->UseRealTime(); diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp index 0ca48dcea8d51..8a86f75f8cf65 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -10,27 +10,28 @@ #include "bench_util.h" #include "core/util/thread_utils.h" +#include "core/common/narrow.h" -template -void SQNBITGEMM(benchmark::State& state) { - if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!"); - if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!"); - if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!"); - if (state.range(3) <= 0) throw std::invalid_argument("Threads must greater than 0!"); +using onnxruntime::narrow; - const size_t M = static_cast(state.range(0)); - const size_t N = static_cast(state.range(1)); - const size_t K = static_cast(state.range(2)); - const size_t threads = static_cast(state.range(3)); +template +void SQNBITGEMM(benchmark::State& state) { + const auto BlkLen = narrow(state.range(0)); + const auto M = narrow(state.range(1)); + const auto N = narrow(state.range(2)); + const auto K = narrow(state.range(3)); + const auto Threads = narrow(state.range(4)); + const auto Symmetric = narrow(state.range(5)); + const auto ComputeType = static_cast(state.range(6)); size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes; MlasBlockwiseQuantizedBufferSizes( - BlkBitWidth, BlkLen, /* columnwise */ true, + BlkBitWidth, static_cast(BlkLen), /* columnwise */ true, static_cast(K), static_cast(N), QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); OrtThreadPoolParams tpo; - tpo.thread_pool_size = static_cast(threads); + tpo.thread_pool_size = static_cast(Threads); tpo.auto_set_affinity = true; std::unique_ptr tp( @@ -47,7 +48,7 @@ void SQNBITGEMM(benchmark::State& state) { MlasQuantizeBlockwise(QuantBData.data(), QuantBScale.data(), Symmetric ? nullptr : QuantBZeroPoint.data(), - B.data(), BlkLen, /* columnwise */ true, + B.data(), static_cast(BlkLen), /* columnwise */ true, static_cast(K), static_cast(N), static_cast(N), tp.get()); @@ -62,26 +63,35 @@ void SQNBITGEMM(benchmark::State& state) { params.ldc = N; // warm up run - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ¶ms, tp.get()); + MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, tp.get()); for (auto _ : state) { - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ¶ms, tp.get()); + MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, tp.get()); } } -static void GemmSizeProducts(benchmark::internal::Benchmark* b) { - b->ArgNames({"M", "N", "K", "Threads"}); - ArgsProduct(b, {{1, 1024, 2048}, {4096, 11008}, {4096, 11008}, {8}}); +static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) { + b->ArgNames({"BlkLen", "M", "N", "K", "Threads", "Symmetric", "ComputeType"}); + + ArgsProductWithFilter(b, + + {{16, 32, 64, 128, 256}, // BlkLen + {1, 1024, 2048}, // M + {4096, 11008}, // N + {4096, 11008}, // K + {8}, // Threads + {int64_t{false}, int64_t{true}}, // Symmetric + {int64_t{CompFp32}, int64_t{CompInt8}}}, // ComputeType + + [](const std::vector& args) { + return MlasIsSQNBitGemmAvailable( + // M, N, K + narrow(args[1]), narrow(args[2]), narrow(args[3]), + // BlkBitWidth, BlkLen + 4, narrow(args[0]), + // ComputeType + static_cast(args[6])); + }); } -// TODO disable if unavailable -BENCHMARK(SQNBITGEMM<4, 16, false>)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK(SQNBITGEMM<4, 16, true>)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK(SQNBITGEMM<4, 32, false>)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK(SQNBITGEMM<4, 32, true>)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK(SQNBITGEMM<4, 64, false>)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK(SQNBITGEMM<4, 64, true>)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK(SQNBITGEMM<4, 128, false>)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK(SQNBITGEMM<4, 128, true>)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK(SQNBITGEMM<4, 256, false>)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK(SQNBITGEMM<4, 256, true>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4>)->Apply(SQNBitGemmArgs)->UseRealTime(); diff --git a/onnxruntime/test/mlas/bench/bench_util.cpp b/onnxruntime/test/mlas/bench/bench_util.cpp index b79cd3a2a40aa..d57564615b04e 100644 --- a/onnxruntime/test/mlas/bench/bench_util.cpp +++ b/onnxruntime/test/mlas/bench/bench_util.cpp @@ -23,10 +23,9 @@ std::vector RandomVectorUniform(std::vector shape, float min_val return RandomVectorUniform(static_cast(sz), min_value, max_value); } -// The Benchmark used here do not contains this as in newer version. -// Use the code from newer version. -void ArgsProduct(benchmark::internal::Benchmark* bench, - const std::vector>& arglists) { +void ArgsProductWithFilter(benchmark::internal::Benchmark* bench, + const std::vector>& arglists, + std::function& args)> include_filter) { std::vector indices(arglists.size(), 0); const std::size_t total = std::accumulate( std::begin(arglists), std::end(arglists), std::size_t{1}, @@ -39,7 +38,9 @@ void ArgsProduct(benchmark::internal::Benchmark* bench, for (std::size_t arg = 0; arg < arglists.size(); arg++) { args.push_back(arglists[arg][indices[arg]]); } - bench->Args(args); + if (include_filter(args)) { + bench->Args(args); + } args.clear(); std::size_t arg = 0; diff --git a/onnxruntime/test/mlas/bench/bench_util.h b/onnxruntime/test/mlas/bench/bench_util.h index a2b49e117da38..ee2ec42d0f755 100644 --- a/onnxruntime/test/mlas/bench/bench_util.h +++ b/onnxruntime/test/mlas/bench/bench_util.h @@ -5,10 +5,14 @@ #include +#include #include -void ArgsProduct(benchmark::internal::Benchmark* bench, - const std::vector>& arglists); +// Specifies benchmark arguments from the cartesian product of `arglists`, like Benchmark::ArgsProduct(). +// `include_filter` is called to determine whether a given set of arguments should be included. +void ArgsProductWithFilter(benchmark::internal::Benchmark* bench, + const std::vector>& arglists, + std::function& args)> include_filter); template std::vector RandomVectorUniform( From b59e7e13c186e8374d6fcd88475af5f91ec08ca3 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Thu, 14 Dec 2023 16:36:38 -0800 Subject: [PATCH 06/31] handle workspace in benchmark --- onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp index 8a86f75f8cf65..2d05c4fc55340 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -4,7 +4,9 @@ #include "mlas_q4.h" #include "mlas_qnbit.h" +#include #include +#include #include "benchmark/benchmark.h" @@ -52,6 +54,12 @@ void SQNBITGEMM(benchmark::State& state) { static_cast(K), static_cast(N), static_cast(N), tp.get()); + std::unique_ptr Workspace; + if (const auto WorkspaceSize = MlasSQNBitGemmWorkspaceSize(M, N, K, BlkBitWidth, BlkLen, ComputeType); + WorkspaceSize > 0) { + Workspace = std::make_unique(WorkspaceSize); + } + MLAS_SQNBIT_GEMM_DATA_PARAMS params{}; params.A = A.data(); params.lda = K; @@ -61,6 +69,7 @@ void SQNBITGEMM(benchmark::State& state) { params.Bias = nullptr; params.C = C.data(); params.ldc = N; + params.Workspace = Workspace.get(); // warm up run MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, tp.get()); From 585103bed3bd822c3b3bcc0e52f4a93f38866aee Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Thu, 14 Dec 2023 16:49:29 -0800 Subject: [PATCH 07/31] QuantizeARow neon impl1 --- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 19 +- onnxruntime/core/mlas/lib/sqnbitgemm.h | 6 +- .../core/mlas/lib/sqnbitgemm_kernel_neon.cpp | 178 ++++++++++++++---- .../test/mlas/unittest/test_sqnbitgemm.cpp | 4 +- 4 files changed, 156 insertions(+), 51 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index a4d6aedc83ca6..06e9e2f24a38f 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -88,7 +88,7 @@ MlasIsSQNBitGemmAvailable( } case SQNBitGemmVariant_BitWidth4_CompInt8: { return dispatch->SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8 != nullptr && - dispatch->QuantizeA_CompInt8 != nullptr; + dispatch->QuantizeARow_CompInt8 != nullptr; } default: { return false; @@ -366,13 +366,24 @@ InitializeWorkspace_CompInt8( MLAS_UNREFERENCED_PARAMETER(ThreadPool); - const auto QuantizeA = GetMlasPlatform().SQNBitGemmDispatch->QuantizeA_CompInt8; + const auto QuantizeARow = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8; + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); // TODO use threading for (size_t gemm_idx = 0; gemm_idx < BatchN; ++gemm_idx) { - auto& data = DataParams[gemm_idx]; + const auto& data = DataParams[gemm_idx]; + + const float* ARowPtr = data.A; + std::byte* QuantARowPtr = static_cast(data.Workspace); - QuantizeA(BlkLen, data.A, M, K, K, static_cast(data.Workspace)); + for (size_t m = 0; m < M; ++m) { + QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); + + ARowPtr += data.lda; + QuantARowPtr += QuantAStride; + } } } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index b23955183f34c..471a15d570e84 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -186,14 +186,12 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8_Fn* SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8 = nullptr; - typedef void(QuantizeA_CompInt8_Fn)( + typedef void(QuantizeARow_CompInt8_Fn)( size_t BlkLen, const float* A, - size_t CountM, size_t CountK, - size_t lda, std::byte* QuantA ); - QuantizeA_CompInt8_Fn* QuantizeA_CompInt8 = nullptr; + QuantizeARow_CompInt8_Fn* QuantizeARow_CompInt8 = nullptr; }; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index 38342febbf8ff..474b43d47eaf8 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -64,7 +64,7 @@ FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3) template MLAS_FORCEINLINE void -LoadData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4]) +LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4]) { static_assert(Capacity % 4 == 0, "Capacity must be divisible by 4."); @@ -166,7 +166,7 @@ ComputeDotProducts( // load `SubBlkLen` elements from A, padded with 0's if there aren't enough const size_t k_subblk_len = std::min(k_blk_len - k_idx_in_blk, SubBlkLen); float32x4_t av[4]{}; - LoadData(ARowPtr + k + k_idx_in_blk, k_subblk_len, av); + LoadFloatData(ARowPtr + k + k_idx_in_blk, k_subblk_len, av); // load B column vectors uint8x8_t bv_packed[NCols]; @@ -413,65 +413,161 @@ MlasQNBitBlkDequantBForSgemmNeon( // CompInt8 kernel implementation and related helpers // -void MLASCALL -QuantizeA_CompInt8( +template +MLAS_FORCEINLINE void +QuantizeBlock( size_t BlkLen, const float* A, - size_t CountM, - size_t CountK, - size_t lda, + size_t ElementCount, std::byte* QuantA ) { - auto impl0_reference = [&]() { - const size_t BlockCountK = MlasDivRoundup(CountK, BlkLen); + static_assert(SubBlkLen >= 16 && SubBlkLen % 16 == 0); - const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); + assert(BlkLen % SubBlkLen == 0); - for (size_t m = 0; m < CountM; ++m) { - const float* ADataRowPtr = A + m * lda; - std::byte* QuantARowPtr = QuantA + m * QuantAStride; + constexpr size_t VectorCount = SubBlkLen / 4; - for (size_t k = 0, k_blk = 0; k < CountK; k += BlkLen, ++k_blk) { - const size_t k_blk_len = std::min(CountK - k, BlkLen); + // + // Scan block values first to determine scale. + // - const float* ADataBlkPtr = ADataRowPtr + k; + float amax = 0.0f; // max of absolute values of A block - // scan block values first to determine scale + size_t k; + for (k = 0; k < ElementCount; k += SubBlkLen) { + const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); - float amax = 0.0f; // max of absolute values of A block + float32x4_t a[VectorCount]{}; + LoadFloatData(A + k, SubBlkElementCount, a); - for (size_t kk = 0; kk < k_blk_len; ++kk) { - float a = ADataBlkPtr[kk]; - amax = std::max(amax, fabsf(a)); - } + float32x4_t abs_a[VectorCount]; + UnrolledLoop([&](size_t i) { + abs_a[i] = vabsq_f32(a[i]); + }); - constexpr float range_max = (1 << 7) - 1; - const float scale = amax / range_max; - const float scale_reciprocal = scale != 0.0f ? 1.0f / scale : 0.0f; + // find amax of SubBlkLen elements + for (size_t interval = VectorCount / 2; interval > 0; interval /= 2) { + for (size_t i = 0; i < interval; ++i) { + abs_a[i] = vmaxq_f32(abs_a[i], abs_a[i + interval]); + } + } - std::byte* QuantABlkPtr = QuantARowPtr + k_blk * Q8BlkSize(BlkLen); + // update existing amax + amax = std::max(amax, vmaxvq_f32(abs_a[0])); + } - Q8BlkScale(QuantABlkPtr) = scale; - int8_t* QuantABlkData = Q8BlkData(QuantABlkPtr); + constexpr float range_max = (1 << 7) - 1; + const float scale = amax / range_max; + const float scale_reciprocal = scale != 0.0f ? 1.0f / scale : 0.0f; - for (size_t kk = 0; kk < k_blk_len; ++kk) { - const float q = ADataBlkPtr[kk] * scale_reciprocal; - QuantABlkData[kk] = static_cast( - std::clamp( - q, - static_cast(std::numeric_limits::min()), - static_cast(std::numeric_limits::max()) - ) - ); - } + Q8BlkScale(QuantA) = scale; + + // + // Compute quantized block values. + // + + int8_t* QuantAData = Q8BlkData(QuantA); + + for (k = 0; k < ElementCount; k += SubBlkLen) { + const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); + + float32x4_t a[VectorCount]{}; + LoadFloatData(A + k, SubBlkElementCount, a); + + UnrolledLoop([&](size_t i) { + a[i] = vmulq_n_f32(a[i], scale_reciprocal); + }); + + int32x4_t a_s32[VectorCount]; + UnrolledLoop([&](size_t i) { + a_s32[i] = vcvtaq_s32_f32(a[i]); + }); + + UnrolledLoop([&](size_t i) { + QuantAData[k + i * 4 + 0] = static_cast(vgetq_lane_s32(a_s32[i], 0)); + QuantAData[k + i * 4 + 1] = static_cast(vgetq_lane_s32(a_s32[i], 1)); + QuantAData[k + i * 4 + 2] = static_cast(vgetq_lane_s32(a_s32[i], 2)); + QuantAData[k + i * 4 + 3] = static_cast(vgetq_lane_s32(a_s32[i], 3)); + }); + } + + // + // Zero out any remaining sub-block elements. + // + + for (; k < BlkLen; k += SubBlkLen) { + const int8x16_t Zeros = vdupq_n_s8(0); + UnrolledLoop([&](size_t i) { + vst1q_s8(QuantAData + k + i * 16, Zeros); + }); + } +} + +void MLASCALL +QuantizeARow_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA +) +{ + [[maybe_unused]] auto impl0_reference = [&]() { + const float* ADataRowPtr = A; + std::byte* QuantARowPtr = QuantA; + + for (size_t k = 0, k_blk = 0; k < CountK; k += BlkLen, ++k_blk) { + const size_t k_blk_len = std::min(CountK - k, BlkLen); + + const float* ADataBlkPtr = ADataRowPtr + k; + + // scan block values first to determine scale + + float amax = 0.0f; // max of absolute values of A block + + for (size_t kk = 0; kk < k_blk_len; ++kk) { + float a = ADataBlkPtr[kk]; + amax = std::max(amax, fabsf(a)); + } + + constexpr float range_max = (1 << 7) - 1; + const float scale = amax / range_max; + const float scale_reciprocal = scale != 0.0f ? 1.0f / scale : 0.0f; + + std::byte* QuantABlkPtr = QuantARowPtr + k_blk * Q8BlkSize(BlkLen); + + Q8BlkScale(QuantABlkPtr) = scale; + int8_t* QuantABlkData = Q8BlkData(QuantABlkPtr); + + for (size_t kk = 0; kk < k_blk_len; ++kk) { + const float q = roundf(ADataBlkPtr[kk] * scale_reciprocal); + QuantABlkData[kk] = static_cast( + std::clamp( + q, + static_cast(std::numeric_limits::min()), + static_cast(std::numeric_limits::max()) + ) + ); } } }; - // TODO neon impl + [[maybe_unused]] auto impl1 = [&]() { + const float* ADataBlkPtr = A; + std::byte* QuantABlkPtr = QuantA; - impl0_reference(); + for (size_t k = 0; k < CountK; k += BlkLen) { + const size_t k_blk_len = std::min(CountK - k, BlkLen); + + QuantizeBlock<16>(BlkLen, ADataBlkPtr, k_blk_len, QuantABlkPtr); + + ADataBlkPtr += BlkLen; + QuantABlkPtr += Q8BlkSize(BlkLen); + } + }; + + // impl0_reference(); + impl1(); } MLAS_FORCEINLINE @@ -544,7 +640,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { d.SQNBitGemmM1Kernel_BlkBitWidth4_CompFp32 = MlasSQNBitGemmM1KernelNeon<4>; d.QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32 = MlasQNBitBlkDequantBForSgemmNeon<4>; d.SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8 = SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8; - d.QuantizeA_CompInt8 = QuantizeA_CompInt8; + d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; return d; }(); diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index c442d1441b35d..2b3f5b4f01ee7 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -88,7 +88,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { std::copy_n(A + m * lda + k, local_blk_len, blk_a); float amax = 0.0f; // max of absolute values of A block - for (size_t kk = 0; kk < k_blk_len; ++kk) { + for (size_t kk = 0; kk < local_blk_len; ++kk) { float a = blk_a[kk]; amax = std::max(amax, fabsf(a)); } @@ -100,7 +100,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { QuantAScale[m * BlockCountK + k_blk] = scale; for (size_t kk = 0; kk < BlkLen; ++kk) { - const float q = blk_a[kk] * scale_reciprocal; + const float q = roundf(blk_a[kk] * scale_reciprocal); QuantAData[m * BlockCountK * BlkLen + k + kk] = static_cast( std::clamp(q, From c26cef4ff68f60513f015114c57509faaf79706f Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Thu, 14 Dec 2023 18:46:29 -0800 Subject: [PATCH 08/31] dot compint8 neon impl --- .../core/mlas/lib/sqnbitgemm_kernel_neon.cpp | 201 +++++++++++++++++- 1 file changed, 197 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index 474b43d47eaf8..075606d1e0ea5 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -566,10 +566,136 @@ QuantizeARow_CompInt8( } }; - // impl0_reference(); + //impl0_reference(); impl1(); } +template +MLAS_FORCEINLINE void +ComputeDotProducts_BlkBitWidth4_CompInt8( + size_t BlkLen, + const std::byte* QuantARowPtr, + const uint8_t* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const uint8_t* QuantBZeroPointColPtr, + float* SumPtr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* BiasPtr +) +{ + static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); + + constexpr size_t BlkBitWidth = 4; + + constexpr size_t SubBlkLen = 16; // number of block elements to process in a sub-block iteration + assert(BlkLen % SubBlkLen == 0); + + const uint8x8_t LowMask = vdup_n_u8(0x0F); + + const std::byte* QuantA = QuantARowPtr; + + const uint8_t* QuantBData = QuantBDataColPtr; + const float* QuantBScale = QuantBScaleColPtr; + size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + float32x4_t acc[NCols]{}; + + for (size_t k = 0; k < CountK; k += BlkLen) { + const size_t k_blk_len = std::min(CountK - k, BlkLen); + + const float a_scale = Q8BlkScale(QuantA); + const int8_t* a_data = Q8BlkData(QuantA); + + float b_scale[NCols]; + UnrolledLoop([&](size_t i) { b_scale[i] = QuantBScale[i * StrideQuantBScale]; }); + + int8_t b_zp[NCols]; + if (QuantBZeroPointColPtr != nullptr) { + UnrolledLoop([&](size_t i) { + const uint8_t zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + b_zp[i] = ((QuantBZeroPointIdx & 1) == 1) + ? static_cast(zp_packed >> 4) + : static_cast(zp_packed & 0x0F); + }); + } else { + UnrolledLoop([&](size_t i) { + b_zp[i] = 8; + }); + } + + for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { + // load A row vector + int8x16_t av = vld1q_s8(a_data + k_idx_in_blk); + + // load B column vectors + uint8x8_t bv_packed[NCols]; + UnrolledLoop([&](size_t i) { + const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; + bv_packed[i] = vld1_u8(QuantBData + i * StrideQuantBData + b_data_block_offset); + }); + + uint8x8_t bv_u8_unzipped[NCols][2]; + UnrolledLoop([&](size_t i) { + bv_u8_unzipped[i][0] = vand_u8(bv_packed[i], LowMask); + bv_u8_unzipped[i][1] = vand_u8(vshr_n_u8(bv_packed[i], 4), LowMask); + }); + + int8x16_t bv[NCols]; + UnrolledLoop([&](size_t i) { + const int8x8_t lo = vreinterpret_s8_u8(vzip1_u8(bv_u8_unzipped[i][0], bv_u8_unzipped[i][1])); + const int8x8_t hi = vreinterpret_s8_u8(vzip2_u8(bv_u8_unzipped[i][0], bv_u8_unzipped[i][1])); + bv[i] = vcombine_s8(lo, hi); + }); + + // subtract B zero point + UnrolledLoop([&](size_t i) { + const int8x16_t zp_v = vdupq_n_s8(b_zp[i]); + bv[i] = vsubq_s8(bv[i], zp_v); + }); + + // compute quantized dot product + int32x4_t dot[NCols]; + UnrolledLoop([&](size_t i) { + const int32x4_t zero_v = vdupq_n_s32(0); + dot[i] = vdotq_s32(zero_v, av, bv[i]); + }); + + // convert to float and add to `acc` + UnrolledLoop([&](size_t i) { + const float32x4_t scale_v = vdupq_n_f32(a_scale * b_scale[i]); + acc[i] = vfmaq_f32(acc[i], vcvtq_f32_s32(dot[i]), scale_v); + }); + } + + // increment pointers to next block + QuantA += Q8BlkSize(BlkLen); + QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + QuantBScale += 1; + QuantBZeroPointIdx += 1; + } + + if constexpr (NCols == 4) { + float32x4_t sum = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + + if (BiasPtr != nullptr) { + sum = vaddq_f32(sum, vld1q_f32(BiasPtr)); + } + + vst1q_f32(SumPtr, sum); + } else { + for (size_t i = 0; i < NCols; ++i) { + SumPtr[i] = vaddvq_f32(acc[i]); + if (BiasPtr != nullptr) { + SumPtr[i] += BiasPtr[i]; + } + } + } +} + MLAS_FORCEINLINE void SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8( @@ -585,7 +711,7 @@ SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8( const float* Bias ) { - auto impl0_reference = [&]() { + [[maybe_unused]] auto impl0_reference = [&]() { const std::byte* QuantARowPtr = QuantA; for (size_t n = 0; n < CountN; ++n) { @@ -623,9 +749,76 @@ SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8( } }; - // TODO neon impl + [[maybe_unused]] auto impl1 = [&]() { + constexpr size_t BlkBitWidth = 4; + constexpr size_t NCols = 4; - impl0_reference(); + const std::byte* QuantARowPtr = QuantA; + float* CRowPtr = C; + + const size_t BlockCountK = BlockStrideQuantB; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const uint8_t* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + int64_t nblk = static_cast(CountN) - NCols; + + while (nblk >= 0) { + ComputeDotProducts_BlkBitWidth4_CompInt8( + BlkLen, + QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next `NCols` columns + + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * StrideQuantBScale; + if (QuantBZeroPointColPtr != nullptr) { + QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols : 0; + SumPtr += NCols; + + nblk -= NCols; + } + + // left over columns less than `NCols`? + nblk += NCols; + for (int64_t n = 0; n < nblk; ++n) { + ComputeDotProducts_BlkBitWidth4_CompInt8<1>( + BlkLen, + QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if (QuantBZeroPointColPtr != nullptr) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + }; + + //impl0_reference(); + impl1(); } } // namespace From 1b7d81b43d24eeb7fb3cf157f23bf0593476cb14 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Fri, 15 Dec 2023 17:15:02 -0800 Subject: [PATCH 09/31] use single workspace pointer in interface, get matmul_nbits working --- .../cpu/quantization/matmul_nbits.cc | 15 +- onnxruntime/core/mlas/inc/mlas_qnbit.h | 13 +- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 150 ++++++++++++++---- onnxruntime/core/mlas/lib/sqnbitgemm.h | 7 +- .../test/mlas/bench/bench_sqnbitgemm.cpp | 7 +- .../test/mlas/unittest/test_sqnbitgemm.cpp | 5 +- 6 files changed, 146 insertions(+), 51 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 320a05bb97dac..f31f6917f9351 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -67,7 +67,8 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(false); - if (MlasIsSQNBitGemmAvailable(nbits_, block_size_)) { + const MLAS_SQNBITGEMM_COMPUTE_TYPE compute_type = CompFp32; + if (MlasIsSQNBitGemmAvailable(M, N, K, nbits_, block_size_, compute_type)) { // number of bytes or elements between adjacent matrices size_t b_data_matrix_stride_in_bytes, b_scale_matrix_stride, b_zero_point_matrix_stride_in_bytes; MlasBlockwiseQuantizedBufferSizes(static_cast(nbits_), static_cast(block_size_), /* columnwise */ true, @@ -77,6 +78,15 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t b_matrix_size = K * N; + IAllocatorUniquePtr workspace{}; + if (const size_t workspace_size = MlasSQNBitGemmWorkspaceSize(M, N, K, batch_count, + nbits_, block_size_, compute_type); + workspace_size > 0) { + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); + workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); + } + InlinedVector data(batch_count); for (size_t i = 0; i < batch_count; ++i) { const size_t b_matrix_offset = helper.RightOffsets()[i] / b_matrix_size; @@ -92,7 +102,8 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { data[i].ldc = N; } - MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, data.data(), thread_pool); + MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), + thread_pool); return Status::OK(); } diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 8de6670203a65..f7f8c875287de 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -42,12 +42,6 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { float* C = nullptr; ///< address of result matrix size_t ldc = 0; ///< leading dimension of C - /** - * Address of intermediate workspace buffer. - * Only required if MlasSQNBitGemmWorkspaceSize returns a non-zero value. - */ - void* Workspace = nullptr; - ///< optional post processing to apply to result matrix MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; }; @@ -63,9 +57,12 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { * @param[in] BatchN number of batches * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] Workspace Address of intermediate workspace buffer. + If MlasSQNBitGemmWorkspaceSize() returns a non-zero value, this should be a buffer + with at least that many bytes. Otherwise, it can be nullptr. * @param[in] ThreadPool optional thread pool to use - * // TODO update param doc */ void MLASCALL MlasSQNBitGemmBatch( @@ -77,6 +74,7 @@ MlasSQNBitGemmBatch( size_t BlkLen, MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType, const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, MLAS_THREADPOOL* ThreadPool = nullptr ); @@ -106,6 +104,7 @@ MlasSQNBitGemmWorkspaceSize( size_t M, size_t N, size_t K, + size_t BatchN, size_t BlkBitWidth, size_t BlkLen, MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 06e9e2f24a38f..01ccef971a9e9 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -74,21 +74,21 @@ MlasIsSQNBitGemmAvailable( MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType ) { - const auto* dispatch = GetMlasPlatform().SQNBitGemmDispatch; - if (dispatch == nullptr) { + const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + if (Dispatch == nullptr) { return false; } - const auto variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); - switch (variant) { + switch (Variant) { case SQNBitGemmVariant_BitWidth4_CompFp32: { - return dispatch->SQNBitGemmM1Kernel_BlkBitWidth4_CompFp32 != nullptr && - dispatch->QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32 != nullptr; + return Dispatch->SQNBitGemmM1Kernel_BlkBitWidth4_CompFp32 != nullptr && + Dispatch->QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32 != nullptr; } case SQNBitGemmVariant_BitWidth4_CompInt8: { - return dispatch->SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8 != nullptr && - dispatch->QuantizeARow_CompInt8 != nullptr; + return Dispatch->SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8 != nullptr && + Dispatch->QuantizeARow_CompInt8 != nullptr; } default: { return false; @@ -96,25 +96,39 @@ MlasIsSQNBitGemmAvailable( } } -size_t MLASCALL -MlasSQNBitGemmWorkspaceSize( +namespace +{ + +size_t +SQNBitGemmWorkspaceAlignment(SQNBitGemmVariant Variant) +{ + switch (Variant) { + case SQNBitGemmVariant_BitWidth4_CompInt8: { + return Q8BlkAlignment(); + } + default: { + return 1; + } + } +} + +size_t +SQNBitGemmPerGemmWorkspaceSize( + SQNBitGemmVariant Variant, size_t M, size_t N, size_t K, - size_t BlkBitWidth, - size_t BlkLen, - MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType + size_t BlkLen ) { - const auto variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); + MLAS_UNREFERENCED_PARAMETER(N); - switch (variant) { + switch (Variant) { case SQNBitGemmVariant_BitWidth4_CompInt8: { // workspace buffer is used for block quantization of A to int8 const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t RequiredBufferSize = M * BlockCountK * Q8BlkSize(BlkLen); - const size_t RequiredAlignment = Q8BlkAlignment(BlkLen); - return (RequiredBufferSize + RequiredAlignment - 1) / RequiredAlignment * RequiredAlignment; + const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); + return PerGemmWorkspaceSize; } default: { return 0; @@ -122,6 +136,47 @@ MlasSQNBitGemmWorkspaceSize( } } +size_t +SQNBitGemmPerGemmWorkspaceStride( + SQNBitGemmVariant Variant, + size_t M, + size_t N, + size_t K, + size_t BlkLen +) +{ + const auto Size = SQNBitGemmPerGemmWorkspaceSize(Variant, M, N, K, BlkLen); + const auto Alignment = SQNBitGemmWorkspaceAlignment(Variant); + return MlasDivRoundup(Size, Alignment) * Alignment; +} + +} // namespace + +size_t MLASCALL +MlasSQNBitGemmWorkspaceSize( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType +) +{ + const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); + + const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen); + if (PerGemmWorkspaceStride == 0) { + return 0; + } + + const size_t Alignment = SQNBitGemmWorkspaceAlignment(Variant); + + const size_t WorkspaceSize = BatchN * PerGemmWorkspaceStride; + + return WorkspaceSize + Alignment - 1; +} + namespace { @@ -152,13 +207,14 @@ AddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t } typedef void(SQNBitGemmFn)( - const size_t BlkLen, - const size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN + size_t BlkLen, + size_t K, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + void* PerGemmWorkspace, + size_t RangeStartM, + size_t RangeCountM, + size_t RangeStartN, + size_t RangeCountN ); void @@ -166,6 +222,7 @@ SQNBitGemm_BlkBitWidth4_CompFp32( const size_t BlkLen, const size_t K, const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + void* const PerGemmWorkspace, const size_t RangeStartM, const size_t RangeCountM, const size_t RangeStartN, @@ -174,6 +231,8 @@ SQNBitGemm_BlkBitWidth4_CompFp32( { constexpr size_t BlkBitWidth = 4; + MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspace); + const size_t lda = DataParams->lda; const size_t ldc = DataParams->ldc; @@ -282,6 +341,7 @@ SQNBitGemm_BlkBitWidth4_CompInt8( const size_t BlkLen, const size_t K, const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + void* const PerGemmWorkspace, const size_t RangeStartM, const size_t RangeCountM, const size_t RangeStartN, @@ -297,7 +357,7 @@ SQNBitGemm_BlkBitWidth4_CompInt8( const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); - const std::byte* QuantA = static_cast(DataParams->Workspace) + RangeStartM * lda; + const std::byte* QuantA = static_cast(PerGemmWorkspace) + RangeStartM * lda; const uint8_t* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; @@ -348,6 +408,8 @@ typedef void(InitializeWorkspaceFn)( size_t BatchN, size_t BlkLen, const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + size_t PerGemmWorkspaceStride, MLAS_THREADPOOL* ThreadPool ); @@ -359,6 +421,8 @@ InitializeWorkspace_CompInt8( size_t BatchN, size_t BlkLen, const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + size_t PerGemmWorkspaceStride, MLAS_THREADPOOL* ThreadPool ) { @@ -376,7 +440,7 @@ InitializeWorkspace_CompInt8( const auto& data = DataParams[gemm_idx]; const float* ARowPtr = data.A; - std::byte* QuantARowPtr = static_cast(data.Workspace); + std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; for (size_t m = 0; m < M; ++m) { QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); @@ -415,23 +479,42 @@ MlasSQNBitGemmBatch( const size_t BlkLen, MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType, const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, MLAS_THREADPOOL* ThreadPool ) { const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); assert(Variant != SQNBitGemmVariantInvalid); + // + // Ensure `Workspace` has correct alignment. + // + if (Workspace != nullptr) { + const size_t Alignment = SQNBitGemmWorkspaceAlignment(Variant); + const uintptr_t WorkspaceAddress = reinterpret_cast(Workspace); + Workspace = reinterpret_cast( + (WorkspaceAddress + Alignment - 1) & (~(Alignment - 1)) + ); + } + + const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen); + if (const auto InitializeWorkspaceOperation = OperationMap[Variant].InitializeWorkspace; InitializeWorkspaceOperation != nullptr) { - InitializeWorkspaceOperation(M, N, K, BatchN, BlkLen, DataParams, ThreadPool); + InitializeWorkspaceOperation( + M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool + ); } const auto ComputeOperation = OperationMap[Variant].SQNBitGemm; if (ThreadPool == nullptr) { for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { - auto Data = &DataParams[gemm_i]; - ComputeOperation(BlkLen, K, Data, 0, M, 0, N); + const auto* Data = &DataParams[gemm_i]; + void* PerGemmWorkspace = reinterpret_cast( + reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride + ); + ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, 0, M, 0, N); } return; } @@ -480,7 +563,10 @@ MlasSQNBitGemmBatch( MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { const auto gemm_i = tid / ThreadsPerGemm; const auto blk_i = tid % ThreadsPerGemm; - auto Data = &DataParams[gemm_i]; + const auto* Data = &DataParams[gemm_i]; + void* PerGemmWorkspace = reinterpret_cast( + reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride + ); const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; @@ -491,6 +577,6 @@ MlasSQNBitGemmBatch( const size_t RangeStartN = ThreadIdN * StrideN; const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); - ComputeOperation(BlkLen, K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); }); } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index 471a15d570e84..1691976bd8a71 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -31,6 +31,8 @@ Module Name: #include "mlas_qnbit.h" #include "mlasi.h" +#include + constexpr MLAS_FORCEINLINE size_t MlasQNBitBlkDataSizeInBytes(size_t BlkBitWidth, size_t BlkLen) { @@ -87,15 +89,14 @@ Q8BlkSize(size_t BlkLen) const size_t BlkSize = sizeof(float) + BlkLen * sizeof(int8_t); // Currently, the strictest alignment requirement of a block is for a float. // Ensure contiguous blocks are suitably aligned. - // assert(BlkSize % alignof(float) == 0); // TODO needs include, put it in .cpp? + assert(BlkSize % alignof(float) == 0); return BlkSize; } MLAS_FORCEINLINE constexpr size_t -Q8BlkAlignment(size_t BlkLen) +Q8BlkAlignment() { - MLAS_UNREFERENCED_PARAMETER(BlkLen); return alignof(float); } diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp index 2d05c4fc55340..f671f925afaa6 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -55,7 +55,7 @@ void SQNBITGEMM(benchmark::State& state) { tp.get()); std::unique_ptr Workspace; - if (const auto WorkspaceSize = MlasSQNBitGemmWorkspaceSize(M, N, K, BlkBitWidth, BlkLen, ComputeType); + if (const auto WorkspaceSize = MlasSQNBitGemmWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); WorkspaceSize > 0) { Workspace = std::make_unique(WorkspaceSize); } @@ -69,13 +69,12 @@ void SQNBITGEMM(benchmark::State& state) { params.Bias = nullptr; params.C = C.data(); params.ldc = N; - params.Workspace = Workspace.get(); // warm up run - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, tp.get()); + MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); for (auto _ : state) { - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, tp.get()); + MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); } } diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 2b3f5b4f01ee7..8e1fab3e386f0 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -72,10 +72,9 @@ class MlasSQNBitGemmTest : public MlasTestBase { params.QuantBData = QuantBData; params.QuantBScale = QuantBScale; params.QuantBZeroPoint = QuantBZeroPoint; - params.Workspace = Workspace; params.PostProcessor = nullptr; - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Threadpool); + MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace, Threadpool); } void QuantizeA(size_t M, size_t K, const float* A, int8_t* QuantAData, float* QuantAScale) { @@ -252,7 +251,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { } void* Workspace = nullptr; - if (const auto WorkspaceSize = MlasSQNBitGemmWorkspaceSize(M, N, K, BlkBitWidth, BlkLen, ComputeType); + if (const auto WorkspaceSize = MlasSQNBitGemmWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); WorkspaceSize > 0) { Workspace = BufferWorkspace.GetBuffer(WorkspaceSize); } From 71bd3a92853af706319ab1c07cd152688189c1c9 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Wed, 27 Dec 2023 15:26:19 -0800 Subject: [PATCH 10/31] renaming and cleanup --- .../cpu/quantization/matmul_nbits.cc | 6 +- onnxruntime/core/mlas/inc/mlas_qnbit.h | 69 ++++++++++--------- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 14 ++-- .../test/mlas/bench/bench_sqnbitgemm.cpp | 2 +- .../test/mlas/unittest/test_sqnbitgemm.cpp | 2 +- 5 files changed, 50 insertions(+), 43 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 319fea9476536..de61bb88e6309 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -161,7 +161,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { gemm_params[i].C = y_data + helper.OutputOffsets()[i]; gemm_params[i].ldc = N; } - auto ws_size = MlasSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data()); + auto ws_size = MlasSQNBitsGemmBatchPackedBWorkspaceSize(M, N, K, max_len, gemm_params.data()); // workspace for activation process(dynamic quantization and others) auto ws_ptr = IAllocator::MakeUniquePtr(allocator, ws_size); MlasSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(), @@ -207,8 +207,8 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t b_matrix_size = K * N; IAllocatorUniquePtr workspace{}; - if (const size_t workspace_size = MlasSQNBitGemmWorkspaceSize(M, N, K, batch_count, - nbits_, block_size_, compute_type); + if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count, + nbits_, block_size_, compute_type); workspace_size > 0) { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 43f863e60a983..8fc2a2bb04536 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -23,24 +23,31 @@ Module Name: #include "mlas.h" #include "mlas_gemm_postprocessor.h" -// TODO add documentation -enum MLAS_SQNBITGEMM_COMPUTE_TYPE { - CompFp32, // fp32 A, fp32 accumulator - CompInt8, // int8 A, int32 accumulator -}; +/** + * @brief Define compute types of block quantization + */ +typedef enum { + CompUndef = 0, /*!< undef */ + CompFp32 = 1, /*!< input fp32, accumulator fp32 */ + CompFp16 = 2, /*!< input fp16, accumulator fp16 */ + CompBf16 = 3, /*!< input bf16, accumulator fp32 */ + CompInt8 = 4 /*!< input int8, accumulator int32 */ +} MLAS_SQNBIT_COMPUTE_TYPE; + +using MLAS_SQNBITGEMM_COMPUTE_TYPE = MLAS_SQNBIT_COMPUTE_TYPE; // TODO consolidate these /** * @brief Data parameters for float/n-bit quantized int GEMM routine. */ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { - const float* A = nullptr; ///< address of A (float32 matrix) - size_t lda = 0; ///< leading dimension of A - const void* QuantBData = nullptr; ///< address of quantized B (quantized n-bit int values) - const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block - const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block - const float* Bias = nullptr; ///< optional address of Bias, vector size N - float* C = nullptr; ///< address of result matrix - size_t ldc = 0; ///< leading dimension of C + const float* A = nullptr; ///< address of A (float32 matrix) + size_t lda = 0; ///< leading dimension of A + const void* QuantBData = nullptr; ///< address of quantized B (quantized n-bit int values) + const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block + const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block + const float* Bias = nullptr; ///< optional address of Bias, vector size N + float* C = nullptr; ///< address of result matrix + size_t ldc = 0; ///< leading dimension of C ///< optional post processing to apply to result matrix MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; @@ -60,8 +67,8 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) * @param[inout] DataParams An array (size BatchN) of parameter blocks * @param[in] Workspace Address of intermediate workspace buffer. - If MlasSQNBitGemmWorkspaceSize() returns a non-zero value, this should be a buffer - with at least that many bytes. Otherwise, it can be nullptr. + If MlasSQNBitGemmBatchWorkspaceSize() returns a non-zero value, this must be a + buffer with at least that many bytes. Otherwise, it may be nullptr. * @param[in] ThreadPool optional thread pool to use */ void MLASCALL @@ -80,9 +87,14 @@ MlasSQNBitGemmBatch( /** * @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform. + * Ensure that this returns true before calling MlasSQNBitGemmBatch(). + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block - * TODO update param doc + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ bool MLASCALL MlasIsSQNBitGemmAvailable( @@ -97,10 +109,16 @@ MlasIsSQNBitGemmAvailable( /** * @brief Gets the size in bytes of the intermediate workspace buffer required by the float32/quantized n-bit int GEMM * implementation. If zero, no intermediate workspace is required. - * // TODO update param doc + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ size_t MLASCALL -MlasSQNBitGemmWorkspaceSize( +MlasSQNBitGemmBatchWorkspaceSize( size_t M, size_t N, size_t K, @@ -110,17 +128,6 @@ MlasSQNBitGemmWorkspaceSize( MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType ); -/** - * @brief Define compute types of block quantization - */ -typedef enum { - CompUndef = 0, /*!< undef */ - CompFp32 = 1, /*!< input fp32, accumulator fp32 */ - CompFp16 = 2, /*!< input fp16, accumulator fp16 */ - CompBf16 = 3, /*!< input bf16, accumulator fp32 */ - CompInt8 = 4 /*!< input int8, accumulator int32 */ -} MLAS_SQNBIT_COMPUTE_TYPE; - /** * @brief Data parameters for NBits GEMM routine * C = A * B @@ -171,7 +178,7 @@ MlasNBitsGemmPackBSize( * @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor * one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where * they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up - * inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale + * inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale * (is_asym is false) and Zp(is_asym is true). * @param thread_pool */ @@ -218,7 +225,7 @@ MlasNBitsGemmUnPackB( * @return Workspace size in bytes */ size_t MLASCALL -MlasSQNBitsGemmBatchWorkspaceSize( +MlasSQNBitsGemmBatchPackedBWorkspaceSize( const size_t M, const size_t N, const size_t K, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 328cbabfe9ddb..59d1aa5aafdfc 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -16,12 +16,13 @@ Module Name: --*/ #include "sqnbitgemm.h" + +#include + #ifdef MLAS_JBLAS #include "jblas_gemm.h" #endif -#include - namespace { @@ -156,7 +157,7 @@ SQNBitGemmPerGemmWorkspaceStride( } // namespace size_t MLASCALL -MlasSQNBitGemmWorkspaceSize( +MlasSQNBitGemmBatchWorkspaceSize( size_t M, size_t N, size_t K, @@ -514,9 +515,8 @@ MlasSQNBitGemmBatch( if (ThreadPool == nullptr) { for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { const auto* Data = &DataParams[gemm_i]; - void* PerGemmWorkspace = reinterpret_cast( - reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride - ); + void* PerGemmWorkspace = + reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, 0, M, 0, N); } return; @@ -662,7 +662,7 @@ MlasNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, s } size_t MLASCALL -MlasSQNBitsGemmBatchWorkspaceSize( +MlasSQNBitsGemmBatchPackedBWorkspaceSize( const size_t M, const size_t N, const size_t K, diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp index c25fc844e34f4..0105cd75a917d 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -55,7 +55,7 @@ void SQNBITGEMM(benchmark::State& state) { tp.get()); std::unique_ptr Workspace; - if (const auto WorkspaceSize = MlasSQNBitGemmWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); + if (const auto WorkspaceSize = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); WorkspaceSize > 0) { Workspace = std::make_unique(WorkspaceSize); } diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 8e1fab3e386f0..8ede6ae09ab92 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -251,7 +251,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { } void* Workspace = nullptr; - if (const auto WorkspaceSize = MlasSQNBitGemmWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); + if (const auto WorkspaceSize = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); WorkspaceSize > 0) { Workspace = BufferWorkspace.GetBuffer(WorkspaceSize); } From f7127f9f695248098d6095fd3815d61f6ecf36cf Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Wed, 27 Dec 2023 17:05:11 -0800 Subject: [PATCH 11/31] try different comp types in matmulnbits --- .../cpu/quantization/matmul_nbits.cc | 78 ++++++++++--------- onnxruntime/core/mlas/inc/mlas_qnbit.h | 17 ++-- 2 files changed, 52 insertions(+), 43 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index de61bb88e6309..9a086c02539f3 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -195,45 +195,49 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(false); - const MLAS_SQNBITGEMM_COMPUTE_TYPE compute_type = CompFp32; - if (MlasIsSQNBitGemmAvailable(M, N, K, nbits_, block_size_, compute_type)) { - // number of bytes or elements between adjacent matrices - size_t b_data_matrix_stride_in_bytes, b_scale_matrix_stride, b_zero_point_matrix_stride_in_bytes; - MlasBlockwiseQuantizedBufferSizes(static_cast(nbits_), static_cast(block_size_), /* columnwise */ true, - static_cast(K), static_cast(N), - b_data_matrix_stride_in_bytes, b_scale_matrix_stride, - &b_zero_point_matrix_stride_in_bytes); - - const size_t b_matrix_size = K * N; - - IAllocatorUniquePtr workspace{}; - if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count, - nbits_, block_size_, compute_type); - workspace_size > 0) { - AllocatorPtr allocator; - ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); - workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); - } + for (int64_t accuracy_level = accuracy_level_; + accuracy_level >= static_cast(CompMostAccurate); + --accuracy_level) { + const auto compute_type = static_cast(accuracy_level); + if (MlasIsSQNBitGemmAvailable(M, N, K, nbits_, block_size_, compute_type)) { + // number of bytes or elements between adjacent matrices + size_t b_data_matrix_stride_in_bytes, b_scale_matrix_stride, b_zero_point_matrix_stride_in_bytes; + MlasBlockwiseQuantizedBufferSizes(static_cast(nbits_), static_cast(block_size_), /* columnwise */ true, + static_cast(K), static_cast(N), + b_data_matrix_stride_in_bytes, b_scale_matrix_stride, + &b_zero_point_matrix_stride_in_bytes); + + const size_t b_matrix_size = K * N; + + IAllocatorUniquePtr workspace{}; + if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count, + nbits_, block_size_, compute_type); + workspace_size > 0) { + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); + workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); + } + + InlinedVector data(batch_count); + for (size_t i = 0; i < batch_count; ++i) { + const size_t b_matrix_offset = helper.RightOffsets()[i] / b_matrix_size; + + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].QuantBData = b_data + b_matrix_offset * b_data_matrix_stride_in_bytes; + data[i].QuantBScale = scales_data + b_matrix_offset * b_scale_matrix_stride; + data[i].QuantBZeroPoint = zero_points_data != nullptr + ? zero_points_data + b_matrix_offset * b_zero_point_matrix_stride_in_bytes + : nullptr; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + } + + MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), + thread_pool); - InlinedVector data(batch_count); - for (size_t i = 0; i < batch_count; ++i) { - const size_t b_matrix_offset = helper.RightOffsets()[i] / b_matrix_size; - - data[i].A = a_data + helper.LeftOffsets()[i]; - data[i].lda = lda; - data[i].QuantBData = b_data + b_matrix_offset * b_data_matrix_stride_in_bytes; - data[i].QuantBScale = scales_data + b_matrix_offset * b_scale_matrix_stride; - data[i].QuantBZeroPoint = zero_points_data != nullptr - ? zero_points_data + b_matrix_offset * b_zero_point_matrix_stride_in_bytes - : nullptr; - data[i].C = y_data + helper.OutputOffsets()[i]; - data[i].ldc = N; + return Status::OK(); } - - MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), - thread_pool); - - return Status::OK(); } const size_t ldb = helper.Ldb(true); diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 8fc2a2bb04536..1637c1795a66a 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -24,14 +24,19 @@ Module Name: #include "mlas_gemm_postprocessor.h" /** - * @brief Define compute types of block quantization + * @brief Define compute types of block quantization, in order of decreasing accuracy. */ typedef enum { - CompUndef = 0, /*!< undef */ - CompFp32 = 1, /*!< input fp32, accumulator fp32 */ - CompFp16 = 2, /*!< input fp16, accumulator fp16 */ - CompBf16 = 3, /*!< input bf16, accumulator fp32 */ - CompInt8 = 4 /*!< input int8, accumulator int32 */ + CompUndef = 0, /*!< undef */ + CompFp32, /*!< input fp32, accumulator fp32 */ + CompFp16, /*!< input fp16, accumulator fp16 */ + CompBf16, /*!< input bf16, accumulator fp32 */ + CompInt8, /*!< input int8, accumulator int32 */ + + // special values that should be the first and last actual values + + CompMostAccurate = CompUndef, + CompLeastAccurate = CompInt8, } MLAS_SQNBIT_COMPUTE_TYPE; using MLAS_SQNBITGEMM_COMPUTE_TYPE = MLAS_SQNBIT_COMPUTE_TYPE; // TODO consolidate these From b3147c6cf97807f1544045d99d72d857a5133ed8 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Thu, 28 Dec 2023 12:11:58 -0800 Subject: [PATCH 12/31] rename enum, add doc --- .../cpu/quantization/matmul_nbits.cc | 2 +- onnxruntime/core/mlas/inc/mlas_qnbit.h | 22 +++++---- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 8 ++-- onnxruntime/core/mlas/lib/sqnbitgemm.h | 48 +++++++++++++------ .../test/mlas/bench/bench_sqnbitgemm.cpp | 4 +- .../test/mlas/unittest/test_sqnbitgemm.cpp | 14 +++--- 6 files changed, 60 insertions(+), 38 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 9a086c02539f3..bf4ee13ea4003 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -198,7 +198,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { for (int64_t accuracy_level = accuracy_level_; accuracy_level >= static_cast(CompMostAccurate); --accuracy_level) { - const auto compute_type = static_cast(accuracy_level); + const auto compute_type = static_cast(accuracy_level); if (MlasIsSQNBitGemmAvailable(M, N, K, nbits_, block_size_, compute_type)) { // number of bytes or elements between adjacent matrices size_t b_data_matrix_stride_in_bytes, b_scale_matrix_stride, b_zero_point_matrix_stride_in_bytes; diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 1637c1795a66a..416080a9eea30 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -27,11 +27,11 @@ Module Name: * @brief Define compute types of block quantization, in order of decreasing accuracy. */ typedef enum { - CompUndef = 0, /*!< undef */ - CompFp32, /*!< input fp32, accumulator fp32 */ - CompFp16, /*!< input fp16, accumulator fp16 */ - CompBf16, /*!< input bf16, accumulator fp32 */ - CompInt8, /*!< input int8, accumulator int32 */ + CompUndef = 0, /*!< undef */ + CompFp32, /*!< input fp32, accumulator fp32 */ + CompFp16, /*!< input fp16, accumulator fp16 */ + CompBf16, /*!< input bf16, accumulator fp32 */ + CompInt8, /*!< input int8, accumulator int32 */ // special values that should be the first and last actual values @@ -39,7 +39,7 @@ typedef enum { CompLeastAccurate = CompInt8, } MLAS_SQNBIT_COMPUTE_TYPE; -using MLAS_SQNBITGEMM_COMPUTE_TYPE = MLAS_SQNBIT_COMPUTE_TYPE; // TODO consolidate these +using MLAS_SQNBIT_GEMM_COMPUTE_TYPE = MLAS_SQNBIT_COMPUTE_TYPE; // TODO consolidate these /** * @brief Data parameters for float/n-bit quantized int GEMM routine. @@ -63,6 +63,8 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { * A must be a float32 matrix * B must be a quantized and packed n-bit int matrix * + * Call MlasIsSQNBitGemmAvailable() with the same parameters to determine whether this function may be called. + * * @param[in] M row size of matrix A and C * @param[in] N column size of matrix B and C * @param[in] K column size of matrix A and row size of matrix B @@ -84,7 +86,7 @@ MlasSQNBitGemmBatch( size_t BatchN, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, MLAS_THREADPOOL* ThreadPool = nullptr @@ -92,7 +94,6 @@ MlasSQNBitGemmBatch( /** * @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform. - * Ensure that this returns true before calling MlasSQNBitGemmBatch(). * * @param[in] M row size of matrix A and C * @param[in] N column size of matrix B and C @@ -108,12 +109,13 @@ MlasIsSQNBitGemmAvailable( size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ); /** * @brief Gets the size in bytes of the intermediate workspace buffer required by the float32/quantized n-bit int GEMM * implementation. If zero, no intermediate workspace is required. + * * @param[in] M row size of matrix A and C * @param[in] N column size of matrix B and C * @param[in] K column size of matrix A and row size of matrix B @@ -130,7 +132,7 @@ MlasSQNBitGemmBatchWorkspaceSize( size_t BatchN, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ); /** diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 59d1aa5aafdfc..12f20721ea622 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -48,7 +48,7 @@ GetSQNBitGemmVariant( size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ) { MLAS_UNREFERENCED_PARAMETER(N); @@ -75,7 +75,7 @@ MlasIsSQNBitGemmAvailable( size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ) { const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; @@ -164,7 +164,7 @@ MlasSQNBitGemmBatchWorkspaceSize( size_t BatchN, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ) { const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); @@ -481,7 +481,7 @@ MlasSQNBitGemmBatch( const size_t BatchN, const size_t BlkBitWidth, const size_t BlkLen, - MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, MLAS_THREADPOOL* ThreadPool diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index 1691976bd8a71..8481140dc2722 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -10,19 +10,13 @@ Module Name: Abstract: - This module includes: // TODO update + This module includes kernel function prototypes and helper functions for + implementing SQNBitGemm. - - Declaration of the set of template functions used to implement a kernel - for a matrix/matrix multiplication, A*B, where A is a float matrix and B is - a n-bit quantized integer matrix (QNBitGemm). - - - A shared kernel driver function template, MlasSQNBitGemmOperation. - - - Kernel dispatch structure. - - The B matrix is block quantized, which means that its values are grouped - into blocks which each have one scale and optional zero point. Each - quantized value in B is n-bits wide. + SQNBitGemm is a matrix/matrix multiplication, A*B, where A is a float + matrix and B is a n-bit quantized integer matrix. B is block quantized, + meaning values of B are divided into blocks and each block has its own + scale and optional zero point. --*/ @@ -106,7 +100,7 @@ Q8BlkAlignment() struct MLAS_SQNBIT_GEMM_DISPATCH { // - // CompFp32 kernels + // CompFp32 kernel function prototypes. // /** @@ -169,9 +163,26 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32_Fn* QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32 = nullptr; // - // CompInt8 kernels + // CompInt8 kernel function prototypes. // + /** + * @brief Multiply quantized int8 matrix A with quantized n-bit integer matrix B. + * A and B are block quantized and B is column major. + * This kernel handles the special case where M, the number of rows of A and C, is 1. + * + * @param BlkLen Number of values in a block. + * @param QuantA Supplies the quantized A matrix. + Binary data containing block quantized int8 data and scale values. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + */ typedef void(SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8_Fn)( size_t BlkLen, const std::byte* QuantA, @@ -187,6 +198,15 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8_Fn* SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8 = nullptr; + /** + * @brief Block quantize values from one row of matrix A from float to int8. + * + * @param BlkLen Number of values in a block. + * @param A Supplies the A matrix. + * @param CountK Number of columns of A. + * @param[out] QuantA Supplies the output quantized A matrix. + * Binary data containing block quantized int8 data and scale values. + */ typedef void(QuantizeARow_CompInt8_Fn)( size_t BlkLen, const float* A, diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp index 0105cd75a917d..d0b83812b63b1 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -24,7 +24,7 @@ void SQNBITGEMM(benchmark::State& state) { const auto K = narrow(state.range(3)); const auto Threads = narrow(state.range(4)); const auto Symmetric = narrow(state.range(5)); - const auto ComputeType = static_cast(state.range(6)); + const auto ComputeType = static_cast(state.range(6)); size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes; MlasBlockwiseQuantizedBufferSizes( @@ -98,7 +98,7 @@ static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) { // BlkBitWidth, BlkLen 4, narrow(args[0]), // ComputeType - static_cast(args[6])); + static_cast(args[6])); }); } diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 8ede6ae09ab92..d2b99b9adaaae 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -18,7 +18,7 @@ Module Name: #include "mlas_q4.h" #include "mlas_qnbit.h" -static constexpr const char* ComputeTypeName(MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType) { +static constexpr const char* ComputeTypeName(MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType) { switch (ComputeType) { case CompFp32: return "Fp32"; @@ -61,7 +61,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { float* C, size_t ldc, void* Workspace, - MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, MLAS_THREADPOOL* Threadpool) { MLAS_SQNBIT_GEMM_DATA_PARAMS params; params.A = A; @@ -194,7 +194,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { public: void Test(size_t M, size_t N, size_t K, - MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, bool WithBias, bool Symmetric, bool WithThreadpool) { MLAS_THREADPOOL* Threadpool = WithThreadpool ? GetMlasThreadPool() : nullptr; @@ -294,7 +294,7 @@ template class SQNBitGemmShortExecuteTest : public MlasTestFixture> { public: explicit SQNBitGemmShortExecuteTest(size_t M, size_t N, size_t K, - MLAS_SQNBITGEMM_COMPUTE_TYPE ComputeType, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, bool WithThreadpool, bool Symmetric, bool WithBias) : M_(M), N_(N), @@ -311,7 +311,7 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture Date: Thu, 28 Dec 2023 14:12:10 -0800 Subject: [PATCH 13/31] change quant b params from uint8_t* to std::byte* --- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 24 ++-- onnxruntime/core/mlas/lib/sqnbitgemm.h | 12 +- .../core/mlas/lib/sqnbitgemm_kernel_neon.cpp | 107 ++++++++++-------- 3 files changed, 78 insertions(+), 65 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 12f20721ea622..3fcbb57aa8533 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -246,12 +246,12 @@ SQNBitGemm_BlkBitWidth4_CompFp32( const float* A = DataParams->A + RangeStartM * lda; - const uint8_t* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; + const std::byte* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; - const uint8_t* QuantBZeroPoint = + const std::byte* QuantBZeroPoint = (DataParams->QuantBZeroPoint == nullptr) ? nullptr - : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; float* C = DataParams->C + RangeStartM * ldc + RangeStartN; @@ -263,9 +263,9 @@ SQNBitGemm_BlkBitWidth4_CompFp32( CountN = std::min(RangeCountN - n, size_t{128}); const float* a_row = A; - const uint8_t* b_col = QuantBData + n * ldb; + const std::byte* b_col = QuantBData + n * ldb; const float* b_col_scale = QuantBScale + n * k_blks; - const uint8_t* b_col_zp = + const std::byte* b_col_zp = (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; @@ -301,9 +301,9 @@ SQNBitGemm_BlkBitWidth4_CompFp32( // Step through each slice of matrix A along the M dimension. // const float* a_row = A; - const uint8_t* b_col = QuantBData + n * ldb; + const std::byte* b_col = QuantBData + n * ldb; const float* b_col_scale = QuantBScale + n * k_blks; - const uint8_t* b_col_zp = + const std::byte* b_col_zp = (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; @@ -363,12 +363,12 @@ SQNBitGemm_BlkBitWidth4_CompInt8( const std::byte* QuantA = static_cast(PerGemmWorkspace) + RangeStartM * lda; - const uint8_t* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; + const std::byte* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; - const uint8_t* QuantBZeroPoint = + const std::byte* QuantBZeroPoint = (DataParams->QuantBZeroPoint == nullptr) ? nullptr - : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; float* C = DataParams->C + RangeStartM * ldc + RangeStartN; @@ -380,9 +380,9 @@ SQNBitGemm_BlkBitWidth4_CompInt8( CountN = std::min(RangeCountN - n, size_t{128}); const std::byte* a_row = QuantA; - const uint8_t* b_col = QuantBData + n * ldb; + const std::byte* b_col = QuantBData + n * ldb; const float* b_col_scale = QuantBScale + n * k_blks; - const uint8_t* b_col_zp = + const std::byte* b_col_zp = (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index 8481140dc2722..a02aa6987d518 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -122,9 +122,9 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { typedef void(SQNBitGemmM1Kernel_BlkBitWidth4_CompFp32_Fn)( size_t BlkLen, const float* A, - const uint8_t* QuantBData, + const std::byte* QuantBData, const float* QuantBScale, - const uint8_t* QuantBZeroPoint, + const std::byte* QuantBZeroPoint, float* C, size_t CountN, size_t CountK, @@ -152,9 +152,9 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { typedef void(QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32_Fn)( size_t BlkLen, float* FpData, - const uint8_t* QuantBData, + const std::byte* QuantBData, const float* QuantBScale, - const uint8_t* QuantBZeroPoint, + const std::byte* QuantBZeroPoint, size_t CountN, size_t CountK, size_t BlockStrideQuantB @@ -186,9 +186,9 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { typedef void(SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8_Fn)( size_t BlkLen, const std::byte* QuantA, - const uint8_t* QuantBData, + const std::byte* QuantBData, const float* QuantBScale, - const uint8_t* QuantBZeroPoint, + const std::byte* QuantBZeroPoint, float* C, size_t CountN, size_t CountK, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index 075606d1e0ea5..22dbff546a0ad 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -95,14 +95,14 @@ LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4]) } } -template +template MLAS_FORCEINLINE void -ComputeDotProducts( +ComputeDotProducts_BlkBitWidth4_CompFp32( size_t BlkLen, const float* ARowPtr, - const uint8_t* QuantBDataColPtr, + const std::byte* QuantBDataColPtr, const float* QuantBScaleColPtr, - const uint8_t* QuantBZeroPointColPtr, + const std::byte* QuantBZeroPointColPtr, float* SumPtr, size_t CountK, size_t StrideQuantBData, @@ -111,6 +111,8 @@ ComputeDotProducts( const float* BiasPtr ) { + constexpr size_t BlkBitWidth = 4; + static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); constexpr size_t SubBlkLen = 16; // number of block elements to process in a sub-block iteration @@ -133,7 +135,7 @@ ComputeDotProducts( float32x4_t acc[NCols]{}; - const uint8_t* QuantBData = QuantBDataColPtr; + const std::byte* QuantBData = QuantBDataColPtr; const float* QuantBScale = QuantBScaleColPtr; size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer @@ -148,10 +150,12 @@ ComputeDotProducts( float offset[NCols]; // Includes zero point and float conversion offset of 16. if (QuantBZeroPointColPtr != nullptr) { UnrolledLoop([&](size_t i) { - const uint8_t zp_packed = + const std::byte zp_packed = QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - const uint8_t zp = ((QuantBZeroPointIdx & 1) == 1) ? (zp_packed >> 4) : (zp_packed & 0x0F); - offset[i] = 16.0f + zp; + const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[i] = 16.0f + std::to_integer(zp); }); } else { UnrolledLoop([&](size_t i) { @@ -172,7 +176,9 @@ ComputeDotProducts( uint8x8_t bv_packed[NCols]; UnrolledLoop([&](size_t i) { const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; - bv_packed[i] = vld1_u8(QuantBData + i * StrideQuantBData + b_data_block_offset); + bv_packed[i] = vld1_u8( + reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset + ); }); uint8x8_t bv_u8_unzipped[NCols][2]; @@ -258,14 +264,13 @@ ComputeDotProducts( } } -template MLAS_FORCEINLINE void -MlasSQNBitGemmM1KernelNeon( +SQNBitGemmM1Kernel_BlkBitWidth4_CompFp32( size_t BlkLen, const float* A, - const uint8_t* QuantBData, + const std::byte* QuantBData, const float* QuantBScale, - const uint8_t* QuantBZeroPoint, + const std::byte* QuantBZeroPoint, float* C, size_t CountN, size_t CountK, @@ -273,6 +278,7 @@ MlasSQNBitGemmM1KernelNeon( const float* Bias ) { + constexpr size_t BlkBitWidth = 4; constexpr size_t NCols = 4; const float* ARowPtr = A; @@ -286,16 +292,16 @@ MlasSQNBitGemmM1KernelNeon( const float* BiasPtr = Bias; - const uint8_t* QuantBDataColPtr = QuantBData; + const std::byte* QuantBDataColPtr = QuantBData; const float* QuantBScaleColPtr = QuantBScale; - const uint8_t* QuantBZeroPointColPtr = QuantBZeroPoint; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; float* SumPtr = CRowPtr; int64_t nblk = static_cast(CountN) - NCols; while (nblk >= 0) { - ComputeDotProducts( + ComputeDotProducts_BlkBitWidth4_CompFp32( BlkLen, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, @@ -319,7 +325,7 @@ MlasSQNBitGemmM1KernelNeon( // left over columns less than `NCols`? nblk += NCols; for (int64_t n = 0; n < nblk; ++n) { - ComputeDotProducts( + ComputeDotProducts_BlkBitWidth4_CompFp32<1>( BlkLen, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, @@ -339,27 +345,26 @@ MlasSQNBitGemmM1KernelNeon( } } -template MLAS_FORCEINLINE void -MlasQNBitBlkDequantBForSgemmNeon( +QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32( size_t BlkLen, float* FpData, - const uint8_t* QuantBData, + const std::byte* QuantBData, const float* QuantBScale, - const uint8_t* QuantBZeroPoint, + const std::byte* QuantBZeroPoint, size_t CountN, size_t CountK, size_t BlockStrideQuantB ) { auto impl0_reference = [&]() { - static_assert(BlkBitWidth == 4); + constexpr size_t BlkBitWidth = 4; float* Dst = FpData; - const uint8_t* QuantBDataCol = QuantBData; + const std::byte* QuantBDataCol = QuantBData; const float* QuantBScaleCol = QuantBScale; - const uint8_t* QuantBZeroPointCol = QuantBZeroPoint; + const std::byte* QuantBZeroPointCol = QuantBZeroPoint; for (size_t n = 0; n < CountN; n += 16) { const size_t nnlen = std::min(CountN - n, size_t{16}); @@ -368,20 +373,20 @@ MlasQNBitBlkDequantBForSgemmNeon( for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, k_blk_idx += 1) { const size_t kklen = std::min(CountK - k, BlkLen); - const uint8_t* b_data = + const std::byte* b_data = QuantBDataCol + k_blk_idx * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); const float b_s = QuantBScaleCol[k_blk_idx]; const uint8_t b_z = (QuantBZeroPointCol != nullptr) ? ((k_blk_idx & 1) == 1) - ? QuantBZeroPointCol[k_blk_idx / 2] >> 4 - : QuantBZeroPointCol[k_blk_idx / 2] & 0x0F + ? std::to_integer(QuantBZeroPointCol[k_blk_idx / 2] >> 4) + : std::to_integer(QuantBZeroPointCol[k_blk_idx / 2] & std::byte{0x0F}) : 8; for (size_t kk = 0; kk < kklen; ++kk) { - const uint8_t b_packed = b_data[kk / 2]; - const uint8_t b_byte = ((kk & 1) == 1) ? b_packed >> 4 : b_packed & 0x0F; - const float b_value = (b_byte - b_z) * b_s; + const std::byte b_packed = b_data[kk / 2]; + const std::byte b_byte = ((kk & 1) == 1) ? b_packed >> 4 : b_packed & std::byte{0x0F}; + const float b_value = (std::to_integer(b_byte) - b_z) * b_s; Dst[(k + kk) * 16 + nn] = b_value; } @@ -575,9 +580,9 @@ MLAS_FORCEINLINE void ComputeDotProducts_BlkBitWidth4_CompInt8( size_t BlkLen, const std::byte* QuantARowPtr, - const uint8_t* QuantBDataColPtr, + const std::byte* QuantBDataColPtr, const float* QuantBScaleColPtr, - const uint8_t* QuantBZeroPointColPtr, + const std::byte* QuantBZeroPointColPtr, float* SumPtr, size_t CountK, size_t StrideQuantBData, @@ -597,7 +602,7 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( const std::byte* QuantA = QuantARowPtr; - const uint8_t* QuantBData = QuantBDataColPtr; + const std::byte* QuantBData = QuantBDataColPtr; const float* QuantBScale = QuantBScaleColPtr; size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer @@ -615,11 +620,11 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( int8_t b_zp[NCols]; if (QuantBZeroPointColPtr != nullptr) { UnrolledLoop([&](size_t i) { - const uint8_t zp_packed = + const std::byte zp_packed = QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; b_zp[i] = ((QuantBZeroPointIdx & 1) == 1) - ? static_cast(zp_packed >> 4) - : static_cast(zp_packed & 0x0F); + ? std::to_integer(zp_packed >> 4) + : std::to_integer(zp_packed & std::byte{0x0F}); }); } else { UnrolledLoop([&](size_t i) { @@ -635,7 +640,9 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( uint8x8_t bv_packed[NCols]; UnrolledLoop([&](size_t i) { const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; - bv_packed[i] = vld1_u8(QuantBData + i * StrideQuantBData + b_data_block_offset); + bv_packed[i] = vld1_u8( + reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset + ); }); uint8x8_t bv_u8_unzipped[NCols][2]; @@ -701,9 +708,9 @@ void SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8( size_t BlkLen, const std::byte* QuantA, - const uint8_t* QuantBData, + const std::byte* QuantBData, const float* QuantBScale, - const uint8_t* QuantBZeroPoint, + const std::byte* QuantBZeroPoint, float* C, size_t CountN, size_t CountK, @@ -728,8 +735,10 @@ SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8( int8_t b_zp = 8; if (QuantBZeroPoint != nullptr) { - const uint8_t b_zp_byte = QuantBZeroPoint[n * ((BlockStrideQuantB + 1) / 2) + k_blk / 2]; - b_zp = (k_blk & 1) ? static_cast(b_zp_byte >> 4) : static_cast(b_zp_byte & 0x0F); + const std::byte b_zp_byte = QuantBZeroPoint[n * ((BlockStrideQuantB + 1) / 2) + k_blk / 2]; + b_zp = (k_blk & 1) + ? std::to_integer(b_zp_byte >> 4) + : std::to_integer(b_zp_byte & std::byte{0x0F}); } int32_t qsum = 0; @@ -737,8 +746,12 @@ SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8( const int8_t* QuantABlkData = Q8BlkData(QuantABlkPtr); for (size_t kk = 0; kk < k_blk_len; ++kk) { const int8_t qa = QuantABlkData[kk]; - const uint8_t qb_byte = QuantBData[(n * BlockStrideQuantB * BlkLen + k + kk) / 2]; - const int8_t qb = ((kk & 1) == 1 ? static_cast(qb_byte >> 4) : static_cast(qb_byte & 0x0F)) - b_zp; + const std::byte qb_byte = QuantBData[(n * BlockStrideQuantB * BlkLen + k + kk) / 2]; + const int8_t qb = + ((kk & 1) == 1 + ? std::to_integer(qb_byte >> 4) + : std::to_integer(qb_byte & std::byte{0x0F})) - + b_zp; qsum += qa * qb; } @@ -764,9 +777,9 @@ SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8( const float* BiasPtr = Bias; - const uint8_t* QuantBDataColPtr = QuantBData; + const std::byte* QuantBDataColPtr = QuantBData; const float* QuantBScaleColPtr = QuantBScale; - const uint8_t* QuantBZeroPointColPtr = QuantBZeroPoint; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; float* SumPtr = CRowPtr; @@ -830,8 +843,8 @@ SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8( const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { MLAS_SQNBIT_GEMM_DISPATCH d; - d.SQNBitGemmM1Kernel_BlkBitWidth4_CompFp32 = MlasSQNBitGemmM1KernelNeon<4>; - d.QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32 = MlasQNBitBlkDequantBForSgemmNeon<4>; + d.SQNBitGemmM1Kernel_BlkBitWidth4_CompFp32 = SQNBitGemmM1Kernel_BlkBitWidth4_CompFp32; + d.QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32 = QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32; d.SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8 = SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8; d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; From 039dd92b6500b22e3f9d2b359dffb587cc257f67 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Thu, 28 Dec 2023 14:23:48 -0800 Subject: [PATCH 14/31] handle CompUndef --- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 3fcbb57aa8533..caed10489ad11 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -56,7 +56,8 @@ GetSQNBitGemmVariant( if (BlkBitWidth == 4 && (BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { - if (ComputeType == CompFp32) { + if (ComputeType == CompFp32 || + ComputeType == CompUndef) { // treat CompUndef (undefined) as CompFp32 return SQNBitGemmVariant_BitWidth4_CompFp32; } else if (ComputeType == CompInt8 && M == 1) { return SQNBitGemmVariant_BitWidth4_CompInt8; From cb9f42879307c22312215cc8754a622c899df1d1 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Fri, 29 Dec 2023 10:04:39 -0800 Subject: [PATCH 15/31] check if dot product instructions are available before setting SQNBitGemm neon kernel --- onnxruntime/core/mlas/lib/platform.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 8329a34f1338f..1310ed3f384b9 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -482,7 +482,6 @@ Return Value: this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchNeon; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; // // Check if the processor supports ASIMD dot product instructions. @@ -512,6 +511,9 @@ Return Value: this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchSdot; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchDot; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; + + // MlasSQNBitGemmDispatchNeon has a dependency on dot product instructions + this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; } #if defined(__linux__) From 437ad52abcc74d6091200a6c8fdc430cc1279c21 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Fri, 29 Dec 2023 10:53:18 -0800 Subject: [PATCH 16/31] try to fix compile issue --- cmake/onnxruntime_mlas.cmake | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 28dfcb91155f6..b995b27123218 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -58,7 +58,7 @@ endif() set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas) function(add_jblas) - add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas) + add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas) target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas) target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/jblas_gemm.cpp @@ -356,6 +356,8 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp ) + set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp + PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") if (NOT APPLE) set(mlas_platform_srcs ${mlas_platform_srcs} From 241ca27d03e5faaaedcb28011bd133b89c30499e Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Fri, 29 Dec 2023 13:50:17 -0800 Subject: [PATCH 17/31] move zero initialize out of unrolled loop --- onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index 22dbff546a0ad..6945505a46d7b 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -665,10 +665,9 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( }); // compute quantized dot product - int32x4_t dot[NCols]; + int32x4_t dot[NCols]{}; UnrolledLoop([&](size_t i) { - const int32x4_t zero_v = vdupq_n_s32(0); - dot[i] = vdotq_s32(zero_v, av, bv[i]); + dot[i] = vdotq_s32(dot[i], av, bv[i]); }); // convert to float and add to `acc` From 53e2ae292eb8a54fb4f903f595f551e03d365ce5 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Tue, 2 Jan 2024 11:33:57 -0800 Subject: [PATCH 18/31] update comment --- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index caed10489ad11..689d2b4ea6c21 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -433,6 +433,7 @@ InitializeWorkspace_CompInt8( { MLAS_UNREFERENCED_PARAMETER(N); + // Note: Multi-threading did not bring a significant performance gain. Using a single thread for simplicity. MLAS_UNREFERENCED_PARAMETER(ThreadPool); const auto QuantizeARow = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8; @@ -440,7 +441,6 @@ InitializeWorkspace_CompInt8( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); - // TODO use threading for (size_t gemm_idx = 0; gemm_idx < BatchN; ++gemm_idx) { const auto& data = DataParams[gemm_idx]; From d5b26b4d0aabdda456ec375d7f40ddbdfbcea09d Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Tue, 2 Jan 2024 13:03:03 -0800 Subject: [PATCH 19/31] split out float conversion --- onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index 6945505a46d7b..f644307054612 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -670,10 +670,16 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( dot[i] = vdotq_s32(dot[i], av, bv[i]); }); - // convert to float and add to `acc` + // convert dot product result to float + float32x4_t dot_f32[NCols]; + UnrolledLoop([&](size_t i) { + dot_f32[i] = vcvtq_f32_s32(dot[i]); + }); + + // multiply dot product result by scale and update accumulator UnrolledLoop([&](size_t i) { const float32x4_t scale_v = vdupq_n_f32(a_scale * b_scale[i]); - acc[i] = vfmaq_f32(acc[i], vcvtq_f32_s32(dot[i]), scale_v); + acc[i] = vfmaq_f32(acc[i], dot_f32[i], scale_v); }); } From 02cf7b37f0f10895d0f4b073efebfa5877cdad86 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Tue, 2 Jan 2024 13:14:47 -0800 Subject: [PATCH 20/31] remove impl0_reference --- .../core/mlas/lib/sqnbitgemm_kernel_neon.cpp | 206 +++++------------- 1 file changed, 56 insertions(+), 150 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index f644307054612..489fb80f7a8f8 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -517,62 +517,17 @@ QuantizeARow_CompInt8( std::byte* QuantA ) { - [[maybe_unused]] auto impl0_reference = [&]() { - const float* ADataRowPtr = A; - std::byte* QuantARowPtr = QuantA; + const float* ADataBlkPtr = A; + std::byte* QuantABlkPtr = QuantA; - for (size_t k = 0, k_blk = 0; k < CountK; k += BlkLen, ++k_blk) { - const size_t k_blk_len = std::min(CountK - k, BlkLen); - - const float* ADataBlkPtr = ADataRowPtr + k; - - // scan block values first to determine scale - - float amax = 0.0f; // max of absolute values of A block - - for (size_t kk = 0; kk < k_blk_len; ++kk) { - float a = ADataBlkPtr[kk]; - amax = std::max(amax, fabsf(a)); - } - - constexpr float range_max = (1 << 7) - 1; - const float scale = amax / range_max; - const float scale_reciprocal = scale != 0.0f ? 1.0f / scale : 0.0f; - - std::byte* QuantABlkPtr = QuantARowPtr + k_blk * Q8BlkSize(BlkLen); - - Q8BlkScale(QuantABlkPtr) = scale; - int8_t* QuantABlkData = Q8BlkData(QuantABlkPtr); - - for (size_t kk = 0; kk < k_blk_len; ++kk) { - const float q = roundf(ADataBlkPtr[kk] * scale_reciprocal); - QuantABlkData[kk] = static_cast( - std::clamp( - q, - static_cast(std::numeric_limits::min()), - static_cast(std::numeric_limits::max()) - ) - ); - } - } - }; - - [[maybe_unused]] auto impl1 = [&]() { - const float* ADataBlkPtr = A; - std::byte* QuantABlkPtr = QuantA; - - for (size_t k = 0; k < CountK; k += BlkLen) { - const size_t k_blk_len = std::min(CountK - k, BlkLen); - - QuantizeBlock<16>(BlkLen, ADataBlkPtr, k_blk_len, QuantABlkPtr); + for (size_t k = 0; k < CountK; k += BlkLen) { + const size_t k_blk_len = std::min(CountK - k, BlkLen); - ADataBlkPtr += BlkLen; - QuantABlkPtr += Q8BlkSize(BlkLen); - } - }; + QuantizeBlock<16>(BlkLen, ADataBlkPtr, k_blk_len, QuantABlkPtr); - //impl0_reference(); - impl1(); + ADataBlkPtr += BlkLen; + QuantABlkPtr += Q8BlkSize(BlkLen); + } } template @@ -723,120 +678,71 @@ SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8( const float* Bias ) { - [[maybe_unused]] auto impl0_reference = [&]() { - const std::byte* QuantARowPtr = QuantA; - - for (size_t n = 0; n < CountN; ++n) { - float sum = Bias != nullptr ? Bias[n] : 0.0f; - - for (size_t k = 0, k_blk = 0; k < CountK; k += BlkLen, ++k_blk) { - const size_t k_blk_len = std::min(CountK - k, BlkLen); - - const std::byte* QuantABlkPtr = QuantARowPtr + k_blk * Q8BlkSize(BlkLen); - - const float a_scale = Q8BlkScale(QuantABlkPtr); - - const float b_scale = QuantBScale[n * BlockStrideQuantB + k_blk]; - - int8_t b_zp = 8; - if (QuantBZeroPoint != nullptr) { - const std::byte b_zp_byte = QuantBZeroPoint[n * ((BlockStrideQuantB + 1) / 2) + k_blk / 2]; - b_zp = (k_blk & 1) - ? std::to_integer(b_zp_byte >> 4) - : std::to_integer(b_zp_byte & std::byte{0x0F}); - } - - int32_t qsum = 0; - - const int8_t* QuantABlkData = Q8BlkData(QuantABlkPtr); - for (size_t kk = 0; kk < k_blk_len; ++kk) { - const int8_t qa = QuantABlkData[kk]; - const std::byte qb_byte = QuantBData[(n * BlockStrideQuantB * BlkLen + k + kk) / 2]; - const int8_t qb = - ((kk & 1) == 1 - ? std::to_integer(qb_byte >> 4) - : std::to_integer(qb_byte & std::byte{0x0F})) - - b_zp; - qsum += qa * qb; - } - - sum += static_cast(qsum) * a_scale * b_scale; - } - - C[n] = sum; - } - }; + constexpr size_t BlkBitWidth = 4; + constexpr size_t NCols = 4; - [[maybe_unused]] auto impl1 = [&]() { - constexpr size_t BlkBitWidth = 4; - constexpr size_t NCols = 4; + const std::byte* QuantARowPtr = QuantA; + float* CRowPtr = C; - const std::byte* QuantARowPtr = QuantA; - float* CRowPtr = C; + const size_t BlockCountK = BlockStrideQuantB; - const size_t BlockCountK = BlockStrideQuantB; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + const float* BiasPtr = Bias; - const float* BiasPtr = Bias; + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + float* SumPtr = CRowPtr; - float* SumPtr = CRowPtr; + int64_t nblk = static_cast(CountN) - NCols; - int64_t nblk = static_cast(CountN) - NCols; + while (nblk >= 0) { + ComputeDotProducts_BlkBitWidth4_CompInt8( + BlkLen, + QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); - while (nblk >= 0) { - ComputeDotProducts_BlkBitWidth4_CompInt8( - BlkLen, - QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - BiasPtr - ); + // move to next `NCols` columns - // move to next `NCols` columns + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * StrideQuantBScale; + if (QuantBZeroPointColPtr != nullptr) { + QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; + } - QuantBDataColPtr += NCols * StrideQuantBData; - QuantBScaleColPtr += NCols * StrideQuantBScale; - if (QuantBZeroPointColPtr != nullptr) { - QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; - } + BiasPtr += BiasPtr != nullptr ? NCols : 0; + SumPtr += NCols; - BiasPtr += BiasPtr != nullptr ? NCols : 0; - SumPtr += NCols; + nblk -= NCols; + } - nblk -= NCols; - } + // left over columns less than `NCols`? + nblk += NCols; + for (int64_t n = 0; n < nblk; ++n) { + ComputeDotProducts_BlkBitWidth4_CompInt8<1>( + BlkLen, + QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); - // left over columns less than `NCols`? - nblk += NCols; - for (int64_t n = 0; n < nblk; ++n) { - ComputeDotProducts_BlkBitWidth4_CompInt8<1>( - BlkLen, - QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - BiasPtr - ); - - // move to next column - - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if (QuantBZeroPointColPtr != nullptr) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } + // move to next column - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if (QuantBZeroPointColPtr != nullptr) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; } - }; - //impl0_reference(); - impl1(); + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } } } // namespace From 5b4a86c7bf83878c05f9aac46c5d112c6a8dc80a Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Tue, 2 Jan 2024 14:55:01 -0800 Subject: [PATCH 21/31] use thread per gemm in prepare workspace fn, reorder include --- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 7 ++----- onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp | 4 ++-- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 689d2b4ea6c21..803ba4d62241a 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -433,15 +433,12 @@ InitializeWorkspace_CompInt8( { MLAS_UNREFERENCED_PARAMETER(N); - // Note: Multi-threading did not bring a significant performance gain. Using a single thread for simplicity. - MLAS_UNREFERENCED_PARAMETER(ThreadPool); - const auto QuantizeARow = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8; const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); - for (size_t gemm_idx = 0; gemm_idx < BatchN; ++gemm_idx) { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { const auto& data = DataParams[gemm_idx]; const float* ARowPtr = data.A; @@ -453,7 +450,7 @@ InitializeWorkspace_CompInt8( ARowPtr += data.lda; QuantARowPtr += QuantAStride; } - } + }); } struct Operations { diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index 489fb80f7a8f8..4d7362c01f7ae 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -15,14 +15,14 @@ Module Name: --*/ +#include "sqnbitgemm.h" + #include #include #include #include -#include "sqnbitgemm.h" - namespace { From 61998ea6d00a54bc6728221c096de9666dc6f4d2 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Wed, 3 Jan 2024 10:15:55 -0800 Subject: [PATCH 22/31] make pointer const --- onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index d2b99b9adaaae..a1f88e075f34a 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -198,7 +198,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { bool WithBias, bool Symmetric, bool WithThreadpool) { MLAS_THREADPOOL* Threadpool = WithThreadpool ? GetMlasThreadPool() : nullptr; - float* A = BufferA.GetBuffer(K * M); + const float* A = BufferA.GetBuffer(K * M); const float* B = BufferB.GetBuffer(N * K); From d54cbd96ccfebfb72d9196ed2519a0f3726b9263 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Wed, 10 Jan 2024 10:46:53 -0800 Subject: [PATCH 23/31] remove unneeded and --- onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index 4d7362c01f7ae..3e527d6a846f8 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -184,7 +184,7 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( uint8x8_t bv_u8_unzipped[NCols][2]; UnrolledLoop([&](size_t i) { bv_u8_unzipped[i][0] = vand_u8(bv_packed[i], LowMask); - bv_u8_unzipped[i][1] = vand_u8(vshr_n_u8(bv_packed[i], 4), LowMask); + bv_u8_unzipped[i][1] = vshr_n_u8(bv_packed[i], 4); }); uint8x8_t bv_u8[NCols][2]; @@ -603,7 +603,7 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( uint8x8_t bv_u8_unzipped[NCols][2]; UnrolledLoop([&](size_t i) { bv_u8_unzipped[i][0] = vand_u8(bv_packed[i], LowMask); - bv_u8_unzipped[i][1] = vand_u8(vshr_n_u8(bv_packed[i], 4), LowMask); + bv_u8_unzipped[i][1] = vshr_n_u8(bv_packed[i], 4); }); int8x16_t bv[NCols]; From 6d88a0b47dd0c8bbab8fe1bd4b55dd57c3fbb5e2 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Wed, 10 Jan 2024 14:05:44 -0800 Subject: [PATCH 24/31] move code from merge conflict --- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 803ba4d62241a..36d1466ca8479 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -316,7 +316,7 @@ SQNBitGemm_BlkBitWidth4_CompFp32( size_t RowsRemaining = RangeCountM; while (RowsRemaining > 0) { -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) auto RowsHandled = GetMlasPlatform().GemmFloatKernel( a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true ); From ccaa994778ec82bb2ed3bb3cdea4fa26c18a240a Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Wed, 10 Jan 2024 18:51:48 -0800 Subject: [PATCH 25/31] pack quant b data --- onnxruntime/core/mlas/inc/mlas_qnbit.h | 19 ++++ onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 87 +++++++++++++++++++ .../core/mlas/lib/sqnbitgemm_kernel_neon.cpp | 32 +++---- .../test/mlas/bench/bench_sqnbitgemm.cpp | 11 ++- .../test/mlas/unittest/test_sqnbitgemm.cpp | 25 ++++-- 5 files changed, 149 insertions(+), 25 deletions(-) diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 416080a9eea30..1f0899507e9d4 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -135,6 +135,25 @@ MlasSQNBitGemmBatchWorkspaceSize( MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ); +size_t MLASCALL +MlasSQNBitGemmPackQuantBDataSize( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen +); + +void MLASCALL +MlasSQNBitGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + const void* QuantBData, + void* PackedQuantBData, + MLAS_THREADPOOL* ThreadPool = nullptr +); + /** * @brief Data parameters for NBits GEMM routine * C = A * B diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 36d1466ca8479..e7a6f3296c7eb 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -182,6 +182,93 @@ MlasSQNBitGemmBatchWorkspaceSize( return WorkspaceSize + Alignment - 1; } +namespace +{ +void +SQNBitGemmPackQuantBData_BlkBitWidth4( + size_t N, + size_t K, + size_t BlkLen, + const std::byte* QuantBData, + std::byte* PackedQuantBData, + MLAS_THREADPOOL* ThreadPool +) +{ + MLAS_UNREFERENCED_PARAMETER(ThreadPool); // TODO use ThreadPool + + assert(BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + for (size_t n = 0; n < N; ++n) { + for (size_t k_blk = 0; k_blk < BlockCountK; ++k_blk) { + // + // Pack 16 4-bit values (8 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | + // => + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + // + for (size_t kk = 0; kk < BlkLen; kk += 16) { + for (size_t byte_pair_idx = 0; byte_pair_idx < 4; ++byte_pair_idx) { + const std::byte src0 = QuantBData[byte_pair_idx]; + const std::byte src1 = QuantBData[byte_pair_idx + 4]; + + std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; + std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; + + dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); + dst1 = (src0 >> 4) | ((src1 >> 4) << 4); + } + + QuantBData += 8; + PackedQuantBData += 8; + } + } + } +} +} + +size_t MLASCALL +MlasSQNBitGemmPackQuantBDataSize( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen +) +{ + if (BlkBitWidth == 4) { + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + return PackedQuantBDataSize; + } + + return 0; +} + +void MLASCALL +MlasSQNBitGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + const void* QuantBData, + void* PackedQuantBData, + MLAS_THREADPOOL* ThreadPool +) +{ + if (BlkBitWidth == 4) { + SQNBitGemmPackQuantBData_BlkBitWidth4( + N, + K, + BlkLen, + static_cast(QuantBData), + static_cast(PackedQuantBData), + ThreadPool + ); + } +} + namespace { diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index 3e527d6a846f8..bf791c99a5caf 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -181,16 +181,10 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( ); }); - uint8x8_t bv_u8_unzipped[NCols][2]; - UnrolledLoop([&](size_t i) { - bv_u8_unzipped[i][0] = vand_u8(bv_packed[i], LowMask); - bv_u8_unzipped[i][1] = vshr_n_u8(bv_packed[i], 4); - }); - uint8x8_t bv_u8[NCols][2]; UnrolledLoop([&](size_t i) { - bv_u8[i][0] = vzip1_u8(bv_u8_unzipped[i][0], bv_u8_unzipped[i][1]); - bv_u8[i][1] = vzip2_u8(bv_u8_unzipped[i][0], bv_u8_unzipped[i][1]); + bv_u8[i][0] = vand_u8(bv_packed[i], LowMask); + bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); }); // dequantize B @@ -384,9 +378,15 @@ QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32( : 8; for (size_t kk = 0; kk < kklen; ++kk) { - const std::byte b_packed = b_data[kk / 2]; - const std::byte b_byte = ((kk & 1) == 1) ? b_packed >> 4 : b_packed & std::byte{0x0F}; - const float b_value = (std::to_integer(b_byte) - b_z) * b_s; + const size_t packed_idx = kk % 16; + + const bool is_low_half = packed_idx < 8; + const size_t packed_byte_idx = packed_idx % 8; + const size_t packed_range_offset = (kk / 16) * 8; + + const std::byte b_packed = b_data[packed_range_offset + packed_byte_idx]; + const std::byte b_byte = is_low_half ? (b_packed & std::byte{0x0F}) : (b_packed >> 4); + const float b_value = (std::to_integer(b_byte) - b_z) * b_s; Dst[(k + kk) * 16 + nn] = b_value; } @@ -600,16 +600,10 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( ); }); - uint8x8_t bv_u8_unzipped[NCols][2]; - UnrolledLoop([&](size_t i) { - bv_u8_unzipped[i][0] = vand_u8(bv_packed[i], LowMask); - bv_u8_unzipped[i][1] = vshr_n_u8(bv_packed[i], 4); - }); - int8x16_t bv[NCols]; UnrolledLoop([&](size_t i) { - const int8x8_t lo = vreinterpret_s8_u8(vzip1_u8(bv_u8_unzipped[i][0], bv_u8_unzipped[i][1])); - const int8x8_t hi = vreinterpret_s8_u8(vzip2_u8(bv_u8_unzipped[i][0], bv_u8_unzipped[i][1])); + const int8x8_t lo = vreinterpret_s8_u8(vand_u8(bv_packed[i], LowMask)); + const int8x8_t hi = vreinterpret_s8_u8(vshr_n_u8(bv_packed[i], 4)); bv[i] = vcombine_s8(lo, hi); }); diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp index d0b83812b63b1..2a56d37b899f8 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -60,10 +60,19 @@ void SQNBITGEMM(benchmark::State& state) { Workspace = std::make_unique(WorkspaceSize); } + std::unique_ptr PackedQuantBData; + if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen); + PackedQuantBDataSize > 0) { + PackedQuantBData = std::make_unique(PackedQuantBDataSize); + MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, QuantBData.data(), PackedQuantBData.get(), tp.get()); + } + MLAS_SQNBIT_GEMM_DATA_PARAMS params{}; params.A = A.data(); params.lda = K; - params.QuantBData = QuantBData.data(); + params.QuantBData = PackedQuantBData != nullptr + ? static_cast(PackedQuantBData.get()) + : static_cast(QuantBData.data()); params.QuantBScale = QuantBScale.data(); params.QuantBZeroPoint = Symmetric ? nullptr : QuantBZeroPoint.data(); params.Bias = nullptr; diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index a1f88e075f34a..4fb8ab41745d5 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -41,6 +41,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { MatrixGuardBuffer BufferQuantAScale; MatrixGuardBuffer BufferB; MatrixGuardBuffer BufferQuantBData; + MatrixGuardBuffer BufferPackedQuantBData; MatrixGuardBuffer BufferQuantBZeroPoint; MatrixGuardBuffer BufferQuantBScale; MatrixGuardBuffer BufferDequantizedB; @@ -54,9 +55,10 @@ class MlasSQNBitGemmTest : public MlasTestBase { size_t K, const float* A, size_t lda, - const uint8_t* QuantBData, + const void* QuantBData, + const void* PackedQuantBData, const float* QuantBScale, - const uint8_t* QuantBZeroPoint, + const void* QuantBZeroPoint, const float* Bias, float* C, size_t ldc, @@ -69,7 +71,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { params.Bias = Bias; params.C = C; params.ldc = ldc; - params.QuantBData = QuantBData; + params.QuantBData = PackedQuantBData != nullptr ? PackedQuantBData : QuantBData; params.QuantBScale = QuantBScale; params.QuantBZeroPoint = QuantBZeroPoint; params.PostProcessor = nullptr; @@ -256,6 +258,13 @@ class MlasSQNBitGemmTest : public MlasTestBase { Workspace = BufferWorkspace.GetBuffer(WorkspaceSize); } + void* PackedQuantBData = nullptr; + if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen); + PackedQuantBDataSize > 0) { + PackedQuantBData = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize); + MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, QuantBData, PackedQuantBData, GetMlasThreadPool()); + } + if (ComputeType == CompFp32) { CallReferenceGemm_CompFp32(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); } else if (ComputeType == CompInt8) { @@ -265,8 +274,14 @@ class MlasSQNBitGemmTest : public MlasTestBase { << ComputeType << " (" << ComputeTypeName(ComputeType) << ")"; } - CallGemm(M, N, K, A, /* lda */ K, QuantBData, QuantBScale, QuantBZeroPoint, Bias, C, /* ldc */ N, Workspace, - ComputeType, Threadpool); + CallGemm(M, N, K, + A, /* lda */ K, + QuantBData, PackedQuantBData, QuantBScale, QuantBZeroPoint, + Bias, + C, /* ldc */ N, + Workspace, + ComputeType, + Threadpool); size_t f = 0; for (size_t m = 0; m < M; m++) { From cff3cb47cb1fd56ebbad146fc7ecaf4d5bd994c6 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Wed, 10 Jan 2024 19:59:33 -0800 Subject: [PATCH 26/31] get matmulnbits working, add docs --- .../cpu/quantization/matmul_nbits.cc | 118 +++++++++++------- onnxruntime/core/mlas/inc/mlas_qnbit.h | 30 +++++ onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 2 +- 3 files changed, 105 insertions(+), 45 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index cbe90459d7719..406c73c95d444 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -64,6 +64,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat if (!all_constant_) { return Status::OK(); } + +#if defined(MLAS_JBLAS) + auto compt_type = static_cast(accuracy_level_); MLAS_THREADPOOL* pool = NULL; if (input_idx == 1) { @@ -101,12 +104,32 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat is_packed = true; } +#else // defined(MLAS_JBLAS) + + if (input_idx == 1) { + packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_); + if (packed_b_size_ == 0) return Status::OK(); + auto qptr = tensor.DataRaw(); + packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, qptr, packed_b_.get()); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } + is_packed = true; + } + +#endif // defined(MLAS_JBLAS) + return Status::OK(); } Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; + +#if defined(MLAS_JBLAS) + // Pack three tensors into one buffer if (input_idx == 1) { used_shared_buffers = true; @@ -120,6 +143,15 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep used_shared_buffers = true; packed_b_ = std::move(prepacked_buffers[0]); } + +#else // defined(MLAS_JBLAS) + + if (input_idx == 1) { + used_shared_buffers = true; + packed_b_ = std::move(prepacked_buffers[0]); + } + +#endif // defined(MLAS_JBLAS) return Status::OK(); } @@ -129,6 +161,8 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const Tensor* a = ctx->Input(0); const auto* a_data = a->Data(); +#if defined(MLAS_JBLAS) + if (packed_b_.get()) { TensorShape b_shape({static_cast(N_), static_cast(K_)}); @@ -166,10 +200,10 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { return Status::OK(); } - const Tensor* b = ctx->Input(1); +#endif // defined(MLAS_JBLAS) + const Tensor* scales = ctx->Input(2); const Tensor* zero_points = ctx->Input(3); - const uint8_t* b_data = b->Data(); const auto* scales_data = scales->Data(); const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); @@ -181,8 +215,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { Tensor* y = ctx->Output(0, helper.OutputShape()); // Bail out early if the output is going to be empty - if (y->Shape().Size() == 0) + if (y->Shape().Size() == 0) { return Status::OK(); + } auto* y_data = y->MutableData(); @@ -192,51 +227,46 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(false); - for (int64_t accuracy_level = accuracy_level_; - accuracy_level >= static_cast(CompMostAccurate); - --accuracy_level) { - const auto compute_type = static_cast(accuracy_level); - if (MlasIsSQNBitGemmAvailable(M, N, K, nbits_, block_size_, compute_type)) { - // number of bytes or elements between adjacent matrices - size_t b_data_matrix_stride_in_bytes, b_scale_matrix_stride, b_zero_point_matrix_stride_in_bytes; - MlasBlockwiseQuantizedBufferSizes(static_cast(nbits_), static_cast(block_size_), /* columnwise */ true, - static_cast(K), static_cast(N), - b_data_matrix_stride_in_bytes, b_scale_matrix_stride, - &b_zero_point_matrix_stride_in_bytes); - - const size_t b_matrix_size = K * N; - - IAllocatorUniquePtr workspace{}; - if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count, - nbits_, block_size_, compute_type); - workspace_size > 0) { - AllocatorPtr allocator; - ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); - workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); - } - - InlinedVector data(batch_count); - for (size_t i = 0; i < batch_count; ++i) { - const size_t b_matrix_offset = helper.RightOffsets()[i] / b_matrix_size; - - data[i].A = a_data + helper.LeftOffsets()[i]; - data[i].lda = lda; - data[i].QuantBData = b_data + b_matrix_offset * b_data_matrix_stride_in_bytes; - data[i].QuantBScale = scales_data + b_matrix_offset * b_scale_matrix_stride; - data[i].QuantBZeroPoint = zero_points_data != nullptr - ? zero_points_data + b_matrix_offset * b_zero_point_matrix_stride_in_bytes - : nullptr; - data[i].C = y_data + helper.OutputOffsets()[i]; - data[i].ldc = N; + const bool has_single_b_matrix = std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), + [](size_t offset) { return offset == 0; }); + + if (has_single_b_matrix && packed_b_) { + for (int64_t accuracy_level = accuracy_level_; + accuracy_level >= static_cast(CompMostAccurate); + --accuracy_level) { + const auto compute_type = static_cast(accuracy_level); + if (MlasIsSQNBitGemmAvailable(M, N, K, nbits_, block_size_, compute_type)) { + IAllocatorUniquePtr workspace{}; + if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count, + nbits_, block_size_, compute_type); + workspace_size > 0) { + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); + workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); + } + + InlinedVector data(batch_count); + for (size_t i = 0; i < batch_count; ++i) { + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].QuantBData = packed_b_.get(); + data[i].QuantBScale = scales_data; + data[i].QuantBZeroPoint = zero_points_data; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + } + + MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), + thread_pool); + + return Status::OK(); } - - MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), - thread_pool); - - return Status::OK(); } } + const Tensor* b = ctx->Input(1); + const uint8_t* b_data = b->Data(); + const size_t ldb = helper.Ldb(true); AllocatorPtr allocator; diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 1f0899507e9d4..bc0bfc92c85a0 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -65,6 +65,13 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { * * Call MlasIsSQNBitGemmAvailable() with the same parameters to determine whether this function may be called. * + * Call MlasSQNBitGemmPackQuantBDataSize() with the same parameters to determine whether + * MLAS_SQNBIT_GEMM_DATA_PARAMS::QuantBData in `DataParams` should point to a buffer packed with + * MlasSQNBitGemmPackQuantBData(). + * + * Call MlasSQNBitGemmBatchWorkspaceSize() with the same parameters to determine whether `Workspace` should + * point to an intermediate workspace buffer. + * * @param[in] M row size of matrix A and C * @param[in] N column size of matrix B and C * @param[in] K column size of matrix A and row size of matrix B @@ -135,6 +142,18 @@ MlasSQNBitGemmBatchWorkspaceSize( MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ); +/** + * @brief Gets the size in bytes of the packed quantized B data. + * If non-zero, the quantized B data must first be packed by calling MlasSQNBitGemmPackQuantBData() with a buffer of + * this size, and then that packed quantized B data buffer must be passed to MlasSQNBitGemmBatch(). + * If zero, MlasSQNBitGemmPackQuantBData() must not be called and the quantized B data must be directly passed to + * MlasSQNBitGemmBatch(). + * + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + */ size_t MLASCALL MlasSQNBitGemmPackQuantBDataSize( size_t N, @@ -143,6 +162,17 @@ MlasSQNBitGemmPackQuantBDataSize( size_t BlkLen ); +/** + * @brief Packs the quantized B data in a format that the kernel expects. + * + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[in] QuantBData quantized B data + * @param[out] PackedQuantBData packed quantized B data + * @param[in] ThreadPool optional thread pool to use + */ void MLASCALL MlasSQNBitGemmPackQuantBData( size_t N, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index e7a6f3296c7eb..b17a154331fd1 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -227,7 +227,7 @@ SQNBitGemmPackQuantBData_BlkBitWidth4( } } } -} +} // namespace size_t MLASCALL MlasSQNBitGemmPackQuantBDataSize( From 33e6dd903804ff1e983a90778b8ad4d97985bf24 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Thu, 11 Jan 2024 10:33:55 -0800 Subject: [PATCH 27/31] use threadpool to pack b data --- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index b17a154331fd1..47f48f113e305 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -189,19 +189,29 @@ SQNBitGemmPackQuantBData_BlkBitWidth4( size_t N, size_t K, size_t BlkLen, - const std::byte* QuantBData, - std::byte* PackedQuantBData, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, MLAS_THREADPOOL* ThreadPool ) { - MLAS_UNREFERENCED_PARAMETER(ThreadPool); // TODO use ThreadPool + constexpr size_t BlkBitWidth = 4; assert(BlkLen % 16 == 0); const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t Iterations = N * BlockCountK; // one iteration per block + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t k_blk = tid % BlockCountK; + + const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize; + const std::byte* QuantBData = QuantBDataBegin + data_offset; + std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset; - for (size_t n = 0; n < N; ++n) { - for (size_t k_blk = 0; k_blk < BlockCountK; ++k_blk) { // // Pack 16 4-bit values (8 bytes) at a time like this: // @@ -225,7 +235,7 @@ SQNBitGemmPackQuantBData_BlkBitWidth4( PackedQuantBData += 8; } } - } + ); } } // namespace From 4cd2474cea94dd806edeb8510c9e5530d337a335 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Thu, 11 Jan 2024 14:19:36 -0800 Subject: [PATCH 28/31] shorten names, update docs --- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 20 +++++++------- onnxruntime/core/mlas/lib/sqnbitgemm.h | 27 +++++++++---------- .../core/mlas/lib/sqnbitgemm_kernel_neon.cpp | 12 ++++----- 3 files changed, 29 insertions(+), 30 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 47f48f113e305..0f36f7f37beef 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -88,11 +88,11 @@ MlasIsSQNBitGemmAvailable( switch (Variant) { case SQNBitGemmVariant_BitWidth4_CompFp32: { - return Dispatch->SQNBitGemmM1Kernel_BlkBitWidth4_CompFp32 != nullptr && - Dispatch->QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32 != nullptr; + return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr && + Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr; } case SQNBitGemmVariant_BitWidth4_CompInt8: { - return Dispatch->SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8 != nullptr && + return Dispatch->SQ4BitGemmM1Kernel_CompInt8 != nullptr && Dispatch->QuantizeARow_CompInt8 != nullptr; } default: { @@ -320,7 +320,7 @@ typedef void(SQNBitGemmFn)( ); void -SQNBitGemm_BlkBitWidth4_CompFp32( +SQ4BitGemm_CompFp32( const size_t BlkLen, const size_t K, const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, @@ -368,7 +368,7 @@ SQNBitGemm_BlkBitWidth4_CompFp32( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - GetMlasPlatform().SQNBitGemmDispatch->SQNBitGemmM1Kernel_BlkBitWidth4_CompFp32( + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompFp32( BlkLen, a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias ); @@ -406,7 +406,7 @@ SQNBitGemm_BlkBitWidth4_CompFp32( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - GetMlasPlatform().SQNBitGemmDispatch->QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32( + GetMlasPlatform().SQNBitGemmDispatch->Q4BitBlkDequantBForSgemm_CompFp32( BlkLen, dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks ); @@ -439,7 +439,7 @@ SQNBitGemm_BlkBitWidth4_CompFp32( } void -SQNBitGemm_BlkBitWidth4_CompInt8( +SQ4BitGemm_CompInt8( const size_t BlkLen, const size_t K, const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, @@ -485,7 +485,7 @@ SQNBitGemm_BlkBitWidth4_CompInt8( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - GetMlasPlatform().SQNBitGemmDispatch->SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8( + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( BlkLen, a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias ); @@ -558,10 +558,10 @@ struct Operations { constexpr auto OperationMap = []() { std::array ops; - ops[SQNBitGemmVariant_BitWidth4_CompFp32].SQNBitGemm = SQNBitGemm_BlkBitWidth4_CompFp32; + ops[SQNBitGemmVariant_BitWidth4_CompFp32].SQNBitGemm = SQ4BitGemm_CompFp32; ops[SQNBitGemmVariant_BitWidth4_CompInt8].InitializeWorkspace = InitializeWorkspace_CompInt8; - ops[SQNBitGemmVariant_BitWidth4_CompInt8].SQNBitGemm = SQNBitGemm_BlkBitWidth4_CompInt8; + ops[SQNBitGemmVariant_BitWidth4_CompInt8].SQNBitGemm = SQ4BitGemm_CompInt8; return ops; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index a02aa6987d518..a66db79dc290a 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -22,11 +22,11 @@ Module Name: #pragma once +#include + #include "mlas_qnbit.h" #include "mlasi.h" -#include - constexpr MLAS_FORCEINLINE size_t MlasQNBitBlkDataSizeInBytes(size_t BlkBitWidth, size_t BlkLen) { @@ -104,7 +104,7 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { // /** - * @brief Multiply float matrix A with quantized n-bit integer matrix B. + * @brief Multiply float matrix A with quantized 4-bit integer matrix B. * B is block quantized and column major. * This kernel handles the special case where M, the number of rows of A and C, is 1. * @@ -119,7 +119,7 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. * @param Bias Bias vector of length N. */ - typedef void(SQNBitGemmM1Kernel_BlkBitWidth4_CompFp32_Fn)( + typedef void(SQ4BitGemmM1Kernel_CompFp32_Fn)( size_t BlkLen, const float* A, const std::byte* QuantBData, @@ -132,13 +132,12 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { const float* Bias ); - SQNBitGemmM1Kernel_BlkBitWidth4_CompFp32_Fn* SQNBitGemmM1Kernel_BlkBitWidth4_CompFp32 = nullptr; + SQ4BitGemmM1Kernel_CompFp32_Fn* SQ4BitGemmM1Kernel_CompFp32 = nullptr; /** * @brief Dequantize B into the format expected by the Sgemm kernel. - * B is block quantized and column major. - * This is equivalent to dequantizing B and then running - * MlasSgemmCopyPackB. + * B is a quantized 4-bit integer matrix that is block quantized and column major. + * This is equivalent to dequantizing B and then running MlasSgemmCopyPackB. * * @param BlkLen Number of values in a block. * @param[out] FpData Supplies the output buffer for the dequantized B float data. @@ -149,7 +148,7 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { * @param CountK Number of rows of B. * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. */ - typedef void(QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32_Fn)( + typedef void(Q4BitBlkDequantBForSgemm_CompFp32_Fn)( size_t BlkLen, float* FpData, const std::byte* QuantBData, @@ -160,14 +159,14 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { size_t BlockStrideQuantB ); - QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32_Fn* QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32 = nullptr; + Q4BitBlkDequantBForSgemm_CompFp32_Fn* Q4BitBlkDequantBForSgemm_CompFp32 = nullptr; // // CompInt8 kernel function prototypes. // /** - * @brief Multiply quantized int8 matrix A with quantized n-bit integer matrix B. + * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. * A and B are block quantized and B is column major. * This kernel handles the special case where M, the number of rows of A and C, is 1. * @@ -183,7 +182,7 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. * @param Bias Bias vector of length N. */ - typedef void(SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8_Fn)( + typedef void(SQ4BitGemmM1Kernel_CompInt8_Fn)( size_t BlkLen, const std::byte* QuantA, const std::byte* QuantBData, @@ -196,10 +195,10 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { const float* Bias ); - SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8_Fn* SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8 = nullptr; + SQ4BitGemmM1Kernel_CompInt8_Fn* SQ4BitGemmM1Kernel_CompInt8 = nullptr; /** - * @brief Block quantize values from one row of matrix A from float to int8. + * @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers. * * @param BlkLen Number of values in a block. * @param A Supplies the A matrix. diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index bf791c99a5caf..45e7809a07575 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -259,7 +259,7 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( } MLAS_FORCEINLINE void -SQNBitGemmM1Kernel_BlkBitWidth4_CompFp32( +SQ4BitGemmM1Kernel_CompFp32( size_t BlkLen, const float* A, const std::byte* QuantBData, @@ -340,7 +340,7 @@ SQNBitGemmM1Kernel_BlkBitWidth4_CompFp32( } MLAS_FORCEINLINE void -QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32( +Q4BitBlkDequantBForSgemm_CompFp32( size_t BlkLen, float* FpData, const std::byte* QuantBData, @@ -659,7 +659,7 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( MLAS_FORCEINLINE void -SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8( +SQ4BitGemmM1Kernel_CompInt8( size_t BlkLen, const std::byte* QuantA, const std::byte* QuantBData, @@ -748,9 +748,9 @@ SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8( const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { MLAS_SQNBIT_GEMM_DISPATCH d; - d.SQNBitGemmM1Kernel_BlkBitWidth4_CompFp32 = SQNBitGemmM1Kernel_BlkBitWidth4_CompFp32; - d.QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32 = QNBitBlkDequantBForSgemm_BlkBitWidth4_CompFp32; - d.SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8 = SQNBitGemmM1Kernel_BlkBitWidth4_CompInt8; + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; + d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32; + d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8; d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; return d; From 9244a3f16176bfca7ecb7e93ae75a96f10b290be Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Thu, 11 Jan 2024 15:44:24 -0800 Subject: [PATCH 29/31] rename another function, add check for implementation in MlasSQNBitGemmPackQuantBDataSize() --- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 0f36f7f37beef..7d877848017fe 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -184,8 +184,9 @@ MlasSQNBitGemmBatchWorkspaceSize( namespace { + void -SQNBitGemmPackQuantBData_BlkBitWidth4( +SQ4BitGemmPackQuantBData( size_t N, size_t K, size_t BlkLen, @@ -237,6 +238,7 @@ SQNBitGemmPackQuantBData_BlkBitWidth4( } ); } + } // namespace size_t MLASCALL @@ -247,6 +249,20 @@ MlasSQNBitGemmPackQuantBDataSize( size_t BlkLen ) { + // Ensure that a general implementation is available on this platform. + // For now, all implementations share the same packed format. + { + // Currently, there are implementations specific to M = 1, so pick a more general M > 1. + constexpr size_t M = 2; + // A CompUndef implementation should be available if any is available. + constexpr MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType = CompUndef; + const bool HasGeneralImplementation = + MlasIsSQNBitGemmAvailable(M, N, K, BlkBitWidth, BlkLen, ComputeType); + if (!HasGeneralImplementation) { + return 0; + } + } + if (BlkBitWidth == 4) { const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); @@ -268,7 +284,7 @@ MlasSQNBitGemmPackQuantBData( ) { if (BlkBitWidth == 4) { - SQNBitGemmPackQuantBData_BlkBitWidth4( + SQ4BitGemmPackQuantBData( N, K, BlkLen, From 86f84ea0fab60dc8efbb8a90b5cdad7b86993f29 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Thu, 11 Jan 2024 16:03:00 -0800 Subject: [PATCH 30/31] move b_data_block_offset out of unrolled loop body --- onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index 45e7809a07575..53b9cacf54f72 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -593,8 +593,8 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( // load B column vectors uint8x8_t bv_packed[NCols]; + const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; UnrolledLoop([&](size_t i) { - const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; bv_packed[i] = vld1_u8( reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset ); From 23373759562f74bfcdef30679b1ae825c230f725 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Fri, 12 Jan 2024 10:53:29 -0800 Subject: [PATCH 31/31] move b data offset out of unrolled loop in compfp32 kernel --- onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index 53b9cacf54f72..69fd427fa574a 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -174,8 +174,8 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( // load B column vectors uint8x8_t bv_packed[NCols]; + const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; UnrolledLoop([&](size_t i) { - const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; bv_packed[i] = vld1_u8( reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset );