From ed9396bd8f8b5c8694040d772bffadbc852a711f Mon Sep 17 00:00:00 2001 From: Jing Fang <126209182+fajin-corp@users.noreply.github.com> Date: Thu, 14 Nov 2024 18:38:59 +0000 Subject: [PATCH] [ARM] MatMulNBits Fp16 support - API change only (#22826) ### Description A break-down PR of https://github.com/microsoft/onnxruntime/pull/22651 Op API change only. - add template to functions and classes that support fp32 and fp16 - rename functions, classes and files that support fp32 and fp16 from SQNBxxx to QNBxxx ### Motivation and Context --- cmake/onnxruntime_mlas.cmake | 4 +- .../cpu/quantization/matmul_nbits.cc | 159 +++++---- onnxruntime/core/mlas/inc/mlas_qnbit.h | 82 ++--- .../mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp | 4 +- onnxruntime/core/mlas/lib/mlasi.h | 14 +- onnxruntime/core/mlas/lib/platform.cpp | 10 +- .../lib/{sqnbitgemm.cpp => qnbitgemm.cpp} | 313 ++++++++++++------ .../mlas/lib/{sqnbitgemm.h => qnbitgemm.h} | 45 +-- .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 30 +- .../sqnbitgemm_kernel_avx2_int8_blklen16.h | 2 +- .../sqnbitgemm_kernel_avx2_int8_blklen32.h | 2 +- .../sqnbitgemm_kernel_avx2_int8_blklen64.h | 2 +- .../mlas/lib/sqnbitgemm_kernel_avx512.cpp | 22 +- .../mlas/lib/sqnbitgemm_kernel_avx512_int8.h | 12 +- .../sqnbitgemm_kernel_avx512_int8_blklen128.h | 2 +- .../sqnbitgemm_kernel_avx512_int8_blklen16.h | 2 +- .../sqnbitgemm_kernel_avx512_int8_blklen32.h | 2 +- .../sqnbitgemm_kernel_avx512_int8_blklen64.h | 2 +- .../mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp | 18 +- .../mlas/lib/sqnbitgemm_kernel_avx_common.h | 24 +- .../lib/sqnbitgemm_kernel_avx_common_fp32.h | 2 +- .../lib/sqnbitgemm_kernel_avx_common_int8.h | 2 +- .../core/mlas/lib/sqnbitgemm_kernel_neon.cpp | 32 +- .../core/mlas/lib/sqnbitgemm_kernel_neon.h | 8 +- .../mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp | 6 +- .../mlas/lib/sqnbitgemm_kernel_neon_int8.cpp | 6 +- ...bitgemm_m1_sym_kernel_avx2_int8_blklen32.h | 2 +- ...bitgemm_m1_sym_kernel_avx2_int8_blklen64.h | 2 +- .../vsinpu/patches/mlas_crosscompiling.patch | 56 ++-- .../test/mlas/bench/bench_sqnbitgemm.cpp | 42 +-- .../test/mlas/unittest/test_sqnbitgemm.cpp | 40 +-- 31 files changed, 547 insertions(+), 402 deletions(-) rename onnxruntime/core/mlas/lib/{sqnbitgemm.cpp => qnbitgemm.cpp} (70%) rename onnxruntime/core/mlas/lib/{sqnbitgemm.h => qnbitgemm.h} (91%) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 1d578d09b1c03..a85ea942c42a3 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -36,8 +36,8 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/qpostprocessor.cpp ${MLAS_SRC_DIR}/qlgavgpool.cpp ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp - ${MLAS_SRC_DIR}/sqnbitgemm.h - ${MLAS_SRC_DIR}/sqnbitgemm.cpp + ${MLAS_SRC_DIR}/qnbitgemm.h + ${MLAS_SRC_DIR}/qnbitgemm.cpp ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h ${MLAS_SRC_DIR}/flashattn.cpp ${MLAS_SRC_DIR}/cast.cpp diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 89e96543c4729..473ec51524f22 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -32,24 +32,46 @@ constexpr size_t A = 0, bias = 5; }; -int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level_attr) { - const auto accuracy_level = std::clamp(accuracy_level_attr, - static_cast(CompMostAccurate), - static_cast(CompLeastAccurate)); - - // Find a supported accuracy level that is not less accurate than the one given. - // CompMostAccurate is always supported with the fallback implementation. - // Note: A higher numeric accuracy level value means lower accuracy, so the comparison order is reversed. - int64_t effective_accuracy_level = accuracy_level; - for (; effective_accuracy_level > CompMostAccurate; --effective_accuracy_level) { - const auto compute_type = static_cast(effective_accuracy_level); - if (MlasIsSQNBitGemmAvailable(nbits, block_size, compute_type)) { - break; - } +typedef enum { + Level1, /*!< input fp32, accumulator fp32 */ + Level2, /*!< input fp16, accumulator fp16 */ + Level3, /*!< input bf16, accumulator fp32 */ + Level4, /*!< input int8, accumulator int32 */ +} ACCURACY_LEVEL; + +// T: A data type. +template +MLAS_QNBIT_GEMM_COMPUTE_TYPE +GetComputeType(size_t nbits, size_t block_size, int64_t accuracy_level_attr) { + // For Fp32, only accuracy level 1 or 4 makes sense. + // non-ARM CPU converts Fp16 to Fp32. + // By converting Fp32 to Fp16, precision becomes worse. And due to the casting, + // there is no performance gain. + if (accuracy_level_attr == static_cast(Level4) && + MlasIsQNBitGemmAvailable(nbits, block_size, SQNBIT_CompInt8)) { + return SQNBIT_CompInt8; } - return effective_accuracy_level; + return SQNBIT_CompFp32; } + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) +template <> +MLAS_QNBIT_GEMM_COMPUTE_TYPE +GetComputeType(size_t nbits, size_t block_size, int64_t accuracy_level_attr) { + // For Fp16, only accuracy level 2 or 4 makes sense. + // By converting Fp16 to Fp32, there is not precision increase, and the performance + // becomes worse. + if (accuracy_level_attr == static_cast(Level4) && + MlasIsQNBitGemmAvailable(nbits, block_size, HQNBIT_CompInt8)) { + return HQNBIT_CompInt8; + } + + // if HQNBIT_CompFp16 is not supported, will fallback to unpacked computation. + return HQNBIT_CompFp16; +} +#endif // !MLAS_F16VEC_INTRINSICS_SUPPORTED || !MLAS_TARGET_ARM64 + } // namespace bool GetType(const NodeArg& node_arg, int32_t& type) { @@ -74,10 +96,9 @@ class MatMulNBits final : public OpKernel { N_{narrow(info.GetAttr("N"))}, block_size_{narrow(info.GetAttr("block_size"))}, nbits_{narrow(info.GetAttr("bits"))}, - accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr("accuracy_level"))}, has_g_idx_{info.GetInputCount() > InputIndex::g_idx && info.node().InputDefs()[InputIndex::g_idx]->Exists()}, has_bias_{info.GetInputCount() > InputIndex::bias && info.node().InputDefs()[InputIndex::bias]->Exists()}, - compute_type_{static_cast(accuracy_level_)} { + compute_type_{GetComputeType(nbits_, block_size_, info.GetAttr("accuracy_level"))} { const auto& node = info.node(); auto input_defs = node.InputDefs(); const NodeArg* zero_point_arg = @@ -109,10 +130,9 @@ class MatMulNBits final : public OpKernel { const size_t N_; const size_t block_size_; const size_t nbits_; - const int64_t accuracy_level_; const bool has_g_idx_; const bool has_bias_; - const MLAS_SQNBIT_GEMM_COMPUTE_TYPE compute_type_; + const MLAS_QNBIT_GEMM_COMPUTE_TYPE compute_type_; bool has_unquantized_zero_point_{false}; const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_{}; @@ -143,9 +163,7 @@ class MatMulNBits final : public OpKernel { Tensor* y, AllocatorPtr& allocator, concurrency::ThreadPool* thread_pool, - const MatMulComputeHelper& helper) const { - ORT_THROW("ComputeBPacked is not supported for T1 type."); - } + const MatMulComputeHelper& helper) const; }; template @@ -158,28 +176,28 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All return Status::OK(); } - if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { + if (!MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { return Status::OK(); } if (input_idx == InputIndex::B) { - packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type_); + packed_b_size_ = MlasQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type_); if (packed_b_size_ == 0) { return Status::OK(); } auto qptr = tensor.DataRaw(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), nullptr, has_zp_input_, nullptr, nullptr); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), nullptr, has_zp_input_, nullptr, nullptr); is_packed = true; - } else if (compute_type_ == CompInt8) { + } else if (compute_type_ == SQNBIT_CompInt8) { #ifdef MLAS_TARGET_AMD64_IX86 if (input_idx == InputIndex::scales && packed_b_ != nullptr) { auto sptr = tensor.Data(); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr, - has_zp_input_, nullptr, nullptr); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr, + has_zp_input_, nullptr, nullptr); is_packed = false; } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { auto zptr = tensor.Data(); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr); is_packed = false; } #endif // MLAS_TARGET_AMD64_IX86 @@ -188,6 +206,8 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All return Status::OK(); } +#if !defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || !defined(MLAS_TARGET_ARM64) +// Non-ARM-with-fp16-intrinsics fall back fp16 to fp32. template <> Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, @@ -211,29 +231,29 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*ou return Status::OK(); } - if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { + if (!MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { return Status::OK(); } if (input_idx == InputIndex::B) { - packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type_); + packed_b_size_ = MlasQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type_); if (packed_b_size_ == 0) { return Status::OK(); } auto qptr = tensor.DataRaw(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), - nullptr, has_zp_input_, nullptr, nullptr); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), + nullptr, has_zp_input_, nullptr, nullptr); is_packed = true; - } else if (compute_type_ == CompInt8) { + } else if (compute_type_ == SQNBIT_CompInt8) { #ifdef MLAS_TARGET_AMD64_IX86 if (input_idx == InputIndex::scales && packed_b_ != nullptr) { - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), - scales_fp32_.get(), has_zp_input_, nullptr, nullptr); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), + scales_fp32_.get(), has_zp_input_, nullptr, nullptr); is_packed = false; } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { auto zptr = tensor.Data(); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), - nullptr, has_zp_input_, zptr, nullptr); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), + nullptr, has_zp_input_, zptr, nullptr); is_packed = false; } #endif // MLAS_TARGET_AMD64_IX86 @@ -241,6 +261,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*ou return Status::OK(); } +#endif // end !MLAS_F16VEC_INTRINSICS_SUPPORTED || !MLAS_TARGET_ARM64 template Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, @@ -255,20 +276,20 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& return Status::OK(); } -template <> -Status MatMulNBits::ComputeBPacked(const Tensor* a, - const Tensor* scales, - const Tensor* zero_points, - const Tensor* bias, - Tensor* y, - AllocatorPtr& allocator, - concurrency::ThreadPool* thread_pool, - const MatMulComputeHelper& helper) const { - const auto* a_data = a->Data(); - const auto* scales_data = scales->Data(); +template +Status MatMulNBits::ComputeBPacked(const Tensor* a, + const Tensor* scales, + const Tensor* zero_points, + const Tensor* bias, + Tensor* y, + AllocatorPtr& allocator, + concurrency::ThreadPool* thread_pool, + const MatMulComputeHelper& helper) const { + const auto* a_data = a->Data(); + const auto* scales_data = scales->Data(); const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); - const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); - auto* y_data = y->MutableData(); + const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); + auto* y_data = y->MutableData(); const size_t batch_count = helper.OutputOffsets().size(); const size_t M = static_cast(helper.M()); @@ -277,19 +298,19 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, const size_t lda = helper.Lda(false); IAllocatorUniquePtr workspace{}; - const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize( + const size_t workspace_size = MlasQNBitGemmBatchWorkspaceSize( M, N, K, batch_count, nbits_, block_size_, compute_type_); if (workspace_size > 0) { // Use reserve since no caching is needed workspace = IAllocator::MakeUniquePtr(allocator, workspace_size, true); } - InlinedVector data(batch_count); + InlinedVector> data(batch_count); for (size_t i = 0; i < batch_count; ++i) { data[i].A = a_data + helper.LeftOffsets()[i]; data[i].lda = lda; #ifdef MLAS_TARGET_AMD64_IX86 - if (compute_type_ == CompInt8) { + if (compute_type_ == SQNBIT_CompInt8) { data[i].QuantBDataWorkspace = packed_b_.get(); } #endif @@ -300,11 +321,12 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, data[i].C = y_data + helper.OutputOffsets()[i]; data[i].ldc = N; } - MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type_, data.data(), workspace.get(), - thread_pool); + MlasQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type_, data.data(), workspace.get(), + thread_pool); return Status::OK(); } +#if !defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || !defined(MLAS_TARGET_ARM64) template <> Status MatMulNBits::ComputeBPacked(const Tensor* a, const Tensor* scales, @@ -327,7 +349,7 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, const size_t lda = helper.Lda(false); IAllocatorUniquePtr workspace{}; - const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize( + const size_t workspace_size = MlasQNBitGemmBatchWorkspaceSize( M, N, K, batch_count, nbits_, block_size_, compute_type_); if (workspace_size > 0) { // Use reserve since no caching is needed @@ -361,12 +383,12 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, size_t c_size = static_cast(y->Shape().Size()); std::vector c_v(c_size); - InlinedVector data(batch_count); + InlinedVector> data(batch_count); for (size_t i = 0; i < batch_count; ++i) { data[i].A = tmp_a_data_ptr.get() + helper.LeftOffsets()[i]; data[i].lda = lda; #ifdef MLAS_TARGET_AMD64_IX86 - if (compute_type_ == CompInt8) { + if (compute_type_ == SQNBIT_CompInt8) { data[i].QuantBDataWorkspace = packed_b_.get(); } #endif @@ -377,11 +399,12 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, data[i].C = c_v.data() + helper.OutputOffsets()[i]; data[i].ldc = N; } - MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type_, data.data(), workspace.get(), - thread_pool); + MlasQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type_, data.data(), workspace.get(), + thread_pool); MlasConvertFloatToHalfBuffer(c_v.data(), y_data, c_size); return Status::OK(); } +#endif // end of !MLAS_F16VEC_INTRINSICS_SUPPORTED || !MLAS_TARGET_AMD64 template <> Status MatMulNBits::ComputeBUnpacked(const Tensor* a, @@ -517,9 +540,10 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, const size_t ldb = helper.Ldb(true); float* scales_ptr = nullptr; + IAllocatorUniquePtr temp_scales; if (!scales_fp32_) { auto scales_size = static_cast(scales->Shape().Size()); - auto temp_scales = IAllocator::MakeUniquePtr(allocator, scales_size, true); + temp_scales = IAllocator::MakeUniquePtr(allocator, scales_size, true); MlasConvertHalfToFloatBuffer(scales_data, temp_scales.get(), scales_size); scales_ptr = temp_scales.get(); } else { @@ -600,8 +624,9 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, if (bias) { float* bias_ptr = nullptr; const size_t bias_size = static_cast(bias->Shape().Size()); + IAllocatorUniquePtr bias_temp; if (!bias_fp32_) { - auto bias_temp = IAllocator::MakeUniquePtr(allocator, bias_size, true); + bias_temp = IAllocator::MakeUniquePtr(allocator, bias_size, true); MlasConvertHalfToFloatBuffer(bias->Data(), bias_temp.get(), bias_size); bias_ptr = bias_temp.get(); } else { @@ -654,11 +679,11 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { // clang-format on if (has_single_b_matrix && - packed_b_) { // Assume that MlasSQNBitGemmBatch() always requires packed B. - // If this changes, i.e., if MlasIsSQNBitGemmAvailable() can return true while - // MlasSQNBitGemmPackQuantBDataSize() returns 0, we can consider calling MlasSQNBitGemmBatch() + packed_b_) { // Assume that MlasQNBitGemmBatch() always requires packed B. + // If this changes, i.e., if MlasIsQNBitGemmAvailable() can return true while + // MlasQNBitGemmPackQuantBDataSize() returns 0, we can consider calling MlasQNBitGemmBatch() // with B directly too. - if (MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { + if (MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { return ComputeBPacked(a, scales, zero_points, bias, y, allocator, thread_pool, helper); } } diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 232bf2261ef4c..9608644a22523 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -27,51 +27,50 @@ Module Name: * @brief Define compute types of block quantization, in order of decreasing accuracy. */ typedef enum { - CompUndef = 0, /*!< undef */ - CompFp32, /*!< input fp32, accumulator fp32 */ - CompFp16, /*!< input fp16, accumulator fp16 */ - CompBf16, /*!< input bf16, accumulator fp32 */ - CompInt8, /*!< input int8, accumulator int32 */ - - // special values that should be the first and last actual values - - CompMostAccurate = CompUndef, - CompLeastAccurate = CompInt8, -} MLAS_SQNBIT_GEMM_COMPUTE_TYPE; + SQNBIT_CompFp32, /*!< input fp32, accumulator fp32 */ + HQNBIT_CompFp16, /*!< input fp16, accumulator fp16 */ + BHQNBIT_CompBf16, /*!< input bf16, accumulator fp32 */ + SQNBIT_CompInt8, /*!< input int8, accumulator int32, input fp32 */ + HQNBIT_CompInt8, /*!< input int8, accumulator int32, input fp16 */ +} MLAS_QNBIT_GEMM_COMPUTE_TYPE; /** * @brief Data parameters for float/n-bit quantized int GEMM routine. + * + * @tparam T data type of input A */ -struct MLAS_SQNBIT_GEMM_DATA_PARAMS { - const float* A = nullptr; ///< address of A (float32 matrix) +template +struct MLAS_QNBIT_GEMM_DATA_PARAMS { + const T* A = nullptr; ///< address of A (float32/16 matrix) size_t lda = 0; ///< leading dimension of A const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values) const std::byte* PackedQuantBData = nullptr; /// address of packed quantized B data - const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block + const T* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block - const float* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block - const float* Bias = nullptr; ///< optional address of Bias, vector size N - float* C = nullptr; ///< address of result matrix + const T* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block + const T* Bias = nullptr; ///< optional address of Bias, vector size N + T* C = nullptr; ///< address of result matrix size_t ldc = 0; ///< leading dimension of C ///< optional post processing to apply to result matrix - MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; + MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; }; /** * @brief Batched GEMM: C = A * B + Bias - * A must be a float32 matrix + * A must be a float32/16 matrix * B must be a quantized and packed n-bit int matrix * - * Call MlasIsSQNBitGemmAvailable() with the same parameters to determine whether this function may be called. + * Call MlasIsQNBitGemmAvailable() with the same parameters to determine whether this function may be called. * - * Call MlasSQNBitGemmPackQuantBDataSize() with the same parameters to determine whether - * MLAS_SQNBIT_GEMM_DATA_PARAMS::QuantBData in `DataParams` should point to a buffer packed with - * MlasSQNBitGemmPackQuantBData(). + * Call MlasQNBitGemmPackQuantBDataSize() with the same parameters to determine whether + * MLAS_QNBIT_GEMM_DATA_PARAMS::QuantBData in `DataParams` should point to a buffer packed with + * MlasQNBitGemmPackQuantBData(). * - * Call MlasSQNBitGemmBatchWorkspaceSize() with the same parameters to determine whether `Workspace` should + * Call MlasQNBitGemmBatchWorkspaceSize() with the same parameters to determine whether `Workspace` should * point to an intermediate workspace buffer. * + * @tparam T data type of input A * @param[in] M row size of matrix A and C * @param[in] N column size of matrix B and C * @param[in] K column size of matrix A and row size of matrix B @@ -81,36 +80,37 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) * @param[inout] DataParams An array (size BatchN) of parameter blocks * @param[in] Workspace Address of intermediate workspace buffer. - If MlasSQNBitGemmBatchWorkspaceSize() returns a non-zero value, this must be a + If MlasQNBitGemmBatchWorkspaceSize() returns a non-zero value, this must be a buffer with at least that many bytes. Otherwise, it may be nullptr. * @param[in] ThreadPool optional thread pool to use */ +template void MLASCALL -MlasSQNBitGemmBatch( +MlasQNBitGemmBatch( size_t M, size_t N, size_t K, size_t BatchN, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, MLAS_THREADPOOL* ThreadPool = nullptr ); /** - * @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform. + * @brief Determines whether a float32/16 quantized n-bit int GEMM implementation is available on the current platform. * * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ bool MLASCALL -MlasIsSQNBitGemmAvailable( +MlasIsQNBitGemmAvailable( size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); /** @@ -126,22 +126,22 @@ MlasIsSQNBitGemmAvailable( * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ size_t MLASCALL -MlasSQNBitGemmBatchWorkspaceSize( +MlasQNBitGemmBatchWorkspaceSize( size_t M, size_t N, size_t K, size_t BatchN, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); /** * @brief Gets the size in bytes of the packed quantized B data. - * If non-zero, the quantized B data must first be packed by calling MlasSQNBitGemmPackQuantBData() with a buffer of - * this size, and then that packed quantized B data buffer must be passed to MlasSQNBitGemmBatch(). - * If zero, MlasSQNBitGemmPackQuantBData() must not be called and the quantized B data must be directly passed to - * MlasSQNBitGemmBatch(). + * If non-zero, the quantized B data must first be packed by calling MlasQNBitGemmPackQuantBData() with a buffer of + * this size, and then that packed quantized B data buffer must be passed to MlasQNBitGemmBatch(). + * If zero, MlasQNBitGemmPackQuantBData() must not be called and the quantized B data must be directly passed to + * MlasQNBitGemmBatch(). * * @param[in] N column size of matrix B and C * @param[in] K column size of matrix A and row size of matrix B @@ -150,12 +150,12 @@ MlasSQNBitGemmBatchWorkspaceSize( * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ size_t MLASCALL -MlasSQNBitGemmPackQuantBDataSize( +MlasQNBitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); /** @@ -186,12 +186,12 @@ MlasSQNBitGemmPackQuantBDataSize( * @param[in] ThreadPool thread pool to use (no parallel if nullptr) */ void MLASCALL -MlasSQNBitGemmPackQuantBData( +MlasQNBitGemmPackQuantBData( size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, void* PackedQuantBDataAndOrBlkSum, const void* QuantBScale, diff --git a/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp index d7d99899d544a..f1bc013a469d9 100644 --- a/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp +++ b/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp @@ -23,7 +23,7 @@ Module Name: #include #include "fp16_common.h" -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_neon.h" namespace sqnbitgemm_neon @@ -131,7 +131,7 @@ HQ4BitGemmPackQuantBData_CompFp16( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, std::byte* PackedQuantBDataBegin, MLAS_THREADPOOL* ThreadPool diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 13ea8d96c20e4..9bc574a845a3e 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1017,17 +1017,17 @@ extern const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchAvx512; // Float/quantized n-bit integer matrix/matrix multiply dispatch structure. // -struct MLAS_SQNBIT_GEMM_DISPATCH; +struct MLAS_QNBIT_GEMM_DISPATCH; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon; +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2; +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni; +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512; +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; // // Quantized depthwise convolution kernels. @@ -1184,7 +1184,7 @@ struct MLAS_PLATFORM { const MLAS_FPQ4GEMM_DISPATCH* FpQ4GemmDispatch{nullptr}; const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr}; - const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr}; + const MLAS_QNBIT_GEMM_DISPATCH* QNBitGemmDispatch{nullptr}; MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel; MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 23d29fd02fa5a..12f7dd3e74dbc 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -387,7 +387,7 @@ Return Value: this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernelAvx2; this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernelAvx2; this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelFma3; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2; this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx2; this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2; @@ -417,7 +417,7 @@ Return Value: this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvxVnni; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2vnni; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2vnni; } #if !defined(ORT_MINIMAL_BUILD) @@ -458,7 +458,7 @@ Return Value: this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx512Core; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Core; this->FpQ4GemmDispatch = &MlasFpQ4GemmDispatchAvx512; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512; // // Check if the processor supports AVX512VNNI. @@ -471,7 +471,7 @@ Return Value: this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Vnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Vnni; this->Q8Q4GemmDispatch = &MlasQ8Q4GemmDispatchAvx512vnni; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512vnni; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512vnni; } } } @@ -562,7 +562,7 @@ Return Value: this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; // MlasSQNBitGemmDispatchNeon has a dependency on dot product instructions - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; } #if defined(__linux__) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp similarity index 70% rename from onnxruntime/core/mlas/lib/sqnbitgemm.cpp rename to onnxruntime/core/mlas/lib/qnbitgemm.cpp index a45494ef2e04f..635a3b47a23fa 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -6,16 +6,16 @@ Licensed under the MIT License. Module Name: - sqnbitgemm.cpp + qnbitgemm.cpp Abstract: This module implements the float/quantized n-bit integer matrix - multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch, + multiplication hardware agnostic entrypoint, MlasQNBitGemmBatch, as well as some SQNBitGemm-related query functions. --*/ -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_q8_block.h" #include @@ -23,35 +23,40 @@ Module Name: namespace { -enum SQNBitGemmVariant { +enum QNBitGemmVariant { SQNBitGemmVariantInvalid = -1, // Valid variants SQNBitGemmVariant_BitWidth4_CompFp32 = 0, SQNBitGemmVariant_BitWidth4_CompInt8, + HQNBitGemmVariant_BitWidth4_CompFp16, + HQNBitGemmVariant_BitWidth4_CompInt8, // End of valid variants - // Keep this element last and ensure that its value is the number of valid SQNBitGemmVariant values. + // Keep this element last and ensure that its value is the number of valid QNBitGemmVariant values. // Its value is used as an array size. SQNBitGemmVariantCount, }; -SQNBitGemmVariant -GetSQNBitGemmVariant( +QNBitGemmVariant +GetQNBitGemmVariant( size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { if (BlkBitWidth == 4 && (BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { - if (ComputeType == CompFp32 || - ComputeType == CompUndef) { // treat CompUndef (undefined) as CompFp32 + if (ComputeType == SQNBIT_CompFp32) { return SQNBitGemmVariant_BitWidth4_CompFp32; - } else if (ComputeType == CompInt8) { + } else if (ComputeType == HQNBIT_CompFp16) { + return HQNBitGemmVariant_BitWidth4_CompFp16; + } else if (ComputeType == SQNBIT_CompInt8) { return SQNBitGemmVariant_BitWidth4_CompInt8; + } else if (ComputeType == HQNBIT_CompInt8) { + return HQNBitGemmVariant_BitWidth4_CompInt8; } } @@ -61,18 +66,18 @@ GetSQNBitGemmVariant( } // namespace bool MLASCALL -MlasIsSQNBitGemmAvailable( +MlasIsQNBitGemmAvailable( size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; if (Dispatch == nullptr) { return false; } - const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); + const auto Variant = GetQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); switch (Variant) { case SQNBitGemmVariant_BitWidth4_CompFp32: { @@ -94,80 +99,80 @@ namespace { size_t -SQNBitGemmPerGemmWorkspaceSize( +QNBitGemmPerGemmWorkspaceSize( size_t M, size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; if (Dispatch == nullptr) { return 0; } - if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPerGemmWorkspaceSize != nullptr) { - return Dispatch->SQ4BitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, ComputeType); + if (BlkBitWidth == 4 && Dispatch->Q4BitGemmPerGemmWorkspaceSize != nullptr) { + return Dispatch->Q4BitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, ComputeType); } return 0; } size_t -SQNBitGemmPerGemmWorkspaceAlignment( +QNBitGemmPerGemmWorkspaceAlignment( size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; if (Dispatch == nullptr) { return 1; } - if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPerGemmWorkspaceAlignment != nullptr) { - return Dispatch->SQ4BitGemmPerGemmWorkspaceAlignment(BlkLen, ComputeType); + if (BlkBitWidth == 4 && Dispatch->Q4BitGemmPerGemmWorkspaceAlignment != nullptr) { + return Dispatch->Q4BitGemmPerGemmWorkspaceAlignment(BlkLen, ComputeType); } return 1; } size_t -SQNBitGemmPerGemmWorkspaceStride( +QNBitGemmPerGemmWorkspaceStride( size_t M, size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto Size = SQNBitGemmPerGemmWorkspaceSize(M, N, K, BlkBitWidth, BlkLen, ComputeType); - const auto Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); + const auto Size = QNBitGemmPerGemmWorkspaceSize(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const auto Alignment = QNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); return MlasDivRoundup(Size, Alignment) * Alignment; } } // namespace size_t MLASCALL -MlasSQNBitGemmBatchWorkspaceSize( +MlasQNBitGemmBatchWorkspaceSize( size_t M, size_t N, size_t K, size_t BatchN, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const size_t PerGemmWorkspaceStride = QNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); if (PerGemmWorkspaceStride == 0) { return 0; } - const size_t Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); + const size_t Alignment = QNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); const size_t WorkspaceSize = BatchN * PerGemmWorkspaceStride; @@ -175,21 +180,21 @@ MlasSQNBitGemmBatchWorkspaceSize( } size_t MLASCALL -MlasSQNBitGemmPackQuantBDataSize( +MlasQNBitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; if (Dispatch == nullptr) { return 0; } - if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBDataSize != nullptr) { - return Dispatch->SQ4BitGemmPackQuantBDataSize( + if (BlkBitWidth == 4 && Dispatch->Q4BitGemmPackQuantBDataSize != nullptr) { + return Dispatch->Q4BitGemmPackQuantBDataSize( N, K, BlkLen, ComputeType ); } @@ -213,12 +218,12 @@ struct PerGemmQuantAWorkspace { }; void MLASCALL -MlasSQNBitGemmPackQuantBData( +MlasQNBitGemmPackQuantBData( size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, void* PackedQuantBDataAndOrBlkSumWorkspace, const void* QuantBScale, @@ -227,15 +232,15 @@ MlasSQNBitGemmPackQuantBData( MLAS_THREADPOOL* ThreadPool ) { - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; if (Dispatch == nullptr) { return; } if (BlkBitWidth == 4) { - if (ComputeType == CompInt8 && Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + if (ComputeType == SQNBIT_CompInt8 && Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum( N, K, @@ -295,22 +300,11 @@ AddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t } } -typedef void(SQNBitGemmFn)( - size_t BlkLen, - size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, - void* PerGemmWorkspace, - size_t RangeStartM, - size_t RangeCountM, - size_t RangeStartN, - size_t RangeCountN -); - void SQ4BitGemm_CompFp32( const size_t BlkLen, const size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + const MLAS_QNBIT_GEMM_DATA_PARAMS* const DataParams, void* const PerGemmWorkspace, const size_t RangeStartM, const size_t RangeCountM, @@ -355,7 +349,7 @@ SQ4BitGemm_CompFp32( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompFp32( + GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmM1Kernel_CompFp32( BlkLen, a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias ); @@ -393,7 +387,7 @@ SQ4BitGemm_CompFp32( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - GetMlasPlatform().SQNBitGemmDispatch->Q4BitBlkDequantBForSgemm_CompFp32( + GetMlasPlatform().QNBitGemmDispatch->Q4BitBlkDequantBForSgemm_CompFp32( BlkLen, dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks ); @@ -429,7 +423,7 @@ void SQ4BitGemm_CompInt8( const size_t BlkLen, const size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + const MLAS_QNBIT_GEMM_DATA_PARAMS* const DataParams, void* const PerGemmWorkspace, const size_t RangeStartM, const size_t RangeCountM, @@ -500,10 +494,10 @@ SQ4BitGemm_CompInt8( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8 != nullptr) { + if (GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_CompInt8 != nullptr) { size_t RowsRemaining = RangeCountM; while (RowsRemaining > 0) { - const auto RowsHandled = GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( + const auto RowsHandled = GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( BlkLen, a_row, b_col, b_col_scale, b_col_zp, c_blk, RowsRemaining, CountN, K, k_blks, ldc, bias ); @@ -522,10 +516,10 @@ SQ4BitGemm_CompInt8( } } #ifdef MLAS_TARGET_AMD64_IX86 - else if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) + else if (GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) { const float* b_blk_sum = QuantBBlkSum + n * k_blks; - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8( + GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8( BlkLen, QuantA, QuantAScale, @@ -554,26 +548,29 @@ SQ4BitGemm_CompInt8( } } -typedef void(InitializeWorkspaceFn)( +template +void +InitializeWorkspace_CompInt8( size_t M, size_t N, size_t K, size_t BatchN, size_t BlkLen, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, size_t PerGemmWorkspaceStride, MLAS_THREADPOOL* ThreadPool ); +template <> void -InitializeWorkspace_CompInt8( +InitializeWorkspace_CompInt8( size_t M, size_t N, size_t K, size_t BatchN, size_t BlkLen, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, size_t PerGemmWorkspaceStride, MLAS_THREADPOOL* ThreadPool @@ -581,8 +578,8 @@ InitializeWorkspace_CompInt8( { MLAS_UNREFERENCED_PARAMETER(N); - const auto QuantizeARow = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8; - const auto QuantizeARow2 = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8; + const auto QuantizeARow = GetMlasPlatform().QNBitGemmDispatch->QuantizeARow_CompInt8; + const auto QuantizeARow2 = GetMlasPlatform().QNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8; const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); @@ -622,61 +619,153 @@ InitializeWorkspace_CompInt8( } } -struct Operations { - InitializeWorkspaceFn* InitializeWorkspace = nullptr; - SQNBitGemmFn* SQNBitGemm = nullptr; -}; +template <> +void +InitializeWorkspace_CompInt8( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkLen, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + size_t PerGemmWorkspaceStride, + MLAS_THREADPOOL* ThreadPool +) { + MLAS_UNREFERENCED_PARAMETER(M); + MLAS_UNREFERENCED_PARAMETER(N); + MLAS_UNREFERENCED_PARAMETER(K); + MLAS_UNREFERENCED_PARAMETER(BatchN); + MLAS_UNREFERENCED_PARAMETER(BlkLen); + MLAS_UNREFERENCED_PARAMETER(DataParams); + MLAS_UNREFERENCED_PARAMETER(Workspace); + MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspaceStride); + MLAS_UNREFERENCED_PARAMETER(ThreadPool); +} + +template +using InitializeWorkspaceFn = std::function* DataParams, + void* Workspace, + size_t PerGemmWorkspaceStride, + MLAS_THREADPOOL* ThreadPool +)>; + +template +InitializeWorkspaceFn +GetInitializeWorkspace(QNBitGemmVariant variant); + +template <> +InitializeWorkspaceFn +GetInitializeWorkspace(QNBitGemmVariant variant) +{ + switch (variant) { + case SQNBitGemmVariant_BitWidth4_CompInt8: + return InitializeWorkspace_CompInt8; + default: + return nullptr; + } +} + +template <> +InitializeWorkspaceFn +GetInitializeWorkspace(QNBitGemmVariant variant) +{ + switch (variant) { + case HQNBitGemmVariant_BitWidth4_CompInt8: + return InitializeWorkspace_CompInt8; + default: + return nullptr; + } +} -constexpr auto OperationMap = []() { - std::array ops; +template +using QNBitGemmFn = std::function* const DataParams, + void* const PerGemmWorkspace, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +)>; - ops[SQNBitGemmVariant_BitWidth4_CompFp32].SQNBitGemm = SQ4BitGemm_CompFp32; +template +QNBitGemmFn +GetQNBitGemm(QNBitGemmVariant variant); - ops[SQNBitGemmVariant_BitWidth4_CompInt8].InitializeWorkspace = InitializeWorkspace_CompInt8; - ops[SQNBitGemmVariant_BitWidth4_CompInt8].SQNBitGemm = SQ4BitGemm_CompInt8; +template <> +QNBitGemmFn +GetQNBitGemm(QNBitGemmVariant variant) +{ + switch (variant) { + case SQNBitGemmVariant_BitWidth4_CompFp32: + return SQ4BitGemm_CompFp32; + case SQNBitGemmVariant_BitWidth4_CompInt8: + return SQ4BitGemm_CompInt8; + default: + return nullptr; + } +} - return ops; -}(); +template <> +QNBitGemmFn +GetQNBitGemm(QNBitGemmVariant variant) +{ + switch (variant) { + case HQNBitGemmVariant_BitWidth4_CompFp16: + return nullptr; + default: + return nullptr; + } +} } // namespace +template void MLASCALL -MlasSQNBitGemmBatch( +MlasQNBitGemmBatch( const size_t M, const size_t N, const size_t K, const size_t BatchN, const size_t BlkBitWidth, const size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, MLAS_THREADPOOL* ThreadPool ) { - const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); + const auto Variant = GetQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); assert(Variant != SQNBitGemmVariantInvalid); // // Ensure `Workspace` has correct alignment. // if (Workspace != nullptr) { - const size_t Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); + const size_t Alignment = QNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); const uintptr_t WorkspaceAddress = reinterpret_cast(Workspace); Workspace = reinterpret_cast( (WorkspaceAddress + Alignment - 1) & (~(Alignment - 1)) ); } - const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const size_t PerGemmWorkspaceStride = QNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); - if (const auto InitializeWorkspaceOperation = OperationMap[Variant].InitializeWorkspace; + if (const auto InitializeWorkspaceOperation = GetInitializeWorkspace(Variant); InitializeWorkspaceOperation != nullptr) { InitializeWorkspaceOperation( M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool ); } - const auto ComputeOperation = OperationMap[Variant].SQNBitGemm; + const auto ComputeOperation = GetQNBitGemm(Variant); const size_t BlockCountK = MlasDivRoundup(K, BlkLen); @@ -685,11 +774,11 @@ MlasSQNBitGemmBatch( const auto* Data = &DataParams[gemm_i]; void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - if (ComputeType == CompInt8 && GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); - const_cast(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; - const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; - const_cast(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + if (ComputeType == SQNBIT_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; + const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); } else { @@ -756,11 +845,11 @@ MlasSQNBitGemmBatch( void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - if (ComputeType == CompInt8 && GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); - const_cast(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; - const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; - const_cast(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + if (ComputeType == SQNBIT_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; + const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); @@ -769,3 +858,33 @@ MlasSQNBitGemmBatch( } }); } + +template +void MLASCALL +MlasQNBitGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const size_t BlkBitWidth, + const size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + MLAS_THREADPOOL* ThreadPool +); + +template +void MLASCALL +MlasQNBitGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const size_t BlkBitWidth, + const size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + MLAS_THREADPOOL* ThreadPool +); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/qnbitgemm.h similarity index 91% rename from onnxruntime/core/mlas/lib/sqnbitgemm.h rename to onnxruntime/core/mlas/lib/qnbitgemm.h index 2da336ca2f0ec..28e17f14b02c9 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm.h @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - sqnbitgemm.h + qnbitgemm.h Abstract: @@ -46,24 +46,25 @@ MlasAlignAddress(void* addr, const size_t alignment) return addr; } +template struct PackedQuantBDataStruct { PackedQuantBDataStruct(void* PackedQuantBWorkspace, size_t N, size_t BlockCountK, size_t BlkLen) : QuantBWorkspace_(PackedQuantBWorkspace), N_(N), BlockCountK_(BlockCountK), BlkLen_(BlkLen) { - // TODO: duplicate code from SQ4BitGemmPackQuantBDataSize + // TODO: duplicate code from Q4BitGemmPackQuantBDataSize constexpr size_t BlkBitWidth = 4; const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); + size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(T); // _mm256_load_si256 requires alignment on a 32-byte boundary PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32); - QuantBBlkSum = (float*)(PackedQuantBData + PackedQuantBDataSize); - QuantBBlkSum = (float*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment()); - PackedQuantBScale = (float*)((std::byte*)QuantBBlkSum + BlkSumSize); + QuantBBlkSum = (T*)(PackedQuantBData + PackedQuantBDataSize); + QuantBBlkSum = (T*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment()); + PackedQuantBScale = (T*)((std::byte*)QuantBBlkSum + BlkSumSize); } std::byte* PackedQuantBData; - float* PackedQuantBScale; - float* QuantBBlkSum; + T* PackedQuantBScale; + T* QuantBBlkSum; void* QuantBWorkspace_; size_t N_, BlockCountK_, BlkLen_; @@ -84,27 +85,27 @@ MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount) // Kernel dispatch structure. // -struct MLAS_SQNBIT_GEMM_DISPATCH { +struct MLAS_QNBIT_GEMM_DISPATCH { // // Quantized B data packing function prototypes. // - /** Gets size of packed quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBDataSize(). */ + /** Gets size of packed quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBDataSize(). */ typedef size_t(SQ4BitGemmPackQuantBDataSize_Fn)( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - SQ4BitGemmPackQuantBDataSize_Fn* SQ4BitGemmPackQuantBDataSize = nullptr; + SQ4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr; - /** Packs quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBData(). */ + /** Packs quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBData(). */ typedef void(SQ4BitGemmPackQuantBData_Fn)( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, std::byte* PackedQuantBDataBegin, MLAS_THREADPOOL* ThreadPool @@ -116,12 +117,12 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, bool has_zp_input, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ); @@ -146,10 +147,10 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - SQ4BitGemmPerGemmWorkspaceSize_Fn* SQ4BitGemmPerGemmWorkspaceSize = nullptr; + SQ4BitGemmPerGemmWorkspaceSize_Fn* Q4BitGemmPerGemmWorkspaceSize = nullptr; /** * @brief Gets the required byte alignment of the per-GEMM intermediate workspace. @@ -159,13 +160,13 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { */ typedef size_t(SQ4BitGemmPerGemmWorkspaceAlignment_Fn)( size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - SQ4BitGemmPerGemmWorkspaceAlignment_Fn* SQ4BitGemmPerGemmWorkspaceAlignment = nullptr; + SQ4BitGemmPerGemmWorkspaceAlignment_Fn* Q4BitGemmPerGemmWorkspaceAlignment = nullptr; // - // CompFp32 kernel function prototypes. + // SQNBIT_CompFp32 kernel function prototypes. // /** @@ -231,7 +232,7 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { Q4BitBlkDequantBForSgemm_CompFp32_Fn* Q4BitBlkDequantBForSgemm_CompFp32 = nullptr; // - // CompInt8 kernel function prototypes. + // SQNBIT_CompInt8 kernel function prototypes. // /** diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index baaa4ba1a3b1f..01443e2ff077f 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -19,7 +19,7 @@ Module Name: #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_int8.h" #include "sqnbitgemm_kernel_avx2_int8_blklen16.h" @@ -1306,12 +1306,12 @@ SQ4BitGemmPackQuantBDataAndBlkSum( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, bool has_zp_input, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ) { @@ -1319,9 +1319,9 @@ SQ4BitGemmPackQuantBDataAndBlkSum( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - // TODO: always use SubBlkLen = 64 in CompInt8 + // TODO: always use SubBlkLen = 64 in SQNBIT_CompInt8 size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); - if (BlkLen == 32 && ComputeType == CompInt8) { + if (BlkLen == 32 && ComputeType == SQNBIT_CompInt8) { SubBlkLen = 64; } PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); @@ -1330,15 +1330,15 @@ SQ4BitGemmPackQuantBDataAndBlkSum( // // Kernel dispatch structure definition. // -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = Q4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.Q4BitGemmPerGemmWorkspaceSize = Q4BitGemmPerGemmWorkspaceSize; + d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; @@ -1349,15 +1349,15 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { return d; }(); -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = Q4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.Q4BitGemmPerGemmWorkspaceSize = Q4BitGemmPerGemmWorkspaceSize; + d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h index 80d67806ea6e8..445ead329acf8 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h index af6f52090adcb..5dab8091ce760 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h index 4cb9dddd4ecf9..d4b89bd9bad2d 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" template diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index 13bd369a065bb..425dfbe87c982 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -19,7 +19,7 @@ Module Name: #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_int8.h" #include "sqnbitgemm_kernel_avx512_int8_blklen16.h" @@ -28,7 +28,7 @@ Module Name: #include "sqnbitgemm_kernel_avx512_int8_blklen128.h" // -// CompFp32 kernel implementation. +// SQNBIT_CompFp32 kernel implementation. // #include "sqnbitgemm_kernel_avx_common_fp32.h" @@ -151,7 +151,7 @@ SQ4BitGemmM1Kernel_CompFp32_avx512( } // -// CompInt8 kernel implementation. +// SQNBIT_CompInt8 kernel implementation. // MLAS_FORCEINLINE @@ -332,12 +332,12 @@ SQ4BitGemmPackQuantBDataAndBlkSum512( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, bool has_zp_input, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ) { @@ -346,21 +346,21 @@ SQ4BitGemmPackQuantBDataAndBlkSum512( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); - if (ComputeType == CompInt8) { + if (ComputeType == SQNBIT_CompInt8) { SubBlkLen = 128; } PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); } -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = Q4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512; - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.Q4BitGemmPerGemmWorkspaceSize = Q4BitGemmPerGemmWorkspaceSize; + d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h index 7d9dc36854621..8f1ea6676b788 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" @@ -81,7 +81,7 @@ accumulate_blklen32_r2c1blk2_avx2( _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av01_32_epi8, bv1_32_epi8) ); const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); - + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); @@ -143,7 +143,7 @@ accumulate_blklen32_r2c1blk2_avx2( // const __m256i bv1 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 16, 17,...30, 31, 48, 49,...,62, 63 __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 16, 17,...30, 31, 48, 49,...,62, 63 - //__m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); + //__m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); //// This (the second line below) saves one _mm256_extracti128_si256 against using _mm256_set_m128i. ////__m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); @@ -184,7 +184,7 @@ accumulate_blklen32_r2c1blk1_avx2( const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); - + const int8_t zp = get_zp(true, QuantBZeroPointPtr); const __m256i bzp = _mm256_set1_epi8(zp); bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); @@ -435,7 +435,7 @@ Q4Int8Gemm2x4BlkLen32Avx2( } } -template +template void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( const std::byte* QuantA, const std::byte* QuantBData, @@ -877,7 +877,7 @@ MLAS_FORCEINLINE QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, C + multipleRows * ldc + multipleCols, remainingRows, - remainingCols, + remainingCols, BlockCountK, Bias ? Bias + multipleCols : nullptr, lda, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h index 60a887345d0e0..d79554c34c108 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx512_int8_blklen64.h" diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h index bb14babd6c2b1..03064886caf24 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx2_int8_blklen16.h" #include "sqnbitgemm_kernel_avx512_int8_blklen32.h" diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h index e9df6b952bd27..3b1096ac05ba7 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx2_int8_blklen32.h" #include "sqnbitgemm_kernel_avx512_int8_blklen64.h" diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h index 2a65ac4af0c1d..72ce28d834199 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" static MLAS_FORCEINLINE __m256 diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index 6a5c01162c51b..777d4609ef5d4 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -19,7 +19,7 @@ Module Name: #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_fp32.h" #include "sqnbitgemm_kernel_avx_common_int8.h" @@ -314,12 +314,12 @@ SQ4BitGemmPackQuantBDataAndBlkSum512vnni( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, bool has_zp_input, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ) { @@ -328,7 +328,7 @@ SQ4BitGemmPackQuantBDataAndBlkSum512vnni( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); - if (ComputeType == CompInt8) { + if (ComputeType == SQNBIT_CompInt8) { SubBlkLen = 128; } PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); @@ -337,15 +337,15 @@ SQ4BitGemmPackQuantBDataAndBlkSum512vnni( // // Kernel dispatch structure definition. // -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = Q4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512vnni; - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.Q4BitGemmPerGemmWorkspaceSize = Q4BitGemmPerGemmWorkspaceSize; + d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index 177f5518bb891..b0367b7fb9a15 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -1,5 +1,5 @@ #pragma once -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_q8_block.h" // @@ -7,16 +7,16 @@ // static size_t -SQ4BitGemmPackQuantBDataSize( +Q4BitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { constexpr size_t BlkBitWidth = 4; const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - if (ComputeType == CompInt8) { + if (ComputeType == SQNBIT_CompInt8) { size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); const size_t ScaleSize = N * BlockCountK * sizeof(float); size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); @@ -39,7 +39,7 @@ SQ4BitGemmPackQuantBData( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, const std::byte* QuantBDataBegin, std::byte* PackedQuantBDataBegin, MLAS_THREADPOOL* ThreadPool @@ -304,7 +304,7 @@ PackQuantBDataAndBlkSum( const float* QuantBScaleBegin, bool has_zp_input, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ) { @@ -326,18 +326,18 @@ PackQuantBDataAndBlkSum( // static size_t -SQ4BitGemmPerGemmWorkspaceSize( +Q4BitGemmPerGemmWorkspaceSize( size_t M, size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { MLAS_UNREFERENCED_PARAMETER(N); switch(ComputeType) { - case CompInt8: { + case SQNBIT_CompInt8: { // workspace buffer is used for block quantization of A to int8 const size_t BlockCountK = MlasDivRoundup(K, BlkLen); // QuantData + Scale + BlkSum @@ -351,15 +351,15 @@ SQ4BitGemmPerGemmWorkspaceSize( } static size_t -SQ4BitGemmPerGemmWorkspaceAlignment( +Q4BitGemmPerGemmWorkspaceAlignment( size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { MLAS_UNREFERENCED_PARAMETER(BlkLen); switch (ComputeType) { - case CompInt8: { + case SQNBIT_CompInt8: { return Q8BlkAlignment(); } default: { diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_fp32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_fp32.h index 5cd380e591098..d15cfc782e125 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_fp32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_fp32.h @@ -1,5 +1,5 @@ #pragma once -#include "sqnbitgemm.h" +#include "qnbitgemm.h" template MLAS_FORCEINLINE diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h index 895ce6cd091c2..2e96082968866 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_q8_block.h" diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index 3f32cc6c5312d..03c8ce264c846 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -19,7 +19,7 @@ Module Name: #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_neon.h" #include "sqnbitgemm_q8_block.h" @@ -34,11 +34,11 @@ namespace // size_t -SQ4BitGemmPackQuantBDataSize( +Q4BitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType @@ -55,7 +55,7 @@ SQ4BitGemmPackQuantBData( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, std::byte* PackedQuantBDataBegin, MLAS_THREADPOOL* ThreadPool @@ -69,7 +69,7 @@ SQ4BitGemmPackQuantBData( const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); const size_t Iterations = N * BlockCountK; // one iteration per block - const size_t SubBlkLen = (ComputeType == CompInt8) + const size_t SubBlkLen = (ComputeType == SQNBIT_CompInt8) ? ((BlkLen == 16) ? 16 : 32) : 16; @@ -126,18 +126,18 @@ SQ4BitGemmPackQuantBData( // size_t -SQ4BitGemmPerGemmWorkspaceSize( +Q4BitGemmPerGemmWorkspaceSize( size_t M, size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { MLAS_UNREFERENCED_PARAMETER(N); switch (ComputeType) { - case CompInt8: { + case SQNBIT_CompInt8: { // workspace buffer is used for block quantization of A to int8 const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); @@ -150,15 +150,15 @@ SQ4BitGemmPerGemmWorkspaceSize( } size_t -SQ4BitGemmPerGemmWorkspaceAlignment( +Q4BitGemmPerGemmWorkspaceAlignment( size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { MLAS_UNREFERENCED_PARAMETER(BlkLen); switch (ComputeType) { - case CompInt8: { + case SQNBIT_CompInt8: { return Q8BlkAlignment(); } default: { @@ -175,14 +175,14 @@ SQ4BitGemmPerGemmWorkspaceAlignment( // Kernel dispatch structure definition. // -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = sqnbitgemm_neon::SQ4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = sqnbitgemm_neon::Q4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = sqnbitgemm_neon::SQ4BitGemmPackQuantBData; - d.SQ4BitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::SQ4BitGemmPerGemmWorkspaceAlignment; + d.Q4BitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::Q4BitGemmPerGemmWorkspaceSize; + d.Q4BitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = sqnbitgemm_neon::SQ4BitGemmM1Kernel_CompFp32; d.Q4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::Q4BitBlkDequantBForSgemm_CompFp32; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h index ef9345d7ac484..247d885615393 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h @@ -30,13 +30,13 @@ namespace sqnbitgemm_neon // // Function declarations for SQNBitGemm ARM NEON kernel entry points. -// Refer to the prototypes in sqnbitgemm.h for documentation. +// Refer to the prototypes in qnbitgemm.h for documentation. // These are declared here so they can be used to initialize the -// MLAS_SQNBIT_GEMM_DISPATCH structure and also be implemented in separate +// MLAS_QNBIT_GEMM_DISPATCH structure and also be implemented in separate // files. // -// CompFp32 declarations +// SQNBIT_CompFp32 declarations void SQ4BitGemmM1Kernel_CompFp32( @@ -64,7 +64,7 @@ Q4BitBlkDequantBForSgemm_CompFp32( size_t BlockCountK ); -// CompInt8 declarations +// SQNBIT_CompInt8 declarations void QuantizeARow_CompInt8( diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp index 12ddc42506e98..7b9f05a9c385d 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp @@ -13,7 +13,7 @@ Module Name: This module implements the float/quantized n-bit integer matrix multiplication kernels for ARM NEON specific to input type T1 as float32 and - MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompFp32. + MLAS_QNBIT_GEMM_COMPUTE_TYPE SQNBIT_CompFp32. --*/ @@ -21,7 +21,7 @@ Module Name: #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_neon.h" namespace sqnbitgemm_neon @@ -31,7 +31,7 @@ namespace { // -// CompFp32 kernel implementation. +// SQNBIT_CompFp32 kernel implementation. // MLAS_FORCEINLINE void diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp index 0d62ea37b7e26..f1acd99c7b693 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp @@ -13,7 +13,7 @@ Module Name: This module implements the float/quantized n-bit integer matrix multiplication kernels for ARM NEON specific to input type T1 as float32 and - MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompInt8. + MLAS_QNBIT_GEMM_COMPUTE_TYPE SQNBIT_CompInt8. --*/ @@ -21,7 +21,7 @@ Module Name: #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_neon.h" #include "sqnbitgemm_q8_block.h" @@ -29,7 +29,7 @@ namespace sqnbitgemm_neon { // -// CompInt8 kernel implementation. +// SQNBIT_CompInt8 kernel implementation. // namespace diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h index 45c3963365e6b..941b884d0b9d2 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" template diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h index e9c3812bde899..ed78dfa67042d 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" diff --git a/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch b/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch index 45de47f3e5128..f55b593fd8a8c 100644 --- a/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch +++ b/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch @@ -16,7 +16,7 @@ index e46105324a..414c46a1ce 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -82,6 +82,9 @@ Abstract: - + #if (!defined(_MSC_VER)) || (_MSC_VER >= 1930) #if defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_ARM64EC) +#if !defined(USE_VSINPU) @@ -26,16 +26,16 @@ index e46105324a..414c46a1ce 100644 // Had to temporary disable fp16 under APPLE ARM64, as compiling // the source files require a hardware specific compilation flag. @@ -90,6 +93,7 @@ Abstract: - + #define MLAS_F16VEC_INTRINSICS_SUPPORTED - + +#endif // #endif // #endif // ARM64 #endif // Visual Studio 16 or earlier does not support fp16 intrinsic @@ -1635,6 +1639,7 @@ MlasHalfGemmConvertPackB( ); - + #if defined(__aarch64__) && defined(__linux__) +#if !defined(USE_VSINPU) /** @@ -46,7 +46,7 @@ index e46105324a..414c46a1ce 100644 MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB); #endif +#endif - + /** * @brief Indirect Depthwise convolution for fp16 diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h @@ -55,7 +55,7 @@ index 4239e2ecae..3df7e5573d 100644 +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -361,6 +361,7 @@ size_t #else - + #if defined(__aarch64__) && defined(__linux__) +#if !defined(USE_VSINPU) typedef size_t(MLASCALL MLAS_SBGEMM_FLOAT_KERNEL)( @@ -66,7 +66,7 @@ index 4239e2ecae..3df7e5573d 100644 ); #endif +#endif - + typedef size_t @@ -763,8 +765,10 @@ extern "C" { @@ -82,13 +82,13 @@ index 4239e2ecae..3df7e5573d 100644 MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelAdd; @@ -899,8 +903,10 @@ extern "C" { #define MLAS_QGEMM_THREAD_COMPLEXITY 65536 - + #if defined(__aarch64__) && defined(__linux__) +#if !defined(USE_VSINPU) #define MLAS_SBGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) #endif +#endif - + // // Single-threaded single precision matrix/matrix multiply operation. @@ -2570,4 +2576,3 @@ MlasPackInt4Elements(uint8_t* Output, UnpackedType ValueLow, UnpackedType ValueH @@ -103,16 +103,16 @@ index ed437f20f7..8c9d0a75fd 100644 @@ -20,7 +20,7 @@ Abstract: #include #include - --#if defined(MLAS_TARGET_POWER) + +-#if defined(MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_POWER) #if defined(__linux__) #include #elif defined(_AIX) @@ -536,7 +536,7 @@ Return Value: - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; } - + -#if defined(__linux__) +#if defined(__linux__) && !defined(USE_VSINPU) // @@ -124,12 +124,12 @@ index de7fd72fad..4f75dbd6fa 100644 +++ b/onnxruntime/core/mlas/lib/sbgemm.h @@ -31,6 +31,7 @@ Abstract: --*/ - + #if defined(__aarch64__) && defined(__linux__) +#if !defined(USE_VSINPU) - + #pragma once - + @@ -396,4 +397,5 @@ MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t Bat } ); @@ -141,7 +141,7 @@ index 6a71283f9d..d8bd348854 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -132,7 +132,7 @@ Status MatMul::Compute(OpKernelContext* ctx) const { - + return Status::OK(); } -#if defined(__aarch64__) && defined(__linux__) @@ -187,7 +187,7 @@ index b9bbe36583..2f570502d2 100644 +++ b/onnxruntime/core/providers/cpu/math/matmul.h @@ -31,8 +31,10 @@ class MatMul final : public OpKernel { trans_batch_b_ = trans_batch_b_attr != 0; - + #if defined(__aarch64__) && defined(__linux__) +#if !defined(USE_VSINPU) auto config_ops = info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16); @@ -195,10 +195,10 @@ index b9bbe36583..2f570502d2 100644 +#endif #endif } - + @@ -57,12 +59,14 @@ class MatMul final : public OpKernel { bool trans_batch_b_; - + #if defined(__aarch64__) && defined(__linux__) +#if !defined(USE_VSINPU) // fastmath mode state @@ -209,7 +209,7 @@ index b9bbe36583..2f570502d2 100644 #endif +#endif }; - + } // namespace onnxruntime diff --git a/onnxruntime/test/mlas/unittest/test_sbgemm.cpp b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp index f85fe97776..6039b7fa9e 100644 @@ -217,12 +217,12 @@ index f85fe97776..6039b7fa9e 100644 +++ b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp @@ -16,6 +16,7 @@ Abstract: --*/ - + #if defined(__aarch64__) && defined(__linux__) +#if !defined(USE_VSINPU) - + #include "test_sbgemm.h" - + @@ -138,4 +139,5 @@ static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_exe } return SBGemmRegistLongExecute() > 0; @@ -235,15 +235,15 @@ index 13701e2e3d..7e432f53c2 100644 +++ b/onnxruntime/test/mlas/unittest/test_sbgemm.h @@ -16,6 +16,7 @@ Abstract: --*/ - + #if defined(__aarch64__) && defined(__linux__) +#if !defined(USE_VSINPU) - + #pragma once - + @@ -278,4 +279,5 @@ class MlasSBGemmTest : public MlasTestBase { } }; - + +#endif #endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp index 71db7d81075b5..df543d8eca1fc 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -22,9 +22,9 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, size_t Threads, bool Symmetric, bool HasBias, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, benchmark::State& state) { - if (!MlasIsSQNBitGemmAvailable(BlkBitWidth, BlkLen, ComputeType)) { + if (!MlasIsQNBitGemmAvailable(BlkBitWidth, BlkLen, ComputeType)) { state.SkipWithMessage("SQNBitGemm is not available with the given configuration on the current machine."); return; } @@ -62,21 +62,21 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, tp.get()); std::unique_ptr Workspace; - if (const auto WorkspaceSize = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); + if (const auto WorkspaceSize = MlasQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); WorkspaceSize > 0) { Workspace = std::make_unique(WorkspaceSize); } std::unique_ptr PackedQuantBData; - if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); + if (const auto PackedQuantBDataSize = MlasQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); PackedQuantBDataSize > 0) { PackedQuantBData = std::make_unique(PackedQuantBDataSize); - MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData.data(), PackedQuantBData.get(), - QuantBScale.data(), has_zp_input, QuantBZeroPoint.data(), - tp.get()); + MlasQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData.data(), PackedQuantBData.get(), + QuantBScale.data(), has_zp_input, QuantBZeroPoint.data(), + tp.get()); } - MLAS_SQNBIT_GEMM_DATA_PARAMS params{}; + MLAS_QNBIT_GEMM_DATA_PARAMS params{}; params.A = A.data(); params.lda = K; if (PackedQuantBData != nullptr) @@ -92,10 +92,10 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, params.ldc = N; // warm up run - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); + MlasQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); for (auto _ : state) { - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); + MlasQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); } } @@ -110,7 +110,7 @@ void SQNBITGEMM(benchmark::State& state) { const auto Threads = narrow(state.range(4)); const auto Symmetric = narrow(state.range(5)); const bool HasBias = narrow(state.range(6)); - const auto ComputeType = static_cast(state.range(7)); + const auto ComputeType = static_cast(state.range(7)); RunSQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, ComputeType, state); } @@ -119,14 +119,14 @@ static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) { b->ArgNames({"BlkLen", "M", "N", "K", "Threads", "Symmetric", "HasBias", "ComputeType"}); b->ArgsProduct({ - {128}, // BlkLen - {1}, // M - {4096, 11008}, // N - {4096, 11008}, // K - {1, 8}, // Threads - {int64_t{false}, int64_t{true}}, // Symmetric - {int64_t{false}, int64_t{true}}, // HasBias - {int64_t{CompFp32}, int64_t{CompInt8}}, // ComputeType + {128}, // BlkLen + {1}, // M + {4096, 11008}, // N + {4096, 11008}, // K + {1, 8}, // Threads + {int64_t{false}, int64_t{true}}, // Symmetric + {int64_t{false}, int64_t{true}}, // HasBias + {int64_t{SQNBIT_CompFp32}, int64_t{SQNBIT_CompInt8}}, // ComputeType }); } @@ -145,10 +145,10 @@ void SQNBITGEMM_ENV(benchmark::State& state) { const auto Symmetric = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_SYMMETRIC", true); const auto HasBias = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_HAS_BIAS", false); const auto ComputeType = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_COMPUTE_TYPE", - static_cast(CompFp32)); + static_cast(SQNBIT_CompFp32)); RunSQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, - static_cast(ComputeType), + static_cast(ComputeType), state); std::ostringstream s; diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 0710981fa17c6..e22018ae2877f 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -18,11 +18,11 @@ Module Name: #include "mlas_q4.h" #include "mlas_qnbit.h" -static constexpr const char* ComputeTypeName(MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType) { +static constexpr const char* ComputeTypeName(MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType) { switch (ComputeType) { - case CompFp32: + case SQNBIT_CompFp32: return "Fp32"; - case CompInt8: + case SQNBIT_CompInt8: return "Int8"; default: return "unknown"; @@ -63,16 +63,16 @@ class MlasSQNBitGemmTest : public MlasTestBase { float* C, size_t ldc, void* Workspace, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, MLAS_THREADPOOL* Threadpool) { - MLAS_SQNBIT_GEMM_DATA_PARAMS params; + MLAS_QNBIT_GEMM_DATA_PARAMS params; params.A = A; params.lda = lda; params.Bias = Bias; params.C = C; params.ldc = ldc; #ifdef MLAS_TARGET_AMD64_IX86 - if (ComputeType == CompInt8) { + if (ComputeType == SQNBIT_CompInt8) { params.QuantBDataWorkspace = PackedQuantBDataWorkspace; } #endif @@ -81,7 +81,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { params.QuantBZeroPoint = QuantBZeroPoint; params.PostProcessor = nullptr; - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace, Threadpool); + MlasQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace, Threadpool); } void QuantizeA(size_t M, size_t K, const float* A, int8_t* QuantAData, float* QuantAScale) { @@ -201,7 +201,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { public: void Test(size_t M, size_t N, size_t K, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, bool WithThreadpool, bool Symmetric, bool WithBias) { MLAS_THREADPOOL* Threadpool = WithThreadpool ? GetMlasThreadPool() : nullptr; @@ -265,19 +265,19 @@ class MlasSQNBitGemmTest : public MlasTestBase { } void* Workspace = nullptr; - if (const auto WorkspaceSize = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); + if (const auto WorkspaceSize = MlasQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); WorkspaceSize > 0) { Workspace = BufferWorkspace.GetBuffer(WorkspaceSize); } void* PackedQuantBDataWorkspace = nullptr; - if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); + if (const auto PackedQuantBDataSize = MlasQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); PackedQuantBDataSize > 0) { PackedQuantBDataWorkspace = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize); bool has_zp_input = QuantBZeroPoint != nullptr; - MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBDataWorkspace, - QuantBScale, has_zp_input, QuantBZeroPoint, - GetMlasThreadPool()); + MlasQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBDataWorkspace, + QuantBScale, has_zp_input, QuantBZeroPoint, + GetMlasThreadPool()); } CallGemm(M, N, K, @@ -289,9 +289,9 @@ class MlasSQNBitGemmTest : public MlasTestBase { ComputeType, Threadpool); - if (ComputeType == CompFp32) { + if (ComputeType == SQNBIT_CompFp32) { CallReferenceGemm_CompFp32(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); - } else if (ComputeType == CompInt8) { + } else if (ComputeType == SQNBIT_CompInt8) { CallReferenceGemm_CompInt8(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); } else { FAIL() << "Test is not implemented for compute type " @@ -324,7 +324,7 @@ template class SQNBitGemmShortExecuteTest : public MlasTestFixture> { public: explicit SQNBitGemmShortExecuteTest(size_t M, size_t N, size_t K, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, bool WithThreadpool, bool Symmetric, bool WithBias) : M_(M), N_(N), @@ -341,11 +341,11 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture