diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 72948c74d7877..166f5c8f52f54 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -9,6 +9,7 @@ #include "core/mlas/inc/mlas_q4.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" + #ifdef ORT_NEURAL_SPEED #include "contrib_ops/cpu/quantization/neural_speed_gemm.h" #endif @@ -16,6 +17,39 @@ namespace onnxruntime { namespace contrib { +namespace { +int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level_attr) { + const auto accuracy_level = std::clamp(accuracy_level_attr, + static_cast(CompMostAccurate), + static_cast(CompLeastAccurate)); + +#if defined(ORT_NEURAL_SPEED) + + ORT_UNUSED_PARAMETER(nbits); + ORT_UNUSED_PARAMETER(block_size); + + // Neural Speed APIs already expect a minimum accuracy level so just use the given value. + return accuracy_level; + +#else // defined(ORT_NEURAL_SPEED) + + // Find a supported accuracy level that is not less accurate than the one given. + // CompMostAccurate is always supported with the fallback implementation. + // Note: A higher numeric accuracy level value means lower accuracy, so the comparison order is reversed. + int64_t effective_accuracy_level = accuracy_level; + for (; effective_accuracy_level > CompMostAccurate; --effective_accuracy_level) { + const auto compute_type = static_cast(effective_accuracy_level); + if (MlasIsSQNBitGemmAvailable(nbits, block_size, compute_type)) { + break; + } + } + + return effective_accuracy_level; + +#endif // defined(ORT_NEURAL_SPEED) +} +} // namespace + class MatMulNBits final : public OpKernel { public: MatMulNBits(const OpKernelInfo& info) @@ -24,7 +58,7 @@ class MatMulNBits final : public OpKernel { N_{narrow(info.GetAttr("N"))}, block_size_{narrow(info.GetAttr("block_size"))}, nbits_{narrow(info.GetAttr("bits"))}, - accuracy_level_{info.GetAttr("accuracy_level")} { + accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr("accuracy_level"))} { ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); #ifdef ORT_NEURAL_SPEED @@ -58,17 +92,22 @@ class MatMulNBits final : public OpKernel { const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_; size_t packed_b_size_{0}; -#ifdef ORT_NEURAL_SPEED + +#if defined(ORT_NEURAL_SPEED) + bool is_asym_{false}; bool all_constant_{false}; -#endif + +#endif // defined(ORT_NEURAL_SPEED) }; Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; -#ifdef ORT_NEURAL_SPEED + +#if defined(ORT_NEURAL_SPEED) + if (!all_constant_) { return Status::OK(); } @@ -116,11 +155,17 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat #else // defined(ORT_NEURAL_SPEED) if (input_idx == 1) { - packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_); - if (packed_b_size_ == 0) return Status::OK(); + const auto compute_type = static_cast(accuracy_level_); + if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { + return Status::OK(); + } + packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type); + if (packed_b_size_ == 0) { + return Status::OK(); + } auto qptr = tensor.DataRaw(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, qptr, packed_b_.get()); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get()); if (prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); @@ -136,7 +181,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; -#ifdef ORT_NEURAL_SPEED + +#if defined(ORT_NEURAL_SPEED) + // Pack three tensors into one buffer if (input_idx == 1) { used_shared_buffers = true; @@ -159,6 +206,7 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep } #endif // defined(ORT_NEURAL_SPEED) + return Status::OK(); } @@ -167,8 +215,10 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const Tensor* a = ctx->Input(0); const auto* a_data = a->Data(); -#ifdef ORT_NEURAL_SPEED - if (packed_b_.get()) { + +#if defined(ORT_NEURAL_SPEED) + + if (packed_b_) { TensorShape b_shape({static_cast(N_), static_cast(K_)}); MatMulComputeHelper helper; @@ -234,37 +284,43 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const bool has_single_b_matrix = std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), [](size_t offset) { return offset == 0; }); - if (has_single_b_matrix && packed_b_) { - for (int64_t accuracy_level = accuracy_level_; - accuracy_level >= static_cast(CompMostAccurate); - --accuracy_level) { - const auto compute_type = static_cast(accuracy_level); - if (MlasIsSQNBitGemmAvailable(M, N, K, nbits_, block_size_, compute_type)) { - IAllocatorUniquePtr workspace{}; - if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count, - nbits_, block_size_, compute_type); - workspace_size > 0) { - AllocatorPtr allocator; - ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); - workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); - } + if (has_single_b_matrix) { + const auto compute_type = static_cast(accuracy_level_); + + if (MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { + IAllocatorUniquePtr workspace{}; + if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count, + nbits_, block_size_, compute_type); + workspace_size > 0) { + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); + workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); + } - InlinedVector data(batch_count); - for (size_t i = 0; i < batch_count; ++i) { - data[i].A = a_data + helper.LeftOffsets()[i]; - data[i].lda = lda; - data[i].QuantBData = packed_b_.get(); - data[i].QuantBScale = scales_data; - data[i].QuantBZeroPoint = zero_points_data; - data[i].C = y_data + helper.OutputOffsets()[i]; - data[i].ldc = N; + const void* b_data = [&]() -> const void* { + if (packed_b_) { + return packed_b_.get(); } - MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), - thread_pool); - - return Status::OK(); + const Tensor* b = ctx->Input(1); + return b->DataRaw(); + }(); + + InlinedVector data(batch_count); + for (size_t i = 0; i < batch_count; ++i) { + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].QuantBData = b_data; + data[i].QuantBScale = scales_data; + data[i].QuantBZeroPoint = zero_points_data; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; } + + MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), + thread_pool); + + return Status::OK(); } } diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 047011e70bd4d..32e9cc98106d5 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -37,9 +37,7 @@ typedef enum { CompMostAccurate = CompUndef, CompLeastAccurate = CompInt8, -} MLAS_SQNBIT_COMPUTE_TYPE; - -using MLAS_SQNBIT_GEMM_COMPUTE_TYPE = MLAS_SQNBIT_COMPUTE_TYPE; // TODO consolidate these +} MLAS_SQNBIT_GEMM_COMPUTE_TYPE; /** * @brief Data parameters for float/n-bit quantized int GEMM routine. @@ -102,18 +100,12 @@ MlasSQNBitGemmBatch( /** * @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform. * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ bool MLASCALL MlasIsSQNBitGemmAvailable( - size_t M, - size_t N, - size_t K, size_t BlkBitWidth, size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType @@ -153,13 +145,15 @@ MlasSQNBitGemmBatchWorkspaceSize( * @param[in] K column size of matrix A and row size of matrix B * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ size_t MLASCALL MlasSQNBitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkBitWidth, - size_t BlkLen + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ); /** @@ -169,6 +163,7 @@ MlasSQNBitGemmPackQuantBDataSize( * @param[in] K column size of matrix A and row size of matrix B * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) * @param[in] QuantBData quantized B data * @param[out] PackedQuantBData packed quantized B data * @param[in] ThreadPool optional thread pool to use @@ -179,6 +174,7 @@ MlasSQNBitGemmPackQuantBData( size_t K, size_t BlkBitWidth, size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, void* PackedQuantBData, MLAS_THREADPOOL* ThreadPool = nullptr diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 0d8a5692359a6..38c31c8841761 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -39,23 +39,17 @@ enum SQNBitGemmVariant { SQNBitGemmVariant GetSQNBitGemmVariant( - size_t M, - size_t N, - size_t K, size_t BlkBitWidth, size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - MLAS_UNREFERENCED_PARAMETER(N); - MLAS_UNREFERENCED_PARAMETER(K); - if (BlkBitWidth == 4 && (BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { if (ComputeType == CompFp32 || ComputeType == CompUndef) { // treat CompUndef (undefined) as CompFp32 return SQNBitGemmVariant_BitWidth4_CompFp32; - } else if (ComputeType == CompInt8 && M == 1) { + } else if (ComputeType == CompInt8) { return SQNBitGemmVariant_BitWidth4_CompInt8; } } @@ -67,9 +61,6 @@ GetSQNBitGemmVariant( bool MLASCALL MlasIsSQNBitGemmAvailable( - size_t M, - size_t N, - size_t K, size_t BlkBitWidth, size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType @@ -80,7 +71,7 @@ MlasIsSQNBitGemmAvailable( return false; } - const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); switch (Variant) { case SQNBitGemmVariant_BitWidth4_CompFp32: { @@ -164,7 +155,7 @@ MlasSQNBitGemmBatchWorkspaceSize( MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen); if (PerGemmWorkspaceStride == 0) { @@ -178,91 +169,24 @@ MlasSQNBitGemmBatchWorkspaceSize( return WorkspaceSize + Alignment - 1; } -namespace -{ - -void -SQ4BitGemmPackQuantBData( - size_t N, - size_t K, - size_t BlkLen, - const std::byte* QuantBDataBegin, - std::byte* PackedQuantBDataBegin, - MLAS_THREADPOOL* ThreadPool -) -{ - constexpr size_t BlkBitWidth = 4; - - assert(BlkLen % 16 == 0); - - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t Iterations = N * BlockCountK; // one iteration per block - - MlasTrySimpleParallel( - ThreadPool, Iterations, - [&](ptrdiff_t tid) { - const size_t n = tid / BlockCountK; - const size_t k_blk = tid % BlockCountK; - - const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize; - const std::byte* QuantBData = QuantBDataBegin + data_offset; - std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset; - - // - // Pack 16 4-bit values (8 bytes) at a time like this: - // - // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | - // => - // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | - // - for (size_t kk = 0; kk < BlkLen; kk += 16) { - for (size_t byte_pair_idx = 0; byte_pair_idx < 4; ++byte_pair_idx) { - const std::byte src0 = QuantBData[byte_pair_idx]; - const std::byte src1 = QuantBData[byte_pair_idx + 4]; - - std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; - std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; - - dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); - dst1 = (src0 >> 4) | ((src1 >> 4) << 4); - } - - QuantBData += 8; - PackedQuantBData += 8; - } - } - ); -} - -} // namespace - size_t MLASCALL MlasSQNBitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkBitWidth, - size_t BlkLen + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - // Ensure that a general implementation is available on this platform. - // For now, all implementations share the same packed format. - { - // Currently, there are implementations specific to M = 1, so pick a more general M > 1. - constexpr size_t M = 2; - // A CompUndef implementation should be available if any is available. - constexpr MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType = CompUndef; - const bool HasGeneralImplementation = - MlasIsSQNBitGemmAvailable(M, N, K, BlkBitWidth, BlkLen, ComputeType); - if (!HasGeneralImplementation) { - return 0; - } + const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + if (Dispatch == nullptr) { + return 0; } - if (BlkBitWidth == 4) { - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - return PackedQuantBDataSize; + if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBDataSize != nullptr) { + return Dispatch->SQ4BitGemmPackQuantBDataSize( + N, K, BlkLen, ComputeType + ); } return 0; @@ -274,20 +198,28 @@ MlasSQNBitGemmPackQuantBData( size_t K, size_t BlkBitWidth, size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, void* PackedQuantBData, MLAS_THREADPOOL* ThreadPool ) { - if (BlkBitWidth == 4) { - SQ4BitGemmPackQuantBData( + const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + if (Dispatch == nullptr) { + return; + } + + if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBData != nullptr) { + Dispatch->SQ4BitGemmPackQuantBData( N, K, BlkLen, + ComputeType, static_cast(QuantBData), static_cast(PackedQuantBData), ThreadPool ); + return; } } @@ -512,7 +444,37 @@ SQ4BitGemm_CompInt8( return; } - assert(false && "not implemented for M > 1"); + // This is a naive M > 1 implementation that repeatedly calls the M=1 kernel. + // TODO Replace it with an optimized implementation. + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const std::byte* a_row = QuantA; + const std::byte* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const std::byte* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + for (size_t m = 0; m < RangeCountM; ++m) { + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( + BlkLen, + a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + + c_blk += ldc; + a_row += lda; + } + } } typedef void(InitializeWorkspaceFn)( @@ -594,7 +556,7 @@ MlasSQNBitGemmBatch( MLAS_THREADPOOL* ThreadPool ) { - const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); assert(Variant != SQNBitGemmVariantInvalid); // diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index a66db79dc290a..3992bc3e452a3 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -99,6 +99,33 @@ Q8BlkAlignment() // struct MLAS_SQNBIT_GEMM_DISPATCH { + // + // Quantized B data packing function prototypes. + // + + /** Gets size of packed quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBDataSize(). */ + typedef size_t(SQ4BitGemmPackQuantBDataSize_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + ); + + SQ4BitGemmPackQuantBDataSize_Fn* SQ4BitGemmPackQuantBDataSize = nullptr; + + /** Packs quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBData(). */ + typedef void(SQ4BitGemmPackQuantBData_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool + ); + + SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; + // // CompFp32 kernel function prototypes. // diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index 69fd427fa574a..c4c54a9be34d8 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -15,14 +15,115 @@ Module Name: --*/ -#include "sqnbitgemm.h" - #include #include #include #include +#include "sqnbitgemm.h" + +// +// Quantized B data packing function implementation. +// + +namespace +{ + +size_t +SQ4BitGemmPackQuantBDataSize( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType + + constexpr size_t BlkBitWidth = 4; + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + return PackedQuantBDataSize; +} + +void +SQ4BitGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + constexpr size_t BlkBitWidth = 4; + + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t Iterations = N * BlockCountK; // one iteration per block + + const size_t SubBlkLen = (ComputeType == CompInt8) + ? ((BlkLen == 16) ? 16 : 32) + : 16; + + const size_t SubBlkDataSize = SubBlkLen / 2; + const size_t SubBlkBytePairCount = SubBlkLen / 4; + + // + // For SubBlkLen == 16, pack 16 4-bit values (8 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | + // => + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + // + + // + // For SubBlkLen == 32, pack 32 4-bit values (16 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | + // => + // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t k_blk = tid % BlockCountK; + + const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize; + const std::byte* QuantBData = QuantBDataBegin + data_offset; + std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset; + + for (size_t kk = 0; kk < BlkLen; kk += SubBlkLen) { + for (size_t byte_pair_idx = 0; byte_pair_idx < SubBlkBytePairCount; ++byte_pair_idx) { + const std::byte src0 = QuantBData[byte_pair_idx]; + const std::byte src1 = QuantBData[byte_pair_idx + SubBlkDataSize / 2]; + + std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; + std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; + + dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); + dst1 = (src0 >> 4) | ((src1 >> 4) << 4); + } + + QuantBData += SubBlkDataSize; + PackedQuantBData += SubBlkDataSize; + } + } + ); +} + +} // namespace + +// +// General helpers. +// + namespace { @@ -95,7 +196,16 @@ LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4]) } } -template +} // namespace + +// +// CompFp32 kernel implementation. +// + +namespace +{ + +template MLAS_FORCEINLINE void ComputeDotProducts_BlkBitWidth4_CompFp32( size_t BlkLen, @@ -112,11 +222,11 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( ) { constexpr size_t BlkBitWidth = 4; + constexpr size_t SubBlkLen = 16; static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); - constexpr size_t SubBlkLen = 16; // number of block elements to process in a sub-block iteration - assert(BlkLen % SubBlkLen == 0); + assert(BlkLen >= SubBlkLen && BlkLen % SubBlkLen == 0); const uint8x8_t LowMask = vdup_n_u8(0x0F); @@ -137,7 +247,8 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( const std::byte* QuantBData = QuantBDataColPtr; const float* QuantBScale = QuantBScaleColPtr; - size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + // only used if HasZeroPoint == true for (size_t k = 0; k < CountK; k += BlkLen) { const size_t k_blk_len = std::min(CountK - k, BlkLen); @@ -147,8 +258,9 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( [&](size_t i) { scale[i] = QuantBScale[i * StrideQuantBScale]; } ); - float offset[NCols]; // Includes zero point and float conversion offset of 16. - if (QuantBZeroPointColPtr != nullptr) { + [[maybe_unused]] float offset[NCols]; // Includes zero point and float conversion offset of 16. + // only used if HasZeroPoint == true + if constexpr (HasZeroPoint) { UnrolledLoop([&](size_t i) { const std::byte zp_packed = QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; @@ -157,11 +269,6 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( : (zp_packed & std::byte{0x0F}); offset[i] = 16.0f + std::to_integer(zp); }); - } else { - UnrolledLoop([&](size_t i) { - constexpr float zp = 8.0f; - offset[i] = 16.0f + zp; - }); } for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { @@ -187,8 +294,6 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); }); - // dequantize B - // shift left 3 and widen to 16 bits uint16x8_t bv_u16[NCols][2]; UnrolledLoop([&](size_t i) { @@ -217,10 +322,17 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( }); // subtract float conversion offset (16) and zero point - UnrolledLoop([&](size_t i) { - const float32x4_t offset_v = vdupq_n_f32(offset[i]); - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); - }); + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const float32x4_t offset_v = vdupq_n_f32(offset[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } else { + const float32x4_t offset_v = vdupq_n_f32(16.0f + 8.0f); + UnrolledLoop([&](size_t i) { + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } // multiply by scale UnrolledLoop([&](size_t i) { @@ -237,7 +349,9 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( // increment pointers to next block QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); QuantBScale += 1; - QuantBZeroPointIdx += 1; + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } } if constexpr (NCols == 4) { @@ -258,8 +372,9 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( } } -MLAS_FORCEINLINE void -SQ4BitGemmM1Kernel_CompFp32( +template +void +SQ4BitGemmM1Kernel_CompFp32_Impl( size_t BlkLen, const float* A, const std::byte* QuantBData, @@ -295,7 +410,7 @@ SQ4BitGemmM1Kernel_CompFp32( int64_t nblk = static_cast(CountN) - NCols; while (nblk >= 0) { - ComputeDotProducts_BlkBitWidth4_CompFp32( + ComputeDotProducts_BlkBitWidth4_CompFp32( BlkLen, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, @@ -306,7 +421,7 @@ SQ4BitGemmM1Kernel_CompFp32( QuantBDataColPtr += NCols * StrideQuantBData; QuantBScaleColPtr += NCols * StrideQuantBScale; - if (QuantBZeroPointColPtr != nullptr) { + if constexpr (HasZeroPoint) { QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; } @@ -319,7 +434,7 @@ SQ4BitGemmM1Kernel_CompFp32( // left over columns less than `NCols`? nblk += NCols; for (int64_t n = 0; n < nblk; ++n) { - ComputeDotProducts_BlkBitWidth4_CompFp32<1>( + ComputeDotProducts_BlkBitWidth4_CompFp32<1, HasZeroPoint>( BlkLen, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, @@ -330,7 +445,7 @@ SQ4BitGemmM1Kernel_CompFp32( QuantBDataColPtr += StrideQuantBData; QuantBScaleColPtr += StrideQuantBScale; - if (QuantBZeroPointColPtr != nullptr) { + if constexpr (HasZeroPoint) { QuantBZeroPointColPtr += StrideQuantBZeroPoint; } @@ -339,6 +454,49 @@ SQ4BitGemmM1Kernel_CompFp32( } } +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_CompFp32( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_CompFp32_Impl( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_CompFp32_Impl( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } +} + MLAS_FORCEINLINE void Q4BitBlkDequantBForSgemm_CompFp32( size_t BlkLen, @@ -353,6 +511,7 @@ Q4BitBlkDequantBForSgemm_CompFp32( { auto impl0_reference = [&]() { constexpr size_t BlkBitWidth = 4; + constexpr size_t SubBlkLen = 16; float* Dst = FpData; @@ -378,11 +537,11 @@ Q4BitBlkDequantBForSgemm_CompFp32( : 8; for (size_t kk = 0; kk < kklen; ++kk) { - const size_t packed_idx = kk % 16; + const size_t packed_idx = kk % SubBlkLen; - const bool is_low_half = packed_idx < 8; - const size_t packed_byte_idx = packed_idx % 8; - const size_t packed_range_offset = (kk / 16) * 8; + const bool is_low_half = packed_idx < (SubBlkLen / 2); + const size_t packed_byte_idx = packed_idx % (SubBlkLen / 2); + const size_t packed_range_offset = (kk / SubBlkLen) * (SubBlkLen / 2); const std::byte b_packed = b_data[packed_range_offset + packed_byte_idx]; const std::byte b_byte = is_low_half ? (b_packed & std::byte{0x0F}) : (b_packed >> 4); @@ -415,7 +574,7 @@ Q4BitBlkDequantBForSgemm_CompFp32( } // -// CompInt8 kernel implementation and related helpers +// CompInt8 kernel implementation. // template @@ -431,8 +590,6 @@ QuantizeBlock( assert(BlkLen % SubBlkLen == 0); - constexpr size_t VectorCount = SubBlkLen / 4; - // // Scan block values first to determine scale. // @@ -443,16 +600,16 @@ QuantizeBlock( for (k = 0; k < ElementCount; k += SubBlkLen) { const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); - float32x4_t a[VectorCount]{}; + float32x4_t a[SubBlkLen / 4]{}; LoadFloatData(A + k, SubBlkElementCount, a); - float32x4_t abs_a[VectorCount]; - UnrolledLoop([&](size_t i) { + float32x4_t abs_a[SubBlkLen / 4]; + UnrolledLoop([&](size_t i) { abs_a[i] = vabsq_f32(a[i]); }); // find amax of SubBlkLen elements - for (size_t interval = VectorCount / 2; interval > 0; interval /= 2) { + for (size_t interval = SubBlkLen / 4 / 2; interval > 0; interval /= 2) { for (size_t i = 0; i < interval; ++i) { abs_a[i] = vmaxq_f32(abs_a[i], abs_a[i + interval]); } @@ -477,19 +634,19 @@ QuantizeBlock( for (k = 0; k < ElementCount; k += SubBlkLen) { const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); - float32x4_t a[VectorCount]{}; + float32x4_t a[SubBlkLen / 4]{}; LoadFloatData(A + k, SubBlkElementCount, a); - UnrolledLoop([&](size_t i) { + UnrolledLoop([&](size_t i) { a[i] = vmulq_n_f32(a[i], scale_reciprocal); }); - int32x4_t a_s32[VectorCount]; - UnrolledLoop([&](size_t i) { + int32x4_t a_s32[SubBlkLen / 4]; + UnrolledLoop([&](size_t i) { a_s32[i] = vcvtaq_s32_f32(a[i]); }); - UnrolledLoop([&](size_t i) { + UnrolledLoop([&](size_t i) { QuantAData[k + i * 4 + 0] = static_cast(vgetq_lane_s32(a_s32[i], 0)); QuantAData[k + i * 4 + 1] = static_cast(vgetq_lane_s32(a_s32[i], 1)); QuantAData[k + i * 4 + 2] = static_cast(vgetq_lane_s32(a_s32[i], 2)); @@ -530,7 +687,7 @@ QuantizeARow_CompInt8( } } -template +template MLAS_FORCEINLINE void ComputeDotProducts_BlkBitWidth4_CompInt8( size_t BlkLen, @@ -546,20 +703,22 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( const float* BiasPtr ) { - static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); - constexpr size_t BlkBitWidth = 4; - constexpr size_t SubBlkLen = 16; // number of block elements to process in a sub-block iteration - assert(BlkLen % SubBlkLen == 0); + static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); + static_assert(SubBlkLen == 16 || SubBlkLen == 32, "SubBlkLen must be 16 or 32"); - const uint8x8_t LowMask = vdup_n_u8(0x0F); + assert(BlkLen >= SubBlkLen && BlkLen % SubBlkLen == 0); + + [[maybe_unused]] const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F); // only used if SubBlkLen == 16 + [[maybe_unused]] const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); // only used if SubBlkLen == 32 const std::byte* QuantA = QuantARowPtr; const std::byte* QuantBData = QuantBDataColPtr; const float* QuantBScale = QuantBScaleColPtr; - size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + // only used if HasZeroPoint == true float32x4_t acc[NCols]{}; @@ -572,8 +731,8 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( float b_scale[NCols]; UnrolledLoop([&](size_t i) { b_scale[i] = QuantBScale[i * StrideQuantBScale]; }); - int8_t b_zp[NCols]; - if (QuantBZeroPointColPtr != nullptr) { + [[maybe_unused]] int8_t b_zp[NCols]; // only used if HasZeroPoint == true + if constexpr (HasZeroPoint) { UnrolledLoop([&](size_t i) { const std::byte zp_packed = QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; @@ -581,42 +740,73 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( ? std::to_integer(zp_packed >> 4) : std::to_integer(zp_packed & std::byte{0x0F}); }); - } else { - UnrolledLoop([&](size_t i) { - b_zp[i] = 8; - }); } for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { // load A row vector - int8x16_t av = vld1q_s8(a_data + k_idx_in_blk); + int8x16_t av[SubBlkLen / 16]; + UnrolledLoop([&](size_t i) { + av[i] = vld1q_s8(a_data + k_idx_in_blk + i * 16); + }); // load B column vectors - uint8x8_t bv_packed[NCols]; + int8x16_t bv[NCols][SubBlkLen / 16]; + const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; - UnrolledLoop([&](size_t i) { - bv_packed[i] = vld1_u8( - reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset - ); - }); - int8x16_t bv[NCols]; - UnrolledLoop([&](size_t i) { - const int8x8_t lo = vreinterpret_s8_u8(vand_u8(bv_packed[i], LowMask)); - const int8x8_t hi = vreinterpret_s8_u8(vshr_n_u8(bv_packed[i], 4)); - bv[i] = vcombine_s8(lo, hi); - }); + if constexpr (SubBlkLen == 16) { + uint8x8_t bv_packed[NCols]; + UnrolledLoop([&](size_t i) { + bv_packed[i] = vld1_u8( + reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset + ); + }); + + UnrolledLoop([&](size_t i) { + const int8x8_t lo = vreinterpret_s8_u8(vand_u8(bv_packed[i], LowMaskU8x8)); + const int8x8_t hi = vreinterpret_s8_u8(vshr_n_u8(bv_packed[i], 4)); + bv[i][0] = vcombine_s8(lo, hi); + }); + } else { + static_assert(SubBlkLen == 32); + + uint8x16_t bv_packed[NCols]; + UnrolledLoop([&](size_t i) { + bv_packed[i] = vld1q_u8( + reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset + ); + }); + + UnrolledLoop([&](size_t i) { + bv[i][0] = vreinterpretq_s8_u8(vandq_u8(bv_packed[i], LowMaskU8x16)); + bv[i][1] = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed[i], 4)); + }); + } // subtract B zero point - UnrolledLoop([&](size_t i) { - const int8x16_t zp_v = vdupq_n_s8(b_zp[i]); - bv[i] = vsubq_s8(bv[i], zp_v); - }); + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const int8x16_t zp_v = vdupq_n_s8(b_zp[i]); + UnrolledLoop([&](size_t j) { + bv[i][j] = vsubq_s8(bv[i][j], zp_v); + }); + }); + } else { + const int8x16_t zp_v = vdupq_n_s8(8); + + UnrolledLoop([&](size_t i) { + UnrolledLoop([&](size_t j) { + bv[i][j] = vsubq_s8(bv[i][j], zp_v); + }); + }); + } // compute quantized dot product int32x4_t dot[NCols]{}; UnrolledLoop([&](size_t i) { - dot[i] = vdotq_s32(dot[i], av, bv[i]); + UnrolledLoop([&](size_t j) { + dot[i] = vdotq_s32(dot[i], av[j], bv[i][j]); + }); }); // convert dot product result to float @@ -636,7 +826,9 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( QuantA += Q8BlkSize(BlkLen); QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); QuantBScale += 1; - QuantBZeroPointIdx += 1; + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } } if constexpr (NCols == 4) { @@ -657,9 +849,9 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( } } -MLAS_FORCEINLINE +template void -SQ4BitGemmM1Kernel_CompInt8( +SQ4BitGemmM1Kernel_CompInt8_Impl( size_t BlkLen, const std::byte* QuantA, const std::byte* QuantBData, @@ -673,7 +865,6 @@ SQ4BitGemmM1Kernel_CompInt8( ) { constexpr size_t BlkBitWidth = 4; - constexpr size_t NCols = 4; const std::byte* QuantARowPtr = QuantA; float* CRowPtr = C; @@ -695,7 +886,7 @@ SQ4BitGemmM1Kernel_CompInt8( int64_t nblk = static_cast(CountN) - NCols; while (nblk >= 0) { - ComputeDotProducts_BlkBitWidth4_CompInt8( + ComputeDotProducts_BlkBitWidth4_CompInt8( BlkLen, QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, @@ -706,7 +897,7 @@ SQ4BitGemmM1Kernel_CompInt8( QuantBDataColPtr += NCols * StrideQuantBData; QuantBScaleColPtr += NCols * StrideQuantBScale; - if (QuantBZeroPointColPtr != nullptr) { + if constexpr (HasZeroPoint) { QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; } @@ -719,7 +910,7 @@ SQ4BitGemmM1Kernel_CompInt8( // left over columns less than `NCols`? nblk += NCols; for (int64_t n = 0; n < nblk; ++n) { - ComputeDotProducts_BlkBitWidth4_CompInt8<1>( + ComputeDotProducts_BlkBitWidth4_CompInt8<1, SubBlkLen, HasZeroPoint>( BlkLen, QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, @@ -730,7 +921,7 @@ SQ4BitGemmM1Kernel_CompInt8( QuantBDataColPtr += StrideQuantBData; QuantBScaleColPtr += StrideQuantBScale; - if (QuantBZeroPointColPtr != nullptr) { + if constexpr (HasZeroPoint) { QuantBZeroPointColPtr += StrideQuantBZeroPoint; } @@ -739,6 +930,94 @@ SQ4BitGemmM1Kernel_CompInt8( } } +template +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (BlkLen == 16) { + SQ4BitGemmM1Kernel_CompInt8_Impl<4, 16, HasZeroPoint>( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_CompInt8_Impl<4, 32, HasZeroPoint>( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } +} + +MLAS_FORCEINLINE +void +SQ4BitGemmM1Kernel_CompInt8( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } +} + } // namespace // @@ -748,8 +1027,12 @@ SQ4BitGemmM1Kernel_CompInt8( const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { MLAS_SQNBIT_GEMM_DISPATCH d; + d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32; + d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8; d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp index 668d7a0611367..b7b453415838a 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -61,10 +61,11 @@ void SQNBITGEMM(benchmark::State& state) { } std::unique_ptr PackedQuantBData; - if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen); + if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); PackedQuantBDataSize > 0) { PackedQuantBData = std::make_unique(PackedQuantBDataSize); - MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, QuantBData.data(), PackedQuantBData.get(), tp.get()); + MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData.data(), PackedQuantBData.get(), + tp.get()); } MLAS_SQNBIT_GEMM_DATA_PARAMS params{}; @@ -87,7 +88,9 @@ void SQNBITGEMM(benchmark::State& state) { } } -static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) { +static void SQ4BitGemmArgs(benchmark::internal::Benchmark* b) { + constexpr size_t BlkBitWidth = 4; + b->ArgNames({"BlkLen", "M", "N", "K", "Threads", "Symmetric", "ComputeType"}); ArgsProductWithFilter(b, @@ -96,19 +99,17 @@ static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) { {1, 1024, 2048}, // M {4096, 11008}, // N {4096, 11008}, // K - {8}, // Threads + {1, 8}, // Threads {int64_t{false}, int64_t{true}}, // Symmetric {int64_t{CompFp32}, int64_t{CompInt8}}}, // ComputeType - [](const std::vector& args) { + [&](const std::vector& args) { return MlasIsSQNBitGemmAvailable( - // M, N, K - narrow(args[1]), narrow(args[2]), narrow(args[3]), // BlkBitWidth, BlkLen - 4, narrow(args[0]), + BlkBitWidth, narrow(args[0]), // ComputeType static_cast(args[6])); }); } -BENCHMARK(SQNBITGEMM<4>)->Apply(SQNBitGemmArgs)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4>)->Apply(SQ4BitGemmArgs)->UseRealTime(); diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 4fb8ab41745d5..ed09d7ee92b2a 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -259,10 +259,11 @@ class MlasSQNBitGemmTest : public MlasTestBase { } void* PackedQuantBData = nullptr; - if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen); + if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); PackedQuantBDataSize > 0) { PackedQuantBData = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize); - MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, QuantBData, PackedQuantBData, GetMlasThreadPool()); + MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBData, + GetMlasThreadPool()); } if (ComputeType == CompFp32) { @@ -330,7 +331,7 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture