From 01ac345d3c3304b89b8dbbed585447d4471c946b Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Wed, 25 Oct 2023 20:34:14 -0700 Subject: [PATCH] save work - got sqnbitgemm tests and a cpu impl --- cmake/onnxruntime_mlas.cmake | 3 + onnxruntime/core/mlas/.clang-format | 6 +- .../core/mlas/inc/mlas_gemm_postprocessor.h | 16 + onnxruntime/core/mlas/inc/mlas_q4.h | 23 +- onnxruntime/core/mlas/inc/mlas_qnbit.h | 68 +++ onnxruntime/core/mlas/lib/mlasi.h | 18 + onnxruntime/core/mlas/lib/platform.cpp | 1 + onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 128 +++++ onnxruntime/core/mlas/lib/sqnbitgemm.h | 277 +++++++++++ .../core/mlas/lib/sqnbitgemm_kernel_neon.cpp | 213 ++++++++ .../test/mlas/unittest/test_halfgemm.h | 14 - onnxruntime/test/mlas/unittest/test_q4gemm.h | 14 - .../test/mlas/unittest/test_q8q4gemm.cpp | 14 - .../test/mlas/unittest/test_sqnbitgemm.cpp | 466 ++++++++++++++++++ onnxruntime/test/mlas/unittest/test_util.h | 16 + 15 files changed, 1214 insertions(+), 63 deletions(-) create mode 100644 onnxruntime/core/mlas/inc/mlas_gemm_postprocessor.h create mode 100644 onnxruntime/core/mlas/inc/mlas_qnbit.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm.cpp create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp create mode 100644 onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index af51df4838505..9985f5c8bc516 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -33,6 +33,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/qpostprocessor.cpp ${MLAS_SRC_DIR}/qlgavgpool.cpp ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp + ${MLAS_SRC_DIR}/sqnbitgemm.cpp ) if (NOT onnxruntime_ORT_MINIMAL_BUILD) @@ -69,6 +70,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp ) set(mlas_platform_preprocess_srcs @@ -336,6 +338,7 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp ) if (NOT APPLE) set(mlas_platform_srcs diff --git a/onnxruntime/core/mlas/.clang-format b/onnxruntime/core/mlas/.clang-format index 4a89ef98cf049..16ad8bd8a7234 100644 --- a/onnxruntime/core/mlas/.clang-format +++ b/onnxruntime/core/mlas/.clang-format @@ -2,10 +2,12 @@ BasedOnStyle: Google IndentWidth: 4 -ColumnLimit: 100 +# Setting ColumnLimit to 0 so developer choices about where to break lines are maintained. +# Developers are responsible for adhering to the 120 character maximum. +ColumnLimit: 0 +AlignAfterOpenBracket: BlockIndent AlwaysBreakAfterReturnType: TopLevel AlwaysBreakTemplateDeclarations: Yes BinPackParameters: false BreakBeforeBraces: Linux ... - diff --git a/onnxruntime/core/mlas/inc/mlas_gemm_postprocessor.h b/onnxruntime/core/mlas/inc/mlas_gemm_postprocessor.h new file mode 100644 index 0000000000000..fe77ace18ecc4 --- /dev/null +++ b/onnxruntime/core/mlas/inc/mlas_gemm_postprocessor.h @@ -0,0 +1,16 @@ +#pragma once + +template +class MLAS_GEMM_POSTPROCESSOR +{ + public: + virtual void Process(T* C, /**< the address of matrix to process */ + size_t RangeStartM, /**< the start row index of matrix */ + size_t RangeStartN, /**< the start col index of matrix */ + size_t RangeCountM, /**< the element count per row to process */ + size_t RangeCountN, /**< the element count per col to process */ + size_t ldc /**< the leading dimension of matrix */ + ) const = 0; + + virtual ~MLAS_GEMM_POSTPROCESSOR() {} +}; diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h index 65b48a3009e72..84d3d3fa6d54c 100644 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ b/onnxruntime/core/mlas/inc/mlas_q4.h @@ -21,6 +21,7 @@ Module Name: #pragma once #include "mlas.h" +#include "mlas_gemm_postprocessor.h" #include #include @@ -39,7 +40,7 @@ typedef enum { * @brief Computes the number of bytes required to pack and int4-quantize * a weight matrix * @param QType type of block quantization - * @param N the number of columns of matrix B. + * @param N the number of columns of matrix B. * @param K the number of rows of matrix B. * @return size of the packing buffer, 0 if the operation is not yet supported. */ @@ -53,11 +54,11 @@ MlasQ4GemmPackBSize( /** * @brief Prepack and Quantize fp32 weight tensor to int4 blocks - * + * * @param QType type of block quantization * @param PackedBuf destination buffer * @param FpData the pointer to fp32 matrix - * @param N the number of columns of matrix B. + * @param N the number of columns of matrix B. * @param K the number of rows of matrix B. * @param ldb leading dimension of B */ @@ -95,22 +96,6 @@ MlasQ4GemmUnPackB( ); -template -class MLAS_GEMM_POSTPROCESSOR -{ - public: - virtual void Process(T*, /**< the address of matrix to process */ - size_t, /**< the start row index of matrix */ - size_t, /**< the start col index of matrix */ - size_t, /**< the element count per row to process */ - size_t, /**< the element count per col to process */ - size_t /**< the leading dimension of matrix */ - ) const = 0; - - virtual ~MLAS_GEMM_POSTPROCESSOR() {} -}; - - /** * @brief Data parameters for Q4 GEMM routine * C = A * B + Bias diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h new file mode 100644 index 0000000000000..26ff0f0271e32 --- /dev/null +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -0,0 +1,68 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + mlas_qnbit.h + +Abstract: + + This module contains the public data structures and procedure prototypes + for blocked n-bit quantized GEMM. + + N-bit block quantization is used to compress weight tensors of large + language models. + +--*/ + +#pragma once + +#include "mlas.h" +#include "mlas_gemm_postprocessor.h" + +/** + * @brief Data parameters for float/n-bit quantized int GEMM routine. + */ +struct MLAS_SQNBIT_GEMM_DATA_PARAMS { + const float* A = nullptr; ///< address of A (float32 matrix) + const void* PackedBData = nullptr; ///< address of B (quantized and packed n-bit int values) + const float* PackedBScale = nullptr; ///< address of scale values of quantized B, one per block + const void* PackedBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block + bool IsBPacked = false; ///< whether B values are packed in the optimal format for the computation + const float* Bias = nullptr; ///< optional address of Bias, vector size N + float* C = nullptr; ///< address of result matrix + size_t lda = 0; ///< leading dimension of A + size_t ldc = 0; ///< leading dimension of C + + ///< optional post processing to apply to result matrix + MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; +}; + +/** + * @brief Batched GEMM: C = A * B + Bias + * A must be a float32 matrix + * B must be a quantized and packed n-bit int matrix + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkSize number of quantized values per block + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] ThreadPool optional thread pool to use + */ +void MLASCALL +MlasSQNBitGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const size_t BlkBitWidth, + const size_t BlkSize, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool = nullptr +); diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index d037360cf1028..0d5e425018f37 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -882,15 +882,31 @@ extern const MLAS_CONV_SYM_DISPATCH MlasConvSymS8DispatchNeon; extern const MLAS_CONV_SYM_DISPATCH MlasConvSymU8DispatchDot; extern const MLAS_CONV_SYM_DISPATCH MlasConvSymS8DispatchDot; +// +// Quantized 8-bit integer/quantized 4-bit integer matrix/matrix multiply dispatch structure. +// + struct MLAS_Q8Q4GEMM_DISPATCH; extern const MLAS_Q8Q4GEMM_DISPATCH MlasQ8Q4GemmDispatchAvx512vnni; +// +// Float/quantized 4-bit integer matrix/matrix multiply dispatch structure. +// + struct MLAS_FPQ4GEMM_DISPATCH; extern const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchAvx512; extern const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchNeon; +// +// Float/quantized n-bit integer matrix/matrix multiply dispatch structure. +// + +struct MLAS_SQNBIT_GEMM_DISPATCH; + +extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon; + // // Quantized depthwise convolution kernels. // @@ -1022,6 +1038,8 @@ struct MLAS_PLATFORM { const MLAS_FPQ4GEMM_DISPATCH* FpQ4GemmDispatch{nullptr}; const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr}; + + const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr}; }; inline diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index d81c4338cc140..0a4d9e05c4cd2 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -450,6 +450,7 @@ Return Value: this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon; this->FpQ4GemmDispatch = &MlasFpQ4GemmDispatchNeon; + this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; // // Check if the processor supports ASIMD dot product instructions. diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp new file mode 100644 index 0000000000000..c952c38677224 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -0,0 +1,128 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch. +--*/ + +#include "sqnbitgemm.h" + +namespace +{ + +// Get quantization variant based on `BlkBitWidth` and `BlkLen`. +// Return -1 if the input values are unsupported. +int32_t +GetDispatchQuantVariant(size_t BlkBitWidth, size_t BlkLen) +{ + int32_t type = -1; + if (BlkBitWidth == 4 && BlkLen == 16) { + type = QuantVariant_BitWidth4_BlockSize16; + } else if (BlkBitWidth == 4 && BlkLen == 32) { + type = QuantVariant_BitWidth4_BlockSize32; + } + + return type; +} + +} // namespace + +void MLASCALL +MlasSQNBitGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const size_t BlkBitWidth, + const size_t BlkLen, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool +) +{ + const int32_t QuantVariant = GetDispatchQuantVariant(BlkLen, BlkBitWidth); + if (QuantVariant == -1) { + MLAS_THROW_EX(std::invalid_argument, "Unsupported quantization block size / bit width."); + } + + MLAS_SQNBIT_GEMM_OPERATION* const Operation = + GetMlasPlatform().SQNBitGemmDispatch->Operations[QuantVariant]; + + if (Operation == nullptr) { + MLAS_THROW_EX(std::invalid_argument, "FpQNBitGemm is not implemented on this platform."); + } + + if (ThreadPool == nullptr) { + for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { + auto Data = &DataParams[gemm_i]; + Operation(K, Data, 0, M, 0, N); + } + return; + } + + // + // Compute the number of target threads given the complexity of the SGEMM + // operation. Small requests should run using the single threaded path. + // + + const double Complexity = double(M) * double(N) * double(K) * double(BatchN); + + ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1; + + ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool) * 8; + + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchN; + if (ThreadsPerGemm < 1) { + ThreadsPerGemm = 1; + } + + constexpr size_t StrideM = 128; + + size_t nc = N; + if (ThreadsPerGemm > 1) { + // more than one thread per GEMM + + const size_t BlockedM = MlasDivRoundup(M, StrideM); + const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm); + if (max_nc < nc) { + nc = std::min( + nc, MlasDivRoundup(max_nc, MLAS_QGEMM_STRIDEN_THREAD_ALIGN) * + MLAS_QGEMM_STRIDEN_THREAD_ALIGN + ); + } + } + const size_t StrideN = nc; + + const size_t ThreadCountM = MlasDivRoundup(M, StrideM); + const size_t ThreadCountN = MlasDivRoundup(N, StrideN); + ThreadsPerGemm = ThreadCountM * ThreadCountN; + + MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { + const auto gemm_i = tid / ThreadsPerGemm; + const auto blk_i = tid % ThreadsPerGemm; + auto Data = &DataParams[gemm_i]; + + const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; + const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; + + const size_t RangeStartM = ThreadIdM * StrideM; + const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM); + + const size_t RangeStartN = ThreadIdN * StrideN; + const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); + + Operation(K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + }); +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h new file mode 100644 index 0000000000000..404a9f8001fbe --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -0,0 +1,277 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm.h + +Abstract: + + This module includes: + + - Declaration of the set of template functions used to implement a kernel + for a matrix/matrix multiplication, A*B, where A is a float matrix and B is + a n-bit quantized integer matrix (QNBitGemm). + + - A shared kernel driver function template, MlasSQNBitGemmOperation. + + - Kernel dispatch structure. + + The B matrix is block quantized, which means that its values are grouped + into blocks which each have one scale and optional zero point. Each + quantized value in B is n-bits wide. + +--*/ + +#pragma once + +#include "mlas_qnbit.h" +#include "mlasi.h" + +// +// Kernel implementation template declarations +// + +/// +/// Multiply float matrix A with quantized n-bit integer matrix B. +/// +/// Hardware-specific kernel type. +/// Number of values in a block. +/// Bit width of each value in a block. +/// Supplies the A matrix. +/// Supplies the packed B matrix block data. +/// Supplies the packed B matrix block scale values. +/// Supplies the packed B matrix block zero point values. Optional. +/// Supplies the output C matrix. +/// Number of rows of A and C. +/// Number of columns of B and C. +/// Number of columns of A and rows of B. +/// Leading dimension of A. +/// +/// Number of blocks between adjacent columns of B (packed B values are transposed). +/// +/// Leading dimension of C. +/// Bias vector of length N. Optional. +/// Number of rows of A handled. +template +MLAS_FORCEINLINE size_t +MlasSQNBitGemmKernel( + const float* A, + const uint8_t* PackedBData, + const float* PackedBScale, + const uint8_t* PackedBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t BlockStridePackedB, + size_t ldc, + const float* Bias +); + +// dequantize B into the format expected by MlasSgemmKernelZero +template +MLAS_FORCEINLINE void +MlasQNBitBlkDequantBForSgemm( + float* FpData, + const uint8_t* PackedBData, + const float* PackedBScale, + const uint8_t* PackedBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStridePackedB +); + +// +// MlasQNBitGemmOperation and helpers +// + +constexpr MLAS_FORCEINLINE size_t +MlasQNBitBlkDataSizeInBytes(size_t BlkBitWidth, size_t BlkLen) +{ + return BlkLen * BlkBitWidth / 8; +} + +template +constexpr MLAS_FORCEINLINE size_t +MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount) +{ + if constexpr (BlkBitWidth <= 4) { + return MlasDivRoundup(BlkCount, 2); // 2 blocks per byte + } else { + return BlkCount; + } +} + +MLAS_FORCEINLINE void +MlasAddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc) +{ + for (size_t m = 0; m < CountM; m++) { + const float* bias = Bias; + float* sum = C; + for (size_t n = 0; n < CountN; n += 4) { + if (CountN - n < 4) { + for (size_t nn = n; nn < CountN; nn++) { + *sum += *bias; + sum++; + bias++; + } + break; + } + + MLAS_FLOAT32X4 acc_x = MlasLoadFloat32x4(sum); + acc_x = MlasAddFloat32x4(acc_x, MlasLoadFloat32x4(bias)); + MlasStoreFloat32x4(sum, acc_x); + bias += 4; + sum += 4; + } + C += ldc; + } +} + +template +MLAS_FORCEINLINE void MLASCALL +MlasSQNBitGemmOperation( + const size_t K, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) +{ + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; + + const size_t k_blks = MlasDivRoundup(K, BlkLen); + const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const float* A = DataParams->A + RangeStartM * lda; + const uint8_t* PackedBData = static_cast(DataParams->PackedBData); + const float* PackedBScale = DataParams->PackedBScale; + const uint8_t* PackedBZeroPoint = static_cast(DataParams->PackedBZeroPoint); + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + const float* Bias = DataParams->Bias; + + if (RangeCountM == 1) { + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, (size_t)128); + + // + // Step through each slice of matrix A along the M dimension. + // + const float* bias = (Bias == nullptr) ? nullptr : Bias + RangeStartN + n; + const uint8_t* b_col = PackedBData + (RangeStartN + n) * ldb; + const float* b_col_scale = PackedBScale + (RangeStartN + n) * k_blks; + const uint8_t* b_col_zp = + (PackedBZeroPoint == nullptr) + ? nullptr + : PackedBZeroPoint + (RangeStartN + n) * MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); + float* c_blk = C + n; + const float* a_row = A; + + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { + auto RowsHandled = MlasSQNBitGemmKernel( + a_row, b_col, b_col_scale, b_col_zp, c_blk, RowsRemaining, CountN, K, lda, k_blks, ldc, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN, + RowsHandled, CountN, ldc + ); + } + + c_blk += ldc * RowsHandled; + a_row += lda * RowsHandled; + RowsRemaining -= RowsHandled; + } + } + return; + } + + constexpr size_t StrideN = 32; + size_t bufsize = k_blks * BlkLen * StrideN * sizeof(float); + MlasThreadedBufAlloc(bufsize); + auto* dequant_b = reinterpret_cast(ThreadedBufHolder.get()); + // + // Step through each slice of matrix B along the N dimension. + // + + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, (size_t)StrideN); + + // + // Step through each slice of matrix A along the M dimension. + // + const float* bias = (Bias == nullptr) ? nullptr : Bias + RangeStartN + n; + const uint8_t* b_col = PackedBData + (RangeStartN + n) * ldb; + const float* b_col_scale = PackedBScale + (RangeStartN + n) * k_blks; + const uint8_t* b_col_zp = + (PackedBZeroPoint == nullptr) + ? nullptr + : PackedBZeroPoint + (RangeStartN + n) * MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); + float* c_blk = C + n; + const float* a_row = A; + + MlasQNBitBlkDequantBForSgemm( + dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks + ); + + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true + ); +#else + auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f); +#endif + + if (bias) { + MlasAddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc); + } + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN, + RowsHandled, CountN, ldc + ); + } + + c_blk += ldc * RowsHandled; + a_row += lda * RowsHandled; + RowsRemaining -= RowsHandled; + } + } +} + +// +// Kernel dispatch structure. +// + +typedef void(MLASCALL MLAS_SQNBIT_GEMM_OPERATION)( + size_t K, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + size_t RangeStartM, + size_t RangeCountM, + size_t RangeStartN, + size_t RangeCountN +); + +enum QuantVariant { + QuantVariant_BitWidth4_BlockSize16, + QuantVariant_BitWidth4_BlockSize32, + QuantVariantCount, // keep this element last +}; + +struct MLAS_SQNBIT_GEMM_DISPATCH { + MLAS_SQNBIT_GEMM_OPERATION* Operations[QuantVariantCount] = { + // Initialized to nullptrs. Overwrite in hardware-specific kernel implementation. + }; +}; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp new file mode 100644 index 0000000000000..b6dce5d1c2578 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -0,0 +1,213 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_neon.h + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for ARM NEON. + +--*/ + +#include + +#include "sqnbitgemm.h" + +// +// Hardware-specific kernel type. +// +struct MLAS_SQNBIT_GEMM_KERNEL_NEON { +}; + +// +// MlasSQNBitGemmKernel and helpers. +// + +template +MLAS_FORCEINLINE size_t +MlasSQNBitGemmKernelNeon( + const float* A, + const uint8_t* PackedBData, + const float* PackedBScale, + const uint8_t* PackedBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t BlockStridePackedB, + size_t ldc, + const float* Bias +) +{ + auto impl0_reference = [&]() { + static_assert(BlkBitWidth == 4); + + for (size_t m = 0; m < CountM; ++m) { + for (size_t n = 0; n < CountN; ++n) { + for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, k_blk_idx += 1) { + const size_t kblocklen = std::min(CountK - k, BlkLen); + + const float b_s = PackedBScale[n * BlockStridePackedB + k_blk_idx]; + const uint8_t b_z = [&]() -> uint8_t { + if (PackedBZeroPoint != nullptr) { + const size_t i = n * BlockStridePackedB + k_blk_idx; + uint8_t zp_packed = PackedBZeroPoint[i / 2]; + return ((i & 1) == 1) ? (zp_packed >> 4) : (zp_packed & 0x0F); + } else { + return 8; + } + }(); + const uint8_t* b_data = + PackedBData + (n * BlockStridePackedB + k_blk_idx) * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + + for (size_t kk = 0; kk < kblocklen; ++kk) { + uint8_t b_packed = b_data[kk / 2]; + uint8_t b_byte = ((kk & 1) == 1) ? (b_packed >> 4) : (b_packed & 0x0F); + float b_value = (b_byte - b_z) * b_s; + + C[m * ldc + n] += A[m * lda + k + kk] * b_value; + } + } + + if (Bias) { + C[m * ldc + n] += Bias[n]; + } + } + } + + return CountM; + }; + + return impl0_reference(); +} + +template <> +MLAS_FORCEINLINE size_t +MlasSQNBitGemmKernel<4, 32, MLAS_SQNBIT_GEMM_KERNEL_NEON>( + const float* A, + const uint8_t* PackedBData, + const float* PackedBScale, + const uint8_t* PackedBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t BlockStridePackedB, + size_t ldc, + const float* Bias +) +{ + return MlasSQNBitGemmKernelNeon<4, 32>( + A, PackedBData, PackedBScale, PackedBZeroPoint, C, CountM, CountN, CountK, lda, + BlockStridePackedB, ldc, Bias + ); +} + +// +// MlasQNBitBlkDequantBForSgemm and helpers. +// + +template +MLAS_FORCEINLINE void +MlasQNBitBlkDequantBForSgemmNeon( + float* FpData, + const uint8_t* PackedBData, + const float* PackedBScale, + const uint8_t* PackedBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStridePackedB +) +{ + auto impl0_reference = [&]() { + static_assert(BlkBitWidth == 4); + + float* Dst = FpData; + + const uint8_t* PackedBDataCol = PackedBData; + const float* PackedBScaleCol = PackedBScale; + const uint8_t* PackedBZeroPointCol = PackedBZeroPoint; + + for (size_t n = 0; n < CountN; n += 16) { + const size_t nnlen = std::min(CountN - n, size_t{16}); + + for (size_t nn = 0; nn < nnlen; ++nn) { + + for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, k_blk_idx += 1) { + const size_t kklen = std::min(CountK - k, BlkLen); + + const uint8_t* b_data = + PackedBDataCol + k_blk_idx * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const float b_s = PackedBScaleCol[k_blk_idx]; + const uint8_t b_z = + (PackedBZeroPointCol != nullptr) + ? ((k_blk_idx & 1) == 1) + ? PackedBZeroPointCol[k_blk_idx / 2] >> 4 + : PackedBZeroPointCol[k_blk_idx / 2] & 0x0F + : 8; + + for (size_t kk = 0; kk < kklen; ++kk) { + const uint8_t b_packed = b_data[kk / 2]; + const uint8_t b_byte = ((kk & 1) == 1) ? b_packed >> 4 : b_packed & 0x0F; + const float b_value = (b_byte - b_z) * b_s; + + Dst[(k + kk) * 16 + nn] = b_value; + } + } + + PackedBDataCol += BlockStridePackedB * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + PackedBScaleCol += BlockStridePackedB; + if (PackedBZeroPointCol != nullptr) { + PackedBZeroPointCol += MlasQNBitZeroPointsForBlksSizeInBytes(BlockStridePackedB); + } + } + + // zero out any remaining columns + + if (nnlen < 16) { + for (size_t k = 0; k < CountK; ++k) { + std::fill_n(Dst + (k * 16) + nnlen, 16 - nnlen, 0.0f); + } + } + + Dst += CountK * 16; + } + }; + + impl0_reference(); +} + +template <> +MLAS_FORCEINLINE void +MlasQNBitBlkDequantBForSgemm<4, 32, MLAS_SQNBIT_GEMM_KERNEL_NEON>( + float* FpData, + const uint8_t* PackedBData, + const float* PackedBScale, + const uint8_t* PackedBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStridePackedB +) +{ + MlasQNBitBlkDequantBForSgemmNeon<4, 32>( + FpData, PackedBData, PackedBScale, PackedBZeroPoint, CountN, CountK, BlockStridePackedB + ); +} + +// +// Kernel dispatch structure definition. +// + +const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { + MLAS_SQNBIT_GEMM_DISPATCH d; + d.Operations[QuantVariant_BitWidth4_BlockSize32] = MlasSQNBitGemmOperation<4, 32, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + return d; +}(); diff --git a/onnxruntime/test/mlas/unittest/test_halfgemm.h b/onnxruntime/test/mlas/unittest/test_halfgemm.h index 2861b0e746fdc..4db5c2bebca40 100644 --- a/onnxruntime/test/mlas/unittest/test_halfgemm.h +++ b/onnxruntime/test/mlas/unittest/test_halfgemm.h @@ -18,20 +18,6 @@ Module Name: #include "test_fp16.h" -inline bool -CloseEnough(float actual, float expected) { - if (std::isnan(actual)) { - return std::isnan(expected); - } - float diff = std::abs(actual - expected); - float top = std::max(std::abs(actual), std::abs(expected)); - float ratio = 0; - if (top > 0.0001) { - ratio = diff / top; - } - return ratio < 0.005; -} - /** * @brief Test class for half precision GEMM * @tparam AType Data type of A matrix, can be either float or MLFp16 diff --git a/onnxruntime/test/mlas/unittest/test_q4gemm.h b/onnxruntime/test/mlas/unittest/test_q4gemm.h index 58a64491ae80b..97c6969b5bf91 100644 --- a/onnxruntime/test/mlas/unittest/test_q4gemm.h +++ b/onnxruntime/test/mlas/unittest/test_q4gemm.h @@ -19,20 +19,6 @@ Module Name: #include "test_util.h" #include "mlas_q4.h" -inline bool -CloseEnough(float actual, float expected) { - if (std::isnan(actual)) { - return std::isnan(expected); - } - float diff = std::abs(actual - expected); - float top = std::max(std::abs(actual), std::abs(expected)); - float ratio = 0; - if (top > 0.0001) { - ratio = diff / top; - } - return ratio < 0.005; -} - /** * @brief Test class for int4 block quantized GEMM * Note: only 2-D matmul supported for now diff --git a/onnxruntime/test/mlas/unittest/test_q8q4gemm.cpp b/onnxruntime/test/mlas/unittest/test_q8q4gemm.cpp index bac16b0103a6e..55aa9198e5818 100644 --- a/onnxruntime/test/mlas/unittest/test_q8q4gemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_q8q4gemm.cpp @@ -19,20 +19,6 @@ Module Name: #include "test_util.h" #include "mlas_q4.h" -inline bool -CloseEnough(float actual, float expected) { - if (std::isnan(actual)) { - return std::isnan(expected); - } - float diff = std::abs(actual - expected); - float top = std::max(std::abs(actual), std::abs(expected)); - float ratio = 0; - if (top > 0.0001) { - ratio = diff / top; - } - return ratio < 0.005; -} - template static void blkq8_dequant_reference(const int8_t* src, float* dst, size_t M, size_t K) { const size_t num_blks = K / QBlkLen; diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp new file mode 100644 index 0000000000000..129715f769d50 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -0,0 +1,466 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_sqnbitgemm.h + +Abstract: + + Tests for MLAS n-bit int block quantized GEMM. + +--*/ + +#pragma once + +#include "test_util.h" +#include "mlas_qnbit.h" + +namespace { + +constexpr size_t DivRoundUp(size_t a, size_t b) { + return (a + b - 1) / b; +} + +constexpr size_t BlkDataSizeInBytes(size_t BlkBitWidth, size_t BlkLen) { + return BlkLen * BlkBitWidth / 8; +} + +template +constexpr size_t ZeroPointsForBlksSizeInBytes(size_t BlkCount) { + if constexpr (BlkBitWidth <= 4) { + return DivRoundUp(BlkCount, 2); + } else { + return BlkCount; + } +} + +template +struct ReferenceQNBitPacking { + static_assert(BlkBitWidth == 4, "Only implemented for BlkBitWidth == 4."); + + static void GetPackedBSizes(size_t CountN, size_t CountK, + size_t& PackedBDataSizeInBytes, + size_t& PackedBScaleElementCount, + size_t* PackedBZeroPointSizeInBytes) { + const size_t BlockCountK = DivRoundUp(CountK, BlkLen); + const size_t TotalBlockCount = CountN * BlockCountK; + + PackedBDataSizeInBytes = TotalBlockCount * BlkDataSizeInBytes(BlkLen, BlkBitWidth); + PackedBScaleElementCount = TotalBlockCount; + if (PackedBZeroPointSizeInBytes) { + *PackedBZeroPointSizeInBytes = CountN * ZeroPointsForBlksSizeInBytes(BlockCountK); + } + } + + static void PackB(size_t CountN, size_t CountK, + const float* BDataPtr, size_t ldb, + uint8_t* PackedBDataPtr, + float* PackedBScalePtr, + uint8_t* PackedBZeroPointPtr) { + const size_t BlockCountK = DivRoundUp(CountK, BlkLen); + + uint8_t* PackedBDataColPtr = PackedBDataPtr; + float* PackedBScaleColPtr = PackedBScalePtr; + uint8_t* PackedBZeroPointColPtr = PackedBZeroPointPtr; + + for (size_t n = 0; n < CountN; ++n) { + for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, k_blk_idx += 1) { + size_t kklen = std::min(BlkLen, CountK - k); + + uint8_t* PackedBBlkDataPtr = PackedBDataColPtr + k_blk_idx * BlkDataSizeInBytes(BlkLen, BlkBitWidth); + + if (PackedBZeroPointColPtr) { + float scale_block; + uint8_t zp_block; + QuantizeBlock(BDataPtr + k * ldb + n, ldb, kklen, PackedBBlkDataPtr, scale_block, zp_block); + + if ((k_blk_idx & 1) == 0) { + PackedBZeroPointColPtr[k_blk_idx / 2] = zp_block & 0x0F; + } else { + PackedBZeroPointColPtr[k_blk_idx / 2] |= zp_block << 4; + } + + PackedBScaleColPtr[k_blk_idx] = scale_block; + } else { + float scale_block; + QuantizeBlock(BDataPtr + k * ldb + n, ldb, kklen, PackedBBlkDataPtr, scale_block); + + PackedBScaleColPtr[k_blk_idx] = scale_block; + } + } + + PackedBDataColPtr += BlockCountK * BlkDataSizeInBytes(BlkLen, BlkBitWidth); + PackedBScaleColPtr += BlockCountK; + if (PackedBZeroPointColPtr != nullptr) { + PackedBZeroPointColPtr += ZeroPointsForBlksSizeInBytes(BlockCountK); + } + } + } + + static void UnpackB(size_t CountN, size_t CountK, + const uint8_t* PackedBDataPtr, const float* PackedBScalePtr, const uint8_t* PackedBZeroPointPtr, + float* BDataPtr, size_t ldb) { + const size_t BlockCountK = DivRoundUp(CountK, BlkLen); + + const uint8_t* PackedBDataColPtr = PackedBDataPtr; + const float* PackedBScaleColPtr = PackedBScalePtr; + const uint8_t* PackedBZeroPointColPtr = PackedBZeroPointPtr; + + for (size_t n = 0; n < CountN; ++n) { + for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, k_blk_idx += 1) { + size_t kklen = std::min(BlkLen, CountK - k); + + const uint8_t* PackedBBlkDataPtr = PackedBDataColPtr + k_blk_idx * BlkDataSizeInBytes(BlkLen, BlkBitWidth); + const float scale_block = PackedBScaleColPtr[k_blk_idx]; + + if (PackedBZeroPointColPtr) { + const uint8_t zp_block = ((k_blk_idx & 1) == 1) + ? (PackedBZeroPointColPtr[k_blk_idx / 2] >> 4) + : (PackedBZeroPointColPtr[k_blk_idx / 2] & 0x0F); + + DequantizeBlock(BDataPtr + k * ldb + n, ldb, kklen, PackedBBlkDataPtr, scale_block, zp_block); + } else { + DequantizeBlock(BDataPtr + k * ldb + n, ldb, kklen, PackedBBlkDataPtr, scale_block); + } + } + + PackedBDataColPtr += BlockCountK * BlkDataSizeInBytes(BlkLen, BlkBitWidth); + PackedBScaleColPtr += BlockCountK; + if (PackedBZeroPointColPtr != nullptr) { + PackedBZeroPointColPtr += ZeroPointsForBlksSizeInBytes(BlockCountK); + } + } + } + + static void QuantizeBlock(const float* b_begin, size_t ldb, size_t actual_block_size, + uint8_t* data_block, float& scale_block, uint8_t& zp_block) { + float min = *b_begin; + float max = *b_begin; + for (int32_t kk = 0; kk < actual_block_size; kk++) { + const float v = b_begin[ldb * kk]; + if (v < min) min = v; + if (v > max) max = v; + } + min = std::min(min, 0.0f); + max = std::max(max, 0.0f); + + scale_block = (max - min) / ((1 << 4) - 1); + + const float reciprocal_scale = scale_block ? 1.0f / scale_block : 0.0f; + float zero_point_fp = min; + if (scale_block != 0.0f) { + zero_point_fp = 0.f - min / scale_block; + } + + // Handle any clamping + if (zero_point_fp < 0.0f) { + zp_block = 0; + } else if (zero_point_fp > 15.0f) { + zp_block = 15; + } else { + zp_block = (uint8_t)roundf(zero_point_fp); + } + + for (int32_t kk = 0; kk < actual_block_size; kk += 2) { + const float v0 = b_begin[ldb * kk]; + const uint8_t vi0 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v0 * reciprocal_scale + zp_block))); + + const float v1 = (kk + 1 < actual_block_size) ? b_begin[ldb * (kk + 1)] : 0.f; + const uint8_t vi1 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v1 * reciprocal_scale + zp_block))); + + data_block[kk / 2] = vi0 | (vi1 << 4); + } + } + + static void QuantizeBlock(const float* b_begin, size_t ldb, size_t actual_block_size, + uint8_t* data_block, float& scale_block) { + float amax = 0.0f; // abs(max) + float max = 0.0f; + + for (int32_t kk = 0; kk < actual_block_size; kk++) { + const float v = b_begin[ldb * kk]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + scale_block = max / (-8.f); + const float reciprocal_scale = scale_block ? 1.0f / scale_block : 0.0f; + + for (int32_t kk = 0; kk < actual_block_size; kk += 2) { + const float v0 = b_begin[ldb * kk] * reciprocal_scale; + const uint8_t vi0 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v0 + 8.f))); + + const float v1 = (kk + 1 < actual_block_size) ? b_begin[ldb * (kk + 1)] * reciprocal_scale : 0; + const uint8_t vi1 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v1 + 8.f))); + + data_block[kk / 2] = vi0 | (vi1 << 4); + } + } + + static void DequantizeBlock(float* b_begin, size_t ldb, size_t actual_block_size, + const uint8_t* data_block, float scale_block, uint8_t zp_block) { + for (size_t kk = 0; kk < actual_block_size; kk += 2) { + float x0 = static_cast(data_block[kk / 2] & 0x0F); + b_begin[ldb * kk] = scale_block * (x0 - zp_block); + + if (kk + 1 < actual_block_size) { + float x1 = static_cast(data_block[kk / 2] >> 4); + b_begin[ldb * (kk + 1)] = scale_block * (x1 - zp_block); + } + } + } + + static void DequantizeBlock(float* b_begin, size_t ldb, size_t actual_block_size, + const uint8_t* data_block, float scale_block) { + DequantizeBlock(b_begin, ldb, actual_block_size, data_block, scale_block, uint8_t{8}); + } +}; + +} // namespace + +/** + * @brief Test class for n-bit int block quantized GEMM + * Note: only 2-D matmul supported for now + */ +template +class MlasSQNBitGemmTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferA; + MatrixGuardBuffer BufferB; + MatrixGuardBuffer BufferPackedBData; + MatrixGuardBuffer BufferPackedBZeroPoint; + MatrixGuardBuffer BufferPackedBScale; + MatrixGuardBuffer BufferUnpackedBReference; + MatrixGuardBuffer BufferBias; + MatrixGuardBuffer BufferC; + MatrixGuardBuffer BufferCReference; + + void CallGemm(size_t M, + size_t N, + size_t K, + const float* A, + size_t lda, + const uint8_t* PackedBData, + const float* PackedBScale, + const uint8_t* PackedBZeroPoint, + const float* Bias, + float* C, + size_t ldc, + MLAS_THREADPOOL* Threadpool) { + MLAS_SQNBIT_GEMM_DATA_PARAMS params; + params.A = A; + params.lda = lda; + params.Bias = Bias; + params.C = C; + params.ldc = ldc; + params.PackedBData = PackedBData; + params.PackedBScale = PackedBScale; + params.PackedBZeroPoint = PackedBZeroPoint; + params.PostProcessor = nullptr; + + MlasSQNBitGemmBatch(M, N, K, 1, BlkLen, BlkBitWidth, ¶ms, Threadpool); + } + + void CallReferenceGemm(size_t M, + size_t N, + size_t K, + const float* A, + const uint8_t* PackedBData, + const float* PackedBScale, + const uint8_t* PackedBZeroPoint, + const float* Bias, + float* C) { + float* UnpackedBData = BufferUnpackedBReference.GetBuffer(K * N); + ReferenceQNBitPacking::UnpackB( + N, K, PackedBData, PackedBScale, PackedBZeroPoint, UnpackedBData, N); + + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + const float* a = A + m * K; + const float* b = UnpackedBData + n; + float* c = C + (m * N) + n; + + float sum = Bias == nullptr ? 0.0f : Bias[n]; + for (size_t k = 0; k < K; k++) { + sum += (*a) * (*b); + b += N; + a += 1; + } + *c = sum; + } + } + } + + public: + void Test(size_t M, size_t N, size_t K, + bool WithBias, bool WithZeroPoint, bool WithThreadpool) { + MLAS_THREADPOOL* Threadpool = WithThreadpool ? GetMlasThreadPool() : nullptr; + + const float* A = BufferA.GetBuffer(K * M); + + const float* B = BufferB.GetBuffer(N * K); + + const float* Bias = nullptr; + if (WithBias) { + Bias = BufferBias.GetBuffer(N); + } + +#if 0 + auto print_matrix = [](size_t ncols, size_t nrows, const float* data) { + for (size_t row = 0; row < nrows; ++row) { + for (size_t col = 0; col < ncols; ++col) { + std::cout << data[row * nrows + col] << "\t"; + } + std::cout << "\n"; + } + }; + + std::cout << "A:\n"; + print_matrix(M, K, A); + std::cout << "B:\n"; + print_matrix(K, N, B); +#endif + + float* C = BufferC.GetBuffer(N * M, true); + float* CReference = BufferCReference.GetBuffer(N * M, true); + + // pack B + uint8_t* PackedBData = nullptr; + float* PackedBScale = nullptr; + uint8_t* PackedBZeroPoint = nullptr; + { + size_t PackedBDataSize, PackedBScaleSize, PackedBZeroPointSize; + ReferenceQNBitPacking::GetPackedBSizes( + N, K, PackedBDataSize, PackedBScaleSize, &PackedBZeroPointSize); + + PackedBData = BufferPackedBData.GetBuffer(PackedBDataSize); + PackedBScale = BufferPackedBScale.GetBuffer(PackedBScaleSize); + if (WithZeroPoint) { + PackedBZeroPoint = BufferPackedBZeroPoint.GetBuffer(PackedBZeroPointSize); + } + + ReferenceQNBitPacking::PackB(N, K, B, /* ldb */ N, + PackedBData, PackedBScale, PackedBZeroPoint); + } + + CallGemm(M, N, K, A, /* lda */ K, PackedBData, PackedBScale, PackedBZeroPoint, Bias, C, /* ldc */ N, Threadpool); + CallReferenceGemm(M, N, K, A, PackedBData, PackedBScale, PackedBZeroPoint, Bias, CReference); + + size_t f = 0; + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++, f++) { + ASSERT_TRUE(CloseEnough(C[f], CReference[f])) + << "Expected: " << CReference[f] << " Actual: " << C[f] << "@[" << m << "x" << n << "], " + << "M=" << M << ", N=" << N << ", K=" << K; + } + } + } + + public: + static const char* GetTestSuiteName() { + static std::string suite_name = std::string("SQNBitGemm") + + "BlkLen" + std::to_string(BlkLen) + + "BlkBitWidth" + std::to_string(BlkBitWidth); + return suite_name.c_str(); + } +}; + +// +// Short Execute() test helper to register each test separately by all parameters. +// +template +class SQNBitGemmShortExecuteTest : public MlasTestFixture> { + public: + explicit SQNBitGemmShortExecuteTest(size_t M, size_t N, size_t K, + bool WithThreadpool, bool WithZeroPoint, bool WithBias) + : M_(M), N_(N), K_(K), WithThreadpool_(WithThreadpool), WithZeroPoint_(WithZeroPoint), WithBias_(WithBias) { + } + + void TestBody() override { + MlasTestFixture::mlas_tester->Test( + M_, N_, K_, WithThreadpool_, WithZeroPoint_, WithBias_); + } + + static size_t RegisterSingleTest(size_t M, size_t N, size_t K, + bool WithThreadpool, bool WithZeroPoint, bool WithBias) { + std::stringstream ss; + ss << (WithThreadpool ? "SingleThread" : "Threaded") + << "/hasZeroPoint" << WithZeroPoint + << "/M" << M << "xN" << N << "xK" << K + << "/hasBias" << WithBias; + auto test_name = ss.str(); + + testing::RegisterTest( + MlasTesterType::GetTestSuiteName(), + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + // Important to use the fixture type as the return type here. + [=]() -> MlasTestFixture* { + return new SQNBitGemmShortExecuteTest( + M, N, K, WithThreadpool, WithZeroPoint, WithBias); + }); + + return 1; + } + + static size_t RegisterShortExecuteTests() { + size_t test_registered = 0; + + for (bool WithThreadpool : {false, true}) { + for (bool WithZeroPoint : {false, true}) { + for (size_t b = 1; b < 16; b++) { + test_registered += RegisterSingleTest(b, b, b, WithThreadpool, WithZeroPoint, false); + test_registered += RegisterSingleTest(b, b, b, WithThreadpool, WithZeroPoint, true); + } + for (size_t b = 16; b <= 256; b <<= 1) { + test_registered += RegisterSingleTest(b, b, b, WithThreadpool, WithZeroPoint, false); + test_registered += RegisterSingleTest(b, b, b, WithThreadpool, WithZeroPoint, true); + } + for (size_t b = 256; b < 320; b += 32) { + test_registered += RegisterSingleTest(b, b, b, WithThreadpool, WithZeroPoint, true); + } + for (size_t b = 1; b < 96; b++) { + test_registered += RegisterSingleTest(1, b, 32, WithThreadpool, WithZeroPoint, false); + test_registered += RegisterSingleTest(1, 32, b, WithThreadpool, WithZeroPoint, true); + test_registered += RegisterSingleTest(1, b, b, WithThreadpool, WithZeroPoint, false); + } + test_registered += RegisterSingleTest(43, 500, 401, WithThreadpool, WithZeroPoint, true); + + // test_registered += RegisterSingleTest(1001, 1027, 1031, WithThreadpool, WithZeroPoint, false); + } + } + + return test_registered; + } + + private: + size_t M_, N_, K_; + bool WithThreadpool_, WithZeroPoint_, WithBias_; +}; + +template <> +MlasSQNBitGemmTest<32, 4>* MlasTestFixture>::mlas_tester(nullptr); + +static size_t SQNBitGemmRegisterAllShortExecuteTests() { + size_t count = 0; + + count += SQNBitGemmShortExecuteTest<32, 4>::RegisterShortExecuteTests(); + + return count; +} + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + if (is_short_execute) { + return SQNBitGemmRegisterAllShortExecuteTests() > 0; + } + return false; +}); diff --git a/onnxruntime/test/mlas/unittest/test_util.h b/onnxruntime/test/mlas/unittest/test_util.h index c5ee8b4b6115a..1d25af677ae1a 100644 --- a/onnxruntime/test/mlas/unittest/test_util.h +++ b/onnxruntime/test/mlas/unittest/test_util.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -177,6 +178,8 @@ bool AddTestRegister(TestRegister test_register); template class MlasTestFixture : public testing::Test { public: + using MlasTesterType = TMlasTester; + static void SetUpTestSuite() { mlas_tester = new TMlasTester(); }; @@ -254,3 +257,16 @@ inline void ReorderInputNchw(const int64_t* input_shape, const float* S, float* D += spatial_count * nchwc_channel_count; } } + +inline bool CloseEnough(float actual, float expected) { + if (std::isnan(actual)) { + return std::isnan(expected); + } + float diff = std::abs(actual - expected); + float top = std::max(std::abs(actual), std::abs(expected)); + float ratio = 0; + if (top > 0.0001) { + ratio = diff / top; + } + return ratio < 0.005; +}