Skip to content

Commit

Permalink
mlas nbit matmul requires packed_b (#20482)
Browse files Browse the repository at this point in the history
### Description

mlas matmul nbits implementation requires packed b. have a condition for
this.

need to update this logic if it changes.


### Motivation and Context

---------

Signed-off-by: Liqun Fu <[email protected]>
  • Loading branch information
liqunfu authored Apr 27, 2024
1 parent 619ceee commit 2f5fe45
Showing 1 changed file with 3 additions and 11 deletions.
14 changes: 3 additions & 11 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,8 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
if (has_single_b_matrix) {
const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(accuracy_level_);

if (MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) {
// mlas nbits implementation requires packed b. update this logic if it changes.
if (MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type) && packed_b_) {
IAllocatorUniquePtr<std::byte> workspace{};
if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count,
nbits_, block_size_, compute_type);
Expand All @@ -332,20 +333,11 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
workspace = IAllocator::MakeUniquePtr<std::byte>(allocator, workspace_size);
}

const void* b_data = [&]() -> const void* {
if (packed_b_) {
return packed_b_.get();
}

const Tensor* b = ctx->Input<Tensor>(1);
return b->DataRaw();
}();

InlinedVector<MLAS_SQNBIT_GEMM_DATA_PARAMS> data(batch_count);
for (size_t i = 0; i < batch_count; ++i) {
data[i].A = a_data + helper.LeftOffsets()[i];
data[i].lda = lda;
data[i].QuantBData = b_data;
data[i].QuantBData = packed_b_.get();
data[i].QuantBScale = scales_data;
data[i].QuantBZeroPoint = zero_points_data;
data[i].C = y_data + helper.OutputOffsets()[i];
Expand Down

0 comments on commit 2f5fe45

Please sign in to comment.