Skip to content

Commit

Permalink
condition for -mavxvnni
Browse files Browse the repository at this point in the history
Signed-off-by: liqunfu <[email protected]>
  • Loading branch information
liqunfu committed Jul 30, 2024
1 parent 012e9c4 commit 21b9138
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 5 deletions.
7 changes: 5 additions & 2 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -555,8 +555,11 @@ else()
${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
)
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavxvnni")

if(NOT CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "9")
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavxvnni")
else()
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma")
endif()
set(mlas_platform_srcs_avx512f
${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S
${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx512F.S
Expand Down
18 changes: 17 additions & 1 deletion onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ accumulate_blklen32_r2c1blk2_avx2(
// const __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask);
__m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63

#if !defined(__GNUC__) || (__GNUC__ > 9)
if constexpr (vnni) {
__m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b));
{
Expand All @@ -78,6 +79,7 @@ accumulate_blklen32_r2c1blk2_avx2(
acc1 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc1);
}
} else {
#endif
//{
const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8);
const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8);
Expand Down Expand Up @@ -106,7 +108,9 @@ accumulate_blklen32_r2c1blk2_avx2(
__m256 scale_8_ps_ = _mm256_permute_ps(_mm256_mul_ps(scale_a1_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0));
acc1 = _mm256_fmadd_ps(sum_ps_, scale_8_ps_, acc1);
//}
#if !defined(__GNUC__) || (__GNUC__ > 9)
}
#endif
}

template <bool vnni>
Expand All @@ -126,6 +130,7 @@ accumulate_blklen32_r1c1blk2_avx2(
__m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31
__m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63

#if !defined(__GNUC__) || (__GNUC__ > 9)
if constexpr (vnni) {
const __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8);
const __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8);
Expand All @@ -139,6 +144,7 @@ accumulate_blklen32_r1c1blk2_avx2(
__m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0));
acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0);
} else {
#endif
const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8);
const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8);
const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16);
Expand All @@ -152,7 +158,9 @@ accumulate_blklen32_r1c1blk2_avx2(
// 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0
__m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0));
acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0);
#if !defined(__GNUC__) || (__GNUC__ > 9)
}
#endif
}

template<bool vnni>
Expand All @@ -171,15 +179,19 @@ accumulate_blklen32_r2c1blk1_avx2(
const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(QuantBDataPtr));
__m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0);
bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8);


#if !defined(__GNUC__) || (__GNUC__ > 9)
if constexpr (vnni) {
accumulate_1blk_dot_vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0);
accumulate_1blk_dot_vnni(av10_32_epi8, bv_32_epi8, combined_scale10, acc1);
} else {
#endif
__m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15);
accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0);
accumulate_1blk_dot(av10_32_epi8, bv_32_epi8, combined_scale10, one_16_epi16, acc1);
#if !defined(__GNUC__) || (__GNUC__ > 9)
}
#endif
}

template <bool vnni>
Expand All @@ -196,12 +208,16 @@ accumulate_blklen32_r1c1blk1_avx2(
__m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0);
bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8);

#if !defined(__GNUC__) || (__GNUC__ > 9)
if constexpr (vnni) {
accumulate_1blk_dot_vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0);
} else {
#endif
__m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15);
accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0);
#if !defined(__GNUC__) || (__GNUC__ > 9)
}
#endif
}

template <bool vnni>
Expand Down
12 changes: 10 additions & 2 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ accumulate_blklen64_r2c1blk1_avx2(
__m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31
__m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63

#if !defined(__GNUC__) || (__GNUC__ > 9)
if constexpr (vnni) {
__m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8);
sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av01_32_epi8);
Expand All @@ -43,8 +44,9 @@ accumulate_blklen64_r2c1blk1_avx2(
__m256 scale_a1_ps = _mm256_broadcast_ss(scale_a1);

acc1 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a1_ps, scale_b_ps), acc1);

} else {
#endif
__m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8);
__m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8);
__m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16);
Expand All @@ -69,7 +71,9 @@ accumulate_blklen64_r2c1blk1_avx2(
__m256 scale_a1_ps = _mm256_broadcast_ss(scale_a1);

acc1 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a1_ps, scale_b_ps), acc1);
#if !defined(__GNUC__) || (__GNUC__ > 9)
}
#endif
}

