Skip to content

Commit

Permalink
sqnbitgemm_kernel_avx2/avx512
Browse files Browse the repository at this point in the history
Signed-off-by: Liqun Fu <[email protected]>
  • Loading branch information
liqunfu committed Mar 30, 2024
1 parent 2f82400 commit bbd7cf6
Show file tree
Hide file tree
Showing 6 changed files with 1,154 additions and 1 deletion.
2 changes: 2 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ if (NOT onnxruntime_ORT_MINIMAL_BUILD)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/q4_dq.cpp
${MLAS_SRC_DIR}/q4gemm.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp
)
endif()

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3346,7 +3346,7 @@ MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
3. Input B's scale and zero point are specified by input scales and zero_points.
Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
- n_blocks_per_col = (K + block_size - 1) / block_size
- blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>)
For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t.
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,9 @@ struct MLAS_SQNBIT_GEMM_DISPATCH;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512;
//
// Quantized depthwise convolution kernels.
//
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,9 @@ Return Value:
this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernelAvx2<int8_t, uint8_t>;
this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelFma3;

this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2;


//
// Check if the processor supports Hybrid core architecture.
//
Expand Down Expand Up @@ -438,6 +441,7 @@ Return Value:
this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx512Core;
this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Core;
this->FpQ4GemmDispatch = &MlasFpQ4GemmDispatchAvx512;
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512;

//
// Check if the processor supports AVX512VNNI.
Expand Down
Loading

0 comments on commit bbd7cf6

Please sign in to comment.