diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 8724b58ce2961..d7431e95048d3 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -146,6 +146,34 @@ UnrolledLoop(IterationFn&& f) namespace { + /** + * @brief Horizontally sum 4 vectors and store + * the results in the returned vector + */ + static MLAS_FORCEINLINE __m128 + FoldAccumulators(const __m256& acc0, const __m256& acc1, const __m256& acc2, const __m256& acc3) + { + __m256 acc_lo01 = _mm256_unpacklo_ps(acc0, acc1); + __m256 acc_hi01 = _mm256_unpacklo_ps(acc0, acc1); + __m256 acc_lo23 = _mm256_unpacklo_ps(acc2, acc3); + __m256 acc_hi23 = _mm256_unpacklo_ps(acc2, acc3); + + __m256 acc_lo0123 = _mm256_castpd_ps( + _mm256_unpacklo_pd(_mm256_castps_pd(acc_lo01), _mm256_castps_pd(acc_lo23))); + __m256 acc_hi0123 = _mm256_castpd_ps( + _mm256_unpackhi_pd(_mm256_castps_pd(acc_lo01), _mm256_castps_pd(acc_lo23))); + acc_lo0123 = _mm256_add_ps(acc_lo0123, acc_hi0123); + acc_hi0123 = _mm256_castpd_ps( + _mm256_unpacklo_pd(_mm256_castps_pd(acc_hi01), _mm256_castps_pd(acc_hi23))); + acc_lo0123 = _mm256_add_ps(acc_lo0123, acc_hi0123); + acc_hi0123 = _mm256_castpd_ps( + _mm256_unpackhi_pd(_mm256_castps_pd(acc_hi01), _mm256_castps_pd(acc_hi23))); + acc_lo0123 = _mm256_add_ps(acc_lo0123, acc_hi0123); + + __m256 acc_y = + _mm256_add_ps(_mm256_extractf32x8_ps(acc_lo0123, 0), _mm256_extractf32x8_ps(acc_lo0123, 1)); + return _mm_add_ps(_mm256_extractf32x4_ps(acc_y, 0), _mm256_extractf32x4_ps(acc_y, 1)); + } template MLAS_FORCEINLINE void @@ -163,8 +191,142 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( const float* BiasPtr ) { - BlkLen, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK; - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, BiasPtr; + constexpr size_t BlkBitWidth = 4; + //constexpr size_t SubBlkLen = 16; + + const __m256i lowMask = _mm256_set1_epi8(0xF); + + __m256 acc_lo[NCols]; + UnrolledLoop([&](size_t i) { + acc_lo[i] = _mm256_setzero_ps(); + }); + + const auto* b = QuantBDataColPtr; + const float* s = QuantBScaleColPtr; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + // only used if HasZeroPoint == true + + for (size_t k = 0; k < CountK; k += BlkLen) { + size_t ck = std::min(CountK - k, BlkLen); + + float scale_v[NCols]; + UnrolledLoop([&](size_t i) { + scale_v[i] = *(s + StrideQuantBScale * i); + }); + + __m128i* bptr[NCols]; + UnrolledLoop([&](size_t i) { + bptr[i] = (__m128i*)(b + StrideQuantBData * i); + }); + + [[maybe_unused]] uint8_t offset[NCols]; + // not ready for "Manual conversion to float" in neon yet. following neon to unpack to uint8_t. + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const std::byte zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{ 0x0F }); + offset[i] = std::to_integer(zp); + }); + } + + // TODO: block size shall be multiple of 16 but MLAS_QUANT4_BLK_UNIT is 32 + // follwing code compute 32 float as once so lets not to use SubBlkLen(16) + // code copied from MlasQ4GemmKernelAvx512f in q4gemm_avx512.cc which only works + // with the NCols==4 case. + for (size_t kk = 0; kk < ck; kk += SubBlkLen) { + size_t kklen = std::min((size_t)SubBlkLen, ck - kk); + + // Load A row vectors + uint32_t mask = 0xffff >> (SubBlkLen - kklen); + __m256 av_lo = _mm256_maskz_loadu_ps(__mmask16(mask), ARowPtr + k + kk); + + mask = mask >> 8; + __m256 av_hi = mask == 0 ? _mm256_setzero_ps() + : _mm256_maskz_loadu_ps(__mmask16(mask), ARowPtr + k + kk + 8); + + // Load B col vectors + __m128i bvi4[NCols]; + UnrolledLoop([&](size_t i) { + // get 16 4 bits quantized features from each column + bvi4[i] = _mm_loadu_si64(bptr[i]++); + }); + + __m256 bvf_lo[NCols], bvf_hi[NCols]; + UnrolledLoop([&](size_t i) { + // get 16 4 bits quantized features from each column + __m128i bvi4 = _mm_loadu_si128(bptr[i]++); + + // Interleave lower and upper to form the final 8-bit unpacked values + __m128i lower = _mm_and_si128(bvi4, mask); + __m128i upper = _mm_and_si128(_mm_srli_epi16(bvi4, 4), mask); + + // Interleave lower and upper to form 8-bit unpacked values + __m128i unpacked = _mm_unpacklo_epi8(lower, upper); + + // Convert the unpacked 8-bit integers to 16-bit integers + __m256i bv_lo = _mm256_cvtepu8_epi32(unpacked); + + // Extract the second 8 8-bit integers + __m128i unpacked_next_8 = _mm_srli_si128(unpacked, 8); + + // Extend the 8-bit integers to 32-bit integers + __m256i bv_hi = _mm256_cvtepu8_epi32(unpacked_next_8); + + // Subtract zero-point from the integers + if constexpr (HasZeroPoint) { + // Subtract zero-point from the integers + __m256i zp = _mm256_set1_epi32(offset[i]) + bv_lo = _mm256_sub_epi32(bv_lo, zp); + bv_hi = _mm256_sub_epi32(bv_hi, zp); + } + else { + // Subtract 8 from the integers + const __m256i eight = _mm256_set1_epi32(8); + bv_lo = _mm256_sub_epi32(bv_lo, eight); + bv_hi = _mm256_sub_epi32(bv_hi, eight); + } + + // Convert to 32-bit int -> float 32 + bvf_lo[i] = _mm256_cvtepi32_ps(bv_lo); + bvf_hi[i] = _mm256_cvtepi32_ps(bv_hi); + }); + + UnrolledLoop([&](size_t i) { + // multiply by scale + __m256 s = _mm256_set1_ps(scale_v[i]); + bvf_lo[i] = _mm256_mul_ps(bvf_lo[i], s); + bvf_hi[i] = _mm256_mul_ps(bvf_hi[i], s); + + // c[m,n] += a[m,k] * b[k,n] + acc_lo[i] = _mm256_fmadd_ps(bvf_lo[i], av_lo, acc_lo[i]); + acc_lo[i] = _mm256_fmadd_ps(bvf_hi[i], av_hi, acc_lo[i]); + }); + } + + b += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + } + + if constexpr (NCols == 4) { + __m128 acc_x = FoldAccumulators(acc_lo[0], acc_lo[1], acc_lo[2], acc_lo[3]); + if (BiasPtr != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_x); + } + else { + for (size_t i = 0; i < NCols; ++i) { + SumPtr[i] = _mm256_reduce_add_ps(acc_lo[i]); + SumPtr[i] += BiasPtr == nullptr ? 0.0f : BiasPtr[i]; + } + } } template diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index dcbee4b1e85ef..d4eccd0f3b7e3 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -235,7 +235,7 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( } // TODO: block size shall be multiple of 16 but MLAS_QUANT4_BLK_UNIT is 32 - // follwing code compute 32 float as once so lets not to use SubBlkLen(16) + // follwing code compute 32 floats as once so lets not to use SubBlkLen(16) // code copied from MlasQ4GemmKernelAvx512f in q4gemm_avx512.cc which only works // with the NCols==4 case. for (size_t kk = 0; kk < ck; kk += MLAS_QUANT4_BLK_UNIT) { @@ -249,6 +249,8 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( __m512 av_hi = mask == 0 ? _mm512_setzero_ps() : _mm512_maskz_loadu_ps(__mmask16(mask), ARowPtr + k + kk + 16); + // TODO: following code to unpack b does not work with MatMulNBits. + // // Load B col vectors __m128i bvi4[NCols]; UnrolledLoop([&](size_t i) { diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index ed09d7ee92b2a..688fd0673f4a5 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -209,11 +209,11 @@ class MlasSQNBitGemmTest : public MlasTestBase { Bias = BufferBias.GetBuffer(N); } -#if 0 - auto print_matrix = [](size_t ncols, size_t nrows, const float* data) { +#if 1 + auto print_matrix = [](size_t nrows, size_t ncols, const float* data) { for (size_t row = 0; row < nrows; ++row) { for (size_t col = 0; col < ncols; ++col) { - std::cout << data[row * nrows + col] << "\t"; + std::cout << data[row * ncols + col] << "\t"; } std::cout << "\n"; }