Skip to content

Commit

Permalink
Reapply "add template argument HasZeroPoint"
Browse files Browse the repository at this point in the history
This reverts commit 47bcd5f.
  • Loading branch information
edgchen1 committed Jan 29, 2024
1 parent 47bcd5f commit a4b197c
Showing 1 changed file with 146 additions and 43 deletions.
189 changes: 146 additions & 43 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4])
namespace
{

template <size_t NCols>
template <size_t NCols, bool HasZeroPoint>
MLAS_FORCEINLINE void
ComputeDotProducts_BlkBitWidth4_CompFp32(
size_t BlkLen,
Expand Down Expand Up @@ -247,7 +247,8 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(

const std::byte* QuantBData = QuantBDataColPtr;
const float* QuantBScale = QuantBScaleColPtr;
size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer
[[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer
// only used if HasZeroPoint == true

for (size_t k = 0; k < CountK; k += BlkLen) {
const size_t k_blk_len = std::min(CountK - k, BlkLen);
Expand All @@ -257,8 +258,9 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
[&](size_t i) { scale[i] = QuantBScale[i * StrideQuantBScale]; }
);

float offset[NCols]; // Includes zero point and float conversion offset of 16.
if (QuantBZeroPointColPtr != nullptr) {
[[maybe_unused]] float offset[NCols]; // Includes zero point and float conversion offset of 16.
// only used if HasZeroPoint == true
if constexpr (HasZeroPoint) {
UnrolledLoop<NCols>([&](size_t i) {
const std::byte zp_packed =
QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2];
Expand All @@ -267,11 +269,6 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
: (zp_packed & std::byte{0x0F});
offset[i] = 16.0f + std::to_integer<uint8_t>(zp);
});
} else {
UnrolledLoop<NCols>([&](size_t i) {
constexpr float zp = 8.0f;
offset[i] = 16.0f + zp;
});
}

for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) {
Expand Down Expand Up @@ -325,10 +322,17 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
});

// subtract float conversion offset (16) and zero point
UnrolledLoop<NCols>([&](size_t i) {
const float32x4_t offset_v = vdupq_n_f32(offset[i]);
UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); });
});
if constexpr (HasZeroPoint) {
UnrolledLoop<NCols>([&](size_t i) {
const float32x4_t offset_v = vdupq_n_f32(offset[i]);
UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); });
});
} else {
const float32x4_t offset_v = vdupq_n_f32(16.0f + 8.0f);
UnrolledLoop<NCols>([&](size_t i) {
UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); });
});
}

// multiply by scale
UnrolledLoop<NCols>([&](size_t i) {
Expand All @@ -345,7 +349,9 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
// increment pointers to next block
QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
QuantBScale += 1;
QuantBZeroPointIdx += 1;
if constexpr (HasZeroPoint) {
QuantBZeroPointIdx += 1;
}
}

if constexpr (NCols == 4) {
Expand All @@ -366,8 +372,9 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
}
}

template <bool HasZeroPoint>
MLAS_FORCEINLINE void
SQ4BitGemmM1Kernel_CompFp32(
SQ4BitGemmM1Kernel_CompFp32_Impl(
size_t BlkLen,
const float* A,
const std::byte* QuantBData,
Expand Down Expand Up @@ -403,7 +410,7 @@ SQ4BitGemmM1Kernel_CompFp32(
int64_t nblk = static_cast<int64_t>(CountN) - NCols;

while (nblk >= 0) {
ComputeDotProducts_BlkBitWidth4_CompFp32<NCols>(
ComputeDotProducts_BlkBitWidth4_CompFp32<NCols, HasZeroPoint>(
BlkLen,
ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK,
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
Expand All @@ -414,7 +421,7 @@ SQ4BitGemmM1Kernel_CompFp32(

QuantBDataColPtr += NCols * StrideQuantBData;
QuantBScaleColPtr += NCols * StrideQuantBScale;
if (QuantBZeroPointColPtr != nullptr) {
if constexpr (HasZeroPoint) {
QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint;
}

Expand All @@ -427,7 +434,7 @@ SQ4BitGemmM1Kernel_CompFp32(
// left over columns less than `NCols`?
nblk += NCols;
for (int64_t n = 0; n < nblk; ++n) {
ComputeDotProducts_BlkBitWidth4_CompFp32<1>(
ComputeDotProducts_BlkBitWidth4_CompFp32<1, HasZeroPoint>(
BlkLen,
ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK,
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
Expand All @@ -438,7 +445,7 @@ SQ4BitGemmM1Kernel_CompFp32(

QuantBDataColPtr += StrideQuantBData;
QuantBScaleColPtr += StrideQuantBScale;
if (QuantBZeroPointColPtr != nullptr) {
if constexpr (HasZeroPoint) {
QuantBZeroPointColPtr += StrideQuantBZeroPoint;
}

Expand All @@ -447,6 +454,49 @@ SQ4BitGemmM1Kernel_CompFp32(
}
}

MLAS_FORCEINLINE void
SQ4BitGemmM1Kernel_CompFp32(
size_t BlkLen,
const float* A,
const std::byte* QuantBData,
const float* QuantBScale,
const std::byte* QuantBZeroPoint,
float* C,
size_t CountN,
size_t CountK,
size_t BlockStrideQuantB,
const float* Bias
)
{
if (QuantBZeroPoint != nullptr) {
SQ4BitGemmM1Kernel_CompFp32_Impl<true>(
BlkLen,
A,
QuantBData,
QuantBScale,
QuantBZeroPoint,
C,
CountN,
CountK,
BlockStrideQuantB,
Bias
);
} else {
SQ4BitGemmM1Kernel_CompFp32_Impl<false>(
BlkLen,
A,
QuantBData,
QuantBScale,
QuantBZeroPoint,
C,
CountN,
CountK,
BlockStrideQuantB,
Bias
);
}
}

MLAS_FORCEINLINE void
Q4BitBlkDequantBForSgemm_CompFp32(
size_t BlkLen,
Expand Down Expand Up @@ -637,7 +687,7 @@ QuantizeARow_CompInt8(
}
}

template <size_t NCols, size_t SubBlkLen>
template <size_t NCols, size_t SubBlkLen, bool HasZeroPoint>
MLAS_FORCEINLINE void
ComputeDotProducts_BlkBitWidth4_CompInt8(
size_t BlkLen,
Expand All @@ -660,14 +710,15 @@ ComputeDotProducts_BlkBitWidth4_CompInt8(

assert(BlkLen >= SubBlkLen && BlkLen % SubBlkLen == 0);

const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F);
const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F);
[[maybe_unused]] const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F); // only used if SubBlkLen == 16
[[maybe_unused]] const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); // only used if SubBlkLen == 32

const std::byte* QuantA = QuantARowPtr;

const std::byte* QuantBData = QuantBDataColPtr;
const float* QuantBScale = QuantBScaleColPtr;
size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer
[[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer
// only used if HasZeroPoint == true

float32x4_t acc[NCols]{};

Expand All @@ -680,19 +731,15 @@ ComputeDotProducts_BlkBitWidth4_CompInt8(
float b_scale[NCols];
UnrolledLoop<NCols>([&](size_t i) { b_scale[i] = QuantBScale[i * StrideQuantBScale]; });

int8_t b_zp[NCols];
if (QuantBZeroPointColPtr != nullptr) {
[[maybe_unused]] int8_t b_zp[NCols]; // only used if HasZeroPoint == true
if constexpr (HasZeroPoint) {
UnrolledLoop<NCols>([&](size_t i) {
const std::byte zp_packed =
QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2];
b_zp[i] = ((QuantBZeroPointIdx & 1) == 1)
? std::to_integer<int8_t>(zp_packed >> 4)
: std::to_integer<int8_t>(zp_packed & std::byte{0x0F});
});
} else {
UnrolledLoop<NCols>([&](size_t i) {
b_zp[i] = 8;
});
}

for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) {
Expand Down Expand Up @@ -737,12 +784,22 @@ ComputeDotProducts_BlkBitWidth4_CompInt8(
}

// subtract B zero point
UnrolledLoop<NCols>([&](size_t i) {
const int8x16_t zp_v = vdupq_n_s8(b_zp[i]);
UnrolledLoop<SubBlkLen / 16>([&](size_t j) {
bv[i][j] = vsubq_s8(bv[i][j], zp_v);
if constexpr (HasZeroPoint) {
UnrolledLoop<NCols>([&](size_t i) {
const int8x16_t zp_v = vdupq_n_s8(b_zp[i]);
UnrolledLoop<SubBlkLen / 16>([&](size_t j) {
bv[i][j] = vsubq_s8(bv[i][j], zp_v);
});
});
});
} else {
const int8x16_t zp_v = vdupq_n_s8(8);

UnrolledLoop<NCols>([&](size_t i) {
UnrolledLoop<SubBlkLen / 16>([&](size_t j) {
bv[i][j] = vsubq_s8(bv[i][j], zp_v);
});
});
}

// compute quantized dot product
int32x4_t dot[NCols]{};
Expand All @@ -769,7 +826,9 @@ ComputeDotProducts_BlkBitWidth4_CompInt8(
QuantA += Q8BlkSize(BlkLen);
QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
QuantBScale += 1;
QuantBZeroPointIdx += 1;
if constexpr (HasZeroPoint) {
QuantBZeroPointIdx += 1;
}
}

if constexpr (NCols == 4) {
Expand All @@ -790,7 +849,7 @@ ComputeDotProducts_BlkBitWidth4_CompInt8(
}
}

template <size_t NCols, size_t SubBlkLen>
template <size_t NCols, size_t SubBlkLen, bool HasZeroPoint>
MLAS_FORCEINLINE void
SQ4BitGemmM1Kernel_CompInt8_Impl(
size_t BlkLen,
Expand Down Expand Up @@ -827,7 +886,7 @@ SQ4BitGemmM1Kernel_CompInt8_Impl(
int64_t nblk = static_cast<int64_t>(CountN) - NCols;

while (nblk >= 0) {
ComputeDotProducts_BlkBitWidth4_CompInt8<NCols, SubBlkLen>(
ComputeDotProducts_BlkBitWidth4_CompInt8<NCols, SubBlkLen, HasZeroPoint>(
BlkLen,
QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK,
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
Expand All @@ -838,7 +897,7 @@ SQ4BitGemmM1Kernel_CompInt8_Impl(

QuantBDataColPtr += NCols * StrideQuantBData;
QuantBScaleColPtr += NCols * StrideQuantBScale;
if (QuantBZeroPointColPtr != nullptr) {
if constexpr (HasZeroPoint) {
QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint;
}

Expand All @@ -851,7 +910,7 @@ SQ4BitGemmM1Kernel_CompInt8_Impl(
// left over columns less than `NCols`?
nblk += NCols;
for (int64_t n = 0; n < nblk; ++n) {
ComputeDotProducts_BlkBitWidth4_CompInt8<1, SubBlkLen>(
ComputeDotProducts_BlkBitWidth4_CompInt8<1, SubBlkLen, HasZeroPoint>(
BlkLen,
QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK,
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
Expand All @@ -862,7 +921,7 @@ SQ4BitGemmM1Kernel_CompInt8_Impl(

QuantBDataColPtr += StrideQuantBData;
QuantBScaleColPtr += StrideQuantBScale;
if (QuantBZeroPointColPtr != nullptr) {
if constexpr (HasZeroPoint) {
QuantBZeroPointColPtr += StrideQuantBZeroPoint;
}

Expand All @@ -871,6 +930,50 @@ SQ4BitGemmM1Kernel_CompInt8_Impl(
}
}

template <bool HasZeroPoint>
MLAS_FORCEINLINE void
SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(
size_t BlkLen,
const std::byte* QuantA,
const std::byte* QuantBData,
const float* QuantBScale,
const std::byte* QuantBZeroPoint,
float* C,
size_t CountN,
size_t CountK,
size_t BlockStrideQuantB,
const float* Bias
)
{
if (BlkLen == 16) {
SQ4BitGemmM1Kernel_CompInt8_Impl<4, 16, HasZeroPoint>(
BlkLen,
QuantA,
QuantBData,
QuantBScale,
QuantBZeroPoint,
C,
CountN,
CountK,
BlockStrideQuantB,
Bias
);
} else {
SQ4BitGemmM1Kernel_CompInt8_Impl<4, 32, HasZeroPoint>(
BlkLen,
QuantA,
QuantBData,
QuantBScale,
QuantBZeroPoint,
C,
CountN,
CountK,
BlockStrideQuantB,
Bias
);
}
}

MLAS_FORCEINLINE
void
SQ4BitGemmM1Kernel_CompInt8(
Expand All @@ -886,8 +989,8 @@ SQ4BitGemmM1Kernel_CompInt8(
const float* Bias
)
{
if (BlkLen == 16) {
SQ4BitGemmM1Kernel_CompInt8_Impl<4, 16>(
if (QuantBZeroPoint != nullptr) {
SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<true>(
BlkLen,
QuantA,
QuantBData,
Expand All @@ -900,7 +1003,7 @@ SQ4BitGemmM1Kernel_CompInt8(
Bias
);
} else {
SQ4BitGemmM1Kernel_CompInt8_Impl<4, 32>(
SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<false>(
BlkLen,
QuantA,
QuantBData,
Expand Down

0 comments on commit a4b197c

Please sign in to comment.