Skip to content

Commit

Permalink
avx512 port from q4gemm - pass M1N1K1 w/o bias
Browse files Browse the repository at this point in the history
Signed-off-by: liqunfu <[email protected]>
  • Loading branch information
liqunfu committed Mar 31, 2024
1 parent bbd7cf6 commit a605e6a
Showing 1 changed file with 161 additions and 2 deletions.
163 changes: 161 additions & 2 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Module Name:
#include <utility>

#include "sqnbitgemm.h"
#include "q4Common.h"

//
// Quantized B data packing function implementation.
Expand Down Expand Up @@ -146,6 +147,34 @@ UnrolledLoop(IterationFn&& f)

namespace
{
/**
* @brief Horizontally sum 4 vectors and store
* the results in the returned vector
*/
static MLAS_FORCEINLINE __m128
FoldAccumulators(const __m512& acc0, const __m512& acc1, const __m512& acc2, const __m512& acc3)
{
__m512 acc_lo01 = _mm512_unpacklo_ps(acc0, acc1);
__m512 acc_hi01 = _mm512_unpackhi_ps(acc0, acc1);
__m512 acc_lo23 = _mm512_unpacklo_ps(acc2, acc3);
__m512 acc_hi23 = _mm512_unpackhi_ps(acc2, acc3);

__m512 acc_lo0123 = _mm512_castpd_ps(
_mm512_unpacklo_pd(_mm512_castps_pd(acc_lo01), _mm512_castps_pd(acc_lo23)));
__m512 acc_hi0123 = _mm512_castpd_ps(
_mm512_unpackhi_pd(_mm512_castps_pd(acc_lo01), _mm512_castps_pd(acc_lo23)));
acc_lo0123 = _mm512_add_ps(acc_lo0123, acc_hi0123);
acc_hi0123 = _mm512_castpd_ps(
_mm512_unpacklo_pd(_mm512_castps_pd(acc_hi01), _mm512_castps_pd(acc_hi23)));
acc_lo0123 = _mm512_add_ps(acc_lo0123, acc_hi0123);
acc_hi0123 = _mm512_castpd_ps(
_mm512_unpackhi_pd(_mm512_castps_pd(acc_hi01), _mm512_castps_pd(acc_hi23)));
acc_lo0123 = _mm512_add_ps(acc_lo0123, acc_hi0123);

__m256 acc_y =
_mm256_add_ps(_mm512_extractf32x8_ps(acc_lo0123, 0), _mm512_extractf32x8_ps(acc_lo0123, 1));
return _mm_add_ps(_mm256_extractf32x4_ps(acc_y, 0), _mm256_extractf32x4_ps(acc_y, 1));
}

template <size_t NCols, bool HasZeroPoint>
MLAS_FORCEINLINE void
Expand All @@ -163,8 +192,138 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
const float* BiasPtr
)
{
BlkLen, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK;
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, BiasPtr;
constexpr size_t BlkBitWidth = 4;
//constexpr size_t SubBlkLen = 16;

const __m256i lowMask = _mm256_set1_epi8(0xF);

__m512 acc_lo[NCols];
UnrolledLoop<NCols>([&](size_t i) {
acc_lo[i] = _mm512_setzero_ps();
});

const auto* b = QuantBDataColPtr;
const float* s = QuantBScaleColPtr;

[[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer
// only used if HasZeroPoint == true

for (size_t k = 0; k < CountK; k += BlkLen) {
size_t ck = std::min(CountK - k, BlkLen);

float scale_v[NCols];
UnrolledLoop<NCols>([&](size_t i) {
scale_v[i] = *(s + StrideQuantBScale * i);
});

__m128i* bptr[NCols];
UnrolledLoop<NCols>([&](size_t i) {
bptr[i] = (__m128i*)(b + StrideQuantBData * i);
});

[[maybe_unused]] uint8_t offset[NCols];
// not ready for "Manual conversion to float" in neon yet. following neon to unpack to uint8_t.
if constexpr (HasZeroPoint) {
UnrolledLoop<NCols>([&](size_t i) {
const std::byte zp_packed =
QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2];
const std::byte zp = ((QuantBZeroPointIdx & 1) == 1)
? (zp_packed >> 4)
: (zp_packed & std::byte{ 0x0F });
offset[i] = std::to_integer<uint8_t>(zp);
});
}

// TODO: block size shall be multiple of 16 but MLAS_QUANT4_BLK_UNIT is 32
// follwing code compute 32 float as once so lets not to use SubBlkLen(16)
// code copied from MlasQ4GemmKernelAvx512f in q4gemm_avx512.cc which only works
// with the NCols==4 case.
for (size_t kk = 0; kk < ck; kk += MLAS_QUANT4_BLK_UNIT) {
size_t kklen = std::min((size_t)MLAS_QUANT4_BLK_UNIT, ck - kk);

// Load A row vectors
uint32_t mask = 0xffffffff >> (MLAS_QUANT4_BLK_UNIT - kklen);
__m512 av_lo = _mm512_maskz_loadu_ps(__mmask16(mask), ARowPtr + k + kk);

mask = mask >> 16;
__m512 av_hi = mask == 0 ? _mm512_setzero_ps()
: _mm512_maskz_loadu_ps(__mmask16(mask), ARowPtr + k + kk + 16);

// Load B col vectors
__m128i bvi4[NCols];
UnrolledLoop<NCols>([&](size_t i) {
bvi4[i] = _mm_loadu_si128(bptr[i]++);
});

// expand 4b into byte array
__m256i bytes[NCols];
UnrolledLoop<NCols>([&](size_t i) {
bytes[i] = _mm256_set_m128i(_mm_srli_epi16(bvi4[i], 4), bvi4[i]);
bytes[i] = _mm256_and_si256(lowMask, bytes[i]);
});

// Subtract zero-point from the integers
if constexpr (HasZeroPoint) {
// Subtract zero-point from the integers
UnrolledLoop<NCols>([&](size_t i) {
bytes[i] = _mm256_sub_epi8(bytes[i], _mm256_set1_epi8(offset[i]));
});
}
else {
// Subtract 8 from the integers
const __m256i eight = _mm256_set1_epi8(8);
UnrolledLoop<NCols>([&](size_t i) {
bytes[i] = _mm256_sub_epi8(bytes[i], eight);
});
}

// TODO: converting, scale and fma in one unroll loop
// Convert to 16-bit int
__m256i vx16_lo[NCols], vx16_hi[NCols];
UnrolledLoop<NCols>([&](size_t i) {
vx16_lo[i] =
_mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes[i], 0));
vx16_hi[i] =
_mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes[i], 1));
});

