-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
save work - got sqnbitgemm tests and a cpu impl
- Loading branch information
Showing
15 changed files
with
1,214 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#pragma once | ||
|
||
template<typename T> | ||
class MLAS_GEMM_POSTPROCESSOR | ||
{ | ||
public: | ||
virtual void Process(T* C, /**< the address of matrix to process */ | ||
size_t RangeStartM, /**< the start row index of matrix */ | ||
size_t RangeStartN, /**< the start col index of matrix */ | ||
size_t RangeCountM, /**< the element count per row to process */ | ||
size_t RangeCountN, /**< the element count per col to process */ | ||
size_t ldc /**< the leading dimension of matrix */ | ||
) const = 0; | ||
|
||
virtual ~MLAS_GEMM_POSTPROCESSOR() {} | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
/*++ | ||
Copyright (c) Microsoft Corporation. All rights reserved. | ||
Licensed under the MIT License. | ||
Module Name: | ||
mlas_qnbit.h | ||
Abstract: | ||
This module contains the public data structures and procedure prototypes | ||
for blocked n-bit quantized GEMM. | ||
N-bit block quantization is used to compress weight tensors of large | ||
language models. | ||
--*/ | ||
|
||
#pragma once | ||
|
||
#include "mlas.h" | ||
#include "mlas_gemm_postprocessor.h" | ||
|
||
/** | ||
* @brief Data parameters for float/n-bit quantized int GEMM routine. | ||
*/ | ||
struct MLAS_SQNBIT_GEMM_DATA_PARAMS { | ||
const float* A = nullptr; ///< address of A (float32 matrix) | ||
const void* PackedBData = nullptr; ///< address of B (quantized and packed n-bit int values) | ||
const float* PackedBScale = nullptr; ///< address of scale values of quantized B, one per block | ||
const void* PackedBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block | ||
bool IsBPacked = false; ///< whether B values are packed in the optimal format for the computation | ||
const float* Bias = nullptr; ///< optional address of Bias, vector size N | ||
float* C = nullptr; ///< address of result matrix | ||
size_t lda = 0; ///< leading dimension of A | ||
size_t ldc = 0; ///< leading dimension of C | ||
|
||
///< optional post processing to apply to result matrix | ||
MLAS_GEMM_POSTPROCESSOR<float>* PostProcessor = nullptr; | ||
}; | ||
|
||
/** | ||
* @brief Batched GEMM: C = A * B + Bias | ||
* A must be a float32 matrix | ||
* B must be a quantized and packed n-bit int matrix | ||
* | ||
* @param[in] M row size of matrix A and C | ||
* @param[in] N column size of matrix B and C | ||
* @param[in] K column size of matrix A and row size of matrix B | ||
* @param[in] BatchN number of batches | ||
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) | ||
* @param[in] BlkSize number of quantized values per block | ||
* @param[inout] DataParams An array (size BatchN) of parameter blocks | ||
* @param[in] ThreadPool optional thread pool to use | ||
*/ | ||
void MLASCALL | ||
MlasSQNBitGemmBatch( | ||
const size_t M, | ||
const size_t N, | ||
const size_t K, | ||
const size_t BatchN, | ||
const size_t BlkBitWidth, | ||
const size_t BlkSize, | ||
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, | ||
MLAS_THREADPOOL* ThreadPool = nullptr | ||
); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
/*++ | ||
Copyright (c) Microsoft Corporation. All rights reserved. | ||
Licensed under the MIT License. | ||
Module Name: | ||
sqnbitgemm.cpp | ||
Abstract: | ||
This module implements the float/quantized n-bit integer matrix | ||
multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch. | ||
--*/ | ||
|
||
#include "sqnbitgemm.h" | ||
|
||
namespace | ||
{ | ||
|
||
// Get quantization variant based on `BlkBitWidth` and `BlkLen`. | ||
// Return -1 if the input values are unsupported. | ||
int32_t | ||
GetDispatchQuantVariant(size_t BlkBitWidth, size_t BlkLen) | ||
{ | ||
int32_t type = -1; | ||
if (BlkBitWidth == 4 && BlkLen == 16) { | ||
type = QuantVariant_BitWidth4_BlockSize16; | ||
} else if (BlkBitWidth == 4 && BlkLen == 32) { | ||
type = QuantVariant_BitWidth4_BlockSize32; | ||
} | ||
|
||
return type; | ||
} | ||
|
||
} // namespace | ||
|
||
void MLASCALL | ||
MlasSQNBitGemmBatch( | ||
const size_t M, | ||
const size_t N, | ||
const size_t K, | ||
const size_t BatchN, | ||
const size_t BlkBitWidth, | ||
const size_t BlkLen, | ||
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, | ||
MLAS_THREADPOOL* ThreadPool | ||
) | ||
{ | ||
const int32_t QuantVariant = GetDispatchQuantVariant(BlkLen, BlkBitWidth); | ||
if (QuantVariant == -1) { | ||
MLAS_THROW_EX(std::invalid_argument, "Unsupported quantization block size / bit width."); | ||
} | ||
|
||
MLAS_SQNBIT_GEMM_OPERATION* const Operation = | ||
GetMlasPlatform().SQNBitGemmDispatch->Operations[QuantVariant]; | ||
|
||
if (Operation == nullptr) { | ||
MLAS_THROW_EX(std::invalid_argument, "FpQNBitGemm is not implemented on this platform."); | ||
} | ||
|
||
if (ThreadPool == nullptr) { | ||
for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { | ||
auto Data = &DataParams[gemm_i]; | ||
Operation(K, Data, 0, M, 0, N); | ||
} | ||
return; | ||
} | ||
|
||
// | ||
// Compute the number of target threads given the complexity of the SGEMM | ||
// operation. Small requests should run using the single threaded path. | ||
// | ||
|
||
const double Complexity = double(M) * double(N) * double(K) * double(BatchN); | ||
|
||
ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1; | ||
|
||
ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool) * 8; | ||
|
||
if (TargetThreadCount >= MaximumThreadCount) { | ||
TargetThreadCount = MaximumThreadCount; | ||
} | ||
|
||
ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchN; | ||
if (ThreadsPerGemm < 1) { | ||
ThreadsPerGemm = 1; | ||
} | ||
|
||
constexpr size_t StrideM = 128; | ||
|
||
size_t nc = N; | ||
if (ThreadsPerGemm > 1) { | ||
// more than one thread per GEMM | ||
|
||
const size_t BlockedM = MlasDivRoundup(M, StrideM); | ||
const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm); | ||
if (max_nc < nc) { | ||
nc = std::min( | ||
nc, MlasDivRoundup(max_nc, MLAS_QGEMM_STRIDEN_THREAD_ALIGN) * | ||
MLAS_QGEMM_STRIDEN_THREAD_ALIGN | ||
); | ||
} | ||
} | ||
const size_t StrideN = nc; | ||
|
||
const size_t ThreadCountM = MlasDivRoundup(M, StrideM); | ||
const size_t ThreadCountN = MlasDivRoundup(N, StrideN); | ||
ThreadsPerGemm = ThreadCountM * ThreadCountN; | ||
|
||
MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { | ||
const auto gemm_i = tid / ThreadsPerGemm; | ||
const auto blk_i = tid % ThreadsPerGemm; | ||
auto Data = &DataParams[gemm_i]; | ||
|
||
const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; | ||
const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; | ||
|
||
const size_t RangeStartM = ThreadIdM * StrideM; | ||
const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM); | ||
|
||
const size_t RangeStartN = ThreadIdN * StrideN; | ||
const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); | ||
|
||
Operation(K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN); | ||
}); | ||
} |
Oops, something went wrong.