Skip to content

Commit

Permalink
format code
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Jan 9, 2024
1 parent 4013aa5 commit 0406f31
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 195 deletions.
240 changes: 81 additions & 159 deletions onnxruntime/contrib_ops/cpu/quantization/bestla_gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,17 @@ Module Name:
using ThreadPool = onnxruntime::concurrency::ThreadPool;
namespace bestla {
ORTThreading::ORTThreading(void* tp)
: IThreading(ThreadPool::DegreeOfParallelism(reinterpret_cast<ThreadPool*>(tp))), mTp(tp) {
}
: IThreading(ThreadPool::DegreeOfParallelism(reinterpret_cast<ThreadPool*>(tp))), mTp(tp) {}

void ORTThreading::parallel_for(const parallel::thread_func& func) const {
ThreadPool::TrySimpleParallelFor(reinterpret_cast<ThreadPool*>(mTp), mThreadNum, [&](ptrdiff_t tid) {
func(static_cast<int>(tid));
});
ThreadPool::TrySimpleParallelFor(reinterpret_cast<ThreadPool*>(mTp), mThreadNum,
[&](ptrdiff_t tid) { func(static_cast<int>(tid)); });
}

template <class GemmCore_T>
static void
NSSQ4GemmCompF32(
const size_t M,
const size_t N,
const size_t K,
const float* A,
const size_t lda,
storage::gemm::StorageWeightKBlockNInteger* B,
float* C,
const size_t ldc,
int8_t* WorkSpace,
parallel::IThreading* th) {
static void NSSQ4GemmCompF32(const size_t M, const size_t N, const size_t K, const float* A, const size_t lda,
storage::gemm::StorageWeightKBlockNInteger* B, float* C, const size_t ldc,
int8_t* WorkSpace, parallel::IThreading* th) {
auto M_ = static_cast<int>(M);
auto N_ = static_cast<int>(N);
auto K_ = static_cast<int>(K);
Expand All @@ -50,10 +39,10 @@ NSSQ4GemmCompF32(
utils::GemmProblem gp(1, M_, N_, K_, B->mBlockSize);
if (M <= 16) {
using Parallel = parallel::gemm::SchedulerKBlock<GemmCore_T>;
using Launcher = wrapper::gemm::LauncherKBlock<
GemmCore_T::ISA, GemmCore_T, prologue_a::gemm::ActivationKBlockBaseF32,
prologue_b::gemm::WeightKBlockNInteger, epilogue::gemm::CompFp32BlockEpilogue,
epilogue::gemm::AccumulatorWriteBackFp32>;
using Launcher =
wrapper::gemm::LauncherKBlock<GemmCore_T::ISA, GemmCore_T, prologue_a::gemm::ActivationKBlockBaseF32,
prologue_b::gemm::WeightKBlockNInteger, epilogue::gemm::CompFp32BlockEpilogue,
epilogue::gemm::AccumulatorWriteBackFp32>;
static Launcher kernel;
auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize);
if (B->IsAsym()) {
Expand All @@ -62,39 +51,31 @@ NSSQ4GemmCompF32(
kernel.mProA.reduce({A, lda_, &reduceA}, M_, K_, B->mBlockSize, &single);
}
typename Launcher::BEpiParam blkargs{
B->template SPtr<int8_t>(), B->SDtype(), B->CStep(), B->template ZPtr<int8_t>(),
B->template SPtr<int8_t>(), B->SDtype(), B->CStep(), B->template ZPtr<int8_t>(),
reduceA.template RPtr<float>(), reduceA.lda};

typename Launcher::Param args{gp, {A, lda_, &reduceA}, {B}, blkargs, {C, ldc_}};
parallel::GemmRun<Parallel>(kernel, args, th);
} else {
using Parallel = parallel::gemm::SchedulerBase<GemmCore_T>;
using Launcher = wrapper::gemm::LauncherBase<
GemmCore_T::ISA, GemmCore_T, prologue_a::gemm::ActivationBase, prologue_b::gemm::WeightKBlockNInteger,
epilogue::gemm::AccumulatorWriteBackFp32>;
using Launcher =
wrapper::gemm::LauncherBase<GemmCore_T::ISA, GemmCore_T, prologue_a::gemm::ActivationBase,
prologue_b::gemm::WeightKBlockNInteger, epilogue::gemm::AccumulatorWriteBackFp32>;
static Launcher kernel;
typename Launcher::Param args{gp, {A, lda_}, {B}, {C, ldc_}};
parallel::GemmRun<Parallel>(kernel, args, th);
}
}

template <class GemmCore_T>
static void
NSSQ4GemmCompInt8(
const size_t M,
const size_t N,
const size_t K,
const float* A,
const size_t lda,
storage::gemm::StorageWeightKBlockNInteger* B,
float* C,
const size_t ldc,
int8_t* WorkSpace,
parallel::IThreading* th) {
static void NSSQ4GemmCompInt8(const size_t M, const size_t N, const size_t K, const float* A, const size_t lda,
storage::gemm::StorageWeightKBlockNInteger* B, float* C, const size_t ldc,
int8_t* WorkSpace, parallel::IThreading* th) {
using Parallel = parallel::gemm::SchedulerKBlockS<GemmCore_T>;
using Launcher = wrapper::gemm::LauncherIntKBlock<
GemmCore_T::ISA, GemmCore_T, prologue_a::gemm::ActivationF32KBlockQuantize,
prologue_b::gemm::WeightKBlockNInteger, epilogue::gemm::AccumulatorWriteBackFp32>;
using Launcher =
wrapper::gemm::LauncherIntKBlock<GemmCore_T::ISA, GemmCore_T, prologue_a::gemm::ActivationF32KBlockQuantize,
prologue_b::gemm::WeightKBlockNInteger,
epilogue::gemm::AccumulatorWriteBackFp32>;
auto M_ = static_cast<int>(M);
auto N_ = static_cast<int>(N);
auto K_ = static_cast<int>(K);
Expand All @@ -115,16 +96,9 @@ NSSQ4GemmCompInt8(
}

template <class GemmCore_T>
static size_t
NSSQ4GemmCompF32WorkspaceSize(
const size_t M,
const size_t N,
const size_t K,
const float* A,
const size_t lda,
storage::gemm::StorageWeightKBlockNInteger* B,
float* C,
const size_t ldc) {
static size_t NSSQ4GemmCompF32WorkspaceSize(const size_t M, const size_t N, const size_t K, const float* A,
const size_t lda, storage::gemm::StorageWeightKBlockNInteger* B, float* C,
const size_t ldc) {
auto M_ = static_cast<int>(M);
auto K_ = static_cast<int>(K);
(void)(N);
Expand All @@ -146,16 +120,9 @@ NSSQ4GemmCompF32WorkspaceSize(
}

template <class GemmCore_T>
static size_t
NSSQ4GemmCompInt8WorkspaceSize(
const size_t M,
const size_t N,
const size_t K,
const float* A,
const size_t lda,
storage::gemm::StorageWeightKBlockNInteger* B,
float* C,
const size_t ldc) {
static size_t NSSQ4GemmCompInt8WorkspaceSize(const size_t M, const size_t N, const size_t K, const float* A,
const size_t lda, storage::gemm::StorageWeightKBlockNInteger* B, float* C,
const size_t ldc) {
(void)(N);
(void)(lda);
(void)(ldc);
Expand All @@ -170,52 +137,42 @@ NSSQ4GemmCompInt8WorkspaceSize(

using namespace bestla;

Check warning on line 138 in onnxruntime/contrib_ops/cpu/quantization/bestla_gemm.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/cpu/quantization/bestla_gemm.cc:138: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]

bool NSSQ4GemmBatchDriver(
const size_t M,
const size_t N,
const size_t K,
const size_t BatchN,
const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams,
int8_t* WorkSpace,
void* ThreadPool) {
bool NSSQ4GemmBatchDriver(const size_t M, const size_t N, const size_t K, const size_t BatchN,
const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, int8_t* WorkSpace, void* ThreadPool) {
GetCPUDevice();
bestla::ORTThreading orth(ThreadPool);
bool processed = true;
for (size_t i = 0; i < BatchN; i++) {
auto ptr = bestla::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B);
auto uptr = std::unique_ptr<bestla::storage::gemm::IWeightBase>(ptr);
if (ptr) {
auto NTile =
gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT);
auto NTile = gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT);
auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId);
auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId);
auto btype = static_cast<gemm::CompType>(gemm::CompTypeHelper::get_B(CType));
if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) {
auto kptr = reinterpret_cast<bestla::storage::gemm::StorageWeightKBlockNInteger*>(ptr);
if (btype == gemm::CompType::tFP32 && PackRow == 1) {
if (NTile == bestla::tAVX512F::NTILE && _cd->AVX512F()) {
bestla::NSSQ4GemmCompF32<bestla::tAVX512F>(
M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc,
WorkSpace, &orth);
bestla::NSSQ4GemmCompF32<bestla::tAVX512F>(M, N, K, DataParams[i].A, DataParams[i].lda, kptr,
DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth);
} else if (NTile == bestla::tAVX2::NTILE && _cd->AVX2()) {
bestla::NSSQ4GemmCompF32<bestla::tAVX2>(
M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc,
WorkSpace, &orth);
bestla::NSSQ4GemmCompF32<bestla::tAVX2>(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C,
DataParams[i].ldc, WorkSpace, &orth);
}
}
if (btype == gemm::CompType::tS8 && PackRow == 4) {
if (NTile == bestla::tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8()) {
bestla::NSSQ4GemmCompInt8<bestla::tAMX_INT8_SS_KBlock>(
M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc,
WorkSpace, &orth);
bestla::NSSQ4GemmCompInt8<bestla::tAMX_INT8_SS_KBlock>(M, N, K, DataParams[i].A, DataParams[i].lda, kptr,
DataParams[i].C, DataParams[i].ldc, WorkSpace,
&orth);
} else if (NTile == bestla::tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI()) {
bestla::NSSQ4GemmCompInt8<bestla::tAVX512_VNNI_KBlock>(
M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc,
WorkSpace, &orth);
bestla::NSSQ4GemmCompInt8<bestla::tAVX512_VNNI_KBlock>(M, N, K, DataParams[i].A, DataParams[i].lda, kptr,
DataParams[i].C, DataParams[i].ldc, WorkSpace,
&orth);
} else if (NTile == bestla::tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI()) {
bestla::NSSQ4GemmCompInt8<bestla::tAVX_VNNI_KBlock>(
M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc,
WorkSpace, &orth);
bestla::NSSQ4GemmCompInt8<bestla::tAVX_VNNI_KBlock>(M, N, K, DataParams[i].A, DataParams[i].lda, kptr,
DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth);
}
}
}
Expand All @@ -227,13 +184,8 @@ bool NSSQ4GemmBatchDriver(
return processed;
}

size_t
NSSQ4GemmBatchWorkspaceSize(
const size_t M,
const size_t N,
const size_t K,
const size_t BatchN,
const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams) {
size_t NSSQ4GemmBatchWorkspaceSize(const size_t M, const size_t N, const size_t K, const size_t BatchN,
const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams) {
GetCPUDevice();
size_t size = 0;
for (size_t i = 0; i < BatchN; i++) {
Expand All @@ -249,33 +201,28 @@ NSSQ4GemmBatchWorkspaceSize(
auto btype = static_cast<gemm::CompType>(gemm::CompTypeHelper::get_B(CType));
if (btype == gemm::CompType::tFP32 && PackRow == 1) {
if (NTile == tAVX512F::NTILE && _cd->AVX512F()) {
size = std::max(
NSSQ4GemmCompF32WorkspaceSize<tAVX512F>(
M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc),
size);
size = std::max(NSSQ4GemmCompF32WorkspaceSize<tAVX512F>(M, N, K, DataParams[i].A, DataParams[i].lda, kptr,
DataParams[i].C, DataParams[i].ldc),
size);
} else if (NTile == tAVX2::NTILE && _cd->AVX2()) {
size = std::max(
NSSQ4GemmCompF32WorkspaceSize<tAVX2>(
M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc),
size);
size = std::max(NSSQ4GemmCompF32WorkspaceSize<tAVX2>(M, N, K, DataParams[i].A, DataParams[i].lda, kptr,
DataParams[i].C, DataParams[i].ldc),
size);
}
}
if (btype == gemm::CompType::tS8 && PackRow == 4) {
if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8()) {
size = std::max(
NSSQ4GemmCompInt8WorkspaceSize<tAMX_INT8_SS_KBlock>(
M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc),
size);
size = std::max(NSSQ4GemmCompInt8WorkspaceSize<tAMX_INT8_SS_KBlock>(
M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc),
size);
} else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI()) {
size = std::max(
NSSQ4GemmCompInt8WorkspaceSize<tAVX512_VNNI_KBlock>(
M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc),
size);
size = std::max(NSSQ4GemmCompInt8WorkspaceSize<tAVX512_VNNI_KBlock>(
M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc),
size);
} else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI()) {
size = std::max(
NSSQ4GemmCompInt8WorkspaceSize<tAVX_VNNI_KBlock>(
M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc),
size);
size = std::max(NSSQ4GemmCompInt8WorkspaceSize<tAVX_VNNI_KBlock>(

Check warning on line 223 in onnxruntime/contrib_ops/cpu/quantization/bestla_gemm.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for max [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/cpu/quantization/bestla_gemm.cc:223: Add #include <algorithm> for max [build/include_what_you_use] [4]
M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc),
size);
}
}
}
Expand All @@ -285,18 +232,15 @@ NSSQ4GemmBatchWorkspaceSize(
}

template <typename T>
static size_t
NSQ4BuSize(size_t block_size, size_t N, size_t K, bool isAsym) {
static size_t NSQ4BuSize(size_t block_size, size_t N, size_t K, bool isAsym) {
static T proB;
auto stor = proB.createStorage(
static_cast<int>(N), static_cast<int>(K), static_cast<int>(block_size), BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32,
BTLA_DTYPE::BF16, isAsym);
auto stor = proB.createStorage(static_cast<int>(N), static_cast<int>(K), static_cast<int>(block_size),
BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, BTLA_DTYPE::BF16, isAsym);
// TODO(Yu) support more scale dtype
return stor.mSize;
}

size_t
NSQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, int64_t CompType) {
size_t NSQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, int64_t CompType) {
GetCPUDevice();
if (K % BlkSize != 0) {
return 0;
Expand Down Expand Up @@ -333,24 +277,14 @@ NSQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, int64_t CompT
}

template <typename T>
static void
NSQ4GemmPackBImpl(
void* PackedBuf,
size_t BlkSize,
const uint8_t* QData,
const float* Scale,
const uint8_t* Zp,
size_t N,
size_t K,
bool IsAsym,
bool lastCall,
size_t ldb,
void* ThreadPool) {
static void NSQ4GemmPackBImpl(void* PackedBuf, size_t BlkSize, const uint8_t* QData, const float* Scale,
const uint8_t* Zp, size_t N, size_t K, bool IsAsym, bool lastCall, size_t ldb,
void* ThreadPool) {
static T proB;
auto N_ = static_cast<int>(N);
auto K_ = static_cast<int>(K);
auto stor = proB.createStorage(
N_, K_, static_cast<int>(BlkSize), BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, BTLA_DTYPE::BF16, IsAsym);
auto stor = proB.createStorage(N_, K_, static_cast<int>(BlkSize), BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32,
BTLA_DTYPE::BF16, IsAsym);
stor.assign(reinterpret_cast<int8_t*>(PackedBuf));
ORTThreading orth(ThreadPool);
proB.packNbitsWeightQ4(N_, K_, IsAsym, QData, static_cast<int>(ldb), Scale, Zp, &stor, &orth);
Expand All @@ -359,23 +293,12 @@ NSQ4GemmPackBImpl(
}
}

bool NSQ4GemmPackB(
void* PackedBuf,
const uint8_t* QData,
const float* Scale,
const uint8_t* Zp,
size_t N,
size_t K,
size_t ldb,
size_t BlkSize,
bool isAsym,
bool lastCall,
int64_t CompType,
void* ThreadPool) {
bool NSQ4GemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, size_t K,
size_t ldb, size_t BlkSize, bool isAsym, bool lastCall, int64_t CompType, void* ThreadPool) {
GetCPUDevice();
// explicit statement fall through.
switch (CompType) {
case 4: // int8
case 4:
if (!isAsym) { // asym int8 is not optimized, so fall through to others.
if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) {
NSQ4GemmPackBImpl<tWeiNInt<tAMX_INT8_SS_KBlock, tAMX_INT8_SS_KBlock::ISA>>(
Expand All @@ -388,23 +311,23 @@ bool NSQ4GemmPackB(
return true;
}
if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) {
NSQ4GemmPackBImpl<tWeiNInt<tAVX_VNNI_KBlock, tAVX_VNNI_KBlock::ISA>>(
PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool);
NSQ4GemmPackBImpl<tWeiNInt<tAVX_VNNI_KBlock, tAVX_VNNI_KBlock::ISA>>(PackedBuf, BlkSize, QData, Scale, Zp, N,
K, isAsym, lastCall, ldb, ThreadPool);
return true;
}
}
case 3: // bf16
case 2: // fp16
case 1: // fp32
case 3:
case 2:
case 1:
case 0:
if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) {
NSQ4GemmPackBImpl<tWeiNInt<tAVX512F, tAVX512F::ISA>>(
PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool);
NSQ4GemmPackBImpl<tWeiNInt<tAVX512F, tAVX512F::ISA>>(PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym,
lastCall, ldb, ThreadPool);
return true;
}
if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) {
NSQ4GemmPackBImpl<tWeiNInt<tAVX2, tAVX2::ISA>>(
PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool);
NSQ4GemmPackBImpl<tWeiNInt<tAVX2, tAVX2::ISA>>(PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall,
ldb, ThreadPool);
return true;
}
default:
Expand All @@ -422,8 +345,7 @@ bool NSQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, s
auto ldb_ = static_cast<int>(ldb);
GetCPUDevice();
if (ptr) {
auto NTile =
gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT);
auto NTile = gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT);
auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId);
auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId);
auto btype = static_cast<gemm::CompType>(gemm::CompTypeHelper::get_B(CType));
Expand Down
Loading

0 comments on commit 0406f31

Please sign in to comment.