Skip to content

Commit

Permalink
save work - got sqnbitgemm tests and a cpu impl
Browse files Browse the repository at this point in the history
  • Loading branch information
edgchen1 committed Oct 26, 2023
1 parent 6de6339 commit 01ac345
Show file tree
Hide file tree
Showing 15 changed files with 1,214 additions and 63 deletions.
3 changes: 3 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/qpostprocessor.cpp
${MLAS_SRC_DIR}/qlgavgpool.cpp
${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp
${MLAS_SRC_DIR}/sqnbitgemm.cpp
)

if (NOT onnxruntime_ORT_MINIMAL_BUILD)
Expand Down Expand Up @@ -69,6 +70,7 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
)

set(mlas_platform_preprocess_srcs
Expand Down Expand Up @@ -336,6 +338,7 @@ else()
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
)
if (NOT APPLE)
set(mlas_platform_srcs
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/mlas/.clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

BasedOnStyle: Google
IndentWidth: 4
ColumnLimit: 100
# Setting ColumnLimit to 0 so developer choices about where to break lines are maintained.
# Developers are responsible for adhering to the 120 character maximum.
ColumnLimit: 0
AlignAfterOpenBracket: BlockIndent
AlwaysBreakAfterReturnType: TopLevel
AlwaysBreakTemplateDeclarations: Yes
BinPackParameters: false
BreakBeforeBraces: Linux
...

16 changes: 16 additions & 0 deletions onnxruntime/core/mlas/inc/mlas_gemm_postprocessor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

template<typename T>
class MLAS_GEMM_POSTPROCESSOR
{
public:
virtual void Process(T* C, /**< the address of matrix to process */
size_t RangeStartM, /**< the start row index of matrix */
size_t RangeStartN, /**< the start col index of matrix */
size_t RangeCountM, /**< the element count per row to process */
size_t RangeCountN, /**< the element count per col to process */
size_t ldc /**< the leading dimension of matrix */
) const = 0;

virtual ~MLAS_GEMM_POSTPROCESSOR() {}
};
23 changes: 4 additions & 19 deletions onnxruntime/core/mlas/inc/mlas_q4.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Module Name:
#pragma once

#include "mlas.h"
#include "mlas_gemm_postprocessor.h"

#include <math.h>
#include <algorithm>
Expand All @@ -39,7 +40,7 @@ typedef enum {
* @brief Computes the number of bytes required to pack and int4-quantize
* a weight matrix
* @param QType type of block quantization
* @param N the number of columns of matrix B.
* @param N the number of columns of matrix B.
* @param K the number of rows of matrix B.
* @return size of the packing buffer, 0 if the operation is not yet supported.
*/
Expand All @@ -53,11 +54,11 @@ MlasQ4GemmPackBSize(

/**
* @brief Prepack and Quantize fp32 weight tensor to int4 blocks
*
*
* @param QType type of block quantization
* @param PackedBuf destination buffer
* @param FpData the pointer to fp32 matrix
* @param N the number of columns of matrix B.
* @param N the number of columns of matrix B.
* @param K the number of rows of matrix B.
* @param ldb leading dimension of B
*/
Expand Down Expand Up @@ -95,22 +96,6 @@ MlasQ4GemmUnPackB(
);


template<typename T>
class MLAS_GEMM_POSTPROCESSOR
{
public:
virtual void Process(T*, /**< the address of matrix to process */
size_t, /**< the start row index of matrix */
size_t, /**< the start col index of matrix */
size_t, /**< the element count per row to process */
size_t, /**< the element count per col to process */
size_t /**< the leading dimension of matrix */
) const = 0;

virtual ~MLAS_GEMM_POSTPROCESSOR() {}
};


/**
* @brief Data parameters for Q4 GEMM routine
* C = A * B + Bias
Expand Down
68 changes: 68 additions & 0 deletions onnxruntime/core/mlas/inc/mlas_qnbit.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
mlas_qnbit.h
Abstract:
This module contains the public data structures and procedure prototypes
for blocked n-bit quantized GEMM.
N-bit block quantization is used to compress weight tensors of large
language models.
--*/

#pragma once

#include "mlas.h"
#include "mlas_gemm_postprocessor.h"

/**
* @brief Data parameters for float/n-bit quantized int GEMM routine.
*/
struct MLAS_SQNBIT_GEMM_DATA_PARAMS {
const float* A = nullptr; ///< address of A (float32 matrix)
const void* PackedBData = nullptr; ///< address of B (quantized and packed n-bit int values)
const float* PackedBScale = nullptr; ///< address of scale values of quantized B, one per block
const void* PackedBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block
bool IsBPacked = false; ///< whether B values are packed in the optimal format for the computation
const float* Bias = nullptr; ///< optional address of Bias, vector size N
float* C = nullptr; ///< address of result matrix
size_t lda = 0; ///< leading dimension of A
size_t ldc = 0; ///< leading dimension of C

///< optional post processing to apply to result matrix
MLAS_GEMM_POSTPROCESSOR<float>* PostProcessor = nullptr;
};

/**
* @brief Batched GEMM: C = A * B + Bias
* A must be a float32 matrix
* B must be a quantized and packed n-bit int matrix
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BatchN number of batches
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkSize number of quantized values per block
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @param[in] ThreadPool optional thread pool to use
*/
void MLASCALL
MlasSQNBitGemmBatch(
const size_t M,
const size_t N,
const size_t K,
const size_t BatchN,
const size_t BlkBitWidth,
const size_t BlkSize,
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
MLAS_THREADPOOL* ThreadPool = nullptr
);
18 changes: 18 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -882,15 +882,31 @@ extern const MLAS_CONV_SYM_DISPATCH MlasConvSymS8DispatchNeon;
extern const MLAS_CONV_SYM_DISPATCH MlasConvSymU8DispatchDot;
extern const MLAS_CONV_SYM_DISPATCH MlasConvSymS8DispatchDot;

//
// Quantized 8-bit integer/quantized 4-bit integer matrix/matrix multiply dispatch structure.
//

struct MLAS_Q8Q4GEMM_DISPATCH;

extern const MLAS_Q8Q4GEMM_DISPATCH MlasQ8Q4GemmDispatchAvx512vnni;

//
// Float/quantized 4-bit integer matrix/matrix multiply dispatch structure.
//

struct MLAS_FPQ4GEMM_DISPATCH;

extern const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchAvx512;
extern const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchNeon;

//
// Float/quantized n-bit integer matrix/matrix multiply dispatch structure.
//

struct MLAS_SQNBIT_GEMM_DISPATCH;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon;

//
// Quantized depthwise convolution kernels.
//
Expand Down Expand Up @@ -1022,6 +1038,8 @@ struct MLAS_PLATFORM {

const MLAS_FPQ4GEMM_DISPATCH* FpQ4GemmDispatch{nullptr};
const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr};

const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr};
};

inline
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ Return Value:
this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon;
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon;
this->FpQ4GemmDispatch = &MlasFpQ4GemmDispatchNeon;
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;

//
// Check if the processor supports ASIMD dot product instructions.
Expand Down
128 changes: 128 additions & 0 deletions onnxruntime/core/mlas/lib/sqnbitgemm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
sqnbitgemm.cpp
Abstract:
This module implements the float/quantized n-bit integer matrix
multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch.
--*/

#include "sqnbitgemm.h"

namespace
{

// Get quantization variant based on `BlkBitWidth` and `BlkLen`.
// Return -1 if the input values are unsupported.
int32_t
GetDispatchQuantVariant(size_t BlkBitWidth, size_t BlkLen)
{
int32_t type = -1;
if (BlkBitWidth == 4 && BlkLen == 16) {
type = QuantVariant_BitWidth4_BlockSize16;
} else if (BlkBitWidth == 4 && BlkLen == 32) {
type = QuantVariant_BitWidth4_BlockSize32;
}

return type;
}

} // namespace

void MLASCALL
MlasSQNBitGemmBatch(
const size_t M,
const size_t N,
const size_t K,
const size_t BatchN,
const size_t BlkBitWidth,
const size_t BlkLen,
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
MLAS_THREADPOOL* ThreadPool
)
{
const int32_t QuantVariant = GetDispatchQuantVariant(BlkLen, BlkBitWidth);
if (QuantVariant == -1) {
MLAS_THROW_EX(std::invalid_argument, "Unsupported quantization block size / bit width.");
}

MLAS_SQNBIT_GEMM_OPERATION* const Operation =
GetMlasPlatform().SQNBitGemmDispatch->Operations[QuantVariant];

if (Operation == nullptr) {
MLAS_THROW_EX(std::invalid_argument, "FpQNBitGemm is not implemented on this platform.");
}

if (ThreadPool == nullptr) {
for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) {
auto Data = &DataParams[gemm_i];
Operation(K, Data, 0, M, 0, N);
}
return;
}

//
// Compute the number of target threads given the complexity of the SGEMM
// operation. Small requests should run using the single threaded path.
//

const double Complexity = double(M) * double(N) * double(K) * double(BatchN);

ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1;

ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool) * 8;

