Skip to content

Commit

Permalink
Mlas Gemm 4bit avx2, avx512, and avx512vnni kernels (#20163)
Browse files Browse the repository at this point in the history
### Description

```
Avx2:
Int8

NS(Prompt)		MLAS(Prompt)  	MLAS(Prompt)Gain/Loss		NS(TokenGen)		MLAS(TokenGen)  	MLAS(TokenGen)Gain/Loss
Blklen16: 	90.96			25.15			-72%					7.65				11.71			53%
Blklen32:	90.73			48.55			-46%					7.86				14.28			81%
Blklen64:	89.49			68.84			-23%					8.30				15.78			90%
Blklen128:	87.38			78.37			-10%					7.90				16.05			103%
Blklen256:	89.45			82.36			-7%					8.30				16.56			99%

Fp32		
NS(Prompt)		MLAS(Prompt)  	MLAS(Prompt)Gain/Loss		NS(TokenGen)		MLAS(TokenGen)  	MLAS(TokenGen)Gain/Loss
Blklen16:	91.36			105.18		15%				7.57			9.52		25%
Blklen32:	89.30			105.99			18%					7.65				9.68			26%
Blklen64:	89.53			101.41			13%					7.97				9.84			23%
Blklen128:	85.23			99.71			16%					7.86				10.39			32%
Blklen256:	88.46			97.94			10%					8.32				10.23			22%

Avx512vnni:
Int8		
NS(Prompt)		MLAS(Prompt)  	MLAS(Prompt)Gain/Loss		NS(TokenGen)		MLAS(TokenGen)  	MLAS(TokenGen)Gain/Loss
Blklen16:	132.18			21.56			-83%					10.34				11.48			11%
Blklen32:	168.28			43.69			-74%					11.85				14.73			24%
Blklen64:	201.81			60.29			-70%					12.36				15.47			25%
Blklen128:	194.92			57.04			-71%					13.03				14.67			12%
Blklen256:	218.76			70.20			-68%					13.33				16.31			22%

Fp32		
NS(Prompt)		MLAS(Prompt)  	MLAS(Prompt)Gain/Loss		NS(TokenGen)		MLAS(TokenGen)  	MLAS(TokenGen)Gain/Loss
Blklen16:	102.81			92.74			-9%					8.41				9.18			9%
Blklen32:	109.49			97.08			-11%					8.83				11.51			30%
Blklen64:	104.13			101.57			-2%					9.32				12.00			28%
Blklen128:	108.45			103.69			-4%					9.58				12.45			29%
Blklen256:	109.43			106.43			-2%					9.19				12.2			32%

```

---------

Signed-off-by: Liqun Fu <[email protected]>
Signed-off-by: liqunfu <[email protected]>
Co-authored-by: edgchen1 <[email protected]>
  • Loading branch information
2 people authored and yihonglyu committed May 4, 2024
1 parent 204f1f5 commit d9ba4f4
Show file tree
Hide file tree
Showing 17 changed files with 3,435 additions and 10 deletions.
2 changes: 1 addition & 1 deletion cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF)
option(onnxruntime_USE_SNPE "Build with SNPE support" OFF)
option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF)
option(onnxruntime_USE_DNNL "Build with DNNL support" OFF)
option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" ON)
option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" OFF)
option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF)
option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON)
option(onnxruntime_BUILD_CSHARP "Build C# library" OFF)
Expand Down
11 changes: 11 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp
${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp
${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAmx.asm
${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAvx2.asm
${MLAS_SRC_DIR}/amd64/QgemmU8U8KernelAvx2.asm
Expand Down Expand Up @@ -530,6 +533,7 @@ else()
${MLAS_SRC_DIR}/x86_64/ErfKernelFma3.S
${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp
${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
)
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma")

Expand All @@ -549,9 +553,15 @@ else()
${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx512Vnni.S
${MLAS_SRC_DIR}/x86_64/QgemmU8X8KernelAvx512Core.S
${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx512Core.S
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp
)
set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mavx512bw -mavx512dq -mavx512vl")

set(mlas_platform_srcs_avx512vnni
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp
)
set_source_files_properties(${mlas_platform_srcs_avx512vnni} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f")

set(mlas_platform_srcs
${MLAS_SRC_DIR}/activate_fp16.cpp
${MLAS_SRC_DIR}/dwconv.cpp
Expand All @@ -563,6 +573,7 @@ else()
${mlas_platform_srcs_avx2}
${mlas_platform_srcs_avx512f}
${mlas_platform_srcs_avx512core}
${mlas_platform_srcs_avx512vnni}
)

if (NOT onnxruntime_ORT_MINIMAL_BUILD)
Expand Down
2 changes: 1 addition & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2876,7 +2876,7 @@ This version of the operator has been available since version 1 of the 'com.micr
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
3. Input B's scale and zero point are specified by input scales and zero_points.

Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
- n_blocks_per_col = (K + block_size - 1) / block_size
- blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>)
For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t.
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
is_packed = true;
}

#else // defined(ORT_NEURAL_SPEED)
#else // defined(ORT_NEURAL_SPEED)

if (input_idx == 1) {
const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(accuracy_level_);
Expand All @@ -204,7 +204,6 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
}
is_packed = true;
}

#endif // defined(ORT_NEURAL_SPEED)

return Status::OK();
Expand Down Expand Up @@ -315,6 +314,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
const size_t K = static_cast<size_t>(helper.K());
const size_t lda = helper.Lda(false);

#ifndef ORT_NEURAL_SPEED
const bool has_single_b_matrix =
(!act_order_) && (!zero_point_is_not_quant_) &&
std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), [](size_t offset) { return offset == 0; });
Expand Down Expand Up @@ -358,6 +358,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
return Status::OK();
}
}
#endif // not defined(ORT_NEURAL_SPEED)

const Tensor* b = ctx->Input<Tensor>(1);
const uint8_t* b_data = b->Data<uint8_t>();
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3407,7 +3407,7 @@ MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
3. Input B's scale and zero point are specified by input scales and zero_points.
Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
- n_blocks_per_col = (K + block_size - 1) / block_size
- blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>)
For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t.
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,12 @@ struct MLAS_SQNBIT_GEMM_DISPATCH;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni;

//
// Quantized depthwise convolution kernels.
//
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ Return Value:
this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernelAvx2<int8_t, int8_t>;
this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernelAvx2<int8_t, uint8_t>;
this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelFma3;
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2;

//
// Check if the processor supports Hybrid core architecture.
Expand Down Expand Up @@ -439,6 +440,7 @@ Return Value:
this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx512Core;
this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Core;
this->FpQ4GemmDispatch = &MlasFpQ4GemmDispatchAvx512;
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512;

//
// Check if the processor supports AVX512VNNI.
Expand All @@ -451,6 +453,7 @@ Return Value:
this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Vnni;
this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Vnni;
this->Q8Q4GemmDispatch = &MlasQ8Q4GemmDispatchAvx512vnni;
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512vnni;
}
}
}
Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/core/mlas/lib/sqnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,17 @@ SQ4BitGemm_CompInt8(
const size_t RangeCountN
)
{
#ifdef MLAS_TARGET_AMD64_IX86
if (RangeCountM != 1) {
// perf experiment shows fp32 is faster than int8 in M > 1 cases.
// route to fp32 compute before int8 compute is improved.
SQ4BitGemm_CompFp32(
BlkLen,
K, DataParams, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN
);
return;
}
#endif
constexpr size_t BlkBitWidth = 4;

const size_t k_blks = MlasDivRoundup(K, BlkLen);
Expand Down
Loading

0 comments on commit d9ba4f4

Please sign in to comment.