Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLAS AArch64] SQNBitGemm CompInt8 kernel #18953

Merged
merged 36 commits into from
Jan 13, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
8940c0a
only register q4gemm benchmarks if q4gemm is available
edgchen1 Dec 4, 2023
a6a8ce6
some mlas cmake updates
edgchen1 Dec 5, 2023
53a46ca
change BlkLen from template param to function param
edgchen1 Dec 12, 2023
e2a9eee
Save work
edgchen1 Dec 14, 2023
966a915
only enable benchmark if available
edgchen1 Dec 14, 2023
b59e7e1
handle workspace in benchmark
edgchen1 Dec 15, 2023
585103b
QuantizeARow neon impl1
edgchen1 Dec 15, 2023
c26cef4
dot compint8 neon impl
edgchen1 Dec 15, 2023
1b7d81b
use single workspace pointer in interface, get matmul_nbits working
edgchen1 Dec 16, 2023
f7e3db5
Merge remote-tracking branch 'origin/main' into edgchen1/sqnbitgemm_q…
edgchen1 Dec 27, 2023
71bd3a9
renaming and cleanup
edgchen1 Dec 27, 2023
f7127f9
try different comp types in matmulnbits
edgchen1 Dec 28, 2023
0060f55
Merge remote-tracking branch 'origin/main' into edgchen1/sqnbitgemm_q…
edgchen1 Dec 28, 2023
b3147c6
rename enum, add doc
edgchen1 Dec 28, 2023
789bcdc
change quant b params from uint8_t* to std::byte*
edgchen1 Dec 28, 2023
039dd92
handle CompUndef
edgchen1 Dec 28, 2023
cb9f428
check if dot product instructions are available before setting SQNBit…
edgchen1 Dec 29, 2023
437ad52
try to fix compile issue
edgchen1 Dec 29, 2023
241ca27
move zero initialize out of unrolled loop
edgchen1 Dec 29, 2023
53e2ae2
update comment
edgchen1 Jan 2, 2024
d5b26b4
split out float conversion
edgchen1 Jan 2, 2024
02cf7b3
remove impl0_reference
edgchen1 Jan 2, 2024
5b4a86c
use thread per gemm in prepare workspace fn, reorder include
edgchen1 Jan 2, 2024
61998ea
make pointer const
edgchen1 Jan 3, 2024
fe7f0e7
Merge remote-tracking branch 'origin/main' into edgchen1/sqnbitgemm_q…
edgchen1 Jan 3, 2024
d54cbd9
remove unneeded and
edgchen1 Jan 10, 2024
7d8753c
Merge remote-tracking branch 'origin/main' into edgchen1/sqnbitgemm_q…
edgchen1 Jan 10, 2024
6d88a0b
move code from merge conflict
edgchen1 Jan 10, 2024
ccaa994
pack quant b data
edgchen1 Jan 11, 2024
cff3cb4
get matmulnbits working, add docs
edgchen1 Jan 11, 2024
f8aba0c
Merge remote-tracking branch 'origin/main' into edgchen1/sqnbitgemm_q…
edgchen1 Jan 11, 2024
33e6dd9
use threadpool to pack b data
edgchen1 Jan 11, 2024
4cd2474
shorten names, update docs
edgchen1 Jan 11, 2024
9244a3f
rename another function, add check for implementation in MlasSQNBitGe…
edgchen1 Jan 11, 2024
86f84ea
move b_data_block_offset out of unrolled loop body
edgchen1 Jan 12, 2024
2337375
move b data offset out of unrolled loop in compfp32 kernel
edgchen1 Jan 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

set(MLAS_SRC_DIR ${ONNXRUNTIME_ROOT}/core/mlas/lib)
set(MLAS_ROOT ${ONNXRUNTIME_ROOT}/core/mlas)
set(MLAS_SRC_DIR ${MLAS_ROOT}/lib)
set(MLAS_INC_DIR ${MLAS_ROOT}/inc)

#
# All hardware agnostic source files here
# hardware specific files would cause trouble in
# multi-target build
#
onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/mlasi.h
${MLAS_SRC_DIR}/platform.cpp
${MLAS_SRC_DIR}/threading.cpp
${MLAS_SRC_DIR}/sgemm.cpp
Expand All @@ -33,9 +36,18 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/qpostprocessor.cpp
${MLAS_SRC_DIR}/qlgavgpool.cpp
${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp
${MLAS_SRC_DIR}/sqnbitgemm.h
${MLAS_SRC_DIR}/sqnbitgemm.cpp
)

target_sources(onnxruntime_mlas PRIVATE
${MLAS_INC_DIR}/mlas_float16.h
${MLAS_INC_DIR}/mlas_gemm_postprocessor.h
${MLAS_INC_DIR}/mlas_q4.h
${MLAS_INC_DIR}/mlas_qnbit.h
${MLAS_INC_DIR}/mlas.h
)

if (NOT onnxruntime_ORT_MINIMAL_BUILD)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/q4_dq.cpp
Expand All @@ -46,7 +58,7 @@ endif()
set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas)

function(add_jblas)
add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas)
add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas)
target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/jblas_gemm.cpp
Expand Down Expand Up @@ -143,10 +155,6 @@ function(setup_mlas_source_for_windows)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/arm/sgemmc.cpp
)
# it should be removed after Visual Stuio is upgraded to 17.7
if (MSVC)
add_compile_options("-d2SSAOptimizer-")
endif()
elseif(onnxruntime_target_platform STREQUAL "x64")

file(GLOB_RECURSE mlas_platform_srcs_avx CONFIGURE_DEPENDS
Expand Down Expand Up @@ -300,8 +308,8 @@ else()
if(APPLE)
get_target_property(ONNXRUNTIME_MLAS_MACOSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES)
endif()
list(LENGTH ONNXRUNTIME_MLAS_MACOSX_ARCH ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGH)
if(ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGH GREATER 1)
list(LENGTH ONNXRUNTIME_MLAS_MACOSX_ARCH ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH)
if(ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH GREATER 1)
set(ONNXRUNTIME_MLAS_MULTI_ARCH TRUE)
endif()
#If ONNXRUNTIME_MLAS_MULTI_ARCH is true, we need to go through every if branch below
Expand Down Expand Up @@ -348,6 +356,8 @@ else()
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
edgchen1 marked this conversation as resolved.
Show resolved Hide resolved
if (NOT APPLE)
set(mlas_platform_srcs
${mlas_platform_srcs}
Expand Down Expand Up @@ -617,10 +627,12 @@ if(USE_JBLAS)
endif()

foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS})
target_include_directories(${mlas_target} PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${MLAS_SRC_DIR})
target_include_directories(${mlas_target} PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR})
onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET})

set_target_properties(${mlas_target} PROPERTIES FOLDER "ONNXRuntime")
endforeach()
set_target_properties(onnxruntime_mlas PROPERTIES FOLDER "ONNXRuntime")

if (WIN32)
target_compile_options(onnxruntime_mlas PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:/wd6385>" "$<$<COMPILE_LANGUAGE:CXX>:/wd4127>")
if (onnxruntime_ENABLE_STATIC_ANALYSIS)
Expand All @@ -636,6 +648,21 @@ if (NOT onnxruntime_BUILD_SHARED_LIB)
FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR})
endif()

# set up source group for MLAS source files
block()
set(source_group_srcs)
foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS})
get_target_property(mlas_target_srcs ${mlas_target} SOURCES)
foreach(mlas_target_src ${mlas_target_srcs})
cmake_path(IS_PREFIX MLAS_ROOT ${mlas_target_src} in_mlas_root)
if(in_mlas_root)
list(APPEND source_group_srcs ${mlas_target_src})
endif()
endforeach()
endforeach()
source_group(TREE ${MLAS_ROOT} FILES ${source_group_srcs})
endblock()


if (NOT onnxruntime_ORT_MINIMAL_BUILD)

Expand All @@ -647,7 +674,7 @@ if (NOT onnxruntime_ORT_MINIMAL_BUILD)
onnxruntime_add_executable(onnxruntime_mlas_q4dq
${MLAS_SRC_DIR}/q4_dq_cli.cpp
)
target_include_directories(onnxruntime_mlas_q4dq PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${MLAS_SRC_DIR})
target_include_directories(onnxruntime_mlas_q4dq PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR})
set_target_properties(onnxruntime_mlas_q4dq PROPERTIES FOLDER "ONNXRuntimeTest")

target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common)
Expand Down
107 changes: 76 additions & 31 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
if (!all_constant_) {
return Status::OK();
}

#if defined(MLAS_JBLAS)

auto compt_type = static_cast<MLAS_SQNBIT_COMPUTE_TYPE>(accuracy_level_);
MLAS_THREADPOOL* pool = NULL;
if (input_idx == 1) {
Expand Down Expand Up @@ -101,12 +104,32 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
is_packed = true;
}

#else // defined(MLAS_JBLAS)

if (input_idx == 1) {
packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_);
if (packed_b_size_ == 0) return Status::OK();
auto qptr = tensor.DataRaw();
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, qptr, packed_b_.get());
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
}
is_packed = true;
}

#endif // defined(MLAS_JBLAS)

return Status::OK();
}

Status MatMulNBits::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers, int input_idx,
/*out*/ bool& used_shared_buffers) {
used_shared_buffers = false;

#if defined(MLAS_JBLAS)

// Pack three tensors into one buffer
if (input_idx == 1) {
used_shared_buffers = true;
Expand All @@ -120,6 +143,15 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prep
used_shared_buffers = true;
packed_b_ = std::move(prepacked_buffers[0]);
}

#else // defined(MLAS_JBLAS)

if (input_idx == 1) {
used_shared_buffers = true;
packed_b_ = std::move(prepacked_buffers[0]);
}

#endif // defined(MLAS_JBLAS)
return Status::OK();
}

Expand All @@ -129,6 +161,8 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
const Tensor* a = ctx->Input<Tensor>(0);
const auto* a_data = a->Data<float>();

#if defined(MLAS_JBLAS)

if (packed_b_.get()) {
TensorShape b_shape({static_cast<int64_t>(N_), static_cast<int64_t>(K_)});

Expand Down Expand Up @@ -158,18 +192,18 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
gemm_params[i].C = y_data + helper.OutputOffsets()[i];
gemm_params[i].ldc = N;
}
auto ws_size = MlasSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data());
auto ws_size = MlasSQNBitsGemmBatchPackedBWorkspaceSize(M, N, K, max_len, gemm_params.data());
// workspace for activation process(dynamic quantization and others)
auto ws_ptr = IAllocator::MakeUniquePtr<int8_t>(allocator, ws_size);
MlasSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(),
thread_pool);
return Status::OK();
}

const Tensor* b = ctx->Input<Tensor>(1);
#endif // defined(MLAS_JBLAS)

const Tensor* scales = ctx->Input<Tensor>(2);
const Tensor* zero_points = ctx->Input<Tensor>(3);
const uint8_t* b_data = b->Data<uint8_t>();
const auto* scales_data = scales->Data<float>();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data<uint8_t>();

Expand All @@ -181,8 +215,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
Tensor* y = ctx->Output(0, helper.OutputShape());

// Bail out early if the output is going to be empty
if (y->Shape().Size() == 0)
if (y->Shape().Size() == 0) {
return Status::OK();
}

auto* y_data = y->MutableData<float>();

Expand All @@ -192,36 +227,46 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
const size_t K = static_cast<size_t>(helper.K());
const size_t lda = helper.Lda(false);

if (MlasIsSQNBitGemmAvailable(nbits_, block_size_)) {
// number of bytes or elements between adjacent matrices
size_t b_data_matrix_stride_in_bytes, b_scale_matrix_stride, b_zero_point_matrix_stride_in_bytes;
MlasBlockwiseQuantizedBufferSizes(static_cast<int>(nbits_), static_cast<int>(block_size_), /* columnwise */ true,
static_cast<int>(K), static_cast<int>(N),
b_data_matrix_stride_in_bytes, b_scale_matrix_stride,
&b_zero_point_matrix_stride_in_bytes);

const size_t b_matrix_size = K * N;

InlinedVector<MLAS_SQNBIT_GEMM_DATA_PARAMS> data(batch_count);
for (size_t i = 0; i < batch_count; ++i) {
const size_t b_matrix_offset = helper.RightOffsets()[i] / b_matrix_size;

data[i].A = a_data + helper.LeftOffsets()[i];
data[i].lda = lda;
data[i].QuantBData = b_data + b_matrix_offset * b_data_matrix_stride_in_bytes;
data[i].QuantBScale = scales_data + b_matrix_offset * b_scale_matrix_stride;
data[i].QuantBZeroPoint = zero_points_data != nullptr
? zero_points_data + b_matrix_offset * b_zero_point_matrix_stride_in_bytes
: nullptr;
data[i].C = y_data + helper.OutputOffsets()[i];
data[i].ldc = N;
const bool has_single_b_matrix = std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(),
[](size_t offset) { return offset == 0; });

if (has_single_b_matrix && packed_b_) {
for (int64_t accuracy_level = accuracy_level_;
accuracy_level >= static_cast<int64_t>(CompMostAccurate);
--accuracy_level) {
const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(accuracy_level);
if (MlasIsSQNBitGemmAvailable(M, N, K, nbits_, block_size_, compute_type)) {
IAllocatorUniquePtr<std::byte> workspace{};
if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count,
nbits_, block_size_, compute_type);
workspace_size > 0) {
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator));
workspace = IAllocator::MakeUniquePtr<std::byte>(allocator, workspace_size);
}

InlinedVector<MLAS_SQNBIT_GEMM_DATA_PARAMS> data(batch_count);
for (size_t i = 0; i < batch_count; ++i) {
data[i].A = a_data + helper.LeftOffsets()[i];
data[i].lda = lda;
data[i].QuantBData = packed_b_.get();
data[i].QuantBScale = scales_data;
data[i].QuantBZeroPoint = zero_points_data;
data[i].C = y_data + helper.OutputOffsets()[i];
data[i].ldc = N;
}

MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(),
thread_pool);

return Status::OK();
}
}

MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, data.data(), thread_pool);

return Status::OK();
}

const Tensor* b = ctx->Input<Tensor>(1);
const uint8_t* b_data = b->Data<uint8_t>();

const size_t ldb = helper.Ldb(true);

AllocatorPtr allocator;
Expand Down
Loading
Loading