Skip to content

Commit

Permalink
[aarch64] Implement QGEMM kernels with UMMLA/SMMLA instructions (#17160)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
This PR adds UMMLA and SMMLA based QGEMM kernels for aarch64. This
covers
(i) symmetric quantization (zero point is Zero)
(ii) asymmetric quantization (zero point is non zero)
(iii) per channel as well as per tensor quantization
(iv) Signed weights (U8S8 Gemm)
(v) Unsigned weights (U8U8 Gemm) and 
(vi) Signed activations and weights (S8S8 Gemm) scenarios

I've enabled the ummla/smmla kernels based on cpuinfo check for `I8MM`
support
MMLA QGEMM kernels are enabled for all the devices that support I8MM
instructions.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This is to improve INT8 quantized MatMul performance on aarch64
platform.
I have run the below benchmarking script (bert , roberta and gpt2 model
inference) on AWS Graviton3 based c7g.4xl instance and observed up to
1.33x performance improvement compared to the optimized UDOT qgemm
kernel performance.

```
cd onnxruntime/python/tools/transformers
python3 benchmark.py
```
I have also run the unit tests, and made sure all are passing

```
./build.sh --config RelWithDebInfo --build_shared_lib --parallel --compile_no_warning_as_error --skip_submodule_sync 

```
  • Loading branch information
snadampal authored Oct 23, 2023
1 parent 2a17d5c commit 780ee18
Show file tree
Hide file tree
Showing 9 changed files with 3,833 additions and 0 deletions.
6 changes: 6 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,9 @@ else()
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelNeon.S
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelNeon.S
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUdot.S
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSdot.S
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S
${MLAS_SRC_DIR}/aarch64/SgemmKernelNeon.S
${MLAS_SRC_DIR}/aarch64/SgemvKernelNeon.S
${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S
Expand All @@ -334,6 +336,8 @@ else()
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp
)
if (NOT APPLE)
set(mlas_platform_srcs
Expand All @@ -348,6 +352,8 @@ else()
set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
endif()

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/core/common/cpuid_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@
#define HWCAP_ASIMDDP (1 << 20)
#endif

#ifndef HWCAP2_I8MM
#define HWCAP2_I8MM (1 << 13)
#endif

#ifndef HWCAP2_SVEI8MM
#define HWCAP2_SVEI8MM (1 << 9)
#endif

#endif // ARM

#endif // Linux
Expand Down Expand Up @@ -138,6 +146,9 @@ void CPUIDInfo::ArmLinuxInit() {
is_hybrid_ = cpuinfo_get_uarchs_count() > 1;
has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot();
has_fp16_ = cpuinfo_has_arm_neon_fp16_arith();
has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm();
has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm();

const uint32_t core_cnt = cpuinfo_get_cores_count();
core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown);
is_armv8_narrow_ld_.resize(core_cnt, false);
Expand All @@ -162,6 +173,10 @@ void CPUIDInfo::ArmLinuxInit() {
pytorch_cpuinfo_init_ = false;
has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0);
has_fp16_ |= has_arm_neon_dot_;

has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0);
has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0);

#endif
}

Expand Down Expand Up @@ -256,6 +271,9 @@ void CPUIDInfo::ArmWindowsInit() {

has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0);
has_fp16_ |= has_arm_neon_dot_;
/* TODO: implement them when hw+sw is available for testing these features */
has_arm_neon_i8mm_ = false;
has_arm_sve_i8mm_ = false;
}

#endif /* (arm or arm64) and windows */
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/common/cpuid_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class CPUIDInfo {

// ARM
bool HasArmNeonDot() const { return has_arm_neon_dot_; }
bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; }
bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; }

uint32_t GetCurrentCoreIdx() const;

Expand Down Expand Up @@ -121,6 +123,8 @@ class CPUIDInfo {

bool has_arm_neon_dot_{false};
bool has_fp16_{false};
bool has_arm_neon_i8mm_{false};
bool has_arm_sve_i8mm_{false};

#ifdef CPUIDINFO_ARCH_X86

Expand Down
Loading

0 comments on commit 780ee18

Please sign in to comment.