Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate high-performance x64 gemm library to MLAS #17669

Merged
merged 115 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
115 commits
Select commit Hold shift + click to select a range
41a8f2a
add jblas; add q4 perchannel of jblas to kernel, benchmark and test; …
luoyu-intel Sep 22, 2023
82720d0
workaround the compilation on gcc9
luoyu-intel Sep 22, 2023
c71490a
sync compiler flags
luoyu-intel Sep 27, 2023
0ee5377
get memory size without memory allocation
luoyu-intel Oct 11, 2023
44846fe
add amx_int8 kernel for computation
luoyu-intel Oct 11, 2023
2f9b26a
fix path error
luoyu-intel Oct 11, 2023
8a30f65
no need for AMX syscall
luoyu-intel Oct 11, 2023
483818c
Add MatMulNBitsCPU op
mengniwang95 Oct 30, 2023
e48430a
pass compile
luoyu-intel Oct 30, 2023
a15442f
update jblas. add comp_dtype=fp32
luoyu-intel Oct 31, 2023
438fb72
add more test cases
luoyu-intel Oct 31, 2023
b122560
add ut for mamtmul_nbits_cpu
luoyu-intel Oct 31, 2023
c758d53
fix compile errors
luoyu-intel Oct 31, 2023
20f2bb2
clang-format. update UT
luoyu-intel Oct 31, 2023
4b81a1d
update UT reference value.
luoyu-intel Oct 31, 2023
cb12f02
update jblas for MatMulNBits prepack
luoyu-intel Nov 2, 2023
a31a24d
add jblas execute path in matmul_nbits.cc
luoyu-intel Nov 2, 2023
0e5c032
add UT for matmul_nbits
luoyu-intel Nov 2, 2023
9aea379
pass pre-pack UT of matmul_nbits
luoyu-intel Nov 3, 2023
37a800c
pass UT for comp_int8
luoyu-intel Nov 3, 2023
31df05b
revert matmul_nbits_cpu
luoyu-intel Nov 3, 2023
35db583
revert code
luoyu-intel Nov 3, 2023
438a9e0
revert code
luoyu-intel Nov 3, 2023
c77403d
fix benchmark error
luoyu-intel Nov 3, 2023
a0044b2
update jblas
luoyu-intel Nov 14, 2023
3a689a4
add threading dispatch
luoyu-intel Nov 14, 2023
5ce5b58
update nbits attr
luoyu-intel Nov 14, 2023
b711f75
revert BlkQ4SymPerN
luoyu-intel Nov 15, 2023
2d6012d
fix test err
luoyu-intel Nov 15, 2023
0df5767
add AMX_INT8 kernels. move jblas option to the top.
luoyu-intel Nov 15, 2023
03559fe
typo
luoyu-intel Nov 15, 2023
5554365
remove warnings without jblas
luoyu-intel Nov 15, 2023
bdc7444
fix format
luoyu-intel Nov 15, 2023
c55d9fa
use global macro
luoyu-intel Nov 15, 2023
913cba2
fixed the perchannel block
luoyu-intel Nov 15, 2023
aafa2df
mark accuracy_level as optional.
luoyu-intel Nov 15, 2023
4ef9e46
benchmark fix
luoyu-intel Nov 15, 2023
57e16e7
use GetMlasPlatform to enable AMX
luoyu-intel Nov 15, 2023
66d2532
remove useless changes
luoyu-intel Nov 16, 2023
ac5e863
update descriptions for new functions
luoyu-intel Nov 16, 2023
2d0668c
sync compute_type with attribute order
luoyu-intel Nov 16, 2023
0180053
fix err
luoyu-intel Nov 16, 2023
bbf7bf3
check prepack pointer is nil
luoyu-intel Nov 16, 2023
7ca86ac
fix bug
luoyu-intel Nov 16, 2023
10fbfb9
Merge branch 'microsoft:main' into main
luoyu-intel Nov 16, 2023
8f7616c
add default value to accuracy_level
luoyu-intel Nov 16, 2023
c2a3be0
Merge branch 'microsoft:main' into main
luoyu-intel Nov 16, 2023
f5ada45
fix warnings of MSVC
luoyu-intel Nov 16, 2023
4c07742
resolve matmul_nbits.cc
luoyu-intel Nov 17, 2023
05ed728
Merge branch 'microsoft:main' into main
luoyu-intel Nov 17, 2023
26d9bd7
update doc of accuracy_level
luoyu-intel Nov 17, 2023
38a1054
move codes to jblas_gemm.cpp. pass UT
luoyu-intel Nov 17, 2023
89a31ec
revert clang-format
luoyu-intel Nov 17, 2023
bf85c35
add license
luoyu-intel Nov 17, 2023
f9bc5a3
fix warning and issue on windows
luoyu-intel Nov 20, 2023
a19c7b9
Merge branch 'microsoft:main' into main
luoyu-intel Nov 20, 2023
bed07c4
fix UT warning on windows
luoyu-intel Nov 20, 2023
c317006
fix bench warning
luoyu-intel Nov 20, 2023
4ee4946
fix bug of block_size=128
luoyu-intel Nov 20, 2023
6d26a1e
Fix bench bug of ldb
luoyu-intel Nov 20, 2023
99f84d8
reduce memory latency
luoyu-intel Nov 20, 2023
a732cc2
fix of lint and review
luoyu-intel Nov 21, 2023
747965a
fix cpp lint
luoyu-intel Nov 21, 2023
5cfec05
disable verbose
luoyu-intel Nov 21, 2023
b8540ca
Merge branch 'microsoft:main' into main
luoyu-intel Nov 21, 2023
e5fc97c
add auto-dispatch of accuray_level
luoyu-intel Nov 21, 2023
bf3d908
revert format
luoyu-intel Nov 21, 2023
b07ba34
bug and Lint fix
luoyu-intel Nov 21, 2023
5805f33
set USE_JBLAS on as default. fix some lint
luoyu-intel Nov 22, 2023
aba1835
Merge branch 'microsoft:main' into main
luoyu-intel Nov 22, 2023
c3c6663
Merge branch 'microsoft:main' into main
luoyu-intel Nov 22, 2023
2ad762a
add file endline
luoyu-intel Nov 22, 2023
5e18e3b
Merge branch 'microsoft:main' into main
luoyu-intel Nov 23, 2023
7cd6b78
set CompUndef equals CompFp32
luoyu-intel Nov 23, 2023
3a72778
Merge branch 'microsoft:main' into main
luoyu-intel Nov 23, 2023
5a67493
Merge branch 'microsoft:main' into main
luoyu-intel Nov 24, 2023
66f14e5
Merge branch 'microsoft:main' into main
luoyu-intel Nov 24, 2023
9669dfc
Merge branch 'microsoft:main' into main
luoyu-intel Nov 30, 2023
8f45be4
Merge branch 'microsoft:main' into main
luoyu-intel Dec 1, 2023
482cffe
use GetAttr
luoyu-intel Dec 1, 2023
025c572
change MLAS_NBITS to MLAS_SQNBITS
luoyu-intel Dec 1, 2023
810bad1
add MlasSQNBitsGemmBatchWorkspaceSize
luoyu-intel Dec 1, 2023
f41c5d3
renamed to MLAS_SQNBIT_COMPUTE_TYPE
luoyu-intel Dec 1, 2023
111d96d
add doc for last_call flag
luoyu-intel Dec 1, 2023
5842572
typo fix
luoyu-intel Dec 1, 2023
11ca78f
revise doc
luoyu-intel Dec 1, 2023
6521bba
fix benchmark compile
luoyu-intel Dec 1, 2023
d7601b5
Merge branch 'microsoft:main' into main
luoyu-intel Dec 5, 2023
e5aa4ec
add K dimension check
luoyu-intel Dec 5, 2023
bf915fb
bug fix, support padded zeropoints.
luoyu-intel Dec 5, 2023
b9aeaa2
enable jblas on x64 platform only
luoyu-intel Dec 6, 2023
46ae61a
Merge branch 'microsoft:main' into main
luoyu-intel Dec 6, 2023
d94b6bb
fix x64
luoyu-intel Dec 6, 2023
f5dc76e
fix return value
luoyu-intel Dec 6, 2023
6770c5b
support gcc before 9
luoyu-intel Dec 6, 2023
894cf11
enable jblas only with gcc and msvc
luoyu-intel Dec 6, 2023
4e32453
update doc from generated one.
luoyu-intel Dec 6, 2023
ae627d9
no jblas with minimal build
luoyu-intel Dec 6, 2023
c25eaa1
add two-session shared weight check
luoyu-intel Dec 8, 2023
f13d1a7
Merge branch 'microsoft:main' into main
luoyu-intel Dec 8, 2023
6cfa0ab
Merge branch 'microsoft:main' into main
luoyu-intel Dec 8, 2023
5ae9224
add SharedPrepackedWeights test.
luoyu-intel Dec 8, 2023
5f0e6aa
add small k size branch.
luoyu-intel Dec 8, 2023
6465e27
revert format
luoyu-intel Dec 8, 2023
48ba857
Merge branch 'microsoft:main' into main
luoyu-intel Dec 14, 2023
3d85292
fix val names
luoyu-intel Dec 14, 2023
2abed6d
add CompUndef to UT case
luoyu-intel Dec 14, 2023
291ef3b
use assert zps==nullptr
luoyu-intel Dec 14, 2023
a146fa5
change assert unrollk
luoyu-intel Dec 14, 2023
4bd9d04
use size_t for M, N ,K ...
luoyu-intel Dec 14, 2023
c64ccee
add name conversion explaination for launcher templates
luoyu-intel Dec 14, 2023
9536f15
use unique_ptr instead of raw poitner
luoyu-intel Dec 14, 2023
1f25140
remove const_cast
luoyu-intel Dec 14, 2023
f3872b8
change threshold to 16
luoyu-intel Dec 14, 2023
474d2c8
fix typo
luoyu-intel Dec 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF)
option(onnxruntime_USE_SNPE "Build with SNPE support" OFF)
option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF)
option(onnxruntime_USE_DNNL "Build with DNNL support" OFF)
option(onnxruntime_USE_JBLAS "Build MLAS with JBLAS support" ON)
option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF)
option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON)
option(onnxruntime_BUILD_CSHARP "Build C# library" OFF)
Expand Down Expand Up @@ -1166,6 +1167,17 @@ if (onnxruntime_USE_DNNL)
add_compile_definitions(DNNL_OPENMP)
endif()

set(USE_JBLAS FALSE)
if (onnxruntime_USE_JBLAS AND NOT onnxruntime_MINIMAL_BUILD)
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64")
add_compile_definitions(MLAS_JBLAS)
set(USE_JBLAS TRUE)
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64")
add_compile_definitions(MLAS_JBLAS)
set(USE_JBLAS TRUE)
endif()
endif()

# TVM EP
if (onnxruntime_USE_TVM)
if (NOT TARGET tvm)
Expand Down
16 changes: 14 additions & 2 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ endif()

set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas)

function(add_jblas)
add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas)
target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/jblas_gemm.cpp
)
set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR OFF)
endfunction()

#TODO: set MASM flags properly
function(setup_mlas_source_for_windows)

Expand Down Expand Up @@ -200,7 +209,6 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/q4gemm_avx512.cpp
)
endif()

else()
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp
Expand Down Expand Up @@ -566,7 +574,7 @@ else()
)
set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f")
set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f")
endif()
endif()

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
onnxruntime_add_static_library(onnxruntime_mlas_x86_64 ${mlas_platform_srcs})
Expand Down Expand Up @@ -604,6 +612,10 @@ else()
target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs})
endif()

if(USE_JBLAS)
add_jblas()
endif()

foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS})
target_include_directories(${mlas_target} PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${MLAS_SRC_DIR})
onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET})
Expand Down
2 changes: 2 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2824,6 +2824,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>size of each input feature</dd>
<dt><tt>N</tt> : int (required)</dt>
<dd>size of each output feature</dd>
<dt><tt>accuracy_level</tt> : int</dt>
<dd>The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8) (default unset). It is used to control how input A is quantized or downcast internally while doing computation, for example: 0 means input A will not be quantized or downcast while doing computation. 4 means input A can be quantized with the same block_size to int8 internally from type T1.</dd>
<dt><tt>bits</tt> : int (required)</dt>
<dd>number of bits used for weight quantization (default 4)</dd>
<dt><tt>block_size</tt> : int (required)</dt>
Expand Down
134 changes: 131 additions & 3 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,158 @@ class MatMulNBits final : public OpKernel {
K_{narrow<size_t>(info.GetAttr<int64_t>("K"))},
N_{narrow<size_t>(info.GetAttr<int64_t>("N"))},
block_size_{narrow<size_t>(info.GetAttr<int64_t>("block_size"))},
nbits_{narrow<size_t>(info.GetAttr<int64_t>("bits"))} {
nbits_{narrow<size_t>(info.GetAttr<int64_t>("bits"))},
accuracy_level_{info.GetAttr<int64_t>("accuracy_level")} {
luoyu-intel marked this conversation as resolved.
Show resolved Hide resolved
ORT_ENFORCE(nbits_ == 4,
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
is_asym_ = info.GetInputCount() >= 4;
luoyu-intel marked this conversation as resolved.
Show resolved Hide resolved
const Tensor* tensor_B = nullptr;
const Tensor* tensor_scale = nullptr;
const Tensor* tensor_zero_point = nullptr;
bool B_constant = info.TryGetConstantInput(1, &tensor_B);
bool scale_constant = info.TryGetConstantInput(2, &tensor_scale);
bool zero_point_constant = info.TryGetConstantInput(3, &tensor_zero_point);
all_constant_ = B_constant && scale_constant;
all_constant_ = is_asym_ ? all_constant_ && zero_point_constant : all_constant_;
}

Status Compute(OpKernelContext* context) const override;

Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override;

Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers, int input_idx,
/*out*/ bool& used_shared_buffers) override;

private:
const size_t K_;
const size_t N_;
const size_t block_size_;
const size_t nbits_;
const int64_t accuracy_level_;
const bool column_wise_quant_{true};
IAllocatorUniquePtr<void> packed_b_;
size_t packed_b_size_{0};
bool is_asym_{false};
bool all_constant_{false};
};

Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
is_packed = false;
if (!all_constant_) {
return Status::OK();
}
auto compt_type = static_cast<MLAS_SQNBIT_COMPUTE_TYPE>(accuracy_level_);
MLAS_THREADPOOL* pool = NULL;
luoyu-intel marked this conversation as resolved.
Show resolved Hide resolved
if (input_idx == 1) {
packed_b_size_ = MlasNBitsGemmPackBSize(N_, K_, block_size_, static_cast<int>(nbits_), is_asym_, compt_type);
if (packed_b_size_ == 0) return Status::OK();
auto qptr = tensor.Data<uint8_t>();
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
if (packed_b_ == nullptr) {
return Status::OK();
}
std::memset(packed_b_.get(), 0, packed_b_size_);
MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, static_cast<int>(nbits_),
is_asym_, false, compt_type, pool);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
}
is_packed = true;
}
if (input_idx == 2 && packed_b_ != nullptr) {
auto sptr = tensor.Data<float>();
MlasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, static_cast<int>(nbits_),
is_asym_, !is_asym_, compt_type, pool);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
}
is_packed = true;
}
if (input_idx == 3 && packed_b_ != nullptr) {
auto zptr = tensor.Data<uint8_t>();
MlasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, static_cast<int>(nbits_),
is_asym_, is_asym_, compt_type, pool);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
}
is_packed = true;
}

return Status::OK();
}

Status MatMulNBits::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers, int input_idx,
/*out*/ bool& used_shared_buffers) {
used_shared_buffers = false;
// Pack three tensors into one buffer
if (input_idx == 1) {
used_shared_buffers = true;
packed_b_ = std::move(prepacked_buffers[0]);
}
if (input_idx == 2) {
used_shared_buffers = true;
packed_b_ = std::move(prepacked_buffers[0]);
}
if (input_idx == 3) {
used_shared_buffers = true;
packed_b_ = std::move(prepacked_buffers[0]);
}
return Status::OK();
}

Status MatMulNBits::Compute(OpKernelContext* ctx) const {
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();

const Tensor* a = ctx->Input<Tensor>(0);
const auto* a_data = a->Data<float>();

if (packed_b_.get()) {
luoyu-intel marked this conversation as resolved.
Show resolved Hide resolved
TensorShape b_shape({static_cast<int64_t>(N_), static_cast<int64_t>(K_)});

MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true));

Tensor* y = ctx->Output(0, helper.OutputShape());

// Bail out early if the output is going to be empty
if (y->Shape().Size() == 0) return Status::OK();

auto* y_data = y->MutableData<float>();

const size_t max_len = helper.OutputOffsets().size();
const size_t M = static_cast<size_t>(helper.M());
const size_t N = static_cast<size_t>(helper.N());
const size_t K = static_cast<size_t>(helper.K());
const size_t lda = helper.Lda(false);
std::vector<MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS> gemm_params(max_len);
AllocatorPtr allocator;
auto status = ctx->GetTempSpaceAllocator(&allocator);
ORT_RETURN_IF_ERROR(status);
for (size_t i = 0; i < max_len; i++) {
gemm_params[i].A = a_data + helper.LeftOffsets()[i];
gemm_params[i].lda = lda;
gemm_params[i].B = packed_b_.get();
gemm_params[i].C = y_data + helper.OutputOffsets()[i];
gemm_params[i].ldc = N;
}
auto ws_size = MlasSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data());
// workspace for activation process(dynamic quantization and others)
auto ws_ptr = IAllocator::MakeUniquePtr<int8_t>(allocator, ws_size);
MlasSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(),
thread_pool);
return Status::OK();
}

const Tensor* b = ctx->Input<Tensor>(1);
const Tensor* scales = ctx->Input<Tensor>(2);
const Tensor* zero_points = ctx->Input<Tensor>(3);

const auto* a_data = a->Data<float>();
const uint8_t* b_data = b->Data<uint8_t>();
const auto* scales_data = scales->Data<float>();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data<uint8_t>();
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3359,6 +3359,13 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored
.Attr("N", "size of each output feature", AttributeProto::INT)
.Attr("bits", "number of bits used for weight quantization (default 4)", AttributeProto::INT)
.Attr("block_size", "number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT)
.Attr("accuracy_level",
"The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8) "
"(default unset). It is used to control how input A is quantized or downcast internally while "
"doing computation, for example: 0 means input A will not be quantized or downcast while doing "
"computation. 4 means input A can be quantized with the same block_size to int8 internally from "
"type T1.",
AttributeProto::INT, static_cast<int64_t>(0))
.Input(0, "A", "The input tensor, not quantized", "T1")
.Input(1, "B", "1-dimensional data blob", "T2")
.Input(2, "scales", "quantization scale", "T1")
Expand Down
141 changes: 141 additions & 0 deletions onnxruntime/core/mlas/inc/mlas_qnbit.h
luoyu-intel marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,144 @@ MlasIsSQNBitGemmAvailable(
size_t BlkBitWidth,
size_t BlkLen
);

/**
* @brief Define compute types of block quantization
*/
typedef enum {
CompUndef = 0, /*!< undef */
CompFp32 = 1, /*!< input fp32, accumulator fp32 */
CompFp16 = 2, /*!< input fp16, accumulator fp16 */
CompBf16 = 3, /*!< input bf16, accumulator fp32 */
CompInt8 = 4 /*!< input int8, accumulator int32 */
} MLAS_SQNBIT_COMPUTE_TYPE;

/**
* @brief Data parameters for NBits GEMM routine
* C = A * B
* A, C must be a float32 matrix
* B must be a packed nbits blob
* All except C are [in] parameters
*/
struct MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS {
const float* A = nullptr; /**< address of A (float32 matrix)*/
const void* B = nullptr; /**< address of B (packed nbits blob)*/
float* C = nullptr; /**< address of result matrix */
size_t lda = 0; /**< leading dimension of A */
size_t ldc = 0; /**< leading dimension of C*/
};

/**
* @brief Compute the byte size of the parameter combination
*
* @param N the number of columns of matrix B.
* @param K the number of rows of matrix B.
* @param block_size size of the block to quantize, elements from the same block share the same
* scale and zero point
* @param nbits number of bits used for weight quantization
* @param is_asym flag for asymmetric quantization
* @param comp_type specify input data type and accumulator data type
* @return size of the packing buffer, 0 if the operation is not yet supported.
*/
size_t MLASCALL
MlasNBitsGemmPackBSize(
size_t N, size_t K, size_t block_size, int nbits, bool is_asym, MLAS_SQNBIT_COMPUTE_TYPE comp_type
);

/**
* @brief Prepack tensor data from n-bit quantized data, scale and zero point buffers.
*
* @param PackedBuf packed data buffer
* @param QData quantized data buffer
* @param Scale scale pointer
* @param Zp zero point pointer
* @param N the number of columns of matrix B.
* @param K the number of rows of matrix B.
* @param ldb leading dimension of B
* @param block_size size of the block to quantize, elements from the same block share the same
* scale and zero point
* @param nbits number of bits used for weight quantization (default 4)
* @param is_asym flag for asymmetric quantization
* @param comp_type specify input data type and accumulator data type
* @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor
* one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where
* they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up
* inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale
* (is_asym is false) and Zp(is_asym is true).
* @param thread_pool
*/
void MLASCALL
MlasNBitsGemmPackB(
void* PackedBuf,
const uint8_t* QData,
const float* Scale,
const uint8_t* Zp,
size_t N,
size_t K,
size_t ldb,
size_t block_size,
int nbits,
bool is_asym,
bool last_call,
MLAS_SQNBIT_COMPUTE_TYPE comp_type,
MLAS_THREADPOOL* thread_pool
);

/**
* @brief Unpack and dequantize to fp32
*
* @param FpData unpacked float32 data
* @param PackedBuf quantized and packed data
* @param N the number of columns of matrix B.
* @param K the number of rows of matrix B.
* @param ldb leading dimension of B
* @param thread_pool
*/
void MLASCALL
MlasNBitsGemmUnPackB(
float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* thread_pool
);

/**
* @brief Get the workspace size required by computation.
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BatchN number of batches
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @return Workspace size in bytes
*/
size_t MLASCALL
MlasSQNBitsGemmBatchWorkspaceSize(
const size_t M,
const size_t N,
const size_t K,
const size_t BatchN,
const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams
);

/**
* @brief Batched GEMM: C = A * B
* A, C must be a float32 matrix
* B must be a packed nbits blob
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BatchN number of batches
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @param[in] WorkSpace temporary buffer
luoyu-intel marked this conversation as resolved.
Show resolved Hide resolved
* @param[in] ThreadPool
* @return
*/
void MLASCALL
MlasSQNBitsGemmBatchPackedB(
const size_t M,
const size_t N,
const size_t K,
const size_t BatchN,
const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams,
void* WorkSpace,
MLAS_THREADPOOL* ThreadPool = nullptr
);
Loading
Loading