if (TargetThreadCount >= MaximumThreadCount) {
TargetThreadCount = MaximumThreadCount;
}

ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchN;
if (ThreadsPerGemm < 1) {
ThreadsPerGemm = 1;
}

constexpr size_t StrideM = 128;

size_t nc = N;
if (ThreadsPerGemm > 1) {
// more than one thread per GEMM

const size_t BlockedM = MlasDivRoundup(M, StrideM);
const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm);
if (max_nc < nc) {
nc = std::min(
nc, MlasDivRoundup(max_nc, MLAS_QGEMM_STRIDEN_THREAD_ALIGN) *
MLAS_QGEMM_STRIDEN_THREAD_ALIGN
);
}
}
const size_t StrideN = nc;

const size_t ThreadCountM = MlasDivRoundup(M, StrideM);
const size_t ThreadCountN = MlasDivRoundup(N, StrideN);
ThreadsPerGemm = ThreadCountM * ThreadCountN;

MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) {
const auto gemm_i = tid / ThreadsPerGemm;
const auto blk_i = tid % ThreadsPerGemm;
auto Data = &DataParams[gemm_i];

const ptrdiff_t ThreadIdN = blk_i / ThreadCountM;
const ptrdiff_t ThreadIdM = blk_i % ThreadCountM;

const size_t RangeStartM = ThreadIdM * StrideM;
const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM);

const size_t RangeStartN = ThreadIdN * StrideN;
const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN);

Operation(K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN);
});
}
Loading

0 comments on commit 01ac345

Please sign in to comment.