diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index bee83ff07c74b..b995b27123218 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -1,7 +1,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -set(MLAS_SRC_DIR ${ONNXRUNTIME_ROOT}/core/mlas/lib) +set(MLAS_ROOT ${ONNXRUNTIME_ROOT}/core/mlas) +set(MLAS_SRC_DIR ${MLAS_ROOT}/lib) +set(MLAS_INC_DIR ${MLAS_ROOT}/inc) # # All hardware agnostic source files here @@ -9,6 +11,7 @@ set(MLAS_SRC_DIR ${ONNXRUNTIME_ROOT}/core/mlas/lib) # multi-target build # onnxruntime_add_static_library(onnxruntime_mlas + ${MLAS_SRC_DIR}/mlasi.h ${MLAS_SRC_DIR}/platform.cpp ${MLAS_SRC_DIR}/threading.cpp ${MLAS_SRC_DIR}/sgemm.cpp @@ -33,9 +36,18 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/qpostprocessor.cpp ${MLAS_SRC_DIR}/qlgavgpool.cpp ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp + ${MLAS_SRC_DIR}/sqnbitgemm.h ${MLAS_SRC_DIR}/sqnbitgemm.cpp ) +target_sources(onnxruntime_mlas PRIVATE + ${MLAS_INC_DIR}/mlas_float16.h + ${MLAS_INC_DIR}/mlas_gemm_postprocessor.h + ${MLAS_INC_DIR}/mlas_q4.h + ${MLAS_INC_DIR}/mlas_qnbit.h + ${MLAS_INC_DIR}/mlas.h +) + if (NOT onnxruntime_ORT_MINIMAL_BUILD) target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/q4_dq.cpp @@ -46,7 +58,7 @@ endif() set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas) function(add_jblas) - add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas) + add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas) target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas) target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/jblas_gemm.cpp @@ -143,10 +155,6 @@ function(setup_mlas_source_for_windows) target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/arm/sgemmc.cpp ) - # it should be removed after Visual Stuio is upgraded to 17.7 - if (MSVC) - add_compile_options("-d2SSAOptimizer-") - endif() elseif(onnxruntime_target_platform STREQUAL "x64") file(GLOB_RECURSE mlas_platform_srcs_avx CONFIGURE_DEPENDS @@ -300,8 +308,8 @@ else() if(APPLE) get_target_property(ONNXRUNTIME_MLAS_MACOSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) endif() - list(LENGTH ONNXRUNTIME_MLAS_MACOSX_ARCH ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGH) - if(ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGH GREATER 1) + list(LENGTH ONNXRUNTIME_MLAS_MACOSX_ARCH ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH) + if(ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH GREATER 1) set(ONNXRUNTIME_MLAS_MULTI_ARCH TRUE) endif() #If ONNXRUNTIME_MLAS_MULTI_ARCH is true, we need to go through every if branch below @@ -348,6 +356,8 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp ) + set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp + PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") if (NOT APPLE) set(mlas_platform_srcs ${mlas_platform_srcs} @@ -617,10 +627,12 @@ if(USE_JBLAS) endif() foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) - target_include_directories(${mlas_target} PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${MLAS_SRC_DIR}) + target_include_directories(${mlas_target} PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET}) + + set_target_properties(${mlas_target} PROPERTIES FOLDER "ONNXRuntime") endforeach() -set_target_properties(onnxruntime_mlas PROPERTIES FOLDER "ONNXRuntime") + if (WIN32) target_compile_options(onnxruntime_mlas PRIVATE "$<$:/wd6385>" "$<$:/wd4127>") if (onnxruntime_ENABLE_STATIC_ANALYSIS) @@ -636,6 +648,21 @@ if (NOT onnxruntime_BUILD_SHARED_LIB) FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) endif() +# set up source group for MLAS source files +block() + set(source_group_srcs) + foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) + get_target_property(mlas_target_srcs ${mlas_target} SOURCES) + foreach(mlas_target_src ${mlas_target_srcs}) + cmake_path(IS_PREFIX MLAS_ROOT ${mlas_target_src} in_mlas_root) + if(in_mlas_root) + list(APPEND source_group_srcs ${mlas_target_src}) + endif() + endforeach() + endforeach() + source_group(TREE ${MLAS_ROOT} FILES ${source_group_srcs}) +endblock() + if (NOT onnxruntime_ORT_MINIMAL_BUILD) @@ -647,7 +674,7 @@ if (NOT onnxruntime_ORT_MINIMAL_BUILD) onnxruntime_add_executable(onnxruntime_mlas_q4dq ${MLAS_SRC_DIR}/q4_dq_cli.cpp ) - target_include_directories(onnxruntime_mlas_q4dq PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${MLAS_SRC_DIR}) + target_include_directories(onnxruntime_mlas_q4dq PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) set_target_properties(onnxruntime_mlas_q4dq PROPERTIES FOLDER "ONNXRuntimeTest") target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index a9703dc68dd26..406c73c95d444 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -64,6 +64,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat if (!all_constant_) { return Status::OK(); } + +#if defined(MLAS_JBLAS) + auto compt_type = static_cast(accuracy_level_); MLAS_THREADPOOL* pool = NULL; if (input_idx == 1) { @@ -101,12 +104,32 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat is_packed = true; } +#else // defined(MLAS_JBLAS) + + if (input_idx == 1) { + packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_); + if (packed_b_size_ == 0) return Status::OK(); + auto qptr = tensor.DataRaw(); + packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, qptr, packed_b_.get()); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } + is_packed = true; + } + +#endif // defined(MLAS_JBLAS) + return Status::OK(); } Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; + +#if defined(MLAS_JBLAS) + // Pack three tensors into one buffer if (input_idx == 1) { used_shared_buffers = true; @@ -120,6 +143,15 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep used_shared_buffers = true; packed_b_ = std::move(prepacked_buffers[0]); } + +#else // defined(MLAS_JBLAS) + + if (input_idx == 1) { + used_shared_buffers = true; + packed_b_ = std::move(prepacked_buffers[0]); + } + +#endif // defined(MLAS_JBLAS) return Status::OK(); } @@ -129,6 +161,8 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const Tensor* a = ctx->Input(0); const auto* a_data = a->Data(); +#if defined(MLAS_JBLAS) + if (packed_b_.get()) { TensorShape b_shape({static_cast(N_), static_cast(K_)}); @@ -158,7 +192,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { gemm_params[i].C = y_data + helper.OutputOffsets()[i]; gemm_params[i].ldc = N; } - auto ws_size = MlasSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data()); + auto ws_size = MlasSQNBitsGemmBatchPackedBWorkspaceSize(M, N, K, max_len, gemm_params.data()); // workspace for activation process(dynamic quantization and others) auto ws_ptr = IAllocator::MakeUniquePtr(allocator, ws_size); MlasSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(), @@ -166,10 +200,10 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { return Status::OK(); } - const Tensor* b = ctx->Input(1); +#endif // defined(MLAS_JBLAS) + const Tensor* scales = ctx->Input(2); const Tensor* zero_points = ctx->Input(3); - const uint8_t* b_data = b->Data(); const auto* scales_data = scales->Data(); const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); @@ -181,8 +215,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { Tensor* y = ctx->Output(0, helper.OutputShape()); // Bail out early if the output is going to be empty - if (y->Shape().Size() == 0) + if (y->Shape().Size() == 0) { return Status::OK(); + } auto* y_data = y->MutableData(); @@ -192,36 +227,46 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(false); - if (MlasIsSQNBitGemmAvailable(nbits_, block_size_)) { - // number of bytes or elements between adjacent matrices - size_t b_data_matrix_stride_in_bytes, b_scale_matrix_stride, b_zero_point_matrix_stride_in_bytes; - MlasBlockwiseQuantizedBufferSizes(static_cast(nbits_), static_cast(block_size_), /* columnwise */ true, - static_cast(K), static_cast(N), - b_data_matrix_stride_in_bytes, b_scale_matrix_stride, - &b_zero_point_matrix_stride_in_bytes); - - const size_t b_matrix_size = K * N; - - InlinedVector data(batch_count); - for (size_t i = 0; i < batch_count; ++i) { - const size_t b_matrix_offset = helper.RightOffsets()[i] / b_matrix_size; - - data[i].A = a_data + helper.LeftOffsets()[i]; - data[i].lda = lda; - data[i].QuantBData = b_data + b_matrix_offset * b_data_matrix_stride_in_bytes; - data[i].QuantBScale = scales_data + b_matrix_offset * b_scale_matrix_stride; - data[i].QuantBZeroPoint = zero_points_data != nullptr - ? zero_points_data + b_matrix_offset * b_zero_point_matrix_stride_in_bytes - : nullptr; - data[i].C = y_data + helper.OutputOffsets()[i]; - data[i].ldc = N; + const bool has_single_b_matrix = std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), + [](size_t offset) { return offset == 0; }); + + if (has_single_b_matrix && packed_b_) { + for (int64_t accuracy_level = accuracy_level_; + accuracy_level >= static_cast(CompMostAccurate); + --accuracy_level) { + const auto compute_type = static_cast(accuracy_level); + if (MlasIsSQNBitGemmAvailable(M, N, K, nbits_, block_size_, compute_type)) { + IAllocatorUniquePtr workspace{}; + if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count, + nbits_, block_size_, compute_type); + workspace_size > 0) { + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); + workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); + } + + InlinedVector data(batch_count); + for (size_t i = 0; i < batch_count; ++i) { + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].QuantBData = packed_b_.get(); + data[i].QuantBScale = scales_data; + data[i].QuantBZeroPoint = zero_points_data; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + } + + MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), + thread_pool); + + return Status::OK(); + } } - - MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, data.data(), thread_pool); - - return Status::OK(); } + const Tensor* b = ctx->Input(1); + const uint8_t* b_data = b->Data(); + const size_t ldb = helper.Ldb(true); AllocatorPtr allocator; diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 1e83dd1cec400..bc0bfc92c85a0 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -23,19 +23,36 @@ Module Name: #include "mlas.h" #include "mlas_gemm_postprocessor.h" +/** + * @brief Define compute types of block quantization, in order of decreasing accuracy. + */ +typedef enum { + CompUndef = 0, /*!< undef */ + CompFp32, /*!< input fp32, accumulator fp32 */ + CompFp16, /*!< input fp16, accumulator fp16 */ + CompBf16, /*!< input bf16, accumulator fp32 */ + CompInt8, /*!< input int8, accumulator int32 */ + + // special values that should be the first and last actual values + + CompMostAccurate = CompUndef, + CompLeastAccurate = CompInt8, +} MLAS_SQNBIT_COMPUTE_TYPE; + +using MLAS_SQNBIT_GEMM_COMPUTE_TYPE = MLAS_SQNBIT_COMPUTE_TYPE; // TODO consolidate these + /** * @brief Data parameters for float/n-bit quantized int GEMM routine. */ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { - const float* A = nullptr; ///< address of A (float32 matrix) - size_t lda = 0; ///< leading dimension of A - const void* QuantBData = nullptr; ///< address of quantized B (quantized n-bit int values) - const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block - const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block - bool IsBPacked = false; ///< whether B values are packed in an optimized format for the computation - const float* Bias = nullptr; ///< optional address of Bias, vector size N - float* C = nullptr; ///< address of result matrix - size_t ldc = 0; ///< leading dimension of C + const float* A = nullptr; ///< address of A (float32 matrix) + size_t lda = 0; ///< leading dimension of A + const void* QuantBData = nullptr; ///< address of quantized B (quantized n-bit int values) + const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block + const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block + const float* Bias = nullptr; ///< optional address of Bias, vector size N + float* C = nullptr; ///< address of result matrix + size_t ldc = 0; ///< leading dimension of C ///< optional post processing to apply to result matrix MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; @@ -46,13 +63,26 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { * A must be a float32 matrix * B must be a quantized and packed n-bit int matrix * + * Call MlasIsSQNBitGemmAvailable() with the same parameters to determine whether this function may be called. + * + * Call MlasSQNBitGemmPackQuantBDataSize() with the same parameters to determine whether + * MLAS_SQNBIT_GEMM_DATA_PARAMS::QuantBData in `DataParams` should point to a buffer packed with + * MlasSQNBitGemmPackQuantBData(). + * + * Call MlasSQNBitGemmBatchWorkspaceSize() with the same parameters to determine whether `Workspace` should + * point to an intermediate workspace buffer. + * * @param[in] M row size of matrix A and C * @param[in] N column size of matrix B and C * @param[in] K column size of matrix A and row size of matrix B * @param[in] BatchN number of batches * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] Workspace Address of intermediate workspace buffer. + If MlasSQNBitGemmBatchWorkspaceSize() returns a non-zero value, this must be a + buffer with at least that many bytes. Otherwise, it may be nullptr. * @param[in] ThreadPool optional thread pool to use */ void MLASCALL @@ -63,31 +93,96 @@ MlasSQNBitGemmBatch( size_t BatchN, size_t BlkBitWidth, size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, MLAS_THREADPOOL* ThreadPool = nullptr ); /** * @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ bool MLASCALL MlasIsSQNBitGemmAvailable( + size_t M, + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +); + +/** + * @brief Gets the size in bytes of the intermediate workspace buffer required by the float32/quantized n-bit int GEMM + * implementation. If zero, no intermediate workspace is required. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) + */ +size_t MLASCALL +MlasSQNBitGemmBatchWorkspaceSize( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +); + +/** + * @brief Gets the size in bytes of the packed quantized B data. + * If non-zero, the quantized B data must first be packed by calling MlasSQNBitGemmPackQuantBData() with a buffer of + * this size, and then that packed quantized B data buffer must be passed to MlasSQNBitGemmBatch(). + * If zero, MlasSQNBitGemmPackQuantBData() must not be called and the quantized B data must be directly passed to + * MlasSQNBitGemmBatch(). + * + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + */ +size_t MLASCALL +MlasSQNBitGemmPackQuantBDataSize( + size_t N, + size_t K, size_t BlkBitWidth, size_t BlkLen ); /** - * @brief Define compute types of block quantization + * @brief Packs the quantized B data in a format that the kernel expects. + * + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[in] QuantBData quantized B data + * @param[out] PackedQuantBData packed quantized B data + * @param[in] ThreadPool optional thread pool to use */ -typedef enum { - CompUndef = 0, /*!< undef */ - CompFp32 = 1, /*!< input fp32, accumulator fp32 */ - CompFp16 = 2, /*!< input fp16, accumulator fp16 */ - CompBf16 = 3, /*!< input bf16, accumulator fp32 */ - CompInt8 = 4 /*!< input int8, accumulator int32 */ -} MLAS_SQNBIT_COMPUTE_TYPE; +void MLASCALL +MlasSQNBitGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + const void* QuantBData, + void* PackedQuantBData, + MLAS_THREADPOOL* ThreadPool = nullptr +); /** * @brief Data parameters for NBits GEMM routine @@ -139,7 +234,7 @@ MlasNBitsGemmPackBSize( * @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor * one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where * they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up - * inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale + * inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale * (is_asym is false) and Zp(is_asym is true). * @param thread_pool */ @@ -186,7 +281,7 @@ MlasNBitsGemmUnPackB( * @return Workspace size in bytes */ size_t MLASCALL -MlasSQNBitsGemmBatchWorkspaceSize( +MlasSQNBitsGemmBatchPackedBWorkspaceSize( const size_t M, const size_t N, const size_t K, diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 8329a34f1338f..1310ed3f384b9 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -482,7 +482,6 @@ Return Value: this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchNeon; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; // // Check if the processor supports ASIMD dot product instructions. @@ -512,6 +511,9 @@ Return Value: this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchSdot; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchDot; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; + + // MlasSQNBitGemmDispatchNeon has a dependency on dot product instructions + this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; } #if defined(__linux__) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 7f1d1b084aec0..7d877848017fe 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -11,10 +11,14 @@ Module Name: Abstract: This module implements the float/quantized n-bit integer matrix - multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch. + multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch, + as well as some SQNBitGemm-related query functions. --*/ #include "sqnbitgemm.h" + +#include + #ifdef MLAS_JBLAS #include "jblas_gemm.h" #endif @@ -22,29 +26,564 @@ Module Name: namespace { -// Get quantization variant based on `BlkBitWidth` and `BlkLen`. -// Return -1 if the input values are unsupported. -int32_t -GetDispatchQuantVariant(size_t BlkBitWidth, size_t BlkLen) +enum SQNBitGemmVariant { + SQNBitGemmVariantInvalid = -1, + + // Valid variants + + SQNBitGemmVariant_BitWidth4_CompFp32 = 0, + SQNBitGemmVariant_BitWidth4_CompInt8, + + // End of valid variants + + // Keep this element last and ensure that its value is the number of valid SQNBitGemmVariant values. + // Its value is used as an array size. + SQNBitGemmVariantCount, +}; + +SQNBitGemmVariant +GetSQNBitGemmVariant( + size_t M, + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + MLAS_UNREFERENCED_PARAMETER(N); + MLAS_UNREFERENCED_PARAMETER(K); + + if (BlkBitWidth == 4 && + (BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { + if (ComputeType == CompFp32 || + ComputeType == CompUndef) { // treat CompUndef (undefined) as CompFp32 + return SQNBitGemmVariant_BitWidth4_CompFp32; + } else if (ComputeType == CompInt8 && M == 1) { + return SQNBitGemmVariant_BitWidth4_CompInt8; + } + } + + return SQNBitGemmVariantInvalid; +} + +} // namespace + +bool MLASCALL +MlasIsSQNBitGemmAvailable( + size_t M, + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + if (Dispatch == nullptr) { + return false; + } + + const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); + + switch (Variant) { + case SQNBitGemmVariant_BitWidth4_CompFp32: { + return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr && + Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr; + } + case SQNBitGemmVariant_BitWidth4_CompInt8: { + return Dispatch->SQ4BitGemmM1Kernel_CompInt8 != nullptr && + Dispatch->QuantizeARow_CompInt8 != nullptr; + } + default: { + return false; + } + } +} + +namespace +{ + +size_t +SQNBitGemmWorkspaceAlignment(SQNBitGemmVariant Variant) +{ + switch (Variant) { + case SQNBitGemmVariant_BitWidth4_CompInt8: { + return Q8BlkAlignment(); + } + default: { + return 1; + } + } +} + +size_t +SQNBitGemmPerGemmWorkspaceSize( + SQNBitGemmVariant Variant, + size_t M, + size_t N, + size_t K, + size_t BlkLen +) +{ + MLAS_UNREFERENCED_PARAMETER(N); + + switch (Variant) { + case SQNBitGemmVariant_BitWidth4_CompInt8: { + // workspace buffer is used for block quantization of A to int8 + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); + return PerGemmWorkspaceSize; + } + default: { + return 0; + } + } +} + +size_t +SQNBitGemmPerGemmWorkspaceStride( + SQNBitGemmVariant Variant, + size_t M, + size_t N, + size_t K, + size_t BlkLen +) +{ + const auto Size = SQNBitGemmPerGemmWorkspaceSize(Variant, M, N, K, BlkLen); + const auto Alignment = SQNBitGemmWorkspaceAlignment(Variant); + return MlasDivRoundup(Size, Alignment) * Alignment; +} + +} // namespace + +size_t MLASCALL +MlasSQNBitGemmBatchWorkspaceSize( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) { - int32_t type = -1; - if (BlkBitWidth == 4 && BlkLen == 16) { - type = QuantVariant_BitWidth4_BlockSize16; - } else if (BlkBitWidth == 4 && BlkLen == 32) { - type = QuantVariant_BitWidth4_BlockSize32; - } else if (BlkBitWidth == 4 && BlkLen == 64) { - type = QuantVariant_BitWidth4_BlockSize64; - } else if (BlkBitWidth == 4 && BlkLen == 128) { - type = QuantVariant_BitWidth4_BlockSize128; - } else if (BlkBitWidth == 4 && BlkLen == 256) { - type = QuantVariant_BitWidth4_BlockSize256; + const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); + + const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen); + if (PerGemmWorkspaceStride == 0) { + return 0; } - return type; + const size_t Alignment = SQNBitGemmWorkspaceAlignment(Variant); + + const size_t WorkspaceSize = BatchN * PerGemmWorkspaceStride; + + return WorkspaceSize + Alignment - 1; +} + +namespace +{ + +void +SQ4BitGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkLen, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + constexpr size_t BlkBitWidth = 4; + + assert(BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t Iterations = N * BlockCountK; // one iteration per block + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t k_blk = tid % BlockCountK; + + const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize; + const std::byte* QuantBData = QuantBDataBegin + data_offset; + std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset; + + // + // Pack 16 4-bit values (8 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | + // => + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + // + for (size_t kk = 0; kk < BlkLen; kk += 16) { + for (size_t byte_pair_idx = 0; byte_pair_idx < 4; ++byte_pair_idx) { + const std::byte src0 = QuantBData[byte_pair_idx]; + const std::byte src1 = QuantBData[byte_pair_idx + 4]; + + std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; + std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; + + dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); + dst1 = (src0 >> 4) | ((src1 >> 4) << 4); + } + + QuantBData += 8; + PackedQuantBData += 8; + } + } + ); } } // namespace +size_t MLASCALL +MlasSQNBitGemmPackQuantBDataSize( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen +) +{ + // Ensure that a general implementation is available on this platform. + // For now, all implementations share the same packed format. + { + // Currently, there are implementations specific to M = 1, so pick a more general M > 1. + constexpr size_t M = 2; + // A CompUndef implementation should be available if any is available. + constexpr MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType = CompUndef; + const bool HasGeneralImplementation = + MlasIsSQNBitGemmAvailable(M, N, K, BlkBitWidth, BlkLen, ComputeType); + if (!HasGeneralImplementation) { + return 0; + } + } + + if (BlkBitWidth == 4) { + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + return PackedQuantBDataSize; + } + + return 0; +} + +void MLASCALL +MlasSQNBitGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + const void* QuantBData, + void* PackedQuantBData, + MLAS_THREADPOOL* ThreadPool +) +{ + if (BlkBitWidth == 4) { + SQ4BitGemmPackQuantBData( + N, + K, + BlkLen, + static_cast(QuantBData), + static_cast(PackedQuantBData), + ThreadPool + ); + } +} + +namespace +{ + +MLAS_FORCEINLINE void +AddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc) +{ + for (size_t m = 0; m < CountM; m++) { + const float* bias = Bias; + float* sum = C; + for (size_t n = 0; n < CountN; n += 4) { + if (CountN - n < 4) { + for (size_t nn = n; nn < CountN; nn++) { + *sum += *bias; + sum++; + bias++; + } + break; + } + + MLAS_FLOAT32X4 acc_x = MlasLoadFloat32x4(sum); + acc_x = MlasAddFloat32x4(acc_x, MlasLoadFloat32x4(bias)); + MlasStoreFloat32x4(sum, acc_x); + bias += 4; + sum += 4; + } + C += ldc; + } +} + +typedef void(SQNBitGemmFn)( + size_t BlkLen, + size_t K, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + void* PerGemmWorkspace, + size_t RangeStartM, + size_t RangeCountM, + size_t RangeStartN, + size_t RangeCountN +); + +void +SQ4BitGemm_CompFp32( + const size_t BlkLen, + const size_t K, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + void* const PerGemmWorkspace, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) +{ + constexpr size_t BlkBitWidth = 4; + + MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspace); + + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; + + const size_t k_blks = MlasDivRoundup(K, BlkLen); + const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); + + const float* A = DataParams->A + RangeStartM * lda; + + const std::byte* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; + const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; + const std::byte* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; + + if (RangeCountM == 1) { + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const float* a_row = A; + const std::byte* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const std::byte* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompFp32( + BlkLen, + a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + } + return; + } + + constexpr size_t StrideN = 32; + size_t bufsize = k_blks * BlkLen * StrideN * sizeof(float); + MlasThreadedBufAlloc(bufsize); + auto* dequant_b = reinterpret_cast(ThreadedBufHolder.get()); + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, StrideN); + + // + // Step through each slice of matrix A along the M dimension. + // + const float* a_row = A; + const std::byte* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const std::byte* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + GetMlasPlatform().SQNBitGemmDispatch->Q4BitBlkDequantBForSgemm_CompFp32( + BlkLen, + dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks + ); + + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true + ); +#else + auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f); +#endif + + if (bias) { + AddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc); + } + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN, + RowsHandled, CountN, ldc + ); + } + + c_blk += ldc * RowsHandled; + a_row += lda * RowsHandled; + RowsRemaining -= RowsHandled; + } + } +} + +void +SQ4BitGemm_CompInt8( + const size_t BlkLen, + const size_t K, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + void* const PerGemmWorkspace, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) +{ + constexpr size_t BlkBitWidth = 4; + + const size_t k_blks = MlasDivRoundup(K, BlkLen); + + const size_t lda = k_blks * Q8BlkSize(BlkLen); + const size_t ldc = DataParams->ldc; + const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); + + const std::byte* QuantA = static_cast(PerGemmWorkspace) + RangeStartM * lda; + + const std::byte* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; + const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; + const std::byte* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; + + if (RangeCountM == 1) { + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const std::byte* a_row = QuantA; + const std::byte* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const std::byte* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( + BlkLen, + a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + } + return; + } + + assert(false && "not implemented for M > 1"); +} + +typedef void(InitializeWorkspaceFn)( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkLen, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + size_t PerGemmWorkspaceStride, + MLAS_THREADPOOL* ThreadPool +); + +void +InitializeWorkspace_CompInt8( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkLen, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + size_t PerGemmWorkspaceStride, + MLAS_THREADPOOL* ThreadPool +) +{ + MLAS_UNREFERENCED_PARAMETER(N); + + const auto QuantizeARow = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8; + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); + + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; + + const float* ARowPtr = data.A; + std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + + for (size_t m = 0; m < M; ++m) { + QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); + + ARowPtr += data.lda; + QuantARowPtr += QuantAStride; + } + }); +} + +struct Operations { + InitializeWorkspaceFn* InitializeWorkspace = nullptr; + SQNBitGemmFn* SQNBitGemm = nullptr; +}; + +constexpr auto OperationMap = []() { + std::array ops; + + ops[SQNBitGemmVariant_BitWidth4_CompFp32].SQNBitGemm = SQ4BitGemm_CompFp32; + + ops[SQNBitGemmVariant_BitWidth4_CompInt8].InitializeWorkspace = InitializeWorkspace_CompInt8; + ops[SQNBitGemmVariant_BitWidth4_CompInt8].SQNBitGemm = SQ4BitGemm_CompInt8; + + return ops; +}(); + +} // namespace + void MLASCALL MlasSQNBitGemmBatch( const size_t M, @@ -53,17 +592,43 @@ MlasSQNBitGemmBatch( const size_t BatchN, const size_t BlkBitWidth, const size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, MLAS_THREADPOOL* ThreadPool ) { - const int32_t QuantVariant = GetDispatchQuantVariant(BlkBitWidth, BlkLen); - MLAS_SQNBIT_GEMM_OPERATION* const Operation = GetMlasPlatform().SQNBitGemmDispatch->Operations[QuantVariant]; + const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); + assert(Variant != SQNBitGemmVariantInvalid); + + // + // Ensure `Workspace` has correct alignment. + // + if (Workspace != nullptr) { + const size_t Alignment = SQNBitGemmWorkspaceAlignment(Variant); + const uintptr_t WorkspaceAddress = reinterpret_cast(Workspace); + Workspace = reinterpret_cast( + (WorkspaceAddress + Alignment - 1) & (~(Alignment - 1)) + ); + } + + const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen); + + if (const auto InitializeWorkspaceOperation = OperationMap[Variant].InitializeWorkspace; + InitializeWorkspaceOperation != nullptr) { + InitializeWorkspaceOperation( + M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool + ); + } + + const auto ComputeOperation = OperationMap[Variant].SQNBitGemm; if (ThreadPool == nullptr) { for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { - auto Data = &DataParams[gemm_i]; - Operation(K, Data, 0, M, 0, N); + const auto* Data = &DataParams[gemm_i]; + void* PerGemmWorkspace = + reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; + ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, 0, M, 0, N); } return; } @@ -112,7 +677,10 @@ MlasSQNBitGemmBatch( MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { const auto gemm_i = tid / ThreadsPerGemm; const auto blk_i = tid % ThreadsPerGemm; - auto Data = &DataParams[gemm_i]; + const auto* Data = &DataParams[gemm_i]; + void* PerGemmWorkspace = reinterpret_cast( + reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride + ); const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; @@ -123,29 +691,10 @@ MlasSQNBitGemmBatch( const size_t RangeStartN = ThreadIdN * StrideN; const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); - Operation(K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); }); } -bool MLASCALL -MlasIsSQNBitGemmAvailable( - size_t BlkBitWidth, - size_t BlkLen -) -{ - const int32_t QuantVariant = GetDispatchQuantVariant(BlkBitWidth, BlkLen); - if (QuantVariant == -1) { - return false; - } - - if (GetMlasPlatform().SQNBitGemmDispatch == nullptr || - GetMlasPlatform().SQNBitGemmDispatch->Operations[QuantVariant] == nullptr) { - return false; - } - - return true; -} - size_t MLASCALL MlasNBitsGemmPackBSize( size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType @@ -224,7 +773,7 @@ MlasNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, s } size_t MLASCALL -MlasSQNBitsGemmBatchWorkspaceSize( +MlasSQNBitsGemmBatchPackedBWorkspaceSize( const size_t M, const size_t N, const size_t K, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index 90fdd710e2773..a66db79dc290a 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -10,98 +10,23 @@ Module Name: Abstract: - This module includes: + This module includes kernel function prototypes and helper functions for + implementing SQNBitGemm. - - Declaration of the set of template functions used to implement a kernel - for a matrix/matrix multiplication, A*B, where A is a float matrix and B is - a n-bit quantized integer matrix (QNBitGemm). - - - A shared kernel driver function template, MlasSQNBitGemmOperation. - - - Kernel dispatch structure. - - The B matrix is block quantized, which means that its values are grouped - into blocks which each have one scale and optional zero point. Each - quantized value in B is n-bits wide. + SQNBitGemm is a matrix/matrix multiplication, A*B, where A is a float + matrix and B is a n-bit quantized integer matrix. B is block quantized, + meaning values of B are divided into blocks and each block has its own + scale and optional zero point. --*/ #pragma once +#include + #include "mlas_qnbit.h" #include "mlasi.h" -// -// Kernel implementation template declarations -// - -/** - * @brief Multiply float matrix A with quantized n-bit integer matrix B. - * B is block quantized and column major. - * This kernel handles the special case where M, the number of rows of A and C, is 1. - * - * @tparam BlkBitWidth Bit width of each value in a block. - * @tparam BlkLen Number of values in a block. - * @tparam KernelType Hardware-specific kernel type. - * - * @param A Supplies the A matrix. - * @param QuantBData Supplies the quantized B matrix block data. - * @param QuantBScale Supplies the quantized B matrix block scale values. - * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. - * @param[out] C Supplies the output C matrix. - * @param CountN Number of columns of B and C. - * @param CountK Number of columns of A and rows of B. - * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. - * @param Bias Bias vector of length N. - */ -template -MLAS_FORCEINLINE void -MlasSQNBitGemmM1Kernel( - const float* A, - const uint8_t* QuantBData, - const float* QuantBScale, - const uint8_t* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB, - const float* Bias -); - -/** - * @brief Dequantize B into the format expected by the Sgemm kernel. - * B is block quantized and column major. - * This is equivalent to dequantizing B and then running - * MlasSgemmCopyPackB. - * - * @tparam BlkBitWidth Bit width of each value in a block. - * @tparam BlkLen Number of values in a block. - * @tparam KernelType Hardware-specific kernel type. - * - * @param[out] FpData Supplies the output buffer for the dequantized B float data. - * @param QuantBData Supplies the quantized B matrix block data. - * @param QuantBScale Supplies the quantized B matrix block scale values. - * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. - * @param CountN Number of columns of B. - * @param CountK Number of rows of B. - * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. - */ -template -MLAS_FORCEINLINE void -MlasQNBitBlkDequantBForSgemm( - float* FpData, - const uint8_t* QuantBData, - const float* QuantBScale, - const uint8_t* QuantBZeroPoint, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB -); - -// -// MlasQNBitGemmOperation and helpers -// - constexpr MLAS_FORCEINLINE size_t MlasQNBitBlkDataSizeInBytes(size_t BlkBitWidth, size_t BlkLen) { @@ -119,169 +44,174 @@ MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount) } } -MLAS_FORCEINLINE void -MlasAddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc) -{ - for (size_t m = 0; m < CountM; m++) { - const float* bias = Bias; - float* sum = C; - for (size_t n = 0; n < CountN; n += 4) { - if (CountN - n < 4) { - for (size_t nn = n; nn < CountN; nn++) { - *sum += *bias; - sum++; - bias++; - } - break; - } +// +// Quantized int8 block helpers. +// - MLAS_FLOAT32X4 acc_x = MlasLoadFloat32x4(sum); - acc_x = MlasAddFloat32x4(acc_x, MlasLoadFloat32x4(bias)); - MlasStoreFloat32x4(sum, acc_x); - bias += 4; - sum += 4; - } - C += ldc; - } +MLAS_FORCEINLINE +const float& +Q8BlkScale(const std::byte* BlkPtr) +{ + return *reinterpret_cast(BlkPtr); } -template -MLAS_FORCEINLINE void MLASCALL -MlasSQNBitGemmOperation( - const size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN -) +MLAS_FORCEINLINE +float& +Q8BlkScale(std::byte* BlkPtr) { - const size_t lda = DataParams->lda; - const size_t ldc = DataParams->ldc; - - const size_t k_blks = MlasDivRoundup(K, BlkLen); - const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); - - const float* A = DataParams->A + RangeStartM * lda; - - const uint8_t* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; - const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; - const uint8_t* QuantBZeroPoint = - (DataParams->QuantBZeroPoint == nullptr) - ? nullptr - : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; - - float* C = DataParams->C + RangeStartM * ldc + RangeStartN; - - const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; - - if (RangeCountM == 1) { - size_t CountN; - for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(RangeCountN - n, size_t{128}); - - const float* a_row = A; - const uint8_t* b_col = QuantBData + n * ldb; - const float* b_col_scale = QuantBScale + n * k_blks; - const uint8_t* b_col_zp = - (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; - float* c_blk = C + n; - const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - - MlasSQNBitGemmM1Kernel( - a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias - ); - - if (DataParams->PostProcessor != nullptr) { - DataParams->PostProcessor->Process( - DataParams->C, RangeStartM, RangeStartN + n, - RangeCountM, CountN, ldc - ); - } - } - return; - } - - constexpr size_t StrideN = 32; - size_t bufsize = k_blks * BlkLen * StrideN * sizeof(float); - MlasThreadedBufAlloc(bufsize); - auto* dequant_b = reinterpret_cast(ThreadedBufHolder.get()); - // - // Step through each slice of matrix B along the N dimension. - // - - size_t CountN; - for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(RangeCountN - n, StrideN); - - // - // Step through each slice of matrix A along the M dimension. - // - const float* a_row = A; - const uint8_t* b_col = QuantBData + n * ldb; - const float* b_col_scale = QuantBScale + n * k_blks; - const uint8_t* b_col_zp = - (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; - float* c_blk = C + n; - const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + return *reinterpret_cast(BlkPtr); +} - MlasQNBitBlkDequantBForSgemm( - dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks - ); +MLAS_FORCEINLINE +const int8_t* +Q8BlkData(const std::byte* BlkPtr) +{ + return reinterpret_cast(BlkPtr + sizeof(float)); +} - size_t RowsRemaining = RangeCountM; - while (RowsRemaining > 0) { -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) - auto RowsHandled = GetMlasPlatform().GemmFloatKernel( - a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true - ); -#else - auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f); -#endif +MLAS_FORCEINLINE +int8_t* +Q8BlkData(std::byte* BlkPtr) +{ + return reinterpret_cast(BlkPtr + sizeof(float)); +} - if (bias) { - MlasAddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc); - } - if (DataParams->PostProcessor != nullptr) { - DataParams->PostProcessor->Process( - DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN, - RowsHandled, CountN, ldc - ); - } +MLAS_FORCEINLINE +constexpr size_t +Q8BlkSize(size_t BlkLen) +{ + const size_t BlkSize = sizeof(float) + BlkLen * sizeof(int8_t); + // Currently, the strictest alignment requirement of a block is for a float. + // Ensure contiguous blocks are suitably aligned. + assert(BlkSize % alignof(float) == 0); + return BlkSize; +} - c_blk += ldc * RowsHandled; - a_row += lda * RowsHandled; - RowsRemaining -= RowsHandled; - } - } +MLAS_FORCEINLINE +constexpr size_t +Q8BlkAlignment() +{ + return alignof(float); } // // Kernel dispatch structure. // -typedef void(MLASCALL MLAS_SQNBIT_GEMM_OPERATION)( - size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, - size_t RangeStartM, - size_t RangeCountM, - size_t RangeStartN, - size_t RangeCountN -); +struct MLAS_SQNBIT_GEMM_DISPATCH { + // + // CompFp32 kernel function prototypes. + // + + /** + * @brief Multiply float matrix A with quantized 4-bit integer matrix B. + * B is block quantized and column major. + * This kernel handles the special case where M, the number of rows of A and C, is 1. + * + * @param BlkLen Number of values in a block. + * @param A Supplies the A matrix. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + */ + typedef void(SQ4BitGemmM1Kernel_CompFp32_Fn)( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias + ); + + SQ4BitGemmM1Kernel_CompFp32_Fn* SQ4BitGemmM1Kernel_CompFp32 = nullptr; + + /** + * @brief Dequantize B into the format expected by the Sgemm kernel. + * B is a quantized 4-bit integer matrix that is block quantized and column major. + * This is equivalent to dequantizing B and then running MlasSgemmCopyPackB. + * + * @param BlkLen Number of values in a block. + * @param[out] FpData Supplies the output buffer for the dequantized B float data. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param CountN Number of columns of B. + * @param CountK Number of rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + */ + typedef void(Q4BitBlkDequantBForSgemm_CompFp32_Fn)( + size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB + ); + + Q4BitBlkDequantBForSgemm_CompFp32_Fn* Q4BitBlkDequantBForSgemm_CompFp32 = nullptr; -enum QuantVariant { - QuantVariant_BitWidth4_BlockSize16, - QuantVariant_BitWidth4_BlockSize32, - QuantVariant_BitWidth4_BlockSize64, - QuantVariant_BitWidth4_BlockSize128, - QuantVariant_BitWidth4_BlockSize256, - QuantVariantCount, // Keep this element last and ensure that its value is the number of other QuantVariant values. - // Its value is used as an array size. -}; + // + // CompInt8 kernel function prototypes. + // -struct MLAS_SQNBIT_GEMM_DISPATCH { - MLAS_SQNBIT_GEMM_OPERATION* Operations[QuantVariantCount] = { - // Initialized to nullptrs. Overwrite in hardware-specific kernel implementation. - }; + /** + * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. + * A and B are block quantized and B is column major. + * This kernel handles the special case where M, the number of rows of A and C, is 1. + * + * @param BlkLen Number of values in a block. + * @param QuantA Supplies the quantized A matrix. + Binary data containing block quantized int8 data and scale values. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + */ + typedef void(SQ4BitGemmM1Kernel_CompInt8_Fn)( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias + ); + + SQ4BitGemmM1Kernel_CompInt8_Fn* SQ4BitGemmM1Kernel_CompInt8 = nullptr; + + /** + * @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers. + * + * @param BlkLen Number of values in a block. + * @param A Supplies the A matrix. + * @param CountK Number of columns of A. + * @param[out] QuantA Supplies the output quantized A matrix. + * Binary data containing block quantized int8 data and scale values. + */ + typedef void(QuantizeARow_CompInt8_Fn)( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA + ); + + QuantizeARow_CompInt8_Fn* QuantizeARow_CompInt8 = nullptr; }; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index 63afe57dd9137..69fd427fa574a 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -23,12 +23,6 @@ Module Name: #include #include -// -// Hardware-specific kernel type. -// -struct MLAS_SQNBIT_GEMM_KERNEL_NEON { -}; - namespace { @@ -70,7 +64,7 @@ FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3) template MLAS_FORCEINLINE void -LoadData(const float* src, size_t count, float32x4_t (& dst)[Capacity / 4]) +LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4]) { static_assert(Capacity % 4 == 0, "Capacity must be divisible by 4."); @@ -101,13 +95,14 @@ LoadData(const float* src, size_t count, float32x4_t (& dst)[Capacity / 4]) } } -template +template MLAS_FORCEINLINE void -ComputeDotProducts( +ComputeDotProducts_BlkBitWidth4_CompFp32( + size_t BlkLen, const float* ARowPtr, - const uint8_t* QuantBDataColPtr, + const std::byte* QuantBDataColPtr, const float* QuantBScaleColPtr, - const uint8_t* QuantBZeroPointColPtr, + const std::byte* QuantBZeroPointColPtr, float* SumPtr, size_t CountK, size_t StrideQuantBData, @@ -116,8 +111,13 @@ ComputeDotProducts( const float* BiasPtr ) { + constexpr size_t BlkBitWidth = 4; + static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); + constexpr size_t SubBlkLen = 16; // number of block elements to process in a sub-block iteration + assert(BlkLen % SubBlkLen == 0); + const uint8x8_t LowMask = vdup_n_u8(0x0F); // Manual conversion to float takes place in two steps: @@ -135,7 +135,7 @@ ComputeDotProducts( float32x4_t acc[NCols]{}; - const uint8_t* QuantBData = QuantBDataColPtr; + const std::byte* QuantBData = QuantBDataColPtr; const float* QuantBScale = QuantBScaleColPtr; size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer @@ -150,10 +150,12 @@ ComputeDotProducts( float offset[NCols]; // Includes zero point and float conversion offset of 16. if (QuantBZeroPointColPtr != nullptr) { UnrolledLoop([&](size_t i) { - const uint8_t zp_packed = + const std::byte zp_packed = QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - const uint8_t zp = ((QuantBZeroPointIdx & 1) == 1) ? (zp_packed >> 4) : (zp_packed & 0x0F); - offset[i] = 16.0f + zp; + const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[i] = 16.0f + std::to_integer(zp); }); } else { UnrolledLoop([&](size_t i) { @@ -162,33 +164,27 @@ ComputeDotProducts( }); } - constexpr size_t SubBlkLen = 16; // number of block elements to process in one iteration - for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { // load A row vector elements // load `SubBlkLen` elements from A, padded with 0's if there aren't enough const size_t k_subblk_len = std::min(k_blk_len - k_idx_in_blk, SubBlkLen); float32x4_t av[4]{}; - LoadData(ARowPtr + k + k_idx_in_blk, k_subblk_len, av); + LoadFloatData(ARowPtr + k + k_idx_in_blk, k_subblk_len, av); // load B column vectors uint8x8_t bv_packed[NCols]; + const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; UnrolledLoop([&](size_t i) { - const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; - bv_packed[i] = vld1_u8(QuantBData + i * StrideQuantBData + b_data_block_offset); - }); - - uint8x8_t bv_u8_unzipped[NCols][2]; - UnrolledLoop([&](size_t i) { - bv_u8_unzipped[i][0] = vand_u8(bv_packed[i], LowMask); - bv_u8_unzipped[i][1] = vand_u8(vshr_n_u8(bv_packed[i], 4), LowMask); + bv_packed[i] = vld1_u8( + reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset + ); }); uint8x8_t bv_u8[NCols][2]; UnrolledLoop([&](size_t i) { - bv_u8[i][0] = vzip1_u8(bv_u8_unzipped[i][0], bv_u8_unzipped[i][1]); - bv_u8[i][1] = vzip2_u8(bv_u8_unzipped[i][0], bv_u8_unzipped[i][1]); + bv_u8[i][0] = vand_u8(bv_packed[i], LowMask); + bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); }); // dequantize B @@ -262,19 +258,13 @@ ComputeDotProducts( } } -} // namespace - -// -// MlasSQNBitGemmKernel and helpers. -// - -template MLAS_FORCEINLINE void -MlasSQNBitGemmM1KernelNeon( +SQ4BitGemmM1Kernel_CompFp32( + size_t BlkLen, const float* A, - const uint8_t* QuantBData, + const std::byte* QuantBData, const float* QuantBScale, - const uint8_t* QuantBZeroPoint, + const std::byte* QuantBZeroPoint, float* C, size_t CountN, size_t CountK, @@ -282,6 +272,7 @@ MlasSQNBitGemmM1KernelNeon( const float* Bias ) { + constexpr size_t BlkBitWidth = 4; constexpr size_t NCols = 4; const float* ARowPtr = A; @@ -295,16 +286,17 @@ MlasSQNBitGemmM1KernelNeon( const float* BiasPtr = Bias; - const uint8_t* QuantBDataColPtr = QuantBData; + const std::byte* QuantBDataColPtr = QuantBData; const float* QuantBScaleColPtr = QuantBScale; - const uint8_t* QuantBZeroPointColPtr = QuantBZeroPoint; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; float* SumPtr = CRowPtr; int64_t nblk = static_cast(CountN) - NCols; while (nblk >= 0) { - ComputeDotProducts( + ComputeDotProducts_BlkBitWidth4_CompFp32( + BlkLen, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, BiasPtr @@ -327,7 +319,8 @@ MlasSQNBitGemmM1KernelNeon( // left over columns less than `NCols`? nblk += NCols; for (int64_t n = 0; n < nblk; ++n) { - ComputeDotProducts( + ComputeDotProducts_BlkBitWidth4_CompFp32<1>( + BlkLen, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, BiasPtr @@ -346,59 +339,26 @@ MlasSQNBitGemmM1KernelNeon( } } -#define SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(BlkBitWidth, BlkLen) \ - template <> \ - MLAS_FORCEINLINE void \ - MlasSQNBitGemmM1Kernel( \ - const float* A, \ - const uint8_t* QuantBData, \ - const float* QuantBScale, \ - const uint8_t* QuantBZeroPoint, \ - float* C, \ - size_t CountN, \ - size_t CountK, \ - size_t BlockStrideQuantB, \ - const float* Bias \ - ) \ - { \ - return MlasSQNBitGemmM1KernelNeon( \ - A, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, \ - BlockStrideQuantB, Bias \ - ); \ - } - -SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 16) -SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 32) -SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 64) -SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 128) -SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 256) - -#undef SPECIALIZE_SQNBIT_GEMM_M1_KERNEL - -// -// MlasQNBitBlkDequantBForSgemm and helpers. -// - -template MLAS_FORCEINLINE void -MlasQNBitBlkDequantBForSgemmNeon( +Q4BitBlkDequantBForSgemm_CompFp32( + size_t BlkLen, float* FpData, - const uint8_t* QuantBData, + const std::byte* QuantBData, const float* QuantBScale, - const uint8_t* QuantBZeroPoint, + const std::byte* QuantBZeroPoint, size_t CountN, size_t CountK, size_t BlockStrideQuantB ) { auto impl0_reference = [&]() { - static_assert(BlkBitWidth == 4); + constexpr size_t BlkBitWidth = 4; float* Dst = FpData; - const uint8_t* QuantBDataCol = QuantBData; + const std::byte* QuantBDataCol = QuantBData; const float* QuantBScaleCol = QuantBScale; - const uint8_t* QuantBZeroPointCol = QuantBZeroPoint; + const std::byte* QuantBZeroPointCol = QuantBZeroPoint; for (size_t n = 0; n < CountN; n += 16) { const size_t nnlen = std::min(CountN - n, size_t{16}); @@ -407,20 +367,26 @@ MlasQNBitBlkDequantBForSgemmNeon( for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, k_blk_idx += 1) { const size_t kklen = std::min(CountK - k, BlkLen); - const uint8_t* b_data = + const std::byte* b_data = QuantBDataCol + k_blk_idx * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); const float b_s = QuantBScaleCol[k_blk_idx]; const uint8_t b_z = (QuantBZeroPointCol != nullptr) ? ((k_blk_idx & 1) == 1) - ? QuantBZeroPointCol[k_blk_idx / 2] >> 4 - : QuantBZeroPointCol[k_blk_idx / 2] & 0x0F + ? std::to_integer(QuantBZeroPointCol[k_blk_idx / 2] >> 4) + : std::to_integer(QuantBZeroPointCol[k_blk_idx / 2] & std::byte{0x0F}) : 8; for (size_t kk = 0; kk < kklen; ++kk) { - const uint8_t b_packed = b_data[kk / 2]; - const uint8_t b_byte = ((kk & 1) == 1) ? b_packed >> 4 : b_packed & 0x0F; - const float b_value = (b_byte - b_z) * b_s; + const size_t packed_idx = kk % 16; + + const bool is_low_half = packed_idx < 8; + const size_t packed_byte_idx = packed_idx % 8; + const size_t packed_range_offset = (kk / 16) * 8; + + const std::byte b_packed = b_data[packed_range_offset + packed_byte_idx]; + const std::byte b_byte = is_low_half ? (b_packed & std::byte{0x0F}) : (b_packed >> 4); + const float b_value = (std::to_integer(b_byte) - b_z) * b_s; Dst[(k + kk) * 16 + nn] = b_value; } @@ -448,31 +414,332 @@ MlasQNBitBlkDequantBForSgemmNeon( impl0_reference(); } -#define SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(BlkBitWidth, BlkLen) \ - template <> \ - MLAS_FORCEINLINE void \ - MlasQNBitBlkDequantBForSgemm( \ - float* FpData, \ - const uint8_t* QuantBData, \ - const float* QuantBScale, \ - const uint8_t* QuantBZeroPoint, \ - size_t CountN, \ - size_t CountK, \ - size_t BlockStrideQuantB \ - ) \ - { \ - MlasQNBitBlkDequantBForSgemmNeon( \ - FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB \ - ); \ +// +// CompInt8 kernel implementation and related helpers +// + +template +MLAS_FORCEINLINE void +QuantizeBlock( + size_t BlkLen, + const float* A, + size_t ElementCount, + std::byte* QuantA +) +{ + static_assert(SubBlkLen >= 16 && SubBlkLen % 16 == 0); + + assert(BlkLen % SubBlkLen == 0); + + constexpr size_t VectorCount = SubBlkLen / 4; + + // + // Scan block values first to determine scale. + // + + float amax = 0.0f; // max of absolute values of A block + + size_t k; + for (k = 0; k < ElementCount; k += SubBlkLen) { + const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); + + float32x4_t a[VectorCount]{}; + LoadFloatData(A + k, SubBlkElementCount, a); + + float32x4_t abs_a[VectorCount]; + UnrolledLoop([&](size_t i) { + abs_a[i] = vabsq_f32(a[i]); + }); + + // find amax of SubBlkLen elements + for (size_t interval = VectorCount / 2; interval > 0; interval /= 2) { + for (size_t i = 0; i < interval; ++i) { + abs_a[i] = vmaxq_f32(abs_a[i], abs_a[i + interval]); + } + } + + // update existing amax + amax = std::max(amax, vmaxvq_f32(abs_a[0])); + } + + constexpr float range_max = (1 << 7) - 1; + const float scale = amax / range_max; + const float scale_reciprocal = scale != 0.0f ? 1.0f / scale : 0.0f; + + Q8BlkScale(QuantA) = scale; + + // + // Compute quantized block values. + // + + int8_t* QuantAData = Q8BlkData(QuantA); + + for (k = 0; k < ElementCount; k += SubBlkLen) { + const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); + + float32x4_t a[VectorCount]{}; + LoadFloatData(A + k, SubBlkElementCount, a); + + UnrolledLoop([&](size_t i) { + a[i] = vmulq_n_f32(a[i], scale_reciprocal); + }); + + int32x4_t a_s32[VectorCount]; + UnrolledLoop([&](size_t i) { + a_s32[i] = vcvtaq_s32_f32(a[i]); + }); + + UnrolledLoop([&](size_t i) { + QuantAData[k + i * 4 + 0] = static_cast(vgetq_lane_s32(a_s32[i], 0)); + QuantAData[k + i * 4 + 1] = static_cast(vgetq_lane_s32(a_s32[i], 1)); + QuantAData[k + i * 4 + 2] = static_cast(vgetq_lane_s32(a_s32[i], 2)); + QuantAData[k + i * 4 + 3] = static_cast(vgetq_lane_s32(a_s32[i], 3)); + }); + } + + // + // Zero out any remaining sub-block elements. + // + + for (; k < BlkLen; k += SubBlkLen) { + const int8x16_t Zeros = vdupq_n_s8(0); + UnrolledLoop([&](size_t i) { + vst1q_s8(QuantAData + k + i * 16, Zeros); + }); + } +} + +void MLASCALL +QuantizeARow_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA +) +{ + const float* ADataBlkPtr = A; + std::byte* QuantABlkPtr = QuantA; + + for (size_t k = 0; k < CountK; k += BlkLen) { + const size_t k_blk_len = std::min(CountK - k, BlkLen); + + QuantizeBlock<16>(BlkLen, ADataBlkPtr, k_blk_len, QuantABlkPtr); + + ADataBlkPtr += BlkLen; + QuantABlkPtr += Q8BlkSize(BlkLen); + } +} + +template +MLAS_FORCEINLINE void +ComputeDotProducts_BlkBitWidth4_CompInt8( + size_t BlkLen, + const std::byte* QuantARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + float* SumPtr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* BiasPtr +) +{ + static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); + + constexpr size_t BlkBitWidth = 4; + + constexpr size_t SubBlkLen = 16; // number of block elements to process in a sub-block iteration + assert(BlkLen % SubBlkLen == 0); + + const uint8x8_t LowMask = vdup_n_u8(0x0F); + + const std::byte* QuantA = QuantARowPtr; + + const std::byte* QuantBData = QuantBDataColPtr; + const float* QuantBScale = QuantBScaleColPtr; + size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + float32x4_t acc[NCols]{}; + + for (size_t k = 0; k < CountK; k += BlkLen) { + const size_t k_blk_len = std::min(CountK - k, BlkLen); + + const float a_scale = Q8BlkScale(QuantA); + const int8_t* a_data = Q8BlkData(QuantA); + + float b_scale[NCols]; + UnrolledLoop([&](size_t i) { b_scale[i] = QuantBScale[i * StrideQuantBScale]; }); + + int8_t b_zp[NCols]; + if (QuantBZeroPointColPtr != nullptr) { + UnrolledLoop([&](size_t i) { + const std::byte zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + b_zp[i] = ((QuantBZeroPointIdx & 1) == 1) + ? std::to_integer(zp_packed >> 4) + : std::to_integer(zp_packed & std::byte{0x0F}); + }); + } else { + UnrolledLoop([&](size_t i) { + b_zp[i] = 8; + }); + } + + for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { + // load A row vector + int8x16_t av = vld1q_s8(a_data + k_idx_in_blk); + + // load B column vectors + uint8x8_t bv_packed[NCols]; + const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; + UnrolledLoop([&](size_t i) { + bv_packed[i] = vld1_u8( + reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset + ); + }); + + int8x16_t bv[NCols]; + UnrolledLoop([&](size_t i) { + const int8x8_t lo = vreinterpret_s8_u8(vand_u8(bv_packed[i], LowMask)); + const int8x8_t hi = vreinterpret_s8_u8(vshr_n_u8(bv_packed[i], 4)); + bv[i] = vcombine_s8(lo, hi); + }); + + // subtract B zero point + UnrolledLoop([&](size_t i) { + const int8x16_t zp_v = vdupq_n_s8(b_zp[i]); + bv[i] = vsubq_s8(bv[i], zp_v); + }); + + // compute quantized dot product + int32x4_t dot[NCols]{}; + UnrolledLoop([&](size_t i) { + dot[i] = vdotq_s32(dot[i], av, bv[i]); + }); + + // convert dot product result to float + float32x4_t dot_f32[NCols]; + UnrolledLoop([&](size_t i) { + dot_f32[i] = vcvtq_f32_s32(dot[i]); + }); + + // multiply dot product result by scale and update accumulator + UnrolledLoop([&](size_t i) { + const float32x4_t scale_v = vdupq_n_f32(a_scale * b_scale[i]); + acc[i] = vfmaq_f32(acc[i], dot_f32[i], scale_v); + }); + } + + // increment pointers to next block + QuantA += Q8BlkSize(BlkLen); + QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + QuantBScale += 1; + QuantBZeroPointIdx += 1; + } + + if constexpr (NCols == 4) { + float32x4_t sum = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + + if (BiasPtr != nullptr) { + sum = vaddq_f32(sum, vld1q_f32(BiasPtr)); + } + + vst1q_f32(SumPtr, sum); + } else { + for (size_t i = 0; i < NCols; ++i) { + SumPtr[i] = vaddvq_f32(acc[i]); + if (BiasPtr != nullptr) { + SumPtr[i] += BiasPtr[i]; + } + } + } +} + +MLAS_FORCEINLINE +void +SQ4BitGemmM1Kernel_CompInt8( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + constexpr size_t BlkBitWidth = 4; + constexpr size_t NCols = 4; + + const std::byte* QuantARowPtr = QuantA; + float* CRowPtr = C; + + const size_t BlockCountK = BlockStrideQuantB; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + int64_t nblk = static_cast(CountN) - NCols; + + while (nblk >= 0) { + ComputeDotProducts_BlkBitWidth4_CompInt8( + BlkLen, + QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next `NCols` columns + + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * StrideQuantBScale; + if (QuantBZeroPointColPtr != nullptr) { + QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols : 0; + SumPtr += NCols; + + nblk -= NCols; } -SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 16) -SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 32) -SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 64) -SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 128) -SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 256) + // left over columns less than `NCols`? + nblk += NCols; + for (int64_t n = 0; n < nblk; ++n) { + ComputeDotProducts_BlkBitWidth4_CompInt8<1>( + BlkLen, + QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); -#undef SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if (QuantBZeroPointColPtr != nullptr) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +} // namespace // // Kernel dispatch structure definition. @@ -480,10 +747,11 @@ SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 256) const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { MLAS_SQNBIT_GEMM_DISPATCH d; - d.Operations[QuantVariant_BitWidth4_BlockSize16] = MlasSQNBitGemmOperation<4, 16, MLAS_SQNBIT_GEMM_KERNEL_NEON>; - d.Operations[QuantVariant_BitWidth4_BlockSize32] = MlasSQNBitGemmOperation<4, 32, MLAS_SQNBIT_GEMM_KERNEL_NEON>; - d.Operations[QuantVariant_BitWidth4_BlockSize64] = MlasSQNBitGemmOperation<4, 64, MLAS_SQNBIT_GEMM_KERNEL_NEON>; - d.Operations[QuantVariant_BitWidth4_BlockSize128] = MlasSQNBitGemmOperation<4, 128, MLAS_SQNBIT_GEMM_KERNEL_NEON>; - d.Operations[QuantVariant_BitWidth4_BlockSize256] = MlasSQNBitGemmOperation<4, 256, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; + d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32; + d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8; + d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; + return d; }(); diff --git a/onnxruntime/test/mlas/bench/bench_q4gemm.cpp b/onnxruntime/test/mlas/bench/bench_q4gemm.cpp index 87e3601612761..61b3f57d8daac 100644 --- a/onnxruntime/test/mlas/bench/bench_q4gemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_q4gemm.cpp @@ -109,12 +109,19 @@ void Q8Q4GEMM(benchmark::State& state, MLAS_BLK_QUANT_TYPE qtype) { static void GemmSizeProducts(benchmark::internal::Benchmark* b) { b->ArgNames(q4gemm_bench_arg_names); - ArgsProduct(b, {{1, 1024, 2048}, {4096}, {4096}, {8}}); + b->ArgsProduct({{1, 1024, 2048}, {4096}, {4096}, {8}}); } -BENCHMARK_CAPTURE(Q4GEMM, Q4Sym, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK_CAPTURE(Q4GEMM, Q4Zp8, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK_CAPTURE(Q4GEMM, Q4Sym128, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Sym, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Zp8, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Sym128, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime(); +[[maybe_unused]] static const bool benchmarks_registered = []() { + const bool is_q4gemm_supported = MlasQ4GemmPackBSize(BlkQ4Sym, 1, 1) > 0; + if (is_q4gemm_supported) { + BENCHMARK_CAPTURE(Q4GEMM, Q4Sym, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime(); + BENCHMARK_CAPTURE(Q4GEMM, Q4Zp8, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime(); + BENCHMARK_CAPTURE(Q4GEMM, Q4Sym128, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime(); + BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Sym, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime(); + BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Zp8, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime(); + BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Sym128, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime(); + return true; + } + return false; +}(); diff --git a/onnxruntime/test/mlas/bench/bench_sconv.cpp b/onnxruntime/test/mlas/bench/bench_sconv.cpp index 115641f6a6efb..39d135236b89c 100644 --- a/onnxruntime/test/mlas/bench/bench_sconv.cpp +++ b/onnxruntime/test/mlas/bench/bench_sconv.cpp @@ -224,8 +224,7 @@ BENCHMARK_CAPTURE(SCONV_NCHW, TeamsModel, "")->Apply(TeamsModel)->UseRealTime(); static void General_Conv2d(benchmark::internal::Benchmark* b) { b->ArgNames(ArgNamesForConv(2)); - ArgsProduct( - b, + b->ArgsProduct( {{2}, // Rank, {1}, // N {1, 2}, // Groups diff --git a/onnxruntime/test/mlas/bench/bench_sgemm.cpp b/onnxruntime/test/mlas/bench/bench_sgemm.cpp index e6e34bc88ad59..a94d33cd77f63 100644 --- a/onnxruntime/test/mlas/bench/bench_sgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sgemm.cpp @@ -103,14 +103,14 @@ void SGEMM(benchmark::State& state, bool pack_b, bool trans_a, bool trans_b, flo static void GemmSizeWithOne(benchmark::internal::Benchmark* b) { b->ArgNames(sgemm_bench_arg_names); - ArgsProduct(b, {{1}, {63, 255, 1023}, {63, 255, 1023}}); - ArgsProduct(b, {{63, 255, 1023}, {1}, {63, 255, 1023}}); - ArgsProduct(b, {{63, 255, 1023}, {63, 255, 1023}, {1}}); + b->ArgsProduct({{1}, {63, 255, 1023}, {63, 255, 1023}}); + b->ArgsProduct({{63, 255, 1023}, {1}, {63, 255, 1023}}); + b->ArgsProduct({{63, 255, 1023}, {63, 255, 1023}, {1}}); } static void GemmSizeProducts(benchmark::internal::Benchmark* b) { b->ArgNames(sgemm_bench_arg_names); - ArgsProduct(b, {{63, 255, 1023}, {63, 255, 1023}, {63, 255, 1023}}); + b->ArgsProduct({{63, 255, 1023}, {63, 255, 1023}, {63, 255, 1023}}); } BENCHMARK_CAPTURE(SGEMM, NORMAL_NoTrans, false, false, false)->Apply(GemmSizeProducts)->UseRealTime(); @@ -128,7 +128,7 @@ BENCHMARK_CAPTURE(SGEMM, PACKB_TransA, true, true, false)->Apply(GemmSizeProduct static void GemmLLMSizeProducts(benchmark::internal::Benchmark* b) { b->ArgNames(sgemm_bench_arg_names); - ArgsProduct(b, {{1, 1024, 2048}, {4096, 11008}, {4096, 11008}}); + b->ArgsProduct({{1, 1024, 2048}, {4096, 11008}, {4096, 11008}}); } BENCHMARK_CAPTURE(SGEMM, LLM, false, false, true)->Apply(GemmLLMSizeProducts)->UseRealTime(); diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp index cf67ef6f82051..2a56d37b899f8 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -4,33 +4,36 @@ #include "mlas_q4.h" #include "mlas_qnbit.h" +#include #include +#include #include "benchmark/benchmark.h" #include "bench_util.h" #include "core/util/thread_utils.h" +#include "core/common/narrow.h" -template -void SQNBITGEMM(benchmark::State& state) { - if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!"); - if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!"); - if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!"); - if (state.range(3) <= 0) throw std::invalid_argument("Threads must greater than 0!"); +using onnxruntime::narrow; - const size_t M = static_cast(state.range(0)); - const size_t N = static_cast(state.range(1)); - const size_t K = static_cast(state.range(2)); - const size_t threads = static_cast(state.range(3)); +template +void SQNBITGEMM(benchmark::State& state) { + const auto BlkLen = narrow(state.range(0)); + const auto M = narrow(state.range(1)); + const auto N = narrow(state.range(2)); + const auto K = narrow(state.range(3)); + const auto Threads = narrow(state.range(4)); + const auto Symmetric = narrow(state.range(5)); + const auto ComputeType = static_cast(state.range(6)); size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes; MlasBlockwiseQuantizedBufferSizes( - BlkBitWidth, BlkLen, /* columnwise */ true, + BlkBitWidth, static_cast(BlkLen), /* columnwise */ true, static_cast(K), static_cast(N), QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); OrtThreadPoolParams tpo; - tpo.thread_pool_size = static_cast(threads); + tpo.thread_pool_size = static_cast(Threads); tpo.auto_set_affinity = true; std::unique_ptr tp( @@ -47,14 +50,29 @@ void SQNBITGEMM(benchmark::State& state) { MlasQuantizeBlockwise(QuantBData.data(), QuantBScale.data(), Symmetric ? nullptr : QuantBZeroPoint.data(), - B.data(), BlkLen, /* columnwise */ true, + B.data(), static_cast(BlkLen), /* columnwise */ true, static_cast(K), static_cast(N), static_cast(N), tp.get()); + std::unique_ptr Workspace; + if (const auto WorkspaceSize = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); + WorkspaceSize > 0) { + Workspace = std::make_unique(WorkspaceSize); + } + + std::unique_ptr PackedQuantBData; + if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen); + PackedQuantBDataSize > 0) { + PackedQuantBData = std::make_unique(PackedQuantBDataSize); + MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, QuantBData.data(), PackedQuantBData.get(), tp.get()); + } + MLAS_SQNBIT_GEMM_DATA_PARAMS params{}; params.A = A.data(); params.lda = K; - params.QuantBData = QuantBData.data(); + params.QuantBData = PackedQuantBData != nullptr + ? static_cast(PackedQuantBData.get()) + : static_cast(QuantBData.data()); params.QuantBScale = QuantBScale.data(); params.QuantBZeroPoint = Symmetric ? nullptr : QuantBZeroPoint.data(); params.Bias = nullptr; @@ -62,30 +80,41 @@ void SQNBITGEMM(benchmark::State& state) { params.ldc = N; // warm up run - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ¶ms, tp.get()); + MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); for (auto _ : state) { - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ¶ms, tp.get()); + MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); } } -static void GemmSizeProducts(benchmark::internal::Benchmark* b) { - b->ArgNames({"M", "N", "K", "Threads"}); - ArgsProduct(b, {{1, 1024, 2048}, {4096, 11008}, {4096, 11008}, {8}}); +static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) { + b->ArgNames({"BlkLen", "M", "N", "K", "Threads", "Symmetric", "ComputeType"}); + + ArgsProductWithFilter(b, + + {{16, 32, 64, 128, 256}, // BlkLen + {1, 1024, 2048}, // M + {4096, 11008}, // N + {4096, 11008}, // K + {8}, // Threads + {int64_t{false}, int64_t{true}}, // Symmetric + {int64_t{CompFp32}, int64_t{CompInt8}}}, // ComputeType + + [](const std::vector& args) { + return MlasIsSQNBitGemmAvailable( + // M, N, K + narrow(args[1]), narrow(args[2]), narrow(args[3]), + // BlkBitWidth, BlkLen + 4, narrow(args[0]), + // ComputeType + static_cast(args[6])); + }); } -BENCHMARK(SQNBITGEMM<4, 16, false>)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK(SQNBITGEMM<4, 16, true>)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK(SQNBITGEMM<4, 32, false>)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK(SQNBITGEMM<4, 32, true>)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK(SQNBITGEMM<4, 64, false>)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK(SQNBITGEMM<4, 64, true>)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK(SQNBITGEMM<4, 128, false>)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK(SQNBITGEMM<4, 128, true>)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK(SQNBITGEMM<4, 256, false>)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK(SQNBITGEMM<4, 256, true>)->Apply(GemmSizeProducts)->UseRealTime(); - -#ifdef MLAS_JBLAS +BENCHMARK(SQNBITGEMM<4>)->Apply(SQNBitGemmArgs)->UseRealTime(); + +#if defined(MLAS_JBLAS) + void Q4GEMM_Jblas(benchmark::State& state, int block_size, bool is_asym, MLAS_SQNBIT_COMPUTE_TYPE cmp_type) { if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!"); if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!"); @@ -130,6 +159,11 @@ void Q4GEMM_Jblas(benchmark::State& state, int block_size, bool is_asym, MLAS_SQ } } +static void GemmSizeProducts(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K", "Threads"}); + b->ArgsProduct({{1, 1024, 2048}, {4096, 11008}, {4096, 11008}, {8}}); +} + BENCHMARK_CAPTURE(Q4GEMM_Jblas, Q4G32SymInt8, 32, false, CompInt8)->Apply(GemmSizeProducts)->UseRealTime(); BENCHMARK_CAPTURE(Q4GEMM_Jblas, Q4G128SymInt8, 128, false, CompInt8)->Apply(GemmSizeProducts)->UseRealTime(); BENCHMARK_CAPTURE(Q4GEMM_Jblas, Q4GPerNSymInt8, -1, false, CompInt8)->Apply(GemmSizeProducts)->UseRealTime(); @@ -137,4 +171,5 @@ BENCHMARK_CAPTURE(Q4GEMM_Jblas, Q4G32SymFp32, 32, false, CompFp32)->Apply(GemmSi BENCHMARK_CAPTURE(Q4GEMM_Jblas, Q4G128SymFp32, 128, false, CompFp32)->Apply(GemmSizeProducts)->UseRealTime(); BENCHMARK_CAPTURE(Q4GEMM_Jblas, Q4GPerNSymFp32, -1, false, CompFp32)->Apply(GemmSizeProducts)->UseRealTime(); BENCHMARK_CAPTURE(Q4GEMM_Jblas, Q4G32AsymFp32, 32, true, CompFp32)->Apply(GemmSizeProducts)->UseRealTime(); -#endif + +#endif // defined(MLAS_JBLAS) diff --git a/onnxruntime/test/mlas/bench/bench_util.cpp b/onnxruntime/test/mlas/bench/bench_util.cpp index b79cd3a2a40aa..d57564615b04e 100644 --- a/onnxruntime/test/mlas/bench/bench_util.cpp +++ b/onnxruntime/test/mlas/bench/bench_util.cpp @@ -23,10 +23,9 @@ std::vector RandomVectorUniform(std::vector shape, float min_val return RandomVectorUniform(static_cast(sz), min_value, max_value); } -// The Benchmark used here do not contains this as in newer version. -// Use the code from newer version. -void ArgsProduct(benchmark::internal::Benchmark* bench, - const std::vector>& arglists) { +void ArgsProductWithFilter(benchmark::internal::Benchmark* bench, + const std::vector>& arglists, + std::function& args)> include_filter) { std::vector indices(arglists.size(), 0); const std::size_t total = std::accumulate( std::begin(arglists), std::end(arglists), std::size_t{1}, @@ -39,7 +38,9 @@ void ArgsProduct(benchmark::internal::Benchmark* bench, for (std::size_t arg = 0; arg < arglists.size(); arg++) { args.push_back(arglists[arg][indices[arg]]); } - bench->Args(args); + if (include_filter(args)) { + bench->Args(args); + } args.clear(); std::size_t arg = 0; diff --git a/onnxruntime/test/mlas/bench/bench_util.h b/onnxruntime/test/mlas/bench/bench_util.h index a2b49e117da38..ee2ec42d0f755 100644 --- a/onnxruntime/test/mlas/bench/bench_util.h +++ b/onnxruntime/test/mlas/bench/bench_util.h @@ -5,10 +5,14 @@ #include +#include #include -void ArgsProduct(benchmark::internal::Benchmark* bench, - const std::vector>& arglists); +// Specifies benchmark arguments from the cartesian product of `arglists`, like Benchmark::ArgsProduct(). +// `include_filter` is called to determine whether a given set of arguments should be included. +void ArgsProductWithFilter(benchmark::internal::Benchmark* bench, + const std::vector>& arglists, + std::function& args)> include_filter); template std::vector RandomVectorUniform( diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 6c97d60301573..4fb8ab41745d5 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -18,6 +18,17 @@ Module Name: #include "mlas_q4.h" #include "mlas_qnbit.h" +static constexpr const char* ComputeTypeName(MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType) { + switch (ComputeType) { + case CompFp32: + return "Fp32"; + case CompInt8: + return "Int8"; + default: + return "unknown"; + } +} + /** * @brief Test class for n-bit int block quantized GEMM * Note: only 2-D matmul supported for now @@ -26,12 +37,16 @@ template class MlasSQNBitGemmTest : public MlasTestBase { private: MatrixGuardBuffer BufferA; + MatrixGuardBuffer BufferQuantAData; + MatrixGuardBuffer BufferQuantAScale; MatrixGuardBuffer BufferB; MatrixGuardBuffer BufferQuantBData; + MatrixGuardBuffer BufferPackedQuantBData; MatrixGuardBuffer BufferQuantBZeroPoint; MatrixGuardBuffer BufferQuantBScale; MatrixGuardBuffer BufferDequantizedB; MatrixGuardBuffer BufferBias; + MatrixGuardBuffer BufferWorkspace; MatrixGuardBuffer BufferC; MatrixGuardBuffer BufferCReference; @@ -40,12 +55,15 @@ class MlasSQNBitGemmTest : public MlasTestBase { size_t K, const float* A, size_t lda, - const uint8_t* QuantBData, + const void* QuantBData, + const void* PackedQuantBData, const float* QuantBScale, - const uint8_t* QuantBZeroPoint, + const void* QuantBZeroPoint, const float* Bias, float* C, size_t ldc, + void* Workspace, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, MLAS_THREADPOOL* Threadpool) { MLAS_SQNBIT_GEMM_DATA_PARAMS params; params.A = A; @@ -53,23 +71,106 @@ class MlasSQNBitGemmTest : public MlasTestBase { params.Bias = Bias; params.C = C; params.ldc = ldc; - params.QuantBData = QuantBData; + params.QuantBData = PackedQuantBData != nullptr ? PackedQuantBData : QuantBData; params.QuantBScale = QuantBScale; params.QuantBZeroPoint = QuantBZeroPoint; params.PostProcessor = nullptr; - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ¶ms, Threadpool); + MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace, Threadpool); + } + + void QuantizeA(size_t M, size_t K, const float* A, int8_t* QuantAData, float* QuantAScale) { + const size_t BlockCountK = (K + BlkLen - 1) / BlkLen; + const size_t lda = K; + for (size_t m = 0; m < M; ++m) { + for (size_t k = 0, k_blk = 0; k < K; k += BlkLen, ++k_blk) { + const size_t local_blk_len = std::min(K - k, BlkLen); + float blk_a[BlkLen]{}; + std::copy_n(A + m * lda + k, local_blk_len, blk_a); + + float amax = 0.0f; // max of absolute values of A block + for (size_t kk = 0; kk < local_blk_len; ++kk) { + float a = blk_a[kk]; + amax = std::max(amax, fabsf(a)); + } + + constexpr float range_max = (1 << 7) - 1; + const float scale = amax / range_max; + const float scale_reciprocal = scale != 0.0f ? 1.0f / scale : 0.0f; + + QuantAScale[m * BlockCountK + k_blk] = scale; + + for (size_t kk = 0; kk < BlkLen; ++kk) { + const float q = roundf(blk_a[kk] * scale_reciprocal); + QuantAData[m * BlockCountK * BlkLen + k + kk] = + static_cast( + std::clamp(q, + static_cast(std::numeric_limits::min()), + static_cast(std::numeric_limits::max()))); + } + } + } + } + + void CallReferenceGemm_CompInt8(size_t M, + size_t N, + size_t K, + const float* A, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + const float* Bias, + float* C) { + const size_t BlockCountK = (K + BlkLen - 1) / BlkLen; + + int8_t* QuantAData = BufferQuantAData.GetBuffer(M * BlockCountK * BlkLen); + float* QuantAScale = BufferQuantAScale.GetBuffer(M * BlockCountK); + QuantizeA(M, K, A, QuantAData, QuantAScale); + + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float sum = Bias == nullptr ? 0.0f : Bias[n]; + for (size_t k = 0, k_blk = 0; k < K; k += BlkLen, ++k_blk) { + const size_t k_blk_len = std::min(K - k, BlkLen); + + const float a_scale = QuantAScale[m * BlockCountK + k_blk]; + + const float b_scale = QuantBScale[n * BlockCountK + k_blk]; + + static_assert(BlkBitWidth == 4, "only implemented for 4-bit quantized B"); + + uint8_t b_zp = 8; + if (QuantBZeroPoint != nullptr) { + const uint8_t b_zp_byte = QuantBZeroPoint[n * ((BlockCountK + 1) / 2) + k_blk / 2]; + b_zp = (k_blk & 1) ? (b_zp_byte >> 4) : (b_zp_byte & 0x0F); + } + + int32_t qsum = 0; + + for (size_t kk = 0; kk < k_blk_len; ++kk) { + const int8_t qa = QuantAData[m * BlockCountK * BlkLen + k + kk]; + const uint8_t qb_byte = QuantBData[(n * BlockCountK * BlkLen + k + kk) / 2]; + const int8_t qb = ((kk & 1) == 1 ? (qb_byte >> 4) : (qb_byte & 0x0F)) - b_zp; + qsum += qa * qb; + } + + sum += static_cast(qsum) * a_scale * b_scale; + } + + C[m * N + n] = sum; + } + } } - void CallReferenceGemm(size_t M, - size_t N, - size_t K, - const float* A, - const uint8_t* QuantBData, - const float* QuantBScale, - const uint8_t* QuantBZeroPoint, - const float* Bias, - float* C) { + void CallReferenceGemm_CompFp32(size_t M, + size_t N, + size_t K, + const float* A, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + const float* Bias, + float* C) { float* DequantizedBData = BufferDequantizedB.GetBuffer(K * N); MlasDequantizeBlockwise( DequantizedBData, QuantBData, QuantBScale, QuantBZeroPoint, BlkLen, /* columnwise */ true, @@ -95,6 +196,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { public: void Test(size_t M, size_t N, size_t K, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, bool WithBias, bool Symmetric, bool WithThreadpool) { MLAS_THREADPOOL* Threadpool = WithThreadpool ? GetMlasThreadPool() : nullptr; @@ -126,7 +228,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { float* C = BufferC.GetBuffer(N * M, true); float* CReference = BufferCReference.GetBuffer(N * M, true); - // pack B + // quantize B uint8_t* QuantBData = nullptr; float* QuantBScale = nullptr; uint8_t* QuantBZeroPoint = nullptr; @@ -138,20 +240,48 @@ class MlasSQNBitGemmTest : public MlasTestBase { QuantBData = BufferQuantBData.GetBuffer(QuantBDataSizeInBytes); QuantBScale = BufferQuantBScale.GetBuffer(QuantBScaleSize); - if (Symmetric) { + if (!Symmetric) { QuantBZeroPoint = BufferQuantBZeroPoint.GetBuffer(QuantBZeroPointSizeInBytes); } - MlasQuantizeBlockwise(QuantBData, QuantBScale, QuantBZeroPoint, - B, BlkLen, - /* columnwise */ true, - static_cast(K), static_cast(N), - static_cast(N), - GetMlasThreadPool()); + MlasQuantizeBlockwise(QuantBData, QuantBScale, QuantBZeroPoint, + B, BlkLen, + /* columnwise */ true, + static_cast(K), static_cast(N), + static_cast(N), + GetMlasThreadPool()); } - CallGemm(M, N, K, A, /* lda */ K, QuantBData, QuantBScale, QuantBZeroPoint, Bias, C, /* ldc */ N, Threadpool); - CallReferenceGemm(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); + void* Workspace = nullptr; + if (const auto WorkspaceSize = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); + WorkspaceSize > 0) { + Workspace = BufferWorkspace.GetBuffer(WorkspaceSize); + } + + void* PackedQuantBData = nullptr; + if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen); + PackedQuantBDataSize > 0) { + PackedQuantBData = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize); + MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, QuantBData, PackedQuantBData, GetMlasThreadPool()); + } + + if (ComputeType == CompFp32) { + CallReferenceGemm_CompFp32(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); + } else if (ComputeType == CompInt8) { + CallReferenceGemm_CompInt8(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); + } else { + FAIL() << "Test is not implemented for compute type " + << ComputeType << " (" << ComputeTypeName(ComputeType) << ")"; + } + + CallGemm(M, N, K, + A, /* lda */ K, + QuantBData, PackedQuantBData, QuantBScale, QuantBZeroPoint, + Bias, + C, /* ldc */ N, + Workspace, + ComputeType, + Threadpool); size_t f = 0; for (size_t m = 0; m < M; m++) { @@ -179,74 +309,90 @@ template class SQNBitGemmShortExecuteTest : public MlasTestFixture> { public: explicit SQNBitGemmShortExecuteTest(size_t M, size_t N, size_t K, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, bool WithThreadpool, bool Symmetric, bool WithBias) - : M_(M), N_(N), K_(K), WithThreadpool_(WithThreadpool), Symmetric_(Symmetric), WithBias_(WithBias) { + : M_(M), + N_(N), + K_(K), + ComputeType_(ComputeType), + WithThreadpool_(WithThreadpool), + Symmetric_(Symmetric), + WithBias_(WithBias) { } void TestBody() override { MlasTestFixture>::mlas_tester->Test( - M_, N_, K_, WithThreadpool_, Symmetric_, WithBias_); + M_, N_, K_, ComputeType_, WithThreadpool_, Symmetric_, WithBias_); } static size_t RegisterSingleTest(size_t M, size_t N, size_t K, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, bool WithThreadpool, bool Symmetric, bool WithBias) { - std::stringstream ss; - ss << (WithThreadpool ? "SingleThread" : "Threaded") - << "/isSymmetric" << Symmetric - << "/M" << M << "xN" << N << "xK" << K - << "/hasBias" << WithBias; - auto test_name = ss.str(); - - testing::RegisterTest( - MlasSQNBitGemmTest::GetTestSuiteName(), - test_name.c_str(), - nullptr, - test_name.c_str(), - __FILE__, - __LINE__, - // Important to use the fixture type as the return type here. - [=]() -> MlasTestFixture>* { - return new SQNBitGemmShortExecuteTest( - M, N, K, WithThreadpool, Symmetric, WithBias); - }); - - return 1; + size_t tests_registered = 0; + + if (MlasIsSQNBitGemmAvailable(M, N, K, BlkBitWidth, BlkLen, ComputeType)) { + std::stringstream ss; + ss << (WithThreadpool ? "SingleThread" : "Threaded") + << "/isSymmetric" << Symmetric + << "/M" << M << "xN" << N << "xK" << K + << "/hasBias" << WithBias + << "/computeType" << ComputeTypeName(ComputeType); + auto test_name = ss.str(); + + testing::RegisterTest( + MlasSQNBitGemmTest::GetTestSuiteName(), + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + // Important to use the fixture type as the return type here. + [=]() -> MlasTestFixture>* { + return new SQNBitGemmShortExecuteTest( + M, N, K, ComputeType, WithThreadpool, Symmetric, WithBias); + }); + + tests_registered += 1; + } + + return tests_registered; } static size_t RegisterShortExecuteTests() { - size_t test_registered = 0; + size_t tests_registered = 0; - if (MlasIsSQNBitGemmAvailable(BlkBitWidth, BlkLen)) { + for (MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType : {CompFp32, CompInt8}) { for (bool WithThreadpool : {false, true}) { for (bool Symmetric : {false, true}) { for (size_t b = 1; b < 16; b++) { - test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, false); - test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, true); + tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, false); + tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, true); } for (size_t b = 16; b <= 256; b <<= 1) { - test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, false); - test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, true); + tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, false); + tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, true); } for (size_t b = 256; b < 320; b += 32) { - test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, true); + tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, true); } for (size_t b = 1; b < 96; b++) { - test_registered += RegisterSingleTest(1, b, 32, WithThreadpool, Symmetric, false); - test_registered += RegisterSingleTest(1, 32, b, WithThreadpool, Symmetric, true); - test_registered += RegisterSingleTest(1, b, b, WithThreadpool, Symmetric, false); + tests_registered += RegisterSingleTest(1, b, 32, ComputeType, WithThreadpool, Symmetric, false); + tests_registered += RegisterSingleTest(1, 32, b, ComputeType, WithThreadpool, Symmetric, true); + tests_registered += RegisterSingleTest(1, b, b, ComputeType, WithThreadpool, Symmetric, false); } - test_registered += RegisterSingleTest(43, 500, 401, WithThreadpool, Symmetric, true); + tests_registered += RegisterSingleTest(43, 500, 401, ComputeType, WithThreadpool, Symmetric, true); - // test_registered += RegisterSingleTest(1001, 1027, 1031, WithThreadpool, Symmetric, false); + // tests_registered += RegisterSingleTest(1001, 1027, 1031, ComputeType, WithThreadpool, Symmetric, false); } } } - return test_registered; + return tests_registered; } private: size_t M_, N_, K_; + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType_; bool WithThreadpool_, Symmetric_, WithBias_; };