From 2f5fe4500d5b252bb64fce29cb7c1f92326c1074 Mon Sep 17 00:00:00 2001 From: liqun Fu Date: Sat, 27 Apr 2024 00:18:53 +0000 Subject: [PATCH] mlas nbit matmul requires packed_b (#20482) ### Description mlas matmul nbits implementation requires packed b. have a condition for this. need to update this logic if it changes. ### Motivation and Context --------- Signed-off-by: Liqun Fu --- .../contrib_ops/cpu/quantization/matmul_nbits.cc | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 6144ca46ac180..dbc678c9bc9c6 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -322,7 +322,8 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { if (has_single_b_matrix) { const auto compute_type = static_cast(accuracy_level_); - if (MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { + // mlas nbits implementation requires packed b. update this logic if it changes. + if (MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type) && packed_b_) { IAllocatorUniquePtr workspace{}; if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count, nbits_, block_size_, compute_type); @@ -332,20 +333,11 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); } - const void* b_data = [&]() -> const void* { - if (packed_b_) { - return packed_b_.get(); - } - - const Tensor* b = ctx->Input(1); - return b->DataRaw(); - }(); - InlinedVector data(batch_count); for (size_t i = 0; i < batch_count; ++i) { data[i].A = a_data + helper.LeftOffsets()[i]; data[i].lda = lda; - data[i].QuantBData = b_data; + data[i].QuantBData = packed_b_.get(); data[i].QuantBScale = scales_data; data[i].QuantBZeroPoint = zero_points_data; data[i].C = y_data + helper.OutputOffsets()[i];