Skip to content

Commit

Permalink
add post processing
Browse files Browse the repository at this point in the history
  • Loading branch information
fajin-corp committed Oct 30, 2024
1 parent d00c95d commit 98b1e5f
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions onnxruntime/core/mlas/lib/sqnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,12 @@ SQ4BitGemm_CompFp16(
a, dequant_b, Bias, c, K, lda, ldb, ldc
);

if (DataParams->PostProcessor != nullptr) {
DataParams->PostProcessor->Process(
DataParams->C, RangeStartM + m, RangeStartN + n, StrideM, StrideN, ldc
);
}

a += StrideM * lda;
c += StrideM * ldc;
}
Expand All @@ -652,6 +658,12 @@ SQ4BitGemm_CompFp16(
GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompFp16_Remainder(
a, dequant_b, Bias, c, 1, StrideN, K, lda, ldb, ldc
);

if (DataParams->PostProcessor != nullptr) {
DataParams->PostProcessor->Process(
DataParams->C, RangeStartM + m, RangeStartN + n, 1, StrideN, ldc
);
}
}

QuantBData += StrideN * qldb;
Expand All @@ -674,6 +686,12 @@ SQ4BitGemm_CompFp16(
a, dequant_b, Bias, c, countM, RangeCountN - n, K, lda, ldb, ldc
);

if (DataParams->PostProcessor != nullptr) {
DataParams->PostProcessor->Process(
DataParams->C, RangeStartM + m, RangeStartN + n, countM, RangeCountN - n, ldc
);
}

a += countM * lda;
c += countM * ldc;
}
Expand Down

0 comments on commit 98b1e5f

Please sign in to comment.