diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index a62b1b259d109..04efa5c2b4f6d 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -33,6 +33,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/qpostprocessor.cpp ${MLAS_SRC_DIR}/qlgavgpool.cpp ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp + ${MLAS_SRC_DIR}/sqnbitgemm.cpp ) if (NOT onnxruntime_ORT_MINIMAL_BUILD) @@ -68,6 +69,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp ) set(mlas_platform_preprocess_srcs @@ -334,6 +336,7 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp ) if (NOT APPLE) set(mlas_platform_srcs diff --git a/include/onnxruntime/core/framework/op_node_proto_helper.h b/include/onnxruntime/core/framework/op_node_proto_helper.h index 700e1edc0cb7d..e7ac01947af41 100644 --- a/include/onnxruntime/core/framework/op_node_proto_helper.h +++ b/include/onnxruntime/core/framework/op_node_proto_helper.h @@ -10,20 +10,6 @@ #include "core/common/gsl.h" #endif -#ifdef __has_attribute -#define ORT_HAVE_ATTRIBUTE(x) __has_attribute(x) -#else -#define ORT_HAVE_ATTRIBUTE(x) 0 -#endif - -#if ORT_HAVE_ATTRIBUTE(nodiscard) -#define MUST_USE_RESULT [[nodiscard]] -#elif defined(__clang__) && ORT_HAVE_ATTRIBUTE(warn_unused_result) -#define MUST_USE_RESULT __attribute__((warn_unused_result)) -#else -#define MUST_USE_RESULT -#endif - class IMLOpKernel; namespace onnxruntime { @@ -43,14 +29,26 @@ class OpNodeProtoHelper { Call this function for a required attribute or when a default value for an optional attribute is specified in the op schema */ template - MUST_USE_RESULT Status GetAttr(const std::string& name, T* value) const; + Status GetAttr(const std::string& name, T* value) const; + + /** + Get a single attribute + Call this function for a required attribute or when a default value for an optional attribute is specified in the op schema + Throws if an attribute with the specified type doesn't exist + */ + template + [[nodiscard]] T GetAttr(const std::string& name) const { + T value; + ORT_THROW_IF_ERROR(GetAttr(name, &value)); + return value; + } /** Get a single attribute Call this function only when a default value for an optional attribute isn't specified in the op schema */ template - T GetAttrOrDefault(const std::string& name, const T& default_value) const { + [[nodiscard]] T GetAttrOrDefault(const std::string& name, const T& default_value) const { T tmp; return GetAttr(name, &tmp).IsOK() ? tmp : default_value; } @@ -70,7 +68,8 @@ class OpNodeProtoHelper { Call this function only when a default value for an optional attribute isn't specified in the op schema */ template - MUST_USE_RESULT std::vector GetAttrsOrDefault(const std::string& name, const std::vector& default_value = std::vector{}) const { + [[nodiscard]] std::vector GetAttrsOrDefault(const std::string& name, + const std::vector& default_value = {}) const { std::vector tmp; return GetAttrs(name, tmp).IsOK() ? tmp : default_value; } @@ -87,11 +86,12 @@ class OpNodeProtoHelper { /// Attribute data in a span, out parameter /// Status template - MUST_USE_RESULT Status GetAttrsAsSpan(const std::string& name, gsl::span& values) const; + Status GetAttrsAsSpan(const std::string& name, gsl::span& values) const; - MUST_USE_RESULT Status GetAttrs(const std::string& name, TensorShapeVector& out) const; + Status GetAttrs(const std::string& name, TensorShapeVector& out) const; - MUST_USE_RESULT TensorShapeVector GetAttrsOrDefault(const std::string& name, const TensorShapeVector& default_value = TensorShapeVector{}) const { + [[nodiscard]] TensorShapeVector GetAttrsOrDefault(const std::string& name, + const TensorShapeVector& default_value = {}) const { TensorShapeVector tmp; return GetAttrs(name, tmp).IsOK() ? tmp : default_value; } @@ -100,43 +100,43 @@ class OpNodeProtoHelper { Get repeated attributes */ template - MUST_USE_RESULT Status GetAttrs(const std::string& name, std::vector& values) const; + Status GetAttrs(const std::string& name, std::vector& values) const; template - MUST_USE_RESULT Status GetAttrs(const std::string& name, gsl::span values) const; + Status GetAttrs(const std::string& name, gsl::span values) const; - MUST_USE_RESULT Status GetAttrsStringRefs(const std::string& name, - std::vector>& refs) const; + Status GetAttrsStringRefs(const std::string& name, + std::vector>& refs) const; - uint32_t GetPrimitiveAttrElementCount(ONNX_NAMESPACE::AttributeProto_AttributeType type, - const std::string& name) const noexcept; + [[nodiscard]] uint32_t GetPrimitiveAttrElementCount(ONNX_NAMESPACE::AttributeProto_AttributeType type, + const std::string& name) const noexcept; - bool HasPrimitiveAttribute(ONNX_NAMESPACE::AttributeProto_AttributeType type, - const std::string& name) const noexcept; + [[nodiscard]] bool HasPrimitiveAttribute(ONNX_NAMESPACE::AttributeProto_AttributeType type, + const std::string& name) const noexcept; - uint32_t GetInputCount() const { + [[nodiscard]] uint32_t GetInputCount() const { return gsl::narrow_cast(impl_->getNumInputs()); } - uint32_t GetOutputCount() const { + [[nodiscard]] uint32_t GetOutputCount() const { return gsl::narrow_cast(impl_->getNumOutputs()); } - const ONNX_NAMESPACE::TypeProto* GetInputType(size_t index) const { + [[nodiscard]] const ONNX_NAMESPACE::TypeProto* GetInputType(size_t index) const { return impl_->getInputType(index); } - const ONNX_NAMESPACE::TypeProto* GetOutputType(size_t index) const { + [[nodiscard]] const ONNX_NAMESPACE::TypeProto* GetOutputType(size_t index) const { // Work around lack of a const method from the onnx InferenceContext interface return const_cast(impl_)->getOutputType(index); } // Try to query an attribute, returning nullptr if it doesn't exist - const ONNX_NAMESPACE::AttributeProto* TryGetAttribute(const std::string& name) const { + [[nodiscard]] const ONNX_NAMESPACE::AttributeProto* TryGetAttribute(const std::string& name) const { return impl_->getAttribute(name); } - const ONNX_NAMESPACE::AttributeProto* GetAttribute(const std::string& name) const { + [[nodiscard]] const ONNX_NAMESPACE::AttributeProto* GetAttribute(const std::string& name) const { const ONNX_NAMESPACE::AttributeProto* attr = TryGetAttribute(name); ORT_ENFORCE(attr != nullptr); return attr; diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index c72d811170a27..320a05bb97dac 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -1,35 +1,38 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/common/narrow.h" #include "core/common/safeint.h" #include "core/framework/op_kernel.h" +#include "core/mlas/inc/mlas.h" +#include "core/mlas/inc/mlas_qnbit.h" +#include "core/mlas/inc/mlas_q4.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" -#include "core/mlas/inc/mlas_q4.h" namespace onnxruntime { namespace contrib { class MatMulNBits final : public OpKernel { public: - MatMulNBits(const OpKernelInfo& info) : OpKernel(info) { - ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + MatMulNBits(const OpKernelInfo& info) + : OpKernel(info), + K_{narrow(info.GetAttr("K"))}, + N_{narrow(info.GetAttr("N"))}, + block_size_{narrow(info.GetAttr("block_size"))}, + nbits_{narrow(info.GetAttr("bits"))} { ORT_ENFORCE(nbits_ == 4, - "Only 4b quantization is supported for MatMulNBits op," - " additional bits support is planned."); + "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); } Status Compute(OpKernelContext* context) const override; private: - int64_t K_; - int64_t N_; - int64_t block_size_; - int64_t nbits_; - bool column_wise_quant_{true}; + const size_t K_; + const size_t N_; + const size_t block_size_; + const size_t nbits_; + const bool column_wise_quant_{true}; }; Status MatMulNBits::Compute(OpKernelContext* ctx) const { @@ -45,11 +48,60 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const auto* scales_data = scales->Data(); const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + TensorShape b_shape({static_cast(N_), static_cast(K_)}); + + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); + + Tensor* y = ctx->Output(0, helper.OutputShape()); + + // Bail out early if the output is going to be empty + if (y->Shape().Size() == 0) + return Status::OK(); + + auto* y_data = y->MutableData(); + + const size_t batch_count = helper.OutputOffsets().size(); + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + const size_t lda = helper.Lda(false); + + if (MlasIsSQNBitGemmAvailable(nbits_, block_size_)) { + // number of bytes or elements between adjacent matrices + size_t b_data_matrix_stride_in_bytes, b_scale_matrix_stride, b_zero_point_matrix_stride_in_bytes; + MlasBlockwiseQuantizedBufferSizes(static_cast(nbits_), static_cast(block_size_), /* columnwise */ true, + static_cast(K), static_cast(N), + b_data_matrix_stride_in_bytes, b_scale_matrix_stride, + &b_zero_point_matrix_stride_in_bytes); + + const size_t b_matrix_size = K * N; + + InlinedVector data(batch_count); + for (size_t i = 0; i < batch_count; ++i) { + const size_t b_matrix_offset = helper.RightOffsets()[i] / b_matrix_size; + + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].QuantBData = b_data + b_matrix_offset * b_data_matrix_stride_in_bytes; + data[i].QuantBScale = scales_data + b_matrix_offset * b_scale_matrix_stride; + data[i].QuantBZeroPoint = zero_points_data != nullptr + ? zero_points_data + b_matrix_offset * b_zero_point_matrix_stride_in_bytes + : nullptr; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + } + + MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, data.data(), thread_pool); + + return Status::OK(); + } + + const size_t ldb = helper.Ldb(true); + AllocatorPtr allocator; - auto status = ctx->GetTempSpaceAllocator(&allocator); - ORT_RETURN_IF_ERROR(status); + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); - // dequantize b, only 4b quantization is supported for now MlasDequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output @@ -67,29 +119,8 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { MlasTranspose(tmp_b_data_ptr.get(), tm_b_data_ptr_trans.get(), N_, K_); #endif - TensorShape b_shape({N_, K_}); - - MatMulComputeHelper helper; - ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); - - Tensor* y = ctx->Output(0, helper.OutputShape()); - - // Bail out early if the output is going to be empty - if (y->Shape().Size() == 0) - return Status::OK(); - - auto* y_data = y->MutableData(); - - const size_t max_len = helper.OutputOffsets().size(); - const size_t M = static_cast(helper.M()); - const size_t N = static_cast(helper.N()); - const size_t K = static_cast(helper.K()); - const size_t lda = helper.Lda(false); - const size_t ldb = helper.Ldb(true); - - // TODO: implement with native kernel - std::vector data(max_len); - for (size_t i = 0; i < max_len; i++) { + std::vector data(batch_count); + for (size_t i = 0; i < batch_count; i++) { data[i].BIsPacked = false; data[i].A = a_data + helper.LeftOffsets()[i]; data[i].lda = lda; @@ -101,7 +132,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { data[i].beta = 0.0f; } MlasGemmBatch(CblasNoTrans, CblasTrans, - M, N, K, data.data(), max_len, thread_pool); + M, N, K, data.data(), batch_count, thread_pool); return Status::OK(); } diff --git a/onnxruntime/core/common/cpuid_uarch.cc b/onnxruntime/core/common/cpuid_uarch.cc index 52baad739441b..16634b2bc8744 100644 --- a/onnxruntime/core/common/cpuid_uarch.cc +++ b/onnxruntime/core/common/cpuid_uarch.cc @@ -3,7 +3,8 @@ #include "core/common/cpuid_uarch.h" -#include "core/common/logging/logging.h" +#include // For std::cerr. + // Writing to stderr instead of logging because logger may not be initialized yet. namespace onnxruntime { @@ -137,7 +138,7 @@ void decodeMIDR( break; // #endif /* ARM */ default: - LOGS_DEFAULT(WARNING) << "unknown ARM CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown ARM CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } } break; @@ -156,7 +157,7 @@ void decodeMIDR( break; // #endif default: - LOGS_DEFAULT(WARNING) << "unknown Broadcom CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Broadcom CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #if (defined(_M_ARM64) || defined(__aarch64__)) && !defined(__ANDROID__) @@ -172,7 +173,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_thunderx2; break; default: - LOGS_DEFAULT(WARNING) << "unknown Cavium CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Cavium CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #endif @@ -187,7 +188,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_cortex_a76; break; default: - LOGS_DEFAULT(WARNING) << "unknown Huawei CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Huawei CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #if defined(_M_ARM) || defined(__arm__) @@ -199,7 +200,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_xscale; break; default: - LOGS_DEFAULT(WARNING) << "unknown Intel CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Intel CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #endif /* ARM */ @@ -215,7 +216,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_carmel; break; default: - LOGS_DEFAULT(WARNING) << "unknown Nvidia CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Nvidia CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; #if !defined(__ANDROID__) @@ -225,7 +226,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_xgene; break; default: - LOGS_DEFAULT(WARNING) << "unknown Applied Micro CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Applied Micro CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; #endif @@ -297,7 +298,7 @@ void decodeMIDR( break; // #endif /* ARM64 && !defined(__ANDROID__) */ default: - LOGS_DEFAULT(WARNING) << "unknown Qualcomm CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Qualcomm CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; case 'S': @@ -343,8 +344,9 @@ void decodeMIDR( *uarch = cpuinfo_uarch_exynos_m5; break; default: - LOGS_DEFAULT(WARNING) << "unknown Samsung CPU variant 0x" - << std::hex << midr_get_variant(midr) << " part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Samsung CPU variant 0x" + << std::hex << midr_get_variant(midr) << " part 0x" << std::hex << midr_get_part(midr) + << " ignored\n"; } break; // #if defined(_M_ARM) || defined(__arm__) @@ -355,12 +357,12 @@ void decodeMIDR( *uarch = cpuinfo_uarch_pj4; break; default: - LOGS_DEFAULT(WARNING) << "unknown Marvell CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Marvell CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #endif /* ARM */ default: - LOGS_DEFAULT(WARNING) << "unknown CPU uarch from MIDR value: 0x" << std::hex << midr; + std::cerr << "unknown CPU uarch from MIDR value: 0x" << std::hex << midr << "\n"; } } diff --git a/onnxruntime/core/framework/op_node_proto_helper.cc b/onnxruntime/core/framework/op_node_proto_helper.cc index 38d67eb0e0c72..c3deb94300e78 100644 --- a/onnxruntime/core/framework/op_node_proto_helper.cc +++ b/onnxruntime/core/framework/op_node_proto_helper.cc @@ -182,7 +182,7 @@ ORT_DEFINE_GET_ATTRS_SPAN_SPECIALIZATION(float, floats) ORT_DEFINE_GET_ATTRS_SPAN_SPECIALIZATION(int64_t, ints) template -MUST_USE_RESULT Status OpNodeProtoHelper::GetAttrs(const std::string& name, TensorShapeVector& out) const { +Status OpNodeProtoHelper::GetAttrs(const std::string& name, TensorShapeVector& out) const { gsl::span span; Status status = this->GetAttrsAsSpan(name, span); if (status.IsOK()) { @@ -193,7 +193,7 @@ MUST_USE_RESULT Status OpNodeProtoHelper::GetAttrs(const std::string& na } template -MUST_USE_RESULT Status OpNodeProtoHelper::GetAttrsStringRefs( +Status OpNodeProtoHelper::GetAttrsStringRefs( const std::string& name, std::vector>& refs) const { const AttributeProto* attr = TryGetAttribute(name); diff --git a/onnxruntime/core/mlas/.clang-format b/onnxruntime/core/mlas/.clang-format index 4a89ef98cf049..16ad8bd8a7234 100644 --- a/onnxruntime/core/mlas/.clang-format +++ b/onnxruntime/core/mlas/.clang-format @@ -2,10 +2,12 @@ BasedOnStyle: Google IndentWidth: 4 -ColumnLimit: 100 +# Setting ColumnLimit to 0 so developer choices about where to break lines are maintained. +# Developers are responsible for adhering to the 120 character maximum. +ColumnLimit: 0 +AlignAfterOpenBracket: BlockIndent AlwaysBreakAfterReturnType: TopLevel AlwaysBreakTemplateDeclarations: Yes BinPackParameters: false BreakBeforeBraces: Linux ... - diff --git a/onnxruntime/core/mlas/inc/mlas_gemm_postprocessor.h b/onnxruntime/core/mlas/inc/mlas_gemm_postprocessor.h new file mode 100644 index 0000000000000..7ea29eb091318 --- /dev/null +++ b/onnxruntime/core/mlas/inc/mlas_gemm_postprocessor.h @@ -0,0 +1,33 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + mlas_gemm_postprocessor.h + +Abstract: + + This module contains a base class for custom postprocessing following a + GEMM. + +--*/ + +#pragma once + +template +class MLAS_GEMM_POSTPROCESSOR +{ + public: + virtual void Process(T* C, /**< the address of matrix to process */ + size_t RangeStartM, /**< the start row index of matrix */ + size_t RangeStartN, /**< the start col index of matrix */ + size_t RangeCountM, /**< the element count per row to process */ + size_t RangeCountN, /**< the element count per col to process */ + size_t ldc /**< the leading dimension of matrix */ + ) const = 0; + + virtual ~MLAS_GEMM_POSTPROCESSOR() {} +}; diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h index 7c7b729117e4a..316344ad8c214 100644 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ b/onnxruntime/core/mlas/inc/mlas_q4.h @@ -21,6 +21,7 @@ Module Name: #pragma once #include "mlas.h" +#include "mlas_gemm_postprocessor.h" #include #include @@ -95,22 +96,6 @@ MlasQ4GemmUnPackB( ); -template -class MLAS_GEMM_POSTPROCESSOR -{ - public: - virtual void Process(T*, /**< the address of matrix to process */ - size_t, /**< the start row index of matrix */ - size_t, /**< the start col index of matrix */ - size_t, /**< the element count per row to process */ - size_t, /**< the element count per col to process */ - size_t /**< the leading dimension of matrix */ - ) const = 0; - - virtual ~MLAS_GEMM_POSTPROCESSOR() {} -}; - - /** * @brief Data parameters for Q4 GEMM routine * C = A * B + Bias @@ -241,7 +226,7 @@ MlasQ8Q4GemmBatch( * matrix shape [rows, columns], compute the shape of the * quantization parameter matrix [meta_rows, meta_cols] */ -template +template void MlasBlockwiseQuantMetaShape( int block_size, @@ -259,6 +244,7 @@ MlasBlockwiseQuantMetaShape( * is in column major layout, with bits packed on the column. * * @tparam T + * @tparam qbits * @param block_size * @param columnwise * @param rows @@ -266,7 +252,7 @@ MlasBlockwiseQuantMetaShape( * @param q_rows * @param q_cols */ -template +template void MlasBlockwiseQuantizedShape( int block_size, @@ -277,6 +263,32 @@ MlasBlockwiseQuantizedShape( int& q_cols ); +/** + * @brief Compute the sizes of the quantized data and quantization parameter buffers. + * + * @param qbits The bit width of each quantized value. + * @param block_size The number of quantized values in a block. + * @param columnwise Whether a block contains values from a matrix column (true) or row (false). + * @param rows Number of matrix rows. + * @param columns Number of matrix columns. + * @param[out] q_data_size_in_bytes The size in bytes of the quantized data. + * @param[out] q_scale_num_elements The size in elements of the scale quantization parameters. + * @param[out] q_zero_point_size_in_bytes The size in bytes of the zero point quantization parameters. Optional. + * + * If the qbits or block_size values are unsupported the output sizes will be zero. + */ +void MLASCALL +MlasBlockwiseQuantizedBufferSizes( + int qbits, + int block_size, + bool columnwise, + int rows, + int columns, + size_t& q_data_size_in_bytes, + size_t& q_scale_num_elements, + size_t* q_zero_point_size_in_bytes +); + /** * @brief Blockwise 4 bits quantization, resulting elements and quantization diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h new file mode 100644 index 0000000000000..9620dd42d1da9 --- /dev/null +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -0,0 +1,79 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + mlas_qnbit.h + +Abstract: + + This module contains the public data structures and procedure prototypes + for blocked n-bit quantized GEMM. + + N-bit block quantization is used to compress weight tensors of large + language models. + +--*/ + +#pragma once + +#include "mlas.h" +#include "mlas_gemm_postprocessor.h" + +/** + * @brief Data parameters for float/n-bit quantized int GEMM routine. + */ +struct MLAS_SQNBIT_GEMM_DATA_PARAMS { + const float* A = nullptr; ///< address of A (float32 matrix) + size_t lda = 0; ///< leading dimension of A + const void* QuantBData = nullptr; ///< address of quantized B (quantized n-bit int values) + const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block + const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block + bool IsBPacked = false; ///< whether B values are packed in an optimized format for the computation + const float* Bias = nullptr; ///< optional address of Bias, vector size N + float* C = nullptr; ///< address of result matrix + size_t ldc = 0; ///< leading dimension of C + + ///< optional post processing to apply to result matrix + MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; +}; + +/** + * @brief Batched GEMM: C = A * B + Bias + * A must be a float32 matrix + * B must be a quantized and packed n-bit int matrix + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] ThreadPool optional thread pool to use + */ +void MLASCALL +MlasSQNBitGemmBatch( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkBitWidth, + size_t BlkLen, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool = nullptr +); + +/** + * @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform. + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + */ +bool MLASCALL +MlasIsSQNBitGemmAvailable( + size_t BlkBitWidth, + size_t BlkLen +); diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index e0c2772cbb719..6c859e4e4f44b 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -890,14 +890,30 @@ extern const MLAS_CONV_SYM_DISPATCH MlasConvSymS8DispatchNeon; extern const MLAS_CONV_SYM_DISPATCH MlasConvSymU8DispatchDot; extern const MLAS_CONV_SYM_DISPATCH MlasConvSymS8DispatchDot; +// +// Quantized 8-bit integer/quantized 4-bit integer matrix/matrix multiply dispatch structure. +// + struct MLAS_Q8Q4GEMM_DISPATCH; extern const MLAS_Q8Q4GEMM_DISPATCH MlasQ8Q4GemmDispatchAvx512vnni; +// +// Float/quantized 4-bit integer matrix/matrix multiply dispatch structure. +// + struct MLAS_FPQ4GEMM_DISPATCH; extern const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchAvx512; +// +// Float/quantized n-bit integer matrix/matrix multiply dispatch structure. +// + +struct MLAS_SQNBIT_GEMM_DISPATCH; + +extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon; + // // Quantized depthwise convolution kernels. // @@ -1029,6 +1045,8 @@ struct MLAS_PLATFORM { const MLAS_FPQ4GEMM_DISPATCH* FpQ4GemmDispatch{nullptr}; const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr}; + + const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr}; }; inline diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 39586282e00ad..fec56c6ee063f 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -460,6 +460,7 @@ Return Value: this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchNeon; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon; + this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; // // Check if the processor supports ASIMD dot product instructions. diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index fbd1030de8ab7..48d975a7fd26d 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -422,6 +422,24 @@ struct BlockwiseQuantizer { q_cols = meta_cols * QuantBlk::kColumn; } + static MLAS_FORCEINLINE void quantizedBufferSizes( + int rows, int columns, size_t& data_bytes, size_t& scale_num_elements, size_t* zero_point_bytes + ) + { + int meta_rows, meta_cols; + quantizeMetaShape(rows, columns, meta_rows, meta_cols); + int q_rows, q_cols; + quantizedShape(rows, columns, q_rows, q_cols); + + data_bytes = q_rows * q_cols; + scale_num_elements = meta_rows * meta_cols; + + if (zero_point_bytes) { + // this works for qbits == 4 but may need to be updated for other qbits values + *zero_point_bytes = ((meta_rows * qbits + 7) / 8) * meta_cols; + } + } + /** * @brief Quantized a Matrix shape [rows, columns], resulting quantized * and packed data are stored in column major (transposed) @@ -621,7 +639,7 @@ struct BlockwiseQuantizer { }; -template +template void MlasBlockwiseQuantMetaShape( int block_size, @@ -635,47 +653,47 @@ MlasBlockwiseQuantMetaShape( switch (block_size) { case 16: { if (columnwise) { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } else { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } break; } case 32: { if (columnwise) { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } else { - BlockwiseQuantizer::quantizeMetaShape( + BlockwiseQuantizer::quantizeMetaShape( rows, columns, meta_rows, meta_cols); } break; } case 64: { if (columnwise) { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } else { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } break; } case 128: { if (columnwise) { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } else { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } break; } case 256: { if (columnwise) { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } else { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } break; @@ -689,7 +707,7 @@ MlasBlockwiseQuantMetaShape( -template +template void MlasBlockwiseQuantizedShape( int block_size, @@ -703,42 +721,42 @@ MlasBlockwiseQuantizedShape( switch (block_size) { case 16: { if (columnwise) { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } else { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } break; } case 32: { if (columnwise) { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } else { - BlockwiseQuantizer::quantizedShape( + BlockwiseQuantizer::quantizedShape( rows, columns, q_rows, q_cols); } break; } case 64: { if (columnwise) { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } else { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } break; } case 128: { if (columnwise) { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } else { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } break; } case 256: { if (columnwise) { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } else { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } break; } @@ -752,7 +770,7 @@ MlasBlockwiseQuantizedShape( template void -MlasBlockwiseQuantMetaShape( +MlasBlockwiseQuantMetaShape( int block_size, bool columnwise, int rows, @@ -763,7 +781,7 @@ MlasBlockwiseQuantMetaShape( template void -MlasBlockwiseQuantizedShape( +MlasBlockwiseQuantizedShape( int block_size, bool columnwise, int rows, @@ -773,6 +791,93 @@ MlasBlockwiseQuantizedShape( ); +void MLASCALL +MlasBlockwiseQuantizedBufferSizes( + int qbits, + int block_size, + bool columnwise, + int rows, + int columns, + size_t& q_data_size_in_bytes, + size_t& q_scale_num_elements, + size_t* q_zero_point_size_in_bytes +) +{ + q_data_size_in_bytes = q_scale_num_elements = 0; + if (q_zero_point_size_in_bytes) { + *q_zero_point_size_in_bytes = 0; + } + + if (qbits == 4) { + switch (block_size) { + case 16: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 32: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 64: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 128: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 256: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + default: + // Only block size 16, 32, 64, 128, 256 are supported. + break; + } + } +} + + template void MlasQuantizeBlockwise( diff --git a/onnxruntime/core/mlas/lib/q4gemm.h b/onnxruntime/core/mlas/lib/q4gemm.h index 1562f9c0b4236..b1b51dd53c4fc 100644 --- a/onnxruntime/core/mlas/lib/q4gemm.h +++ b/onnxruntime/core/mlas/lib/q4gemm.h @@ -90,7 +90,7 @@ MlasQ4GemmOperation( if (DataParams->OutputProcessor != nullptr) { DataParams->OutputProcessor->Process( - DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN, + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, RowsHandled, CountN, ldc); } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp new file mode 100644 index 0000000000000..f964b1affec31 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -0,0 +1,144 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch. +--*/ + +#include "sqnbitgemm.h" + +namespace +{ + +// Get quantization variant based on `BlkBitWidth` and `BlkLen`. +// Return -1 if the input values are unsupported. +int32_t +GetDispatchQuantVariant(size_t BlkBitWidth, size_t BlkLen) +{ + int32_t type = -1; + if (BlkBitWidth == 4 && BlkLen == 16) { + type = QuantVariant_BitWidth4_BlockSize16; + } else if (BlkBitWidth == 4 && BlkLen == 32) { + type = QuantVariant_BitWidth4_BlockSize32; + } else if (BlkBitWidth == 4 && BlkLen == 64) { + type = QuantVariant_BitWidth4_BlockSize64; + } else if (BlkBitWidth == 4 && BlkLen == 128) { + type = QuantVariant_BitWidth4_BlockSize128; + } else if (BlkBitWidth == 4 && BlkLen == 256) { + type = QuantVariant_BitWidth4_BlockSize256; + } + + return type; +} + +} // namespace + +void MLASCALL +MlasSQNBitGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const size_t BlkBitWidth, + const size_t BlkLen, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool +) +{ + const int32_t QuantVariant = GetDispatchQuantVariant(BlkBitWidth, BlkLen); + MLAS_SQNBIT_GEMM_OPERATION* const Operation = GetMlasPlatform().SQNBitGemmDispatch->Operations[QuantVariant]; + + if (ThreadPool == nullptr) { + for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { + auto Data = &DataParams[gemm_i]; + Operation(K, Data, 0, M, 0, N); + } + return; + } + + // + // Compute the number of target threads given the complexity of the SGEMM + // operation. Small requests should run using the single threaded path. + // + + const double Complexity = double(M) * double(N) * double(K) * double(BatchN); + + ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1; + + ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool) * 8; + + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchN; + if (ThreadsPerGemm < 1) { + ThreadsPerGemm = 1; + } + + constexpr size_t StrideM = 128; + + size_t nc = N; + if (ThreadsPerGemm > 1) { + // more than one thread per GEMM + + const size_t BlockedM = MlasDivRoundup(M, StrideM); + const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm); + if (max_nc < nc) { + nc = std::min( + nc, MlasDivRoundup(max_nc, MLAS_QGEMM_STRIDEN_THREAD_ALIGN) * + MLAS_QGEMM_STRIDEN_THREAD_ALIGN + ); + } + } + const size_t StrideN = nc; + + const size_t ThreadCountM = MlasDivRoundup(M, StrideM); + const size_t ThreadCountN = MlasDivRoundup(N, StrideN); + ThreadsPerGemm = ThreadCountM * ThreadCountN; + + MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { + const auto gemm_i = tid / ThreadsPerGemm; + const auto blk_i = tid % ThreadsPerGemm; + auto Data = &DataParams[gemm_i]; + + const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; + const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; + + const size_t RangeStartM = ThreadIdM * StrideM; + const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM); + + const size_t RangeStartN = ThreadIdN * StrideN; + const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); + + Operation(K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + }); +} + +bool MLASCALL +MlasIsSQNBitGemmAvailable( + size_t BlkBitWidth, + size_t BlkLen +) +{ + const int32_t QuantVariant = GetDispatchQuantVariant(BlkBitWidth, BlkLen); + if (QuantVariant == -1) { + return false; + } + + if (GetMlasPlatform().SQNBitGemmDispatch == nullptr || + GetMlasPlatform().SQNBitGemmDispatch->Operations[QuantVariant] == nullptr) { + return false; + } + + return true; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h new file mode 100644 index 0000000000000..f8f7dcd43699f --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -0,0 +1,287 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm.h + +Abstract: + + This module includes: + + - Declaration of the set of template functions used to implement a kernel + for a matrix/matrix multiplication, A*B, where A is a float matrix and B is + a n-bit quantized integer matrix (QNBitGemm). + + - A shared kernel driver function template, MlasSQNBitGemmOperation. + + - Kernel dispatch structure. + + The B matrix is block quantized, which means that its values are grouped + into blocks which each have one scale and optional zero point. Each + quantized value in B is n-bits wide. + +--*/ + +#pragma once + +#include "mlas_qnbit.h" +#include "mlasi.h" + +// +// Kernel implementation template declarations +// + +/** + * @brief Multiply float matrix A with quantized n-bit integer matrix B. + * B is block quantized and column major. + * This kernel handles the special case where M, the number of rows of A and C, is 1. + * + * @tparam BlkBitWidth Bit width of each value in a block. + * @tparam BlkLen Number of values in a block. + * @tparam KernelType Hardware-specific kernel type. + * + * @param A Supplies the A matrix. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + */ +template +MLAS_FORCEINLINE void +MlasSQNBitGemmM1Kernel( + const float* A, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +); + +/** + * @brief Dequantize B into the format expected by the Sgemm kernel. + * B is block quantized and column major. + * This is equivalent to dequantizing B and then running + * MlasSgemmCopyPackB. + * + * @tparam BlkBitWidth Bit width of each value in a block. + * @tparam BlkLen Number of values in a block. + * @tparam KernelType Hardware-specific kernel type. + * + * @param[out] FpData Supplies the output buffer for the dequantized B float data. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param CountN Number of columns of B. + * @param CountK Number of rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + */ +template +MLAS_FORCEINLINE void +MlasQNBitBlkDequantBForSgemm( + float* FpData, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB +); + +// +// MlasQNBitGemmOperation and helpers +// + +constexpr MLAS_FORCEINLINE size_t +MlasQNBitBlkDataSizeInBytes(size_t BlkBitWidth, size_t BlkLen) +{ + return BlkLen * BlkBitWidth / 8; +} + +template +constexpr MLAS_FORCEINLINE size_t +MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount) +{ + if constexpr (BlkBitWidth <= 4) { + return MlasDivRoundup(BlkCount, 2); // 2 blocks per byte + } else { + return BlkCount; + } +} + +MLAS_FORCEINLINE void +MlasAddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc) +{ + for (size_t m = 0; m < CountM; m++) { + const float* bias = Bias; + float* sum = C; + for (size_t n = 0; n < CountN; n += 4) { + if (CountN - n < 4) { + for (size_t nn = n; nn < CountN; nn++) { + *sum += *bias; + sum++; + bias++; + } + break; + } + + MLAS_FLOAT32X4 acc_x = MlasLoadFloat32x4(sum); + acc_x = MlasAddFloat32x4(acc_x, MlasLoadFloat32x4(bias)); + MlasStoreFloat32x4(sum, acc_x); + bias += 4; + sum += 4; + } + C += ldc; + } +} + +template +MLAS_FORCEINLINE void MLASCALL +MlasSQNBitGemmOperation( + const size_t K, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) +{ + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; + + const size_t k_blks = MlasDivRoundup(K, BlkLen); + const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); + + const float* A = DataParams->A + RangeStartM * lda; + + const uint8_t* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; + const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; + const uint8_t* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; + + if (RangeCountM == 1) { + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const float* a_row = A; + const uint8_t* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const uint8_t* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + MlasSQNBitGemmM1Kernel( + a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + } + return; + } + + constexpr size_t StrideN = 32; + size_t bufsize = k_blks * BlkLen * StrideN * sizeof(float); + MlasThreadedBufAlloc(bufsize); + auto* dequant_b = reinterpret_cast(ThreadedBufHolder.get()); + // + // Step through each slice of matrix B along the N dimension. + // + + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, StrideN); + + // + // Step through each slice of matrix A along the M dimension. + // + const float* a_row = A; + const uint8_t* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const uint8_t* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + MlasQNBitBlkDequantBForSgemm( + dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks + ); + + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true + ); +#else + auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f); +#endif + + if (bias) { + MlasAddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc); + } + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN, + RowsHandled, CountN, ldc + ); + } + + c_blk += ldc * RowsHandled; + a_row += lda * RowsHandled; + RowsRemaining -= RowsHandled; + } + } +} + +// +// Kernel dispatch structure. +// + +typedef void(MLASCALL MLAS_SQNBIT_GEMM_OPERATION)( + size_t K, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + size_t RangeStartM, + size_t RangeCountM, + size_t RangeStartN, + size_t RangeCountN +); + +enum QuantVariant { + QuantVariant_BitWidth4_BlockSize16, + QuantVariant_BitWidth4_BlockSize32, + QuantVariant_BitWidth4_BlockSize64, + QuantVariant_BitWidth4_BlockSize128, + QuantVariant_BitWidth4_BlockSize256, + QuantVariantCount, // Keep this element last and ensure that its value is the number of other QuantVariant values. + // Its value is used as an array size. +}; + +struct MLAS_SQNBIT_GEMM_DISPATCH { + MLAS_SQNBIT_GEMM_OPERATION* Operations[QuantVariantCount] = { + // Initialized to nullptrs. Overwrite in hardware-specific kernel implementation. + }; +}; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp new file mode 100644 index 0000000000000..63afe57dd9137 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -0,0 +1,489 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_neon.h + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for ARM NEON. + +--*/ + +#include "sqnbitgemm.h" + +#include + +#include +#include +#include + +// +// Hardware-specific kernel type. +// +struct MLAS_SQNBIT_GEMM_KERNEL_NEON { +}; + +namespace +{ + +template +MLAS_FORCEINLINE void +UnrolledLoopIterations(IterationFn&& f, std::index_sequence /* indices */) +{ + (f(Indices), ...); +} + +template +MLAS_FORCEINLINE void +UnrolledLoop(IterationFn&& f) +{ + UnrolledLoopIterations(std::forward(f), std::make_index_sequence()); +} + +MLAS_FORCEINLINE float32x4_t +FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3) +{ + // aN: aN_0 aN_1 aN_2 aN_3 + + float32x4_t b0 = vzip1q_f32(a0, a1); // a0_0 a1_0 a0_1 a1_1 + float32x4_t b1 = vzip2q_f32(a0, a1); // a0_2 a1_2 a0_3 a1_3 + float32x4_t b2 = vzip1q_f32(a2, a3); // a2_0 a3_0 a2_1 a3_1 + float32x4_t b3 = vzip2q_f32(a2, a3); // a2_2 a3_2 a2_3 a3_3 + + // a0_0 a1_0 a2_0 a3_0 + a0 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); + // a0_1 a1_1 a2_1 a3_1 + a1 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); + // a0_2 a1_2 a3_2 a3_2 + a2 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); + // a0_3 a1_3 a2_3 a3_3 + a3 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); + + return vaddq_f32(vaddq_f32(a0, a1), vaddq_f32(a2, a3)); +} + +template +MLAS_FORCEINLINE void +LoadData(const float* src, size_t count, float32x4_t (& dst)[Capacity / 4]) +{ + static_assert(Capacity % 4 == 0, "Capacity must be divisible by 4."); + + assert(count <= Capacity); + + size_t vi = 0; // vector index + + // handle 4 values at a time + while (count > 3) { + dst[vi] = vld1q_f32(src); + + vi += 1; + src += 4; + count -= 4; + } + + // handle remaining values + if (count > 0) { + dst[vi] = vsetq_lane_f32(src[0], dst[vi], 0); + + if (count > 1) { + dst[vi] = vsetq_lane_f32(src[1], dst[vi], 1); + + if (count > 2) { + dst[vi] = vsetq_lane_f32(src[2], dst[vi], 2); + } + } + } +} + +template +MLAS_FORCEINLINE void +ComputeDotProducts( + const float* ARowPtr, + const uint8_t* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const uint8_t* QuantBZeroPointColPtr, + float* SumPtr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* BiasPtr +) +{ + static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); + + const uint8x8_t LowMask = vdup_n_u8(0x0F); + + // Manual conversion to float takes place in two steps: + // 1. Map 4-bit values from [0, 15] to float values from [16.0f, 31.0f]. + // This target float range is convenient because the 4-bit source values can be placed directly into the + // target float bits. + // 2. Subtract the conversion offset of 16 from the float result. + + // The high 16 bits of an IEEE 754 32-bit float used as a template for creating float values. + constexpr uint16_t float_high_half_template = 0b0'10000011'0000000; + // sign|exponent|partial mantissa + // +|131: 2^4|~~~~ <- 4 bits go here + + const uint16x8_t float_high_half_template_v = vdupq_n_u16(float_high_half_template); + + float32x4_t acc[NCols]{}; + + const uint8_t* QuantBData = QuantBDataColPtr; + const float* QuantBScale = QuantBScaleColPtr; + size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + for (size_t k = 0; k < CountK; k += BlkLen) { + const size_t k_blk_len = std::min(CountK - k, BlkLen); + + float scale[NCols]; + UnrolledLoop( + [&](size_t i) { scale[i] = QuantBScale[i * StrideQuantBScale]; } + ); + + float offset[NCols]; // Includes zero point and float conversion offset of 16. + if (QuantBZeroPointColPtr != nullptr) { + UnrolledLoop([&](size_t i) { + const uint8_t zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + const uint8_t zp = ((QuantBZeroPointIdx & 1) == 1) ? (zp_packed >> 4) : (zp_packed & 0x0F); + offset[i] = 16.0f + zp; + }); + } else { + UnrolledLoop([&](size_t i) { + constexpr float zp = 8.0f; + offset[i] = 16.0f + zp; + }); + } + + constexpr size_t SubBlkLen = 16; // number of block elements to process in one iteration + + for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { + // load A row vector elements + + // load `SubBlkLen` elements from A, padded with 0's if there aren't enough + const size_t k_subblk_len = std::min(k_blk_len - k_idx_in_blk, SubBlkLen); + float32x4_t av[4]{}; + LoadData(ARowPtr + k + k_idx_in_blk, k_subblk_len, av); + + // load B column vectors + uint8x8_t bv_packed[NCols]; + UnrolledLoop([&](size_t i) { + const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; + bv_packed[i] = vld1_u8(QuantBData + i * StrideQuantBData + b_data_block_offset); + }); + + uint8x8_t bv_u8_unzipped[NCols][2]; + UnrolledLoop([&](size_t i) { + bv_u8_unzipped[i][0] = vand_u8(bv_packed[i], LowMask); + bv_u8_unzipped[i][1] = vand_u8(vshr_n_u8(bv_packed[i], 4), LowMask); + }); + + uint8x8_t bv_u8[NCols][2]; + UnrolledLoop([&](size_t i) { + bv_u8[i][0] = vzip1_u8(bv_u8_unzipped[i][0], bv_u8_unzipped[i][1]); + bv_u8[i][1] = vzip2_u8(bv_u8_unzipped[i][0], bv_u8_unzipped[i][1]); + }); + + // dequantize B + + // shift left 3 and widen to 16 bits + uint16x8_t bv_u16[NCols][2]; + UnrolledLoop([&](size_t i) { + constexpr int shift = 3; + bv_u16[i][0] = vshll_n_u8(bv_u8[i][0], shift); + bv_u16[i][1] = vshll_n_u8(bv_u8[i][1], shift); + }); + + // combine 4 bits with float high half template + UnrolledLoop([&](size_t i) { + bv_u16[i][0] = vorrq_u16(bv_u16[i][0], float_high_half_template_v); + bv_u16[i][1] = vorrq_u16(bv_u16[i][1], float_high_half_template_v); + }); + + // `SubBlkLen` floats of B + float32x4_t bv[NCols][4]; + + // shift left 16, widen to 32 bits, and reinterpret as float + UnrolledLoop([&](size_t i) { + constexpr int shift = 16; + bv[i][0] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][0]), shift)); + bv[i][1] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][0], shift)); + + bv[i][2] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][1]), shift)); + bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift)); + }); + + // subtract float conversion offset (16) and zero point + UnrolledLoop([&](size_t i) { + const float32x4_t offset_v = vdupq_n_f32(offset[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + + // multiply by scale + UnrolledLoop([&](size_t i) { + const float32x4_t scale_v = vdupq_n_f32(scale[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vmulq_f32(bv[i][j], scale_v); }); + }); + + // c[m,n] += a[m,k] * b[k,n] + UnrolledLoop<4>([&](size_t j) { + UnrolledLoop([&](size_t i) { acc[i] = vfmaq_f32(acc[i], av[j], bv[i][j]); }); + }); + } + + // increment pointers to next block + QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + QuantBScale += 1; + QuantBZeroPointIdx += 1; + } + + if constexpr (NCols == 4) { + float32x4_t sum = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + + if (BiasPtr != nullptr) { + sum = vaddq_f32(sum, vld1q_f32(BiasPtr)); + } + + vst1q_f32(SumPtr, sum); + } else { + for (size_t i = 0; i < NCols; ++i) { + SumPtr[i] = vaddvq_f32(acc[i]); + if (BiasPtr != nullptr) { + SumPtr[i] += BiasPtr[i]; + } + } + } +} + +} // namespace + +// +// MlasSQNBitGemmKernel and helpers. +// + +template +MLAS_FORCEINLINE void +MlasSQNBitGemmM1KernelNeon( + const float* A, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + constexpr size_t NCols = 4; + + const float* ARowPtr = A; + float* CRowPtr = C; + + const size_t BlockCountK = BlockStrideQuantB; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const uint8_t* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + int64_t nblk = static_cast(CountN) - NCols; + + while (nblk >= 0) { + ComputeDotProducts( + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next `NCols` columns + + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * StrideQuantBScale; + if (QuantBZeroPointColPtr != nullptr) { + QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols : 0; + SumPtr += NCols; + + nblk -= NCols; + } + + // left over columns less than `NCols`? + nblk += NCols; + for (int64_t n = 0; n < nblk; ++n) { + ComputeDotProducts( + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if (QuantBZeroPointColPtr != nullptr) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +#define SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(BlkBitWidth, BlkLen) \ + template <> \ + MLAS_FORCEINLINE void \ + MlasSQNBitGemmM1Kernel( \ + const float* A, \ + const uint8_t* QuantBData, \ + const float* QuantBScale, \ + const uint8_t* QuantBZeroPoint, \ + float* C, \ + size_t CountN, \ + size_t CountK, \ + size_t BlockStrideQuantB, \ + const float* Bias \ + ) \ + { \ + return MlasSQNBitGemmM1KernelNeon( \ + A, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, \ + BlockStrideQuantB, Bias \ + ); \ + } + +SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 16) +SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 32) +SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 64) +SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 128) +SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 256) + +#undef SPECIALIZE_SQNBIT_GEMM_M1_KERNEL + +// +// MlasQNBitBlkDequantBForSgemm and helpers. +// + +template +MLAS_FORCEINLINE void +MlasQNBitBlkDequantBForSgemmNeon( + float* FpData, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB +) +{ + auto impl0_reference = [&]() { + static_assert(BlkBitWidth == 4); + + float* Dst = FpData; + + const uint8_t* QuantBDataCol = QuantBData; + const float* QuantBScaleCol = QuantBScale; + const uint8_t* QuantBZeroPointCol = QuantBZeroPoint; + + for (size_t n = 0; n < CountN; n += 16) { + const size_t nnlen = std::min(CountN - n, size_t{16}); + + for (size_t nn = 0; nn < nnlen; ++nn) { + for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, k_blk_idx += 1) { + const size_t kklen = std::min(CountK - k, BlkLen); + + const uint8_t* b_data = + QuantBDataCol + k_blk_idx * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const float b_s = QuantBScaleCol[k_blk_idx]; + const uint8_t b_z = + (QuantBZeroPointCol != nullptr) + ? ((k_blk_idx & 1) == 1) + ? QuantBZeroPointCol[k_blk_idx / 2] >> 4 + : QuantBZeroPointCol[k_blk_idx / 2] & 0x0F + : 8; + + for (size_t kk = 0; kk < kklen; ++kk) { + const uint8_t b_packed = b_data[kk / 2]; + const uint8_t b_byte = ((kk & 1) == 1) ? b_packed >> 4 : b_packed & 0x0F; + const float b_value = (b_byte - b_z) * b_s; + + Dst[(k + kk) * 16 + nn] = b_value; + } + } + + QuantBDataCol += BlockStrideQuantB * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + QuantBScaleCol += BlockStrideQuantB; + if (QuantBZeroPointCol != nullptr) { + QuantBZeroPointCol += MlasQNBitZeroPointsForBlksSizeInBytes(BlockStrideQuantB); + } + } + + // zero out any remaining columns + + if (nnlen < 16) { + for (size_t k = 0; k < CountK; ++k) { + std::fill_n(Dst + (k * 16) + nnlen, 16 - nnlen, 0.0f); + } + } + + Dst += CountK * 16; + } + }; + + impl0_reference(); +} + +#define SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(BlkBitWidth, BlkLen) \ + template <> \ + MLAS_FORCEINLINE void \ + MlasQNBitBlkDequantBForSgemm( \ + float* FpData, \ + const uint8_t* QuantBData, \ + const float* QuantBScale, \ + const uint8_t* QuantBZeroPoint, \ + size_t CountN, \ + size_t CountK, \ + size_t BlockStrideQuantB \ + ) \ + { \ + MlasQNBitBlkDequantBForSgemmNeon( \ + FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB \ + ); \ + } + +SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 16) +SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 32) +SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 64) +SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 128) +SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 256) + +#undef SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM + +// +// Kernel dispatch structure definition. +// + +const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { + MLAS_SQNBIT_GEMM_DISPATCH d; + d.Operations[QuantVariant_BitWidth4_BlockSize16] = MlasSQNBitGemmOperation<4, 16, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + d.Operations[QuantVariant_BitWidth4_BlockSize32] = MlasSQNBitGemmOperation<4, 32, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + d.Operations[QuantVariant_BitWidth4_BlockSize64] = MlasSQNBitGemmOperation<4, 64, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + d.Operations[QuantVariant_BitWidth4_BlockSize128] = MlasSQNBitGemmOperation<4, 128, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + d.Operations[QuantVariant_BitWidth4_BlockSize256] = MlasSQNBitGemmOperation<4, 256, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + return d; +}(); diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 918ee0e6eb976..3c6217915bef0 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -72,21 +72,17 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop MlasTranspose(input1_f_vals.data(), input1_f_vals_trans.data(), K, N); #endif - int meta_rows; - int meta_cols; - MlasBlockwiseQuantMetaShape((int)block_size, true, (int)K, (int)N, meta_rows, meta_cols); + int q_rows, q_cols; + MlasBlockwiseQuantizedShape((int)block_size, true, (int)K, (int)N, q_rows, q_cols); - int q_rows; - int q_cols; - MlasBlockwiseQuantizedShape((int)block_size, true, (int)K, (int)N, q_rows, q_cols); + size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; + MlasBlockwiseQuantizedBufferSizes(4, static_cast(block_size), /* columnwise */ true, + static_cast(K), static_cast(N), + q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); - std::vector input1_vals(q_rows * q_cols); - std::vector scales(meta_rows * meta_cols); - - // TODO!! THIS SHOULD BE PROVIDED BY MLAS - // sub 8b packing always happen on the column dimension - const int packed_meta_rows = (meta_rows * QBits + 7) / 8; - std::vector zp(packed_meta_rows * meta_cols); + std::vector input1_vals(q_data_size_in_bytes); + std::vector scales(q_scale_size); + std::vector zp(q_zp_size_in_bytes); QuantizeDequantize(input1_f_vals, input1_vals, @@ -115,9 +111,9 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop if (use_float16) { test.AddInput("A", {M, K}, ToFloat16(input0_vals), false); test.AddInput("B", {q_cols, q_rows}, input1_vals, true); - test.AddInput("scales", {meta_cols * meta_rows}, ToFloat16(scales), true); + test.AddInput("scales", {static_cast(q_scale_size)}, ToFloat16(scales), true); if (has_zeropoint) { - test.AddInput("zero_points", {meta_cols * packed_meta_rows}, zp, true); + test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); } test.AddOutput("Y", {M, N}, ToFloat16(expected_vals)); @@ -129,9 +125,9 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop } else { test.AddInput("A", {M, K}, input0_vals, false); test.AddInput("B", {q_cols, q_rows}, input1_vals, true); - test.AddInput("scales", {meta_cols * meta_rows}, scales, true); + test.AddInput("scales", {static_cast(q_scale_size)}, scales, true); if (has_zeropoint) { - test.AddInput("zero_points", {meta_cols * packed_meta_rows}, zp, true); + test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); } test.AddOutput("Y", {M, N}, expected_vals); diff --git a/onnxruntime/test/mlas/bench/bench_q4gemm.cpp b/onnxruntime/test/mlas/bench/bench_q4gemm.cpp index cf02d4f3628f9..87e3601612761 100644 --- a/onnxruntime/test/mlas/bench/bench_q4gemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_q4gemm.cpp @@ -33,7 +33,7 @@ void Q4GEMM(benchmark::State& state, MLAS_BLK_QUANT_TYPE qtype) { auto B1 = RandomVectorUniform(static_cast(N * K), -1.0f, 1.0f); std::vector C1(static_cast(M * N)); - std::vector B1_packed(pack_b_size); + std::vector B1_packed(pack_b_size); MlasQ4GemmPackB(qtype, B1_packed.data(), B1.data(), N, K, N); MLAS_Q4_GEMM_DATA_PARAMS params1; diff --git a/onnxruntime/test/mlas/bench/bench_sgemm.cpp b/onnxruntime/test/mlas/bench/bench_sgemm.cpp index baa8f1a830ea1..e6e34bc88ad59 100644 --- a/onnxruntime/test/mlas/bench/bench_sgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sgemm.cpp @@ -128,7 +128,7 @@ BENCHMARK_CAPTURE(SGEMM, PACKB_TransA, true, true, false)->Apply(GemmSizeProduct static void GemmLLMSizeProducts(benchmark::internal::Benchmark* b) { b->ArgNames(sgemm_bench_arg_names); - ArgsProduct(b, {{1, 1024, 2048}, {4096}, {4096}}); + ArgsProduct(b, {{1, 1024, 2048}, {4096, 11008}, {4096, 11008}}); } BENCHMARK_CAPTURE(SGEMM, LLM, false, false, true)->Apply(GemmLLMSizeProducts)->UseRealTime(); diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp new file mode 100644 index 0000000000000..2f2635dab0512 --- /dev/null +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "mlas_q4.h" +#include "mlas_qnbit.h" + +#include + +#include "benchmark/benchmark.h" + +#include "bench_util.h" +#include "core/util/thread_utils.h" + +template +void SQNBITGEMM(benchmark::State& state) { + if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!"); + if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!"); + if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!"); + if (state.range(3) <= 0) throw std::invalid_argument("Threads must greater than 0!"); + + const size_t M = static_cast(state.range(0)); + const size_t N = static_cast(state.range(1)); + const size_t K = static_cast(state.range(2)); + const size_t threads = static_cast(state.range(3)); + + size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes; + MlasBlockwiseQuantizedBufferSizes( + BlkBitWidth, BlkLen, /* columnwise */ true, + static_cast(K), static_cast(N), + QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); + + OrtThreadPoolParams tpo; + tpo.thread_pool_size = static_cast(threads); + tpo.auto_set_affinity = true; + + std::unique_ptr tp( + onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), + tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); + + auto A = RandomVectorUniform(static_cast(M * K), -1.0f, 1.0f); + auto B = RandomVectorUniform(static_cast(K * N), -1.0f, 1.0f); + std::vector C(static_cast(M * N)); + + std::vector QuantBData(QuantBDataSizeInBytes); + std::vector QuantBScale(QuantBScaleSize); + std::vector QuantBZeroPoint(Symmetric ? 0 : QuantBZeroPointSizeInBytes); + + MlasQuantizeBlockwise(QuantBData.data(), QuantBScale.data(), + Symmetric ? nullptr : QuantBZeroPoint.data(), + B.data(), BlkLen, /* columnwise */ true, + static_cast(K), static_cast(N), static_cast(N), + tp.get()); + + MLAS_SQNBIT_GEMM_DATA_PARAMS params{}; + params.A = A.data(); + params.lda = K; + params.QuantBData = QuantBData.data(); + params.QuantBScale = QuantBScale.data(); + params.QuantBZeroPoint = Symmetric ? nullptr : QuantBZeroPoint.data(); + params.Bias = nullptr; + params.C = C.data(); + params.ldc = N; + + // warm up run + MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ¶ms, tp.get()); + + for (auto _ : state) { + MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ¶ms, tp.get()); + } +} + +static void GemmSizeProducts(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K", "Threads"}); + ArgsProduct(b, {{1, 1024, 2048}, {4096, 11008}, {4096, 11008}, {8}}); +} + +BENCHMARK(SQNBITGEMM<4, 16, false>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 16, true>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 32, false>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 32, true>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 64, false>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 64, true>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 128, false>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 128, true>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 256, false>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 256, true>)->Apply(GemmSizeProducts)->UseRealTime(); diff --git a/onnxruntime/test/mlas/unittest/test_blockq4.cpp b/onnxruntime/test/mlas/unittest/test_blockq4.cpp index f836da8277bb8..07f0748fb7ed1 100644 --- a/onnxruntime/test/mlas/unittest/test_blockq4.cpp +++ b/onnxruntime/test/mlas/unittest/test_blockq4.cpp @@ -38,13 +38,17 @@ class MlasBlockwiseQdqTest : public MlasTestBase { int meta_rows; int meta_cols; - MlasBlockwiseQuantMetaShape(block_size, columnwise, rows, columns, meta_rows, meta_cols); + MlasBlockwiseQuantMetaShape(block_size, columnwise, rows, columns, meta_rows, meta_cols); int q_rows; int q_cols; - MlasBlockwiseQuantizedShape(block_size, columnwise, rows, columns, q_rows, q_cols); + MlasBlockwiseQuantizedShape(block_size, columnwise, rows, columns, q_rows, q_cols); - uint8_t* elements = InputElements.GetBuffer(q_rows * q_cols, true); + size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; + MlasBlockwiseQuantizedBufferSizes(4, block_size, columnwise, rows, columns, + q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); + + uint8_t* elements = InputElements.GetBuffer(q_data_size_in_bytes, true); int v = 7; for (int c = 0; c < columns; c++) { @@ -70,8 +74,8 @@ class MlasBlockwiseQdqTest : public MlasTestBase { } } - float* scales = InputScales.GetBuffer(meta_rows * meta_cols); - uint8_t* zp = symmetric ? nullptr : InputOffsets.GetBuffer(((meta_rows + 1) / 2) * meta_cols, true); + float* scales = InputScales.GetBuffer(q_scale_size); + uint8_t* zp = symmetric ? nullptr : InputOffsets.GetBuffer(q_zp_size_in_bytes, true); if (zp) { for (int c = 0; c < meta_cols; c++) { for (int r = 0; r < meta_rows; r += 2) { diff --git a/onnxruntime/test/mlas/unittest/test_halfgemm.h b/onnxruntime/test/mlas/unittest/test_halfgemm.h index 2861b0e746fdc..4db5c2bebca40 100644 --- a/onnxruntime/test/mlas/unittest/test_halfgemm.h +++ b/onnxruntime/test/mlas/unittest/test_halfgemm.h @@ -18,20 +18,6 @@ Module Name: #include "test_fp16.h" -inline bool -CloseEnough(float actual, float expected) { - if (std::isnan(actual)) { - return std::isnan(expected); - } - float diff = std::abs(actual - expected); - float top = std::max(std::abs(actual), std::abs(expected)); - float ratio = 0; - if (top > 0.0001) { - ratio = diff / top; - } - return ratio < 0.005; -} - /** * @brief Test class for half precision GEMM * @tparam AType Data type of A matrix, can be either float or MLFp16 diff --git a/onnxruntime/test/mlas/unittest/test_main.cpp b/onnxruntime/test/mlas/unittest/test_main.cpp index 66b5a6a15db2b..505c0c01dfa90 100644 --- a/onnxruntime/test/mlas/unittest/test_main.cpp +++ b/onnxruntime/test/mlas/unittest/test_main.cpp @@ -1,17 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "test_util.h" - -#include #include +#include +#include + +#include "test_util.h" #if !defined(BUILD_MLAS_NO_ONNXRUNTIME) MLAS_THREADPOOL* GetMlasThreadPool(void) { - static MLAS_THREADPOOL* threadpool = new onnxruntime::concurrency::ThreadPool( + static auto threadpool = std::make_unique( &onnxruntime::Env::Default(), onnxruntime::ThreadOptions(), nullptr, 2, true); - return threadpool; + return threadpool.get(); } #else diff --git a/onnxruntime/test/mlas/unittest/test_q4gemm.h b/onnxruntime/test/mlas/unittest/test_q4gemm.h index 58a64491ae80b..97c6969b5bf91 100644 --- a/onnxruntime/test/mlas/unittest/test_q4gemm.h +++ b/onnxruntime/test/mlas/unittest/test_q4gemm.h @@ -19,20 +19,6 @@ Module Name: #include "test_util.h" #include "mlas_q4.h" -inline bool -CloseEnough(float actual, float expected) { - if (std::isnan(actual)) { - return std::isnan(expected); - } - float diff = std::abs(actual - expected); - float top = std::max(std::abs(actual), std::abs(expected)); - float ratio = 0; - if (top > 0.0001) { - ratio = diff / top; - } - return ratio < 0.005; -} - /** * @brief Test class for int4 block quantized GEMM * Note: only 2-D matmul supported for now diff --git a/onnxruntime/test/mlas/unittest/test_q8q4gemm.cpp b/onnxruntime/test/mlas/unittest/test_q8q4gemm.cpp index a78a3261d1f2a..d3f601793a970 100644 --- a/onnxruntime/test/mlas/unittest/test_q8q4gemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_q8q4gemm.cpp @@ -19,20 +19,6 @@ Module Name: #include "test_util.h" #include "mlas_q4.h" -inline bool -CloseEnough(float actual, float expected) { - if (std::isnan(actual)) { - return std::isnan(expected); - } - float diff = std::abs(actual - expected); - float top = std::max(std::abs(actual), std::abs(expected)); - float ratio = 0; - if (top > 0.0001) { - ratio = diff / top; - } - return ratio < 0.005; -} - template static void blkq8_dequant_reference(const int8_t* src, float* dst, size_t M, size_t K) { const size_t num_blks = K / QBlkLen; diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp new file mode 100644 index 0000000000000..6c97d60301573 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -0,0 +1,270 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_sqnbitgemm.h + +Abstract: + + Tests for MLAS n-bit int block quantized GEMM. + +--*/ + +#include "test_util.h" +#include "mlas_q4.h" +#include "mlas_qnbit.h" + +/** + * @brief Test class for n-bit int block quantized GEMM + * Note: only 2-D matmul supported for now + */ +template +class MlasSQNBitGemmTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferA; + MatrixGuardBuffer BufferB; + MatrixGuardBuffer BufferQuantBData; + MatrixGuardBuffer BufferQuantBZeroPoint; + MatrixGuardBuffer BufferQuantBScale; + MatrixGuardBuffer BufferDequantizedB; + MatrixGuardBuffer BufferBias; + MatrixGuardBuffer BufferC; + MatrixGuardBuffer BufferCReference; + + void CallGemm(size_t M, + size_t N, + size_t K, + const float* A, + size_t lda, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + const float* Bias, + float* C, + size_t ldc, + MLAS_THREADPOOL* Threadpool) { + MLAS_SQNBIT_GEMM_DATA_PARAMS params; + params.A = A; + params.lda = lda; + params.Bias = Bias; + params.C = C; + params.ldc = ldc; + params.QuantBData = QuantBData; + params.QuantBScale = QuantBScale; + params.QuantBZeroPoint = QuantBZeroPoint; + params.PostProcessor = nullptr; + + MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ¶ms, Threadpool); + } + + void CallReferenceGemm(size_t M, + size_t N, + size_t K, + const float* A, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + const float* Bias, + float* C) { + float* DequantizedBData = BufferDequantizedB.GetBuffer(K * N); + MlasDequantizeBlockwise( + DequantizedBData, QuantBData, QuantBScale, QuantBZeroPoint, BlkLen, /* columnwise */ true, + static_cast(K), static_cast(N), GetMlasThreadPool()); + // Note: DequantizedBData is in column major layout. + + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + const float* a = A + m * K; + const float* b = DequantizedBData + n * K; + float* c = C + (m * N) + n; + + float sum = Bias == nullptr ? 0.0f : Bias[n]; + for (size_t k = 0; k < K; k++) { + sum += (*a) * (*b); + b += 1; + a += 1; + } + *c = sum; + } + } + } + + public: + void Test(size_t M, size_t N, size_t K, + bool WithBias, bool Symmetric, bool WithThreadpool) { + MLAS_THREADPOOL* Threadpool = WithThreadpool ? GetMlasThreadPool() : nullptr; + + const float* A = BufferA.GetBuffer(K * M); + + const float* B = BufferB.GetBuffer(N * K); + + const float* Bias = nullptr; + if (WithBias) { + Bias = BufferBias.GetBuffer(N); + } + +#if 0 + auto print_matrix = [](size_t ncols, size_t nrows, const float* data) { + for (size_t row = 0; row < nrows; ++row) { + for (size_t col = 0; col < ncols; ++col) { + std::cout << data[row * nrows + col] << "\t"; + } + std::cout << "\n"; + } + }; + + std::cout << "A:\n"; + print_matrix(M, K, A); + std::cout << "B:\n"; + print_matrix(K, N, B); +#endif + + float* C = BufferC.GetBuffer(N * M, true); + float* CReference = BufferCReference.GetBuffer(N * M, true); + + // pack B + uint8_t* QuantBData = nullptr; + float* QuantBScale = nullptr; + uint8_t* QuantBZeroPoint = nullptr; + { + size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes; + MlasBlockwiseQuantizedBufferSizes(BlkBitWidth, BlkLen, /* columnwise */ true, + static_cast(K), static_cast(N), + QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); + + QuantBData = BufferQuantBData.GetBuffer(QuantBDataSizeInBytes); + QuantBScale = BufferQuantBScale.GetBuffer(QuantBScaleSize); + if (Symmetric) { + QuantBZeroPoint = BufferQuantBZeroPoint.GetBuffer(QuantBZeroPointSizeInBytes); + } + + MlasQuantizeBlockwise(QuantBData, QuantBScale, QuantBZeroPoint, + B, BlkLen, + /* columnwise */ true, + static_cast(K), static_cast(N), + static_cast(N), + GetMlasThreadPool()); + } + + CallGemm(M, N, K, A, /* lda */ K, QuantBData, QuantBScale, QuantBZeroPoint, Bias, C, /* ldc */ N, Threadpool); + CallReferenceGemm(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); + + size_t f = 0; + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++, f++) { + ASSERT_TRUE(CloseEnough(C[f], CReference[f])) + << "Expected: " << CReference[f] << " Actual: " << C[f] << "@[" << m << "x" << n << "], " + << "M=" << M << ", N=" << N << ", K=" << K; + } + } + } + + public: + static const char* GetTestSuiteName() { + static std::string suite_name = std::string("SQNBitGemm") + + "BlkBitWidth" + std::to_string(BlkBitWidth) + + "BlkLen" + std::to_string(BlkLen); + return suite_name.c_str(); + } +}; + +// +// Short Execute() test helper to register each test separately by all parameters. +// +template +class SQNBitGemmShortExecuteTest : public MlasTestFixture> { + public: + explicit SQNBitGemmShortExecuteTest(size_t M, size_t N, size_t K, + bool WithThreadpool, bool Symmetric, bool WithBias) + : M_(M), N_(N), K_(K), WithThreadpool_(WithThreadpool), Symmetric_(Symmetric), WithBias_(WithBias) { + } + + void TestBody() override { + MlasTestFixture>::mlas_tester->Test( + M_, N_, K_, WithThreadpool_, Symmetric_, WithBias_); + } + + static size_t RegisterSingleTest(size_t M, size_t N, size_t K, + bool WithThreadpool, bool Symmetric, bool WithBias) { + std::stringstream ss; + ss << (WithThreadpool ? "SingleThread" : "Threaded") + << "/isSymmetric" << Symmetric + << "/M" << M << "xN" << N << "xK" << K + << "/hasBias" << WithBias; + auto test_name = ss.str(); + + testing::RegisterTest( + MlasSQNBitGemmTest::GetTestSuiteName(), + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + // Important to use the fixture type as the return type here. + [=]() -> MlasTestFixture>* { + return new SQNBitGemmShortExecuteTest( + M, N, K, WithThreadpool, Symmetric, WithBias); + }); + + return 1; + } + + static size_t RegisterShortExecuteTests() { + size_t test_registered = 0; + + if (MlasIsSQNBitGemmAvailable(BlkBitWidth, BlkLen)) { + for (bool WithThreadpool : {false, true}) { + for (bool Symmetric : {false, true}) { + for (size_t b = 1; b < 16; b++) { + test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, false); + test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, true); + } + for (size_t b = 16; b <= 256; b <<= 1) { + test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, false); + test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, true); + } + for (size_t b = 256; b < 320; b += 32) { + test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, true); + } + for (size_t b = 1; b < 96; b++) { + test_registered += RegisterSingleTest(1, b, 32, WithThreadpool, Symmetric, false); + test_registered += RegisterSingleTest(1, 32, b, WithThreadpool, Symmetric, true); + test_registered += RegisterSingleTest(1, b, b, WithThreadpool, Symmetric, false); + } + test_registered += RegisterSingleTest(43, 500, 401, WithThreadpool, Symmetric, true); + + // test_registered += RegisterSingleTest(1001, 1027, 1031, WithThreadpool, Symmetric, false); + } + } + } + + return test_registered; + } + + private: + size_t M_, N_, K_; + bool WithThreadpool_, Symmetric_, WithBias_; +}; + +static size_t SQNBitGemmRegisterAllShortExecuteTests() { + size_t count = 0; + + count += SQNBitGemmShortExecuteTest<4, 16>::RegisterShortExecuteTests(); + count += SQNBitGemmShortExecuteTest<4, 32>::RegisterShortExecuteTests(); + count += SQNBitGemmShortExecuteTest<4, 64>::RegisterShortExecuteTests(); + count += SQNBitGemmShortExecuteTest<4, 128>::RegisterShortExecuteTests(); + count += SQNBitGemmShortExecuteTest<4, 256>::RegisterShortExecuteTests(); + + return count; +} + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + if (is_short_execute) { + return SQNBitGemmRegisterAllShortExecuteTests() > 0; + } + return false; +}); diff --git a/onnxruntime/test/mlas/unittest/test_util.h b/onnxruntime/test/mlas/unittest/test_util.h index db528ef7291cc..8eecda900ff27 100644 --- a/onnxruntime/test/mlas/unittest/test_util.h +++ b/onnxruntime/test/mlas/unittest/test_util.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -253,3 +254,16 @@ inline void ReorderInputNchw(const int64_t* input_shape, const float* S, float* D += spatial_count * nchwc_channel_count; } } + +inline bool CloseEnough(float actual, float expected) { + if (std::isnan(actual)) { + return std::isnan(expected); + } + float diff = std::abs(actual - expected); + float top = std::max(std::abs(actual), std::abs(expected)); + float ratio = 0; + if (top > 0.0001) { + ratio = diff / top; + } + return ratio < 0.005; +} diff --git a/onnxruntime/test/onnx/microbenchmark/reduceminmax.cc b/onnxruntime/test/onnx/microbenchmark/reduceminmax.cc index bd2abadf49b81..d866045ba4962 100644 --- a/onnxruntime/test/onnx/microbenchmark/reduceminmax.cc +++ b/onnxruntime/test/onnx/microbenchmark/reduceminmax.cc @@ -91,6 +91,8 @@ BENCHMARK(BM_FindMinMaxMlasSSE2) ->Arg(98304) ->Arg(160000); +#ifdef MLAS_TARGET_AMD64 + // MLAS avx implementation static void BM_FindMinMaxMlasAvx(benchmark::State& state) { const size_t batch_size = static_cast(state.range(0)); @@ -115,3 +117,5 @@ BENCHMARK(BM_FindMinMaxMlasAvx) ->Arg(80000) ->Arg(98304) ->Arg(160000); + +#endif // MLAS_TARGET_AMD64