template <bool vnni>
Expand All @@ -89,6 +93,7 @@ accumulate_blklen64_r1c1blk1_avx2(
__m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31
__m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63

#if !defined(__GNUC__) || (__GNUC__ > 9)
if constexpr (vnni) {
__m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8);
sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av01_32_epi8);
Expand All @@ -99,6 +104,7 @@ accumulate_blklen64_r1c1blk1_avx2(

acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0);
} else {
#endif
const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8);
const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8);
const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16);
Expand All @@ -111,7 +117,9 @@ accumulate_blklen64_r1c1blk1_avx2(
__m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b);

acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0);
#if !defined(__GNUC__) || (__GNUC__ > 9)
}
#endif
}

template <bool vnni>
Expand All @@ -134,7 +142,7 @@ Q4Int8GemmR2xC4BlkLen64Avx2(
constexpr size_t NCols4 = 4;
constexpr size_t NRows2 = 2;
constexpr size_t SubblkLen = 64;

const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen);
const size_t PerBlkSubblkCount = BlkLen / SubblkLen;
const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,21 @@ accumulate_blklen32_r1c1blk1_zp_avx2(

bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, _mm256_set1_epi8(get_zp<HasZeroPoint>(true, QuantBZeroPointPtr)));

#if !defined(__GNUC__) || (__GNUC__ > 9)
if constexpr (vnni) {
const __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8));
const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32);
acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc);
} else {
#endif
__m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15);
const __m256i dot_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8));
const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, dot_16_epi16);
const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32);
acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc);
#if !defined(__GNUC__) || (__GNUC__ > 9)
}
#endif
}

template<bool vnni>
Expand All @@ -54,6 +58,7 @@ accumulate_blklen32_r1c1blk2_zp_avx2(
__m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31
__m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 32~63

#if !defined(__GNUC__) || (__GNUC__ > 9)
if constexpr (vnni) {
{
bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(get_zp<true>(true, QuantBZeroPointPtr)));
Expand All @@ -71,6 +76,7 @@ accumulate_blklen32_r1c1blk2_zp_avx2(
acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0);
}
} else {
#endif
{
bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(get_zp<true>(true, QuantBZeroPointPtr)));
const __m256 scale = _mm256_set1_ps(*(scale_a) * *(scale_b));
Expand All @@ -92,7 +98,9 @@ accumulate_blklen32_r1c1blk2_zp_avx2(
const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32);
acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0);
}
#if !defined(__GNUC__) || (__GNUC__ > 9)
}
#endif
}

template<bool vnni>
Expand Down Expand Up @@ -122,6 +130,7 @@ accumulate_blklen32_r1c1blk2_zp_is_8_avx2(
bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8);
bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8);

#if !defined(__GNUC__) || (__GNUC__ > 9)
if constexpr (vnni) {
__m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8));
__m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8));
Expand All @@ -135,6 +144,7 @@ accumulate_blklen32_r1c1blk2_zp_is_8_avx2(

acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0);
} else {
#endif
__m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8));
__m256i dot1_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8));
const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16);
Expand All @@ -151,7 +161,9 @@ accumulate_blklen32_r1c1blk2_zp_is_8_avx2(
);

acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0);
#if !defined(__GNUC__) || (__GNUC__ > 9)
}
#endif
}

template <bool vnni>
Expand All @@ -177,6 +189,7 @@ accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(
bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8);
bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8);

#if !defined(__GNUC__) || (__GNUC__ > 9)
if constexpr (vnni) {
{
__m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8));
Expand All @@ -191,6 +204,7 @@ accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(
acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0);
}
} else {
#endif
{
__m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8));
__m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot0_16_epi16);
Expand All @@ -207,7 +221,9 @@ accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(
const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*(scale_b + 1)), scale_a1_8_ps);
acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0);
}
#if !defined(__GNUC__) || (__GNUC__ > 9)
}
#endif
}

template <bool HasZeroPoint, bool vnni>
Expand Down

0 comments on commit 21b9138

Please sign in to comment.