__m512 bvf_lo[NCols], bvf_hi[NCols];
UnrolledLoop<NCols>([&](size_t i) {
// Convert to 32-bit int -> float 32
bvf_lo[i] = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo[i]));
bvf_hi[i] = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi[i]));

// multiply by scale
__m512 s = _mm512_set1_ps(scale_v[i]);
bvf_lo[i] = _mm512_mul_ps(bvf_lo[i], s);
bvf_hi[i] = _mm512_mul_ps(bvf_hi[i], s);

// c[m,n] += a[m,k] * b[k,n]
acc_lo[i] = _mm512_fmadd_ps(bvf_lo[i], av_lo, acc_lo[i]);
acc_lo[i] = _mm512_fmadd_ps(bvf_hi[i], av_hi, acc_lo[i]);
});
}

b += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);

if constexpr (HasZeroPoint) {
QuantBZeroPointIdx += 1;
}
}

if constexpr (NCols == 4) {
__m128 acc_x = FoldAccumulators(acc_lo[0], acc_lo[1], acc_lo[2], acc_lo[3]);
if (BiasPtr != nullptr) {
acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr));
}
_mm_storeu_ps(SumPtr, acc_x);
}
else {
for (size_t i = 0; i < NCols; ++i) {
SumPtr[i] = _mm512_reduce_add_ps(acc_lo[i]);
SumPtr[i] += BiasPtr == nullptr ? 0.0f : BiasPtr[i];
}
}
}

template <bool HasZeroPoint>
Expand Down

0 comments on commit a605e6a

Please sign in to comment.