Skip to content

Commit

Permalink
Revert "add template argument HasZeroPoint"
Browse files Browse the repository at this point in the history
This reverts commit ccc5444.
  • Loading branch information
edgchen1 committed Jan 27, 2024
1 parent adc91ca commit 47bcd5f
Showing 1 changed file with 43 additions and 146 deletions.
189 changes: 43 additions & 146 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4])
namespace
{

template <size_t NCols, bool HasZeroPoint>
template <size_t NCols>
MLAS_FORCEINLINE void
ComputeDotProducts_BlkBitWidth4_CompFp32(
size_t BlkLen,
Expand Down Expand Up @@ -247,8 +247,7 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(

const std::byte* QuantBData = QuantBDataColPtr;
const float* QuantBScale = QuantBScaleColPtr;
[[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer
// only used if HasZeroPoint == true
size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer

for (size_t k = 0; k < CountK; k += BlkLen) {
const size_t k_blk_len = std::min(CountK - k, BlkLen);
Expand All @@ -258,9 +257,8 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
[&](size_t i) { scale[i] = QuantBScale[i * StrideQuantBScale]; }
);

[[maybe_unused]] float offset[NCols]; // Includes zero point and float conversion offset of 16.
// only used if HasZeroPoint == true
if constexpr (HasZeroPoint) {
float offset[NCols]; // Includes zero point and float conversion offset of 16.
if (QuantBZeroPointColPtr != nullptr) {
UnrolledLoop<NCols>([&](size_t i) {
const std::byte zp_packed =
QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2];
Expand All @@ -269,6 +267,11 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
: (zp_packed & std::byte{0x0F});
offset[i] = 16.0f + std::to_integer<uint8_t>(zp);
});
} else {
UnrolledLoop<NCols>([&](size_t i) {
constexpr float zp = 8.0f;
offset[i] = 16.0f + zp;
});
}

for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) {
Expand Down Expand Up @@ -322,17 +325,10 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
});

// subtract float conversion offset (16) and zero point
if constexpr (HasZeroPoint) {
UnrolledLoop<NCols>([&](size_t i) {
const float32x4_t offset_v = vdupq_n_f32(offset[i]);
UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); });
});
} else {
const float32x4_t offset_v = vdupq_n_f32(16.0f + 8.0f);
UnrolledLoop<NCols>([&](size_t i) {
UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); });
});
}
UnrolledLoop<NCols>([&](size_t i) {
const float32x4_t offset_v = vdupq_n_f32(offset[i]);
UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); });
});

