diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h index 2494039134ba3..4a67f29714940 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h @@ -243,8 +243,8 @@ Q4Int8GemmR2xC4BlkLen64Avx512( constexpr size_t PerAccuBlk2 = 2; const size_t lda = BlockCountK * BlkLen64; - const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; - const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBData = PerAccuBlk2 * BlkDataSizeInBytes; + //const size_t StrideQuantBScale = BlockCountK; assert(CountM % NRows2 == 0); assert(CountN % NCols4 == 0); @@ -275,20 +275,16 @@ Q4Int8GemmR2xC4BlkLen64Avx512( const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); - accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); - accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, acc[1], acc[NCols4 + 1]); - accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], acc[NCols4 + 2]); - accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], acc[NCols4 + 3]); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk2, acc[1], acc[NCols4 + 1]); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk2, acc[2], acc[NCols4 + 2]); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk2, acc[3], acc[NCols4 + 3]); // increment block pointers QuantAPtr += BlkLen64 * PerAccuBlk2; QuantAScalePtr += PerAccuBlk2; - QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; - QuantBScalePtr += PerAccuBlk2; + QuantBDataPtr += StrideQuantBData * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; } // k_blks_remaining while (k_blks_remaining-- > 0) { @@ -298,16 +294,16 @@ Q4Int8GemmR2xC4BlkLen64Avx512( accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, - QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, acc[1], acc[NCols4 + 1]); + QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, - QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], acc[NCols4 + 2]); + QuantBDataPtr + 2 * BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, - QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], acc[NCols4 + 3]); + QuantBDataPtr + 3 * BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); QuantAPtr += BlkLen64; QuantAScalePtr++; - QuantBDataPtr += BlkDataSizeInBytes; - QuantBScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; } #if 1 @@ -341,8 +337,8 @@ Q4Int8GemmR2xC4BlkLen64Avx512( _mm_storeu_ps(SumPtr + ldc, acc_r1); #endif // move to next NCols columns - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes; + QuantBScaleColPtr += NCols4 * BlockCountK; BiasPtr += BiasPtr != nullptr ? NCols4 : 0; SumPtr += NCols4; } @@ -465,8 +461,8 @@ Q4Int8GemmR1xC4BlkLen64Avx512( constexpr size_t PerAccuBlk2 = 2; const size_t lda = BlockCountK * BlkLen; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t StrideQuantBScale = BlockCountK; + //const size_t StrideQuantBData = PerAccuBlk2 * BlkDataSizeInBytes; + //const size_t StrideQuantBScale = BlockCountK; //assert(CountM < NRows2); //assert(CountN % NCols4 == 0); @@ -490,34 +486,30 @@ Q4Int8GemmR1xC4BlkLen64Avx512( for (; k_blks_remaining >= PerAccuBlk2; k_blks_remaining -= PerAccuBlk2) { const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); - accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, - QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, - QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); - accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, - QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); - accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, - QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + PerAccuBlk2, acc[1]); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk2, acc[2]); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 3 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk2, acc[3]); // increment block pointers QuantAPtr += BlkLen64 * PerAccuBlk2; QuantAScalePtr += PerAccuBlk2; - QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; - QuantBScalePtr += PerAccuBlk2; + QuantBDataPtr += PerAccuBlk2 * BlkDataSizeInBytes * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; } while (k_blks_remaining-- > 0) { const __m512i av_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); - accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); - accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); QuantAPtr += BlkLen64; QuantAScalePtr++; - QuantBDataPtr += BlkDataSizeInBytes; - QuantBScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; } __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); @@ -528,8 +520,8 @@ Q4Int8GemmR1xC4BlkLen64Avx512( _mm_storeu_ps(SumPtr, acc_r0); // move to next NCols columns - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes; + QuantBScaleColPtr += NCols4 * BlockCountK; BiasPtr += BiasPtr != nullptr ? NCols4 : 0; SumPtr += NCols4; } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index 3e556b8432b84..9b804e38d2c96 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -33,7 +33,7 @@ SQ4BitGemmPackQuantBDataSize( } static size_t -GetContinueLayoutOffset64(size_t N, const size_t n, const size_t SubOrBlkCountK, const size_t k_sub_or_blk) +GetContinueLayoutOffsetSubBlk(size_t N, const size_t n, const size_t SubOrBlkCountK, const size_t k_sub_or_blk) { size_t T = n / 4, t = n % 4; bool te = T == N / 4; @@ -47,19 +47,19 @@ GetContinueLayoutOffset64(size_t N, const size_t n, const size_t SubOrBlkCountK, } static size_t -GetContinueLayoutOffset32(size_t N, const size_t n, const size_t BlockCountK, const size_t k_blk) +GetContinueLayoutOffsetBlkInSubBlk(size_t N, const size_t n, const size_t BlockCountK, const size_t k_blk, const int blks_per_sub) { - size_t T = n / 4, t = n % 4, k_subblk = k_blk / 2, b = k_blk % 2; - bool te = T == N / 4, be = k_subblk == BlockCountK / 2; + size_t T = n / 4, t = n % 4, k_subblk = k_blk / blks_per_sub, b = k_blk % blks_per_sub; + bool te = T == N / 4, be = k_subblk == BlockCountK / blks_per_sub; size_t scale_dst_offset = T * 4 * BlockCountK; if (te) { scale_dst_offset += t * BlockCountK + k_blk; } else { - scale_dst_offset += k_subblk * 2 * 4; + scale_dst_offset += k_subblk * blks_per_sub * 4; if (be) { scale_dst_offset += b * 4 + t; } else { - scale_dst_offset += t * 2 + b; + scale_dst_offset += t * blks_per_sub + b; } } return scale_dst_offset; @@ -132,37 +132,35 @@ PackQuantB( const size_t k_blks_remaining = BlockCountK - (SubBlkCountK - 1) * SubBlkLen / BlkLen; for (size_t k = 0; k < k_blks_remaining; k++) { const size_t k_blk = k_subblk * SubBlkLen / BlkLen + k; - if (BlkLen == 16 || SubBlkLen == 128) { // TODO: + if (BlkLen == 16) { // not to do the compute order layout yet std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset; pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData + k * BlkLen / 2, PackBytePairCount, PackDataSize); - } else if (BlkLen == 32) { - const size_t dst_data_offset = GetContinueLayoutOffset32(N, n, BlockCountK, k_blk); + } else if (BlkLen >= SubBlkLen) { + // shall not reach here with avx2 + assert(SubBlkLen == 128); + } else { + int blks_per_sub = (int)(SubBlkLen / BlkLen); + const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub); std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData + k * BlkLen / 2, PackBytePairCount, PackDataSize); - } else { - // shall not reach here (with avx2?) - assert(false); } } - } - else - { - if (BlkLen == 16 || SubBlkLen == 128) { + } else { + if (BlkLen == 16) { // not to do the compute order layout yet std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset; pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); - } - else if (BlkLen == 32) { - const size_t k_blk = k_subblk * SubBlkLen / BlkLen; - const size_t dst_data_offset = GetContinueLayoutOffset32(N, n, BlockCountK, k_blk); - std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; - pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); - } - else { // if (BlkLen > 32) - const size_t dst_data_offset = GetContinueLayoutOffset64(N, n, SubBlkCountK, k_subblk); + } else if (BlkLen >= SubBlkLen) { + const size_t dst_data_offset = GetContinueLayoutOffsetSubBlk(N, n, SubBlkCountK, k_subblk); std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * SubBlkDataSize; pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); + } else { + int blks_per_sub = (int)(SubBlkLen / BlkLen); + const size_t k_blk = k_subblk * blks_per_sub; + const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub); + std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; + pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); } } } @@ -173,7 +171,8 @@ PackQuantB( static void ComputePackBlkSum( - size_t Blklen, + size_t BlkLen, + size_t SubBlkLen, size_t N, float* QuantBScaleBegin, const std::byte* QuantBZPBegin, @@ -208,13 +207,14 @@ ComputePackBlkSum( const size_t dst_offset = ((n / 16) * BlockCountK + k_blk) * 16 + n % 16; #endif *(BlockSumBegin + dst_offset) = -QuantBScale * zp; - if (true || Blklen == 16) { // TODO + if (BlkLen == 16) { // TODO - } else if (Blklen == 32) { - size_t scale_dst_offset = GetContinueLayoutOffset32(N, n, BlockCountK, k_blk); + } else if (BlkLen >= SubBlkLen) { + const size_t scale_dst_offset = GetContinueLayoutOffsetSubBlk(N, n, BlockCountK, k_blk); *(QuantBScaleBegin + scale_dst_offset) = QuantBScale; - } else if (Blklen > 32) { - const size_t scale_dst_offset = GetContinueLayoutOffset64(N, n, BlockCountK, k_blk); + } else { + int blks_per_sub = (int)(SubBlkLen / BlkLen); + size_t scale_dst_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub); *(QuantBScaleBegin + scale_dst_offset) = QuantBScale; } } @@ -248,7 +248,7 @@ PackQuantBDataAndBlkSum( } if ((QuantBScaleBegin && !has_zp_input) || QuantBZPBegin) { - ComputePackBlkSum(BlkLen, N, packed_quant_b.PackedQuantBScale, QuantBZPBegin, packed_quant_b.QuantBBlkSum, ThreadPool, BlockCountK); + ComputePackBlkSum(BlkLen, SubBlkLen, N, packed_quant_b.PackedQuantBScale, QuantBZPBegin, packed_quant_b.QuantBBlkSum, ThreadPool, BlockCountK); } } diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 3c3eff81bad79..f54a0fd03679b 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -267,7 +267,7 @@ TEST(MatMulNBits, Float32) { for (auto M : {1, 2, 100}) { for (auto N : {/*2560, */1, 2, 32, 288}) { for (auto K : {/*2560, */16, 32, 64, 128, 256, 1024, 93, 1234 }) { - for (auto block_size : {16, 32, 64, 128 }) { + for (auto block_size : {/*16, 32, */64/*, 128*/ }) { for (auto accuracy_level : {/*0, 1, */4}) { TestOptions base_opts{}; base_opts.M = M, base_opts.N = N, base_opts.K = K; diff --git a/onnxruntime/test/mlas/bench/bench_q4dq.cpp b/onnxruntime/test/mlas/bench/bench_q4dq.cpp index 00234ecfd2ce2..eb0727207b837 100644 --- a/onnxruntime/test/mlas/bench/bench_q4dq.cpp +++ b/onnxruntime/test/mlas/bench/bench_q4dq.cpp @@ -9,10 +9,10 @@ #include "core/util/thread_utils.h" static void BM_QDQBlockwiseQuantizer_QuantizeColumnwise(benchmark::State& state) { - int M = state.range(0); - int N = state.range(1); - int quant_block_size = state.range(2); - int threads = state.range(3); + int M = (int)state.range(0); + int N = (int)state.range(1); + int quant_block_size = (int)state.range(2); + int threads = (int)state.range(3); size_t scale_size = (M + quant_block_size - 1) / quant_block_size * N; auto src = RandomVectorUniform(M * N, -16.0f, 14.0f); @@ -37,10 +37,10 @@ static void BM_QDQBlockwiseQuantizer_QuantizeColumnwise(benchmark::State& state) } static void BM_MlasQuantizeBlockwise(benchmark::State& state) { - int M = state.range(0); - int N = state.range(1); - int quant_block_size = state.range(2); - int threads = state.range(3); + int M = (int)state.range(0); + int N = (int)state.range(1); + int quant_block_size = (int)state.range(2); + int threads = (int)state.range(3); size_t scale_size = (M + quant_block_size - 1) / quant_block_size * N; auto src = RandomVectorUniform(M * N, -16.0f, 14.0f); @@ -65,10 +65,10 @@ static void BM_MlasQuantizeBlockwise(benchmark::State& state) { } static void BM_QDQBlockwiseQuantizer_TransposeColumnwise(benchmark::State& state) { - int M = state.range(0); - int N = state.range(1); - int quant_block_size = state.range(2); - int threads = state.range(3); + int M = (int)state.range(0); + int N = (int)state.range(1); + int quant_block_size = (int)state.range(2); + int threads = (int)state.range(3); int quant_num_M = (M + quant_block_size - 1) / quant_block_size; int blob_size = (quant_block_size + 1) / 2; size_t scale_size = quant_num_M * N; diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 1157cd4c868f1..b994a21ff9797 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -461,7 +461,11 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture