diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 56bcce28ef5f8..23657d3379fae 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -88,7 +88,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF) option(onnxruntime_USE_SNPE "Build with SNPE support" OFF) option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF) option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) -option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" ON) +option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" OFF) option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF) option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON) option(onnxruntime_BUILD_CSHARP "Build C# library" OFF) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index f7103c3b00a37..682dcfc5fe84f 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -167,6 +167,9 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp ${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAmx.asm ${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAvx2.asm ${MLAS_SRC_DIR}/amd64/QgemmU8U8KernelAvx2.asm @@ -530,6 +533,7 @@ else() ${MLAS_SRC_DIR}/x86_64/ErfKernelFma3.S ${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp ) set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma") @@ -549,9 +553,15 @@ else() ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx512Vnni.S ${MLAS_SRC_DIR}/x86_64/QgemmU8X8KernelAvx512Core.S ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx512Core.S + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp ) set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mavx512bw -mavx512dq -mavx512vl") + set(mlas_platform_srcs_avx512vnni + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp + ) + set_source_files_properties(${mlas_platform_srcs_avx512vnni} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f") + set(mlas_platform_srcs ${MLAS_SRC_DIR}/activate_fp16.cpp ${MLAS_SRC_DIR}/dwconv.cpp @@ -563,6 +573,7 @@ else() ${mlas_platform_srcs_avx2} ${mlas_platform_srcs_avx512f} ${mlas_platform_srcs_avx512core} + ${mlas_platform_srcs_avx512vnni} ) if (NOT onnxruntime_ORT_MINIMAL_BUILD) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 8d3d80751483d..64d2af955d9c7 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2876,7 +2876,7 @@ This version of the operator has been available since version 1 of the 'com.micr And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. 3. Input B's scale and zero point are specified by input scales and zero_points. - Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: + Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: - n_blocks_per_col = (K + block_size - 1) / block_size - blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>) For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t. diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 602dd98d8c0d6..6144ca46ac180 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -184,7 +184,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat is_packed = true; } -#else // defined(ORT_NEURAL_SPEED) +#else // defined(ORT_NEURAL_SPEED) if (input_idx == 1) { const auto compute_type = static_cast(accuracy_level_); @@ -204,7 +204,6 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } is_packed = true; } - #endif // defined(ORT_NEURAL_SPEED) return Status::OK(); @@ -315,6 +314,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(false); +#ifndef ORT_NEURAL_SPEED const bool has_single_b_matrix = (!act_order_) && (!zero_point_is_not_quant_) && std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), [](size_t offset) { return offset == 0; }); @@ -358,6 +358,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { return Status::OK(); } } +#endif // not defined(ORT_NEURAL_SPEED) const Tensor* b = ctx->Input(1); const uint8_t* b_data = b->Data(); diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 0f364b8880066..5cf1818bbf9e8 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3407,7 +3407,7 @@ MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7 And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. 3. Input B's scale and zero point are specified by input scales and zero_points. - Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: + Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: - n_blocks_per_col = (K + block_size - 1) / block_size - blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>) For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t. diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 4b93dde1bcef9..04da9ab4fd749 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -971,6 +971,12 @@ struct MLAS_SQNBIT_GEMM_DISPATCH; extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon; +extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2; + +extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512; + +extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; + // // Quantized depthwise convolution kernels. // diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index a53c5085b10cf..3f86b3f7c5062 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -373,6 +373,7 @@ Return Value: this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernelAvx2; this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernelAvx2; this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelFma3; + this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2; // // Check if the processor supports Hybrid core architecture. @@ -439,6 +440,7 @@ Return Value: this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx512Core; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Core; this->FpQ4GemmDispatch = &MlasFpQ4GemmDispatchAvx512; + this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512; // // Check if the processor supports AVX512VNNI. @@ -451,6 +453,7 @@ Return Value: this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Vnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Vnni; this->Q8Q4GemmDispatch = &MlasQ8Q4GemmDispatchAvx512vnni; + this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512vnni; } } } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 38c31c8841761..4fd28f5e6998f 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -394,6 +394,17 @@ SQ4BitGemm_CompInt8( const size_t RangeCountN ) { +#ifdef MLAS_TARGET_AMD64_IX86 + if (RangeCountM != 1) { + // perf experiment shows fp32 is faster than int8 in M > 1 cases. + // route to fp32 compute before int8 compute is improved. + SQ4BitGemm_CompFp32( + BlkLen, + K, DataParams, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN + ); + return; + } +#endif constexpr size_t BlkBitWidth = 4; const size_t k_blks = MlasDivRoundup(K, BlkLen); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp new file mode 100644 index 0000000000000..b5d7a4e78fbe2 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -0,0 +1,1113 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_avx2.cpp.h + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for x64 avx2. + +--*/ + +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" +#include "sqnbitgemm_kernel_avx_common_int8.h" + +MLAS_FORCEINLINE +__m256 +load_float_n_avx2(const float* data, int n) +{ + assert(n <= 8); + if (n <= 0) { + return _mm256_setzero_ps(); + } + static const int32_t mask_buffer[16] = {-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0}; + const __m256i load_mask = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - n)); + return _mm256_maskload_ps(data, load_mask); +} + +MLAS_FORCEINLINE void +Q4BitBlkDequantBForSgemmBlkLen16_CompFp32_avx2( + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + const size_t CountN, + const size_t CountK, + const size_t BlockCountK +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + + constexpr size_t blk_data_size_in_bytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t b_data_col_stride_in_bytes = BlockCountK * blk_data_size_in_bytes; + // TODO: constexpr use temaplte parameter + /*constexpr*/ const bool HasZeroPoint = QuantBZeroPoint != nullptr; + const size_t zp_col_stride_in_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + constexpr size_t NCols8 = 8; // process NCols8 columns of QuantB at a time + constexpr size_t GemmFloatKernelWidth16 = 16; // mlas GemmFloatKernel requires B with width 16 + const __m128i low_mask = _mm_set1_epi8(0xF); + for (size_t col = 0; col < CountN; col += NCols8) { + const int cols = std::min((int)NCols8, (int)CountN - (int)col); + for (size_t k = 0; k < BlockCountK; k++) { + // count # of tiles plus blks of the current tile from top + const size_t tile_count = col / GemmFloatKernelWidth16; + float* dst_ptr = FpData + (tile_count * CountK + k * BlkLen16) * GemmFloatKernelWidth16; + if (col % GemmFloatKernelWidth16 >= NCols8) { + // for the second half to 16 width tile + dst_ptr += NCols8; + } + const std::byte* b_data_ptr = QuantBData + col * b_data_col_stride_in_bytes + k * blk_data_size_in_bytes; + const float* scale_ptr = QuantBScale + col * BlockCountK + k; + const std::byte* zp_ptr = QuantBZeroPoint + col * zp_col_stride_in_bytes + k / 2; + bool is_lower = (k % 2) == 0; + + __m256i weight_16_epi16[NCols8]; + __m256 scale_8_ps[NCols8]; + UnrolledLoop([&](size_t col_) { + if ((int)col_ < cols) { + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + __m128i bvi = _mm_loadl_epi64((__m128i const*)(b_data_ptr + col_ * b_data_col_stride_in_bytes)); + const __m128i lower = _mm_and_si128(bvi, low_mask); + const __m128i upper = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bvi, 4), low_mask), 8); + __m128i weight_16_epi8 = _mm_add_epi8(upper, lower); + + if (HasZeroPoint) { + std::byte zp_packed = *(zp_ptr + col_ * zp_col_stride_in_bytes); + uint8_t zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + weight_16_epi8 = _mm_sub_epi8(weight_16_epi8, _mm_set1_epi8(zp)); + } else { + const __m128i eight = _mm_set1_epi8(8); + weight_16_epi8 = _mm_sub_epi8(weight_16_epi8, eight); + } + weight_16_epi16[col_] = _mm256_cvtepi8_epi16(weight_16_epi8); + scale_8_ps[col_] = _mm256_set1_ps(*(scale_ptr + col_ * BlockCountK)); + } else { + weight_16_epi16[col_] = _mm256_setzero_si256(); + scale_8_ps[col_] = _mm256_setzero_ps(); + } + }); + for (int i_of_2 = 0; i_of_2 < 2; i_of_2++) { + __m256 weight_8_ps[8]; + for (size_t col_ = 0; col_ < 8; col_++) { + if ((int)col_ < cols) { + if (i_of_2 == 0) { + __m256i weight_i_8_epi32 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(weight_16_epi16[col_], 0)); + weight_8_ps[col_] = _mm256_mul_ps(_mm256_cvtepi32_ps(weight_i_8_epi32), scale_8_ps[col_]); + } else { + __m256i weight_i_8_epi32 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(weight_16_epi16[col_], 1)); + weight_8_ps[col_] = _mm256_mul_ps(_mm256_cvtepi32_ps(weight_i_8_epi32), scale_8_ps[col_]); + } + } else { + weight_8_ps[col_] = _mm256_setzero_ps(); + } + } + // transpose and store + __m256 a0 = _mm256_unpacklo_ps(weight_8_ps[0], weight_8_ps[1]); + __m256 a1 = _mm256_unpackhi_ps(weight_8_ps[0], weight_8_ps[1]); + __m256 a2 = _mm256_unpacklo_ps(weight_8_ps[2], weight_8_ps[3]); + __m256 a3 = _mm256_unpackhi_ps(weight_8_ps[2], weight_8_ps[3]); + __m256 a4 = _mm256_unpacklo_ps(weight_8_ps[4], weight_8_ps[5]); + __m256 a5 = _mm256_unpackhi_ps(weight_8_ps[4], weight_8_ps[5]); + __m256 a6 = _mm256_unpacklo_ps(weight_8_ps[6], weight_8_ps[7]); + __m256 a7 = _mm256_unpackhi_ps(weight_8_ps[6], weight_8_ps[7]); + + __m256 b0 = _mm256_shuffle_ps(a0, a2, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 b1 = _mm256_shuffle_ps(a0, a2, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 b2 = _mm256_shuffle_ps(a1, a3, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 b3 = _mm256_shuffle_ps(a1, a3, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 b4 = _mm256_shuffle_ps(a4, a6, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 b5 = _mm256_shuffle_ps(a4, a6, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 b6 = _mm256_shuffle_ps(a5, a7, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 b7 = _mm256_shuffle_ps(a5, a7, _MM_SHUFFLE(3, 2, 3, 2)); + + // next i_of_2th row + const size_t ij_offset_in_k = i_of_2 * 8 * GemmFloatKernelWidth16; + __m256 weight_transposed_8_ps = _mm256_permute2f128_ps(b0, b4, 0x20); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 0 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b1, b5, 0x20); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 1 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b2, b6, 0x20); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 2 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b3, b7, 0x20); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 3 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b0, b4, 0x31); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 4 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b1, b5, 0x31); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 5 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b2, b6, 0x31); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 6 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b3, b7, 0x31); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 7 * GemmFloatKernelWidth16, weight_transposed_8_ps); + } + } + } +} + +template +MLAS_FORCEINLINE void +Q4BitBlkDequantBForSgemmBlkLen32AndMore_CompFp32_avx2( + const size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + const size_t CountN, + const size_t CountK, + const size_t BlockCountK +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols8 = 8; // process NCols8 columns of QuantB at a time + constexpr size_t GemmFloatKernelWidth16 = 16; // mlas GemmFloatKernel requires B with width 16 + constexpr size_t SubblkLen32 = 32; // process SubblkLen32 rows of QuantB at a time + + const size_t blk_data_size_in_bytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t subblk_data_size_in_bytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, SubblkLen32); + const size_t b_data_col_stride_in_bytes = BlockCountK * blk_data_size_in_bytes; + // TODO: constexpr use temaplte parameter + /*constexpr*/ const bool HasZeroPoint = QuantBZeroPoint != nullptr; + const size_t zp_col_stride_in_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] int count_half_4 = 0; + + const __m256i low_mask = _mm256_set1_epi8(0xF); + for (size_t col = 0; col < CountN; col += NCols8) { + // TODO: handle last tile with cols < NCols8 + const size_t cols = std::min(NCols8, CountN - col); + for (size_t k = 0; k < BlockCountK; k++) { + // count # of tiles plus blks of the current tile from top + const size_t tile_count = col / GemmFloatKernelWidth16; + float* dst_ptr = FpData + (tile_count * CountK + k * BlkLen) * GemmFloatKernelWidth16; + if (col % GemmFloatKernelWidth16 >= NCols8) { + // for the second half to 16 width tile + dst_ptr += NCols8; + } + const std::byte* b_data_ptr = QuantBData + col * b_data_col_stride_in_bytes + k * blk_data_size_in_bytes; + const float* scale_ptr = QuantBScale + col * BlockCountK + k; + const std::byte* zp_ptr = QuantBZeroPoint + col * zp_col_stride_in_bytes + k / 2; + bool is_lower = (k % 2) == 0; + + for (size_t subblk = 0; subblk < BlkLen / SubblkLen32; subblk++) { + __m256i weight_32_epi8[NCols8]; + __m256 scale_8_ps[NCols8]; + if constexpr (IsBlkLen64Layout) { + count_half_4 = 4 * (subblk % 2); + } + UnrolledLoop([&](size_t col_) { + if (col_ < cols) { + if constexpr (IsBlkLen64Layout) { + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // load 64 weights at once, parse to get v0 - v31 if subblk % 2 == 0, otherwise get v32 - v63 + // at the end of subblk loop, increment b_data_ptr by 2 * subblk_data_size_in_bytes if subblk % 2 == 1 + // so that all v0-64 of the pack are dequantized. + const __m256i bvi = _mm256_loadu_si256((__m256i const*)(b_data_ptr + col_ * b_data_col_stride_in_bytes)); + weight_32_epi8[col_] = _mm256_and_si256(_mm256_srli_epi16(bvi, count_half_4), low_mask); + } else { + // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + __m128i bvi = _mm_loadu_si128((__m128i const*)(b_data_ptr + col_ * b_data_col_stride_in_bytes)); + __m128i lower = _mm_and_si128(bvi, _mm256_castsi256_si128(low_mask)); + __m128i upper = _mm_and_si128(_mm_srli_epi16(bvi, 4), _mm256_castsi256_si128(low_mask)); + weight_32_epi8[col_] = _mm256_set_m128i(upper, lower); + } + + if (HasZeroPoint) { + std::byte zp_packed = *(zp_ptr + col_ * zp_col_stride_in_bytes); + uint8_t zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + weight_32_epi8[col_] = _mm256_sub_epi8(weight_32_epi8[col_], _mm256_set1_epi8(zp)); + } else { + const __m256i eight = _mm256_set1_epi8(8); + weight_32_epi8[col_] = _mm256_sub_epi8(weight_32_epi8[col_], eight); + } + + scale_8_ps[col_] = _mm256_set1_ps(*(scale_ptr + col_ * BlockCountK)); + } else { + weight_32_epi8[col_] = _mm256_setzero_si256(); + scale_8_ps[col_] = _mm256_setzero_ps(); + } + }); + for (int i_of_4 = 0; i_of_4 < 4; i_of_4++) { + __m256 weight_8_ps[8]; + for (size_t col_ = 0; col_ < 8; col_++) { + if (col_ < cols) { + if (i_of_4 == 0) { + __m256i weight_i_16_epi16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(weight_32_epi8[col_], 0)); + __m256i weight_i_j_8_epi32 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(weight_i_16_epi16, 0)); + weight_8_ps[col_] = _mm256_mul_ps(_mm256_cvtepi32_ps(weight_i_j_8_epi32), scale_8_ps[col_]); + } else if (i_of_4 == 1) { + __m256i weight_i_16_epi16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(weight_32_epi8[col_], 0)); + __m256i weight_i_j_8_epi32 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(weight_i_16_epi16, 1)); + weight_8_ps[col_] = _mm256_mul_ps(_mm256_cvtepi32_ps(weight_i_j_8_epi32), scale_8_ps[col_]); + } else if (i_of_4 == 2) { + __m256i weight_i_16_epi16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(weight_32_epi8[col_], 1)); + __m256i weight_i_j_8_epi32 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(weight_i_16_epi16, 0)); + weight_8_ps[col_] = _mm256_mul_ps(_mm256_cvtepi32_ps(weight_i_j_8_epi32), scale_8_ps[col_]); + } else if (i_of_4 == 3) { + __m256i weight_i_16_epi16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(weight_32_epi8[col_], 1)); + __m256i weight_i_j_8_epi32 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(weight_i_16_epi16, 1)); + weight_8_ps[col_] = _mm256_mul_ps(_mm256_cvtepi32_ps(weight_i_j_8_epi32), scale_8_ps[col_]); + } + } else { + weight_8_ps[col_] = _mm256_setzero_ps(); + } + } + // transpose and store + __m256 a0 = _mm256_unpacklo_ps(weight_8_ps[0], weight_8_ps[1]); + __m256 a1 = _mm256_unpackhi_ps(weight_8_ps[0], weight_8_ps[1]); + __m256 a2 = _mm256_unpacklo_ps(weight_8_ps[2], weight_8_ps[3]); + __m256 a3 = _mm256_unpackhi_ps(weight_8_ps[2], weight_8_ps[3]); + __m256 a4 = _mm256_unpacklo_ps(weight_8_ps[4], weight_8_ps[5]); + __m256 a5 = _mm256_unpackhi_ps(weight_8_ps[4], weight_8_ps[5]); + __m256 a6 = _mm256_unpacklo_ps(weight_8_ps[6], weight_8_ps[7]); + __m256 a7 = _mm256_unpackhi_ps(weight_8_ps[6], weight_8_ps[7]); + + __m256 b0 = _mm256_shuffle_ps(a0, a2, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 b1 = _mm256_shuffle_ps(a0, a2, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 b2 = _mm256_shuffle_ps(a1, a3, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 b3 = _mm256_shuffle_ps(a1, a3, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 b4 = _mm256_shuffle_ps(a4, a6, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 b5 = _mm256_shuffle_ps(a4, a6, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 b6 = _mm256_shuffle_ps(a5, a7, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 b7 = _mm256_shuffle_ps(a5, a7, _MM_SHUFFLE(3, 2, 3, 2)); + + const size_t ij_offset_in_k = i_of_4 * 8 * GemmFloatKernelWidth16; + __m256 weight_transposed_8_ps = _mm256_permute2f128_ps(b0, b4, 0x20); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 0 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b1, b5, 0x20); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 1 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b2, b6, 0x20); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 2 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b3, b7, 0x20); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 3 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b0, b4, 0x31); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 4 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b1, b5, 0x31); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 5 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b2, b6, 0x31); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 6 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b3, b7, 0x31); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 7 * GemmFloatKernelWidth16, weight_transposed_8_ps); + } + dst_ptr += SubblkLen32 * GemmFloatKernelWidth16; + if constexpr (IsBlkLen64Layout) { + b_data_ptr += (subblk % 2) * 2 * subblk_data_size_in_bytes; + } else { + b_data_ptr += subblk_data_size_in_bytes; + } + } // subblk + } + } +} + +MLAS_FORCEINLINE void +Q4BitBlkDequantBForSgemm_CompFp32_avx2( + const size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + const size_t CountN, + const size_t CountK, + const size_t BlockStrideQuantB +) +{ + if (BlkLen == 16) { + Q4BitBlkDequantBForSgemmBlkLen16_CompFp32_avx2( + FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB + ); + } else if (BlkLen == 32) { + Q4BitBlkDequantBForSgemmBlkLen32AndMore_CompFp32_avx2( + BlkLen, FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB + ); + } else { + Q4BitBlkDequantBForSgemmBlkLen32AndMore_CompFp32_avx2( + BlkLen, FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB + ); + } +} + +MLAS_FORCEINLINE +void +SQ4BitGemmM1Kernel_CompInt8_avx2( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (QuantBZeroPoint != nullptr) { + constexpr bool HasZeroPoint = true; + if (BlkLen == 16) { + SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else if (BlkLen == 32) { + SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } + } else { + constexpr bool HasZeroPoint = false; + if (BlkLen == 16) { + SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else if (BlkLen == 32) { + SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } + } +} + +template +MLAS_FORCEINLINE void +ComputeDotProducts_BlkLen16_CompFp32_avx2( + size_t BlkLen, + const float* ARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + float* sum_ptr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* bias_ptr +) +{ + if constexpr (!HasZeroPoint) { + // Suppress unused variable warnings + (void)QuantBZeroPointColPtr; + (void)StrideQuantBZeroPoint; + } + + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t SubBlkLen16 = 16; + constexpr size_t SubBlkStep8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, SubBlkLen16); + static_assert(SubBlkStep8 == 8); // 16 * 4 / 8 + + __m256 acc[NCols]; + UnrolledLoop([&](size_t i) { + acc[i] = _mm256_setzero_ps(); + }); + + const std::byte* b_blk_data_ptr = QuantBDataColPtr; + const float* s = QuantBScaleColPtr; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + // only used if HasZeroPoint == true + + for (size_t k = 0; k < CountK; k += BlkLen) { + size_t ck = std::min(CountK - k, BlkLen); + + float scale_v[NCols]; + UnrolledLoop([&](size_t i) { + scale_v[i] = *(s + StrideQuantBScale * i); + }); + + std::byte* b_blk_data_col_ptr[NCols]; + UnrolledLoop([&](size_t i) { + b_blk_data_col_ptr[i] = (std::byte*)(b_blk_data_ptr + StrideQuantBData * i); + }); + + [[maybe_unused]] uint8_t offset[NCols]; + // not ready for "Manual conversion to float" in neon yet. following neon to unpack to uint8_t. + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const std::byte zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[i] = std::to_integer(zp); + }); + } + + for (size_t kk = 0; kk < ck; kk += SubBlkLen16) { + int kklen = std::min((int)SubBlkLen16, (int)(ck - kk)); + + // Load A row vectors + int n_to_read = std::min(kklen, 8); + __m256 av_lo = load_float_n_avx2(ARowPtr + k + kk, n_to_read); + n_to_read = std::min(kklen - 8, 8); + __m256 av_hi = load_float_n_avx2(ARowPtr + k + kk + 8, n_to_read); + + UnrolledLoop([&](size_t i) { + // SubBlkLen = 16: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + // SubBlkLen = 32: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // Load B col vectors. get SubBlkLen(16) 4 bits quantized features from each column + __m128i bvi4 = _mm_loadl_epi64((__m128i const*)(b_blk_data_col_ptr[i])); + b_blk_data_col_ptr[i] += SubBlkStep8; + + // TODO: avoid _mm_set1_epi8 + //__m128i lower_mask_epi8 = _mm_cmpeq_epi16(bvi4, bvi4); // can use any __m128i + // lower_mask_epi8 = _mm_srli_epi16(lower_mask_epi8, 13); + // lower_mask_epi8 = _mm_packus_epi16(lower_mask_epi8, lower_mask_epi8); + __m128i lower_mask_epi8 = _mm_set1_epi8(0x0F); // Mask to isolate the lower 4 bits + + const __m128i lower = _mm_and_si128(bvi4, lower_mask_epi8); + const __m128i upper = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bvi4, 4), lower_mask_epi8), 8); + __m256i bv_epi16 = _mm256_cvtepi8_epi16(_mm_add_epi8(upper, lower)); // unpacked 16 weights of epi16 + + // Subtract zero-point from the integers + if constexpr (HasZeroPoint) { + // Subtract zero-point from the integers + __m256i zp = _mm256_set1_epi16(offset[i]); + bv_epi16 = _mm256_sub_epi16(bv_epi16, zp); + } else { + // Subtract 8 from the integers + const __m256i eight = _mm256_set1_epi16(8); + bv_epi16 = _mm256_sub_epi16(bv_epi16, eight); + } + + // Convert to 16 epi16 to 16 float32 + const __m128i bv_lo = _mm256_extractf128_si256(bv_epi16, 0); + const __m128i bv_hi = _mm256_extractf128_si256(bv_epi16, 1); + + __m256 bvf_lo = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(bv_lo)); + __m256 bvf_hi = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(bv_hi)); + + // multiply by scale + __m256 scale_ps = _mm256_set1_ps(scale_v[i]); + bvf_lo = _mm256_mul_ps(bvf_lo, scale_ps); + bvf_hi = _mm256_mul_ps(bvf_hi, scale_ps); + + // c[m,n] += a[m,k] * b[k,n] + acc[i] = _mm256_fmadd_ps(bvf_lo, av_lo, acc[i]); + acc[i] = _mm256_fmadd_ps(bvf_hi, av_hi, acc[i]); + }); + } // kk + + b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + s++; + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + } // k + + if constexpr (NCols == 4) { + __m128 acc_x = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (bias_ptr != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(bias_ptr)); + } + _mm_storeu_ps(sum_ptr, acc_x); + } else { + UnrolledLoop([&](size_t i) { + __m128 vlow = _mm256_castps256_ps128(acc[i]); + __m128 vhigh = _mm256_extractf128_ps(acc[i], 1); // Extract high 128 bit + + // Add the two 128-bit vectors together + __m128 vsum = _mm_add_ps(vlow, vhigh); + // Horizontally add the elements of the resulting 128-bit vector + vsum = _mm_hadd_ps(vsum, vsum); + vsum = _mm_hadd_ps(vsum, vsum); + + _mm_store_ss(&sum_ptr[i], vsum); + sum_ptr[i] += bias_ptr == nullptr ? 0.0f : bias_ptr[i]; + }); + } +} + +// TODO: flow MlasQ4GemmKernelBlkLen16Avx512f to improve perf +template +void +SQ4BitGemmM1Kernel_BlkLen16_CompFp32_avx2( + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + + const float* ARowPtr = A; + float* CRowPtr = C; + + const size_t BlockCountK = BlockStrideQuantB; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + int64_t nblk = static_cast(CountN) - NCols4; + + while (nblk >= 0) { + ComputeDotProducts_BlkLen16_CompFp32_avx2( + BlkLen16, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next `NCols` columns + + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + + nblk -= NCols4; + } + + // left over columns less than `NCols`? + nblk += NCols4; + for (int64_t n = 0; n < nblk; ++n) { + ComputeDotProducts_BlkLen16_CompFp32_avx2<1, HasZeroPoint>( + BlkLen16, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +// TODO: flow MlasQ4GemmKernelBlkLen32PlusAvx512f to improve perf +template +MLAS_FORCEINLINE void +ComputeDotProducts_BlkLen32Plus_CompFp32_avx2( + size_t BlkLen, + const float* ARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + float* sum_ptr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* bias_ptr +) +{ + if constexpr (!HasZeroPoint) { + // Suppress unused variable warnings + (void)QuantBZeroPointColPtr; + (void)StrideQuantBZeroPoint; + } + + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t SubBlkLen32 = 32; + constexpr size_t SubBlkStep16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, SubBlkLen32); + static_assert(SubBlkStep16 == 16); // 32 * 4 / 8 + + __m256i lowMask = _mm256_set1_epi8(0x0F); + + __m256 acc[NCols]; + UnrolledLoop([&](size_t i) { + acc[i] = _mm256_setzero_ps(); + }); + + const std::byte* b_blk_data_ptr = QuantBDataColPtr; + const float* s = QuantBScaleColPtr; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + [[maybe_unused]] int count_half_4 = 0; + // only used if HasZeroPoint == true + + for (size_t k = 0; k < CountK; k += BlkLen) { + size_t ck = std::min(CountK - k, BlkLen); + + float scale_v[NCols]; + UnrolledLoop([&](size_t i) { + scale_v[i] = *(s + StrideQuantBScale * i); + }); + + std::byte* b_blk_data_col_ptr[NCols]; + UnrolledLoop([&](size_t i) { + b_blk_data_col_ptr[i] = (std::byte*)(b_blk_data_ptr + StrideQuantBData * i); + }); + + [[maybe_unused]] uint8_t offset[NCols]; + // not ready for "Manual conversion to float" in neon yet. + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const std::byte zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[i] = std::to_integer(zp); + }); + } + + for (size_t kk = 0; kk < ck; kk += SubBlkLen32) { + int kklen = std::min((int)SubBlkLen32, (int)(ck - kk)); + + // Load 4 float8 from A + int n_to_read = std::min(kklen, 8); + __m256 av0_8_ps = load_float_n_avx2(ARowPtr + k + kk, n_to_read); + + n_to_read = std::min(kklen - 8, 8); + __m256 av1_8_ps = load_float_n_avx2(ARowPtr + k + kk + 8, n_to_read); + + n_to_read = std::min(kklen - 16, 8); + __m256 av2_8_ps = load_float_n_avx2(ARowPtr + k + kk + 16, n_to_read); + + n_to_read = std::min(kklen - 24, 8); + __m256 av3_8_ps = load_float_n_avx2(ARowPtr + k + kk + 24, n_to_read); + + if constexpr (IsBlkLen64Layout) { + count_half_4 = 4 * (int)((kk % (2 * SubBlkLen32)) / SubBlkLen32); + } + UnrolledLoop([&](size_t i) { + // Load B col vectors. get SubBlkLen32 4b quantized weights from each column + __m256i bv_32_epi8; + if constexpr (IsBlkLen64Layout) { + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // load 64 weights at once, parse to get v0 - v31 if subblk % 2 == 0, otherwise get v32 - v63 + // increment b_data_ptr by 2 * SubBlkStep16 if kk % (2 * SubBlkLen32) == 1 + // so that all v0-63 of the pack are processed. + const __m256i bvi4 = _mm256_loadu_si256((__m256i const*)(b_blk_data_col_ptr[i])); + bv_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bvi4, count_half_4), lowMask); + b_blk_data_col_ptr[i] += count_half_4 / 2 * SubBlkStep16; + } else { + // SubBlkLen = 32: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + __m128i bvi4 = _mm_loadu_si128((const __m128i*)(b_blk_data_col_ptr[i])); + b_blk_data_col_ptr[i] += SubBlkStep16; + + bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bvi4, 4), bvi4); + bv_32_epi8 = _mm256_and_si256(lowMask, bv_32_epi8); + } + + // Subtract zero-point from the integers + if constexpr (HasZeroPoint) { + // Subtract zero-point from the integers + __m256i zp = _mm256_set1_epi8(offset[i]); + bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, zp); + } else { + // Subtract 8 from the integers + const __m256i eight = _mm256_set1_epi8(8); + bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, eight); + } + + // Convert to 16 float32 + const __m256i bv0_16_epi16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bv_32_epi8, 0)); + const __m256i bv1_16_epi16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bv_32_epi8, 1)); + + __m256 bv0_8_ps = + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(bv0_16_epi16, 0))); + __m256 bv1_8_ps = + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(bv0_16_epi16, 1))); + __m256 bv2_8_ps = + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(bv1_16_epi16, 0))); + __m256 bv3_8_ps = + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(bv1_16_epi16, 1))); + + // multiply by scale + __m256 scale_ps = _mm256_set1_ps(scale_v[i]); + bv0_8_ps = _mm256_mul_ps(bv0_8_ps, scale_ps); + bv1_8_ps = _mm256_mul_ps(bv1_8_ps, scale_ps); + bv2_8_ps = _mm256_mul_ps(bv2_8_ps, scale_ps); + bv3_8_ps = _mm256_mul_ps(bv3_8_ps, scale_ps); + + // c[m,n] += a[m,k] * b[k,n] + acc[i] = _mm256_fmadd_ps(bv0_8_ps, av0_8_ps, acc[i]); + acc[i] = _mm256_fmadd_ps(bv1_8_ps, av1_8_ps, acc[i]); + acc[i] = _mm256_fmadd_ps(bv2_8_ps, av2_8_ps, acc[i]); + acc[i] = _mm256_fmadd_ps(bv3_8_ps, av3_8_ps, acc[i]); + }); + } // kk + + b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + s++; + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + } // k + + if constexpr (NCols == 4) { + __m128 acc_x = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (bias_ptr != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(bias_ptr)); + } + _mm_storeu_ps(sum_ptr, acc_x); + } else { + UnrolledLoop([&](size_t i) { + __m128 vlow = _mm256_castps256_ps128(acc[i]); + __m128 vhigh = _mm256_extractf128_ps(acc[i], 1); // Extract high 128 bit + + // Add the two 128-bit vectors together + __m128 vsum = _mm_add_ps(vlow, vhigh); + // Horizontally add the elements of the resulting 128-bit vector + vsum = _mm_hadd_ps(vsum, vsum); + vsum = _mm_hadd_ps(vsum, vsum); + + _mm_store_ss(&sum_ptr[i], vsum); + sum_ptr[i] += bias_ptr == nullptr ? 0.0f : bias_ptr[i]; + }); + } +} + +// TODO: flow MlasQ4GemmKernelBlkLen16Avx512f to improve perf +template +void +SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_avx2( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + + const float* ARowPtr = A; + float* CRowPtr = C; + + const size_t BlockCountK = BlockStrideQuantB; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + int64_t nblk = static_cast(CountN) - NCols4; + while (nblk >= 0) { + if (BlkLen >= 64) { + ComputeDotProducts_BlkLen32Plus_CompFp32_avx2( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + } else { + ComputeDotProducts_BlkLen32Plus_CompFp32_avx2( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + } + + // move to next `NCols` columns + + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + + nblk -= NCols4; + } + + // left over columns less than `NCols`? + nblk += NCols4; + for (int64_t n = 0; n < nblk; ++n) { + if (BlkLen >= 64) { + ComputeDotProducts_BlkLen32Plus_CompFp32_avx2<1, HasZeroPoint, true>( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + } else { + ComputeDotProducts_BlkLen32Plus_CompFp32_avx2<1, HasZeroPoint, false>( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + } + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_CompFp32_avx2( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (BlkLen == 16) { + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_BlkLen16_CompFp32_avx2( + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_BlkLen16_CompFp32_avx2( + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } + } else { + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_avx2( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_avx2( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } + } +} + +MLAS_FORCEINLINE __m128i +convert_2_ps_to_epi8(__m256 v0, __m256 v1) +{ + __m256i v0_8_epi32 = _mm256_cvtps_epi32(v0); + __m256i v1_8_epi32 = _mm256_cvtps_epi32(v1); + + __m128i v0_8_epi16 = _mm_packs_epi32(_mm256_extractf128_si256(v0_8_epi32, 0), _mm256_extractf128_si256(v0_8_epi32, 1)); + __m128i v1_8_epi16 = _mm_packs_epi32(_mm256_extractf128_si256(v1_8_epi32, 0), _mm256_extractf128_si256(v1_8_epi32, 1)); + + return _mm_packs_epi16(v0_8_epi16, v1_8_epi16); +} + +void MLASCALL +QuantizeARow_CompInt8_avx2( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA +) +{ + // port from MlasQ80BlkQuantRow + assert(BlkLen % 16 == 0); + const __m256 signBit = _mm256_set1_ps(-0.0f); + int8_t* blob = reinterpret_cast(QuantA); + for (size_t k = 0; k < CountK; k += BlkLen) { + const size_t step = std::min(BlkLen, CountK - k); + + __m256 maxAbs = _mm256_setzero_ps(); + for (size_t kk = 0; kk < step; kk += 8) { + const int klen = std::min(8, (int)(step - kk)); + + __m256 v0 = load_float_n_avx2(A + k + kk, klen); + + // Compute max(abs(e)) for the block + maxAbs = _mm256_max_ps(maxAbs, _mm256_andnot_ps(signBit, v0)); + } + + __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(maxAbs, 1), _mm256_castps256_ps128(maxAbs)); + max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss(max4, _mm_shuffle_ps(max4, max4, 1)); + const float maxScalar = _mm_cvtss_f32(max4); + + // Quantize these floats + const float scale = maxScalar / 127.f; + *reinterpret_cast(blob) = scale; + blob += sizeof(float); + + const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps(inverse_scale); + __m128i* dst = reinterpret_cast<__m128i*>(blob); + + for (size_t kk = 0; kk < step; kk += 16) { + const int klen = std::min(16, (int)(step - kk)); + + int n_to_read = std::min(klen, 8); + __m256 v0 = load_float_n_avx2(A + k + kk, n_to_read); + v0 = _mm256_mul_ps(v0, mul); + v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST); + + __m256 v1; + n_to_read = std::min(klen - 8, 8); + if (n_to_read <= 0) { + v1 = _mm256_setzero_ps(); + } else { + v1 = load_float_n_avx2(A + k + kk + 8, n_to_read); + v1 = _mm256_mul_ps(v1, mul); + v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST); + } + + __m128i i_8 = convert_2_ps_to_epi8(v0, v1); + _mm_storeu_si128(dst++, i_8); + } + if (step < BlkLen) { + memset(blob + step, 0, BlkLen - step); + } + blob += BlkLen; + } +} + +// +// Kernel dispatch structure definition. +// +const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { + MLAS_SQNBIT_GEMM_DISPATCH d; + + d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; + d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + + d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2; + d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx2; + + return d; +}(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp new file mode 100644 index 0000000000000..1eca0960cf670 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -0,0 +1,243 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_avx512.cpp.h + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for x64 avx512. + +--*/ + +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" +#include "sqnbitgemm_kernel_avx_common_int8.h" + +// +// CompFp32 kernel implementation. +// + +#include "sqnbitgemm_kernel_avx_common_fp32.h" + +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_CompFp32_avx512( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (BlkLen == 16) { + if (QuantBZeroPoint != nullptr) { + MlasQ4GemmKernelBlkLen16Avx512f( + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + 1, + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0, + 0 + ); + } else { + MlasQ4GemmKernelBlkLen16Avx512f( + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + 1, + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0, + 0 + ); + } + } else if (BlkLen == 32) { + if (QuantBZeroPoint != nullptr) { + MlasQ4GemmKernelBlkLen32PlusAvx512f( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + 1, + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0, + 0 + ); + } else { + MlasQ4GemmKernelBlkLen32PlusAvx512f( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + 1, + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0, + 0 + ); + } + } else /*if (BlkLen >= 64)*/ { + if (QuantBZeroPoint != nullptr) { + MlasQ4GemmKernelBlkLen32PlusAvx512f( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + 1, + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0, + 0 + ); + } else { + MlasQ4GemmKernelBlkLen32PlusAvx512f( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + 1, + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0, + 0 + ); + } + } +} + +// +// CompInt8 kernel implementation. +// + +void MLASCALL +MlasQ80BlkQuantRow_avx512( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA +) +{ + // port from MlasQ80BlkQuantRow + assert(BlkLen % 16 == 0); + const __m512 signBit = _mm512_set1_ps(-0.0f); + int8_t* blob = reinterpret_cast(QuantA); + for (size_t k = 0; k < CountK; k += BlkLen) { + const size_t step = std::min(BlkLen, CountK - k); + + __m512 maxAbs = _mm512_setzero_ps(); + for (size_t kk = 0; kk < step; kk += 16) { + const size_t klen = std::min(size_t(16), step - kk); + + uint32_t mask = 0xffff >> (16 - klen); + __m512 v0 = _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk); + + // Compute max(abs(e)) for the block + maxAbs = _mm512_max_ps(maxAbs, _mm512_andnot_ps(signBit, v0)); + } + + __m256 max8 = + _mm256_max_ps(_mm512_extractf32x8_ps(maxAbs, 1), _mm512_extractf32x8_ps(maxAbs, 0)); + __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(max8, 1), _mm256_castps256_ps128(max8)); + max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); + const float maxScalar = _mm_cvtss_f32(max4); + + // Quantize these floats + const float scale = maxScalar / 127.f; + *reinterpret_cast(blob) = scale; + blob += sizeof(float); + + const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f; + const __m512 mul = _mm512_set1_ps(inverse_scale); + __m128i* dst = reinterpret_cast<__m128i*>(blob); + + for (size_t kk = 0; kk < step; kk += 16) { + const size_t klen = std::min(size_t(16), step - kk); + + uint32_t mask = 0xffff >> (16 - klen); + __m512 v0 = _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk); + v0 = _mm512_mul_ps(v0, mul); + + // Round to nearest integer + v0 = _mm512_roundscale_ps(v0, _MM_ROUND_NEAREST); + + // Convert floats to integers + __m512i i0 = _mm512_cvtps_epi32(v0); + + // Convert int32 to int8 + __m128i i0_8 = _mm512_cvtepi32_epi8(i0); + _mm_storeu_si128(dst++, i0_8); + } + if (step < BlkLen) { + memset(blob + step, 0, BlkLen - step); + } + blob += BlkLen; + } +} + +void MLASCALL +QuantizeARow_CompInt8_avx512( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA +) +{ + MlasQ80BlkQuantRow_avx512(BlkLen, A, CountK, QuantA); +} + +const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { + MLAS_SQNBIT_GEMM_DISPATCH d; + + d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512; + d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + + d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2; + d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx512; + + return d; +}(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp new file mode 100644 index 0000000000000..45a69c4f20603 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -0,0 +1,264 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_avx512.cpp.h + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for x64 avx512vnni. + +--*/ + +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" +#include "sqnbitgemm_kernel_avx_common_fp32.h" +#include "sqnbitgemm_kernel_avx_common_int8.h" + +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_CompFp32( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (BlkLen == 16) { + if (QuantBZeroPoint != nullptr) { + MlasQ4GemmKernelBlkLen16Avx512f( + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + 1, + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0, + 0 + ); + } else { + MlasQ4GemmKernelBlkLen16Avx512f( + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + 1, + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0, + 0 + ); + } + } else if (BlkLen == 32) { + if (QuantBZeroPoint != nullptr) { + MlasQ4GemmKernelBlkLen32PlusAvx512f( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + 1, + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0, + 0 + ); + } else { + MlasQ4GemmKernelBlkLen32PlusAvx512f( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + 1, + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0, + 0 + ); + } + } else /*if (BlkLen >= 64)*/ { + if (QuantBZeroPoint != nullptr) { + MlasQ4GemmKernelBlkLen32PlusAvx512f( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + 1, + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0, + 0 + ); + } else { + MlasQ4GemmKernelBlkLen32PlusAvx512f( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + 1, + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0, + 0 + ); + } + } +} + +MLAS_FORCEINLINE +void +SQ4BitGemmM1Kernel_CompInt8_avx512vnni( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (QuantBZeroPoint != nullptr) { + constexpr bool HasZeroPoint = true; + if (BlkLen == 16) { + SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else if (BlkLen == 32) { + SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } + } else { + constexpr bool HasZeroPoint = false; + if (BlkLen == 16) { + SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else if (BlkLen == 32) { + SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } + } +} + +void MLASCALL +MlasQ80BlkQuantRow_avx512( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA +); + +// +// Kernel dispatch structure definition. +// +const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { + MLAS_SQNBIT_GEMM_DISPATCH d; + + d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; + d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + + d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx512vnni; + d.QuantizeARow_CompInt8 = MlasQ80BlkQuantRow_avx512; + + return d; +}(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h new file mode 100644 index 0000000000000..abace949a1c5d --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -0,0 +1,371 @@ +#pragma once +#include "sqnbitgemm.h" + +// +// Quantized B data packing function implementation. +// + +static size_t +SQ4BitGemmPackQuantBDataSize( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType + + constexpr size_t BlkBitWidth = 4; + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + return PackedQuantBDataSize; +} + +static void +SQ4BitGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + constexpr size_t BlkBitWidth = 4; + + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t Iterations = N * BlockCountK; // one iteration per block + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + + const size_t SubBlkDataSize = SubBlkLen / 2; + const size_t SubBlkBytePairCount = SubBlkLen / 4; + + // + // For SubBlkLen == 16, pack 16 4-bit values (8 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | + // => + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + // + + // + // For SubBlkLen == 32, pack 32 4-bit values (16 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | + // => + // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // + + // + // For SubBlkLen == 64, pack 32 4-bit values (16 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | v32 v33 | v34 v33 | + // => + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t k_blk = tid % BlockCountK; + + const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize; + const std::byte* QuantBData = QuantBDataBegin + data_offset; + std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset; + + for (size_t kk = 0; kk < BlkLen; kk += SubBlkLen) { + for (size_t byte_pair_idx = 0; byte_pair_idx < SubBlkBytePairCount; ++byte_pair_idx) { + const std::byte src0 = QuantBData[byte_pair_idx]; + const std::byte src1 = QuantBData[byte_pair_idx + SubBlkDataSize / 2]; + + std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; + std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; + + dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); + dst1 = (src0 >> 4) | ((src1 >> 4) << 4); + } + + QuantBData += SubBlkDataSize; + PackedQuantBData += SubBlkDataSize; + } + } + ); +} + +void +Q4BitBlkDequantBForSgemm_CompFp32_avx2( + const size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + const size_t CountN, + const size_t CountK, + const size_t BlockStrideQuantB +); + +void +SQ4BitGemmM1Kernel_CompInt8_avx2( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +); + +// +// General helpers. +// + +namespace +{ + +template +MLAS_FORCEINLINE void +UnrolledLoopIterations(IterationFn&& f, std::index_sequence /* indices */) +{ + (f(Indices), ...); +} + +template +MLAS_FORCEINLINE void +UnrolledLoop(IterationFn&& f) +{ + UnrolledLoopIterations(std::forward(f), std::make_index_sequence()); +} + +// this function is used to dot product 2 pairs of 32 epi8s. it is used with Int8 precision +// and blklen >= 64. In this case, 64 of 4b weights are filled with one load. +static MLAS_FORCEINLINE __m256 +dot_quad_avx512vnni( + const __m256i bv0_32_epi8, const __m256i bv1_32_epi8, const __m256i av0_32_epi8, const __m256i av1_32_epi8 +) +{ + const __m256i zero = _mm256_setzero_si256(); + __m256i sum_8_epi32 = _mm256_dpbusd_epi32(zero, _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + sum_8_epi32 = _mm256_dpbusd_epi32(sum_8_epi32, _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + return _mm256_cvtepi32_ps(sum_8_epi32); +} + +static MLAS_FORCEINLINE __m256 +dot_quad_avx2( + const __m256i b0, const __m256i b1, const __m256i a0, const __m256i a1 +) +{ + // Perform multiplication and create 16-bit values + const __m256i ones = _mm256_set1_epi16(1); + __m256i sum_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(b0, b0), _mm256_sign_epi8(a0, b0)); + __m256i summed_pair_epi32 = _mm256_madd_epi16(ones, sum_epi16); + + sum_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(b1, b1), _mm256_sign_epi8(a1, b1)); + summed_pair_epi32 = _mm256_add_epi32(_mm256_madd_epi16(ones, sum_epi16), summed_pair_epi32); + return _mm256_cvtepi32_ps(summed_pair_epi32); +} + +// TODO: refactor load_and_mul_sum_s8_quads_with_zp_avx512vnni, load_and_mul_sum_s8_quads_with_zp_avx2 +// and accumulate_mul_sum_avx512vnni, accumulate_mul_sum_avx2 +static MLAS_FORCEINLINE void +load_and_mul_sum_s8_quads_with_zp_avx512vnni( + const __m256i av_0_epi8, const __m128i* QuantBDataPtr, const __m128i low_mask, const __m256i zero, const int8_t zp, const __m256 scale0, __m256& acc0 +) +{ + // load B + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + + // supprisingly this code that works with __m128i is 2-3% faster than the blobk below with __m256i + // to unpack bv_packed0. Also passing in low_mask is faster than creating it here by 2%. + // const __m128i low_mask = _mm_set1_epi8(15); + const __m128i bv_lo0 = _mm_and_si128(bv_packed0, low_mask); // 0, 1, 2, 3,... + const __m128i bv_hi0 = _mm_and_si128(_mm_srli_epi16(bv_packed0, 4), low_mask); // 16, 17, 18, 19,... + __m256i bv_0_epi8 = _mm256_set_m128i(bv_hi0, bv_lo0); + + //__m256i bv_0_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + // const __m256i low_mask = _mm256_set1_epi8(15); + // bv_0_epi8 = _mm256_and_si256(low_mask, bv_0_epi8); + + const __m256i bzp0 = _mm256_set1_epi8(zp); + bv_0_epi8 = _mm256_sub_epi8(bv_0_epi8, bzp0); + // quantized dot product + __m256i dot_0_epi32 = _mm256_dpbusd_epi32( + zero, _mm256_sign_epi8(bv_0_epi8, bv_0_epi8), _mm256_sign_epi8(av_0_epi8, bv_0_epi8) + ); + const __m256 sum_ps = _mm256_cvtepi32_ps(dot_0_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, scale0, acc0); +} + +static MLAS_FORCEINLINE void +load_and_mul_sum_s8_quads_with_zp_avx2( + const __m256i av_0_epi8, const __m128i* QuantBDataPtr, const __m128i low_mask, const __m256i, const int8_t zp, const __m256 scale0, __m256& acc0 +) +{ + // load B + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + + // supprisingly this code that works with __m128i is 2-3% faster than the blobk below with __m256i + // to unpack bv_packed0. Also passing in low_mask is faster than creating it here by 2%. + // const __m128i low_mask = _mm_set1_epi8(15); + const __m128i bv_lo0 = _mm_and_si128(bv_packed0, low_mask); // 0, 1, 2, 3,... + const __m128i bv_hi0 = _mm_and_si128(_mm_srli_epi16(bv_packed0, 4), low_mask); // 16, 17, 18, 19,... + __m256i bv_0_epi8 = _mm256_set_m128i(bv_hi0, bv_lo0); + + //__m256i bv_0_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + // const __m256i low_mask = _mm256_set1_epi8(15); + // bv_0_epi8 = _mm256_and_si256(low_mask, bv_0_epi8); + + const __m256i bzp0 = _mm256_set1_epi8(zp); + bv_0_epi8 = _mm256_sub_epi8(bv_0_epi8, bzp0); + // quantized dot product + __m256i dot_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv_0_epi8, bv_0_epi8), _mm256_sign_epi8(av_0_epi8, bv_0_epi8) + ); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, scale0, acc0); +} + +template +int8_t MLAS_FORCEINLINE +get_zp(bool is_lower_half_byte_zp, const std::byte* QuantBZeroPointPtr) +{ + if constexpr (!HasZeroPoint) { + // Suppress unused variable warnings + (void)QuantBZeroPointPtr; + } + + if constexpr (HasZeroPoint) { + return is_lower_half_byte_zp ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : std::to_integer((*QuantBZeroPointPtr) >> 4); + } else { + return 8; + } +} + +// this function load and unpack 32 4b weights (packed for BlkLen32) and dot product it with 32 +// epi8 input. dot products are accumulated into acc0. +// This function is called for Int8 precision with BlkLen = 32. +template +using AccumulateFunctionType = void (*)( + const __m256i, const __m128i*, const __m128i, const __m256i, const std::byte*, bool, const float, __m256& +); + +template +static MLAS_FORCEINLINE void +accumulate_mul_sum_avx512vnni( + const __m256i av_0_epi8, const __m128i* QuantBDataPtr, const __m128i low_mask, const __m256i zero, const std::byte* QuantBZeroPointPtr, bool is_lower_half_byte_zp, const float combined_scale, __m256& acc0 +) +{ + const __m256 scale0 = _mm256_set1_ps(combined_scale); + const int8_t zp = get_zp(is_lower_half_byte_zp, QuantBZeroPointPtr); + load_and_mul_sum_s8_quads_with_zp_avx512vnni( + av_0_epi8, reinterpret_cast(QuantBDataPtr), + low_mask, zero, + zp, scale0, acc0 + ); +} + +template +static MLAS_FORCEINLINE void +accumulate_mul_sum_avx2( + const __m256i av_0_epi8, const __m128i* QuantBDataPtr, const __m128i low_mask, const __m256i zero, const std::byte* QuantBZeroPointPtr, bool is_lower_half_byte_zp, const float combined_scale, __m256& acc0 +) +{ + const __m256 scale0 = _mm256_set1_ps(combined_scale); + const int8_t zp = get_zp(is_lower_half_byte_zp, QuantBZeroPointPtr); + load_and_mul_sum_s8_quads_with_zp_avx2( + av_0_epi8, reinterpret_cast(QuantBDataPtr), + low_mask, zero, + zp, scale0, acc0 + ); +} + +/** + * @brief Horizontally sum 4 vectors and store + * the results in the returned vector + */ +static MLAS_FORCEINLINE __m128 +FoldAccumulators(const __m256& acc0, const __m256& acc1, const __m256& acc2, const __m256& acc3) +{ + __m256 acc_lo01 = _mm256_unpacklo_ps(acc0, acc1); + __m256 acc_hi01 = _mm256_unpackhi_ps(acc0, acc1); + __m256 acc_lo23 = _mm256_unpacklo_ps(acc2, acc3); + __m256 acc_hi23 = _mm256_unpackhi_ps(acc2, acc3); + + __m256 acc_lo0123 = _mm256_castpd_ps( + _mm256_unpacklo_pd(_mm256_castps_pd(acc_lo01), _mm256_castps_pd(acc_lo23)) + ); + __m256 acc_hi0123 = _mm256_castpd_ps( + _mm256_unpackhi_pd(_mm256_castps_pd(acc_lo01), _mm256_castps_pd(acc_lo23)) + ); + acc_lo0123 = _mm256_add_ps(acc_lo0123, acc_hi0123); + acc_hi0123 = _mm256_castpd_ps( + _mm256_unpacklo_pd(_mm256_castps_pd(acc_hi01), _mm256_castps_pd(acc_hi23)) + ); + acc_lo0123 = _mm256_add_ps(acc_lo0123, acc_hi0123); + acc_hi0123 = _mm256_castpd_ps( + _mm256_unpackhi_pd(_mm256_castps_pd(acc_hi01), _mm256_castps_pd(acc_hi23)) + ); + acc_lo0123 = _mm256_add_ps(acc_lo0123, acc_hi0123); + + __m128 acc_y = + _mm_add_ps(_mm256_extractf128_ps(acc_lo0123, 0), _mm256_extractf128_ps(acc_lo0123, 1)); + return acc_y; +} + +static inline float +hsum_float_8(const __m256 x) +{ + __m128 res = _mm256_extractf128_ps(x, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(x)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + return _mm_cvtss_f32(res); +} + +/** + * @brief Horizontally sum 4 vectors and store + * the results in the returned vector + */ +static MLAS_FORCEINLINE __m128 +FoldAccumulators(const __m512& acc0, const __m512& acc1, const __m512& acc2, const __m512& acc3) +{ + __m512 acc_lo01 = _mm512_unpacklo_ps(acc0, acc1); + __m512 acc_hi01 = _mm512_unpackhi_ps(acc0, acc1); + __m512 acc_lo23 = _mm512_unpacklo_ps(acc2, acc3); + __m512 acc_hi23 = _mm512_unpackhi_ps(acc2, acc3); + + __m512 acc_lo0123 = _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(acc_lo01), _mm512_castps_pd(acc_lo23)) + ); + __m512 acc_hi0123 = _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(acc_lo01), _mm512_castps_pd(acc_lo23)) + ); + acc_lo0123 = _mm512_add_ps(acc_lo0123, acc_hi0123); + acc_hi0123 = _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(acc_hi01), _mm512_castps_pd(acc_hi23)) + ); + acc_lo0123 = _mm512_add_ps(acc_lo0123, acc_hi0123); + acc_hi0123 = _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(acc_hi01), _mm512_castps_pd(acc_hi23)) + ); + acc_lo0123 = _mm512_add_ps(acc_lo0123, acc_hi0123); + + __m256 acc_y = + _mm256_add_ps(_mm512_extractf32x8_ps(acc_lo0123, 0), _mm512_extractf32x8_ps(acc_lo0123, 1)); + return _mm_add_ps(_mm256_extractf32x4_ps(acc_y, 0), _mm256_extractf32x4_ps(acc_y, 1)); +} +} // namespace diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_fp32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_fp32.h new file mode 100644 index 0000000000000..5cd380e591098 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_fp32.h @@ -0,0 +1,639 @@ +#pragma once +#include "sqnbitgemm.h" + +template +MLAS_FORCEINLINE + size_t + MlasQ4GemmKernelBlkLen16Avx512f( + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc + ) +{ + // We process 32 quantized values in a batch. + // assert(BlkLen % 32 == 0) + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols = 4; + constexpr size_t BlkLen16 = 16; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const __m128i lowMask = _mm_set1_epi8(0xF); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + for (size_t m = 0; m < CountM; m++) { + //*// + ////const float* BiasPtr = Bias; + + // for each row of A, reset B pointers + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + ////float* SumPtr = CRowPtr; + //*// + + auto* sum_ptr = C; + const auto* bias_ptr = Bias; + + int64_t nblk = (int64_t)(CountN)-4; + while (nblk >= 0) { + __m512 acc_lo0 = _mm512_setzero_ps(); + __m512 acc_lo1 = _mm512_setzero_ps(); + __m512 acc_lo2 = _mm512_setzero_ps(); + __m512 acc_lo3 = _mm512_setzero_ps(); + + //*// + const std::byte* b_blk_data_ptr = QuantBDataColPtr; + const float* s = QuantBScaleColPtr; + //*// + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx = 0; + } + + for (size_t k = 0; k < CountK; k += BlkLen16) { + size_t kklen = std::min(CountK - k, BlkLen16); + + const float scale_v0 = *(s); + const float scale_v1 = *(s + StrideQuantBScale * 1); + const float scale_v2 = *(s + StrideQuantBScale * 2); + const float scale_v3 = *(s + StrideQuantBScale * 3); + + const __m128i* b0ptr = (const __m128i*)(b_blk_data_ptr); + const __m128i* b1ptr = (const __m128i*)(b_blk_data_ptr + StrideQuantBData * 1); + const __m128i* b2ptr = (const __m128i*)(b_blk_data_ptr + StrideQuantBData * 2); + const __m128i* b3ptr = (const __m128i*)(b_blk_data_ptr + StrideQuantBData * 3); + + // Load A row vector of 16 floats + uint32_t mask = 0xffff >> (BlkLen16 - kklen); + __m512 av_lo = _mm512_maskz_loadu_ps(__mmask16(mask), A + k); + + // Load B col vectors of 16 of 4b + // SubBlkLen = 16: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + const __m128i bvi4_0 = _mm_loadl_epi64(b0ptr++); + const __m128i bvi4_1 = _mm_loadl_epi64(b1ptr++); + const __m128i bvi4_2 = _mm_loadl_epi64(b2ptr++); + const __m128i bvi4_3 = _mm_loadl_epi64(b3ptr++); + + // expand 4b into byte array + __m128i lower = _mm_and_si128(bvi4_0, lowMask); + __m128i upper = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bvi4_0, 4), lowMask), 8); + __m128i bytes0 = _mm_add_epi8(upper, lower); + + lower = _mm_and_si128(bvi4_1, lowMask); + upper = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bvi4_1, 4), lowMask), 8); + __m128i bytes1 = _mm_add_epi8(upper, lower); + + lower = _mm_and_si128(bvi4_2, lowMask); + upper = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bvi4_2, 4), lowMask), 8); + __m128i bytes2 = _mm_add_epi8(upper, lower); + + lower = _mm_and_si128(bvi4_3, lowMask); + upper = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bvi4_3, 4), lowMask), 8); + __m128i bytes3 = _mm_add_epi8(upper, lower); + + // Subtract zero-point from the integers + if constexpr (HasZeroPoint) { + // Subtract zero-point from the integers + bool is_lower = (QuantBZeroPointIdx & 1) == 0; + + // TODO: void condition on is_lower + std::byte zp_packed = QuantBZeroPointColPtr[0 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + uint8_t zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + + bytes0 = _mm_sub_epi8(bytes0, _mm_set1_epi8(zp)); + + zp_packed = QuantBZeroPointColPtr[1 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + bytes1 = _mm_sub_epi8(bytes1, _mm_set1_epi8(zp)); + + zp_packed = QuantBZeroPointColPtr[2 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + bytes2 = _mm_sub_epi8(bytes2, _mm_set1_epi8(zp)); + + zp_packed = QuantBZeroPointColPtr[3 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + bytes3 = _mm_sub_epi8(bytes3, _mm_set1_epi8(zp)); + } else { + // Subtract 8 from the integers + const __m128i eight = _mm_set1_epi8(8); + bytes0 = _mm_sub_epi8(bytes0, eight); + bytes1 = _mm_sub_epi8(bytes1, eight); + bytes2 = _mm_sub_epi8(bytes2, eight); + bytes3 = _mm_sub_epi8(bytes3, eight); + } + + // Convert to 16-bit int + const __m256i vx16_0 = _mm256_cvtepi8_epi16(bytes0); + const __m256i vx16_1 = _mm256_cvtepi8_epi16(bytes1); + const __m256i vx16_2 = _mm256_cvtepi8_epi16(bytes2); + const __m256i vx16_3 = _mm256_cvtepi8_epi16(bytes3); + + // Convert to 32-bit int -> float 32 + __m512 bvf_0 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_0)); + __m512 bvf_1 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_1)); + __m512 bvf_2 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_2)); + __m512 bvf_3 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_3)); + + __m512 scale_ps = _mm512_set1_ps(scale_v0); + bvf_0 = _mm512_mul_ps(bvf_0, scale_ps); + scale_ps = _mm512_set1_ps(scale_v1); + bvf_1 = _mm512_mul_ps(bvf_1, scale_ps); + scale_ps = _mm512_set1_ps(scale_v2); + bvf_2 = _mm512_mul_ps(bvf_2, scale_ps); + scale_ps = _mm512_set1_ps(scale_v3); + bvf_3 = _mm512_mul_ps(bvf_3, scale_ps); + + acc_lo0 = _mm512_fmadd_ps(bvf_0, av_lo, acc_lo0); + acc_lo1 = _mm512_fmadd_ps(bvf_1, av_lo, acc_lo1); + acc_lo2 = _mm512_fmadd_ps(bvf_2, av_lo, acc_lo2); + acc_lo3 = _mm512_fmadd_ps(bvf_3, av_lo, acc_lo3); + + //*// + b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + s++; + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + //*// + + } // k + + __m128 acc_x = FoldAccumulators(acc_lo0, acc_lo1, acc_lo2, acc_lo3); + if (Bias != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(bias_ptr)); + } + _mm_storeu_ps(sum_ptr, acc_x); + + // move to next 4 columns + sum_ptr += 4; + bias_ptr += 4; + nblk -= 4; + + //*// + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; + } + + ////BiasPtr += BiasPtr != nullptr ? NCols : 0; + ////SumPtr += NCols; + + ////nblk -= NCols; + //*// + } + + // left over columns less than 4 ? + nblk += 4; + if (nblk > 0) { + __m512 acc_lo[4]{}; + + //*// + const std::byte* b_blk_data_ptr = QuantBDataColPtr; + const float* s = QuantBScaleColPtr; + //*// + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx = 0; + } + + for (size_t k = 0; k < CountK; k += BlkLen16) { + size_t klen = std::min(CountK - k, BlkLen16); + + float scale_v[4]; + const __m128i* b_ptr[4]; + for (int64_t nn = 0; nn < nblk; nn++) { + //*// + scale_v[nn] = *(s + StrideQuantBScale * nn); + b_ptr[nn] = (const __m128i*)(b_blk_data_ptr + StrideQuantBData * nn); + //*// + } + + uint32_t mask = 0xffff >> (BlkLen16 - klen); + __m512 av_lo = _mm512_maskz_loadu_ps(__mmask16(mask), A + k); + + for (int64_t nn = 0; nn < nblk; nn++) { + // Load B col vectors of 16 of 4b + // SubBlkLen = 16: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + const __m128i bvi4_0 = _mm_loadl_epi64(b_ptr[nn]++); + + // expand 4b into byte array + __m128i lower = _mm_and_si128(bvi4_0, lowMask); + __m128i upper = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bvi4_0, 4), lowMask), 8); + __m128i bytes = _mm_add_epi8(upper, lower); + + if constexpr (HasZeroPoint) { + // Subtract zero-point from the integers + bool is_lower = (QuantBZeroPointIdx & 1) == 0; + + // TODO: void condition on is_lower + std::byte zp_packed = QuantBZeroPointColPtr[nn * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + uint8_t zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + bytes = _mm_sub_epi8(bytes, _mm_set1_epi8(zp)); + } else { + // Subtract 8 from the integers + const __m128i eight = _mm_set1_epi8(8); + bytes = _mm_sub_epi8(bytes, eight); + } + + // Convert to 16-bit int + const __m256i vx16 = _mm256_cvtepi8_epi16(bytes); + + // Convert to 32-bit int -> float 32 + __m512 bvf = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16)); + __m512 scale_16_ps = _mm512_set1_ps(scale_v[nn]); + bvf = _mm512_mul_ps(bvf, scale_16_ps); + + acc_lo[nn] = _mm512_fmadd_ps(bvf, av_lo, acc_lo[nn]); + } + + //*// + b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + s++; + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + //*// + } // k + + for (int64_t nn = 0; nn < nblk; nn++) { + sum_ptr[nn] = _mm512_reduce_add_ps(acc_lo[nn]); + sum_ptr[nn] += Bias == nullptr ? 0.0f : bias_ptr[nn]; + } + } + + // Prepare pointers for the next row + C += ldc; + A += lda; + } + return CountM; +} + +template +MLAS_FORCEINLINE + size_t + MlasQ4GemmKernelBlkLen32PlusAvx512f( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc + ) +{ + // We process 32 quantized values in a batch. + // assert(BlkLen % 32 == 0) + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols = 4; + constexpr size_t MLAS_QUANT4_BLK_UNIT32 = 32; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const __m256i lowMask = _mm256_set1_epi8(0xF); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + for (size_t m = 0; m < CountM; m++) { + //*// + ////const float* BiasPtr = Bias; + + // for each row of A, reset B pointers + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + ////float* SumPtr = CRowPtr; + //*// + + auto* sum_ptr = C; + const auto* bias_ptr = Bias; + + int64_t nblk = (int64_t)(CountN)-4; + while (nblk >= 0) { + __m512 acc_lo0 = _mm512_setzero_ps(); + __m512 acc_lo1 = _mm512_setzero_ps(); + __m512 acc_lo2 = _mm512_setzero_ps(); + __m512 acc_lo3 = _mm512_setzero_ps(); + + //*// + const std::byte* b_blk_data_ptr = QuantBDataColPtr; + const float* s = QuantBScaleColPtr; + //*// + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx = 0; + } + + for (size_t k = 0; k < CountK; k += BlkLen) { + size_t ck = std::min(CountK - k, BlkLen); + + const float scale_v0 = *(s); + const float scale_v1 = *(s + StrideQuantBScale * 1); + const float scale_v2 = *(s + StrideQuantBScale * 2); + const float scale_v3 = *(s + StrideQuantBScale * 3); + + const __m128i* b0ptr = (const __m128i*)(b_blk_data_ptr); + const __m128i* b1ptr = (const __m128i*)(b_blk_data_ptr + StrideQuantBData * 1); + const __m128i* b2ptr = (const __m128i*)(b_blk_data_ptr + StrideQuantBData * 2); + const __m128i* b3ptr = (const __m128i*)(b_blk_data_ptr + StrideQuantBData * 3); + + for (size_t kk = 0; kk < ck; kk += MLAS_QUANT4_BLK_UNIT32) { + size_t kklen = std::min((size_t)MLAS_QUANT4_BLK_UNIT32, ck - kk); + + // Load A row vectors + uint32_t mask = 0xffffffff >> (MLAS_QUANT4_BLK_UNIT32 - kklen); + __m512 av_lo = _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk); + + mask = mask >> 16; + __m512 av_hi = mask == 0 ? _mm512_setzero_ps() + : _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk + 16); + + // Load B col vectors + __m256i bytes0, bytes1, bytes2, bytes3; + if constexpr (IsBlkLen64Layout) { + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // load 64 weights at once, parse to get v0 - v31 if subblk is even, otherwise get v32 - v63 + // increment b_data_ptr by 2 * MLAS_QUANT4_BLK_UNIT32 if subblk is odd + // so that all v0-63 of the pack are processed. + const __m256i bvi4_0 = _mm256_loadu_si256((__m256i const*)(b0ptr)); + const __m256i bvi4_1 = _mm256_loadu_si256((__m256i const*)(b1ptr)); + const __m256i bvi4_2 = _mm256_loadu_si256((__m256i const*)(b2ptr)); + const __m256i bvi4_3 = _mm256_loadu_si256((__m256i const*)(b3ptr)); + const int count_half_4 = + 4 * ((kk % (2 * MLAS_QUANT4_BLK_UNIT32)) / MLAS_QUANT4_BLK_UNIT32); + bytes0 = _mm256_and_si256(_mm256_srli_epi16(bvi4_0, count_half_4), lowMask); + bytes1 = _mm256_and_si256(_mm256_srli_epi16(bvi4_1, count_half_4), lowMask); + bytes2 = _mm256_and_si256(_mm256_srli_epi16(bvi4_2, count_half_4), lowMask); + bytes3 = _mm256_and_si256(_mm256_srli_epi16(bvi4_3, count_half_4), lowMask); + b0ptr += count_half_4 / 2; + b1ptr += count_half_4 / 2; + b2ptr += count_half_4 / 2; + b3ptr += count_half_4 / 2; + } else { + const __m128i bvi4_0 = _mm_loadu_si128(b0ptr++); + const __m128i bvi4_1 = _mm_loadu_si128(b1ptr++); + const __m128i bvi4_2 = _mm_loadu_si128(b2ptr++); + const __m128i bvi4_3 = _mm_loadu_si128(b3ptr++); + + // expand 4b into byte array + bytes0 = _mm256_set_m128i(_mm_srli_epi16(bvi4_0, 4), bvi4_0); + bytes1 = _mm256_set_m128i(_mm_srli_epi16(bvi4_1, 4), bvi4_1); + bytes2 = _mm256_set_m128i(_mm_srli_epi16(bvi4_2, 4), bvi4_2); + bytes3 = _mm256_set_m128i(_mm_srli_epi16(bvi4_3, 4), bvi4_3); + bytes0 = _mm256_and_si256(lowMask, bytes0); + bytes1 = _mm256_and_si256(lowMask, bytes1); + bytes2 = _mm256_and_si256(lowMask, bytes2); + bytes3 = _mm256_and_si256(lowMask, bytes3); + } + + // Subtract zero-point from the integers + if constexpr (HasZeroPoint) { + // Subtract zero-point from the integers + bool is_lower = (QuantBZeroPointIdx & 1) == 0; + + // TODO: void condition on is_lower + std::byte zp_packed = QuantBZeroPointColPtr[0 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + uint8_t zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + + bytes0 = _mm256_sub_epi8(bytes0, _mm256_set1_epi8(zp)); + + zp_packed = QuantBZeroPointColPtr[1 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + bytes1 = _mm256_sub_epi8(bytes1, _mm256_set1_epi8(zp)); + + zp_packed = QuantBZeroPointColPtr[2 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + bytes2 = _mm256_sub_epi8(bytes2, _mm256_set1_epi8(zp)); + + zp_packed = QuantBZeroPointColPtr[3 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + bytes3 = _mm256_sub_epi8(bytes3, _mm256_set1_epi8(zp)); + } else { + // Subtract 8 from the integers + const __m256i eight = _mm256_set1_epi8(8); + bytes0 = _mm256_sub_epi8(bytes0, eight); + bytes1 = _mm256_sub_epi8(bytes1, eight); + bytes2 = _mm256_sub_epi8(bytes2, eight); + bytes3 = _mm256_sub_epi8(bytes3, eight); + } + + // Convert to 16-bit int + const __m256i vx16_lo0 = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 0)); + const __m256i vx16_hi0 = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 1)); + const __m256i vx16_lo1 = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 0)); + const __m256i vx16_hi1 = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 1)); + const __m256i vx16_lo2 = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 0)); + const __m256i vx16_hi2 = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 1)); + const __m256i vx16_lo3 = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 0)); + const __m256i vx16_hi3 = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 1)); + + // Convert to 32-bit int -> float 32 + __m512 bvf_lo0 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo0)); + __m512 bvf_hi0 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi0)); + __m512 bvf_lo1 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo1)); + __m512 bvf_hi1 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi1)); + __m512 bvf_lo2 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo2)); + __m512 bvf_hi2 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi2)); + __m512 bvf_lo3 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo3)); + __m512 bvf_hi3 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi3)); + + __m512 scale_ps = _mm512_set1_ps(scale_v0); + bvf_lo0 = _mm512_mul_ps(bvf_lo0, scale_ps); + bvf_hi0 = _mm512_mul_ps(bvf_hi0, scale_ps); + scale_ps = _mm512_set1_ps(scale_v1); + bvf_lo1 = _mm512_mul_ps(bvf_lo1, scale_ps); + bvf_hi1 = _mm512_mul_ps(bvf_hi1, scale_ps); + scale_ps = _mm512_set1_ps(scale_v2); + bvf_lo2 = _mm512_mul_ps(bvf_lo2, scale_ps); + bvf_hi2 = _mm512_mul_ps(bvf_hi2, scale_ps); + scale_ps = _mm512_set1_ps(scale_v3); + bvf_lo3 = _mm512_mul_ps(bvf_lo3, scale_ps); + bvf_hi3 = _mm512_mul_ps(bvf_hi3, scale_ps); + + acc_lo0 = _mm512_fmadd_ps(bvf_lo0, av_lo, acc_lo0); + acc_lo0 = _mm512_fmadd_ps(bvf_hi0, av_hi, acc_lo0); + acc_lo1 = _mm512_fmadd_ps(bvf_lo1, av_lo, acc_lo1); + acc_lo1 = _mm512_fmadd_ps(bvf_hi1, av_hi, acc_lo1); + acc_lo2 = _mm512_fmadd_ps(bvf_lo2, av_lo, acc_lo2); + acc_lo2 = _mm512_fmadd_ps(bvf_hi2, av_hi, acc_lo2); + acc_lo3 = _mm512_fmadd_ps(bvf_lo3, av_lo, acc_lo3); + acc_lo3 = _mm512_fmadd_ps(bvf_hi3, av_hi, acc_lo3); + } // kk + + //*// + b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + s++; + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + //*// + + } // k + + __m128 acc_x = FoldAccumulators(acc_lo0, acc_lo1, acc_lo2, acc_lo3); + if (Bias != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(bias_ptr)); + } + _mm_storeu_ps(sum_ptr, acc_x); + + // move to next 4 columns + sum_ptr += 4; + bias_ptr += 4; + nblk -= 4; + + //*// + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; + } + + ////BiasPtr += BiasPtr != nullptr ? NCols : 0; + ////SumPtr += NCols; + + ////nblk -= NCols; + //*// + } + + // left over columns less than 4 ? + nblk += 4; + if (nblk > 0) { + __m512 acc_lo[4]{}; + + //*// + const std::byte* b_blk_data_ptr = QuantBDataColPtr; + const float* s = QuantBScaleColPtr; + //*// + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx = 0; + } + + for (size_t k = 0; k < CountK; k += BlkLen) { + size_t ck = std::min(CountK - k, BlkLen); + + float scale_v[4]; + const __m128i* b_ptr[4]; + for (int64_t nn = 0; nn < nblk; nn++) { + //*// + scale_v[nn] = *(s + StrideQuantBScale * nn); + b_ptr[nn] = (const __m128i*)(b_blk_data_ptr + StrideQuantBData * nn); + //*// + } + + for (size_t kk = 0; kk < ck; kk += MLAS_QUANT4_BLK_UNIT32) { + size_t kklen = std::min((size_t)MLAS_QUANT4_BLK_UNIT32, ck - kk); + + uint32_t mask = 0xffffffff >> (MLAS_QUANT4_BLK_UNIT32 - kklen); + __m512 av_lo = _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk); + + mask = mask >> 16; + __m512 av_hi = mask == 0 + ? _mm512_setzero_ps() + : _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk + 16); + + for (int64_t nn = 0; nn < nblk; nn++) { + __m256i bytes; + if constexpr (IsBlkLen64Layout) { + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // load 64 weights at once, parse to get v0 - v31 if subblk is even, otherwise get v32 - v63 + // increment b_data_ptr by 2 * MLAS_QUANT4_BLK_UNIT32 if subblk is odd + // so that all v0-63 of the pack are processed. + const __m256i bvi4 = _mm256_loadu_si256((__m256i const*)(b_ptr[nn])); + const int count_half_4 = + 4 * ((kk % (2 * MLAS_QUANT4_BLK_UNIT32)) / MLAS_QUANT4_BLK_UNIT32); + bytes = _mm256_and_si256(_mm256_srli_epi16(bvi4, count_half_4), lowMask); + b_ptr[nn] += count_half_4 / 2; + } else { + const __m128i bvi4 = _mm_loadu_si128(b_ptr[nn]++); + bytes = _mm256_set_m128i(_mm_srli_epi16(bvi4, 4), bvi4); + bytes = _mm256_and_si256(lowMask, bytes); + } + if constexpr (HasZeroPoint) { + // Subtract zero-point from the integers + bool is_lower = (QuantBZeroPointIdx & 1) == 0; + + // TODO: void condition on is_lower + std::byte zp_packed = QuantBZeroPointColPtr[nn * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + uint8_t zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + bytes = _mm256_sub_epi8(bytes, _mm256_set1_epi8(zp)); + } else { + // Subtract 8 from the integers + const __m256i eight = _mm256_set1_epi8(8); + bytes = _mm256_sub_epi8(bytes, eight); + } + + // Convert to 16-bit int + const __m256i vx16_lo = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes, 0)); + const __m256i vx16_hi = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes, 1)); + + // Convert to 32-bit int -> float 32 + __m512 bvf_lo = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo)); + __m512 bvf_hi = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi)); + __m512 scale_16_ps = _mm512_set1_ps(scale_v[nn]); + bvf_lo = _mm512_mul_ps(bvf_lo, scale_16_ps); + bvf_hi = _mm512_mul_ps(bvf_hi, scale_16_ps); + + acc_lo[nn] = _mm512_fmadd_ps(bvf_lo, av_lo, acc_lo[nn]); + acc_lo[nn] = _mm512_fmadd_ps(bvf_hi, av_hi, acc_lo[nn]); + } + } // kk + + //*// + b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + s++; + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + //*// + } // k + + for (int64_t nn = 0; nn < nblk; nn++) { + sum_ptr[nn] = _mm512_reduce_add_ps(acc_lo[nn]); + sum_ptr[nn] += Bias == nullptr ? 0.0f : bias_ptr[nn]; + } + } + + // Prepare pointers for the next row + C += ldc; + A += lda; + } + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h new file mode 100644 index 0000000000000..8f8506cb3adaa --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h @@ -0,0 +1,744 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + +void +SQ4BitGemmM1Kernel_CompInt8_avx2( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +); + +template +MLAS_FORCEINLINE void +ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen16( + size_t BlkLen, + const std::byte* QuantARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + float* SumPtr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* BiasPtr +) +{ + if constexpr (!HasZeroPoint) { + // Suppress unused variable warnings + (void)QuantBZeroPointColPtr; + (void)StrideQuantBZeroPoint; + } + + assert(BlkLen == 16); + constexpr size_t SubBlkLen = 16; + const __m128i low_mask = _mm_set1_epi8(0xF); + + constexpr size_t BlkBitWidth = 4; + constexpr size_t SubBlkStep = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, SubBlkLen); + + __m256 acc[NCols]; + UnrolledLoop([&](size_t i) { + acc[i] = _mm256_setzero_ps(); + }); + + const std::byte* ablob = QuantARowPtr; + const auto* b = QuantBDataColPtr; + const float* s = QuantBScaleColPtr; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + // only used if HasZeroPoint == true + + for (size_t k = 0; k < CountK; k += BlkLen) { + const float a_scale = Q8BlkScale(ablob); + ablob += sizeof(float); + + float scale_v[NCols]; + UnrolledLoop([&](size_t i) { + scale_v[i] = (*(s + StrideQuantBScale * i)) * a_scale; + }); + + std::byte* bptr[NCols]; + UnrolledLoop([&](size_t i) { + bptr[i] = (std::byte*)(b + StrideQuantBData * i); + }); + + [[maybe_unused]] uint8_t offset[NCols]; + // not ready for "Manual conversion to float" in neon yet. following neon to unpack to uint8_t. + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const std::byte zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[i] = std::to_integer(zp); + }); + } + + // Load A row vector + const __m128i av_epi8 = _mm_lddqu_si128((const __m128i*)ablob); + __m256i av_epi16 = _mm256_cvtepi8_epi16(av_epi8); + ablob += BlkLen; + + // Load 4 B column vectors (quantized to int4 blobs) + __m128i bvi[NCols]; + UnrolledLoop([&](size_t i) { + bvi[i] = _mm_loadl_epi64((__m128i const*)bptr[i]); + bptr[i] += SubBlkStep; + }); + + // expand 4b into byte array + __m256i bv_epi16[NCols]; + UnrolledLoop([&](size_t i) { + const __m128i lower = _mm_and_si128(bvi[i], low_mask); + const __m128i upper = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bvi[i], 4), low_mask), 8); + bv_epi16[i] = _mm256_cvtepi8_epi16(_mm_add_epi8(upper, lower)); + }); + + // Subtract zero-point from the integers + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + bv_epi16[i] = _mm256_sub_epi16(bv_epi16[i], _mm256_set1_epi16(offset[i])); + }); + } else { + const __m256i eight = _mm256_set1_epi16(8); + UnrolledLoop([&](size_t i) { + bv_epi16[i] = _mm256_sub_epi16(bv_epi16[i], eight); + }); + } + + UnrolledLoop([&](size_t i) { + __m256i prod_8_epi32 = _mm256_madd_epi16(bv_epi16[i], av_epi16); + + const __m256 prod_8_ps = _mm256_cvtepi32_ps(prod_8_epi32); + acc[i] = _mm256_fmadd_ps(_mm256_set1_ps(scale_v[i]), prod_8_ps, acc[i]); + }); + + b += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + s++; + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + } + + if constexpr (NCols == 4) { + __m128 acc_x = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_x); + } else { + UnrolledLoop([&](size_t i) { + __m128 vlow = _mm256_castps256_ps128(acc[i]); + __m128 vhigh = _mm256_extractf128_ps(acc[i], 1); // Extract high 128 bit + + // Add the two 128-bit vectors together + __m128 vsum = _mm_add_ps(vlow, vhigh); + // Horizontally add the elements of the resulting 128-bit vector + vsum = _mm_hadd_ps(vsum, vsum); + vsum = _mm_hadd_ps(vsum, vsum); + + _mm_store_ss(&SumPtr[i], vsum); + SumPtr[i] += BiasPtr == nullptr ? 0.0f : BiasPtr[i]; + }); + } +} + +template +void +SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t BlkLen16 = 16; + + const std::byte* QuantARowPtr = QuantA; + float* CRowPtr = C; + + const size_t BlockCountK = BlockStrideQuantB; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + int64_t nblk = static_cast(CountN) - NCols4; + + while (nblk >= 0) { + ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen16( + BlkLen16, QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, + SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, BiasPtr + ); + + // move to next `NCols` columns + + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + + nblk -= NCols4; + } + + // left over columns less than `NCols`? + nblk += NCols4; + for (int64_t n = 0; n < nblk; ++n) { + ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen16<1, HasZeroPoint>( + BlkLen16, QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, + SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, BiasPtr + ); + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +template accumulator> +void +SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias +) +{ + // port from neon implementation + constexpr size_t BlkBitWidth = 4; + constexpr size_t BlkLen = 32; + + float* CRowPtr = C; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + const __m256i zero = _mm256_setzero_si256(); + const __m128i low_mask = _mm_set1_epi8(0xF); + const size_t NCols = 4; + int64_t nblk = (int64_t)(CountN)-4; + while (nblk >= 0) { + const std::byte* QuantAPtr = QuantA; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 + acc0 = _mm256_setzero_ps(), + acc1 = _mm256_setzero_ps(), + acc2 = _mm256_setzero_ps(), + acc3 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= 2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a1 = Q8BlkScale(QuantABlk1); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); + + // Col1 + const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_11 = scale_a1 * (QuantBScalePtr + StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc1); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_11, acc1); + + // Col2 + const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_21 = scale_a1 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc2); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, false, scale_21, acc2); + + // Col3 + const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_31 = scale_a1 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc3); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, false, scale_31, acc3); + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen) * 2; + QuantBDataPtr += 16 * 2; + QuantBScalePtr += 2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + + // Col1 + const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc1); + + // Col2 + const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc2); + + // Col3 + const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc3); + } + + __m128 acc_x = FoldAccumulators(acc0, acc1, acc2, acc3); + if (BiasPtr != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_x); + + // move to next NCols columns + + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols : 0; + SumPtr += NCols; + nblk -= NCols; + } + + nblk += NCols; + for (int64_t n = 0; n < nblk; n++) { + const std::byte* QuantAPtr = QuantA; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= 2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a1 = Q8BlkScale(QuantABlk1); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen) * 2; + QuantBDataPtr += 16 * 2; + QuantBScalePtr += 2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +using DotQuadFunctionType = __m256 (*)( + const __m256i, const __m256i, const __m256i, const __m256i +); + +template +MLAS_FORCEINLINE void +ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen64_NCols4( + size_t BlkLen, + const std::byte* QuantARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + float* SumPtr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* BiasPtr +) +{ + // TODO: make it work with all BlkLens + assert(BlkLen >= 64); + constexpr size_t SubBlkLen64 = 64; + // const __m256i zero = _mm256_setzero_si256(); + const __m256i low_mask = _mm256_set1_epi8(0xF); + + constexpr size_t BlkBitWidth = 4; + constexpr size_t SubBlkStep = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, SubBlkLen64); + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(), acc2 = _mm256_setzero_ps(), acc3 = _mm256_setzero_ps(); + + const std::byte* ablob = QuantARowPtr; + const std::byte* b_blk_data_ptr = QuantBDataColPtr; + const float* blk_scale_ptr = QuantBScaleColPtr; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + // only used if HasZeroPoint == true + + for (size_t k = 0; k < CountK; k += BlkLen) { + size_t ck = std::min(CountK - k, BlkLen); + + const float a_scale = Q8BlkScale(ablob); + ablob += sizeof(float); + + float + scale_v0 = (*(blk_scale_ptr + StrideQuantBScale * 0)) * a_scale, + scale_v1 = (*(blk_scale_ptr + StrideQuantBScale * 1)) * a_scale, + scale_v2 = (*(blk_scale_ptr + StrideQuantBScale * 2)) * a_scale, + scale_v3 = (*(blk_scale_ptr + StrideQuantBScale * 3)) * a_scale; + + const std::byte* bptr0 = (b_blk_data_ptr + StrideQuantBData * 0); + const std::byte* bptr1 = (b_blk_data_ptr + StrideQuantBData * 1); + const std::byte* bptr2 = (b_blk_data_ptr + StrideQuantBData * 2); + const std::byte* bptr3 = (b_blk_data_ptr + StrideQuantBData * 3); + + uint8_t zp0, zp1, zp2, zp3; + if constexpr (HasZeroPoint) { + // TODO: this block causes near 30% of the computation. + bool is_lower = (QuantBZeroPointIdx & 1) == 0; + std::byte zp_packed = QuantBZeroPointColPtr[0 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + zp0 = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + zp_packed = QuantBZeroPointColPtr[1 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + zp1 = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + zp_packed = QuantBZeroPointColPtr[2 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + zp2 = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + zp_packed = QuantBZeroPointColPtr[3 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + zp3 = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + } else { + zp0 = 8; + zp1 = 8; + zp2 = 8; + zp3 = 8; + } + + for (size_t kk = 0; kk < ck; kk += SubBlkLen64) { + // Load A row vector + const __m256i av0_32_epi8 = _mm256_loadu_si256((const __m256i*)ablob); + ablob += 32; + const __m256i av1_32_epi8 = _mm256_loadu_si256((const __m256i*)ablob); + ablob += 32; + + // Load B column vectors (quantized to int4 blobs) + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + __m256i bv = _mm256_loadu_si256((__m256i const*)bptr0); + bptr0 += SubBlkStep; + __m256i bv0_32_epi8 = _mm256_and_si256(bv, low_mask); + __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv, 4), low_mask); + __m256i zp_epi8 = _mm256_set1_epi8(zp0); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, zp_epi8); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, zp_epi8); + __m256 sum_ps = dot_quad(bv0_32_epi8, bv1_32_epi8, av0_32_epi8, av1_32_epi8); + acc0 = _mm256_fmadd_ps(_mm256_set1_ps(scale_v0), sum_ps, acc0); + + bv = _mm256_loadu_si256((__m256i const*)bptr1); + bptr1 += SubBlkStep; + bv0_32_epi8 = _mm256_and_si256(bv, low_mask); + bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv, 4), low_mask); + zp_epi8 = _mm256_set1_epi8(zp1); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, zp_epi8); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, zp_epi8); + sum_ps = dot_quad(bv0_32_epi8, bv1_32_epi8, av0_32_epi8, av1_32_epi8); + acc1 = _mm256_fmadd_ps(_mm256_set1_ps(scale_v1), sum_ps, acc1); + + bv = _mm256_loadu_si256((__m256i const*)bptr2); + bptr2 += SubBlkStep; + bv0_32_epi8 = _mm256_and_si256(bv, low_mask); + bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv, 4), low_mask); + zp_epi8 = _mm256_set1_epi8(zp2); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, zp_epi8); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, zp_epi8); + sum_ps = dot_quad(bv0_32_epi8, bv1_32_epi8, av0_32_epi8, av1_32_epi8); + acc2 = _mm256_fmadd_ps(_mm256_set1_ps(scale_v2), sum_ps, acc2); + + bv = _mm256_loadu_si256((__m256i const*)bptr3); + bptr3 += SubBlkStep; + bv0_32_epi8 = _mm256_and_si256(bv, low_mask); + bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv, 4), low_mask); + zp_epi8 = _mm256_set1_epi8(zp3); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, zp_epi8); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, zp_epi8); + sum_ps = dot_quad(bv0_32_epi8, bv1_32_epi8, av0_32_epi8, av1_32_epi8); + acc3 = _mm256_fmadd_ps(_mm256_set1_ps(scale_v3), sum_ps, acc3); + } // kk + + b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + blk_scale_ptr++; + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + } // k + + __m128 acc_x = FoldAccumulators(acc0, acc1, acc2, acc3); + if (BiasPtr != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_x); +} + +// TODO: is this able to be inlined if DotQuadFunctionType is a function pointer? +template +MLAS_FORCEINLINE void +ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen64_NCols1( + size_t BlkLen, + const std::byte* QuantARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + float* SumPtr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* BiasPtr +) +{ + // TODO: make it work with all BlkLens + assert(BlkLen >= 64); + constexpr size_t SubBlkLen64 = 64; + // const __m256i zero = _mm256_setzero_si256(); + const __m256i low_mask = _mm256_set1_epi8(0xF); + + constexpr size_t BlkBitWidth = 4; + constexpr size_t SubBlkStep = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, SubBlkLen64); + + __m256 acc0 = _mm256_setzero_ps(); + + const std::byte* ablob = QuantARowPtr; + const std::byte* b_blk_data_ptr = QuantBDataColPtr; + const float* blk_scale_ptr = QuantBScaleColPtr; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + // only used if HasZeroPoint == true + + for (size_t k = 0; k < CountK; k += BlkLen) { + size_t ck = std::min(CountK - k, BlkLen); + + const float a_scale = Q8BlkScale(ablob); + ablob += sizeof(float); + + float scale_v0 = (*(blk_scale_ptr + StrideQuantBScale * 0)) * a_scale; + + const std::byte* bptr0 = (b_blk_data_ptr + StrideQuantBData * 0); + + uint8_t zp0; + if constexpr (HasZeroPoint) { + // TODO: this block causes near 30% of the computation. + // The solution proposed by @yufenglee is to factor out scaleB * zp + // while packing A. Will do this in next PR. + bool is_lower = (QuantBZeroPointIdx & 1) == 0; + std::byte zp_packed = QuantBZeroPointColPtr[0 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + zp0 = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + } else { + zp0 = 8; + } + + for (size_t kk = 0; kk < ck; kk += SubBlkLen64) { + // Load A row vector + const __m256i a_byte_lo = _mm256_loadu_si256((const __m256i*)ablob); + ablob += 32; + const __m256i a_byte_hi = _mm256_loadu_si256((const __m256i*)ablob); + ablob += 32; + + // Load B column vectors (quantized to int4 blobs) + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + __m256i bv = _mm256_loadu_si256((__m256i const*)bptr0); + bptr0 += SubBlkStep; + __m256i bv_lo_epi8 = _mm256_and_si256(bv, low_mask); + __m256i bv_hi_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv, 4), low_mask); + __m256i zp_epi8 = _mm256_set1_epi8(zp0); + bv_lo_epi8 = _mm256_sub_epi8(bv_lo_epi8, zp_epi8); + bv_hi_epi8 = _mm256_sub_epi8(bv_hi_epi8, zp_epi8); + __m256 sum_ps = dot_quad(bv_lo_epi8, bv_hi_epi8, a_byte_lo, a_byte_hi); + //__m256 sum_ps = mul_sum_s8_quads_float_avx2(bv_lo_epi8, bv_hi_epi8, a_byte_lo, a_byte_hi); + acc0 = _mm256_fmadd_ps(_mm256_set1_ps(scale_v0), sum_ps, acc0); + } // kk + + b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + blk_scale_ptr++; + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + } // k + + *SumPtr = hsum_float_8(acc0); + *SumPtr += BiasPtr == nullptr ? 0.0f : *BiasPtr; +} + +template +void +SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + constexpr size_t BlkBitWidth = 4; + + const std::byte* QuantARowPtr = QuantA; + float* CRowPtr = C; + + const size_t BlockCountK = BlockStrideQuantB; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + const size_t NCols = 4; + int64_t nblk = static_cast(CountN) - NCols; + + while (nblk >= 0) { + ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen64_NCols4( + BlkLen, QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, + SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, BiasPtr + ); + + // move to next `NCols` columns + + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols : 0; + SumPtr += NCols; + + nblk -= NCols; + } + + // left over columns less than `NCols`? + nblk += NCols; + for (int64_t n = 0; n < nblk; ++n) { + ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen64_NCols1( + BlkLen, + QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index 2fc24b358becd..ffa8b79ebd799 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -1159,7 +1159,7 @@ SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen32( // load B zero point const int8x16_t bzp0 = vdupq_n_s8( - HasZeroPoint ? std::to_integer((*QuantBZeroPoint) & std::byte{0x0F}) : 8 + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 ); // load A diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 353455631a103..1b2b9e3a4b978 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -99,6 +99,15 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura static_cast(K), static_cast(block_size)); +#if 0 + for (int i = 0; i < input1_vals.size(); i++) + { + uint8_t byte = input1_vals[i]; + uint8_t val_lo = byte & 0x0f; + uint8_t val_hi = byte >> 4; + std::cout << (int)val_lo << ", " << (int)val_hi << ", "; + } +#endif std::vector expected_vals(M * N); for (int64_t m = 0; m < M; m++) { for (int64_t n = 0; n < N; n++) { @@ -219,7 +228,7 @@ TEST(MatMulNBits, Float32) { RunTest(M, N, K, block_size, accuracy_level, true, false); } #else - for (auto accuracy_level : {0}) { + for (auto accuracy_level : {0, 1, 4}) { RunTest(M, N, K, block_size, accuracy_level, false, false); RunTest(M, N, K, block_size, accuracy_level, true, false); RunTest(M, N, K, block_size, accuracy_level, false, false, true); diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index ed09d7ee92b2a..71a6123b868bc 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -197,7 +197,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { public: void Test(size_t M, size_t N, size_t K, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - bool WithBias, bool Symmetric, bool WithThreadpool) { + bool WithThreadpool, bool Symmetric, bool WithBias) { MLAS_THREADPOOL* Threadpool = WithThreadpool ? GetMlasThreadPool() : nullptr; const float* A = BufferA.GetBuffer(K * M); @@ -210,10 +210,10 @@ class MlasSQNBitGemmTest : public MlasTestBase { } #if 0 - auto print_matrix = [](size_t ncols, size_t nrows, const float* data) { + auto print_matrix = [](size_t nrows, size_t ncols, const float* data) { for (size_t row = 0; row < nrows; ++row) { for (size_t col = 0; col < ncols; ++col) { - std::cout << data[row * nrows + col] << "\t"; + std::cout << data[row * ncols + col] << "\t"; } std::cout << "\n"; } @@ -383,6 +383,16 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture