Skip to content

Commit

Permalink
draft M1 avx2
Browse files Browse the repository at this point in the history
Signed-off-by: liqunfu <[email protected]>
  • Loading branch information
liqunfu committed Apr 1, 2024
1 parent a605e6a commit 5ca97a7
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 6 deletions.
166 changes: 164 additions & 2 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,34 @@ UnrolledLoop(IterationFn&& f)

namespace
{

Check warning on line 148 in onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 { should almost always be at the end of the previous line [whitespace/braces] [4] Raw Output: onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp:148: { should almost always be at the end of the previous line [whitespace/braces] [4]
/**
* @brief Horizontally sum 4 vectors and store
* the results in the returned vector
*/
static MLAS_FORCEINLINE __m128
FoldAccumulators(const __m256& acc0, const __m256& acc1, const __m256& acc2, const __m256& acc3)
{

Check warning on line 155 in onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 { should almost always be at the end of the previous line [whitespace/braces] [4] Raw Output: onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp:155: { should almost always be at the end of the previous line [whitespace/braces] [4]
__m256 acc_lo01 = _mm256_unpacklo_ps(acc0, acc1);
__m256 acc_hi01 = _mm256_unpacklo_ps(acc0, acc1);
__m256 acc_lo23 = _mm256_unpacklo_ps(acc2, acc3);
__m256 acc_hi23 = _mm256_unpacklo_ps(acc2, acc3);

__m256 acc_lo0123 = _mm256_castpd_ps(
_mm256_unpacklo_pd(_mm256_castps_pd(acc_lo01), _mm256_castps_pd(acc_lo23)));
__m256 acc_hi0123 = _mm256_castpd_ps(
_mm256_unpackhi_pd(_mm256_castps_pd(acc_lo01), _mm256_castps_pd(acc_lo23)));
acc_lo0123 = _mm256_add_ps(acc_lo0123, acc_hi0123);
acc_hi0123 = _mm256_castpd_ps(
_mm256_unpacklo_pd(_mm256_castps_pd(acc_hi01), _mm256_castps_pd(acc_hi23)));
acc_lo0123 = _mm256_add_ps(acc_lo0123, acc_hi0123);
acc_hi0123 = _mm256_castpd_ps(
_mm256_unpackhi_pd(_mm256_castps_pd(acc_hi01), _mm256_castps_pd(acc_hi23)));
acc_lo0123 = _mm256_add_ps(acc_lo0123, acc_hi0123);

__m256 acc_y =
_mm256_add_ps(_mm256_extractf32x8_ps(acc_lo0123, 0), _mm256_extractf32x8_ps(acc_lo0123, 1));
return _mm_add_ps(_mm256_extractf32x4_ps(acc_y, 0), _mm256_extractf32x4_ps(acc_y, 1));
}

template <size_t NCols, bool HasZeroPoint>
MLAS_FORCEINLINE void
Expand All @@ -163,8 +191,142 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
const float* BiasPtr
)
{
BlkLen, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK;
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, BiasPtr;
constexpr size_t BlkBitWidth = 4;
//constexpr size_t SubBlkLen = 16;

const __m256i lowMask = _mm256_set1_epi8(0xF);

__m256 acc_lo[NCols];
UnrolledLoop<NCols>([&](size_t i) {
acc_lo[i] = _mm256_setzero_ps();
});

const auto* b = QuantBDataColPtr;
const float* s = QuantBScaleColPtr;

[[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer
// only used if HasZeroPoint == true

for (size_t k = 0; k < CountK; k += BlkLen) {
size_t ck = std::min(CountK - k, BlkLen);

float scale_v[NCols];
UnrolledLoop<NCols>([&](size_t i) {
scale_v[i] = *(s + StrideQuantBScale * i);
});

__m128i* bptr[NCols];
UnrolledLoop<NCols>([&](size_t i) {
bptr[i] = (__m128i*)(b + StrideQuantBData * i);
});

[[maybe_unused]] uint8_t offset[NCols];
// not ready for "Manual conversion to float" in neon yet. following neon to unpack to uint8_t.
if constexpr (HasZeroPoint) {
UnrolledLoop<NCols>([&](size_t i) {
const std::byte zp_packed =
QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2];
const std::byte zp = ((QuantBZeroPointIdx & 1) == 1)
? (zp_packed >> 4)
: (zp_packed & std::byte{ 0x0F });
offset[i] = std::to_integer<uint8_t>(zp);
});
}

// TODO: block size shall be multiple of 16 but MLAS_QUANT4_BLK_UNIT is 32
// follwing code compute 32 float as once so lets not to use SubBlkLen(16)
// code copied from MlasQ4GemmKernelAvx512f in q4gemm_avx512.cc which only works
// with the NCols==4 case.
for (size_t kk = 0; kk < ck; kk += SubBlkLen) {
size_t kklen = std::min((size_t)SubBlkLen, ck - kk);

// Load A row vectors
uint32_t mask = 0xffff >> (SubBlkLen - kklen);
__m256 av_lo = _mm256_maskz_loadu_ps(__mmask16(mask), ARowPtr + k + kk);

mask = mask >> 8;
__m256 av_hi = mask == 0 ? _mm256_setzero_ps()
: _mm256_maskz_loadu_ps(__mmask16(mask), ARowPtr + k + kk + 8);

// Load B col vectors
__m128i bvi4[NCols];
UnrolledLoop<NCols>([&](size_t i) {
// get 16 4 bits quantized features from each column
bvi4[i] = _mm_loadu_si64(bptr[i]++);
});

__m256 bvf_lo[NCols], bvf_hi[NCols];
UnrolledLoop<NCols>([&](size_t i) {
// get 16 4 bits quantized features from each column
__m128i bvi4 = _mm_loadu_si128(bptr[i]++);

// Interleave lower and upper to form the final 8-bit unpacked values
__m128i lower = _mm_and_si128(bvi4, mask);
__m128i upper = _mm_and_si128(_mm_srli_epi16(bvi4, 4), mask);

// Interleave lower and upper to form 8-bit unpacked values
__m128i unpacked = _mm_unpacklo_epi8(lower, upper);

// Convert the unpacked 8-bit integers to 16-bit integers
__m256i bv_lo = _mm256_cvtepu8_epi32(unpacked);

// Extract the second 8 8-bit integers
__m128i unpacked_next_8 = _mm_srli_si128(unpacked, 8);

// Extend the 8-bit integers to 32-bit integers
__m256i bv_hi = _mm256_cvtepu8_epi32(unpacked_next_8);

// Subtract zero-point from the integers
if constexpr (HasZeroPoint) {
// Subtract zero-point from the integers
__m256i zp = _mm256_set1_epi32(offset[i])
bv_lo = _mm256_sub_epi32(bv_lo, zp);
bv_hi = _mm256_sub_epi32(bv_hi, zp);
}
else {
// Subtract 8 from the integers
const __m256i eight = _mm256_set1_epi32(8);
bv_lo = _mm256_sub_epi32(bv_lo, eight);
bv_hi = _mm256_sub_epi32(bv_hi, eight);
}

// Convert to 32-bit int -> float 32
bvf_lo[i] = _mm256_cvtepi32_ps(bv_lo);
bvf_hi[i] = _mm256_cvtepi32_ps(bv_hi);
});

UnrolledLoop<NCols>([&](size_t i) {
// multiply by scale
__m256 s = _mm256_set1_ps(scale_v[i]);
bvf_lo[i] = _mm256_mul_ps(bvf_lo[i], s);
bvf_hi[i] = _mm256_mul_ps(bvf_hi[i], s);

// c[m,n] += a[m,k] * b[k,n]
acc_lo[i] = _mm256_fmadd_ps(bvf_lo[i], av_lo, acc_lo[i]);
acc_lo[i] = _mm256_fmadd_ps(bvf_hi[i], av_hi, acc_lo[i]);
});
}

b += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);

if constexpr (HasZeroPoint) {
QuantBZeroPointIdx += 1;
}
}

if constexpr (NCols == 4) {
__m128 acc_x = FoldAccumulators(acc_lo[0], acc_lo[1], acc_lo[2], acc_lo[3]);
if (BiasPtr != nullptr) {
acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr));
}
_mm_storeu_ps(SumPtr, acc_x);
}
else {
for (size_t i = 0; i < NCols; ++i) {
SumPtr[i] = _mm256_reduce_add_ps(acc_lo[i]);
SumPtr[i] += BiasPtr == nullptr ? 0.0f : BiasPtr[i];
}
}
}

template <bool HasZeroPoint>
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
}

// TODO: block size shall be multiple of 16 but MLAS_QUANT4_BLK_UNIT is 32
// follwing code compute 32 float as once so lets not to use SubBlkLen(16)
// follwing code compute 32 floats as once so lets not to use SubBlkLen(16)
// code copied from MlasQ4GemmKernelAvx512f in q4gemm_avx512.cc which only works
// with the NCols==4 case.
for (size_t kk = 0; kk < ck; kk += MLAS_QUANT4_BLK_UNIT) {
Expand All @@ -249,6 +249,8 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
__m512 av_hi = mask == 0 ? _mm512_setzero_ps()
: _mm512_maskz_loadu_ps(__mmask16(mask), ARowPtr + k + kk + 16);

// TODO: following code to unpack b does not work with MatMulNBits.
//
// Load B col vectors
__m128i bvi4[NCols];
UnrolledLoop<NCols>([&](size_t i) {
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,11 @@ class MlasSQNBitGemmTest : public MlasTestBase {
Bias = BufferBias.GetBuffer(N);
}

#if 0
auto print_matrix = [](size_t ncols, size_t nrows, const float* data) {
#if 1
auto print_matrix = [](size_t nrows, size_t ncols, const float* data) {
for (size_t row = 0; row < nrows; ++row) {
for (size_t col = 0; col < ncols; ++col) {
std::cout << data[row * nrows + col] << "\t";
std::cout << data[row * ncols + col] << "\t";
}
std::cout << "\n";
}
Expand Down

0 comments on commit 5ca97a7

Please sign in to comment.