diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index bf43aca73ef3a..7d3376d92af43 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -222,16 +222,28 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat auto qptr = tensor.DataRaw(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get(), nullptr, has_zp_input_, nullptr, nullptr); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } is_packed = true; } else if (compute_type == CompInt8) { #ifdef MLAS_TARGET_AMD64_IX86 if (input_idx == InputIndex::scales && packed_b_ != nullptr) { auto sptr = tensor.Data(); MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), sptr, has_zp_input_, nullptr, nullptr); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } is_packed = false; } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { auto zptr = tensor.Data(); MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } is_packed = false; } #endif @@ -267,7 +279,19 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep used_shared_buffers = true; packed_b_ = std::move(prepacked_buffers[0]); } - +#ifdef MLAS_TARGET_AMD64_IX86 + const auto compute_type = static_cast(accuracy_level_); + if (compute_type == CompInt8) { + if (input_idx == 2) { + used_shared_buffers = true; + packed_b_ = std::move(prepacked_buffers[0]); + } + if (input_idx == 3) { + used_shared_buffers = true; + packed_b_ = std::move(prepacked_buffers[0]); + } + } +#endif // MLAS_TARGET_AMD64_IX86 #endif // defined(ORT_NEURAL_SPEED) return Status::OK();