Skip to content

Commit

Permalink
Integrate high-performance x64 gemm library to MLAS (#17669)
Browse files Browse the repository at this point in the history
### Description
Improve MLAS to support high-performance x64 INT4 kernels



### Motivation and Context
1. improve LLM inference performance on Intel CPUs.
2. support more 4bit quantization types: nf4, fp4
3. support dynamic block size: block size aligned with kernel's tiling
size(e.g. 4 for VNNI kernel), per channel on N dimension
4. support most Intel ISAs: avx2, avx_vnni, avx512f, avx512_vnni,
amx_bf16, amx_int8, avx512_fp16
5. support MatMulNBits' data format

### Tasks
- [x] support block_size: 32, 128, -1(per channel)
- [x] get weight pack size without memory allocation
- [x] use ort's thread pool for parallelism
- [x] support ISAs: avx2, avx512f, avx_vnni, avx512_vnni, amx_int8

### Benchmark
Ubuntu 20.22 + Intel(R) Xeon(R) Platinum 8480+ 56 cores

Benchmark | Time | CPU | Iterations
-- | -- | -- | --
Q4GEMM_Jblas/Q4G32SymInt8/M:1/N:4096/K:4096/Threads:56/real_time | 47613
| 47401 | 12970
Q4GEMM_Jblas/Q4G32SymInt8/M:1024/N:4096/K:4096/Threads:56/real_time |
6347792 | 6317562 | 109
Q4GEMM_Jblas/Q4G32SymInt8/M:2048/N:4096/K:4096/Threads:56/real_time |
11814014 | 11757847 | 59
Q4GEMM_Jblas/Q4G128SymInt8/M:1/N:4096/K:4096/Threads:56/real_time |
50222 | 50031 | 13759
Q4GEMM_Jblas/Q4G128SymInt8/M:1024/N:4096/K:4096/Threads:56/real_time |
2038222 | 2028743 | 341
Q4GEMM_Jblas/Q4G128SymInt8/M:2048/N:4096/K:4096/Threads:56/real_time |
3792832 | 3774485 | 191
Q4GEMM_Jblas/Q4GPerNSymInt8/M:1/N:4096/K:4096/Threads:56/real_time |
58717 | 58501 | 11467
Q4GEMM_Jblas/Q4GPerNSymInt8/M:1024/N:4096/K:4096/Threads:56/real_time |
1360846 | 1354598 | 543
Q4GEMM_Jblas/Q4GPerNSymInt8/M:2048/N:4096/K:4096/Threads:56/real_time |
2564232 | 2551365 | 266
Q4GEMM_Jblas/Q4G32SymFp32/M:1/N:4096/K:4096/Threads:56/real_time | 57929
| 57694 | 12047
Q4GEMM_Jblas/Q4G32SymFp32/M:1024/N:4096/K:4096/Threads:56/real_time |
5495330 | 5465810 | 126
Q4GEMM_Jblas/Q4G32SymFp32/M:2048/N:4096/K:4096/Threads:56/real_time |
10676240 | 10617817 | 66
Q4GEMM_Jblas/Q4G128SymFp32/M:1/N:4096/K:4096/Threads:56/real_time |
68305 | 68047 | 10026
Q4GEMM_Jblas/Q4G128SymFp32/M:1024/N:4096/K:4096/Threads:56/real_time |
5504862 | 5476215 | 126
Q4GEMM_Jblas/Q4G128SymFp32/M:2048/N:4096/K:4096/Threads:56/real_time |
11758623 | 11697337 | 66
Q4GEMM_Jblas/Q4GPerNSymFp32/M:1/N:4096/K:4096/Threads:56/real_time |
67713 | 67451 | 10298
Q4GEMM_Jblas/Q4GPerNSymFp32/M:1024/N:4096/K:4096/Threads:56/real_time |
5508325 | 5480237 | 126
Q4GEMM_Jblas/Q4GPerNSymFp32/M:2048/N:4096/K:4096/Threads:56/real_time |
10738528 | 10681656 | 64
Q4GEMM_Jblas/Q4G32AsymFp32/M:1/N:4096/K:4096/Threads:56/real_time |
60708 | 60486 | 11321
Q4GEMM_Jblas/Q4G32AsymFp32/M:1024/N:4096/K:4096/Threads:56/real_time |
5523784 | 5495736 | 126
Q4GEMM_Jblas/Q4G32AsymFp32/M:2048/N:4096/K:4096/Threads:56/real_time |
10829633 | 10772161 | 67


Reference:

Benchmark | Time | CPU | Iterations
-- | -- | -- | --
Q4GEMM/Q4Sym/M:1/N:4096/K:4096/Threads:56/real_time | 53088 | 52911 |
13364
Q4GEMM/Q4Sym/M:1024/N:4096/K:4096/Threads:56/real_time | 6268981 |
6230335 | 110
Q4GEMM/Q4Sym/M:2048/N:4096/K:4096/Threads:56/real_time | 11701237 |
11632339 | 59

Win11+12900K 8 cores:
Benchmark | Time | CPU | Iterations
-- | -- | -- | --
Q4GEMM_Jblas/Q4G32SymInt8/M:1/N:4096/K:4096/Threads:8/real_time | 215976
| 211295 | 2884
Q4GEMM_Jblas/Q4G32SymInt8/M:1024/N:4096/K:4096/Threads:8/real_time |
60960590 | 60937500 | 10
Q4GEMM_Jblas/Q4G32SymInt8/M:2048/N:4096/K:4096/Threads:8/real_time |
1.18E+08 | 1.19E+08 | 5
Q4GEMM_Jblas/Q4G32SymInt8/M:1/N:11008/K:4096/Threads:8/real_time |
470377 | 453059 | 1414
Q4GEMM_Jblas/Q4G32SymInt8/M:1024/N:11008/K:4096/Threads:8/real_time |
1.54E+08 | 1.53E+08 | 5
Q4GEMM_Jblas/Q4G32SymInt8/M:2048/N:11008/K:4096/Threads:8/real_time |
3.18E+08 | 3.13E+08 | 2
Q4GEMM_Jblas/Q4G32SymInt8/M:1/N:4096/K:11008/Threads:8/real_time |
569072 | 559398 | 1229
Q4GEMM_Jblas/Q4G32SymInt8/M:1024/N:4096/K:11008/Threads:8/real_time |
1.54E+08 | 1.52E+08 | 4
Q4GEMM_Jblas/Q4G32SymInt8/M:2048/N:4096/K:11008/Threads:8/real_time |
3.22E+08 | 3.28E+08 | 2
Q4GEMM_Jblas/Q4G32SymInt8/M:1/N:11008/K:11008/Threads:8/real_time |
1486055 | 1473325 | 403
Q4GEMM_Jblas/Q4G32SymInt8/M:1024/N:11008/K:11008/Threads:8/real_time |
4.14E+08 | 4.14E+08 | 2
Q4GEMM_Jblas/Q4G32SymInt8/M:2048/N:11008/K:11008/Threads:8/real_time |
8.88E+08 | 8.59E+08 | 1

---------

Signed-off-by: Mengni Wang <[email protected]>
Co-authored-by: Mengni Wang <[email protected]>
  • Loading branch information
luoyu-intel and mengniwang95 authored Dec 19, 2023
1 parent 4dff154 commit 5f00bc9
Show file tree
Hide file tree
Showing 37 changed files with 24,902 additions and 10 deletions.
12 changes: 12 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF)
option(onnxruntime_USE_SNPE "Build with SNPE support" OFF)
option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF)
option(onnxruntime_USE_DNNL "Build with DNNL support" OFF)
option(onnxruntime_USE_JBLAS "Build MLAS with JBLAS support" ON)
option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF)
option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON)
option(onnxruntime_BUILD_CSHARP "Build C# library" OFF)
Expand Down Expand Up @@ -1166,6 +1167,17 @@ if (onnxruntime_USE_DNNL)
add_compile_definitions(DNNL_OPENMP)
endif()

set(USE_JBLAS FALSE)
if (onnxruntime_USE_JBLAS AND NOT onnxruntime_MINIMAL_BUILD)
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64")
add_compile_definitions(MLAS_JBLAS)
set(USE_JBLAS TRUE)
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64")
add_compile_definitions(MLAS_JBLAS)
set(USE_JBLAS TRUE)
endif()
endif()

# TVM EP
if (onnxruntime_USE_TVM)
if (NOT TARGET tvm)
Expand Down
16 changes: 14 additions & 2 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ endif()

set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas)

function(add_jblas)
add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas)
target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/jblas_gemm.cpp
)
set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR OFF)
endfunction()

#TODO: set MASM flags properly
function(setup_mlas_source_for_windows)

Expand Down Expand Up @@ -200,7 +209,6 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/q4gemm_avx512.cpp
)
endif()

else()
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp
Expand Down Expand Up @@ -566,7 +574,7 @@ else()
)
set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f")
set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f")
endif()
endif()

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
onnxruntime_add_static_library(onnxruntime_mlas_x86_64 ${mlas_platform_srcs})
Expand Down Expand Up @@ -604,6 +612,10 @@ else()
target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs})
endif()

if(USE_JBLAS)
add_jblas()
endif()

foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS})
target_include_directories(${mlas_target} PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${MLAS_SRC_DIR})
onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET})
Expand Down
2 changes: 2 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2824,6 +2824,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>size of each input feature</dd>
<dt><tt>N</tt> : int (required)</dt>
<dd>size of each output feature</dd>
<dt><tt>accuracy_level</tt> : int</dt>
<dd>The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8) (default unset). It is used to control how input A is quantized or downcast internally while doing computation, for example: 0 means input A will not be quantized or downcast while doing computation. 4 means input A can be quantized with the same block_size to int8 internally from type T1.</dd>
<dt><tt>bits</tt> : int (required)</dt>
<dd>number of bits used for weight quantization (default 4)</dd>
<dt><tt>block_size</tt> : int (required)</dt>
Expand Down
134 changes: 131 additions & 3 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,158 @@ class MatMulNBits final : public OpKernel {
K_{narrow<size_t>(info.GetAttr<int64_t>("K"))},
N_{narrow<size_t>(info.GetAttr<int64_t>("N"))},
block_size_{narrow<size_t>(info.GetAttr<int64_t>("block_size"))},
nbits_{narrow<size_t>(info.GetAttr<int64_t>("bits"))} {
nbits_{narrow<size_t>(info.GetAttr<int64_t>("bits"))},
accuracy_level_{info.GetAttr<int64_t>("accuracy_level")} {
ORT_ENFORCE(nbits_ == 4,
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
is_asym_ = info.GetInputCount() >= 4;
const Tensor* tensor_B = nullptr;
const Tensor* tensor_scale = nullptr;
const Tensor* tensor_zero_point = nullptr;
bool B_constant = info.TryGetConstantInput(1, &tensor_B);
bool scale_constant = info.TryGetConstantInput(2, &tensor_scale);
bool zero_point_constant = info.TryGetConstantInput(3, &tensor_zero_point);
all_constant_ = B_constant && scale_constant;
all_constant_ = is_asym_ ? all_constant_ && zero_point_constant : all_constant_;
}

Status Compute(OpKernelContext* context) const override;

Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override;

Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers, int input_idx,
/*out*/ bool& used_shared_buffers) override;

private:
const size_t K_;
const size_t N_;
const size_t block_size_;
const size_t nbits_;
const int64_t accuracy_level_;
const bool column_wise_quant_{true};
IAllocatorUniquePtr<void> packed_b_;
size_t packed_b_size_{0};
bool is_asym_{false};
bool all_constant_{false};
};

Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
is_packed = false;
if (!all_constant_) {
return Status::OK();
}
auto compt_type = static_cast<MLAS_SQNBIT_COMPUTE_TYPE>(accuracy_level_);
MLAS_THREADPOOL* pool = NULL;
if (input_idx == 1) {
packed_b_size_ = MlasNBitsGemmPackBSize(N_, K_, block_size_, static_cast<int>(nbits_), is_asym_, compt_type);
if (packed_b_size_ == 0) return Status::OK();
auto qptr = tensor.Data<uint8_t>();
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
if (packed_b_ == nullptr) {
return Status::OK();
}
std::memset(packed_b_.get(), 0, packed_b_size_);
MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, static_cast<int>(nbits_),
is_asym_, false, compt_type, pool);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
}
is_packed = true;
}
if (input_idx == 2 && packed_b_ != nullptr) {
auto sptr = tensor.Data<float>();
MlasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, static_cast<int>(nbits_),
is_asym_, !is_asym_, compt_type, pool);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
}
is_packed = true;
}
if (input_idx == 3 && packed_b_ != nullptr) {
auto zptr = tensor.Data<uint8_t>();
MlasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, static_cast<int>(nbits_),
is_asym_, is_asym_, compt_type, pool);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
}
is_packed = true;
}

return Status::OK();
}

Status MatMulNBits::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers, int input_idx,
/*out*/ bool& used_shared_buffers) {
used_shared_buffers = false;
// Pack three tensors into one buffer
if (input_idx == 1) {
used_shared_buffers = true;
packed_b_ = std::move(prepacked_buffers[0]);
}
if (input_idx == 2) {
used_shared_buffers = true;
packed_b_ = std::move(prepacked_buffers[0]);
}
if (input_idx == 3) {
used_shared_buffers = true;
packed_b_ = std::move(prepacked_buffers[0]);
}
return Status::OK();
}

Status MatMulNBits::Compute(OpKernelContext* ctx) const {
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();

const Tensor* a = ctx->Input<Tensor>(0);
const auto* a_data = a->Data<float>();

if (packed_b_.get()) {
TensorShape b_shape({static_cast<int64_t>(N_), static_cast<int64_t>(K_)});

MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true));

Tensor* y = ctx->Output(0, helper.OutputShape());

// Bail out early if the output is going to be empty
if (y->Shape().Size() == 0) return Status::OK();

auto* y_data = y->MutableData<float>();

const size_t max_len = helper.OutputOffsets().size();
const size_t M = static_cast<size_t>(helper.M());
const size_t N = static_cast<size_t>(helper.N());
const size_t K = static_cast<size_t>(helper.K());
const size_t lda = helper.Lda(false);
std::vector<MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS> gemm_params(max_len);
AllocatorPtr allocator;
auto status = ctx->GetTempSpaceAllocator(&allocator);
ORT_RETURN_IF_ERROR(status);
for (size_t i = 0; i < max_len; i++) {
gemm_params[i].A = a_data + helper.LeftOffsets()[i];
gemm_params[i].lda = lda;
gemm_params[i].B = packed_b_.get();
gemm_params[i].C = y_data + helper.OutputOffsets()[i];
gemm_params[i].ldc = N;
}
auto ws_size = MlasSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data());
// workspace for activation process(dynamic quantization and others)
auto ws_ptr = IAllocator::MakeUniquePtr<int8_t>(allocator, ws_size);
MlasSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(),
thread_pool);
return Status::OK();
}

const Tensor* b = ctx->Input<Tensor>(1);
const Tensor* scales = ctx->Input<Tensor>(2);
const Tensor* zero_points = ctx->Input<Tensor>(3);

const auto* a_data = a->Data<float>();
const uint8_t* b_data = b->Data<uint8_t>();
const auto* scales_data = scales->Data<float>();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data<uint8_t>();
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3359,6 +3359,13 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored
.Attr("N", "size of each output feature", AttributeProto::INT)
.Attr("bits", "number of bits used for weight quantization (default 4)", AttributeProto::INT)
.Attr("block_size", "number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT)
.Attr("accuracy_level",
"The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8) "
"(default unset). It is used to control how input A is quantized or downcast internally while "
"doing computation, for example: 0 means input A will not be quantized or downcast while doing "
"computation. 4 means input A can be quantized with the same block_size to int8 internally from "
"type T1.",
AttributeProto::INT, static_cast<int64_t>(0))
.Input(0, "A", "The input tensor, not quantized", "T1")
.Input(1, "B", "1-dimensional data blob", "T2")
.Input(2, "scales", "quantization scale", "T1")
Expand Down
141 changes: 141 additions & 0 deletions onnxruntime/core/mlas/inc/mlas_qnbit.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,144 @@ MlasIsSQNBitGemmAvailable(
size_t BlkBitWidth,
size_t BlkLen
);

/**
* @brief Define compute types of block quantization
*/
typedef enum {
CompUndef = 0, /*!< undef */
CompFp32 = 1, /*!< input fp32, accumulator fp32 */
CompFp16 = 2, /*!< input fp16, accumulator fp16 */
CompBf16 = 3, /*!< input bf16, accumulator fp32 */
CompInt8 = 4 /*!< input int8, accumulator int32 */
} MLAS_SQNBIT_COMPUTE_TYPE;

/**
* @brief Data parameters for NBits GEMM routine
* C = A * B
* A, C must be a float32 matrix
* B must be a packed nbits blob
* All except C are [in] parameters
*/
struct MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS {
const float* A = nullptr; /**< address of A (float32 matrix)*/
const void* B = nullptr; /**< address of B (packed nbits blob)*/
float* C = nullptr; /**< address of result matrix */
size_t lda = 0; /**< leading dimension of A */
size_t ldc = 0; /**< leading dimension of C*/
};

/**
* @brief Compute the byte size of the parameter combination
*
* @param N the number of columns of matrix B.
* @param K the number of rows of matrix B.
* @param block_size size of the block to quantize, elements from the same block share the same
* scale and zero point
* @param nbits number of bits used for weight quantization
* @param is_asym flag for asymmetric quantization
* @param comp_type specify input data type and accumulator data type
* @return size of the packing buffer, 0 if the operation is not yet supported.
*/
size_t MLASCALL
MlasNBitsGemmPackBSize(
size_t N, size_t K, size_t block_size, int nbits, bool is_asym, MLAS_SQNBIT_COMPUTE_TYPE comp_type
);

/**
* @brief Prepack tensor data from n-bit quantized data, scale and zero point buffers.
*
* @param PackedBuf packed data buffer
* @param QData quantized data buffer
* @param Scale scale pointer
* @param Zp zero point pointer
* @param N the number of columns of matrix B.
* @param K the number of rows of matrix B.
* @param ldb leading dimension of B
* @param block_size size of the block to quantize, elements from the same block share the same
* scale and zero point
* @param nbits number of bits used for weight quantization (default 4)
* @param is_asym flag for asymmetric quantization
* @param comp_type specify input data type and accumulator data type
* @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor
* one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where
* they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up
* inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale
* (is_asym is false) and Zp(is_asym is true).
* @param thread_pool
*/
void MLASCALL
MlasNBitsGemmPackB(
void* PackedBuf,
const uint8_t* QData,
const float* Scale,
const uint8_t* Zp,
size_t N,
size_t K,
size_t ldb,
size_t block_size,
int nbits,
bool is_asym,
bool last_call,
MLAS_SQNBIT_COMPUTE_TYPE comp_type,
MLAS_THREADPOOL* thread_pool
);

/**
* @brief Unpack and dequantize to fp32
*
* @param FpData unpacked float32 data
* @param PackedBuf quantized and packed data
* @param N the number of columns of matrix B.
* @param K the number of rows of matrix B.
* @param ldb leading dimension of B
* @param thread_pool
*/
void MLASCALL
MlasNBitsGemmUnPackB(
float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* thread_pool
);

/**
* @brief Get the workspace size required by computation.
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BatchN number of batches
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @return Workspace size in bytes
*/
size_t MLASCALL
MlasSQNBitsGemmBatchWorkspaceSize(
const size_t M,
const size_t N,
const size_t K,
const size_t BatchN,
const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams
);

/**
* @brief Batched GEMM: C = A * B
* A, C must be a float32 matrix
* B must be a packed nbits blob
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BatchN number of batches
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @param[in] WorkSpace temporary buffer
* @param[in] ThreadPool
* @return
*/
void MLASCALL
MlasSQNBitsGemmBatchPackedB(
const size_t M,
const size_t N,
const size_t K,
const size_t BatchN,
const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams,
void* WorkSpace,
MLAS_THREADPOOL* ThreadPool = nullptr
);
Loading

0 comments on commit 5f00bc9

Please sign in to comment.