Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mlas Gemm 4bit avx2, avx512, and avx512vnni kernels #20163

Merged
merged 67 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
bbd7cf6
sqnbitgemm_kernel_avx2/avx512
liqunfu Mar 30, 2024
a605e6a
avx512 port from q4gemm - pass M1N1K1 w/o bias
liqunfu Mar 31, 2024
5ca97a7
draft M1 avx2
liqunfu Apr 1, 2024
f3b1298
pass avx2 M1 tests
liqunfu Apr 2, 2024
868a4dc
port draft for M1Int8
liqunfu Apr 2, 2024
fe3f8fa
pass avx512 M1 int8 except blklen16
liqunfu Apr 3, 2024
63eaea6
draft dequant B. still layout is not ready for GemmFloatKernel
liqunfu Apr 4, 2024
b86157e
pass M* Fp32 tests
liqunfu Apr 4, 2024
b5cd8f8
pass blklen16/Int8
liqunfu Apr 4, 2024
652c172
Merge branch 'main' into liqun/mlas-4bit-cpu
liqunfu Apr 4, 2024
649f899
fix FoldAccumulators
liqunfu Apr 5, 2024
3b43b67
try subblk 64
liqunfu Apr 6, 2024
e01570c
subblk len=64 works with avx512(20% improvelent) avx2 (much slower an…
liqunfu Apr 7, 2024
563eeec
int8 Both USE_NCOLs true and false are passing. in USE_NCOLs=true, bl…
liqunfu Apr 9, 2024
ce285aa
having USE_NCOLs options for NCols=4 and NCols=1: former is 5-10% faster
liqunfu Apr 10, 2024
949bdbe
rename avx2 to avx512 because avx512 ops are used. avx2 not supported
liqunfu Apr 10, 2024
021ed71
experiment load int4 and comments
liqunfu Apr 12, 2024
4b02e45
fp32 avx512 M1 5 times improvement by porting existing q4 code, fp32 …
liqunfu Apr 15, 2024
23e981c
MlasQ80BlkQuantRow_avx2
liqunfu Apr 16, 2024
6620aab
support avx2, refactor for avx2, avx512, vnni kernels
liqunfu Apr 17, 2024
c40bc23
fix avx2
liqunfu Apr 17, 2024
549ea04
avx2 fp32 blklen16 reduce extra register use to fix perf
liqunfu Apr 18, 2024
85707f8
refactor to pass linux
liqunfu Apr 18, 2024
c5ff56d
fix M1 avx512 compute bug. trying fix matmul_nbits.cc, no success
liqunfu Apr 19, 2024
3aa5fa0
make ns test to pass
liqunfu Apr 20, 2024
f6f819f
fix MaxMulNBit to work for both NS and NO_NS
liqunfu Apr 20, 2024
2c626ea
fix unused with (void)QuantBZeroPointColPtr;
liqunfu Apr 20, 2024
a64b40c
replace _mm_loadu_si64 with _mm_loadl_epi64
liqunfu Apr 20, 2024
638dcad
replace _mm_loadu_si64 with _mm_loadl_epi64 2
liqunfu Apr 20, 2024
4f5a97c
Merge branch 'main' into liqun/mlas-4bit-cpu
liqunfu Apr 20, 2024
b24f4ae
lint
liqunfu Apr 20, 2024
e007dc4
only enable avx2 to find cause of mlas failure
liqunfu Apr 21, 2024
d8dc403
only run blklen16 M1 with avx2
liqunfu Apr 21, 2024
5e36aa7
likely SQNBitGemmBlkBitWidth4BlkLen16.Threaded/isSymmetric0/M2xN2xK2/…
liqunfu Apr 21, 2024
35690d8
SQNBitGemmBlkBitWidth4BlkLen16.Threaded/isSymmetric0/M2xN2xK2/hasBias…
liqunfu Apr 22, 2024
15f1855
Q4BitBlkDequantBForSgemmBlkLen16_CompFp32 is writing to FpData passin…
liqunfu Apr 22, 2024
75d07d4
skip _mm256_storeu_ps
liqunfu Apr 22, 2024
e85c761
skip _mm256_storeu_ps 2
liqunfu Apr 22, 2024
72ddfc5
skip _mm256_storeu_ps 3
liqunfu Apr 22, 2024
bc83828
skip _mm256_storeu_ps 4
liqunfu Apr 22, 2024
39a5b2e
skip _mm256_storeu_ps 5
liqunfu Apr 22, 2024
9c37281
skip _mm256_storeu_ps 6
liqunfu Apr 22, 2024
a9a2aa6
skip _mm256_storeu_ps 6
liqunfu Apr 22, 2024
7fb99d8
skip _mm256_storeu_ps 6
liqunfu Apr 23, 2024
ddd5e10
skip _mm256_storeu_ps 6
liqunfu Apr 23, 2024
dc38f74
skip _mm256_storeu_ps 6
liqunfu Apr 23, 2024
95e7e77
skip _mm256_storeu_ps 6
liqunfu Apr 23, 2024
13f6c97
skip _mm256_storeu_ps 6
liqunfu Apr 23, 2024
b64e615
skip _mm256_storeu_ps 6
liqunfu Apr 23, 2024
6d1321f
move avx2 from avx512.cpp
liqunfu Apr 23, 2024
7f5fa06
remove a std::cout
liqunfu Apr 23, 2024
1d88398
bring back avx512 and vnni
liqunfu Apr 24, 2024
64d777e
onnxruntime_USE_NEURAL_SPEED OFF
liqunfu Apr 24, 2024
4a173d9
use fp32 for M>1 cases for int8 compute
liqunfu Apr 24, 2024
d3a00a1
Merge branch 'liqun/mlas-4bit-cpu' of https://github.com/microsoft/on…
liqunfu Apr 24, 2024
4a09bf3
incorrect use of _cvtepu8_epi16
liqunfu Apr 25, 2024
919dd76
avoid condition in loops
liqunfu Apr 25, 2024
ddebcb0
Merge branch 'main' into liqun/mlas-4bit-cpu
liqunfu Apr 25, 2024
7e1fef1
update doc
liqunfu Apr 25, 2024
81b865f
fix b0ptr += count_half_4 / 2; missing
liqunfu Apr 25, 2024
53577a9
fix bug - used wrong pointer in SQ4BitGemmM1Kernel_CompInt8_Impl_BlkL…
edgchen1 Apr 25, 2024
13aa5c1
fix parameter ordering in test
edgchen1 Apr 25, 2024
9228531
lint
liqunfu Apr 25, 2024
21c7896
lint
liqunfu Apr 25, 2024
3b2a4e9
move count_half_4 out of loop
liqunfu Apr 25, 2024
8a69e1c
comments
liqunfu Apr 25, 2024
20649bc
reduce dml test to original
liqunfu Apr 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF)
option(onnxruntime_USE_SNPE "Build with SNPE support" OFF)
option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF)
option(onnxruntime_USE_DNNL "Build with DNNL support" OFF)
option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" ON)
option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" OFF)
option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF)
option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON)
option(onnxruntime_BUILD_CSHARP "Build C# library" OFF)
Expand Down
11 changes: 11 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp
${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp
${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAmx.asm
${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAvx2.asm
${MLAS_SRC_DIR}/amd64/QgemmU8U8KernelAvx2.asm
Expand Down Expand Up @@ -530,6 +533,7 @@ else()
${MLAS_SRC_DIR}/x86_64/ErfKernelFma3.S
${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp
${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
)
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma")

Expand All @@ -549,9 +553,15 @@ else()
${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx512Vnni.S
${MLAS_SRC_DIR}/x86_64/QgemmU8X8KernelAvx512Core.S
${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx512Core.S
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp
)
set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mavx512bw -mavx512dq -mavx512vl")

set(mlas_platform_srcs_avx512vnni
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp
)
set_source_files_properties(${mlas_platform_srcs_avx512vnni} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f")

set(mlas_platform_srcs
${MLAS_SRC_DIR}/activate_fp16.cpp
${MLAS_SRC_DIR}/dwconv.cpp
Expand All @@ -563,6 +573,7 @@ else()
${mlas_platform_srcs_avx2}
${mlas_platform_srcs_avx512f}
${mlas_platform_srcs_avx512core}
${mlas_platform_srcs_avx512vnni}
)

if (NOT onnxruntime_ORT_MINIMAL_BUILD)
Expand Down
2 changes: 1 addition & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2876,7 +2876,7 @@ This version of the operator has been available since version 1 of the 'com.micr
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
3. Input B's scale and zero point are specified by input scales and zero_points.

Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
- n_blocks_per_col = (K + block_size - 1) / block_size
- blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>)
For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t.
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
is_packed = true;
}

#else // defined(ORT_NEURAL_SPEED)
#else // defined(ORT_NEURAL_SPEED)

if (input_idx == 1) {
const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(accuracy_level_);
Expand All @@ -204,7 +204,6 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
}
is_packed = true;
}

#endif // defined(ORT_NEURAL_SPEED)

return Status::OK();
Expand Down Expand Up @@ -315,6 +314,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
const size_t K = static_cast<size_t>(helper.K());
const size_t lda = helper.Lda(false);

#ifndef ORT_NEURAL_SPEED
const bool has_single_b_matrix =
(!act_order_) && (!zero_point_is_not_quant_) &&
std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), [](size_t offset) { return offset == 0; });
Expand Down Expand Up @@ -358,6 +358,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
return Status::OK();
}
}
#endif // not defined(ORT_NEURAL_SPEED)

const Tensor* b = ctx->Input<Tensor>(1);
const uint8_t* b_data = b->Data<uint8_t>();
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3407,7 +3407,7 @@ MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
3. Input B's scale and zero point are specified by input scales and zero_points.

Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
- n_blocks_per_col = (K + block_size - 1) / block_size
- blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>)
For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t.
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,12 @@ struct MLAS_SQNBIT_GEMM_DISPATCH;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni;

//
// Quantized depthwise convolution kernels.
//
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ Return Value:
this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernelAvx2<int8_t, int8_t>;
this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernelAvx2<int8_t, uint8_t>;
this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelFma3;
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2;

yufenglee marked this conversation as resolved.
Show resolved Hide resolved
//
// Check if the processor supports Hybrid core architecture.
Expand Down Expand Up @@ -439,6 +440,7 @@ Return Value:
this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx512Core;
this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Core;
this->FpQ4GemmDispatch = &MlasFpQ4GemmDispatchAvx512;
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512;

//
// Check if the processor supports AVX512VNNI.
Expand All @@ -451,6 +453,7 @@ Return Value:
this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Vnni;
this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Vnni;
this->Q8Q4GemmDispatch = &MlasQ8Q4GemmDispatchAvx512vnni;
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512vnni;
}
}
}
Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/core/mlas/lib/sqnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,17 @@
const size_t RangeCountN
)
{
#ifdef MLAS_TARGET_AMD64_IX86
if (RangeCountM != 1) {
// perf experiment shows fp32 is faster than int8 in M > 1 cases.
// route to fp32 compute before int8 compute is improved.
SQ4BitGemm_CompFp32(
BlkLen,
K, DataParams, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN
);

Check warning on line 404 in onnxruntime/core/mlas/lib/sqnbitgemm.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Closing ) should be moved to the previous line [whitespace/parens] [2] Raw Output: onnxruntime/core/mlas/lib/sqnbitgemm.cpp:404: Closing ) should be moved to the previous line [whitespace/parens] [2]
return;
}
#endif
constexpr size_t BlkBitWidth = 4;

const size_t k_blks = MlasDivRoundup(K, BlkLen);
Expand Down
Loading
Loading