diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index cc62d36ebfa3e..079067a85bfcb 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -555,9 +555,11 @@ else() ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp ) -if(NOT CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "9") +if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "10") + message(STATUS "Using -mavx2 -mfma -mavxvnni flags") set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavxvnni") else() + message(STATUS "Using -mavx2 -mfma flags") set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma") endif() set(mlas_platform_srcs_avx512f diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h index 42fd1131d9a4c..7c9828c6b9795 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h @@ -24,7 +24,7 @@ accumulate_blklen32_r1c1blk1_zp_avx2( bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) if constexpr (vnni) { const __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8)); const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); @@ -36,7 +36,7 @@ accumulate_blklen32_r1c1blk1_zp_avx2( const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, dot_16_epi16); const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) } #endif } @@ -58,7 +58,7 @@ accumulate_blklen32_r1c1blk2_zp_avx2( __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 32~63 -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) if constexpr (vnni) { { bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); @@ -98,7 +98,7 @@ accumulate_blklen32_r1c1blk2_zp_avx2( const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); } -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) } #endif } @@ -130,7 +130,7 @@ accumulate_blklen32_r1c1blk2_zp_is_8_avx2( bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) if constexpr (vnni) { __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); @@ -161,7 +161,7 @@ accumulate_blklen32_r1c1blk2_zp_is_8_avx2( ); acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) } #endif } @@ -189,7 +189,7 @@ accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2( bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) if constexpr (vnni) { { __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); @@ -221,7 +221,7 @@ accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2( const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*(scale_b + 1)), scale_a1_8_ps); acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); } -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) } #endif }