diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 6144ca46ac180..dbc678c9bc9c6 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -322,7 +322,8 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { if (has_single_b_matrix) { const auto compute_type = static_cast(accuracy_level_); - if (MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { + // mlas nbits implementation requires packed b. update this logic if it changes. + if (MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type) && packed_b_) { IAllocatorUniquePtr workspace{}; if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count, nbits_, block_size_, compute_type); @@ -332,20 +333,11 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); } - const void* b_data = [&]() -> const void* { - if (packed_b_) { - return packed_b_.get(); - } - - const Tensor* b = ctx->Input(1); - return b->DataRaw(); - }(); - InlinedVector data(batch_count); for (size_t i = 0; i < batch_count; ++i) { data[i].A = a_data + helper.LeftOffsets()[i]; data[i].lda = lda; - data[i].QuantBData = b_data; + data[i].QuantBData = packed_b_.get(); data[i].QuantBScale = scales_data; data[i].QuantBZeroPoint = zero_points_data; data[i].C = y_data + helper.OutputOffsets()[i];