// multiply by scale
UnrolledLoop<NCols>([&](size_t i) {
Expand All @@ -349,9 +345,7 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
// increment pointers to next block
QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
QuantBScale += 1;
if constexpr (HasZeroPoint) {
QuantBZeroPointIdx += 1;
}
QuantBZeroPointIdx += 1;
}

if constexpr (NCols == 4) {
Expand All @@ -372,9 +366,8 @@ ComputeDotProducts_BlkBitWidth4_CompFp32(
}
}

template <bool HasZeroPoint>
MLAS_FORCEINLINE void
SQ4BitGemmM1Kernel_CompFp32_Impl(
SQ4BitGemmM1Kernel_CompFp32(
size_t BlkLen,
const float* A,
const std::byte* QuantBData,
Expand Down Expand Up @@ -410,7 +403,7 @@ SQ4BitGemmM1Kernel_CompFp32_Impl(
int64_t nblk = static_cast<int64_t>(CountN) - NCols;

while (nblk >= 0) {
ComputeDotProducts_BlkBitWidth4_CompFp32<NCols, HasZeroPoint>(
ComputeDotProducts_BlkBitWidth4_CompFp32<NCols>(
BlkLen,
ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK,
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
Expand All @@ -421,7 +414,7 @@ SQ4BitGemmM1Kernel_CompFp32_Impl(

QuantBDataColPtr += NCols * StrideQuantBData;
QuantBScaleColPtr += NCols * StrideQuantBScale;
if constexpr (HasZeroPoint) {
if (QuantBZeroPointColPtr != nullptr) {
QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint;
}

Expand All @@ -434,7 +427,7 @@ SQ4BitGemmM1Kernel_CompFp32_Impl(
// left over columns less than `NCols`?
nblk += NCols;
for (int64_t n = 0; n < nblk; ++n) {
ComputeDotProducts_BlkBitWidth4_CompFp32<1, HasZeroPoint>(
ComputeDotProducts_BlkBitWidth4_CompFp32<1>(
BlkLen,
ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK,
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
Expand All @@ -445,7 +438,7 @@ SQ4BitGemmM1Kernel_CompFp32_Impl(

QuantBDataColPtr += StrideQuantBData;
QuantBScaleColPtr += StrideQuantBScale;
if constexpr (HasZeroPoint) {
if (QuantBZeroPointColPtr != nullptr) {
QuantBZeroPointColPtr += StrideQuantBZeroPoint;
}

Expand All @@ -454,49 +447,6 @@ SQ4BitGemmM1Kernel_CompFp32_Impl(
}
}

MLAS_FORCEINLINE void
SQ4BitGemmM1Kernel_CompFp32(
size_t BlkLen,
const float* A,
const std::byte* QuantBData,
const float* QuantBScale,
const std::byte* QuantBZeroPoint,
float* C,
size_t CountN,
size_t CountK,
size_t BlockStrideQuantB,
const float* Bias
)
{
if (QuantBZeroPoint != nullptr) {
SQ4BitGemmM1Kernel_CompFp32_Impl<true>(
BlkLen,
A,
QuantBData,
QuantBScale,
QuantBZeroPoint,
C,
CountN,
CountK,
BlockStrideQuantB,
Bias
);
} else {
SQ4BitGemmM1Kernel_CompFp32_Impl<false>(
BlkLen,
A,
QuantBData,
QuantBScale,
QuantBZeroPoint,
C,
CountN,
CountK,
BlockStrideQuantB,
Bias
);
}
}

MLAS_FORCEINLINE void
Q4BitBlkDequantBForSgemm_CompFp32(
size_t BlkLen,
Expand Down Expand Up @@ -687,7 +637,7 @@ QuantizeARow_CompInt8(
}
}

template <size_t NCols, size_t SubBlkLen, bool HasZeroPoint>
template <size_t NCols, size_t SubBlkLen>
MLAS_FORCEINLINE void
ComputeDotProducts_BlkBitWidth4_CompInt8(
size_t BlkLen,
Expand All @@ -710,15 +660,14 @@ ComputeDotProducts_BlkBitWidth4_CompInt8(

assert(BlkLen >= SubBlkLen && BlkLen % SubBlkLen == 0);

[[maybe_unused]] const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F); // only used if SubBlkLen == 16
[[maybe_unused]] const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); // only used if SubBlkLen == 32
const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F);
const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F);

const std::byte* QuantA = QuantARowPtr;

const std::byte* QuantBData = QuantBDataColPtr;
const float* QuantBScale = QuantBScaleColPtr;
[[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer
// only used if HasZeroPoint == true
size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer

float32x4_t acc[NCols]{};

Expand All @@ -731,15 +680,19 @@ ComputeDotProducts_BlkBitWidth4_CompInt8(
float b_scale[NCols];
UnrolledLoop<NCols>([&](size_t i) { b_scale[i] = QuantBScale[i * StrideQuantBScale]; });

[[maybe_unused]] int8_t b_zp[NCols]; // only used if HasZeroPoint == true
if constexpr (HasZeroPoint) {
int8_t b_zp[NCols];
if (QuantBZeroPointColPtr != nullptr) {
UnrolledLoop<NCols>([&](size_t i) {
const std::byte zp_packed =
QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2];
b_zp[i] = ((QuantBZeroPointIdx & 1) == 1)
? std::to_integer<int8_t>(zp_packed >> 4)
: std::to_integer<int8_t>(zp_packed & std::byte{0x0F});
});
} else {
UnrolledLoop<NCols>([&](size_t i) {
b_zp[i] = 8;
});
}

for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) {
Expand Down Expand Up @@ -784,22 +737,12 @@ ComputeDotProducts_BlkBitWidth4_CompInt8(
}

// subtract B zero point
if constexpr (HasZeroPoint) {
UnrolledLoop<NCols>([&](size_t i) {
const int8x16_t zp_v = vdupq_n_s8(b_zp[i]);
UnrolledLoop<SubBlkLen / 16>([&](size_t j) {
bv[i][j] = vsubq_s8(bv[i][j], zp_v);
});
});
} else {
const int8x16_t zp_v = vdupq_n_s8(8);

UnrolledLoop<NCols>([&](size_t i) {
UnrolledLoop<SubBlkLen / 16>([&](size_t j) {
bv[i][j] = vsubq_s8(bv[i][j], zp_v);
});
UnrolledLoop<NCols>([&](size_t i) {
const int8x16_t zp_v = vdupq_n_s8(b_zp[i]);
UnrolledLoop<SubBlkLen / 16>([&](size_t j) {
bv[i][j] = vsubq_s8(bv[i][j], zp_v);
});
}
});

// compute quantized dot product
int32x4_t dot[NCols]{};
Expand All @@ -826,9 +769,7 @@ ComputeDotProducts_BlkBitWidth4_CompInt8(
QuantA += Q8BlkSize(BlkLen);
QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
QuantBScale += 1;
if constexpr (HasZeroPoint) {
QuantBZeroPointIdx += 1;
}
QuantBZeroPointIdx += 1;
}

if constexpr (NCols == 4) {
Expand All @@ -849,7 +790,7 @@ ComputeDotProducts_BlkBitWidth4_CompInt8(
}
}

template <size_t NCols, size_t SubBlkLen, bool HasZeroPoint>
template <size_t NCols, size_t SubBlkLen>
MLAS_FORCEINLINE void
SQ4BitGemmM1Kernel_CompInt8_Impl(
size_t BlkLen,
Expand Down Expand Up @@ -886,7 +827,7 @@ SQ4BitGemmM1Kernel_CompInt8_Impl(
int64_t nblk = static_cast<int64_t>(CountN) - NCols;

while (nblk >= 0) {
ComputeDotProducts_BlkBitWidth4_CompInt8<NCols, SubBlkLen, HasZeroPoint>(
ComputeDotProducts_BlkBitWidth4_CompInt8<NCols, SubBlkLen>(
BlkLen,
QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK,
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
Expand All @@ -897,7 +838,7 @@ SQ4BitGemmM1Kernel_CompInt8_Impl(

QuantBDataColPtr += NCols * StrideQuantBData;
QuantBScaleColPtr += NCols * StrideQuantBScale;
if constexpr (HasZeroPoint) {
if (QuantBZeroPointColPtr != nullptr) {
QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint;
}

Expand All @@ -910,7 +851,7 @@ SQ4BitGemmM1Kernel_CompInt8_Impl(
// left over columns less than `NCols`?
nblk += NCols;
for (int64_t n = 0; n < nblk; ++n) {
ComputeDotProducts_BlkBitWidth4_CompInt8<1, SubBlkLen, HasZeroPoint>(
ComputeDotProducts_BlkBitWidth4_CompInt8<1, SubBlkLen>(
BlkLen,
QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK,
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
Expand All @@ -921,7 +862,7 @@ SQ4BitGemmM1Kernel_CompInt8_Impl(

QuantBDataColPtr += StrideQuantBData;
QuantBScaleColPtr += StrideQuantBScale;
if constexpr (HasZeroPoint) {
if (QuantBZeroPointColPtr != nullptr) {
QuantBZeroPointColPtr += StrideQuantBZeroPoint;
}

Expand All @@ -930,50 +871,6 @@ SQ4BitGemmM1Kernel_CompInt8_Impl(
}
}

template <bool HasZeroPoint>
MLAS_FORCEINLINE void
SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(
size_t BlkLen,
const std::byte* QuantA,
const std::byte* QuantBData,
const float* QuantBScale,
const std::byte* QuantBZeroPoint,
float* C,
size_t CountN,
size_t CountK,
size_t BlockStrideQuantB,
const float* Bias
)
{
if (BlkLen == 16) {
SQ4BitGemmM1Kernel_CompInt8_Impl<4, 16, HasZeroPoint>(
BlkLen,
QuantA,
QuantBData,
QuantBScale,
QuantBZeroPoint,
C,
CountN,
CountK,
BlockStrideQuantB,
Bias
);
} else {
SQ4BitGemmM1Kernel_CompInt8_Impl<4, 32, HasZeroPoint>(
BlkLen,
QuantA,
QuantBData,
QuantBScale,
QuantBZeroPoint,
C,
CountN,
CountK,
BlockStrideQuantB,
Bias
);
}
}

MLAS_FORCEINLINE
void
SQ4BitGemmM1Kernel_CompInt8(
Expand All @@ -989,8 +886,8 @@ SQ4BitGemmM1Kernel_CompInt8(
const float* Bias
)
{
if (QuantBZeroPoint != nullptr) {
SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<true>(
if (BlkLen == 16) {
SQ4BitGemmM1Kernel_CompInt8_Impl<4, 16>(
BlkLen,
QuantA,
QuantBData,
Expand All @@ -1003,7 +900,7 @@ SQ4BitGemmM1Kernel_CompInt8(
Bias
);
} else {
SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<false>(
SQ4BitGemmM1Kernel_CompInt8_Impl<4, 32>(
BlkLen,
QuantA,
QuantBData,
Expand Down

0 comments on commit 47bcd5f

Please sign in to comment.