From b2338c26b6f8783148833e2903c78fce50867e40 Mon Sep 17 00:00:00 2001 From: liqun Fu Date: Mon, 5 Aug 2024 20:52:26 -0700 Subject: [PATCH] Fix typos so to call correct vnni functions under vnni condition (#21625) ### Description Fix 2 typos in mlas avx 4bit gemm implementation to call correct vnni functions under vnni condition ### Motivation and Context needed for 1.19.0 release Signed-off-by: liqunfu --- .../core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h | 4 ++-- .../core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h index 3cd610796a5e3..bb14babd6c2b1 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h @@ -679,9 +679,9 @@ Q4Int8GemmR1xC1BlkLen16Avx512( const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); if constexpr (vnni) { - accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); - } else { accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } else { + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); } QuantAPtr += BlkLen16 * PerAccuBlk8; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h index ca12cc14a7875..e9df6b952bd27 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h @@ -721,7 +721,7 @@ Q4Int8GemmR1xC1BlkLen32Avx512( accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); } else { - accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); } QuantAPtr += BlkLen32 * PerAccuBlk4;