From db76bb48e88152eb16cc2a376a1d700cd8c8028d Mon Sep 17 00:00:00 2001 From: Jing Fang <126209182+fajin-corp@users.noreply.github.com> Date: Fri, 15 Nov 2024 22:59:11 +0000 Subject: [PATCH] [ARM] MatMulNBits fp16 support - connect kernels (#22856) ### Description A breakdown PR of https://github.com/microsoft/onnxruntime/pull/22651 ### Motivation and Context --- cmake/onnxruntime_mlas.cmake | 8 +- .../mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp | 2 +- onnxruntime/core/mlas/lib/platform.cpp | 4 +- onnxruntime/core/mlas/lib/qnbitgemm.cpp | 94 +++- onnxruntime/core/mlas/lib/qnbitgemm.h | 81 ++- ...nel_neon.cpp => qnbitgemm_kernel_neon.cpp} | 17 +- ..._kernel_neon.h => qnbitgemm_kernel_neon.h} | 45 +- .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 4 +- .../mlas/lib/sqnbitgemm_kernel_avx512.cpp | 2 +- .../mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp | 2 +- .../mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp | 4 +- .../mlas/lib/sqnbitgemm_kernel_neon_int8.cpp | 2 +- .../test/contrib_ops/matmul_4bits_test.cc | 74 +-- ...nch_sqnbitgemm.cpp => bench_qnbitgemm.cpp} | 91 ++-- onnxruntime/test/mlas/bench/bench_util.h | 25 +- .../mlas/unittest/test_hqnbitgemm_neon.cpp | 504 ++++++++++++++++++ onnxruntime/test/mlas/unittest/test_util.h | 12 +- 17 files changed, 850 insertions(+), 121 deletions(-) rename onnxruntime/core/mlas/lib/{sqnbitgemm_kernel_neon.cpp => qnbitgemm_kernel_neon.cpp} (87%) rename onnxruntime/core/mlas/lib/{sqnbitgemm_kernel_neon.h => qnbitgemm_kernel_neon.h} (74%) rename onnxruntime/test/mlas/bench/{bench_sqnbitgemm.cpp => bench_qnbitgemm.cpp} (65%) create mode 100644 onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index a85ea942c42a3..22971f3313a60 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -84,8 +84,8 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.h + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp ${MLAS_SRC_DIR}/fp16_neon_common.cpp @@ -363,8 +363,8 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.h + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp ) diff --git a/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp index f1bc013a469d9..69e37d2b916d1 100644 --- a/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp +++ b/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp @@ -24,7 +24,7 @@ Module Name: #include "fp16_common.h" #include "qnbitgemm.h" -#include "sqnbitgemm_kernel_neon.h" +#include "qnbitgemm_kernel_neon.h" namespace sqnbitgemm_neon { diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 12f7dd3e74dbc..81bef3b9f194c 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -531,6 +531,7 @@ Return Value: this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchNeon; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; // // Check if the processor supports ASIMD dot product instructions. @@ -560,9 +561,6 @@ Return Value: this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchSdot; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchDot; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; - - // MlasSQNBitGemmDispatchNeon has a dependency on dot product instructions - this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; } #if defined(__linux__) diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index 635a3b47a23fa..f064a8e1d6a78 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -82,7 +82,12 @@ MlasIsQNBitGemmAvailable( switch (Variant) { case SQNBitGemmVariant_BitWidth4_CompFp32: { return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr && - Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr; + Dispatch->SQ4BitBlkDequantBForSgemm_CompFp32 != nullptr; + } + case HQNBitGemmVariant_BitWidth4_CompFp16: { + return Dispatch->HQ4BitGemmPackQuantBData != nullptr && + Dispatch->HQ4BitGemmKernel_CompFp16 != nullptr && + Dispatch->HQ4BitBlkDequantBForHgemm_CompFp16 != nullptr; } case SQNBitGemmVariant_BitWidth4_CompInt8: { // SQ4BitGemmKernel_BlkSum_CompInt8 return @@ -253,6 +258,16 @@ MlasQNBitGemmPackQuantBData( packed_quant_b, ThreadPool ); + } else if (ComputeType == HQNBIT_CompFp16 && Dispatch->HQ4BitGemmPackQuantBData != nullptr) { + Dispatch->HQ4BitGemmPackQuantBData( + N, + K, + BlkLen, + ComputeType, + static_cast(QuantBData), + static_cast(PackedQuantBDataAndOrBlkSumWorkspace), + ThreadPool + ); } else if (Dispatch->SQ4BitGemmPackQuantBData != nullptr) { // TODO: these assertions are true if called from matmul_nbits kernel but not from mlas tests. //assert(QuantBScale == nullptr); @@ -387,7 +402,7 @@ SQ4BitGemm_CompFp32( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - GetMlasPlatform().QNBitGemmDispatch->Q4BitBlkDequantBForSgemm_CompFp32( + GetMlasPlatform().QNBitGemmDispatch->SQ4BitBlkDequantBForSgemm_CompFp32( BlkLen, dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks ); @@ -419,6 +434,79 @@ SQ4BitGemm_CompFp32( } } +void +HQ4BitGemm_CompFp16( + const size_t BlkLen, + const size_t K, + const MLAS_QNBIT_GEMM_DATA_PARAMS* const DataParams, + void* const PerGemmWorkspace, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) +{ + constexpr size_t BlkBitWidth = 4; + MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspace); + + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; + const size_t k_blk_num = MlasDivRoundup(K, BlkLen); + const size_t qldb = k_blk_num * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t ldb = k_blk_num * BlkLen; + const size_t k_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blk_num); + + const MLAS_FP16* A = DataParams->A + RangeStartM * lda; + MLAS_FP16* C = DataParams->C + RangeStartM * ldc + RangeStartN; + const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * qldb; + const MLAS_FP16* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blk_num; + const std::byte* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_zp_bytes; + const MLAS_FP16* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias; + + // 32N is the sweet spot of cache utilization. It is machine dependent though. + constexpr size_t StrideM = 2; + constexpr size_t StrideN = 32; + + // TODO(fajin): move allocation up to the op. + size_t bufsize = ldb * StrideN * sizeof(MLAS_FP16); + MlasThreadedBufAlloc(bufsize); + auto* dequant_b = reinterpret_cast(ThreadedBufHolder.get()); + + for (size_t n = 0, countN; n < RangeCountN; n += countN) { + countN = std::min(StrideN, RangeCountN - n); + GetMlasPlatform().QNBitGemmDispatch->HQ4BitBlkDequantBForHgemm_CompFp16( + BlkLen, dequant_b, QuantBData, QuantBScale, QuantBZeroPoint, countN, K, k_blk_num + ); + + const MLAS_FP16* a = A; + MLAS_FP16* c = C; + for (size_t m = 0, countM; m < RangeCountM; m += countM) { + countM = std::min(StrideM, RangeCountM - m); + GetMlasPlatform().QNBitGemmDispatch->HQ4BitGemmKernel_CompFp16( + a, dequant_b, Bias, c, countM, countN, K, lda, ldb, ldc + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM + m, RangeStartN + n, countM, countN, ldc + ); + } + + a += countM * lda; + c += countM * ldc; + } + + QuantBData += countN * qldb; + QuantBScale += countN * k_blk_num; + QuantBZeroPoint = QuantBZeroPoint ? QuantBZeroPoint + countN * k_zp_bytes : nullptr; + Bias = Bias ? Bias + countN : nullptr; + C += countN; + } +} + void SQ4BitGemm_CompInt8( const size_t BlkLen, @@ -720,7 +808,7 @@ GetQNBitGemm(QNBitGemmVariant variant) { switch (variant) { case HQNBitGemmVariant_BitWidth4_CompFp16: - return nullptr; + return HQ4BitGemm_CompFp16; default: return nullptr; } diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.h b/onnxruntime/core/mlas/lib/qnbitgemm.h index 28e17f14b02c9..eb3d0b44ae3de 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm.h @@ -91,17 +91,17 @@ struct MLAS_QNBIT_GEMM_DISPATCH { // /** Gets size of packed quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBDataSize(). */ - typedef size_t(SQ4BitGemmPackQuantBDataSize_Fn)( + typedef size_t(Q4BitGemmPackQuantBDataSize_Fn)( size_t N, size_t K, size_t BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - SQ4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr; + Q4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr; /** Packs quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBData(). */ - typedef void(SQ4BitGemmPackQuantBData_Fn)( + typedef void(Q4BitGemmPackQuantBData_Fn)( size_t N, size_t K, size_t BlkLen, @@ -111,7 +111,8 @@ struct MLAS_QNBIT_GEMM_DISPATCH { MLAS_THREADPOOL* ThreadPool ); - SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; + Q4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; + Q4BitGemmPackQuantBData_Fn* HQ4BitGemmPackQuantBData = nullptr; typedef void(SQ4BitGemmPackQuantBDataAndSumBlk_Fn)( size_t N, @@ -142,7 +143,7 @@ struct MLAS_QNBIT_GEMM_DISPATCH { * @param[in] BlkLen number of quantized values per block * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ - typedef size_t(SQ4BitGemmPerGemmWorkspaceSize_Fn)( + typedef size_t(Q4BitGemmPerGemmWorkspaceSize_Fn)( size_t M, size_t N, size_t K, @@ -150,7 +151,7 @@ struct MLAS_QNBIT_GEMM_DISPATCH { MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - SQ4BitGemmPerGemmWorkspaceSize_Fn* Q4BitGemmPerGemmWorkspaceSize = nullptr; + Q4BitGemmPerGemmWorkspaceSize_Fn* Q4BitGemmPerGemmWorkspaceSize = nullptr; /** * @brief Gets the required byte alignment of the per-GEMM intermediate workspace. @@ -158,12 +159,12 @@ struct MLAS_QNBIT_GEMM_DISPATCH { * @param[in] BlkLen number of quantized values per block * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ - typedef size_t(SQ4BitGemmPerGemmWorkspaceAlignment_Fn)( + typedef size_t(Q4BitGemmPerGemmWorkspaceAlignment_Fn)( size_t BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - SQ4BitGemmPerGemmWorkspaceAlignment_Fn* Q4BitGemmPerGemmWorkspaceAlignment = nullptr; + Q4BitGemmPerGemmWorkspaceAlignment_Fn* Q4BitGemmPerGemmWorkspaceAlignment = nullptr; // // SQNBIT_CompFp32 kernel function prototypes. @@ -229,7 +230,38 @@ struct MLAS_QNBIT_GEMM_DISPATCH { size_t BlockStrideQuantB ); - Q4BitBlkDequantBForSgemm_CompFp32_Fn* Q4BitBlkDequantBForSgemm_CompFp32 = nullptr; + Q4BitBlkDequantBForSgemm_CompFp32_Fn* SQ4BitBlkDequantBForSgemm_CompFp32 = nullptr; + + /** + * @brief Dequantize B into the format expected by the Sgemm kernel. + * B is a quantized 4-bit integer matrix that is block quantized and column major. + * This is equivalent to dequantizing B and then running MlasSgemmCopyPackB. + * + * @param BlkLen Number of values in a block. + * @param[out] FpData Supplies the output buffer for the dequantized B float data. + * It should have enough space for + * (CountN + 16 - 1) / 16 * 16 * (CountK + BlkLen - 1) / BlkLen * BlkLen + * elements. Only the first (CountN + 16 - 1) / 16 * 16 * CountK elements are + * useful, but the kernel implementation can be simplified with the extra space. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param CountN Number of columns of B. + * @param CountK Number of rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + */ + typedef void(Q4BitBlkDequantBForSgemm_CompFp16_Fn)( + size_t BlkLen, + MLAS_FP16* FpData, + const std::byte* QuantBData, + const MLAS_FP16* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB + ); + + Q4BitBlkDequantBForSgemm_CompFp16_Fn* HQ4BitBlkDequantBForHgemm_CompFp16 = nullptr; // // SQNBIT_CompInt8 kernel function prototypes. @@ -338,4 +370,35 @@ struct MLAS_QNBIT_GEMM_DISPATCH { float* AScaledGroupSum // scale_k * Sum_blklen(a_i) ); QuantizeARowComputeBlkSum_CompInt8_Fn* QuantizeARowComputeBlkSum_CompInt8 = nullptr; + + /** + * @brief Multiply fp16 matrix A rows with fp16 matrix B columns. + * Results are written to fp16 matrix C. + * If bias is provided, the bias are added to the result. + * + * @param A first row of the A matrix segment. Row major. + * @param B first column of the B matrix segment. Column major. + * @param Bias the bias at the target column. Optional. + * @param[out] C first element of the output matrix segment. Row major. + * @param CountM the number of rows of A chunk. + * @param CountN the number of columns of B chunk. + * @param K the number of columns of A matrix and rows of B matrix. + * @param lda the leading dimension of A. + * @param ldb the leading dimension of B. + * @param ldc the leading dimension of C. + */ + typedef void(HQ4BitGemmKernel_CompFp16_Fn)( + const MLAS_FP16* A, + const MLAS_FP16* B, + const MLAS_FP16* Bias, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t K, + size_t lda, + size_t ldb, + size_t ldc + ); + + HQ4BitGemmKernel_CompFp16_Fn* HQ4BitGemmKernel_CompFp16 = nullptr; }; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp similarity index 87% rename from onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp rename to onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp index 03c8ce264c846..d05de64e68ec8 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - sqnbitgemm_kernel_neon.cpp + qnbitgemm_kernel_neon.cpp Abstract: @@ -20,7 +20,7 @@ Module Name: #include #include "qnbitgemm.h" -#include "sqnbitgemm_kernel_neon.h" +#include "qnbitgemm_kernel_neon.h" #include "sqnbitgemm_q8_block.h" namespace sqnbitgemm_neon @@ -185,10 +185,17 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { d.Q4BitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = sqnbitgemm_neon::SQ4BitGemmM1Kernel_CompFp32; - d.Q4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::Q4BitBlkDequantBForSgemm_CompFp32; - - d.SQ4BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_CompInt8; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::SQ4BitBlkDequantBForSgemm_CompFp32; + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot()) { + d.SQ4BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_CompInt8; + } d.QuantizeARow_CompInt8 = sqnbitgemm_neon::QuantizeARow_CompInt8; +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + d.HQ4BitGemmPackQuantBData = sqnbitgemm_neon::HQ4BitGemmPackQuantBData_CompFp16; + d.HQ4BitBlkDequantBForHgemm_CompFp16 = sqnbitgemm_neon::HQ4BitBlkDequantBForHgemm_CompFp16; + d.HQ4BitGemmKernel_CompFp16 = sqnbitgemm_neon::HQ4BitGemmKernel_CompFp16; +#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64 + return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h similarity index 74% rename from onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h rename to onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h index 247d885615393..ccadd24ac1991 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - sqnbitgemm_kernel_neon.h + qnbitgemm_kernel_neon.h Abstract: @@ -53,7 +53,7 @@ SQ4BitGemmM1Kernel_CompFp32( ); void -Q4BitBlkDequantBForSgemm_CompFp32( +SQ4BitBlkDequantBForSgemm_CompFp32( size_t BlkLen, float* FpData, const std::byte* QuantBData, @@ -64,6 +64,47 @@ Q4BitBlkDequantBForSgemm_CompFp32( size_t BlockCountK ); +// HQNBIT_CompFp16 declarations +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) +void +HQ4BitGemmPackQuantBData_CompFp16( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +); + +void +HQ4BitBlkDequantBForHgemm_CompFp16( + size_t BlkLen, + MLAS_FP16* FpData, + const std::byte* QuantBData, + const MLAS_FP16* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t K, + size_t BlockCountK +); + +void +HQ4BitGemmKernel_CompFp16( + const MLAS_FP16* A, + const MLAS_FP16* B, + const MLAS_FP16* Bias, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t K, + size_t lda, + size_t ldb, + size_t ldc +); + +#endif // !(defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)) + // SQNBIT_CompInt8 declarations void diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 01443e2ff077f..81615da46aa2e 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -1341,7 +1341,7 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; @@ -1360,7 +1360,7 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index 425dfbe87c982..b4e25d4e4040a 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -363,7 +363,7 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index 777d4609ef5d4..a4468bb906bbc 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -348,7 +348,7 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp index 7b9f05a9c385d..31a499b8243af 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp @@ -22,7 +22,7 @@ Module Name: #include #include "qnbitgemm.h" -#include "sqnbitgemm_kernel_neon.h" +#include "qnbitgemm_kernel_neon.h" namespace sqnbitgemm_neon { @@ -608,7 +608,7 @@ Q4BitBlkDequantBForSgemm_CompFp32_Impl( } // namespace void -Q4BitBlkDequantBForSgemm_CompFp32( +SQ4BitBlkDequantBForSgemm_CompFp32( size_t BlkLen, float* FpData, const std::byte* QuantBData, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp index f1acd99c7b693..73beb06a3cfad 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp @@ -22,7 +22,7 @@ Module Name: #include #include "qnbitgemm.h" -#include "sqnbitgemm_kernel_neon.h" +#include "qnbitgemm_kernel_neon.h" #include "sqnbitgemm_q8_block.h" namespace sqnbitgemm_neon diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 7d99ffab9a88f..87a9ef762b9ac 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -82,6 +82,7 @@ struct TestOptions { bool has_bias{false}; std::optional output_abs_error{}; + std::optional output_rel_error{}; }; std::ostream& operator<<(std::ostream& os, const TestOptions& opts) { @@ -253,6 +254,10 @@ void RunTest(const TestOptions& opts, test.SetOutputAbsErr("Y", *opts.output_abs_error); } + if (opts.output_rel_error.has_value()) { + test.SetOutputRelErr("Y", *opts.output_rel_error); + } + if (!explicit_eps.empty()) { test.ConfigEps(std::move(explicit_eps)); } @@ -269,16 +274,11 @@ void TestMatMulNBitsTyped() { base_opts.block_size = block_size; base_opts.accuracy_level = accuracy_level; - if (base_opts.accuracy_level == 4) { + if constexpr (std::is_same::value) { + base_opts.output_abs_error = 0.055f; + base_opts.output_rel_error = 0.02f; + } else if (base_opts.accuracy_level == 4) { base_opts.output_abs_error = 0.1f; - } else { - if constexpr (std::is_same::value) { -#ifdef USE_WEBGPU - base_opts.output_abs_error = 0.03f; -#else - base_opts.output_abs_error = 0.01f; -#endif - } } { @@ -391,48 +391,48 @@ TEST(MatMulNBits, Float32_Accuracy4) { TestMatMulNBitsTyped(); } -#ifdef MLAS_TARGET_AMD64_IX86 +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_ARM64) #if !defined(USE_DML) // Actual and expected difference is over 0.01 with DmlExecutionProvider. // Skip the tests instead of raising the tolerance to make is pass. +TEST(MatMulNBits, Float16_Accuracy2) { + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); +} + TEST(MatMulNBits, Float16_Accuracy0) { TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); } -TEST(MatMulNBits, Float16_Accuracy1) { - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); -} - TEST(MatMulNBits, Float16_Accuracy4) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp similarity index 65% rename from onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp rename to onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp index df543d8eca1fc..64d229889214b 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include "benchmark/benchmark.h" @@ -16,16 +17,16 @@ #include "core/util/thread_utils.h" #include "core/platform/env_var_utils.h" -template -void RunSQNBitGemmBenchmark(size_t BlkLen, - size_t M, size_t N, size_t K, - size_t Threads, - bool Symmetric, - bool HasBias, - MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, - benchmark::State& state) { +template +void RunQNBitGemmBenchmark(size_t BlkLen, + size_t M, size_t N, size_t K, + size_t Threads, + bool Symmetric, + bool HasBias, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + benchmark::State& state) { if (!MlasIsQNBitGemmAvailable(BlkBitWidth, BlkLen, ComputeType)) { - state.SkipWithMessage("SQNBitGemm is not available with the given configuration on the current machine."); + state.SkipWithMessage("QNBitGemm is not available with the given configuration on the current machine."); return; } @@ -43,19 +44,19 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); - const auto A = RandomVectorUniform(M * K, -1.0f, 1.0f); - const auto B = RandomVectorUniform(K * N, -1.0f, 1.0f); + const auto A = RandomVectorUniform(M * K, AType(-1.0f), AType(1.0f)); + const auto B = RandomVectorUniform(K * N, AType(-1.0f), AType(1.0f)); - const auto Bias = HasBias ? RandomVectorUniform(N, -1.0f, 1.0f) : std::vector(); + const auto Bias = HasBias ? RandomVectorUniform(N, AType(-1.0f), AType(1.0f)) : std::vector(); - std::vector C(static_cast(M * N)); + std::vector C(static_cast(M * N)); std::vector QuantBData(QuantBDataSizeInBytes); - std::vector QuantBScale(QuantBScaleSize); + std::vector QuantBScale(QuantBScaleSize); std::vector QuantBZeroPoint(Symmetric ? 0 : QuantBZeroPointSizeInBytes); bool has_zp_input = !Symmetric; - MlasQuantizeBlockwise(QuantBData.data(), QuantBScale.data(), + MlasQuantizeBlockwise(QuantBData.data(), QuantBScale.data(), Symmetric ? nullptr : QuantBZeroPoint.data(), B.data(), static_cast(BlkLen), /* columnwise */ true, static_cast(K), static_cast(N), static_cast(N), @@ -76,7 +77,7 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, tp.get()); } - MLAS_QNBIT_GEMM_DATA_PARAMS params{}; + MLAS_QNBIT_GEMM_DATA_PARAMS params{}; params.A = A.data(); params.lda = K; if (PackedQuantBData != nullptr) @@ -99,8 +100,8 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, } } -template -void SQNBITGEMM(benchmark::State& state) { +template +void QNBITGEMM(benchmark::State& state) { using onnxruntime::narrow; const auto BlkLen = narrow(state.range(0)); @@ -112,44 +113,48 @@ void SQNBITGEMM(benchmark::State& state) { const bool HasBias = narrow(state.range(6)); const auto ComputeType = static_cast(state.range(7)); - RunSQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, ComputeType, state); + RunQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, ComputeType, state); } -static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) { +template +static void QNBitGemmArgs(benchmark::internal::Benchmark* b) { b->ArgNames({"BlkLen", "M", "N", "K", "Threads", "Symmetric", "HasBias", "ComputeType"}); b->ArgsProduct({ - {128}, // BlkLen - {1}, // M - {4096, 11008}, // N - {4096, 11008}, // K - {1, 8}, // Threads - {int64_t{false}, int64_t{true}}, // Symmetric - {int64_t{false}, int64_t{true}}, // HasBias - {int64_t{SQNBIT_CompFp32}, int64_t{SQNBIT_CompInt8}}, // ComputeType + {128}, // BlkLen + {1, 4096}, // M + {4096, 11008}, // N + {4096, 11008}, // K + {1, 8}, // Threads + {int64_t{false}, int64_t{true}}, // Symmetric + {int64_t{false}, int64_t{true}}, // HasBias + std::is_same_v + ? std::vector{int64_t{HQNBIT_CompFp16}} + : std::vector{int64_t{SQNBIT_CompFp32}, int64_t{SQNBIT_CompInt8}}, // ComputeType }); } -BENCHMARK(SQNBITGEMM<4>)->Apply(SQNBitGemmArgs)->UseRealTime(); +BENCHMARK(QNBITGEMM)->Apply(QNBitGemmArgs)->UseRealTime(); +BENCHMARK(QNBITGEMM)->Apply(QNBitGemmArgs)->UseRealTime(); // This test gets benchmark arguments from environment variables. -template -void SQNBITGEMM_ENV(benchmark::State& state) { +template +void QNBITGEMM_ENV(benchmark::State& state) { using onnxruntime::ParseEnvironmentVariableWithDefault; - const auto BlkLen = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_BLKLEN", 32); - const auto M = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_M", 1); - const auto N = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_N", 4096); - const auto K = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_K", 4096); - const auto Threads = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_THREADS", 1); - const auto Symmetric = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_SYMMETRIC", true); - const auto HasBias = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_HAS_BIAS", false); - const auto ComputeType = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_COMPUTE_TYPE", + const auto BlkLen = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_BLKLEN", 32); + const auto M = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_M", 1); + const auto N = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_N", 4096); + const auto K = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_K", 4096); + const auto Threads = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_THREADS", 1); + const auto Symmetric = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_SYMMETRIC", true); + const auto HasBias = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_HAS_BIAS", false); + const auto ComputeType = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_COMPUTE_TYPE", static_cast(SQNBIT_CompFp32)); - RunSQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, - static_cast(ComputeType), - state); + RunQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, + static_cast(ComputeType), + state); std::ostringstream s; s << "BlkBitWidth:" << BlkBitWidth << "/BlkLen:" << BlkLen @@ -159,4 +164,4 @@ void SQNBITGEMM_ENV(benchmark::State& state) { state.SetLabel(s.str()); } -BENCHMARK(SQNBITGEMM_ENV<4>)->UseRealTime(); +BENCHMARK(QNBITGEMM_ENV)->UseRealTime(); diff --git a/onnxruntime/test/mlas/bench/bench_util.h b/onnxruntime/test/mlas/bench/bench_util.h index f96dd5c673b3d..78789ef1cbc1a 100644 --- a/onnxruntime/test/mlas/bench/bench_util.h +++ b/onnxruntime/test/mlas/bench/bench_util.h @@ -8,8 +8,12 @@ #include #include +#include "core/framework/float16.h" +#include "core/mlas/inc/mlas.h" + template -std::vector RandomVectorUniform( +typename std::enable_if_t, std::vector> +RandomVectorUniform( size_t N, ElementType min_value = std::numeric_limits::lowest(), ElementType max_value = std::numeric_limits::max()) { @@ -26,6 +30,25 @@ std::vector RandomVectorUniform( return r; } +template +typename std::enable_if_t, std::vector> +RandomVectorUniform( + size_t N, + ElementType min_value, + ElementType max_value) { + if (min_value.ToFloat() >= max_value.ToFloat()) { + return std::vector(N, min_value); + } + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(min_value.ToFloat(), max_value.ToFloat()); + + std::vector r(N); + for (size_t i = 0; i < N; i++) { + r[i] = ElementType(distribution(generator)); + } + return r; +} + std::vector RandomVectorUniform(std::vector shape, float min_value, float max_value); std::vector BenchArgsVector(benchmark::State& state, size_t& start, size_t count); diff --git a/onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp b/onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp new file mode 100644 index 0000000000000..a455007c2f6ae --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp @@ -0,0 +1,504 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_hqnbitgemm_neon.cpp + +Abstract: + + Tests for MLAS n-bit int block quantized GEMM on ARM CPU with input A type T1 fp16. + +--*/ + +#include +#include + +#include "test_util.h" +#include "core/mlas/lib/mlasi.h" +#include "core/mlas/lib/qnbitgemm.h" +#include "mlas_qnbit.h" + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + +class MlasNeonFp16CastTest : public MlasTestBase { + private: + MatrixGuardBuffer fp32Buffer_; + MatrixGuardBuffer fp16Buffer_; + + template + void TestFp16ToFp32() { + const auto* src = fp16Buffer_.GetFilledBuffer(count, [](unsigned short* start, size_t size) { + for (size_t i = 0; i < size; i++) { + start[i] = static_cast(i); + } + }); + auto* dest = fp32Buffer_.GetBuffer(count, true); + + MlasCastF16ToF32KernelNeon(src, dest, count); + + for (size_t i = 0; i < count; i++) { + if ((src[i] & 0x1c00) == 0x1c00) continue; // skip inf and nan + ASSERT_EQ(dest[i], MLAS_FP16::FromBits(src[i]).ToFloat()); + } + } + + template + void TestFp32ToFp16() { + const auto* src = fp32Buffer_.GetFilledBuffer(count, [](float* p, size_t size) { + for (size_t i = 0; i < size; i++) { + p[i] = static_cast(i) + 0.125f; + } + }); + auto* dest = fp16Buffer_.GetBuffer(count, true); + + MlasCastF32ToF16KernelNeon(src, dest, count); + + for (size_t i = 0; i < count; i++) { + ASSERT_EQ(dest[i], MLAS_FP16(src[i]).val); + } + } + + public: + static const char* GetTestSuiteName() { + return "NeonFp16Cast"; + } + + void ExecuteShort(void) override { + TestFp16ToFp32<(1 << 16)>(); + TestFp16ToFp32<1>(); + TestFp16ToFp32<4>(); + TestFp16ToFp32<7>(); + TestFp32ToFp16<(1 << 16)>(); + TestFp32ToFp16<3>(); + TestFp32ToFp16<4>(); + TestFp32ToFp16<6>(); + } +}; + +class MlasNeonFp16PrepackTest : public MlasTestBase { + private: + std::random_device rd_; // a seed source for the random number engine + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_int_distribution<> distrib_; + MatrixGuardBuffer input_, ref_, packed_; + + template + MLAS_FORCEINLINE void Transpose8x8(const uint8_t* src, size_t n, size_t k, uint8_t* dst) { + for (size_t c = 0; c < 8; c++) { + for (size_t r = 0; r < 8; r++) { + size_t i = (n + c) * Ldb + r + k; + size_t j = n * Ldb + (r + k) * 8 + c; + dst[j] = src[i]; + } + } + } + + MLAS_FORCEINLINE + uint8_t GetInt4(uint8_t v, size_t i) { + return (i & 1) ? (v >> 4) : (v & 0x0f); + } + + MLAS_FORCEINLINE + void PrepackSlice(const uint8_t* src, size_t j, uint8_t* dst) { + for (size_t i = 0; i < 8; i++) { + uint8_t v0 = GetInt4(src[j + (i >> 1)], i); + uint8_t v1 = GetInt4(src[j + ((8 + i) >> 1)], i + 8); + dst[j + i] = v0 | (v1 << 4); + } + } + + template + MLAS_FORCEINLINE void Prepack(const uint8_t* src, uint8_t* dst) { + size_t n = 0; + for (; n + 8 <= N; n += 8) { + for (size_t k = 0; k < Ldb; k += 8) { + Transpose8x8(src, n, k, dst); + } + } + + for (; n < N; ++n) { + for (size_t k = 0; k < Ldb; k += 8) { + PrepackSlice(src, n * Ldb + k, dst); + } + } + } + + template + MLAS_FORCEINLINE void Check(const uint8_t* packed, const uint8_t* ref) { + size_t n = 0; + for (; n + 8 <= N; n += 8) { + for (size_t i = 0; i < K; i += 2) { + for (size_t j = 0; j < 8; ++j) { + ASSERT_EQ(packed[n * Ldb + (i >> 1) * 8 + j], ref[n * Ldb + (i >> 1) * 8 + j]) + << " seed " << seed_ + << " n " << n << " i " << i << " j " << j; + } + } + } + + for (; n < N; ++n) { + for (size_t i = 0; i < K; i += 2) { + ASSERT_EQ(packed[n * Ldb + (i >> 1)], ref[n * Ldb + (i >> 1)]) + << " seed " << seed_ + << " n " << n << " i " << i; + } + } + } + + template + void TestPrepack() { + constexpr size_t Bits = 4; + constexpr size_t Ldb = (((K + BlkLen - 1) & (~(BlkLen - 1))) * Bits + 7) / 8; + constexpr size_t BufferSize = N * Ldb; + auto InitializeBuffer = [this](uint8_t* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = static_cast(distrib_(gen_)); + } + }; + + const auto* input = input_.GetFilledBuffer(BufferSize, InitializeBuffer); + auto* packed = packed_.GetBuffer(BufferSize, true); + auto* ref = ref_.GetBuffer(BufferSize, true); + MlasQNBitGemmPackQuantBData( + N, K, Bits, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::HQNBIT_CompFp16, input, packed, + nullptr, false, nullptr, nullptr); + Prepack(input, ref); + Check(packed, ref); + } + + public: + MlasNeonFp16PrepackTest() + : seed_(rd_()), gen_(seed_), distrib_(0, 255) { + } + + static const char* GetTestSuiteName() { + return "NeonFp16Prepack"; + } + + void ExecuteShort(void) override { + TestPrepack<1, 1, 16>(); + TestPrepack<1, 15, 16>(); + TestPrepack<1, 31, 16>(); + TestPrepack<8, 1, 16>(); + TestPrepack<8, 16, 16>(); + TestPrepack<9, 31, 16>(); + TestPrepack<9, 33, 32>(); + TestPrepack<15, 33, 16>(); + TestPrepack<17, 67, 16>(); + TestPrepack<17, 96, 128>(); + TestPrepack<263, 263, 16>(); + } +}; + +class MlasNeonFp16DequantBTest : public MlasTestBase { + private: + std::random_device rd_; // a seed source for the random number engine + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_int_distribution<> distrib_; + std::uniform_real_distribution _distribFp; + MatrixGuardBuffer input_, zero_points_; + MatrixGuardBuffer dequant_, ref_, scales_; + + MLAS_FORCEINLINE + uint8_t GetInt4(uint8_t v, size_t i) { + return (i & 1) ? (v >> 4) : (v & 0x0f); + } + + template + void DequantB(const uint8_t* src, MLAS_FP16* dst, const MLAS_FP16* scales, const uint8_t* zero_points) { + constexpr size_t blkNum = (K + BlkLen - 1) / BlkLen; + constexpr size_t ld_src = (blkNum * BlkLen + 1) / 2; + constexpr size_t ld_dst = blkNum * BlkLen; + constexpr size_t ld_zp = (blkNum + 1) / 2; + size_t n = 0; + for (; n + 8 <= N; n += 8) { + size_t i_src = n * ld_src, i_dst = n * ld_dst, i_scale = n * blkNum, i_zp = n * ld_zp; + for (size_t blk = 0; blk < blkNum; i_zp += (blk & 1), ++blk, ++i_scale) { + for (size_t i = 0; i < BlkLen; i += 2, i_dst += 8) { + for (size_t j = 0; j < 8; ++j, ++i_src, ++i_dst) { + uint8_t v = src[i_src]; + float v0 = static_cast(GetInt4(v, 0)); + float v1 = static_cast(GetInt4(v, 1)); + float zp = static_cast(UseZeroPoints ? GetInt4(zero_points[i_zp + ld_zp * j], blk) : 8); + float scale = scales[i_scale + blkNum * j]; + dst[i_dst] = MLAS_FP16(v0 * scale - zp * scale); + dst[i_dst + 8] = MLAS_FP16(v1 * scale - zp * scale); + } + } + } + } + + for (; n < N; ++n) { + size_t i_src = n * ld_src, i_dst = n * ld_dst, i_scale = n * blkNum, i_zp = n * ld_zp; + for (size_t blk = 0; blk < blkNum; i_zp += (blk & 1), ++blk, ++i_scale) { + float zp = static_cast(UseZeroPoints ? GetInt4(zero_points[i_zp], blk) : 8); + float scale = scales[i_scale]; + for (size_t i = 0; i < BlkLen; i += 16, i_dst += 8) { + for (size_t j = 0; j < 16; j += 2, ++i_src, ++i_dst) { + uint8_t v = src[i_src]; + float v0 = static_cast(GetInt4(v, 0)); + float v1 = static_cast(GetInt4(v, 1)); + dst[i_dst] = MLAS_FP16(v0 * scale - zp * scale); + dst[i_dst + 8] = MLAS_FP16(v1 * scale - zp * scale); + } + } + } + } + } + + MLAS_FORCEINLINE + bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) { + float f0 = std::abs(v0.ToFloat()), f1 = std::abs(v1.ToFloat()); + return std::abs(f0 - f1) <= f1 * rtol + atol; + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* target, const MLAS_FP16* ref) { + size_t n = 0; + for (; n + 8 <= N; n += 8) { + for (size_t i = 0; i < K; ++i) { + for (size_t j = 0; j < 8; ++j) { + size_t idx = n * Ldb + i * 8 + j; + ASSERT_TRUE(FloatEqual(target[idx], ref[idx], 0.01f, 0.01f)) + << " seed " << seed_ + << " v0 " << target[idx] << " v1 " << ref[idx] + << " n " << n << " i " << i << " j " << j; + } + } + } + + for (; n < N; ++n) { + for (size_t i = 0; i < K; ++i) { + size_t idx = n * Ldb + i; + ASSERT_TRUE(FloatEqual(target[idx], ref[idx], 0.01f, 0.01f)) + << " seed " << seed_ + << " v0 " << target[idx] << " v1 " << ref[idx] + << " n " << n << " i " << i; + } + } + } + + template + void TestDequant() { + constexpr size_t BlkNum = (K + BlkLen - 1) / BlkLen; + constexpr size_t BCount = BlkNum * BlkLen * N; + constexpr size_t ScaleCount = N * BlkNum; + constexpr size_t ZpSize = N * ((BlkNum + 1) / 2); + + auto InitializeBuffer_i8 = [this](uint8_t* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = static_cast(distrib_(gen_)); + } + }; + + auto InitializeBuffer_fp16 = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(_distribFp(gen_)); + } + }; + + const auto* input = input_.GetFilledBuffer(BCount / 2, InitializeBuffer_i8); + const auto* zero_points = zero_points_.GetFilledBuffer(ZpSize, InitializeBuffer_i8); + auto* dequant = dequant_.GetBuffer(BCount); + auto* ref = ref_.GetBuffer(BCount); + const auto* scales = scales_.GetFilledBuffer(ScaleCount, InitializeBuffer_fp16); + GetMlasPlatform().QNBitGemmDispatch->HQ4BitBlkDequantBForHgemm_CompFp16( + BlkLen, dequant, reinterpret_cast(input), scales, + UseZeroPoints ? reinterpret_cast(zero_points) : nullptr, + N, K, BlkNum); + DequantB(input, ref, scales, zero_points); + Check(dequant, ref); + } + + public: + MlasNeonFp16DequantBTest() + : seed_(rd_()), gen_(seed_), distrib_(0, 255), _distribFp(0.5f, 2.0f) { + } + + static const char* GetTestSuiteName() { + return "NeonFp16DequantB"; + } + + void ExecuteShort(void) override { + TestDequant<1, 1, 16, false>(); + TestDequant<1, 1, 16, true>(); + TestDequant<1, 15, 16, false>(); + TestDequant<1, 15, 16, true>(); + TestDequant<1, 31, 16, false>(); + TestDequant<1, 31, 16, true>(); + TestDequant<8, 1, 16, false>(); + TestDequant<8, 1, 16, true>(); + TestDequant<8, 16, 16, false>(); + TestDequant<8, 16, 16, true>(); + TestDequant<9, 31, 16, false>(); + TestDequant<9, 31, 16, true>(); + TestDequant<9, 33, 32, false>(); + TestDequant<9, 33, 32, true>(); + TestDequant<15, 33, 16, false>(); + TestDequant<15, 33, 16, true>(); + TestDequant<17, 67, 16, false>(); + TestDequant<17, 67, 16, true>(); + TestDequant<17, 96, 128, false>(); + TestDequant<17, 96, 128, true>(); + TestDequant<263, 263, 16, false>(); + TestDequant<263, 263, 16, true>(); + } +}; + +class MlasNeonFp16HQ4BitGemmKernelTest : public MlasTestBase { + private: + std::random_device rd_; // a seed source for the random number engine + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + MatrixGuardBuffer A_, B_, C_, ref_, bias_; + + MLAS_FORCEINLINE + void InitializeBuffer(MLAS_FP16* buffer, float min, float max, size_t count) { + std::uniform_real_distribution distrib(min, max); + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib(gen_)); + } + } + + MLAS_FORCEINLINE + bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) { + float f0 = v0.ToFloat(), f1 = v1.ToFloat(); + return std::abs(f0 - f1) <= std::abs(f1 * rtol) + atol; + } + + template + float GetBVal(const MLAS_FP16* B, size_t n, size_t k) { + size_t i; + if ((N & (~7)) > n) { + size_t full8 = n & (~7); + i = full8 * ldb + 8 * k + (n - full8); + } else { + i = n * ldb + k; + } + return B[i].ToFloat(); + } + + template + void MatMul(const MLAS_FP16* A, const MLAS_FP16* B, const MLAS_FP16* bias, MLAS_FP16* C) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float accu = UseBias ? bias[n] : 0.0f; + for (size_t k = 0; k < K; ++k) { + float a = A[m * K + k].ToFloat(); + float b = GetBVal(B, n, k); + accu = accu + a * b; + } + C[m * N + n] = MLAS_FP16(accu); + } + } + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* target, const MLAS_FP16* ref) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + size_t i = m * Ldc + n; + ASSERT_TRUE(FloatEqual(target[i], ref[i], 0.015f, 0.03f)) + << " seed " << seed_ + << " v0 " << target[i] << " v1 " << ref[i] + << " m " << m << " n " << n; + } + } + } + + template + void TestHQ4BitGemmKernel() { + static_assert(M <= 2); + constexpr size_t BlkNum = (K + BlkLen - 1) / BlkLen; + constexpr size_t ldb = BlkNum * BlkLen; + + const auto* A = A_.GetFilledBuffer(M * K, [this](MLAS_FP16* p, size_t t) { + InitializeBuffer(p, -0.25f, 0.25f, t); + }); + const auto* B = B_.GetFilledBuffer(ldb * N, [this](MLAS_FP16* p, size_t t) { + InitializeBuffer(p, -0.25f, 0.25f, t); + }); + auto* C = C_.GetBuffer(M * N, true); + auto* ref = ref_.GetBuffer(M * N, true); + auto* bias = bias_.GetFilledBuffer(N, [this](MLAS_FP16* p, size_t t) { + InitializeBuffer(p, -5.0f, 5.0f, t); + }); + + GetMlasPlatform().QNBitGemmDispatch->HQ4BitGemmKernel_CompFp16( + A, B, UseBias ? bias : nullptr, C, M, N, K, K, ldb, N); + + MatMul(A, B, bias, ref); + Check(C, ref); + } + + public: + MlasNeonFp16HQ4BitGemmKernelTest() + : seed_(rd_()), gen_(seed_) { + } + + static const char* GetTestSuiteName() { + return "NeonFp16HQ4BitGemmKernel"; + } + + template + void ExecuteShort_T(void) { + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + } + + void ExecuteShort(void) override { + ExecuteShort_T<1>(); + ExecuteShort_T<2>(); + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + if (GetMlasPlatform().QNBitGemmDispatch) { + if (GetMlasPlatform().QNBitGemmDispatch->HQ4BitGemmPackQuantBData) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + if (GetMlasPlatform().QNBitGemmDispatch->HQ4BitBlkDequantBForHgemm_CompFp16) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + if (GetMlasPlatform().QNBitGemmDispatch->HQ4BitGemmKernel_CompFp16) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + } + } + return count; +}); + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) diff --git a/onnxruntime/test/mlas/unittest/test_util.h b/onnxruntime/test/mlas/unittest/test_util.h index 8eecda900ff27..a000e353f370d 100644 --- a/onnxruntime/test/mlas/unittest/test_util.h +++ b/onnxruntime/test/mlas/unittest/test_util.h @@ -115,24 +115,24 @@ class MatrixGuardBuffer { return GetFilledBuffer( Elements, [](T* start, size_t size) { - std::fill_n(start, size, T(0)); + std::fill_n(start, size, T(0.0f)); }); } return GetFilledBuffer( Elements, [](T* start, size_t size) { - constexpr int offset = -21; - constexpr int range = 43; + constexpr float offset = -21.f; + constexpr float range = 43.f; - int FillValue = 11; + float FillValue = 11.f; T* FillAddress = start; for (size_t i = 0; i < size; i++) { auto itemv = FillValue - offset; *FillAddress++ = (T)(itemv); - FillValue += 7; - FillValue %= range; + FillValue += 7.f; + FillValue = FillValue >= range ? FillValue - range : FillValue; } }); }