From b87e8edb98ed85640ccf369625267786e91db920 Mon Sep 17 00:00:00 2001 From: liqun Fu Date: Fri, 2 Aug 2024 10:20:22 -0700 Subject: [PATCH] Mlas int4 int8 with avx2/512 (#20687) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description model: phi-3-mini-4k-instruct avx2 symmetric blklen|updated prompt tps | baseline prompt tps | prompt tps change%|updated token gen tps | baseline token gen tps | token gen change% -|-|-|-|-|-|- 16 |49.5|70.0|-29.2%|9.6|10.8|-34.2% 32 |76.8|52.4|9.7%|15.2|14.6|4.1% 64 |78.2|71.4|9.5%|16.6|16.3|1.8% 128 |72.9|70.6|3.2%|17.1|16.8|1.7% 256 |83.7|63.6|31.6%|18.1|17.4|4% avx2 asymmetric blklen|updated prompt tps | baseline prompt tps | prompt tps change%|updated token gen tps | baseline token gen tps | token gen change% -|-|-|-|-|-|- 16 |50.7|61.5|-17.5%|9.6|9.2|4.3% 32 |77.4|52.4|47.7%|14.6|13.9|5.0% 64 |78.7|63.0|24.9%|16.2|15.9|1.8% 128 |80.0|61.9|29.2%|17.2|16.9|1.7% 256 |81.5|63.3|28.7%|17.9|17.3|3.4% avx2vnni symmetric blklen|updated prompt tps | baseline prompt tps | prompt tps change%|updated token gen tps | baseline token gen tps | token gen change% -|-|-|-|-|-|- 16 |82.9|117.0|-29.0%|15.9|19.3|-17.6% 32 |133.0|100.4|32.4%|26.1|24.5|6.5% 64 |166.9|118.8|40.4%|28.3|27.1|4.4% 128 |165.9|119.6|38.7%|29.3|28.5|2.8% 256 |165.2|119.6|38.1%|30.2|29.0|4.1% avx2vnni asymmetric blklen|updated prompt tps | baseline prompt tps | prompt tps change%|updated token gen tps | baseline token gen tps | token gen change% -|-|-|-|-|-|- 16 |80.2|118.9|-32.5%|15.1|16.7|-9.5% 32 |130.7|99.7|31.0%|25.0|23.8|5.0% 64 |168.7|124.9|35.0%|27.3|26.8|1.8% 128 |169.6|123.8|36.9%|29.2|27.9|4.6% 256 |175.0|125.7|39.0%|30.0|29.7|1.0% avx512 symmetric blklen|updated prompt tps | baseline prompt tps | prompt tps change%|updated token gen tps | baseline token gen tps | token gen change% -|-|-|-|-|-|- 16 |135.2|156.5|-13.6|25.5|23.8|7.1 32 |150.0|159.5|-5.9|34.9|29.6|17.9 64 |167.5|157.5|6.3|39.7|34.4|15.4 128 |177.8|158.0|12.5|40.3|35.4|13.8 256 |182.6|157.3|16.0|41.7|37.7|10.6 avx512 asymmetric blklen|updated prompt tps | baseline prompt tps | prompt tps change%|updated token gen tps | baseline token gen tps | token gen change% -|-|-|-|-|-|- 16 |136.1|151.4|-10.1%|26.1|19.9|31.1% 32 |150.0|157.8|-4.9%|34.3|29.3|17.0% 64 |165.7|156.6|5.8%|38.7|30.7|26.0% 128 |180.4|156.6|15.1%|40.2|34.7|15.8% 256 |181.3|158.0|14.7%|41.6|36.6|13.6% avx512vnni symmetric blklen|updated prompt tps | baseline prompt tps | prompt tps change%|updated token gen tps | baseline token gen tps | token gen change% -|-|-|-|-|-|- 16 |143.4|155.4|-7.7%|25.6|23.3|9.8% 32 |159.2|157.0|1.4%|34.1|29.8|14.4% 64 |182.0|159.5|14.1%|38.4|34.8|10.3% 128 |221.2|160.8|37.5%|41.0|36.4|12.6% 256 |250.5|162.4|54.2%|41.6|37.7|10.3% avx512vnni asymmetric blklen|updated prompt tps | baseline prompt tps | prompt tps change%|updated token gen tps | baseline token gen tps | token gen change% -|-|-|-|-|-|- 16 |142.5|152.3|-6.4%|26.3|19.7|33.5% 32 |158.2|155.0|2.0%|34.3|29.2|17.4% 64 |184.1|156.6|17.5%|38.3|30.9|23.9% 128 |215.8|156.1|17.5%|41.3|35.0|17.9% 256 |249.2|155.9|59.8%|41.1|36.3|13.2% 4bit gemm implementation with avx using tile. 1. tile size is 2blk by 4. in case of size less then tile, it reduce to 1blk by 4, 2blk by 1 and lastly 1blk by 1. with internal kernel, weight and activation are loaded based on SIMD register width and blk length: avx2 256bit register, 64 weights and activation are loaded. blklen16: 4 blks are computed by the internal kernel blklen32: 2 blks are computed by the internal kernel blklen64: 1 blk are computed by the internal kernel blklen128: 1 blks are computed 2 times by the internal kernel blklen16: 1 blks are computed 4 times by the internal kernel avx512 512bit register, 128 weights and activation are loaded. blklen16: 8 blks are computed by the internal kernel blklen32: 4 blks are computed by the internal kernel blklen64: 2 blk are computed by the internal kernel blklen128: 1 blks are computed by the internal kernel blklen16: 1 blks are computed 2 times by the internal kernel 2. blksum is precomputed during prepacking. computation is reformed: Sum1(scale_a * scale_b * Sum_blk(a_i * b_i)) + Sum2(blksum_a * blksum_b) Sum_blk is over one blk Sum1 is over all blks for one output Sum2 is over all blks for one output Sum is computed with sgemm with the current implementation. Further improvement is possible.   --------- Signed-off-by: Liqun Fu Signed-off-by: liqunfu Signed-off-by: Liqun Fu --- cmake/onnxruntime_mlas.cmake | 13 +- .../cpu/quantization/matmul_nbits.cc | 41 +- onnxruntime/core/mlas/inc/mlas_qnbit.h | 56 +- onnxruntime/core/mlas/lib/mlasi.h | 2 + onnxruntime/core/mlas/lib/platform.cpp | 1 + onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 244 +++- onnxruntime/core/mlas/lib/sqnbitgemm.h | 102 ++ .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 292 +++- .../sqnbitgemm_kernel_avx2_int8_blklen16.h | 727 ++++++++++ .../sqnbitgemm_kernel_avx2_int8_blklen32.h | 1049 +++++++++++++++ .../sqnbitgemm_kernel_avx2_int8_blklen64.h | 541 ++++++++ .../mlas/lib/sqnbitgemm_kernel_avx512.cpp | 150 ++- .../mlas/lib/sqnbitgemm_kernel_avx512_int8.h | 1171 +++++++++++++++++ .../sqnbitgemm_kernel_avx512_int8_blklen128.h | 581 ++++++++ .../sqnbitgemm_kernel_avx512_int8_blklen16.h | 812 ++++++++++++ .../sqnbitgemm_kernel_avx512_int8_blklen32.h | 852 ++++++++++++ .../sqnbitgemm_kernel_avx512_int8_blklen64.h | 840 ++++++++++++ .../mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp | 180 ++- .../mlas/lib/sqnbitgemm_kernel_avx_common.h | 273 +++- .../lib/sqnbitgemm_kernel_avx_common_int8.h | 51 +- ...bitgemm_m1_sym_kernel_avx2_int8_blklen32.h | 759 +++++++++++ ...bitgemm_m1_sym_kernel_avx2_int8_blklen64.h | 312 +++++ .../test/contrib_ops/matmul_4bits_test.cc | 5 +- onnxruntime/test/mlas/bench/bench_q4dq.cpp | 24 +- .../test/mlas/bench/bench_sqnbitgemm.cpp | 9 +- .../test/mlas/unittest/test_sqnbitgemm.cpp | 47 +- 26 files changed, 8834 insertions(+), 300 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 66f4aea606ef5..c02ac2096db2e 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -555,8 +555,17 @@ else() ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp ) - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma") +message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") +message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}") + +if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "10") + message(STATUS "Using -mavx2 -mfma -mavxvnni flags") + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavxvnni") +else() + message(STATUS "Using -mavx2 -mfma flags") + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma") +endif() set(mlas_platform_srcs_avx512f ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx512F.S @@ -575,7 +584,7 @@ else() ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx512Core.S ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp ) - set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mavx512bw -mavx512dq -mavx512vl") + set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl") set(mlas_platform_srcs_avx512vnni ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 995babc857357..5fdd2b017b8a6 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -104,6 +104,8 @@ class MatMulNBits final : public OpKernel { ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + const Tensor* tensor_zero_point = nullptr; + has_zp_input_ = info.TryGetConstantInput(3, &tensor_zero_point); #ifdef ORT_NEURAL_SPEED const Tensor* tensor_B = nullptr; const Tensor* tensor_scale = nullptr; @@ -139,6 +141,7 @@ class MatMulNBits final : public OpKernel { IAllocatorUniquePtr packed_b_{}; size_t packed_b_size_{0}; + bool has_zp_input_{false}; #if defined(ORT_NEURAL_SPEED) bool is_asym_{false}; @@ -207,10 +210,10 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat is_packed = true; } -#else // defined(ORT_NEURAL_SPEED) - +#else // defined(ORT_NEURAL_SPEED) + ORT_UNUSED_PARAMETER(prepacked_weights); + const auto compute_type = static_cast(accuracy_level_); if (input_idx == InputIndex::B) { - const auto compute_type = static_cast(accuracy_level_); if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { return Status::OK(); } @@ -220,12 +223,20 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } auto qptr = tensor.DataRaw(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get()); - if (prepacked_weights) { - prepacked_weights->buffers_.push_back(std::move(packed_b_)); - prepacked_weights->buffer_sizes_.push_back(packed_b_size_); - } + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get(), nullptr, has_zp_input_, nullptr, nullptr); is_packed = true; + } else if (compute_type == CompInt8) { +#ifdef MLAS_TARGET_AMD64_IX86 + if (input_idx == InputIndex::scales && packed_b_ != nullptr) { + auto sptr = tensor.Data(); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), sptr, has_zp_input_, nullptr, nullptr); + is_packed = false; + } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { + auto zptr = tensor.Data(); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr); + is_packed = false; + } +#endif } #endif // defined(ORT_NEURAL_SPEED) @@ -332,9 +343,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); IAllocatorUniquePtr workspace{}; - if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count, - nbits_, block_size_, compute_type); - workspace_size > 0) { + const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize( + M, N, K, batch_count, nbits_, block_size_, compute_type); + if (workspace_size > 0) { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); @@ -344,14 +355,18 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { for (size_t i = 0; i < batch_count; ++i) { data[i].A = a_data + helper.LeftOffsets()[i]; data[i].lda = lda; - data[i].QuantBData = packed_b_.get(); +#ifdef MLAS_TARGET_AMD64_IX86 + if (compute_type == CompInt8) { + data[i].QuantBDataWorkspace = packed_b_.get(); + } +#endif + data[i].PackedQuantBData = static_cast(packed_b_.get()); data[i].QuantBScale = scales_data; data[i].QuantBZeroPoint = zero_points_data; data[i].Bias = bias_data; data[i].C = y_data + helper.OutputOffsets()[i]; data[i].ldc = N; } - MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), thread_pool); diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 32e9cc98106d5..232bf2261ef4c 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -43,14 +43,16 @@ typedef enum { * @brief Data parameters for float/n-bit quantized int GEMM routine. */ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { - const float* A = nullptr; ///< address of A (float32 matrix) - size_t lda = 0; ///< leading dimension of A - const void* QuantBData = nullptr; ///< address of quantized B (quantized n-bit int values) - const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block - const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block - const float* Bias = nullptr; ///< optional address of Bias, vector size N - float* C = nullptr; ///< address of result matrix - size_t ldc = 0; ///< leading dimension of C + const float* A = nullptr; ///< address of A (float32 matrix) + size_t lda = 0; ///< leading dimension of A + const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values) + const std::byte* PackedQuantBData = nullptr; /// address of packed quantized B data + const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block + const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block + const float* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block + const float* Bias = nullptr; ///< optional address of Bias, vector size N + float* C = nullptr; ///< address of result matrix + size_t ldc = 0; ///< leading dimension of C ///< optional post processing to apply to result matrix MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; @@ -159,14 +161,29 @@ MlasSQNBitGemmPackQuantBDataSize( /** * @brief Packs the quantized B data in a format that the kernel expects. * - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) - * @param[in] BlkLen number of quantized values per block - * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) - * @param[in] QuantBData quantized B data - * @param[out] PackedQuantBData packed quantized B data - * @param[in] ThreadPool optional thread pool to use + * If the function is called without QuantBScale and QuantBZeroPoint, + * it just packs QuantBData into PackedQuantBDataAndOrBlkSum. + * + * If the function is called with QuantBData, QuantBScale, and QuantBZeroPoint + * additional BlkSum (Scale * zeropoint) is computed and stored at the second part of PackedQuantBDataAndOrBlkSum. + * + * Because ORT OpKernel::PrePack is called for each input (in this case, QuantBData, + * QuantBScale, and QuantBZeroPoint) separately, this function may be called 3 times, first with QuantBData, + * and then QuantBScale and QuantBZeroPoint. When the function is called with QuantBScale without QuantBZeroPoint, + * BlkSum is computed with default zero point 8 and stored at the second part of PackedQuantBDataAndOrBlkSum. + * If there is a third call with QuantBZeroPoint, BlkSum is recomputed/adjusted with provided zeropoint. + * + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) + * @param[in] QuantBData quantized B data + * @param[in] PackedQuantBDataAndOrBlkSum buffer to store packed quantized B data and/or BlkSum + * @param[in] QuantBScale quantized B scale + * @param[in] has_zp_input whether QuantBZeroPoint is provided + * @param[in] QuantBZeroPoint quantized B zero point + * @param[in] ThreadPool thread pool to use (no parallel if nullptr) */ void MLASCALL MlasSQNBitGemmPackQuantBData( @@ -176,6 +193,9 @@ MlasSQNBitGemmPackQuantBData( size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, - void* PackedQuantBData, - MLAS_THREADPOOL* ThreadPool = nullptr + void* PackedQuantBDataAndOrBlkSum, + const void* QuantBScale, + bool has_zp_input, + const void* QuantBZeroPoint, + MLAS_THREADPOOL* ThreadPool ); diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 83200187963e1..4239e2ecaeb6e 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -993,6 +993,8 @@ extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon; extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2; +extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni; + extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512; extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 859b7c2f560a4..ed437f20f7c2a 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -409,6 +409,7 @@ Return Value: this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvxVnni; + this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2vnni; } #if !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 81789386a3200..a45494ef2e04f 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -16,11 +16,10 @@ Module Name: --*/ #include "sqnbitgemm.h" +#include "sqnbitgemm_q8_block.h" #include -#include "sqnbitgemm_q8_block.h" - namespace { @@ -80,9 +79,10 @@ MlasIsSQNBitGemmAvailable( return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr && Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr; } - case SQNBitGemmVariant_BitWidth4_CompInt8: { - return Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr && - Dispatch->QuantizeARow_CompInt8 != nullptr; + case SQNBitGemmVariant_BitWidth4_CompInt8: { // SQ4BitGemmKernel_BlkSum_CompInt8 + return + (Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr && Dispatch->QuantizeARow_CompInt8 != nullptr) || + (Dispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr && Dispatch->QuantizeARowComputeBlkSum_CompInt8 != nullptr); } default: { return false; @@ -197,6 +197,21 @@ MlasSQNBitGemmPackQuantBDataSize( return 0; } +struct PerGemmQuantAWorkspace { + PerGemmQuantAWorkspace(void* PerGemmWorkspace, size_t M, size_t BlockCountK, size_t BlkLen) + : PerGemmWorkspace_(PerGemmWorkspace), M_(M), BlockCountK_(BlockCountK), BlkLen_(BlkLen) + { + QuantData = (std::byte*)PerGemmWorkspace; + QuantScale = (float*)(QuantData + M * BlockCountK * BlkLen); + BlockSum = QuantScale + M * BlockCountK; + } + std::byte* QuantData; // NxBlockCountKxBlkLen + float* QuantScale; // NxBlockCountK + float* BlockSum; // NxBlockCountK + void* PerGemmWorkspace_; // memory for above data + size_t M_, BlockCountK_, BlkLen_; +}; + void MLASCALL MlasSQNBitGemmPackQuantBData( size_t N, @@ -205,7 +220,10 @@ MlasSQNBitGemmPackQuantBData( size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, - void* PackedQuantBData, + void* PackedQuantBDataAndOrBlkSumWorkspace, + const void* QuantBScale, + bool has_zp_input, + const void* QuantBZeroPoint, MLAS_THREADPOOL* ThreadPool ) { @@ -214,17 +232,37 @@ MlasSQNBitGemmPackQuantBData( return; } - if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBData != nullptr) { - Dispatch->SQ4BitGemmPackQuantBData( - N, - K, - BlkLen, - ComputeType, - static_cast(QuantBData), - static_cast(PackedQuantBData), - ThreadPool - ); - return; + if (BlkBitWidth == 4) { + if (ComputeType == CompInt8 && Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); + Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum( + N, + K, + BlkLen, + ComputeType, + static_cast(QuantBData), + static_cast(QuantBScale), + has_zp_input, + static_cast(QuantBZeroPoint), + packed_quant_b, + ThreadPool + ); + } else if (Dispatch->SQ4BitGemmPackQuantBData != nullptr) { + // TODO: these assertions are true if called from matmul_nbits kernel but not from mlas tests. + //assert(QuantBScale == nullptr); + //assert(QuantBZeroPoint == nullptr); + Dispatch->SQ4BitGemmPackQuantBData( + N, + K, + BlkLen, + ComputeType, + static_cast(QuantBData), + static_cast(PackedQuantBDataAndOrBlkSumWorkspace), + ThreadPool + ); + return; + } } } @@ -293,7 +331,7 @@ SQ4BitGemm_CompFp32( const float* A = DataParams->A + RangeStartM * lda; - const std::byte* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; + const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; const std::byte* QuantBZeroPoint = (DataParams->QuantBZeroPoint == nullptr) @@ -373,7 +411,6 @@ SQ4BitGemm_CompFp32( if (bias) { AddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc); } - if (DataParams->PostProcessor != nullptr) { DataParams->PostProcessor->Process( DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, @@ -383,7 +420,6 @@ SQ4BitGemm_CompFp32( c_blk += ldc * RowsHandled; a_row += lda * RowsHandled; - RowsRemaining -= RowsHandled; } } @@ -402,16 +438,33 @@ SQ4BitGemm_CompInt8( ) { #ifdef MLAS_TARGET_AMD64_IX86 - if (RangeCountM != 1) { - // perf experiment shows fp32 is faster than int8 in M > 1 cases. - // route to fp32 compute before int8 compute is improved. - SQ4BitGemm_CompFp32( - BlkLen, - K, DataParams, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN - ); - return; - } -#endif + PerGemmQuantAWorkspace* const per_gemm_quant_a_workspace = static_cast(PerGemmWorkspace); + constexpr size_t BlkBitWidth = 4; + + const size_t k_blks = MlasDivRoundup(K, BlkLen); + + // quant A scale is embedded in QuantData if QuantScale is nullptr. + const size_t lda = k_blks * (per_gemm_quant_a_workspace->QuantScale ? BlkLen : Q8BlkSize(BlkLen)); + const size_t ldc = DataParams->ldc; + const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); + + const std::byte* QuantA = per_gemm_quant_a_workspace->QuantData + RangeStartM * lda; + const float* QuantAScale = per_gemm_quant_a_workspace->QuantScale + RangeStartM * k_blks; + + assert(RangeStartN % 4 == 0); + const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; + const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; + const std::byte* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + const float* ABlockSum = per_gemm_quant_a_workspace->BlockSum + RangeStartM * k_blks; + const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN * k_blks; + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; +#else constexpr size_t BlkBitWidth = 4; const size_t k_blks = MlasDivRoundup(K, BlkLen); @@ -423,7 +476,7 @@ SQ4BitGemm_CompInt8( const std::byte* QuantA = static_cast(PerGemmWorkspace) + RangeStartM * lda; - const std::byte* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; + const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; const std::byte* QuantBZeroPoint = (DataParams->QuantBZeroPoint == nullptr) @@ -433,6 +486,7 @@ SQ4BitGemm_CompInt8( float* C = DataParams->C + RangeStartM * ldc + RangeStartN; const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; +#endif size_t CountN; for (size_t n = 0; n < RangeCountN; n += CountN) { @@ -446,25 +500,57 @@ SQ4BitGemm_CompInt8( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - size_t RowsRemaining = RangeCountM; - while (RowsRemaining > 0) { - const auto RowsHandled = GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( + if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8 != nullptr) { + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { + const auto RowsHandled = GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( + BlkLen, + a_row, b_col, b_col_scale, b_col_zp, c_blk, RowsRemaining, CountN, K, k_blks, ldc, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, + RowsHandled, CountN, ldc + ); + } + + c_blk += RowsHandled * ldc; + a_row += RowsHandled * lda; + + RowsRemaining -= RowsHandled; + } + } +#ifdef MLAS_TARGET_AMD64_IX86 + else if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) + { + const float* b_blk_sum = QuantBBlkSum + n * k_blks; + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8( BlkLen, - a_row, b_col, b_col_scale, b_col_zp, c_blk, RowsRemaining, CountN, K, k_blks, ldc, bias + QuantA, + QuantAScale, + b_col, + b_col_scale, + b_col_zp, + c_blk, + RangeCountM, + CountN, + K, + k_blks, + bias, + ldc, + ABlockSum, + b_blk_sum ); if (DataParams->PostProcessor != nullptr) { DataParams->PostProcessor->Process( - DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, - RowsHandled, CountN, ldc + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc ); } - - c_blk += RowsHandled * ldc; - a_row += RowsHandled * lda; - - RowsRemaining -= RowsHandled; } +#endif } } @@ -496,23 +582,44 @@ InitializeWorkspace_CompInt8( MLAS_UNREFERENCED_PARAMETER(N); const auto QuantizeARow = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8; + const auto QuantizeARow2 = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8; const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); - MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { - const auto& data = DataParams[gemm_idx]; + // TODO: try parallel on BatchN * M threads because BatchN is usually 1. + if (QuantizeARow) { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; - const float* ARowPtr = data.A; - std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + const float* ARowPtr = data.A; + std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + for (size_t m = 0; m < M; ++m) { + QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); - for (size_t m = 0; m < M; ++m) { - QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); - - ARowPtr += data.lda; - QuantARowPtr += QuantAStride; - } - }); + ARowPtr += data.lda; + QuantARowPtr += QuantAStride; + } + }); + } else { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; + const float* ARowPtr = data.A; + + void* PerGemmWorkspace = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + PerGemmQuantAWorkspace quant_a_data(PerGemmWorkspace, M, BlockCountK, BlkLen); + std::byte* QuantARowPtr = quant_a_data.QuantData; + float* QuantARowScalePtr = quant_a_data.QuantScale; + float* QuantARowBlkSum = quant_a_data.BlockSum; + for (size_t m = 0; m < M; ++m) { + QuantizeARow2(BlkLen, ARowPtr, K, QuantARowPtr, QuantARowScalePtr, QuantARowBlkSum); + ARowPtr += data.lda; + QuantARowPtr += BlockCountK * BlkLen; + QuantARowScalePtr += BlockCountK; + QuantARowBlkSum += BlockCountK; + } + }); + } } struct Operations { @@ -530,7 +637,6 @@ constexpr auto OperationMap = []() { return ops; }(); - } // namespace void MLASCALL @@ -572,12 +678,23 @@ MlasSQNBitGemmBatch( const auto ComputeOperation = OperationMap[Variant].SQNBitGemm; + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + if (ThreadPool == nullptr) { for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { const auto* Data = &DataParams[gemm_i]; void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, 0, M, 0, N); + if (ComputeType == CompInt8 && GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + const_cast(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; + const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + const_cast(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); + ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); + } else { + ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, 0, M, 0, N); + } } return; } @@ -627,9 +744,6 @@ MlasSQNBitGemmBatch( const auto gemm_i = tid / ThreadsPerGemm; const auto blk_i = tid % ThreadsPerGemm; const auto* Data = &DataParams[gemm_i]; - void* PerGemmWorkspace = reinterpret_cast( - reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride - ); const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; @@ -640,6 +754,18 @@ MlasSQNBitGemmBatch( const size_t RangeStartN = ThreadIdN * StrideN; const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); - ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + void* PerGemmWorkspace = + reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; + if (ComputeType == CompInt8 && GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + const_cast(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; + const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + const_cast(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + + PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); + ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + } else { + ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + } }); } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index 8321dcc217e9a..2da336ca2f0ec 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -25,12 +25,50 @@ Module Name: #include "mlas_qnbit.h" #include "mlasi.h" +constexpr MLAS_FORCEINLINE size_t +MlasQNBitQuantBBlkSumAlignment() +{ + // 16 floats. this alignment is required by GemmFloatKernel + return 16 * sizeof(float); +} + constexpr MLAS_FORCEINLINE size_t MlasQNBitBlkDataSizeInBytes(size_t BlkBitWidth, size_t BlkLen) { return BlkLen * BlkBitWidth / 8; } +MLAS_FORCEINLINE void* +MlasAlignAddress(void* addr, const size_t alignment) +{ + const uintptr_t QuantBBlkSumAddr = reinterpret_cast(addr); + addr = (void*)((QuantBBlkSumAddr + alignment - 1) & (~(alignment - 1))); + return addr; +} + +struct PackedQuantBDataStruct { + PackedQuantBDataStruct(void* PackedQuantBWorkspace, size_t N, size_t BlockCountK, size_t BlkLen) + : QuantBWorkspace_(PackedQuantBWorkspace), N_(N), BlockCountK_(BlockCountK), BlkLen_(BlkLen) + { + // TODO: duplicate code from SQ4BitGemmPackQuantBDataSize + constexpr size_t BlkBitWidth = 4; + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); + + // _mm256_load_si256 requires alignment on a 32-byte boundary + PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32); + QuantBBlkSum = (float*)(PackedQuantBData + PackedQuantBDataSize); + QuantBBlkSum = (float*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment()); + PackedQuantBScale = (float*)((std::byte*)QuantBBlkSum + BlkSumSize); + } + std::byte* PackedQuantBData; + float* PackedQuantBScale; + float* QuantBBlkSum; + + void* QuantBWorkspace_; + size_t N_, BlockCountK_, BlkLen_; +}; + template constexpr MLAS_FORCEINLINE size_t MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount) @@ -74,6 +112,21 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; + typedef void(SQ4BitGemmPackQuantBDataAndSumBlk_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool has_zp_input, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& packed_quant_b, + MLAS_THREADPOOL* ThreadPool + ); + + SQ4BitGemmPackQuantBDataAndSumBlk_Fn* SQ4BitGemmPackQuantBDataAndBlkSum = nullptr; + // // Workspace size calculation function prototypes. // @@ -181,6 +234,45 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { // CompInt8 kernel function prototypes. // + /** + * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. + * A and B are block quantized and B is column major. + * + * @param BlkLen Number of values in a block. + * @param QuantA Supplies the quantized A matrix. + Binary data containing block quantized int8 data and scale values. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockCountK Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + * @param ldc Number of elements between adjacent rows of C.. + * @param ABlockSum Supplies the blksum of A. + * @param QuantBBlkSum Supplies the blksum of B. + */ + typedef size_t(SQ4BitGemmKernel_BlkSum_CompInt8_Fn)( + size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum + ); + + SQ4BitGemmKernel_BlkSum_CompInt8_Fn* SQ4BitGemmKernel_BlkSum_CompInt8 = nullptr; + /** * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. * A and B are block quantized and B is column major. @@ -235,4 +327,14 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { ); QuantizeARow_CompInt8_Fn* QuantizeARow_CompInt8 = nullptr; + + typedef void(QuantizeARowComputeBlkSum_CompInt8_Fn)( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA, + float* QuantAScale, + float* AScaledGroupSum // scale_k * Sum_blklen(a_i) + ); + QuantizeARowComputeBlkSum_CompInt8_Fn* QuantizeARowComputeBlkSum_CompInt8 = nullptr; }; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 0922f5ef646be..55d86bb9cc18e 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -22,6 +22,12 @@ Module Name: #include "sqnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_int8.h" +#include "sqnbitgemm_kernel_avx2_int8_blklen16.h" +#include "sqnbitgemm_kernel_avx2_int8_blklen32.h" +#include "sqnbitgemm_kernel_avx2_int8_blklen64.h" + +#include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h" +#include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h" MLAS_FORCEINLINE __m256 @@ -338,38 +344,92 @@ Q4BitBlkDequantBForSgemm_CompFp32_avx2( } } +template +MLAS_FORCEINLINE +void +SQ4BitGemmKernel_CompInt8_avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + if (BlkLen == 16) { + MlasQ4Int8GemmKernelBlkLen16Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + CountK, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 32) { + MlasQ4Int8GemmKernelBlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + CountK, + BlockCountK, + Bias, + ldc + ); + } else { + MlasQ4Int8GemmKernelBlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } +} + +template MLAS_FORCEINLINE void SQ4BitGemmM1Kernel_CompInt8_avx2( size_t BlkLen, const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, float* C, size_t CountN, - size_t CountK, + size_t /*CountK*/, size_t BlockStrideQuantB, const float* Bias ) { - if (QuantBZeroPoint != nullptr) { - constexpr bool HasZeroPoint = true; + if (QuantBZeroPoint) { if (BlkLen == 16) { - SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); } else if (BlkLen == 32) { - SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( + MlasQ4Int8GemmM1KernelBlkLen32Avx2( QuantA, + QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, @@ -379,36 +439,25 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( Bias ); } else { - SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( + MlasQ4Int8GemmKernelBlkLen64Avx2( BlkLen, QuantA, + QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, - CountK, BlockStrideQuantB, Bias ); } } else { - constexpr bool HasZeroPoint = false; if (BlkLen == 16) { - SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); } else if (BlkLen == 32) { - SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( + MlasQ4Int8GemmM1KernelBlkLen32Avx2( QuantA, + QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, @@ -418,15 +467,15 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( Bias ); } else { - SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( + MlasQ4Int8GemmKernelBlkLen64Avx2( BlkLen, QuantA, + QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, - CountK, BlockStrideQuantB, Bias ); @@ -434,10 +483,12 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( } } +MLAS_FORCEINLINE size_t -SQ4BitGemmKernel_CompInt8_avx2( - size_t BlkLen, +SQ4BitGemmKernel_BlkSum_CompInt8_avx2( + const size_t BlkLen, const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, @@ -446,30 +497,101 @@ SQ4BitGemmKernel_CompInt8_avx2( size_t CountN, size_t CountK, size_t BlockCountK, + const float* Bias, size_t ldc, - const float* Bias + const float* ABlockSum, + const float* QuantBBlkSum ) { - MLAS_UNREFERENCED_PARAMETER(ldc); + if (BlkLen >= 32 && CountM == 1) { + SQ4BitGemmM1Kernel_CompInt8_avx2(BlkLen, QuantA, QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, BlockCountK, Bias); + return CountM; + } + + SQ4BitGemmKernel_CompInt8_avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + CountK, + BlockCountK, + Bias, + ldc + ); + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false + ); - if (CountM == 0) { - return 0; + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; } + return CountM; +} - SQ4BitGemmM1Kernel_CompInt8_avx2( +size_t +SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum +) +{ + if (BlkLen >= 32 && CountM == 1) { + SQ4BitGemmM1Kernel_CompInt8_avx2(BlkLen, QuantA, QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, BlockCountK, Bias); + return CountM; + } + + SQ4BitGemmKernel_CompInt8_avx2( BlkLen, QuantA, + QuantAScale, QuantBData, QuantBScale, - QuantBZeroPoint, C, + CountM, CountN, CountK, BlockCountK, - Bias + Bias, + ldc ); + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false + ); - return 1; + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + return CountM; } template @@ -1053,30 +1175,23 @@ SQ4BitGemmM1Kernel_CompFp32_avx2( } } -MLAS_FORCEINLINE __m128i -convert_2_ps_to_epi8(__m256 v0, __m256 v1) -{ - __m256i v0_8_epi32 = _mm256_cvtps_epi32(v0); - __m256i v1_8_epi32 = _mm256_cvtps_epi32(v1); - - __m128i v0_8_epi16 = _mm_packs_epi32(_mm256_extractf128_si256(v0_8_epi32, 0), _mm256_extractf128_si256(v0_8_epi32, 1)); - __m128i v1_8_epi16 = _mm_packs_epi32(_mm256_extractf128_si256(v1_8_epi32, 0), _mm256_extractf128_si256(v1_8_epi32, 1)); - - return _mm_packs_epi16(v0_8_epi16, v1_8_epi16); -} - void MLASCALL QuantizeARow_CompInt8_avx2( size_t BlkLen, const float* A, size_t CountK, - std::byte* QuantA + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) ) { // port from MlasQ80BlkQuantRow assert(BlkLen % 16 == 0); const __m256 signBit = _mm256_set1_ps(-0.0f); + const __m256i one_16_epi16 = _mm256_srli_epi16( + _mm256_cmpeq_epi16(_mm256_castps_si256(signBit), _mm256_castps_si256(signBit)), 15); int8_t* blob = reinterpret_cast(QuantA); + float* scale_ptr = QuantAScale; for (size_t k = 0; k < CountK; k += BlkLen) { const size_t step = std::min(BlkLen, CountK - k); @@ -1097,13 +1212,14 @@ QuantizeARow_CompInt8_avx2( // Quantize these floats const float scale = maxScalar / 127.f; - *reinterpret_cast(blob) = scale; - blob += sizeof(float); + *scale_ptr = scale; + scale_ptr++; const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f; const __m256 mul = _mm256_set1_ps(inverse_scale); __m128i* dst = reinterpret_cast<__m128i*>(blob); + __m256i sum_16_epi16 = _mm256_setzero_si256(); for (size_t kk = 0; kk < step; kk += 16) { const int klen = std::min(16, (int)(step - kk)); @@ -1122,16 +1238,50 @@ QuantizeARow_CompInt8_avx2( v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST); } - __m128i i_8 = convert_2_ps_to_epi8(v0, v1); - _mm_storeu_si128(dst++, i_8); + __m128i i_16_epi8 = convert_2_ps_to_epi8(v0, v1); + _mm_storeu_si128(dst++, i_16_epi8); + + // accumulate Sum(a_i) + __m256i i_16_epi16 = _mm256_cvtepi8_epi16(i_16_epi8); + sum_16_epi16 = _mm256_hadds_epi16(sum_16_epi16, i_16_epi16); } if (step < BlkLen) { memset(blob + step, 0, BlkLen - step); } + + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + *AScaledBlkSum = scale * hsum_8_epi32(sum_8_epi32); + AScaledBlkSum++; blob += BlkLen; } } +static void +SQ4BitGemmPackQuantBDataAndBlkSum( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool has_zp_input, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& packed_quant_b, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + // TODO: always use SubBlkLen = 64 in CompInt8 + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (BlkLen == 32 && ComputeType == CompInt8) { + SubBlkLen = 64; + } + PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); +} + // // Kernel dispatch structure definition. // @@ -1140,6 +1290,26 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; + + d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; + d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; + d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + + d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2; + d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; + + return d; +}(); + +const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { + MLAS_SQNBIT_GEMM_DISPATCH d; + + d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; @@ -1147,8 +1317,8 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx2; - d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx2; + d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni; + d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h new file mode 100644 index 0000000000000..80d67806ea6e8 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h @@ -0,0 +1,727 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + + +MLAS_FORCEINLINE __m256 +load_and_broadcast_4_scale_2(const float* scale) +{ + // 3 2 1 0 3 2 1 0 (7) + __m256 scale_2_4_ps = _mm256_broadcast_ps((__m128 const*)scale); + + // 2 1 0 0 2 1 0 0 (1) + __m256 scale_2_4_ps_shifted = _mm256_castsi256_ps( + _mm256_bslli_epi128(_mm256_castps_si256(scale_2_4_ps), 4) + ); + + // 3 2 1 0 2 1 0 0: (3) cross lane + __m256 scale_2_4_ps_permutted = _mm256_permute2f128_ps( + scale_2_4_ps_shifted, scale_2_4_ps, 0b00110000 + ); + + // in accumulate_r1_4blk_dot and accumulate_r2_4blk_dot + // _mm256_hadd_epi16 inter leaved dot sum, resulting: + // a31b31|a30b30|a11b11|a10b10|a21b21|a20b20|a01b01|a00b00 + // therefore we need weight to be: + // 3 3 1 1 2 2 0 0 (1) + return _mm256_permute_ps(scale_2_4_ps_permutted, 0b11110101); +} + +MLAS_FORCEINLINE +__m256i +load_16_epi8_as_epi16(const std::byte* ablob) +{ + const __m128i av_epi8 = _mm_lddqu_si128(reinterpret_cast(ablob)); + __m256i av_epi16 = _mm256_cvtepi8_epi16(av_epi8); + return av_epi16; +} + +MLAS_FORCEINLINE void +accumulate_r1_4blk_dot( + const __m256i& av0_32_epi8, const __m256i& av1_32_epi8, + const __m256i& bv0_32_epi8, const __m256i& bv1_32_epi8, + const float* scale_a, const float* scale_b, + __m256& acc) +{ + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av0_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av1_32_epi8); + const __m256i sum_16_inter_leaved_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_inter_leaved_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_inter_leaved_epi16); + const __m256 sum_8_inter_leaved_ps = _mm256_cvtepi32_ps(sum_8_inter_leaved_epi32); + + // load 4 scales + __m256 scale_a_4_ps = load_and_broadcast_4_scale_2(scale_a); + __m256 scale_b_4_ps = load_and_broadcast_4_scale_2(scale_b); + __m256 scale_8_ps = _mm256_mul_ps(scale_a_4_ps, scale_b_4_ps); + acc = _mm256_fmadd_ps(sum_8_inter_leaved_ps, scale_8_ps, acc); +} + +MLAS_FORCEINLINE void +accumulate_r2_4blk_dot( + const __m256i& av00_32_epi8, const __m256i& av01_32_epi8, const __m256i& av10_32_epi8, const __m256i& av11_32_epi8, + const __m256i& bv0_32_epi8, const __m256i& bv1_32_epi8, + const float* scale_a0, const float* scale_a1, const float* scale_b, + __m256& acc0, __m256& acc1 +) +{ + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + const __m256i sum_16_inter_leaved_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_inter_leaved_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_inter_leaved_epi16); + const __m256 sum_8_inter_leaved_ps = _mm256_cvtepi32_ps(sum_8_inter_leaved_epi32); + + // load 4 scales + __m256 scale_a0_4_ps = load_and_broadcast_4_scale_2(scale_a0); + __m256 scale_b_4_ps = load_and_broadcast_4_scale_2(scale_b); + __m256 scale_8_ps = _mm256_mul_ps(scale_a0_4_ps, scale_b_4_ps); + acc0 = _mm256_fmadd_ps(sum_8_inter_leaved_ps, scale_8_ps, acc0); + + const __m256i dot0_16_epi16_ = _mm256_maddubs_epi16(bv0_32_epi8, av10_32_epi8); + const __m256i dot1_16_epi16_ = _mm256_maddubs_epi16(bv1_32_epi8, av11_32_epi8); + const __m256i sum_16_inter_leaved_epi16_ = _mm256_hadd_epi16(dot0_16_epi16_, dot1_16_epi16_); + const __m256i sum_8_inter_leaved_epi32_ = _mm256_madd_epi16(one_16_epi16, sum_16_inter_leaved_epi16_); + const __m256 sum_inter_leaved_ps_ = _mm256_cvtepi32_ps(sum_8_inter_leaved_epi32_); + + __m256 scale_a1_4_ps = load_and_broadcast_4_scale_2(scale_a1); + scale_8_ps = _mm256_mul_ps(scale_a1_4_ps, scale_b_4_ps); + acc1 = _mm256_fmadd_ps(sum_inter_leaved_ps_, scale_8_ps, acc1); +} + +static MLAS_FORCEINLINE __m256i +load_4b_packed_1blk_blklen16(const std::byte* QuantBDataPtr) +{ + // | 0 8 |...| 7 15 | + const __m128i bv_packed_64 = _mm_loadl_epi64(reinterpret_cast(QuantBDataPtr)); + const __m128i low_mask = _mm_set1_epi8(0xF); + const __m128i lower_8_epu8 = _mm_and_si128(bv_packed_64, low_mask); // 0~7 + const __m128i upper_8_epu8 = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bv_packed_64, 4), low_mask), 8); // 8~15 + const __m256i bv_16_epu16 = _mm256_cvtepi8_epi16(_mm_add_epi8(upper_8_epu8, lower_8_epu8)); // 0~15 + return bv_16_epu16; +} + +static MLAS_FORCEINLINE void +load_4b_packed_4blk_blklen16(const std::byte* QuantBDataPtr, __m256i& bv0_32_epi8, __m256i& bv1_32_epi8) +{ + // | 0 8 |...| 7 15 | 16 24 |...| 23 31 ||| 32 40 |...| 39 47 | 48 56 |...| 55 63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + // 0~7, 16~22, 32~39, 48~55 + __m256i bv0_32_epi8_ = _mm256_and_si256(bv_packed, low_mask); + // 8~15, 24~31, 40~47, 56~63: (1) + __m256i bv1_32_epi8_ = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8_), 4); + // 0~7, 32~39, 16~22, 48~55 <- cross lane (3) + bv0_32_epi8_ = _mm256_permute4x64_epi64(bv0_32_epi8_, 0b11011000); + // 40~47, 8~15, 56~63, 24~31 <- cross lane (3) + bv1_32_epi8_ = _mm256_permute4x64_epi64(bv1_32_epi8_, 0b01110010); + + // 0~7, 8~15, 16~22, 24~31: (1) + bv0_32_epi8 = _mm256_blend_epi32(bv0_32_epi8_, bv1_32_epi8_, 0b11001100); + + // 40~47, 32~39, 56~63, 48~55: (1) + bv1_32_epi8 = _mm256_blend_epi32(bv0_32_epi8_, bv1_32_epi8_, 0b00110011); + + // 32~39, 40~47, 48~55, 56~63: (1) + bv1_32_epi8 = _mm256_shuffle_epi32(bv1_32_epi8, 0b01001110); +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r2c1blk4_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m256& acc0, + __m256& acc1 +) +{ + __m256i bv0_32_epi8, bv1_32_epi8; + load_4b_packed_4blk_blklen16(QuantBDataPtr, bv0_32_epi8, bv1_32_epi8); + accumulate_r2_4blk_dot(av00_32_epi8, av01_32_epi8, av10_32_epi8, av11_32_epi8, bv0_32_epi8, bv1_32_epi8, + scale_a0, scale_a1, scale_b, acc0, acc1); +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r1c1blk4_avx2( + const __m256i& av0_32_epi8, + const __m256i& av1_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m256& acc +) +{ + __m256i bv0_32_epi8, bv1_32_epi8; + load_4b_packed_4blk_blklen16(QuantBDataPtr, bv0_32_epi8, bv1_32_epi8); + accumulate_r1_4blk_dot(av0_32_epi8, av1_32_epi8, bv0_32_epi8, bv1_32_epi8, scale_a, scale_b, acc); +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r2c1blk1_avx2( + const __m256i& av0_32_epi8, + const __m256i& av1_32_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale0, + const float& combined_scale1, + __m256& acc0, + __m256& acc1 +) +{ + const __m256i bv_16_epu16 = load_4b_packed_1blk_blklen16(QuantBDataPtr); + + __m256i prod_8_epi32 = _mm256_madd_epi16(bv_16_epu16, av0_32_epi8); + __m256 prod_8_ps = _mm256_cvtepi32_ps(prod_8_epi32); + acc0 = _mm256_fmadd_ps(_mm256_set1_ps(combined_scale0), prod_8_ps, acc0); + + prod_8_epi32 = _mm256_madd_epi16(bv_16_epu16, av1_32_epi8); + prod_8_ps = _mm256_cvtepi32_ps(prod_8_epi32); + acc1 = _mm256_fmadd_ps(_mm256_set1_ps(combined_scale1), prod_8_ps, acc1); +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r1c1blk1_avx2( + const __m256i& av_16_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale, + __m256& acc +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m256i bv_16_epu16 = load_4b_packed_1blk_blklen16(QuantBDataPtr); + + __m256i prod_8_epi32 = _mm256_madd_epi16(bv_16_epu16, av_16_epi8); + __m256 prod_8_ps = _mm256_cvtepi32_ps(prod_8_epi32); + acc = _mm256_fmadd_ps(_mm256_set1_ps(combined_scale), prod_8_ps, acc); +} + +MLAS_FORCEINLINE void +Q4Int8GemmR2xC4BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + // process 4 blks of 64 4b weights a time + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 3; k_blks_remaining -= PerAccuBlk4) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); + + accumulate_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, acc[1], acc[NCols4 + 1]); + accumulate_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], acc[NCols4 + 2]); + accumulate_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], acc[NCols4 + 3]); + + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes8 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + while (k_blks_remaining-- > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av0_16_epi16 = load_16_epi8_as_epi16(QuantABlk0); + const __m256i av1_16_epi16 = load_16_epi8_as_epi16(QuantABlk0 + lda); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + { + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr, scale_00, scale_10, acc[0], acc[NCols4]); + } + + { + const float scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + StrideQuantBData, scale_00, scale_10, acc[1], acc[NCols4 + 1]); + } + + { + const float scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + 2 * StrideQuantBData, scale_00, scale_10, acc[2], acc[NCols4 + 2]); + } + + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + 3 * StrideQuantBData, scale_00, scale_10, acc[3], acc[NCols4 + 3]); + } + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes8; + QuantBScalePtr++; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 4 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + // process 4 blks of 64 4b weights a time + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + 32; + const std::byte* QuantABlk10 = QuantAPtr + lda; + const std::byte* QuantABlk11 = QuantABlk10 + 32; + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk00); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk01); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk10); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk11); + + accumulate_blklen16_r2c1blk4_avx2( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes8 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + while (k_blks_remaining-- > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av0_16_epi16 = load_16_epi8_as_epi16(QuantABlk0); + const __m256i av1_16_epi16 = load_16_epi8_as_epi16(QuantABlk0 + lda); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr, scale_00, scale_10, acc0, acc1); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes8; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + accumulate_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, + QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); + accumulate_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, + QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); + accumulate_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, + QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes8 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + while (k_blks_remaining-- > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = load_16_epi8_as_epi16(QuantABlk0); + + const float& scale_a00 = *QuantAScalePtr; + { + // Col0 + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc[0]); + } + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData, scale_00, acc[1]); + } + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData, scale_00, acc[2]); + } + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData, scale_00, acc[3]); + } + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes8; + QuantBScalePtr++; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q4Int8GemmR1xC1BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 4 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + accumulate_blklen16_r1c1blk4_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr, + QuantAScalePtr, QuantBScalePtr, acc0); + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes8 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + while (k_blks_remaining-- > 0) { + const __m256i av_16_epi16 = load_16_epi8_as_epi16(QuantAPtr); + const float& scale_a00 = *QuantAScalePtr; + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_16_epi16, QuantBDataPtr, scale_00, acc0); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes8; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +MLAS_FORCEINLINE + size_t + MlasQ4Int8GemmKernelBlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t ldc + ) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen16 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8GemmR2xC4BlkLen16Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8GemmR2xC1BlkLen16Avx2( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmR1xC4BlkLen16Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmR1xC1BlkLen16Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h new file mode 100644 index 0000000000000..af6f52090adcb --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h @@ -0,0 +1,1049 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + + +MLAS_FORCEINLINE void +accumulate_1blk_dot(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, + const float& combined_scale, const __m256i& one_16_epi16, __m256& acc) +{ + const __m256i dot_16_epi16 = _mm256_maddubs_epi16( + bv_32_epi8, av_32_epi8 + ); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); +} + +#if !defined(__GNUC__) || (__GNUC__ > 10) +MLAS_FORCEINLINE void +accumulate_1blk_dot_vnni(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, const float& combined_scale, __m256& acc) +{ + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv_32_epi8, av_32_epi8); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); +} +#endif + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m256& acc0, + __m256& acc1 +) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + + // generating low_mask of 0x0Fs is not as fast as just calling _mm256_set1_epi8(0x0F). + const __m256i low_mask = _mm256_set1_epi8(0x0F); + //__m256i low_mask = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_packed, bv_packed), 12); + // low_mask = _mm256_packus_epi16(low_mask, low_mask); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 + // TODO: this (the second line below) is faster and does not keep low_mask in use. + // const __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + { + const __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + const __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); + const __m256i sum_8_epi32 = _mm256_hadd_epi32(dot0_8_epi32, dot1_8_epi32); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + } + { + const __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av10_32_epi8); + const __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av11_32_epi8); + const __m256i sum_8_epi32 = _mm256_hadd_epi32(dot0_8_epi32, dot1_8_epi32); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a1_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a1_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc1 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc1); + } + } else { +#endif + //{ + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + // generating constant 1s is faster here. + // __m256i one = _mm256_set1_epi16(1); + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + //} + //{ + const __m256i dot0_16_epi16_ = _mm256_maddubs_epi16(bv0_32_epi8, av10_32_epi8); + const __m256i dot1_16_epi16_ = _mm256_maddubs_epi16(bv1_32_epi8, av11_32_epi8); + const __m256i sum_16_epi16_ = _mm256_hadd_epi16(dot0_16_epi16_, dot1_16_epi16_); + const __m256i sum_8_epi32_ = _mm256_madd_epi16(one_16_epi16, sum_16_epi16_); + const __m256 sum_ps_ = _mm256_cvtepi32_ps(sum_8_epi32_); + + __m256 scale_a1_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); + __m256 scale_8_ps_ = _mm256_permute_ps(_mm256_mul_ps(scale_a1_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc1 = _mm256_fmadd_ps(sum_ps_, scale_8_ps_, acc1); + //} +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m256& acc0 +) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + const __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + const __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); + const __m256i sum_8_epi32 = _mm256_hadd_epi32(dot0_8_epi32, dot1_8_epi32); // 00110011 + + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + } else { +#endif + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk1_avx2( + const __m256i& av00_32_epi8, + const __m256i& av10_32_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale00, + const float& combined_scale10, + __m256& acc0, + __m256& acc1 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + accumulate_1blk_dot_vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); + accumulate_1blk_dot_vnni(av10_32_epi8, bv_32_epi8, combined_scale10, acc1); + } else { +#endif + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); + accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); + accumulate_1blk_dot(av10_32_epi8, bv_32_epi8, combined_scale10, one_16_epi16, acc1); +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk1_avx2( + const __m256i& av00_32_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale00, + __m256& acc0 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + accumulate_1blk_dot_vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); + } else { +#endif + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); + accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +MLAS_FORCEINLINE void +Q4Int8Gemm2x4x2BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen32; + //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + const size_t StrideQuantBDataCol = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBData2 = 2 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBData1 = 1 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale2 = 2; + const size_t StrideQuantBScale1 = 1; + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + BlkLen32)); + + { + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + } + { + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale2, acc[1], acc[NCols4 + 1]); + } + + { + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2], acc[NCols4 + 2]); + } + + { + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3], acc[NCols4 + 3]); + } + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2 * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + } // k_blks_remaining + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + { + // Col0 + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc[0], acc[NCols4]); + } + + { + // Col1 + const float scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale1)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale1)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, scale_10, acc[1], acc[NCols4 + 1]); + } + + { + // Col2 + const float scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, scale_10, acc[2], acc[NCols4 + 2]); + } + + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, scale_10, acc[3], acc[NCols4 + 3]); + } + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + BlkLen32)); + + accumulate_blklen32_r2c1blk2_avx2( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + if (k_blks_remaining > 0) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc0, acc1); + } + + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmXx4BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen32; + //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + const size_t StrideQuantBDataCol = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBData2 = 2 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBData1 = 1 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale2 = 2; + const size_t StrideQuantBScale1 = 1; + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + + { + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, + QuantAScalePtr, QuantBScalePtr, acc[0]); + } + { + accumulate_blklen32_r1c1blk2_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, + QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, acc[1] + ); + } + { + accumulate_blklen32_r1c1blk2_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, + QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2] + ); + } + { + accumulate_blklen32_r1c1blk2_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, + QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3] + ); + } + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2 * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + + const float& scale_a00 = *QuantAScalePtr; + { + // Col0 + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc[0]); + } + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale1)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, acc[1]); + } + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, acc[2]); + } + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, acc[3]); + } + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmXxXBlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + accumulate_blklen32_r1c1blk2_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr, + QuantAScalePtr, QuantBScalePtr, acc0 + ); + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const float& scale_a00 = *QuantAScalePtr; + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc0); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += BlockCountK; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE + size_t + MlasQ4Int8GemmKernelBlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t ldc + ) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen32 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8Gemm2x4x2BlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8Gemm2xXBlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmXx4BlkLen32Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmXxXBlkLen32Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} + +// this function is to explore larger NCols. With Avx2 it does not improve performance. +// Leave it here until the same is implemented in avx512. +template accumulator> +MLAS_FORCEINLINE +size_t +MlasQ4Int8TileGemmKernelBlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc +) +{ + // We process 32 quantized values in a batch. + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const __m256i zero = _mm256_setzero_si256(); + const __m128i low_mask = _mm_set1_epi8(0xF); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + for (size_t m = 0; m < CountM; m++) { + // for each row of A, reset B pointers + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + int64_t nblk = (int64_t)(CountN)-NCols4; + while (nblk >= 0) { + const std::byte* QuantAPtr = QuantA + m * lda; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4]; + + acc[0] = _mm256_setzero_ps(); + acc[1] = _mm256_setzero_ps(); + acc[2] = _mm256_setzero_ps(); + acc[3] = _mm256_setzero_ps(); + + if constexpr (NCols4 == 8) { + acc[4] = _mm256_setzero_ps(); + acc[5] = _mm256_setzero_ps(); + acc[6] = _mm256_setzero_ps(); + acc[7] = _mm256_setzero_ps(); + } + + size_t k_blks_remaining = BlockCountK; + + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a1 = Q8BlkScale(QuantABlk1); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc[0]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc[0]); + + // Col1 + const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_11 = scale_a1 * (QuantBScalePtr + StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc[1]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_11, acc[1]); + + // Col2 + const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_21 = scale_a1 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc[2]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, false, scale_21, acc[2]); + + // Col3 + const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_31 = scale_a1 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc[3]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, false, scale_31, acc[3]); + + if constexpr (NCols4 == 8) { + // Col4 + const float& scale_40 = scale_a0 * (QuantBScalePtr + 4 * StrideQuantBScale)[0]; + const float& scale_41 = scale_a1 * (QuantBScalePtr + 4 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr, true, scale_40, acc[4]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_41, acc[4]); + + // Col5 + const float& scale_50 = scale_a0 * (QuantBScalePtr + 5 * StrideQuantBScale)[0]; + const float& scale_51 = scale_a1 * (QuantBScalePtr + 5 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_50, acc[5]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_51, acc[5]); + + // Col6 + const float& scale_60 = scale_a0 * (QuantBScalePtr + 6 * StrideQuantBScale)[0]; + const float& scale_61 = scale_a1 * (QuantBScalePtr + 6 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, true, scale_60, acc[6]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, false, scale_61, acc[6]); + + // Col7 + const float& scale_70 = scale_a0 * (QuantBScalePtr + 7 * StrideQuantBScale)[0]; + const float& scale_71 = scale_a1 * (QuantBScalePtr + 7 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, true, scale_70, acc[7]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, false, scale_71, acc[7]); + } + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } // k_blks_remaining + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc[0]); + + // Col1 + const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc[1]); + + // Col2 + const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc[2]); + + // Col3 + const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc[3]); + + if constexpr (NCols4 == 8) { + // Col4 + const float& scale_40 = scale_a0 * (QuantBScalePtr + 4 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 4 * StrideQuantBZeroPoint, true, scale_40, acc[4]); + + // Col5 + const float& scale_50 = scale_a0 * (QuantBScalePtr + 5 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 5 * StrideQuantBZeroPoint, true, scale_50, acc[5]); + + // Col6 + const float& scale_60 = scale_a0 * (QuantBScalePtr + 6 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, true, scale_60, acc[6]); + + // Col7 + const float& scale_70 = scale_a0 * (QuantBScalePtr + 7 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, true, scale_70, acc[7]); + } + } // k_blks_remaining + + if constexpr (NCols4 == 8) { + __m128 acc_0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_1 = FoldAccumulators(acc[4], acc[5], acc[6], acc[7]); + if (BiasPtr != nullptr) { + acc_0 = _mm_add_ps(acc_0, _mm_loadu_ps(BiasPtr)); + acc_1 = _mm_add_ps(acc_1, _mm_loadu_ps(BiasPtr + 4)); + } + _mm_storeu_ps(SumPtr, acc_0); + _mm_storeu_ps(SumPtr+4, acc_1); + } else { + __m128 acc_x = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_x); + } + + // move to next NCols columns + + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + nblk -= NCols4; + } // while (nblk >= 0) + + nblk += NCols4; + for (int64_t n = 0; n < nblk; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a1 = Q8BlkScale(QuantABlk1); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } // m + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h new file mode 100644 index 0000000000000..174ebc580904c --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h @@ -0,0 +1,541 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + +template +static MLAS_FORCEINLINE void +accumulate_blklen64_r2c1blk1_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m256& acc0, + __m256& acc1 +) +{ + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av01_32_epi8); + __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_ps = _mm256_broadcast_ss(scale_a0); + __m256 scale_b_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a0_ps, scale_b_ps), acc0); + + sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av10_32_epi8); + sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av11_32_epi8); + sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a1_ps = _mm256_broadcast_ss(scale_a1); + + acc1 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a1_ps, scale_b_ps), acc1); + + } else { +#endif + __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + + __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_ps = _mm256_broadcast_ss(scale_a0); + __m256 scale_b_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a0_ps, scale_b_ps), acc0); + + dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av10_32_epi8); + dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av11_32_epi8); + sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a1_ps = _mm256_broadcast_ss(scale_a1); + + acc1 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a1_ps, scale_b_ps), acc1); +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen64_r1c1blk1_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m256& acc0 +) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av01_32_epi8); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); + __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); + } else { +#endif + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); + __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); +#if !defined(__GNUC__) || (__GNUC__ > 9) + } +#endif +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR2xC4BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + // process 1 blks of 64 4b weights a time + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); + + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr += NCols4; + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +void MLAS_FORCEINLINE +Q4Int8GemmR2xC1BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); + + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr += NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC1BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + accumulate_blklen64_r1c1blk1_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0 + ); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE size_t +MlasQ4Int8GemmKernelBlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8GemmR2xC4BlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8GemmR2xC1BlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmR1xC4BlkLen64Avx2( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmR1xC1BlkLen64Avx2( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index b868906760701..13bd369a065bb 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -22,6 +22,10 @@ Module Name: #include "sqnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_int8.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen16.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen32.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen128.h" // // CompFp32 kernel implementation. @@ -150,18 +154,115 @@ SQ4BitGemmM1Kernel_CompFp32_avx512( // CompInt8 kernel implementation. // +MLAS_FORCEINLINE +size_t +SQ4BitGemmKernel_BlkSum_CompInt8_avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* /*QuantBZeroPoint*/, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum +) +{ + if (BlkLen == 16) { + MlasQ4Int8GemmKernelBlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 32) { + MlasQ4Int8GemmKernelBlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 64) { + MlasQ4Int8GemmKernelBlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else { + MlasQ4Int8GemmKernelBlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } + + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false + ); + + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + return CountM; +} + void MLASCALL -MlasQ80BlkQuantRow_avx512( +QuantizeARow_CompInt8_avx512( size_t BlkLen, const float* A, size_t CountK, - std::byte* QuantA + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) ) { // port from MlasQ80BlkQuantRow assert(BlkLen % 16 == 0); const __m512 signBit = _mm512_set1_ps(-0.0f); + const __m256i one_16_epi16 = _mm256_set1_epi16(1); int8_t* blob = reinterpret_cast(QuantA); + float* scale_ptr = QuantAScale; for (size_t k = 0; k < CountK; k += BlkLen) { const size_t step = std::min(BlkLen, CountK - k); @@ -185,13 +286,14 @@ MlasQ80BlkQuantRow_avx512( // Quantize these floats const float scale = maxScalar / 127.f; - *reinterpret_cast(blob) = scale; - blob += sizeof(float); + *scale_ptr = scale; + scale_ptr++; const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f; const __m512 mul = _mm512_set1_ps(inverse_scale); __m128i* dst = reinterpret_cast<__m128i*>(blob); + __m256i sum_16_epi16 = _mm256_setzero_si256(); for (size_t kk = 0; kk < step; kk += 16) { const size_t klen = std::min(size_t(16), step - kk); @@ -208,23 +310,46 @@ MlasQ80BlkQuantRow_avx512( // Convert int32 to int8 __m128i i0_8 = _mm512_cvtepi32_epi8(i0); _mm_storeu_si128(dst++, i0_8); + + // accumulate Sum(a_i) + __m256i i_16_epi16 = _mm256_cvtepi8_epi16(i0_8); + sum_16_epi16 = _mm256_hadds_epi16(sum_16_epi16, i_16_epi16); + } if (step < BlkLen) { memset(blob + step, 0, BlkLen - step); } + + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + *AScaledBlkSum = scale * hsum_8_epi32(sum_8_epi32); + AScaledBlkSum++; blob += BlkLen; } } -void MLASCALL -QuantizeARow_CompInt8_avx512( +static void +SQ4BitGemmPackQuantBDataAndBlkSum512( + size_t N, + size_t K, size_t BlkLen, - const float* A, - size_t CountK, - std::byte* QuantA + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool has_zp_input, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& packed_quant_b, + MLAS_THREADPOOL* ThreadPool ) { - MlasQ80BlkQuantRow_avx512(BlkLen, A, CountK, QuantA); + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (ComputeType == CompInt8) { + SubBlkLen = 128; + } + PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); } const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { @@ -232,6 +357,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512; d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; @@ -239,8 +365,8 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx2; - d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx512; + d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512; + d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h new file mode 100644 index 0000000000000..7d9dc36854621 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h @@ -0,0 +1,1171 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + + +MLAS_FORCEINLINE void +accumulate_1blk_dot(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, + const float& combined_scale, const __m256i& one_16_epi16, __m256& acc) +{ + const __m256i dot_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8) + ); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); +} + +MLAS_FORCEINLINE void +accumulate_2blk_dot( + const __m256i& av0_32_epi8, const __m256i& av1_32_epi8, + const __m256i& bv0_32_epi8, const __m256i& bv1_32_epi8, + const float& combined_scale0, const float& combined_scale1, + const __m256i& one_16_epi16, + __m256& acc) +{ + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8) + ); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8) + ); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + const __m256 scale_8_ps = _mm256_set_ps( + combined_scale1, combined_scale1, combined_scale0, combined_scale0, + combined_scale1, combined_scale1, combined_scale0, combined_scale0 + ); + acc = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m256& acc0, + __m256& acc1 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + // TODO: will this (the second line below) be faster and not keep low_mask in use? + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + + int8_t zp0, zp1; + get_2_zps(QuantBZeroPointPtr, zp0, zp1); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); + + //accumulate_2blk_dot(av00_32_epi8, av01_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale00, combined_scale01, one_16_epi16, acc0); + //accumulate_2blk_dot(av10_32_epi8, av11_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale10, combined_scale11, one_16_epi16, acc1); + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av00_32_epi8, bv0_32_epi8) + ); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av01_32_epi8, bv1_32_epi8) + ); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256d scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_mul( + _mm256_permute_ps(scale_a0_2_ps, _MM_SHUFFLE(1, 1, 0, 0)), + _mm256_permute_ps(scale_b_2_ps, _MM_SHUFFLE(1, 1, 0, 0))); + + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + + + const __m256i dot0_16_epi16_ = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av10_32_epi8, bv0_32_epi8) + ); + const __m256i dot1_16_epi16_ = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av11_32_epi8, bv1_32_epi8) + ); + const __m256i sum_16_epi16_ = _mm256_hadd_epi16(dot0_16_epi16_, dot1_16_epi16_); + const __m256i sum_8_epi32_ = _mm256_madd_epi16(one_16_epi16, sum_16_epi16_); + const __m256 sum_ps_ = _mm256_cvtepi32_ps(sum_8_epi32_); + + __m256d scale_a1_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); + __m256 scale_8_ps_ = _mm256_mul( + _mm256_permute_ps(scale_a1_2_ps, _MM_SHUFFLE(1, 1, 0, 0)), + _mm256_permute_ps(scale_b_2_ps, _MM_SHUFFLE(1, 1, 0, 0))); + acc1 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc1); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float& combined_scale00, + const float& combined_scale01, + const float& combined_scale10, + const float& combined_scale11, + __m256& acc0, + __m256& acc1 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + + // generating low_mask of 0x0Fs is not as fast as just calling _mm256_set1_epi8(0x0F). + // however, it is faster to generate one_16_epi16 than calling _mm256_set1_ep16(1); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + //__m256i low_mask = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_packed, bv_packed), 12); + //low_mask = _mm256_packus_epi16(low_mask, low_mask); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...14, 15, 32, 33,...46, 47 + // TODO: will this (the second line below) be faster and not keep low_mask in use? + // const __m256i bv1 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 16, 17,...30, 31, 48, 49,...,62, 63 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 16, 17,...30, 31, 48, 49,...,62, 63 + + //__m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); + + //// This (the second line below) saves one _mm256_extracti128_si256 against using _mm256_set_m128i. + ////__m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); + //__m256i bv1_32_epi8 = _mm256_insertf128_si256(bv1, _mm256_extracti128_si256(bv0, 1), 0); + + int8_t zp0, zp1; + get_2_zps(QuantBZeroPointPtr, zp0, zp1); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); + + // generating constant 1s is fater here. + // __m256i one = _mm256_set1_epi16(1); + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + + // performance gains 7% by calling this (accumulate_2blk_dot) instead of 2 accumulate_1blk_dot calls. + // accumulate_1blk_dot(av00_32_epi8, bv0_32_epi8, combined_scale00, one_16_epi16, acc0); + // accumulate_1blk_dot(av01_32_epi8, bv1_32_epi8, combined_scale01, one_16_epi16, acc0); + // accumulate_1blk_dot(av10_32_epi8, bv0_32_epi8, combined_scale10, one_16_epi16, acc1); + // accumulate_1blk_dot(av11_32_epi8, bv1_32_epi8, combined_scale11, one_16_epi16, acc1); + accumulate_2blk_dot(av00_32_epi8, av01_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale00, combined_scale01, one_16_epi16, acc0); + accumulate_2blk_dot(av10_32_epi8, av11_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale10, combined_scale11, one_16_epi16, acc1); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk1_avx2( + const __m256i& av00_32_epi8, + const __m256i& av10_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float& combined_scale00, + const float& combined_scale10, + __m256& acc0, + __m256& acc1 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); + + const int8_t zp = get_zp(true, QuantBZeroPointPtr); + const __m256i bzp = _mm256_set1_epi8(zp); + bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); + accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); + accumulate_1blk_dot(av10_32_epi8, bv_32_epi8, combined_scale10, one_16_epi16, acc1); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float& combined_scale00, + const float& combined_scale01, + __m256& acc0) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...14, 15, 32, 33,...46, 47 + // TODO: will this be faster and save a use of low_mask? + // const __m256i bv1 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 16, 17,...30, 31, 48, 49,...,62, 63 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 16, 17,...30, 31, 48, 49,...,62, 63 + + //__m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); + + //// This saves one _mm256_extracti128_si256 against using _mm256_set_m128i. + ////__m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); + //__m256i bv1_32_epi8 = _mm256_insertf128_si256(bv1, _mm256_extracti128_si256(bv0, 1), 0); + + int8_t zp0, zp1; + get_2_zps(QuantBZeroPointPtr, zp0, zp1); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + //accumulate_1blk_dot(av00_32_epi8, bv0_32_epi8, combined_scale00, one_16_epi16, acc0); + //accumulate_1blk_dot(av01_32_epi8, bv1_32_epi8, combined_scale01, one_16_epi16, acc0); + accumulate_2blk_dot(av00_32_epi8, av01_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale00, combined_scale01, one_16_epi16, acc0); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk1_avx2( + const __m256i& av00_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float& combined_scale00, + __m256& acc0 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); + + const int8_t zp = get_zp(true, QuantBZeroPointPtr); + const __m256i bzp = _mm256_set1_epi8(zp); + bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); + accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); +} + +template +MLAS_FORCEINLINE void +Q4Int8Gemm2x4BlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + constexpr size_t Q8Blk32Size = Q8BlkSize(BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + Q8Blk32Size; + const std::byte* QuantABlk10 = QuantAPtr + lda; + const std::byte* QuantABlk11 = QuantABlk10 + Q8Blk32Size; + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk10)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk11)); + + const float& scale_a00 = Q8BlkScale(QuantABlk00); + const float& scale_a01 = Q8BlkScale(QuantABlk01); + const float& scale_a10 = Q8BlkScale(QuantABlk10); + const float& scale_a11 = Q8BlkScale(QuantABlk11); + + { + // Col0 + const float& scale_00 = scale_a00 * QuantBScalePtr[0]; + const float& scale_01 = scale_a01 * QuantBScalePtr[1]; + const float& scale_10 = scale_a10 * QuantBScalePtr[0]; + const float& scale_11 = scale_a11 * QuantBScalePtr[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, scale_10, scale_11, acc[0], acc[NCols4]); + } + + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + StrideQuantBScale)[1]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_11 = scale_a11 * (QuantBScalePtr + StrideQuantBScale)[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[1], acc[NCols4 + 1]); + } + + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_11 = scale_a11 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[2], acc[NCols4 + 2]); + } + + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_11 = scale_a11 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[3], acc[NCols4 + 3]); + } + + // increment block pointers + QuantAPtr += Q8Blk32Size * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } // k_blks_remaining + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0 + lda)); + + const float& scale_a00 = Q8BlkScale(QuantABlk0); + const float& scale_a10 = Q8BlkScale(QuantABlk0 + lda); + + { + // Col0 + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_10, acc[0], acc[NCols4]); + } + + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, scale_10, acc[1], acc[NCols4 + 1]); + } + + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_10, acc[2], acc[NCols4 + 2]); + } + + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_10, acc[3], acc[NCols4 + 3]); + } + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + // accumulate_blklen32_r2c1_avx2 + const std::byte* QuantAPtr = QuantA + m * lda; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + Q8BlkSize(BlkLen32); + const std::byte* QuantABlk10 = QuantAPtr + lda; + const std::byte* QuantABlk11 = QuantABlk10 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk10)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk11)); + + const float& scale_a00 = Q8BlkScale(QuantABlk00); + const float& scale_a01 = Q8BlkScale(QuantABlk01); + const float& scale_a10 = Q8BlkScale(QuantABlk10); + const float& scale_a11 = Q8BlkScale(QuantABlk11); + + const float& scale_00 = scale_a00 * QuantBScalePtr[0]; + const float& scale_01 = scale_a01 * QuantBScalePtr[1]; + const float& scale_10 = scale_a10 * QuantBScalePtr[0]; + const float& scale_11 = scale_a11 * QuantBScalePtr[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, scale_10, scale_11, acc0, acc1); + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0 + lda)); + + const float& scale_a00 = Q8BlkScale(QuantABlk0); + const float& scale_a10 = Q8BlkScale(QuantABlk0 + lda); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_10, acc0, acc1); + } + + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmXx4BlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); + + const float& scale_a00 = Q8BlkScale(QuantABlk00); + const float& scale_a01 = Q8BlkScale(QuantABlk01); + { + // Col0 + const float& scale_00 = scale_a00 * QuantBScalePtr[0]; + const float& scale_01 = scale_a01 * QuantBScalePtr[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, acc[0]); + } + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + StrideQuantBScale)[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + 1 * StrideQuantBZeroPoint, scale_00, scale_01, acc[1]); + } + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_01, acc[2]); + } + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_01, acc[3]); + } + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a00 = Q8BlkScale(QuantABlk0); + { + // Col0 + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, acc[0]); + } + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, acc[1]); + } + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, acc[2]); + } + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, acc[3]); + } + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmXxXBlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); + + const float& scale_a00 = Q8BlkScale(QuantABlk00); + const float& scale_a01 = Q8BlkScale(QuantABlk01); + + const float& scale_00 = scale_a00 * QuantBScalePtr[0]; + const float& scale_01 = scale_a01 * QuantBScalePtr[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, acc0); + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a00 = Q8BlkScale(QuantABlk0); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, acc0); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE + size_t + MlasQ4Int8TileGemmKernelBlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc + ) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8Gemm2x4BlkLen32Avx2( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + lda, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8Gemm2xXBlkLen32Avx2( + QuantA, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + lda, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmXx4BlkLen32Avx2( + QuantA + multipleRows * lda, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + lda, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmXxXBlkLen32Avx2( + QuantA + multipleRows * lda, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + lda, + ldc); + } + + return CountM; +} + +// this function is to explore larger NCols. With Avx2 it does not improve performance. +// Leave it here until the same is implemented in avx512. +template accumulator> +MLAS_FORCEINLINE +size_t +MlasQ4Int8GemmKernelBlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc +) +{ + // We process 32 quantized values in a batch. + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const __m256i zero = _mm256_setzero_si256(); + const __m128i low_mask = _mm_set1_epi8(0xF); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + for (size_t m = 0; m < CountM; m++) { + // for each row of A, reset B pointers + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + int64_t nblk = (int64_t)(CountN)-NCols4; + while (nblk >= 0) { + const std::byte* QuantAPtr = QuantA + m * lda; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4]; + + acc[0] = _mm256_setzero_ps(); + acc[1] = _mm256_setzero_ps(); + acc[2] = _mm256_setzero_ps(); + acc[3] = _mm256_setzero_ps(); + + if constexpr (NCols4 == 8) { + acc[4] = _mm256_setzero_ps(); + acc[5] = _mm256_setzero_ps(); + acc[6] = _mm256_setzero_ps(); + acc[7] = _mm256_setzero_ps(); + } + + size_t k_blks_remaining = BlockCountK; + + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a1 = Q8BlkScale(QuantABlk1); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc[0]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc[0]); + + // Col1 + const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_11 = scale_a1 * (QuantBScalePtr + StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc[1]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_11, acc[1]); + + // Col2 + const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_21 = scale_a1 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc[2]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, false, scale_21, acc[2]); + + // Col3 + const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_31 = scale_a1 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc[3]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, false, scale_31, acc[3]); + + if constexpr (NCols4 == 8) { + // Col4 + const float& scale_40 = scale_a0 * (QuantBScalePtr + 4 * StrideQuantBScale)[0]; + const float& scale_41 = scale_a1 * (QuantBScalePtr + 4 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr, true, scale_40, acc[4]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_41, acc[4]); + + // Col5 + const float& scale_50 = scale_a0 * (QuantBScalePtr + 5 * StrideQuantBScale)[0]; + const float& scale_51 = scale_a1 * (QuantBScalePtr + 5 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_50, acc[5]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_51, acc[5]); + + // Col6 + const float& scale_60 = scale_a0 * (QuantBScalePtr + 6 * StrideQuantBScale)[0]; + const float& scale_61 = scale_a1 * (QuantBScalePtr + 6 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, true, scale_60, acc[6]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, false, scale_61, acc[6]); + + // Col7 + const float& scale_70 = scale_a0 * (QuantBScalePtr + 7 * StrideQuantBScale)[0]; + const float& scale_71 = scale_a1 * (QuantBScalePtr + 7 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, true, scale_70, acc[7]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, false, scale_71, acc[7]); + } + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } // k_blks_remaining + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc[0]); + + // Col1 + const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc[1]); + + // Col2 + const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc[2]); + + // Col3 + const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc[3]); + + if constexpr (NCols4 == 8) { + // Col4 + const float& scale_40 = scale_a0 * (QuantBScalePtr + 4 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 4 * StrideQuantBZeroPoint, true, scale_40, acc[4]); + + // Col5 + const float& scale_50 = scale_a0 * (QuantBScalePtr + 5 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 5 * StrideQuantBZeroPoint, true, scale_50, acc[5]); + + // Col6 + const float& scale_60 = scale_a0 * (QuantBScalePtr + 6 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, true, scale_60, acc[6]); + + // Col7 + const float& scale_70 = scale_a0 * (QuantBScalePtr + 7 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, true, scale_70, acc[7]); + } + } // k_blks_remaining + + if constexpr (NCols4 == 8) { + __m128 acc_0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_1 = FoldAccumulators(acc[4], acc[5], acc[6], acc[7]); + if (BiasPtr != nullptr) { + acc_0 = _mm_add_ps(acc_0, _mm_loadu_ps(BiasPtr)); + acc_1 = _mm_add_ps(acc_1, _mm_loadu_ps(BiasPtr + 4)); + } + _mm_storeu_ps(SumPtr, acc_0); + _mm_storeu_ps(SumPtr+4, acc_1); + } else { + __m128 acc_x = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_x); + } + + // move to next NCols columns + + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + nblk -= NCols4; + } // while (nblk >= 0) + + nblk += NCols4; + for (int64_t n = 0; n < nblk; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a1 = Q8BlkScale(QuantABlk1); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } // m + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h new file mode 100644 index 0000000000000..60a887345d0e0 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h @@ -0,0 +1,581 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" + +//static MLAS_FORCEINLINE __m512i +//combine_two_m256i_to_m512i(const __m256i& a, const __m256i& b) +//{ +// __m512i result = _mm512_castsi256_si512(a); +// result = _mm512_inserti64x4(result, b, 1); +// return result; +//} + +//static MLAS_FORCEINLINE void +//load_2blk_4b_packed_blklen64(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) +//{ +// // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | v64 v96 | ... | v95 v127 | +// const __m512i bv_packed = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); +// const __m512i low_mask = _mm512_set1_epi8(0x0F); +// __m512i bv0_64_epi8_ = _mm512_and_si512(bv_packed, low_mask); // 0~31, 64~95 +// __m512i bv1_64_epi8_ = _mm512_srli_epi16(_mm512_sub_epi8(bv_packed, bv0_64_epi8), 4); // 32~63, 96~127 +// +// // Extract lower and higher 256 bits from bv0_64_epi8 and bv1_64_epi8 +// __m256i bv0_lower = _mm512_castsi512_si256(bv0_64_epi8_); +// __m256i bv0_higher = _mm512_extracti64x4_epi64(bv0_64_epi8_, 1); +// __m256i bv1_lower = _mm512_castsi512_si256(bv1_64_epi8_); +// __m256i bv1_higher = _mm512_extracti64x4_epi64(bv1_64_epi8_, 1); +// +// // Compose new __m512i variables +// bv0_64_epi8 = _mm512_inserti64x4(_mm512_castsi256_si512(bv0_lower), bv1_lower, 1); +// bv1_64_epi8 = _mm512_inserti64x4(_mm512_castsi256_si512(bv0_higher), bv1_higher, 1); +//} + +static MLAS_FORCEINLINE void +dot_accumulate_1blk( + const __m512i& bv0_64_epi8, + const __m512i& bv1_64_epi8, + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const float combined_scale, + __m512& acc +) +{ + __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8); + __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av1_64_epi8); + __m512i t1 = _mm512_unpacklo_epi32(dot0_32_epi16, dot1_32_epi16); + __m512i t2 = _mm512_unpackhi_epi32(dot0_32_epi16, dot1_32_epi16); + __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); + const __m512i zeros = _mm512_setzero_si512(); + const __m512i one_32_epi16 = _mm512_srli_epi16(_mm512_ternarylogic_epi32(zeros, zeros, zeros, 1), 15); + __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_set1_ps(combined_scale), acc); +} + +static MLAS_FORCEINLINE void +dot_accumulate_1blkvnni( + const __m512i& bv0_64_epi8, + const __m512i& bv1_64_epi8, + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const float combined_scale, + __m512& acc +) +{ + __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av0_64_epi8); + __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(dot0_16_epi32, bv1_64_epi8, av1_64_epi8); + __m512 sum_16_ps = _mm512_cvtepi32_ps(dot1_16_epi32); + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_set1_ps(combined_scale), acc); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen128_r1c1blk1_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + if constexpr (vnni) { + dot_accumulate_1blkvnni( + bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, + (*scale_a) * (*scale_b), acc + ); + } else { + dot_accumulate_1blk( + bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, + (*scale_a) * (*scale_b), acc + ); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen128_r2c1blk1_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + if constexpr (vnni) { + dot_accumulate_1blkvnni( + bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, + (*scale_a0) * (*scale_b), acc0 + ); + dot_accumulate_1blkvnni( + bv0_64_epi8, bv1_64_epi8, av10_64_epi8, av11_64_epi8, + (*scale_a1) * (*scale_b), acc1 + ); + } else { + dot_accumulate_1blk( + bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, + (*scale_a0) * (*scale_b), acc0 + ); + dot_accumulate_1blk( + bv0_64_epi8, bv1_64_epi8, av10_64_epi8, av11_64_epi8, + (*scale_a1) * (*scale_b), acc1 + ); + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR2xC4BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + // process 1 blks of 64 4b weights a time + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av00_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av01_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + const __m512i av10_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av11_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + SubblkLen / 2)); + + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr += NCols4; + } // k_blks_remaining + +#if 1 + *SumPtr = _mm512_reduce_add_ps(acc[0]); + *(SumPtr + 1) = _mm512_reduce_add_ps(acc[1]); + *(SumPtr + 2) = _mm512_reduce_add_ps(acc[2]); + *(SumPtr + 3) = _mm512_reduce_add_ps(acc[3]); + *(SumPtr + ldc) = _mm512_reduce_add_ps(acc[NCols4]); + *(SumPtr + ldc + 1) = _mm512_reduce_add_ps(acc[NCols4 + 1]); + *(SumPtr + ldc + 2) = _mm512_reduce_add_ps(acc[NCols4 + 2]); + *(SumPtr + ldc + 3) = _mm512_reduce_add_ps(acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + *SumPtr += *BiasPtr; + *(SumPtr + 1) += *(BiasPtr + 1); + *(SumPtr + 2) += *(BiasPtr + 2); + *(SumPtr + 3) += *(BiasPtr + 3); + *(SumPtr + ldc) += *BiasPtr; + *(SumPtr + ldc + 1) += *(BiasPtr + 1); + *(SumPtr + ldc + 2) += *(BiasPtr + 2); + *(SumPtr + ldc + 3) += *(BiasPtr + 3); + } +#else + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + const __m128 level_r0 = _mm_loadu_ps(SumPtr); + _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + + const __m128 level_r1 = _mm_loadu_ps(SumPtr + ldc); + _mm_storeu_ps(SumPtr + ldc, _mm_sub_ps(acc_r1, level_r1)); +#endif + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +void MLAS_FORCEINLINE +Q4Int8GemmR2xC1BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av00_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av01_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + const __m512i av10_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av11_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + SubblkLen / 2)); + + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0); + *(SumPtr + ldc) = hsum_float_16(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = {_mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps()}; + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); + accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); + accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr +=NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC1BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + + accumulate_blklen128_r1c1blk1_avx512( + av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0 + ); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE size_t +MlasQ4Int8GemmKernelBlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8GemmR2xC4BlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8GemmR2xC1BlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmR1xC4BlkLen128Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmR1xC1BlkLen128Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h new file mode 100644 index 0000000000000..3cd610796a5e3 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h @@ -0,0 +1,812 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" +#include "sqnbitgemm_kernel_avx2_int8_blklen16.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen32.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" + + + +static MLAS_FORCEINLINE void +load_4blk_4b_packed_blklen16(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) +{ + // | v0 v64 | v1 v65 | ... | v62 v126 | v63 v127 | + const __m512i bv_packed = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i low_mask = _mm512_set1_epi8(0x0F); + bv0_64_epi8 = _mm512_and_si512(bv_packed, low_mask); // 0~63 + bv1_64_epi8 = _mm512_srli_epi16(_mm512_sub_epi8(bv_packed, bv0_64_epi8), 4); // 64~127 +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r1c1blk8_avx512( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc0) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_4blk_4b_packed_blklen16(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 + { + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_castsi512_ps( + _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a0b_ps)) + ); // 0123456701234567 + + __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0044115522663377 + + const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8); // 0~0,1~1,2~2,3~3 + const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av1_64_epi8); // 4~4,5~5,6~6,7~7 + + const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00004444111155552222666633337777 + const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00004444111155552222666633337777 + const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0044115522663377 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r2c1blk4_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 + { + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a0); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_castsi512_ps( + _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a0b_ps)) + ); // 0123456701234567 + + // TODO: load from memory + __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); + + const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av00_64_epi8); + const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av01_64_epi8); + + const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); + const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); + const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } + { + const __m256 scale_a1_ps = _mm256_loadu_ps(scale_a1); // 01234567 + const __m256 scale_a1b_ps = _mm256_mul_ps(scale_b_ps, scale_a1_ps); + __m512 scale_a1b_16_ps = _mm512_castsi512_ps( + _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a1b_ps)) + ); // 0123456701234567 + + __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); + + const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av10_64_epi8); + const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av11_64_epi8); + + const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); + const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); + const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc1 = _mm512_fmadd_ps(sum_16_ps, scale_a1b_16_ps, acc1); + } +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r1c1blk8_avx512vnni( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc0 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_4blk_4b_packed_blklen16(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 + { + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_castsi512_ps( + _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a0b_ps)) + ); // 0123456701234567 + + __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0044115522663377 + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av0_64_epi8); // 0000111122223333 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av1_64_epi8); // 4444555566667777 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0044115522663377 + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); // 0044115522663377 + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r2c1blk4_avx512vnni( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 + { + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a0); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_castsi512_ps( + _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a0b_ps)) + ); // 0123456701234567 + + // TODO: load from memory + __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av00_64_epi8); // 0000111122223333 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av01_64_epi8); // 4444555566667777 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } + { + const __m256 scale_a1_ps = _mm256_loadu_ps(scale_a1); // 01234567 + const __m256 scale_a1b_ps = _mm256_mul_ps(scale_b_ps, scale_a1_ps); + __m512 scale_a1b_16_ps = _mm512_castsi512_ps( + _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a1b_ps)) + ); // 0123456701234567 + + __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av10_64_epi8); // 0000111122223333 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av11_64_epi8); // 4444555566667777 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc1 = _mm512_fmadd_ps(sum_16_ps, scale_a1b_16_ps, acc1); + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR2xC4BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + if constexpr (vnni) { + accumulate_blklen16_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, + acc[0], acc[NCols4] + ); + accumulate_blklen16_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, + acc[1], acc[NCols4 + 1] + ); + accumulate_blklen16_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, + acc[2], acc[NCols4 + 2] + ); + accumulate_blklen16_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, + acc[3], acc[NCols4 + 3] + ); + } else { + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, + acc[0], acc[NCols4] + ); + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, + acc[1], acc[NCols4 + 1] + ); + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, + acc[2], acc[NCols4 + 2] + ); + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, + acc[3], acc[NCols4 + 3] + ); + } + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; + QuantBScalePtr += PerAccuBlk8; + } // k_blks_remaining + + __m256 acc2[NCols4 * NRows2] = { + h_add_512(acc[0]), + h_add_512(acc[1]), + h_add_512(acc[2]), + h_add_512(acc[3]), + h_add_512(acc[4]), + h_add_512(acc[5]), + h_add_512(acc[6]), + h_add_512(acc[7]) + }; + + while (k_blks_remaining-- > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av0_16_epi16 = load_16_epi8_as_epi16(QuantABlk0); + const __m256i av1_16_epi16 = load_16_epi8_as_epi16(QuantABlk0 + lda); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + { + // Col0 + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr, scale_00, scale_10, acc2[0], acc2[NCols4]); + } + + { + // Col1 + const float scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + StrideQuantBData, scale_00, scale_10, + acc2[1], acc2[NCols4 + 1]); + } + + { + // Col2 + const float scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + 2 * StrideQuantBData, scale_00, scale_10, + acc2[2], acc2[NCols4 + 2]); + } + + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen16_r2c1blk1_avx2( + av0_16_epi16, av1_16_epi16, QuantBDataPtr + 3 * StrideQuantBData, scale_00, scale_10, + acc2[3], acc2[NCols4 + 3]); + } + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + __m128 acc_r1 = FoldAccumulators(acc2[NCols4 + 0], acc2[NCols4 + 1], acc2[NCols4 + 2], acc2[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +void MLAS_FORCEINLINE +Q4Int8GemmR2C1BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + if constexpr (vnni) { + accumulate_blklen16_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1 + ); + } else { + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1 + ); + } + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; + QuantBScalePtr += PerAccuBlk8; + } + + __m256 acc20 = h_add_512(acc0); + __m256 acc21 = h_add_512(acc1); + while (k_blks_remaining-- > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av0_16_epi16 = load_16_epi8_as_epi16(QuantABlk0); + const __m256i av1_16_epi16 = load_16_epi8_as_epi16(QuantABlk0 + lda); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr, scale_00, scale_10, acc20, acc21); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc20); + *(SumPtr + ldc) = hsum_float_8(acc21); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); + accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); + accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + } else { + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + } + + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; + QuantBScalePtr += PerAccuBlk8; + } + + __m256 acc2[NCols4] = { + h_add_512(acc[0]), h_add_512(acc[1]), h_add_512(acc[2]), h_add_512(acc[3]) + }; + + while (k_blks_remaining-- > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = load_16_epi8_as_epi16(QuantABlk0); + + const float& scale_a00 = *QuantAScalePtr; + { + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2[0]); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData, scale_00, acc2[1]); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData, scale_00, acc2[2]); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData, scale_00, acc2[3]); + } + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + + } + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC1BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } else { + accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } + + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; + QuantBScalePtr += PerAccuBlk8; + } + + __m256 acc2 = h_add_512(acc0); + while (k_blks_remaining-- > 0) { + const __m256i av_00_epi8 = load_16_epi8_as_epi16(QuantAPtr); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc2); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE + size_t +MlasQ4Int8GemmKernelBlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc + ) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen16 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8GemmR2xC4BlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8GemmR2C1BlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmR1xC4BlkLen16Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmR1xC1BlkLen16Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h new file mode 100644 index 0000000000000..ca12cc14a7875 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h @@ -0,0 +1,852 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" +#include "sqnbitgemm_kernel_avx2_int8_blklen32.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" + +static MLAS_FORCEINLINE void +load_4blk_4b_packed_blklen32(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) +{ + // | v0 v64 | v1 v65 | ... | v62 v126 | v63 v127 | + const __m512i bv_packed = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i low_mask = _mm512_set1_epi8(0x0F); + bv0_64_epi8 = _mm512_and_si512(bv_packed, low_mask); // 0~63 + bv1_64_epi8 = _mm512_srli_epi16(_mm512_sub_epi8(bv_packed, bv0_64_epi8), 4); // 64~127 +} + +static const uint32_t index_array[16] = {0, 0, 2, 2, 0, 0, 2, 2, 1, 1, 3, 3, 1, 1, 3, 3}; + +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk4_avx512( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc0) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_4blk_4b_packed_blklen32(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 + { + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 + + __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); + // __m512i idx = _mm512_loadu_epi8(&index_array[0]); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 + + const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8); // 0~0,1~1 + const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av1_64_epi8); // 2~2,3~3 + + const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // 00002222000022221111333311113333 + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } +} + +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk4_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 + { + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 + + __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); + // __m512i idx = _mm512_loadu_epi8(&index_array[0]); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 + + const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av00_64_epi8); // 0~0,1~1 + const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av01_64_epi8); // 2~2,3~3 + + const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // 00002222000022221111333311113333 + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } + { + const __m128 scale_a1_ps = _mm_loadu_ps(scale_a1); // 0123 + const __m128 scale_a1b_ps = _mm_mul_ps(scale_b_ps, scale_a1_ps); + __m512 scale_a1b_16_ps = _mm512_broadcast_f32x4(scale_a1b_ps); // 0123012301230123 + + __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); + // __m512i idx = _mm512_loadu_epi8(&index_array[0]); + scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); // 0022002211331133 + + const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av10_64_epi8); // 0~0,1~1 + const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av11_64_epi8); // 2~2,3~3 + + const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // 00002222000022221111333311113333 + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc1 = _mm512_fmadd_ps(sum_16_ps, scale_a1b_16_ps, acc1); + } +} + +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk4_avx512vnni( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc0 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_4blk_4b_packed_blklen32(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 + { + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 + + __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); + //__m512i idx = _mm512_loadu_epi8(&index_array[0]); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av0_64_epi8); // 0000000011111111 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av1_64_epi8); // 2222222233333333 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } +} + +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk4_avx512vnni( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); + //__m512i idx = _mm512_loadu_epi8(&index_array[0]); + + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 + { + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 + + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av00_64_epi8); // 0000000011111111 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av01_64_epi8); // 2222222233333333 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } + { + const __m128 scale_a1_ps = _mm_loadu_ps(scale_a1); // 0123 + const __m128 scale_a1b_ps = _mm_mul_ps(scale_b_ps, scale_a1_ps); + __m512 scale_a1b_16_ps = _mm512_broadcast_f32x4(scale_a1b_ps); // 0123012301230123 + + scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); // 0022002211331133 + + const __m512i dot0_32_epi16 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av10_64_epi8); // 0000000011111111 + const __m512i dot1_32_epi16 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av11_64_epi8); // 2222222233333333 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 0022002211331133 + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 0022002211331133 + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc1 = _mm512_fmadd_ps(sum_16_ps, scale_a1b_16_ps, acc1); + } +} + +MLAS_FORCEINLINE void +accumulate_1blk_dot_avx512vnni(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, const float& combined_scale, __m256& acc) +{ + __m256i sum_8_epi32 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bv_32_epi8, av_32_epi8); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk1_avx512( + const __m256i& av00_32_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale00, + __m256& acc0 +) +{ + if constexpr (vnni) { + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); + accumulate_1blk_dot_avx512vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); + } else { + accumulate_blklen32_r1c1blk1_avx2(av00_32_epi8, QuantBDataPtr, combined_scale00, acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk1_avx512( + const __m256i& av00_32_epi8, + const __m256i& av10_32_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale00, + const float& combined_scale10, + __m256& acc0, + __m256& acc1 +) +{ + if constexpr (vnni) { + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); + + accumulate_1blk_dot_avx512vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); + accumulate_1blk_dot_avx512vnni(av10_32_epi8, bv_32_epi8, combined_scale10, acc1); + } else { + accumulate_blklen32_r2c1blk1_avx2(av00_32_epi8, av10_32_epi8, QuantBDataPtr, combined_scale00, combined_scale10, acc0, acc1); + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR2xC4BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = PerAccuBlk4 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + if constexpr (vnni) { + accumulate_blklen32_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, + acc[0], acc[NCols4] + ); + accumulate_blklen32_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk4, + acc[1], acc[NCols4 + 1] + ); + accumulate_blklen32_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk4, + acc[2], acc[NCols4 + 2] + ); + accumulate_blklen32_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk4, + acc[3], acc[NCols4 + 3] + ); + } else { + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, + acc[0], acc[NCols4] + ); + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk4, + acc[1], acc[NCols4 + 1] + ); + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk4, + acc[2], acc[NCols4 + 2] + ); + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk4, + acc[3], acc[NCols4 + 3] + ); + } + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += StrideQuantBData * NCols4; + QuantBScalePtr += PerAccuBlk4 * NCols4; + } // k_blks_remaining + + __m256 acc2[NCols4 * NRows2] = { + h_add_512(acc[0]), + h_add_512(acc[1]), + h_add_512(acc[2]), + h_add_512(acc[3]), + h_add_512(acc[4]), + h_add_512(acc[5]), + h_add_512(acc[6]), + h_add_512(acc[7]) + }; + + while (k_blks_remaining-- > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantABlk0 + lda)); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + { + // Col0 + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc2[0], acc2[NCols4]); + } + + { + // Col1 + const float scale_00 = scale_a00 * (QuantBScalePtr + 1)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + 1)[0]; + accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_00, scale_10, acc2[1], acc2[NCols4 + 1]); + } + + { + // Col2 + const float scale_00 = scale_a00 * (QuantBScalePtr + 2)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + 2)[0]; + accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_00, scale_10, acc2[2], acc2[NCols4 + 2]); + } + + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3)[0]; + accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_00, scale_10, acc2[3], acc2[NCols4 + 3]); + } + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes16 * NCols4; + QuantBScalePtr += NCols4; + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + __m128 acc_r1 = FoldAccumulators(acc2[NCols4 + 0], acc2[NCols4 + 1], acc2[NCols4 + 2], acc2[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes16; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +void MLAS_FORCEINLINE +Q4Int8GemmR2C1BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + if constexpr (vnni) { + accumulate_blklen32_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1 + ); + } else { + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1 + ); + } + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + __m256 acc20 = h_add_512(acc0); + __m256 acc21 = h_add_512(acc1); + while (k_blks_remaining-- > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantABlk0 + lda)); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc20, acc21); + + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes16; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc20); + *(SumPtr + ldc) = hsum_float_8(acc21); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + } else { + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + } + + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4 * NCols4; + QuantBScalePtr += PerAccuBlk4 * NCols4; + } + + __m256 acc2[NCols4] = { + h_add_512(acc[0]), h_add_512(acc[1]), h_add_512(acc[2]), h_add_512(acc[3]) + }; + + while (k_blks_remaining-- > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + + const float& scale_a00 = *QuantAScalePtr; + { + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, scale_00, acc2[0]); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 1)[0]; + accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_00, acc2[1]); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2)[0]; + accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_00, acc2[2]); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3)[0]; + accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_00, acc2[3]); + } + + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes16 * NCols4; + QuantBScalePtr += NCols4; + + } + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes16; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC1BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } + else { + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } + + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + __m256 acc2 = h_add_512(acc0); + while (k_blks_remaining-- > 0) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, scale_00, acc2); + + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes16; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc2); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE +size_t +MlasQ4Int8GemmKernelBlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen32 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8GemmR2xC4BlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8GemmR2C1BlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmR1xC4BlkLen32Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmR1xC1BlkLen32Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h new file mode 100644 index 0000000000000..2a65ac4af0c1d --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h @@ -0,0 +1,840 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + +static MLAS_FORCEINLINE __m256 +h_add_512(__m512 a) +{ + return _mm256_add_ps(_mm512_castps512_ps256(a), _mm512_extractf32x8_ps(a, 1)); +} + +static MLAS_FORCEINLINE float +hsum_float_16(const __m512 x) +{ + __m256 hi = h_add_512(x); + __m128 hi128 = _mm256_extractf128_ps(hi, 1); + __m128 lo128 = _mm256_castps256_ps128(hi); + hi128 = _mm_add_ps(hi128, lo128); + hi128 = _mm_add_ps(hi128, _mm_movehl_ps(hi128, hi128)); + hi128 = _mm_add_ss(hi128, _mm_movehdup_ps(hi128)); + return _mm_cvtss_f32(hi128); +} + +static MLAS_FORCEINLINE __m512i +combine_two_m256i_to_m512i(const __m256i& a, const __m256i& b) +{ + __m512i result = _mm512_castsi256_si512(a); + result = _mm512_inserti64x4(result, b, 1); + return result; +} + +static MLAS_FORCEINLINE void +load_2blk_4b_packed_blklen64(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) +{ + // | v0 v64 | v1 v65 | ... | v62 v126 | v63 v127 | + const __m512i bv_packed = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i low_mask = _mm512_set1_epi8(0x0F); + bv0_64_epi8 = _mm512_and_si512(bv_packed, low_mask); // 0~63 + bv1_64_epi8 = _mm512_srli_epi16(_mm512_sub_epi8(bv_packed, bv0_64_epi8), 4); // 64~127 + + //// Extract lower and higher 256 bits from bv0_64_epi8 and bv1_64_epi8 + //__m256i bv0_lower = _mm512_castsi512_si256(bv0_64_epi8_); + //__m256i bv0_higher = _mm512_extracti64x4_epi64(bv0_64_epi8_, 1); + //__m256i bv1_lower = _mm512_castsi512_si256(bv1_64_epi8_); + //__m256i bv1_higher = _mm512_extracti64x4_epi64(bv1_64_epi8_, 1); + + //// Compose new __m512i variables + //bv0_64_epi8 = _mm512_inserti64x4(_mm512_castsi256_si512(bv0_lower), bv1_lower, 1); + //bv1_64_epi8 = _mm512_inserti64x4(_mm512_castsi256_si512(bv0_higher), bv1_higher, 1); +} + +static MLAS_FORCEINLINE __m512i +load_1blk_4b_packed_blklen64(const std::byte* QuantBDataPtr) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + __m256i bv1_32_epi8 = _mm256_srli_epi16( + _mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + __m512i bv_64_epi8 = combine_two_m256i_to_m512i(bv0_32_epi8, bv1_32_epi8); + return bv_64_epi8; +} + +static MLAS_FORCEINLINE __m512i +horizontal_add_epi32(__m512i a, __m512i b) +{ + __m512i t1 = _mm512_unpacklo_epi32(a, b); + __m512i t2 = _mm512_unpackhi_epi32(a, b); + __m512i sum = _mm512_add_epi32(t1, t2); + return sum; +} + +static MLAS_FORCEINLINE __m512i +generate_ones_32_epi16() +{ + const __m512i zeros = _mm512_setzero_si512(); + return _mm512_srli_epi16(_mm512_ternarylogic_epi64(zeros, zeros, zeros, 1), 15); +} + +static MLAS_FORCEINLINE void +dot_accumulate_2blk( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const float* scale_a, + const __m512i& bv0_64_epi8, + const __m512i& bv1_64_epi8, + const __m512& scale_b_16_ps, + //const __m512i& one_32_epi16, + __m512& acc) +{ + __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8); + __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av1_64_epi8); + + __m512i t1 = _mm512_unpacklo_epi32(dot0_32_epi16, dot1_32_epi16); + __m512i t2 = _mm512_unpackhi_epi32(dot0_32_epi16, dot1_32_epi16); + __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // sum for blk: 0 0 1 1 0 0 1 1... + __m512i one_32_epi16 = generate_ones_32_epi16(); + __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // sum for blk: 0 1 0 1... + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + + __m256 scale_a_8_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); + __m512 scale_a_16_ps = _mm512_broadcast_f32x8(scale_a_8_ps); + + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a_16_ps, scale_b_16_ps), acc); +} + +static MLAS_FORCEINLINE void +dot_accumulate_2blkvnni( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const float* scale_a, + const __m512i& bv0_64_epi8, + const __m512i& bv1_64_epi8, + const __m512& scale_b_16_ps, + // const __m512i& one_32_epi16, + __m512& acc +) +{ + __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av0_64_epi8); + __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av1_64_epi8); + + __m512i t1_16_epi32 = _mm512_unpacklo_epi32(dot0_16_epi32, dot1_16_epi32); + __m512i t2_16_epi32 = _mm512_unpackhi_epi32(dot0_16_epi32, dot1_16_epi32); + __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // sum for blk: 0 0 1 1 0 0 1 1... + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + + __m256 scale_a_8_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); + __m512 scale_a_16_ps = _mm512_broadcast_f32x8(scale_a_8_ps); + + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a_16_ps, scale_b_16_ps), acc); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen64_r2c1blk2_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m256 scale_b_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x8(scale_b_ps); + + if constexpr (vnni) { + dot_accumulate_2blkvnni( + av00_64_epi8, av01_64_epi8, scale_a0, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc0 + ); + + dot_accumulate_2blkvnni( + av10_64_epi8, av11_64_epi8, scale_a1, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc1 + ); + } else { + dot_accumulate_2blk( + av00_64_epi8, av01_64_epi8, scale_a0, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc0 + ); + + dot_accumulate_2blk( + av10_64_epi8, av11_64_epi8, scale_a1, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc1 + ); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen64_r1c1blk2_avx512( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m256 scale_b_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x8(scale_b_ps); + + if constexpr (vnni) { + dot_accumulate_2blkvnni( + av0_64_epi8, av1_64_epi8, scale_a, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc + ); + } else { + dot_accumulate_2blk( + av0_64_epi8, av1_64_epi8, scale_a, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc + ); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen64_r2c1blk1_avx512( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv_64_epi8 = load_1blk_4b_packed_blklen64(QuantBDataPtr); + + const __m128 scale_b_ps = _mm_broadcast_ss(scale_b); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x2(scale_b_ps); + + if constexpr (vnni) { + { + __m512i dot_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av0_64_epi8); + __m512 sum_16_ps = _mm512_cvtepi32_ps(dot_16_epi32); + + __m128 scale_a0_ps = _mm_broadcast_ss(scale_a0); + __m512 scale_a0_16_ps = _mm512_broadcast_f32x2(scale_a0_ps); + + acc0 = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a0_16_ps, scale_b_16_ps), acc0); + } + + { + __m512i dot_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av1_64_epi8); + __m512 sum_16_ps = _mm512_cvtepi32_ps(dot_16_epi32); + + __m128 scale_a1_ps = _mm_broadcast_ss(scale_a1); + __m512 scale_a1_16_ps = _mm512_broadcast_f32x2(scale_a1_ps); + + acc1 = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a1_16_ps, scale_b_16_ps), acc1); + } + } else { + const __m512i zeros = _mm512_setzero_si512(); + // const __m512i one_32_epi16_ = _mm512_andnot_epi32(zeros, zeros); + // const __m512i one_32_epi16 = _mm512_srli_epi16(_mm512_andnot_epi32(zeros, zeros), 15); + + const __m512i one_32_epi16 = _mm512_srli_epi16(_mm512_ternarylogic_epi32(zeros, zeros, zeros, 1), 15); + { + __m512i dot_32_epi16 = _mm512_maddubs_epi16(bv_64_epi8, av0_64_epi8); + __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot_32_epi16); + + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + + __m128 scale_a0_ps = _mm_broadcast_ss(scale_a0); + __m512 scale_a0_16_ps = _mm512_broadcast_f32x2(scale_a0_ps); + + acc0 = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a0_16_ps, scale_b_16_ps), acc0); + } + + { + __m512i dot_32_epi16 = _mm512_maddubs_epi16(bv_64_epi8, av1_64_epi8); + __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot_32_epi16); + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + + __m128 scale_a1_ps = _mm_broadcast_ss(scale_a1); + __m512 scale_a1_16_ps = _mm512_broadcast_f32x2(scale_a1_ps); + + acc1 = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a1_16_ps, scale_b_16_ps), acc1); + } + } +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen64_r1c1blk1_avx512( + const __m512i& av_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc +) +{ + __m512i bv_64_epi8 = load_1blk_4b_packed_blklen64(QuantBDataPtr); + + const __m128 scale_b_ps = _mm_broadcast_ss(scale_b); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x2(scale_b_ps); + + if constexpr (vnni) { + __m512i dot_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av_32_epi8); + __m512 sum_16_ps = _mm512_cvtepi32_ps(dot_16_epi32); + + __m128 scale_a_ps = _mm_broadcast_ss(scale_a); + __m512 scale_a_16_ps = _mm512_broadcast_f32x2(scale_a_ps); + + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a_16_ps, scale_b_16_ps), acc); + } else { + const __m512i one_32_epi16 = _mm512_set1_epi16(1); + + __m512i dot_32_epi16 = _mm512_maddubs_epi16(bv_64_epi8, av_32_epi8); + __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot_32_epi16); + + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + + __m128 scale_a_ps = _mm_broadcast_ss(scale_a); + __m512 scale_a_16_ps = _mm512_broadcast_f32x2(scale_a_ps); + + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a_16_ps, scale_b_16_ps), acc); + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR2xC4BlkLen64Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen64 = 64; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen64); + + // process 2 blks of 128 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen64; + const size_t StrideQuantBData = PerAccuBlk2 * BlkDataSizeInBytes; + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 128 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk2, acc[1], acc[NCols4 + 1]); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk2, acc[2], acc[NCols4 + 2]); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk2, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += StrideQuantBData * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + } // k_blks_remaining + + while (k_blks_remaining-- > 0) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, + QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, + QuantBDataPtr + 2 * BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, + QuantBDataPtr + 3 * BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; + } + +#if 1 + *SumPtr = _mm512_reduce_add_ps(acc[0]); + *(SumPtr + 1) = _mm512_reduce_add_ps(acc[1]); + *(SumPtr + 2) = _mm512_reduce_add_ps(acc[2]); + *(SumPtr + 3) = _mm512_reduce_add_ps(acc[3]); + *(SumPtr + ldc) = _mm512_reduce_add_ps(acc[NCols4]); + *(SumPtr + ldc + 1) = _mm512_reduce_add_ps(acc[NCols4 + 1]); + *(SumPtr + ldc + 2) = _mm512_reduce_add_ps(acc[NCols4 + 2]); + *(SumPtr + ldc + 3) = _mm512_reduce_add_ps(acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + *SumPtr += *BiasPtr; + *(SumPtr + 1) += *(BiasPtr + 1); + *(SumPtr + 2) += *(BiasPtr + 2); + *(SumPtr + 3) += *(BiasPtr + 3); + *(SumPtr + ldc) += *BiasPtr; + *(SumPtr + ldc + 1) += *(BiasPtr + 1); + *(SumPtr + ldc + 2) += *(BiasPtr + 2); + *(SumPtr + ldc + 3) += *(BiasPtr + 3); + } +#else + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); +#endif + // move to next NCols columns + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +void MLAS_FORCEINLINE +Q4Int8GemmR2xC1BlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkLen64 = 64; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + + // process 2 blks of 128 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + //assert(CountM % NRows2 == 0); + //assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 128 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + while (k_blks_remaining-- > 0) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0); + *(SumPtr + ldc) = hsum_float_16(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkLen64 = 64; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + + // process 2 blks of 128 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen; + //const size_t StrideQuantBData = PerAccuBlk2 * BlkDataSizeInBytes; + //const size_t StrideQuantBScale = BlockCountK; + + //assert(CountM < NRows2); + //assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = {_mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps()}; + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 128 4b weights a time + for (; k_blks_remaining >= PerAccuBlk2; k_blks_remaining -= PerAccuBlk2) { + const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + PerAccuBlk2, acc[1]); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk2, acc[2]); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 3 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk2, acc[3]); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += PerAccuBlk2 * BlkDataSizeInBytes * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + } + + while (k_blks_remaining-- > 0) { + const __m512i av_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC1BlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkLen64 = 64; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + + // process 2 blks of 128 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + //assert(CountM < NRows2); + //assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 128 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + accumulate_blklen64_r1c1blk2_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + while (k_blks_remaining-- > 0) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + + accumulate_blklen64_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE size_t +MlasQ4Int8GemmKernelBlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + if (NRows2 == 2) + Q4Int8GemmR2xC4BlkLen64Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + else + Q4Int8GemmR1xC4BlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + if (NRows2 == 2) + Q4Int8GemmR2xC1BlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + else + Q4Int8GemmR1xC1BlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc + ); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmR1xC4BlkLen64Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmR1xC1BlkLen64Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index 6477a2019b21a..6a5c01162c51b 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -23,6 +23,10 @@ Module Name: #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_fp32.h" #include "sqnbitgemm_kernel_avx_common_int8.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen16.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen32.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen128.h" MLAS_FORCEINLINE void SQ4BitGemmM1Kernel_CompFp32( @@ -146,6 +150,7 @@ void SQ4BitGemmM1Kernel_CompInt8_avx512vnni( size_t BlkLen, const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, @@ -157,44 +162,7 @@ SQ4BitGemmM1Kernel_CompInt8_avx512vnni( ) { if (QuantBZeroPoint != nullptr) { - constexpr bool HasZeroPoint = true; - if (BlkLen == 16) { - SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } else if (BlkLen == 32) { - SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } else { - SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( - BlkLen, - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } + assert(false); } else { constexpr bool HasZeroPoint = false; if (BlkLen == 16) { @@ -212,6 +180,7 @@ SQ4BitGemmM1Kernel_CompInt8_avx512vnni( } else if (BlkLen == 32) { SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( QuantA, + QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, @@ -237,52 +206,134 @@ SQ4BitGemmM1Kernel_CompInt8_avx512vnni( } } +MLAS_FORCEINLINE size_t -SQ4BitGemmKernel_CompInt8_avx512vnni( - size_t BlkLen, +SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni( + const size_t BlkLen, const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, - const std::byte* QuantBZeroPoint, + const std::byte* /*QuantBZeroPoint*/, float* C, size_t CountM, size_t CountN, - size_t CountK, + size_t /*CountK*/, size_t BlockCountK, + const float* Bias, size_t ldc, - const float* Bias + const float* ABlockSum, + const float* QuantBBlkSum ) { - MLAS_UNREFERENCED_PARAMETER(ldc); - - if (CountM == 0) { - return 0; + if (BlkLen == 16) { + MlasQ4Int8GemmKernelBlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 32) { + MlasQ4Int8GemmKernelBlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 64) { + MlasQ4Int8GemmKernelBlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else { + MlasQ4Int8GemmKernelBlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); } - SQ4BitGemmM1Kernel_CompInt8_avx512vnni( - BlkLen, - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockCountK, - Bias - ); + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; - return 1; + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false + ); + + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + return CountM; } void MLASCALL -MlasQ80BlkQuantRow_avx512( +QuantizeARow_CompInt8_avx512( size_t BlkLen, const float* A, size_t CountK, - std::byte* QuantA + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) ); +static void +SQ4BitGemmPackQuantBDataAndBlkSum512vnni( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool has_zp_input, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& packed_quant_b, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (ComputeType == CompInt8) { + SubBlkLen = 128; + } + PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); +} + // // Kernel dispatch structure definition. // @@ -291,6 +342,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512vnni; d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; @@ -298,8 +350,8 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx512vnni; - d.QuantizeARow_CompInt8 = MlasQ80BlkQuantRow_avx512; + d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni; + d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index 706e08fc467ba..177f5518bb891 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -14,13 +14,24 @@ SQ4BitGemmPackQuantBDataSize( MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType - constexpr size_t BlkBitWidth = 4; - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - return PackedQuantBDataSize; + if (ComputeType == CompInt8) { + size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t ScaleSize = N * BlockCountK * sizeof(float); + size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); + + // _mm256_load_si256 requires alignment on a 32-byte boundary + constexpr size_t PackedQuantBDataAlignment = 32; + PackedQuantBDataSize += PackedQuantBDataAlignment - 1; + constexpr size_t BlkSumAlignment = MlasQNBitQuantBBlkSumAlignment(); + BlkSumSize += BlkSumAlignment - 1; + + return PackedQuantBDataSize + ScaleSize + BlkSumSize; + } else { + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + return PackedQuantBDataSize; + } } static void @@ -100,6 +111,216 @@ SQ4BitGemmPackQuantBData( ); } +static size_t +GetContinueLayoutOffsetSubBlk(size_t N, const size_t n, const size_t SubOrBlkCountK, const size_t k_sub_or_blk) +{ + size_t T = n / 4, t = n % 4; + bool te = T == N / 4; + size_t scale_dst_offset = T * 4 * SubOrBlkCountK; + if (te) { + scale_dst_offset += t * SubOrBlkCountK + k_sub_or_blk; + } else { + scale_dst_offset += k_sub_or_blk * 4 + t; + } + return scale_dst_offset; +} + +static size_t +GetContinueLayoutOffsetBlkInSubBlk(size_t N, const size_t n, const size_t BlockCountK, const size_t k_blk, const int blks_per_sub) +{ + size_t T = n / 4, t = n % 4, k_subblk = k_blk / blks_per_sub, b = k_blk % blks_per_sub; + bool te = T == N / 4, be = k_subblk == BlockCountK / blks_per_sub; + size_t scale_dst_offset = T * 4 * BlockCountK; + if (te) { + scale_dst_offset += t * BlockCountK + k_blk; + } else { + scale_dst_offset += k_subblk * blks_per_sub * 4; + if (be) { + scale_dst_offset += b * 4 + t; + } else { + scale_dst_offset += t * blks_per_sub + b; + } + } + return scale_dst_offset; +} + +static void +PackQuantB( + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t N, + const size_t BlockCountK, + const size_t BlkLen, + const size_t SubBlkLen) +{ + constexpr size_t BlkBitWidth = 4; + const size_t BlkBytePairCount = BlkLen / 4; + const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + + const size_t SubBlkDataSize = SubBlkLen / 2; + const size_t SubBlkBytePairCount = SubBlkLen / 4; + const size_t SubBlkCountK = MlasDivRoundup(BlockCountK * BlkLen, SubBlkLen); + const size_t Iterations = N * SubBlkCountK; // one iteration per sub block + + // for avx2 + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // for the remaining blk, it shall be: + // dst blklen32: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // dst blklen16: | v0 v8 | v1 v9 | v2 v11 | v3 v12 | v4 v13 | v5 v14 | v6 v15 | v7 v16 | + + // for avx512 + // dst: | v0 v64 | v1 v65 | ... | v62 v126 | v63 v127 | + // for the remaining blk, it shall be: + // dst blklen64: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // dst blklen32: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // dst blklen16: | v0 v8 | v1 v9 | v2 v11 | v3 v12 | v4 v13 | v5 v14 | v6 v15 | v7 v16 | + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t n = tid / SubBlkCountK; + const size_t k_subblk = tid % SubBlkCountK; + + const size_t src_data_offset = n * BlockCountK * BlkDataSize + k_subblk * SubBlkDataSize; + const std::byte* QuantBData = QuantBDataBegin + src_data_offset; + + size_t PackBytePairCount = SubBlkBytePairCount; + size_t PackDataSize = SubBlkDataSize; + + auto pack_subblk = []( + const std::byte* QuantBData, std::byte* PackedQuantBData, + size_t pack_byte_pair_count, size_t pack_data_size) { + for (size_t byte_pair_idx = 0; byte_pair_idx < pack_byte_pair_count; ++byte_pair_idx) { + const std::byte src0 = QuantBData[byte_pair_idx]; + const std::byte src1 = QuantBData[byte_pair_idx + pack_data_size / 2]; + + std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; + std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; + + dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); + dst1 = (src0 >> 4) | ((src1 >> 4) << 4); + } }; + + if (SubBlkLen > BlkLen && k_subblk == SubBlkCountK - 1 && + SubBlkLen * SubBlkCountK > BlkLen * BlockCountK) { + // this is the last subblk of the column. check if it extends out of the + // BlockCountK. If it does, we shall pack per blocks so that can compute + // on each block instead of each subblk. + PackBytePairCount = BlkBytePairCount; + PackDataSize = BlkDataSize; + const size_t k_blks_remaining = BlockCountK - (SubBlkCountK - 1) * SubBlkLen / BlkLen; + for (size_t k = 0; k < k_blks_remaining; k++) { + const size_t k_blk = k_subblk * SubBlkLen / BlkLen + k; + if (BlkLen == 16) { + // not to do the compute order layout yet + std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset; + pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData + k * BlkLen / 2, PackBytePairCount, PackDataSize); + } else if (BlkLen >= SubBlkLen) { + // shall not reach here with avx2 + assert(SubBlkLen == 128); + } else { + int blks_per_sub = (int)(SubBlkLen / BlkLen); + const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub); + std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; + pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData, PackBytePairCount, PackDataSize); + } + } + } else { + if (BlkLen == 16) { + // not to do the compute order layout yet + std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset; + pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); + } else if (BlkLen >= SubBlkLen) { + const size_t dst_data_offset = GetContinueLayoutOffsetSubBlk(N, n, SubBlkCountK, k_subblk); + std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * SubBlkDataSize; + pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); + } else { + int blks_per_sub = (int)(SubBlkLen / BlkLen); + const size_t k_blk = k_subblk * blks_per_sub; + const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub); + std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; + pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); + } + } + } + ); +} + +//#include + +static void +ComputePackBlkSum( + size_t BlkLen, + size_t SubBlkLen, + size_t N, + float* QuantBScaleBegin, + const std::byte* QuantBZPBegin, + float* BlockSumBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t BlockCountK) +{ + std::vector QuantBScaleBeginCopy(N * BlockCountK); + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, QuantBScaleBeginCopy.begin()); + MlasTrySimpleParallel(ThreadPool, N * BlockCountK, [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t k_blk = tid % BlockCountK; + + const size_t src_blk_offset = n * BlockCountK + k_blk; + const float& QuantBScale = QuantBScaleBeginCopy[src_blk_offset]; + uint8_t zp = 8; + if (QuantBZPBegin) { + size_t ZPCountK = MlasDivRoundup(BlockCountK, 2); + size_t src_zp_offset = ZPCountK * n + k_blk / 2; + bool low_zp = k_blk % 2 == 0; + const std::byte* QuantBZP = QuantBZPBegin + src_zp_offset; + const std::byte low_mask{0X0F}; + zp = (uint8_t)(low_zp ? ((*QuantBZP) & low_mask) : ((*QuantBZP) >> 4)); + } + + // BlockSum is a width 16 row major matrix + const size_t dst_offset = ((n / 16) * BlockCountK + k_blk) * 16 + n % 16; + *(BlockSumBegin + dst_offset) = -QuantBScale * zp; + if (BlkLen == 16) { // TODO + + } else if (BlkLen >= SubBlkLen) { + const size_t scale_dst_offset = GetContinueLayoutOffsetSubBlk(N, n, BlockCountK, k_blk); + *(QuantBScaleBegin + scale_dst_offset) = QuantBScale; + } else { + int blks_per_sub = (int)(SubBlkLen / BlkLen); + size_t scale_dst_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub); + *(QuantBScaleBegin + scale_dst_offset) = QuantBScale; + } + } + ); +} + +static void +PackQuantBDataAndBlkSum( + size_t N, + size_t BlockCountK, + size_t BlkLen, + size_t SubBlkLen, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool has_zp_input, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& packed_quant_b, + MLAS_THREADPOOL* ThreadPool +) +{ + if (QuantBDataBegin) { + PackQuantB(QuantBDataBegin, packed_quant_b.PackedQuantBData, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); + } + + if (QuantBScaleBegin) { + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, packed_quant_b.PackedQuantBScale); + } + + if ((QuantBScaleBegin && !has_zp_input) || QuantBZPBegin) { + ComputePackBlkSum(BlkLen, SubBlkLen, N, packed_quant_b.PackedQuantBScale, QuantBZPBegin, packed_quant_b.QuantBBlkSum, ThreadPool, BlockCountK); + } +} + // // Workspace size calculation function implementation. // @@ -119,7 +340,8 @@ SQ4BitGemmPerGemmWorkspaceSize( case CompInt8: { // workspace buffer is used for block quantization of A to int8 const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); + // QuantData + Scale + BlkSum + const size_t PerGemmWorkspaceSize = M * BlockCountK * (Q8BlkSize(BlkLen) + sizeof(float)); return PerGemmWorkspaceSize; } default: { @@ -288,6 +510,20 @@ load_and_mul_sum_s8_quads_with_zp_avx2( acc0 = _mm256_fmadd_ps(sum_ps, scale0, acc0); } +template +void MLAS_FORCEINLINE +get_2_zps(const std::byte* QuantBZeroPointPtr, int8_t& zp0, int8_t& zp1) +{ + if constexpr (HasZeroPoint) { + zp0 = std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}); + zp1 = std::to_integer((*QuantBZeroPointPtr) >> 4); + } else { + zp0 = 8; + zp1 = 8; + (void)QuantBZeroPointPtr; + } +} + template int8_t MLAS_FORCEINLINE get_zp(bool is_lower_half_byte_zp, const std::byte* QuantBZeroPointPtr) @@ -375,7 +611,7 @@ FoldAccumulators(const __m256& acc0, const __m256& acc1, const __m256& acc2, con return acc_y; } -static inline float +static MLAS_FORCEINLINE float hsum_float_8(const __m256 x) { __m128 res = _mm256_extractf128_ps(x, 1); @@ -417,4 +653,27 @@ FoldAccumulators(const __m512& acc0, const __m512& acc1, const __m512& acc2, con _mm256_add_ps(_mm512_extractf32x8_ps(acc_lo0123, 0), _mm512_extractf32x8_ps(acc_lo0123, 1)); return _mm_add_ps(_mm256_extractf32x4_ps(acc_y, 0), _mm256_extractf32x4_ps(acc_y, 1)); } + +static MLAS_FORCEINLINE __m128i +convert_2_ps_to_epi8(__m256 v0, __m256 v1) +{ + __m256i v0_8_epi32 = _mm256_cvtps_epi32(v0); + __m256i v1_8_epi32 = _mm256_cvtps_epi32(v1); + + __m128i v0_8_epi16 = _mm_packs_epi32(_mm256_extractf128_si256(v0_8_epi32, 0), _mm256_extractf128_si256(v0_8_epi32, 1)); + __m128i v1_8_epi16 = _mm_packs_epi32(_mm256_extractf128_si256(v1_8_epi32, 0), _mm256_extractf128_si256(v1_8_epi32, 1)); + + return _mm_packs_epi16(v0_8_epi16, v1_8_epi16); +} + +// horizontally add 8 int32_t +static MLAS_FORCEINLINE int +hsum_8_epi32(const __m256i a_8_epi32) +{ + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a_8_epi32), _mm256_extractf128_si256(a_8_epi32, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} } // namespace diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h index 250ffeacd7c2f..895ce6cd091c2 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h @@ -7,20 +7,6 @@ #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_q8_block.h" -void -SQ4BitGemmM1Kernel_CompInt8_avx2( - size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB, - const float* Bias -); - template MLAS_FORCEINLINE void ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen16( @@ -240,6 +226,7 @@ template accumulator> void SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, @@ -273,6 +260,7 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( int64_t nblk = (int64_t)(CountN)-4; while (nblk >= 0) { const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; const std::byte* QuantBDataPtr = QuantBDataColPtr; const float* QuantBScalePtr = QuantBScaleColPtr; const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; @@ -286,14 +274,14 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( size_t k_blks_remaining = BlockCountK; for (; k_blks_remaining > 1; k_blks_remaining -= 2) { const std::byte* QuantABlk0 = QuantAPtr; - const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); + const std::byte* QuantABlk1 = QuantABlk0 + BlkLen; // load A: - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); - const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk1); - const float& scale_a0 = Q8BlkScale(QuantABlk0); - const float& scale_a1 = Q8BlkScale(QuantABlk1); + const float& scale_a0 = *QuantAScalePtr; + const float& scale_a1 = *(QuantAScalePtr + 1); // Col0 const float& scale_00 = scale_a0 * QuantBScalePtr[0]; @@ -320,7 +308,8 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, false, scale_31, acc3); // increment block pointers - QuantAPtr += Q8BlkSize(BlkLen) * 2; + QuantAPtr += BlkLen * 2; + QuantAScalePtr += 2; QuantBDataPtr += 16 * 2; QuantBScalePtr += 2; if constexpr (HasZeroPoint) { @@ -331,9 +320,9 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( if (k_blks_remaining > 0) { // load A const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a0 = *QuantAScalePtr; // Col0 const float& scale_00 = scale_a0 * QuantBScalePtr[0]; @@ -374,6 +363,7 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( nblk += NCols; for (int64_t n = 0; n < nblk; n++) { const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; const std::byte* QuantBDataPtr = QuantBDataColPtr; const float* QuantBScalePtr = QuantBScaleColPtr; const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; @@ -383,14 +373,14 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( size_t k_blks_remaining = BlockCountK; for (; k_blks_remaining > 1; k_blks_remaining -= 2) { const std::byte* QuantABlk0 = QuantAPtr; - const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); + const std::byte* QuantABlk1 = QuantABlk0 + BlkLen; // load A: - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); - const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk1); - const float& scale_a0 = Q8BlkScale(QuantABlk0); - const float& scale_a1 = Q8BlkScale(QuantABlk1); + const float& scale_a0 = *QuantAScalePtr; + const float& scale_a1 = *(QuantAScalePtr + 1); // Col0 const float& scale_00 = scale_a0 * QuantBScalePtr[0]; @@ -399,7 +389,8 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); // increment block pointers - QuantAPtr += Q8BlkSize(BlkLen) * 2; + QuantAPtr += BlkLen * 2; + QuantAScalePtr += 2; QuantBDataPtr += 16 * 2; QuantBScalePtr += 2; if constexpr (HasZeroPoint) { @@ -410,9 +401,9 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( if (k_blks_remaining > 0) { // load A const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a0 = *QuantAScalePtr; // Col0 const float& scale_00 = scale_a0 * QuantBScalePtr[0]; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h new file mode 100644 index 0000000000000..45c3963365e6b --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h @@ -0,0 +1,759 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk1_zp_avx2( + const __m256i& av_32_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale, + const std::byte* QuantBZeroPointPtr, + __m256& acc, + const __m256i& low_mask +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(low_mask, bv_32_epi8); + + bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + const __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); + } else { +#endif + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); + const __m256i dot_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8)); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk2_zp_avx2( + const __m256i& av0_32_epi8, + const __m256i& av1_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + const std::byte* QuantBZeroPointPtr, + __m256& acc0, + const __m256i& low_mask +) +{ + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 + __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 32~63 + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + { + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + const __m256 scale = _mm256_set1_ps(*(scale_a) * *(scale_b)); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + + { + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(get_zp(false, QuantBZeroPointPtr))); + const __m256 scale = _mm256_set1_ps(*(scale_a + 1) * *(scale_b + 1)); + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + } else { +#endif + { + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); + const __m256 scale = _mm256_set1_ps(*(scale_a) * *(scale_b)); + __m256i dot_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8) + ); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + + { + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(get_zp(false, QuantBZeroPointPtr))); + const __m256 scale = _mm256_set1_ps(*(scale_a + 1) * *(scale_b + 1)); + __m256i dot_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8) + ); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk2_zp_is_8_avx2( + const __m256i& av0_32_epi8, + const __m256i& av1_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m256& acc0, + const __m256i& low_mask, + const __m256i& bzp8 +) +{ + // accumulate_blklen32_r1c1blk2_zp_is_8_avx2 is much faster than + // accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2: + // BlkBitWidth:4/BlkLen:32/M:1/N:2560/K:2560/Threads:8/Symmetric:1/HasBias:0/ComputeType:4 + // 36591 vs 40270 ns (the main is 51836 ns). both are not as good as main with genai. + // TODO: consolidate with accumulate_blklen32_r1c1blk2_avx2 using a zp8 template option + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 + __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 32~63 + + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + const __m256i sum_8_epi32 = _mm256_hadd_epi32(dot0_8_epi32, dot1_8_epi32); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + } else { +#endif + __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + __m256i dot1_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(low_mask, low_mask), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps( + _mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0) + ); + + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2( + const __m256i& av0_32_epi8, + const __m256i& av1_32_epi8, + const __m256& scale_a0_8_ps, + const __m256& scale_a1_8_ps, + const std::byte* QuantBDataPtr, + const float* scale_b, + __m256& acc0, + const __m256i& low_mask, + const __m256i& bzp8 +) +{ + // TODO: consolidate with accumulate_blklen32_r1c1blk2_avx2 using a zp8 template option + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 + + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + { + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*scale_b), scale_a0_8_ps); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + { + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*(scale_b + 1)), scale_a1_8_ps); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + } else { +#endif + { + __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot0_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*scale_b), scale_a0_8_ps); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + { + __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot0_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*(scale_b + 1)), scale_a1_8_ps); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmM1C4BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountN % NCols4 == 0); + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C; + + const __m256i low_mask = _mm256_set1_epi8(0x0F); + //const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(low_mask, low_mask), 15); + const size_t StrideQuantBDataCol = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBData2 = 2 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBData1 = 1 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale2 = 2; + const size_t StrideQuantBScale1 = 1; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + //const __m256 scale_a0_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr)); + //const __m256 scale_a1_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr + Q8BlkSize(BlkLen32))); + + //accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_00_epi8, av_01_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr, QuantBScalePtr, acc[0], low_mask, bzp8); + //accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_00_epi8, av_01_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + StrideQuantBData, QuantBScalePtr + StrideQuantBScale, acc[1], low_mask, bzp8); + //accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_00_epi8, av_01_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + 2 * StrideQuantBData, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], low_mask, bzp8); + //accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_00_epi8, av_01_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + 3 * StrideQuantBData, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], low_mask, bzp8); + if constexpr (HasZeroPoint) { + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, acc[0], low_mask); + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, QuantBZeroPointPtr + StrideQuantBZeroPoint, acc[1], low_mask); + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, acc[2], low_mask); + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, acc[3], low_mask); + + } else { + const __m256i bzp8 = _mm256_set1_epi8(8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0], low_mask, bzp8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, acc[1], low_mask, bzp8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2], low_mask, bzp8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3], low_mask, bzp8); + } + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2 * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const float& scale_a00 = *QuantAScalePtr; + { + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr, scale_00, QuantBZeroPointPtr, acc[0], low_mask); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale1)[0]; + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, QuantBZeroPointPtr + StrideQuantBZeroPoint, acc[1], low_mask); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, acc[2], low_mask); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, acc[3], low_mask); + } + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmM1C1BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountN < NCols4); + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C; + + const __m256i low_mask = _mm256_set1_epi8(0x0F); + [[maybe_unused]] const __m256i bzp8 = _mm256_set1_epi8(8); + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + //const __m256 scale_a0_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr)); + //const __m256 scale_a1_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr + Q8BlkSize(BlkLen32))); + + if constexpr (HasZeroPoint) { + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, acc0, low_mask); + } else { + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0, low_mask, bzp8); + } + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const float& scale_a00 = *QuantAScalePtr; + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr, scale_00, QuantBZeroPointPtr, acc0, low_mask); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +template +MLAS_FORCEINLINE +void +MlasQ4Int8GemmM1KernelBlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias + ) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleCols > 0) { + Q4Int8GemmM1C4BlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + multipleCols, + BlockCountK, + Bias); + } + + if (remainingCols > 0) { + Q4Int8GemmM1C1BlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, + C + multipleCols, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr); + } +} + +//#define SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout 1 +void SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias +) +{ + // port from neon implementation + constexpr size_t BlkBitWidth = 4; + constexpr size_t BlkLen = 32; +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout +#else + constexpr bool HasZeroPoint = false; +#endif + + float* CRowPtr = C; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + //const size_t StrideQuantBScale = BlockCountK; + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + const __m256i low_mask = _mm256_set1_epi8(0x0F); + const __m256i bzp8 = _mm256_set1_epi8(8); + const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(low_mask, low_mask), 15); + (void)StrideQuantBZeroPoint; +#else + const __m256i zero = _mm256_setzero_si256(); + const __m128i low_mask = _mm_set1_epi8(0xF); +#endif + const size_t NCols = 4; + constexpr size_t StrideQuantBScale2 = 2; + constexpr size_t StrideQuantBScale1 = 1; + + int64_t nblk = (int64_t)(CountN)-4; + while (nblk >= 0) { + const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + (void)QuantBZeroPointPtr; +#endif + __m256 + acc0 = _mm256_setzero_ps(), + acc1 = _mm256_setzero_ps(), + acc2 = _mm256_setzero_ps(), + acc3 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= 2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + BlkLen; + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk1); +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + const __m256 scale_a0_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr)); + const __m256 scale_a1_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr + Q8BlkSize(BlkLen))); +#else + const float& scale_a0 = QuantAScalePtr[0]; + const float& scale_a1 = QuantAScalePtr[1]; +#endif + + // Col0 +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr, QuantBScalePtr, acc0, low_mask, bzp8); +#else + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); +#endif + + // Col1 +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + StrideQuantBData, QuantBScalePtr + StrideQuantBScale2, acc1, low_mask, bzp8); +#else + const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale2)[0]; + const float& scale_11 = scale_a1 * (QuantBScalePtr + StrideQuantBScale2)[1]; + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc1); + accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_11, acc1); +#endif + + // Col2 +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + 2 * StrideQuantBData, QuantBScalePtr + 2 * StrideQuantBScale2, acc2, low_mask, bzp8); +#else + const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale2)[0]; + const float& scale_21 = scale_a1 * (QuantBScalePtr + 2 * StrideQuantBScale2)[1]; + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc2); + accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, false, scale_21, acc2); +#endif + // Col3 +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + 3 * StrideQuantBData, QuantBScalePtr + 3 * StrideQuantBScale2, acc3, low_mask, bzp8); +#else + const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale2)[0]; + const float& scale_31 = scale_a1 * (QuantBScalePtr + 3 * StrideQuantBScale2)[1]; + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc3); + accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, false, scale_31, acc3); +#endif + // increment block pointers + QuantAPtr += BlkLen * 2; + QuantAScalePtr += 2; + QuantBDataPtr += 16 * 2; + QuantBScalePtr += 2 * NCols; + } + + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + + const float& scale_a0 = *QuantAScalePtr; + + // Col0 + const float& scale_0 = scale_a0 * QuantBScalePtr[0]; +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr, scale_0, acc0, low_mask, bzp8); +#else + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_0, acc0); +#endif + + // Col1 + const float& scale_1 = scale_a0 * (QuantBScalePtr + StrideQuantBScale1)[0]; +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr + StrideQuantBData, scale_1, acc1, low_mask, bzp8); +#else + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_1, acc1); +#endif + + // Col2 + const float& scale_2 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr + 2 * StrideQuantBData, scale_2, acc2, low_mask, bzp8); +#else + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_2, acc2); +#endif + + // Col3 + const float& scale_3 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr + 3 * StrideQuantBData, scale_3, acc3, low_mask, bzp8); +#else + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_3, acc3); +#endif + } + + __m128 acc_x = FoldAccumulators(acc0, acc1, acc2, acc3); + if (BiasPtr != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_x); + + // move to next NCols columns + + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols : 0; + SumPtr += NCols; + nblk -= NCols; + } + + nblk += NCols; + for (int64_t n = 0; n < nblk; n++) { + const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + (void)QuantBZeroPointPtr; +#endif + __m256 acc0 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= 2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + BlkLen; + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk1); + +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + const __m256 scale_a0_8_ps = _mm256_set1_ps(Q8BlkScale(QuantABlk0)); + const __m256 scale_a1_8_ps = _mm256_set1_ps(Q8BlkScale(QuantABlk1)); +#else + const float& scale_a0 = QuantAScalePtr[0]; + const float& scale_a1 = QuantAScalePtr[1]; +#endif + + // Col0 +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr, QuantBScalePtr, acc0, low_mask, bzp8); +#else + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); +#endif + // increment block pointers + QuantAPtr += BlkLen * 2; + QuantAScalePtr += 2; + QuantBDataPtr += 16 * 2; + QuantBScalePtr += 2; + } + + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + + const float& scale_a0 = *QuantAScalePtr; + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr, scale_00, acc0, low_mask, bzp8); +#else + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); +#endif + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h new file mode 100644 index 0000000000000..e9c3812bde899 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h @@ -0,0 +1,312 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + + +static MLAS_FORCEINLINE void +accumulate_blklen64_r1c1blk1_zp_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + const std::byte* QuantBZeroPointPtr, + const bool is_lower_half_byte_zp, + __m256& acc0, + const __m256i& low_mask +) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + + const __m256i bzp8 = _mm256_set1_epi8(get_zp(is_lower_half_byte_zp, QuantBZeroPointPtr)); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); + + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av00_32_epi8, bv0_32_epi8)); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av01_32_epi8, bv1_32_epi8)); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); + __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); +} + +static MLAS_FORCEINLINE void +accumulate_blklen64_r1c1blk1_zp_is_8_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m256& acc0, + const __m256i& low_mask, + const __m256i& bzp8 +) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); + + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av00_32_epi8, bv0_32_epi8)); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av01_32_epi8, bv1_32_epi8)); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); + __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmM1C4BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t SubblkLen64 = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen64; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountN % NCols4 == 0); + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C; + + const __m256i low_mask = _mm256_set1_epi8(0x0F); + const size_t StrideQuantBData1 = 1 * SubblkDataSizeInBytes; + const size_t StrideQuantBScale1 = 1; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + for (size_t k = 0; k < BlockCountK; ++k) { + [[maybe_unused]] const bool is_lower_half_byte_zp = (k % 2) == 0; + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + if constexpr (HasZeroPoint) { + accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, is_lower_half_byte_zp, acc[0], low_mask); + accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale1, QuantBZeroPointPtr + StrideQuantBZeroPoint, is_lower_half_byte_zp, acc[1], low_mask); + accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale1, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, is_lower_half_byte_zp, acc[2], low_mask); + accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale1, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, is_lower_half_byte_zp, acc[3], low_mask); + } else { + const __m256i bzp8 = _mm256_set1_epi8(8); + accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0], low_mask, bzp8); + accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale1, acc[1], low_mask, bzp8); + accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale1, acc[2], low_mask, bzp8); + accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale1, acc[3], low_mask, bzp8); + } + + // increment block pointers + QuantAPtr += SubblkLen64; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr += NCols4; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += k % 2; + } + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmM1C1BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + assert(CountN < NCols4); + + const __m256i low_mask = _mm256_set1_epi8(0x0F); + [[maybe_unused]] const __m256i bzp8 = _mm256_set1_epi8(8); + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + for (size_t k = 0; k < BlockCountK; ++k) { + [[maybe_unused]] const bool is_lower_half_byte_zp = (k % 2) == 0; + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + if constexpr (HasZeroPoint) { + accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, is_lower_half_byte_zp, acc0, low_mask); + } else { + accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0, low_mask, bzp8); + } + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += k % 2; + } + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +template +MLAS_FORCEINLINE void +MlasQ4Int8GemmKernelBlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleCols > 0) { + Q4Int8GemmM1C4BlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + multipleCols, + BlockCountK, + Bias); + } + + if (remainingCols > 0) { + Q4Int8GemmM1C1BlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, + C + multipleCols, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr); + } +} diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index dedc01de9655d..548f24e8ac69e 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -263,9 +263,10 @@ void RunTest(const TestOptions& opts, } // namespace TEST(MatMulNBits, Float32) { + // onnxruntime::profiling::Profiler::Profiler::Instance().StartProfiling("profile.json"); for (auto M : {1, 2, 100}) { - for (auto N : {1, 2, 32, 288}) { - for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { + for (auto N : {/*2560, */ 1, 2, 32, 288}) { + for (auto K : {/*2560, */ 16, 32, 64, 128, 256, 1024, 93, 1234}) { for (auto block_size : {16, 32, 64, 128}) { for (auto accuracy_level : {0, 1, 4}) { TestOptions base_opts{}; diff --git a/onnxruntime/test/mlas/bench/bench_q4dq.cpp b/onnxruntime/test/mlas/bench/bench_q4dq.cpp index 9d15c9a6bf994..6d21ed2eef864 100644 --- a/onnxruntime/test/mlas/bench/bench_q4dq.cpp +++ b/onnxruntime/test/mlas/bench/bench_q4dq.cpp @@ -9,10 +9,10 @@ #include "core/util/thread_utils.h" static void BM_QDQBlockwiseQuantizer_QuantizeColumnwise(benchmark::State& state) { - int M = state.range(0); - int N = state.range(1); - int quant_block_size = state.range(2); - int threads = state.range(3); + int M = (int)state.range(0); + int N = (int)state.range(1); + int quant_block_size = (int)state.range(2); + int threads = (int)state.range(3); size_t scale_size = (M + quant_block_size - 1) / quant_block_size * N; auto src = RandomVectorUniform(M * N, -16.0f, 14.0f); @@ -37,10 +37,10 @@ static void BM_QDQBlockwiseQuantizer_QuantizeColumnwise(benchmark::State& state) } static void BM_MlasQuantizeBlockwise(benchmark::State& state) { - int M = state.range(0); - int N = state.range(1); - int quant_block_size = state.range(2); - int threads = state.range(3); + int M = (int)state.range(0); + int N = (int)state.range(1); + int quant_block_size = (int)state.range(2); + int threads = (int)state.range(3); size_t scale_size = (M + quant_block_size - 1) / quant_block_size * N; auto src = RandomVectorUniform(M * N, -16.0f, 14.0f); @@ -65,10 +65,10 @@ static void BM_MlasQuantizeBlockwise(benchmark::State& state) { } static void BM_QDQBlockwiseQuantizer_TransposeColumnwise(benchmark::State& state) { - int M = state.range(0); - int N = state.range(1); - int quant_block_size = state.range(2); - int threads = state.range(3); + int M = (int)state.range(0); + int N = (int)state.range(1); + int quant_block_size = (int)state.range(2); + int threads = (int)state.range(3); bool add8 = state.range(4) != 0; int quant_num_M = (M + quant_block_size - 1) / quant_block_size; int blob_size = (quant_block_size + 1) / 2; diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp index 354621eff42b6..73c78b8cc3d47 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -53,6 +53,7 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, std::vector QuantBData(QuantBDataSizeInBytes); std::vector QuantBScale(QuantBScaleSize); std::vector QuantBZeroPoint(Symmetric ? 0 : QuantBZeroPointSizeInBytes); + bool has_zp_input = !Symmetric; MlasQuantizeBlockwise(QuantBData.data(), QuantBScale.data(), Symmetric ? nullptr : QuantBZeroPoint.data(), @@ -71,15 +72,17 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, PackedQuantBDataSize > 0) { PackedQuantBData = std::make_unique(PackedQuantBDataSize); MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData.data(), PackedQuantBData.get(), + QuantBScale.data(), has_zp_input, QuantBZeroPoint.data(), tp.get()); } MLAS_SQNBIT_GEMM_DATA_PARAMS params{}; params.A = A.data(); params.lda = K; - params.QuantBData = PackedQuantBData != nullptr - ? static_cast(PackedQuantBData.get()) - : static_cast(QuantBData.data()); + if (PackedQuantBData != nullptr) + params.QuantBDataWorkspace = static_cast(PackedQuantBData.get()); + else + params.QuantBDataWorkspace = static_cast(QuantBData.data()); params.QuantBScale = QuantBScale.data(); params.QuantBZeroPoint = Symmetric ? nullptr : QuantBZeroPoint.data(); params.Bias = HasBias ? Bias.data() : nullptr; diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index f391027de4d51..0710981fa17c6 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -55,8 +55,8 @@ class MlasSQNBitGemmTest : public MlasTestBase { size_t K, const float* A, size_t lda, - const void* QuantBData, - const void* PackedQuantBData, + const void* /*QuantBData*/, + const void* PackedQuantBDataWorkspace, const float* QuantBScale, const void* QuantBZeroPoint, const float* Bias, @@ -71,7 +71,12 @@ class MlasSQNBitGemmTest : public MlasTestBase { params.Bias = Bias; params.C = C; params.ldc = ldc; - params.QuantBData = PackedQuantBData != nullptr ? PackedQuantBData : QuantBData; +#ifdef MLAS_TARGET_AMD64_IX86 + if (ComputeType == CompInt8) { + params.QuantBDataWorkspace = PackedQuantBDataWorkspace; + } +#endif + params.PackedQuantBData = static_cast(PackedQuantBDataWorkspace); params.QuantBScale = QuantBScale; params.QuantBZeroPoint = QuantBZeroPoint; params.PostProcessor = nullptr; @@ -213,12 +218,19 @@ class MlasSQNBitGemmTest : public MlasTestBase { auto print_matrix = [](size_t nrows, size_t ncols, const float* data) { for (size_t row = 0; row < nrows; ++row) { for (size_t col = 0; col < ncols; ++col) { - std::cout << data[row * ncols + col] << "\t"; + std::cout << data[row * ncols + col] << ", "; } std::cout << "\n"; } }; + auto print_matrix_col = [](size_t nrows, size_t ncols, size_t col, const float* data) { + for (size_t row = 0; row < nrows; ++row) { + std::cout << data[row * ncols + col] << ", "; + } + std::cout << "\n"; + }; + std::cout << "A:\n"; print_matrix(M, K, A); std::cout << "B:\n"; @@ -258,14 +270,25 @@ class MlasSQNBitGemmTest : public MlasTestBase { Workspace = BufferWorkspace.GetBuffer(WorkspaceSize); } - void* PackedQuantBData = nullptr; + void* PackedQuantBDataWorkspace = nullptr; if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); PackedQuantBDataSize > 0) { - PackedQuantBData = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize); - MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBData, + PackedQuantBDataWorkspace = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize); + bool has_zp_input = QuantBZeroPoint != nullptr; + MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBDataWorkspace, + QuantBScale, has_zp_input, QuantBZeroPoint, GetMlasThreadPool()); } + CallGemm(M, N, K, + A, /* lda */ K, + QuantBData, PackedQuantBDataWorkspace, QuantBScale, QuantBZeroPoint, + Bias, + C, /* ldc */ N, + Workspace, + ComputeType, + Threadpool); + if (ComputeType == CompFp32) { CallReferenceGemm_CompFp32(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); } else if (ComputeType == CompInt8) { @@ -275,15 +298,6 @@ class MlasSQNBitGemmTest : public MlasTestBase { << ComputeType << " (" << ComputeTypeName(ComputeType) << ")"; } - CallGemm(M, N, K, - A, /* lda */ K, - QuantBData, PackedQuantBData, QuantBScale, QuantBZeroPoint, - Bias, - C, /* ldc */ N, - Workspace, - ComputeType, - Threadpool); - size_t f = 0; for (size_t m = 0; m < M; m++) { for (size_t n = 0; n < N; n++, f++) { @@ -382,7 +396,6 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture