Skip to content

Commit

Permalink
avx512 blklen64 pass
Browse files Browse the repository at this point in the history
Signed-off-by: liqunfu <[email protected]>
  • Loading branch information
liqunfu committed Jul 4, 2024
1 parent 283fd2d commit d035939
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 83 deletions.
68 changes: 30 additions & 38 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ Q4Int8GemmR2xC4BlkLen64Avx512(
constexpr size_t PerAccuBlk2 = 2;

const size_t lda = BlockCountK * BlkLen64;
const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes;
const size_t StrideQuantBScale = BlockCountK;
const size_t StrideQuantBData = PerAccuBlk2 * BlkDataSizeInBytes;

Check warning

Code scanning / PREfast

The const variable 'StrideQuantBData' can be computed at compile-time. Consider using constexpr (con.5). Warning

The const variable 'StrideQuantBData' can be computed at compile-time. Consider using constexpr (con.5).
//const size_t StrideQuantBScale = BlockCountK;

assert(CountM % NRows2 == 0);
assert(CountN % NCols4 == 0);
Expand Down Expand Up @@ -275,20 +275,16 @@ Q4Int8GemmR2xC4BlkLen64Avx512(
const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda));
const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64));

accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8,
QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]);
accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8,
QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, acc[1], acc[NCols4 + 1]);
accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8,
QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], acc[NCols4 + 2]);
accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8,
QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], acc[NCols4 + 3]);
accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]);
accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk2, acc[1], acc[NCols4 + 1]);
accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk2, acc[2], acc[NCols4 + 2]);
accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk2, acc[3], acc[NCols4 + 3]);

// increment block pointers
QuantAPtr += BlkLen64 * PerAccuBlk2;
QuantAScalePtr += PerAccuBlk2;
QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2;
QuantBScalePtr += PerAccuBlk2;
QuantBDataPtr += StrideQuantBData * NCols4;
QuantBScalePtr += PerAccuBlk2 * NCols4;
} // k_blks_remaining

while (k_blks_remaining-- > 0) {
Expand All @@ -298,16 +294,16 @@ Q4Int8GemmR2xC4BlkLen64Avx512(
accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8,
QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]);
accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8,
QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, acc[1], acc[NCols4 + 1]);
QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]);
accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8,
QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], acc[NCols4 + 2]);
QuantBDataPtr + 2 * BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]);
accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8,
QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], acc[NCols4 + 3]);
QuantBDataPtr + 3 * BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]);

QuantAPtr += BlkLen64;
QuantAScalePtr++;
QuantBDataPtr += BlkDataSizeInBytes;
QuantBScalePtr++;
QuantBDataPtr += BlkDataSizeInBytes * NCols4;
QuantBScalePtr += NCols4;
}

#if 1
Expand Down Expand Up @@ -341,8 +337,8 @@ Q4Int8GemmR2xC4BlkLen64Avx512(
_mm_storeu_ps(SumPtr + ldc, acc_r1);
#endif
// move to next NCols columns
QuantBDataColPtr += NCols4 * StrideQuantBData;
QuantBScaleColPtr += NCols4 * StrideQuantBScale;
QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes;
QuantBScaleColPtr += NCols4 * BlockCountK;
BiasPtr += BiasPtr != nullptr ? NCols4 : 0;
SumPtr += NCols4;
}
Expand Down Expand Up @@ -465,8 +461,8 @@ Q4Int8GemmR1xC4BlkLen64Avx512(
constexpr size_t PerAccuBlk2 = 2;

const size_t lda = BlockCountK * BlkLen;
const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen);
const size_t StrideQuantBScale = BlockCountK;
//const size_t StrideQuantBData = PerAccuBlk2 * BlkDataSizeInBytes;
//const size_t StrideQuantBScale = BlockCountK;

//assert(CountM < NRows2);
//assert(CountN % NCols4 == 0);
Expand All @@ -490,34 +486,30 @@ Q4Int8GemmR1xC4BlkLen64Avx512(
for (; k_blks_remaining >= PerAccuBlk2; k_blks_remaining -= PerAccuBlk2) {
const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr);
const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64));
accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8,
QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]);
accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8,
QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]);
accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8,
QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]);
accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8,
QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]);
accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]);
accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + PerAccuBlk2, acc[1]);
accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk2, acc[2]);
accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 3 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk2, acc[3]);

// increment block pointers
QuantAPtr += BlkLen64 * PerAccuBlk2;
QuantAScalePtr += PerAccuBlk2;
QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2;
QuantBScalePtr += PerAccuBlk2;
QuantBDataPtr += PerAccuBlk2 * BlkDataSizeInBytes * NCols4;
QuantBScalePtr += PerAccuBlk2 * NCols4;
}

while (k_blks_remaining-- > 0) {
const __m512i av_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr);

accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]);
accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]);
accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]);
accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]);
accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]);
accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]);
accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]);

QuantAPtr += BlkLen64;
QuantAScalePtr++;
QuantBDataPtr += BlkDataSizeInBytes;
QuantBScalePtr++;
QuantBDataPtr += BlkDataSizeInBytes * NCols4;
QuantBScalePtr += NCols4;
}

__m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]);
Expand All @@ -528,8 +520,8 @@ Q4Int8GemmR1xC4BlkLen64Avx512(
_mm_storeu_ps(SumPtr, acc_r0);

// move to next NCols columns
QuantBDataColPtr += NCols4 * StrideQuantBData;
QuantBScaleColPtr += NCols4 * StrideQuantBScale;
QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes;
QuantBScaleColPtr += NCols4 * BlockCountK;
BiasPtr += BiasPtr != nullptr ? NCols4 : 0;
SumPtr += NCols4;
}
Expand Down
64 changes: 32 additions & 32 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ SQ4BitGemmPackQuantBDataSize(
}

static size_t
GetContinueLayoutOffset64(size_t N, const size_t n, const size_t SubOrBlkCountK, const size_t k_sub_or_blk)
GetContinueLayoutOffsetSubBlk(size_t N, const size_t n, const size_t SubOrBlkCountK, const size_t k_sub_or_blk)

Check warning

Code scanning / PREfast

You can attempt to make 'GetContinueLayoutOffsetBlkInSubBlk' constexpr unless it contains any undefined behavior (f.4). Warning

