Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
Signed-off-by: Liqun Fu <[email protected]>
  • Loading branch information
liqunfu committed Apr 25, 2024
1 parent 3b2a4e9 commit 8a69e1c
Showing 1 changed file with 3 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ SQ4BitGemmM1Kernel_CompInt8_avx2(
const float* Bias
);

#include <iostream>
template <size_t NCols, bool HasZeroPoint>
MLAS_FORCEINLINE void
ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen16(
Expand Down Expand Up @@ -575,6 +574,7 @@ ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen64_NCols4(
_mm_storeu_ps(SumPtr, acc_x);
}

// TODO: is this able to be inlined if DotQuadFunctionType is a function pointer?
template <bool HasZeroPoint, DotQuadFunctionType dot_quad>
MLAS_FORCEINLINE void
ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen64_NCols1(
Expand Down Expand Up @@ -622,6 +622,8 @@ ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen64_NCols1(
uint8_t zp0;
if constexpr (HasZeroPoint) {
// TODO: this block causes near 30% of the computation.
// The solution proposed by @yufenglee is to factor out scaleB * zp
// while packing A. Will do this in next PR.
bool is_lower = (QuantBZeroPointIdx & 1) == 0;
std::byte zp_packed = QuantBZeroPointColPtr[0 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2];
zp0 = std::to_integer<int8_t>(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4));
Expand Down

0 comments on commit 8a69e1c

Please sign in to comment.