Skip to content

Commit

Permalink
[ARM] MatMulNBits Fp16 support - API change only (#22826)
Browse files Browse the repository at this point in the history
### Description
A break-down PR of #22651
Op API change only.
- add template to functions and classes that support fp32 and fp16
- rename functions, classes and files that support fp32 and fp16 from
SQNBxxx to QNBxxx


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
fajin-corp authored Nov 14, 2024
1 parent c645bd2 commit c02b398
Show file tree
Hide file tree
Showing 31 changed files with 547 additions and 402 deletions.
4 changes: 2 additions & 2 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/qpostprocessor.cpp
${MLAS_SRC_DIR}/qlgavgpool.cpp
${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp
${MLAS_SRC_DIR}/sqnbitgemm.h
${MLAS_SRC_DIR}/sqnbitgemm.cpp
${MLAS_SRC_DIR}/qnbitgemm.h
${MLAS_SRC_DIR}/qnbitgemm.cpp
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
${MLAS_SRC_DIR}/flashattn.cpp
${MLAS_SRC_DIR}/cast.cpp
Expand Down
159 changes: 92 additions & 67 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

Large diffs are not rendered by default.

82 changes: 41 additions & 41 deletions onnxruntime/core/mlas/inc/mlas_qnbit.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,51 +27,50 @@ Module Name:
* @brief Define compute types of block quantization, in order of decreasing accuracy.
*/
typedef enum {
CompUndef = 0, /*!< undef */
CompFp32, /*!< input fp32, accumulator fp32 */
CompFp16, /*!< input fp16, accumulator fp16 */
CompBf16, /*!< input bf16, accumulator fp32 */
CompInt8, /*!< input int8, accumulator int32 */

// special values that should be the first and last actual values

CompMostAccurate = CompUndef,
CompLeastAccurate = CompInt8,
} MLAS_SQNBIT_GEMM_COMPUTE_TYPE;
SQNBIT_CompFp32, /*!< input fp32, accumulator fp32 */
HQNBIT_CompFp16, /*!< input fp16, accumulator fp16 */
BHQNBIT_CompBf16, /*!< input bf16, accumulator fp32 */
SQNBIT_CompInt8, /*!< input int8, accumulator int32, input fp32 */
HQNBIT_CompInt8, /*!< input int8, accumulator int32, input fp16 */
} MLAS_QNBIT_GEMM_COMPUTE_TYPE;

/**
* @brief Data parameters for float/n-bit quantized int GEMM routine.
*
* @tparam T data type of input A
*/
struct MLAS_SQNBIT_GEMM_DATA_PARAMS {
const float* A = nullptr; ///< address of A (float32 matrix)
template <typename T>
struct MLAS_QNBIT_GEMM_DATA_PARAMS {
const T* A = nullptr; ///< address of A (float32/16 matrix)
size_t lda = 0; ///< leading dimension of A
const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values)
const std::byte* PackedQuantBData = nullptr; /// address of packed quantized B data
const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block
const T* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block
const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block
const float* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block
const float* Bias = nullptr; ///< optional address of Bias, vector size N
float* C = nullptr; ///< address of result matrix
const T* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block
const T* Bias = nullptr; ///< optional address of Bias, vector size N
T* C = nullptr; ///< address of result matrix
size_t ldc = 0; ///< leading dimension of C

///< optional post processing to apply to result matrix
MLAS_GEMM_POSTPROCESSOR<float>* PostProcessor = nullptr;
MLAS_GEMM_POSTPROCESSOR<T>* PostProcessor = nullptr;
};

/**
* @brief Batched GEMM: C = A * B + Bias
* A must be a float32 matrix
* A must be a float32/16 matrix
* B must be a quantized and packed n-bit int matrix
*
* Call MlasIsSQNBitGemmAvailable() with the same parameters to determine whether this function may be called.
* Call MlasIsQNBitGemmAvailable() with the same parameters to determine whether this function may be called.
*
* Call MlasSQNBitGemmPackQuantBDataSize() with the same parameters to determine whether
* MLAS_SQNBIT_GEMM_DATA_PARAMS::QuantBData in `DataParams` should point to a buffer packed with
* MlasSQNBitGemmPackQuantBData().
* Call MlasQNBitGemmPackQuantBDataSize() with the same parameters to determine whether
* MLAS_QNBIT_GEMM_DATA_PARAMS::QuantBData in `DataParams` should point to a buffer packed with
* MlasQNBitGemmPackQuantBData().
*
* Call MlasSQNBitGemmBatchWorkspaceSize() with the same parameters to determine whether `Workspace` should
* Call MlasQNBitGemmBatchWorkspaceSize() with the same parameters to determine whether `Workspace` should
* point to an intermediate workspace buffer.
*
* @tparam T data type of input A
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
Expand All @@ -81,36 +80,37 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS {
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @param[in] Workspace Address of intermediate workspace buffer.
If MlasSQNBitGemmBatchWorkspaceSize() returns a non-zero value, this must be a
If MlasQNBitGemmBatchWorkspaceSize() returns a non-zero value, this must be a
buffer with at least that many bytes. Otherwise, it may be nullptr.
* @param[in] ThreadPool optional thread pool to use
*/
template <typename T>
void MLASCALL
MlasSQNBitGemmBatch(
MlasQNBitGemmBatch(
size_t M,
size_t N,
size_t K,
size_t BatchN,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType,
const MLAS_QNBIT_GEMM_DATA_PARAMS<T>* DataParams,
void* Workspace,
MLAS_THREADPOOL* ThreadPool = nullptr
);

/**
* @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform.
* @brief Determines whether a float32/16 quantized n-bit int GEMM implementation is available on the current platform.
*
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
bool MLASCALL
MlasIsSQNBitGemmAvailable(
MlasIsQNBitGemmAvailable(
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

/**
Expand All @@ -126,22 +126,22 @@ MlasIsSQNBitGemmAvailable(
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
size_t MLASCALL
MlasSQNBitGemmBatchWorkspaceSize(
MlasQNBitGemmBatchWorkspaceSize(
size_t M,
size_t N,
size_t K,
size_t BatchN,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

/**
* @brief Gets the size in bytes of the packed quantized B data.
* If non-zero, the quantized B data must first be packed by calling MlasSQNBitGemmPackQuantBData() with a buffer of
* this size, and then that packed quantized B data buffer must be passed to MlasSQNBitGemmBatch().
* If zero, MlasSQNBitGemmPackQuantBData() must not be called and the quantized B data must be directly passed to
* MlasSQNBitGemmBatch().
* If non-zero, the quantized B data must first be packed by calling MlasQNBitGemmPackQuantBData() with a buffer of
* this size, and then that packed quantized B data buffer must be passed to MlasQNBitGemmBatch().
* If zero, MlasQNBitGemmPackQuantBData() must not be called and the quantized B data must be directly passed to
* MlasQNBitGemmBatch().
*
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
Expand All @@ -150,12 +150,12 @@ MlasSQNBitGemmBatchWorkspaceSize(
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
size_t MLASCALL
MlasSQNBitGemmPackQuantBDataSize(
MlasQNBitGemmPackQuantBDataSize(
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

/**
Expand Down Expand Up @@ -186,12 +186,12 @@ MlasSQNBitGemmPackQuantBDataSize(
* @param[in] ThreadPool thread pool to use (no parallel if nullptr)
*/
void MLASCALL
MlasSQNBitGemmPackQuantBData(
MlasQNBitGemmPackQuantBData(
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType,
const void* QuantBData,
void* PackedQuantBDataAndOrBlkSum,
const void* QuantBScale,
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Module Name:
#include <type_traits>

#include "fp16_common.h"
#include "sqnbitgemm.h"
#include "qnbitgemm.h"
#include "sqnbitgemm_kernel_neon.h"

namespace sqnbitgemm_neon
Expand Down Expand Up @@ -131,7 +131,7 @@ HQ4BitGemmPackQuantBData_CompFp16(
size_t N,
size_t K,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType,
const std::byte* QuantBDataBegin,
std::byte* PackedQuantBDataBegin,
MLAS_THREADPOOL* ThreadPool
Expand Down
14 changes: 7 additions & 7 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -1017,17 +1017,17 @@ extern const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchAvx512;
// Float/quantized n-bit integer matrix/matrix multiply dispatch structure.
//

struct MLAS_SQNBIT_GEMM_DISPATCH;
struct MLAS_QNBIT_GEMM_DISPATCH;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon;
extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2;
extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni;
extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512;
extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni;
extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni;

//
// Quantized depthwise convolution kernels.
Expand Down Expand Up @@ -1184,7 +1184,7 @@ struct MLAS_PLATFORM {
const MLAS_FPQ4GEMM_DISPATCH* FpQ4GemmDispatch{nullptr};
const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr};

const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr};
const MLAS_QNBIT_GEMM_DISPATCH* QNBitGemmDispatch{nullptr};

MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel;
MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel;
Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ Return Value:
this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernelAvx2<int8_t, int8_t>;
this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernelAvx2<int8_t, uint8_t>;
this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelFma3;
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2;
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2;
this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx2;
this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2;

Expand Down Expand Up @@ -417,7 +417,7 @@ Return Value:
this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni;
this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni;
this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvxVnni;
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2vnni;
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2vnni;
}

#if !defined(ORT_MINIMAL_BUILD)
Expand Down Expand Up @@ -458,7 +458,7 @@ Return Value:
this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx512Core;
this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Core;
this->FpQ4GemmDispatch = &MlasFpQ4GemmDispatchAvx512;
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512;
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512;

//
// Check if the processor supports AVX512VNNI.
Expand All @@ -471,7 +471,7 @@ Return Value:
this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Vnni;
this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Vnni;
this->Q8Q4GemmDispatch = &MlasQ8Q4GemmDispatchAvx512vnni;
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512vnni;
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512vnni;
}
}
}
Expand Down Expand Up @@ -562,7 +562,7 @@ Return Value:
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot;

// MlasSQNBitGemmDispatchNeon has a dependency on dot product instructions
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;
}

#if defined(__linux__)
Expand Down
Loading

0 comments on commit c02b398

Please sign in to comment.