You can attempt to make 'GetContinueLayoutOffsetBlkInSubBlk' constexpr unless it contains any undefined behavior (f.4).
{
size_t T = n / 4, t = n % 4;
bool te = T == N / 4;
Expand All @@ -47,19 +47,19 @@ GetContinueLayoutOffset64(size_t N, const size_t n, const size_t SubOrBlkCountK,
}

static size_t
GetContinueLayoutOffset32(size_t N, const size_t n, const size_t BlockCountK, const size_t k_blk)
GetContinueLayoutOffsetBlkInSubBlk(size_t N, const size_t n, const size_t BlockCountK, const size_t k_blk, const int blks_per_sub)

Check warning

Code scanning / PREfast

You can attempt to make 'GetContinueLayoutOffsetBlkInSubBlk' constexpr unless it contains any undefined behavior (f.4). Warning

You can attempt to make 'GetContinueLayoutOffsetBlkInSubBlk' constexpr unless it contains any undefined behavior (f.4).
{
size_t T = n / 4, t = n % 4, k_subblk = k_blk / 2, b = k_blk % 2;
bool te = T == N / 4, be = k_subblk == BlockCountK / 2;
size_t T = n / 4, t = n % 4, k_subblk = k_blk / blks_per_sub, b = k_blk % blks_per_sub;
bool te = T == N / 4, be = k_subblk == BlockCountK / blks_per_sub;
size_t scale_dst_offset = T * 4 * BlockCountK;
if (te) {
scale_dst_offset += t * BlockCountK + k_blk;
} else {
scale_dst_offset += k_subblk * 2 * 4;
scale_dst_offset += k_subblk * blks_per_sub * 4;
if (be) {
scale_dst_offset += b * 4 + t;
} else {
scale_dst_offset += t * 2 + b;
scale_dst_offset += t * blks_per_sub + b;
}
}
return scale_dst_offset;
Expand Down Expand Up @@ -132,37 +132,35 @@ PackQuantB(
const size_t k_blks_remaining = BlockCountK - (SubBlkCountK - 1) * SubBlkLen / BlkLen;
for (size_t k = 0; k < k_blks_remaining; k++) {
const size_t k_blk = k_subblk * SubBlkLen / BlkLen + k;
if (BlkLen == 16 || SubBlkLen == 128) { // TODO:
if (BlkLen == 16) {
// not to do the compute order layout yet
std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset;
pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData + k * BlkLen / 2, PackBytePairCount, PackDataSize);
} else if (BlkLen == 32) {
const size_t dst_data_offset = GetContinueLayoutOffset32(N, n, BlockCountK, k_blk);
} else if (BlkLen >= SubBlkLen) {
// shall not reach here with avx2
assert(SubBlkLen == 128);
} else {
int blks_per_sub = (int)(SubBlkLen / BlkLen);
const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub);
std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2;
pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData + k * BlkLen / 2, PackBytePairCount, PackDataSize);
} else {
// shall not reach here (with avx2?)
assert(false);
}
}
}
else
{
if (BlkLen == 16 || SubBlkLen == 128) {
} else {
if (BlkLen == 16) {
// not to do the compute order layout yet
std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset;
pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize);
}
else if (BlkLen == 32) {
const size_t k_blk = k_subblk * SubBlkLen / BlkLen;
const size_t dst_data_offset = GetContinueLayoutOffset32(N, n, BlockCountK, k_blk);
std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2;
pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize);
}
else { // if (BlkLen > 32)
const size_t dst_data_offset = GetContinueLayoutOffset64(N, n, SubBlkCountK, k_subblk);
} else if (BlkLen >= SubBlkLen) {
const size_t dst_data_offset = GetContinueLayoutOffsetSubBlk(N, n, SubBlkCountK, k_subblk);
std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * SubBlkDataSize;
pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize);
} else {
int blks_per_sub = (int)(SubBlkLen / BlkLen);
const size_t k_blk = k_subblk * blks_per_sub;
const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub);
std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2;
pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize);
}
}
}
Expand All @@ -173,7 +171,8 @@ PackQuantB(

static void
ComputePackBlkSum(
size_t Blklen,
size_t BlkLen,
size_t SubBlkLen,
size_t N,
float* QuantBScaleBegin,
const std::byte* QuantBZPBegin,
Expand Down Expand Up @@ -208,13 +207,14 @@ ComputePackBlkSum(
const size_t dst_offset = ((n / 16) * BlockCountK + k_blk) * 16 + n % 16;
#endif
*(BlockSumBegin + dst_offset) = -QuantBScale * zp;
if (true || Blklen == 16) { // TODO
if (BlkLen == 16) { // TODO

} else if (Blklen == 32) {
size_t scale_dst_offset = GetContinueLayoutOffset32(N, n, BlockCountK, k_blk);
} else if (BlkLen >= SubBlkLen) {
const size_t scale_dst_offset = GetContinueLayoutOffsetSubBlk(N, n, BlockCountK, k_blk);
*(QuantBScaleBegin + scale_dst_offset) = QuantBScale;
} else if (Blklen > 32) {
const size_t scale_dst_offset = GetContinueLayoutOffset64(N, n, BlockCountK, k_blk);
} else {
int blks_per_sub = (int)(SubBlkLen / BlkLen);
size_t scale_dst_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub);
*(QuantBScaleBegin + scale_dst_offset) = QuantBScale;
}
}
Expand Down Expand Up @@ -248,7 +248,7 @@ PackQuantBDataAndBlkSum(
}

if ((QuantBScaleBegin && !has_zp_input) || QuantBZPBegin) {
ComputePackBlkSum(BlkLen, N, packed_quant_b.PackedQuantBScale, QuantBZPBegin, packed_quant_b.QuantBBlkSum, ThreadPool, BlockCountK);
ComputePackBlkSum(BlkLen, SubBlkLen, N, packed_quant_b.PackedQuantBScale, QuantBZPBegin, packed_quant_b.QuantBBlkSum, ThreadPool, BlockCountK);
}
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/contrib_ops/matmul_4bits_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ TEST(MatMulNBits, Float32) {
for (auto M : {1, 2, 100}) {
for (auto N : {/*2560, */1, 2, 32, 288}) {
for (auto K : {/*2560, */16, 32, 64, 128, 256, 1024, 93, 1234 }) {
for (auto block_size : {16, 32, 64, 128 }) {
for (auto block_size : {/*16, 32, */64/*, 128*/ }) {
for (auto accuracy_level : {/*0, 1, */4}) {
TestOptions base_opts{};
base_opts.M = M, base_opts.N = N, base_opts.K = K;
Expand Down
24 changes: 12 additions & 12 deletions onnxruntime/test/mlas/bench/bench_q4dq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
#include "core/util/thread_utils.h"

static void BM_QDQBlockwiseQuantizer_QuantizeColumnwise(benchmark::State& state) {
int M = state.range(0);
int N = state.range(1);
int quant_block_size = state.range(2);
int threads = state.range(3);
int M = (int)state.range(0);
int N = (int)state.range(1);
int quant_block_size = (int)state.range(2);
int threads = (int)state.range(3);
size_t scale_size = (M + quant_block_size - 1) / quant_block_size * N;

auto src = RandomVectorUniform(M * N, -16.0f, 14.0f);
Expand All @@ -37,10 +37,10 @@ static void BM_QDQBlockwiseQuantizer_QuantizeColumnwise(benchmark::State& state)
}

static void BM_MlasQuantizeBlockwise(benchmark::State& state) {
int M = state.range(0);
int N = state.range(1);
int quant_block_size = state.range(2);
int threads = state.range(3);
int M = (int)state.range(0);
int N = (int)state.range(1);
int quant_block_size = (int)state.range(2);
int threads = (int)state.range(3);
size_t scale_size = (M + quant_block_size - 1) / quant_block_size * N;

auto src = RandomVectorUniform(M * N, -16.0f, 14.0f);
Expand All @@ -65,10 +65,10 @@ static void BM_MlasQuantizeBlockwise(benchmark::State& state) {
}

static void BM_QDQBlockwiseQuantizer_TransposeColumnwise(benchmark::State& state) {
int M = state.range(0);
int N = state.range(1);
int quant_block_size = state.range(2);
int threads = state.range(3);
int M = (int)state.range(0);
int N = (int)state.range(1);
int quant_block_size = (int)state.range(2);
int threads = (int)state.range(3);
int quant_num_M = (M + quant_block_size - 1) / quant_block_size;
int blob_size = (quant_block_size + 1) / 2;
size_t scale_size = quant_num_M * N;
Expand Down
Loading

0 comments on commit d035939

Please sign in to comment.