Skip to content

Commit

Permalink
clean up and add doc comments
Browse files Browse the repository at this point in the history
  • Loading branch information
edgchen1 committed Oct 26, 2023
1 parent 01ac345 commit d7bd709
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 126 deletions.
41 changes: 26 additions & 15 deletions onnxruntime/core/mlas/inc/mlas_qnbit.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Module Name:
*/
struct MLAS_SQNBIT_GEMM_DATA_PARAMS {
const float* A = nullptr; ///< address of A (float32 matrix)
const void* PackedBData = nullptr; ///< address of B (quantized and packed n-bit int values)
const void* PackedBData = nullptr; ///< address of B (quantized and packed n-bit int values)
const float* PackedBScale = nullptr; ///< address of scale values of quantized B, one per block
const void* PackedBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block
bool IsBPacked = false; ///< whether B values are packed in the optimal format for the computation
Expand All @@ -46,23 +46,34 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS {
* A must be a float32 matrix
* B must be a quantized and packed n-bit int matrix
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BatchN number of batches
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkSize number of quantized values per block
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @param[in] ThreadPool optional thread pool to use
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BatchN number of batches
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @param[in] ThreadPool optional thread pool to use
*/
void MLASCALL
MlasSQNBitGemmBatch(
const size_t M,
const size_t N,
const size_t K,
const size_t BatchN,
const size_t BlkBitWidth,
const size_t BlkSize,
size_t M,
size_t N,
size_t K,
size_t BatchN,
size_t BlkBitWidth,
size_t BlkLen,
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
MLAS_THREADPOOL* ThreadPool = nullptr
);

/**
* @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform.
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
*/
bool MLASCALL
MlasIsSQNBitGemmAvailable(
size_t BlkBitWidth,
size_t BlkLen
);
34 changes: 23 additions & 11 deletions onnxruntime/core/mlas/lib/sqnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ GetDispatchQuantVariant(size_t BlkBitWidth, size_t BlkLen)
type = QuantVariant_BitWidth4_BlockSize16;
} else if (BlkBitWidth == 4 && BlkLen == 32) {
type = QuantVariant_BitWidth4_BlockSize32;
} else if (BlkBitWidth == 4 && BlkLen == 64) {
type = QuantVariant_BitWidth4_BlockSize64;
}

return type;
Expand All @@ -48,17 +50,8 @@ MlasSQNBitGemmBatch(
MLAS_THREADPOOL* ThreadPool
)
{

Check warning on line 52 in onnxruntime/core/mlas/lib/sqnbitgemm.cpp

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/mlas/lib/sqnbitgemm.cpp#L52

{ should almost always be at the end of the previous line [whitespace/braces] [4]
Raw output
onnxruntime/core/mlas/lib/sqnbitgemm.cpp:52:  { should almost always be at the end of the previous line  [whitespace/braces] [4]
const int32_t QuantVariant = GetDispatchQuantVariant(BlkLen, BlkBitWidth);
if (QuantVariant == -1) {
MLAS_THROW_EX(std::invalid_argument, "Unsupported quantization block size / bit width.");
}

MLAS_SQNBIT_GEMM_OPERATION* const Operation =
GetMlasPlatform().SQNBitGemmDispatch->Operations[QuantVariant];

if (Operation == nullptr) {
MLAS_THROW_EX(std::invalid_argument, "FpQNBitGemm is not implemented on this platform.");
}
const int32_t QuantVariant = GetDispatchQuantVariant(BlkBitWidth, BlkLen);
MLAS_SQNBIT_GEMM_OPERATION* const Operation = GetMlasPlatform().SQNBitGemmDispatch->Operations[QuantVariant];

if (ThreadPool == nullptr) {
for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) {
Expand Down Expand Up @@ -126,3 +119,22 @@ MlasSQNBitGemmBatch(
Operation(K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN);
});
}

bool MLASCALL
MlasIsSQNBitGemmAvailable(
size_t BlkBitWidth,
size_t BlkLen
)
{

Check warning on line 128 in onnxruntime/core/mlas/lib/sqnbitgemm.cpp

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/mlas/lib/sqnbitgemm.cpp#L128

{ should almost always be at the end of the previous line [whitespace/braces] [4]
Raw output
onnxruntime/core/mlas/lib/sqnbitgemm.cpp:128:  { should almost always be at the end of the previous line  [whitespace/braces] [4]
const int32_t QuantVariant = GetDispatchQuantVariant(BlkBitWidth, BlkLen);
if (QuantVariant == -1) {
return false;
}

if (GetMlasPlatform().SQNBitGemmDispatch == nullptr ||
GetMlasPlatform().SQNBitGemmDispatch->Operations[QuantVariant] == nullptr) {
return false;
}

return true;
}
64 changes: 41 additions & 23 deletions onnxruntime/core/mlas/lib/sqnbitgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,27 +35,28 @@ Module Name:
// Kernel implementation template declarations
//

/// <summary>
/// Multiply float matrix A with quantized n-bit integer matrix B.
/// </summary>
/// <typeparam name="KernelType">Hardware-specific kernel type.</typeparam>
/// <typeparam name="BlkLen">Number of values in a block.</typeparam>
/// <typeparam name="BlkBitWidth">Bit width of each value in a block.</typeparam>
/// <param name="A">Supplies the A matrix.</param>
/// <param name="PackedBData">Supplies the packed B matrix block data.</param>
/// <param name="PackedBScale">Supplies the packed B matrix block scale values.</param>
/// <param name="PackedBZeroPoint">Supplies the packed B matrix block zero point values. Optional.</param>
/// <param name="C">Supplies the output C matrix.</param>
/// <param name="CountM">Number of rows of A and C.</param>
/// <param name="CountN">Number of columns of B and C.</param>
/// <param name="CountK">Number of columns of A and rows of B.</param>
/// <param name="lda">Leading dimension of A.</param>
/// <param name="BlockStridePackedB">
/// Number of blocks between adjacent columns of B (packed B values are transposed).
/// </param>
/// <param name="ldc">Leading dimension of C.</param>
/// <param name="Bias">Bias vector of length N. Optional.</param>
/// <returns>Number of rows of A handled.</returns>
/**
* @brief Multiply float matrix A with quantized n-bit integer matrix B.
*
* @tparam BlkBitWidth Bit width of each value in a block.
* @tparam BlkLen Number of values in a block.
* @tparam KernelType Hardware-specific kernel type.
*
* @param A Supplies the A matrix.
* @param PackedBData Supplies the packed B matrix block data.
* @param PackedBScale Supplies the packed B matrix block scale values.
* @param PackedBZeroPoint Supplies the packed B matrix block zero point values. Optional.
* @param[out] C Supplies the output C matrix.
* @param CountM Number of rows of A and C.
* @param CountN Number of columns of B and C.
* @param CountK Number of columns of A and rows of B.
* @param lda Leading dimension of A.
* @param BlockStridePackedB Number of blocks between adjacent columns of B (packed B values are transposed).
* @param ldc Leading dimension of C.
* @param Bias Bias vector of length N.
*
* @return Number of rows of A handled.
*/
template <size_t BlkBitWidth, size_t BlkLen, typename KernelType>
MLAS_FORCEINLINE size_t
MlasSQNBitGemmKernel(
Expand All @@ -73,7 +74,23 @@ MlasSQNBitGemmKernel(
const float* Bias
);

// dequantize B into the format expected by MlasSgemmKernelZero
/**
* @brief Dequantize B into the format expected by the Sgemm kernel.
* This is equivalent to unpacking and dequantizing B and then running
* MlasSgemmCopyPackB.
*
* @tparam BlkBitWidth Bit width of each value in a block.
* @tparam BlkLen Number of values in a block.
* @tparam KernelType Hardware-specific kernel type.
*
* @param[out] FpData Supplies the output buffer for the B float data.
* @param PackedBData Supplies the packed B matrix block data.
* @param PackedBScale Supplies the packed B matrix block scale values.
* @param PackedBZeroPoint Supplies the packed B matrix block zero point values. Optional.
* @param CountN Number of columns of B.
* @param CountK Number of rows of B.
* @param BlockStridePackedB Number of blocks between adjacent columns of B (packed B values are transposed).
*/
template <size_t BlkBitWidth, size_t BlkLen, typename KernelType>
MLAS_FORCEINLINE void
MlasQNBitBlkDequantBForSgemm(
Expand All @@ -96,7 +113,7 @@ MlasQNBitBlkDataSizeInBytes(size_t BlkBitWidth, size_t BlkLen)
return BlkLen * BlkBitWidth / 8;
}

template<size_t BlkBitWidth>
template <size_t BlkBitWidth>
constexpr MLAS_FORCEINLINE size_t
MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount)
{

Check warning on line 119 in onnxruntime/core/mlas/lib/sqnbitgemm.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/mlas/lib/sqnbitgemm.h#L119

{ should almost always be at the end of the previous line [whitespace/braces] [4]
Raw output
onnxruntime/core/mlas/lib/sqnbitgemm.h:119:  { should almost always be at the end of the previous line  [whitespace/braces] [4]
Expand Down Expand Up @@ -267,6 +284,7 @@ typedef void(MLASCALL MLAS_SQNBIT_GEMM_OPERATION)(
enum QuantVariant {
QuantVariant_BitWidth4_BlockSize16,
QuantVariant_BitWidth4_BlockSize32,
QuantVariant_BitWidth4_BlockSize64,
QuantVariantCount, // keep this element last
};

Expand Down
93 changes: 54 additions & 39 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,28 +88,35 @@ MlasSQNBitGemmKernelNeon(
return impl0_reference();
}

template <>
MLAS_FORCEINLINE size_t
MlasSQNBitGemmKernel<4, 32, MLAS_SQNBIT_GEMM_KERNEL_NEON>(
const float* A,
const uint8_t* PackedBData,
const float* PackedBScale,
const uint8_t* PackedBZeroPoint,
float* C,
size_t CountM,
size_t CountN,
size_t CountK,
size_t lda,
size_t BlockStridePackedB,
size_t ldc,
const float* Bias
)
{
return MlasSQNBitGemmKernelNeon<4, 32>(
A, PackedBData, PackedBScale, PackedBZeroPoint, C, CountM, CountN, CountK, lda,
BlockStridePackedB, ldc, Bias
);
}
#define SPECIALIZE_SQNBIT_GEMM_KERNEL(BlkBitWidth, BlkLen) \
template <> \
MLAS_FORCEINLINE size_t \
MlasSQNBitGemmKernel<BlkBitWidth, BlkLen, MLAS_SQNBIT_GEMM_KERNEL_NEON>( \
const float* A, \
const uint8_t* PackedBData, \
const float* PackedBScale, \
const uint8_t* PackedBZeroPoint, \
float* C, \
size_t CountM, \
size_t CountN, \
size_t CountK, \
size_t lda, \
size_t BlockStridePackedB, \
size_t ldc, \
const float* Bias \
) \

Check warning on line 107 in onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp#L107

Closing ) should be moved to the previous line [whitespace/parens] [2]
Raw output
onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp:107:  Closing ) should be moved to the previous line  [whitespace/parens] [2]
{ \
return MlasSQNBitGemmKernelNeon<BlkBitWidth, BlkLen>( \
A, PackedBData, PackedBScale, PackedBZeroPoint, C, CountM, CountN, CountK, lda, \
BlockStridePackedB, ldc, Bias \
); \

Check warning on line 112 in onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp#L112

Closing ) should be moved to the previous line [whitespace/parens] [2]
Raw output
onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp:112:  Closing ) should be moved to the previous line  [whitespace/parens] [2]
}

SPECIALIZE_SQNBIT_GEMM_KERNEL(4, 16)
SPECIALIZE_SQNBIT_GEMM_KERNEL(4, 32)
SPECIALIZE_SQNBIT_GEMM_KERNEL(4, 64)

#undef SPECIALIZE_SQNBIT_GEMM_KERNEL

//
// MlasQNBitBlkDequantBForSgemm and helpers.
Expand Down Expand Up @@ -140,7 +147,6 @@ MlasQNBitBlkDequantBForSgemmNeon(
const size_t nnlen = std::min(CountN - n, size_t{16});

for (size_t nn = 0; nn < nnlen; ++nn) {

for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, k_blk_idx += 1) {
const size_t kklen = std::min(CountK - k, BlkLen);

Expand Down Expand Up @@ -185,29 +191,38 @@ MlasQNBitBlkDequantBForSgemmNeon(
impl0_reference();
}

template <>
MLAS_FORCEINLINE void
MlasQNBitBlkDequantBForSgemm<4, 32, MLAS_SQNBIT_GEMM_KERNEL_NEON>(
float* FpData,
const uint8_t* PackedBData,
const float* PackedBScale,
const uint8_t* PackedBZeroPoint,
size_t CountN,
size_t CountK,
size_t BlockStridePackedB
)
{
MlasQNBitBlkDequantBForSgemmNeon<4, 32>(
FpData, PackedBData, PackedBScale, PackedBZeroPoint, CountN, CountK, BlockStridePackedB
);
}
#define SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(BlkBitWidth, BlkLen) \
template <> \
MLAS_FORCEINLINE void \
MlasQNBitBlkDequantBForSgemm<BlkBitWidth, BlkLen, MLAS_SQNBIT_GEMM_KERNEL_NEON>( \
float* FpData, \
const uint8_t* PackedBData, \
const float* PackedBScale, \
const uint8_t* PackedBZeroPoint, \
size_t CountN, \
size_t CountK, \
size_t BlockStridePackedB \
) \

Check warning on line 205 in onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp#L205

Closing ) should be moved to the previous line [whitespace/parens] [2]
Raw output
onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp:205:  Closing ) should be moved to the previous line  [whitespace/parens] [2]
{ \
MlasQNBitBlkDequantBForSgemmNeon<BlkBitWidth, BlkLen>( \
FpData, PackedBData, PackedBScale, PackedBZeroPoint, CountN, CountK, BlockStridePackedB \
); \

Check warning on line 209 in onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp#L209

Closing ) should be moved to the previous line [whitespace/parens] [2]
Raw output
onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp:209:  Closing ) should be moved to the previous line  [whitespace/parens] [2]
}

SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 16)
SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 32)
SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 64)

#undef SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM

//
// Kernel dispatch structure definition.
//

const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() {
MLAS_SQNBIT_GEMM_DISPATCH d;
d.Operations[QuantVariant_BitWidth4_BlockSize16] = MlasSQNBitGemmOperation<4, 16, MLAS_SQNBIT_GEMM_KERNEL_NEON>;
d.Operations[QuantVariant_BitWidth4_BlockSize32] = MlasSQNBitGemmOperation<4, 32, MLAS_SQNBIT_GEMM_KERNEL_NEON>;
d.Operations[QuantVariant_BitWidth4_BlockSize64] = MlasSQNBitGemmOperation<4, 64, MLAS_SQNBIT_GEMM_KERNEL_NEON>;
return d;
}();
Loading

0 comments on commit d7bd709

Please sign in to comment.