diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 992908392c946..06237c8010fbc 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -325,7 +325,9 @@ else() ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUdot.S + ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSdot.S + ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S ${MLAS_SRC_DIR}/aarch64/SgemmKernelNeon.S ${MLAS_SRC_DIR}/aarch64/SgemvKernelNeon.S ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S @@ -334,6 +336,8 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp ) if (NOT APPLE) set(mlas_platform_srcs @@ -348,6 +352,8 @@ else() set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index 6a82b3fcc734d..655d5014f3d60 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -22,6 +22,14 @@ #define HWCAP_ASIMDDP (1 << 20) #endif +#ifndef HWCAP2_I8MM +#define HWCAP2_I8MM (1 << 13) +#endif + +#ifndef HWCAP2_SVEI8MM +#define HWCAP2_SVEI8MM (1 << 9) +#endif + #endif // ARM #endif // Linux @@ -138,6 +146,9 @@ void CPUIDInfo::ArmLinuxInit() { is_hybrid_ = cpuinfo_get_uarchs_count() > 1; has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot(); has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); + has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); + has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); + const uint32_t core_cnt = cpuinfo_get_cores_count(); core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); is_armv8_narrow_ld_.resize(core_cnt, false); @@ -162,6 +173,10 @@ void CPUIDInfo::ArmLinuxInit() { pytorch_cpuinfo_init_ = false; has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0); has_fp16_ |= has_arm_neon_dot_; + + has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); + has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); + #endif } @@ -256,6 +271,9 @@ void CPUIDInfo::ArmWindowsInit() { has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); has_fp16_ |= has_arm_neon_dot_; + /* TODO: implement them when hw+sw is available for testing these features */ + has_arm_neon_i8mm_ = false; + has_arm_sve_i8mm_ = false; } #endif /* (arm or arm64) and windows */ diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index 386db347c669d..a15c75104b83a 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -28,6 +28,8 @@ class CPUIDInfo { // ARM bool HasArmNeonDot() const { return has_arm_neon_dot_; } + bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; } + bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } uint32_t GetCurrentCoreIdx() const; @@ -121,6 +123,8 @@ class CPUIDInfo { bool has_arm_neon_dot_{false}; bool has_fp16_{false}; + bool has_arm_neon_i8mm_{false}; + bool has_arm_sve_i8mm_{false}; #ifdef CPUIDINFO_ARCH_X86 diff --git a/onnxruntime/core/mlas/lib/aarch64/QgemmS8S8KernelSmmla.S b/onnxruntime/core/mlas/lib/aarch64/QgemmS8S8KernelSmmla.S new file mode 100644 index 0000000000000..e18846c89030e --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/QgemmS8S8KernelSmmla.S @@ -0,0 +1,922 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmS8S8KernelSmmla.s + +Abstract: + + This module implements the kernels for the Int8 precision matrix/matrix + multiply operation (QGEMM). + +--*/ + +#include "asmmacro.h" + + .text + +// +// Stack frame layout for the smmla kernel. d8-d15, x19-x30 need save +// + .equ .LMlasQgemmKernel_backup_x19_x20, 0 + .equ .LMlasQgemmKernel_backup_x21_x22, 16 + .equ .LMlasQgemmKernel_backup_x23_x24, 32 + .equ .LMlasQgemmKernel_backup_x25_x26, 48 + .equ .LMlasQgemmKernel_backup_x27_x28, 64 + .equ .LMlasQgemmKernel_backup_d8_d9, 80 + .equ .LMlasQgemmKernel_backup_d10_d11, 96 + .equ .LMlasQgemmKernel_backup_d12_d13, 112 + .equ .LMlasQgemmKernel_backup_d14_d15, 128 + .equ .LMlasQgemmKernel_SavedRegisters, 144 + .equ .LMlasQgemmKernel_SavedRegisters_Neg, -144 + + +// +// Init Row Accumulators +// +// Generates the code to initialize the accumulators for a single row of the output +// block. +// +// +// Accumulators are initialized to ZeroPointB * RowSum + ColumnSum +// x7 for RowSumsBuffer pointer +// x10 for ColumnSumBuffer pointer +// x11 for ZeroPointB buffer pointer +// +// v12~v13 for RowSums values +// v14~v15 for ColumnSums values +// v0~v3 for ZeroPointB values +// + .macro InitRowAccumulators Columns, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, RowSumReg + + mul v7.4s, v\RowSumReg\().4s, v8.4s + mov v\Vec1Reg\().16b, v7.16b + add v\Vec1Reg\().4s, v\Vec1Reg\().4s, v0.4s +.if \Columns\() > 2 + mul v7.4s, v\RowSumReg\().4s, v9.4s + mov v\Vec2Reg\().16b, v7.16b + add v\Vec2Reg\().4s, v\Vec2Reg\().4s, v1.4s +.endif +.if \Columns\() > 4 + mul v7.4s, v\RowSumReg\().4s, v10.4s + mov v\Vec3Reg\().16b, v7.16b + add v\Vec3Reg\().4s, v\Vec3Reg\().4s, v2.4s +.endif +.if \Columns\() > 6 + mul v7.4s, v\RowSumReg\().4s, v11.4s + mov v\Vec4Reg\().16b, v7.16b + add v\Vec4Reg\().4s, v\Vec4Reg\().4s, v3.4s +.endif + + .endm + + +// +// InitBlockAccumulators +// +// Generates the code to initialize the accumulators for 8x8 output +// block. +// + .macro InitBlockAccumulators Mode, Columns, Rows + + ld1 {v14.4s},[x10],#16 // load ColumnSumBuffer[0] +.if \Columns\() > 4 + ld1 {v15.4s},[x10],#16 // load ColumnSumBuffer[4] +.endif + // v4~v7 will be set to matrixB after this, so, they can used now + dup v4.4s,v14.s[0] // broadcast column + dup v5.4s,v14.s[1] + dup v6.4s,v14.s[2] + dup v7.4s,v14.s[3] + + zip1 v0.4s, v4.4s, v5.4s + zip2 v1.4s, v6.4s, v7.4s +.if \Columns\() > 4 + dup v4.4s,v15.s[0] // broadcast column + dup v5.4s,v15.s[1] + dup v6.4s,v15.s[2] + dup v7.4s,v15.s[3] + + zip1 v2.4s, v4.4s, v5.4s + zip2 v3.4s, v6.4s, v7.4s +.endif + + // v8~v11 will anyway get set in MatrixA loading, so they are free to use now + movi v8.4s, #1 + movi v9.4s, #1 + movi v10.4s, #1 + movi v11.4s, #1 + + cbz x11,.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB + + ld1 {v4.4s},[x11],#16 // load ZeroPointB[0] + ld1 {v5.4s},[x11],#16 // load ZeroPointB[4] + + dup v6.4s, v4.s[0] + dup v7.4s, v4.s[1] + zip1 v8.4s, v6.4s, v7.4s + + dup v6.4s, v4.s[2] + dup v7.4s, v4.s[3] + zip1 v9.4s, v6.4s, v7.4s + + dup v6.4s, v5.s[0] + dup v7.4s, v5.s[1] + zip1 v10.4s, v6.4s, v7.4s + + dup v6.4s, v5.s[2] + dup v7.4s, v5.s[3] + zip1 v11.4s, v6.4s, v7.4s + +.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB: + dup v4.4s, v12.s[0] //boardcast RowSums + dup v5.4s, v12.s[1] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),16,17,18,19,6 +.if \Rows\() > 2 + dup v4.4s, v12.s[2] //boardcast RowSums + dup v5.4s, v12.s[3] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),20,21,22,23,6 +.endif +.if \Rows\() > 4 + dup v4.4s,v13.s[0] // broadcast row sums + dup v5.4s,v13.s[1] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),24,25,26,27,6 +.endif +.if \Rows\() > 6 + dup v4.4s,v13.s[2] // broadcast row sums + dup v5.4s,v13.s[3] + + uzp1 v6.2d, v4.2d, v5.2d + InitRowAccumulators \Columns\(),28,29,30,31,6 +.endif + + .endm + + +// LoadPackedMatrixABy16Elements +// +// Generates the code to load 16 elements from matrix A. +// + .macro LoadPackedMatrixABy16Elements Rows +.if \Rows\() == 1 + ldr q8,[x0],#8 +.else + ldr q8,[x0],#16 + +.if \Rows\() > 2 + ldr q9,[x0],#16 +.endif + +.if \Rows\() > 4 + ldr q10,[x0],#16 +.endif + +.if \Rows\() > 6 + ldr q11,[x0],#16 +.endif +.endif + .endm + + +// +// MultiplyAccumulateRow +// +// Generates the code to multiply and accumulate a single row of the output +// block. +// + + .macro MultiplyAccumulateRow Columns, MatrixAReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + smmla v\Vec1Reg\().4s, \MatrixAReg\().16b, v4.16b +.if \Columns\() > 2 + smmla v\Vec2Reg\().4s, \MatrixAReg\().16b, v5.16b +.endif +.if \Columns\() > 4 + smmla v\Vec3Reg\().4s, \MatrixAReg\().16b, v6.16b +.endif +.if \Columns\() > 6 + smmla v\Vec4Reg\().4s, \MatrixAReg\().16b, v7.16b +.endif + + .endm + +// +// MultiplyAccumulateBlock +// +// Generates the code to multiply and accumulate into the output block. +// + + .macro MultiplyAccumulateBlock Columns, Rows + + MultiplyAccumulateRow \Columns\(),v8,16,17,18,19 +.if \Rows\() > 2 + MultiplyAccumulateRow \Columns\(),v9,20,21,22,23 +.endif +.if \Rows\() > 4 + MultiplyAccumulateRow \Columns\(),v10,24,25,26,27 +.endif +.if \Rows\() > 6 + MultiplyAccumulateRow \Columns\(),v11,28,29,30,31 +.endif + + .endm + +// +// ComputeBlockLoop +// +// Generates the code to loop over K entries of the input matrices to produce +// the output block. +// + + .macro ComputeBlockLoop Mode, Columns, Rows + + InitBlockAccumulators \Mode\(), \Columns\(),\Rows\() + + sub x9,x3,#1 // block count to process + tbnz x9,#63,.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop: + + LoadPackedMatrixABy16Elements \Rows\() + ld1 {v4.16b - v7.16b}, [x1], #64 + MultiplyAccumulateBlock \Columns\(),\Rows\() + + sub x9,x9,#1 + tbz x9,#63,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop +.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks: + add x9,x9,#1 // correct for over-subtract above + cbz x9,.L\Mode\().Output\Columns\().x\Rows\().Block + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4PaddedLoop: + LoadPackedMatrixABy16Elements \Rows\() + ld1 {v4.16b - v7.16b}, [x1], #64 + MultiplyAccumulateBlock \Columns\(),\Rows\() + +.L\Mode\().Output\Columns\().x\Rows\().Block: + + .endm + + +// +// OutputRow2Element +// OutputRow4Element +// OutputRow6Element +// OutputRow8Element +// OutputRow10Element +// OutputRow12Element +// OutputRow14Element +// OutputRow16Element +// +// Generates the code to store elements to the output block. +// + + .macro OutputRow2Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr s8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr s9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + mov v8.S[2], v9.S[0] + add v8.4s,v8.4s,v\Vec1Reg\().4s + + mov w27, v8.S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v8.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + mov w27, v\Vec1Reg\().S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v\Vec1Reg\().S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow4Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + + mov v8.D[1], v9.D[0] + + add v8.4s,v8.4s,v\Vec1Reg\().4s + + mov x27, v8.D[0] + mov x28, v8.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + mov x27, v\Vec1Reg\().D[0] + mov x28, v\Vec1Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.endif + + .endm + + + .macro OutputRow6Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#8 + ldr w28,[\AddrReg1\()],#-8 + mov v8.S[2], w28 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-8 + mov v9.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + mov x27, v8.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v8.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v9.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v9.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + mov x27, v4.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v4.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v5.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v5.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow8Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + +.endif + + .endm + + + .macro OutputRow10Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr w28, [\AddrReg1\()],#-16 + +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr w27,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + mov v8.S[0], w28 + mov v8.S[2], w27 + + add v8.4s,v8.4s,v\Vec3Reg\().4s + + mov w27, v8.S[0] + mov w28, v8.S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov w27, v\Vec3Reg\().S[0] + mov w28, v\Vec3Reg\().S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif +.endif + +.endm + + + .macro OutputRow12Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#-16 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + mov v11.D[0],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + + mov v10.D[1], v11.D[0] + + add v10.4s,v10.4s,v\Vec3Reg\().4s + + mov x27, v10.D[0] + mov x28, v10.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov x27, v\Vec3Reg\().D[0] + mov x28, v\Vec3Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif +.endif + + .endm + + .macro OutputRow14Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#8 + ldr w28, [\AddrReg1\()],#-24 + mov v10.S[2], w28 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-24 + mov v11.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + add v10.4s,v10.4s,v6.4s + add v11.4s,v11.4s,v7.4s + + str q8,[\AddrReg1\()],#16 + + mov x27, v10.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v10.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 + mov x27, v11.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v11.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + str q4,[\AddrReg1\()],#16 + mov x27, v6.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v6.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 + mov x27, v7.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v7.S[2] + str w27, [\AddrReg2\()],#4 +.endif +.endif + + .endm + + + .macro OutputRow16Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldp q8,q10,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldp q9,q11,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + add v10.4s,v10.4s,v6.4s + add v11.4s,v11.4s,v7.4s + + stp q8,q10,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q9,q11,[\AddrReg2\()],#32 +.endif +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + stp q4,q6,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q5,q7,[\AddrReg2\()],#32 +.endif +.endif + + .endm + +// +// OutputBlock +// +// Generates the code to store the output block. +// + + .macro OutputBlock Mode, Columns, Rows + + OutputRow\Columns\()Element \Mode\(),x2,x13,16,17,18,19,(\Rows\() == 1) + +.if \Rows\() > 2 + OutputRow\Columns\()Element \Mode\(),x14,x15,20,21,22,23,(\Rows\() == 3) +.endif + +.if \Rows\() > 4 + OutputRow\Columns\()Element \Mode\(),x16,x17,24,25,26,27,(\Rows\() == 5) +.endif + +.if \Rows\() > 6 + OutputRow\Columns\()Element \Mode\(),x18,x19,28,29,30,31,(\Rows\() == 7) +.endif + + .endm +// +// ProcessRows +// +// Generates the code to process a compute and store the output block for a +// fixed number of rows. +// + + .macro ProcessRows Mode, Rows + mov x4,#\Rows\() // return number of rows handled + cmp x5,#6 + ble .L\Mode\().ProcessNextColumnLoop6x\Rows\() + +.L\Mode\().ProcessNextColumnLoop8x\Rows\(): + ComputeBlockLoop \Mode\(),8,\Rows\() + + sub x5,x5,#8 + cmp x5,#0 + blt .L\Mode\().Output14ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),16,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#6 + bgt .L\Mode\().ProcessNextColumnLoop8x\Rows\() + cbz x5,.L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop6x\Rows\(): + + cmp x5,#4 + ble .L\Mode\().ProcessNextColumnLoop4x\Rows\() + ComputeBlockLoop \Mode\(),6,\Rows\() + sub x5,x5,#6 + cmp x5,#0 + blt .L\Mode\().Output10ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),12,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#4 + bgt .L\Mode\().ProcessNextColumnLoop6x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop4x\Rows\(): + cmp x5,#2 + ble .L\Mode\().ProcessNextColumnLoop2x\Rows\() + ComputeBlockLoop \Mode\(),4,\Rows\() + sub x5,x5,#4 + cmp x5,#0 + blt .L\Mode\().Output6ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),8,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#2 + bgt .L\Mode\().ProcessNextColumnLoop4x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop2x\Rows\(): + ComputeBlockLoop \Mode\(),2,\Rows\() + sub x5,x5,#2 + cmp x5,#0 + blt .L\Mode\().Output2ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),4,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#2 + b .L\Mode\().ExitKernel + +.L\Mode\().Output14ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),14,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output10ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),10,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output6ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),6,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output2ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),2,\Rows\() + b .L\Mode\().ExitKernel + + .endm + + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (x0) - Supplies the address of matrix A. The matrix data has been packed + using MlasGemmQuantCopyPackA. + + B (x1) - Supplies the address of matrix B. The matrix data has been packed + using MlasGemmQuantCopyPackB. + + C (x2) - Supplies the address of matrix C. + + PackedCountK (x3) - Supplies the number of packed columns from matrix A and + the number of packed rows from matrix B to iterate over. + + CountM (x4) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (x5) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + ldc (x6) - Supplies the first dimension of matrix C. + + RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values + have been pre-scaled by the zero point offset of matrix B if the offset + is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be + scaled by the per-column zero point offsets of matrix B. These values are + accumulated into every row of matrix C. + + ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied + by the zero point offset of matrix A. These values are accumulated into + every column of matrix C. + + ZeroPointB - Optionally supplies the per-column zero point offsets of matrix + B, else nullptr if the matrix B is using per-tensor quantization. + +Return Value: + + Returns the number of rows handled. + +--*/ + + .macro QgemmS8S8KernelSmmlaFunction Mode + + FUNCTION_ENTRY MlasGemmS8S8KernelSmmla\Mode\() + + ldr x10,[sp, #0] + ldr x11,[sp,#8] + + stp x19, x20, [sp, #.LMlasQgemmKernel_SavedRegisters_Neg]! + stp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] + stp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] + stp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] + stp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] + stp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] + stp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] + stp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] + stp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] + + add x13,x2,x6,lsl #2 // compute matrix C plus 1 row + add x14,x13,x6,lsl #2 // compute matrix C plus 2 rows + add x15,x14,x6,lsl #2 // compute matrix C plus 3 rows + add x16,x15,x6,lsl #2 // compute matrix C plus 4 rows + add x17,x16,x6,lsl #2 // compute matrix C plus 5 rows + add x18,x17,x6,lsl #2 // compute matrix C plus 6 rows + add x19,x18,x6,lsl #2 // compute matrix C plus 7 rows + + mov x8,x0 // save matrix A + +// +// Process 8 rows of the matrices. +// + ld1 {v12.4s},[x7],#16 // load row sum 1 ~ 4 + cmp x4,#8 + blt .L\Mode\().ProcessCountMLessThan8 + ld1 {v13.4s},[x7],#16 // load row sum 5 ~ 8 + ProcessRows \Mode\(),8 + +// +// Restore non-volatile registers and return. +// + +.L\Mode\().ExitKernel: + mov x0,x4 + + ldp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] + ldp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] + ldp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] + ldp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] + ldp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] + ldp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] + ldp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] + ldp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] + ldp x19, x20, [sp], #.LMlasQgemmKernel_SavedRegisters + + ret + +// +// Process 4 rows of the matrix. +// + +.L\Mode\().ProcessCountMLessThan8: + cmp x4,#4 + blt .L\Mode\().ProcessCountMLessThan4 + ProcessRows \Mode\(),4 + b .L\Mode\().ExitKernel + +// +// Process 2 row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan4: + cmp x4,#2 + blt .L\Mode\().ProcessCountMLessThan2 + + ProcessRows \Mode\(),2 + b .L\Mode\().ExitKernel + + +// +// Process the last row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan2: + ProcessRows \Mode\(),1 + b .L\Mode\().ExitKernel + + + .endm + + QgemmS8S8KernelSmmlaFunction Zero + QgemmS8S8KernelSmmlaFunction Add + + .end diff --git a/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelUmmla.S b/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelUmmla.S new file mode 100644 index 0000000000000..baf6e21e6ff06 --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelUmmla.S @@ -0,0 +1,922 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmU8X8KernelUmmla.s + +Abstract: + + This module implements the kernels for the Int8 precision matrix/matrix + multiply operation (QGEMM). + +--*/ + +#include "asmmacro.h" + + .text + +// +// Stack frame layout for the ummla kernel. d8-d15, x19-x30 need save +// + .equ .LMlasQgemmKernel_backup_x19_x20, 0 + .equ .LMlasQgemmKernel_backup_x21_x22, 16 + .equ .LMlasQgemmKernel_backup_x23_x24, 32 + .equ .LMlasQgemmKernel_backup_x25_x26, 48 + .equ .LMlasQgemmKernel_backup_x27_x28, 64 + .equ .LMlasQgemmKernel_backup_d8_d9, 80 + .equ .LMlasQgemmKernel_backup_d10_d11, 96 + .equ .LMlasQgemmKernel_backup_d12_d13, 112 + .equ .LMlasQgemmKernel_backup_d14_d15, 128 + .equ .LMlasQgemmKernel_SavedRegisters, 144 + .equ .LMlasQgemmKernel_SavedRegisters_Neg, -144 + + +// +// Init Row Accumulators +// +// Generates the code to initialize the accumulators for a single row of the output +// block. +// +// +// Accumulators are initialized to ZeroPointB * RowSum + ColumnSum +// x7 for RowSumsBuffer pointer +// x10 for ColumnSumBuffer pointer +// x11 for ZeroPointB buffer pointer +// +// v12~v13 for RowSums values +// v14~v15 for ColumnSums values +// v0~v3 for ZeroPointB values +// + .macro InitRowAccumulators Columns, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, RowSumReg + + mul v7.4s, v\RowSumReg\().4s, v8.4s + mov v\Vec1Reg\().16b, v7.16b + add v\Vec1Reg\().4s, v\Vec1Reg\().4s, v0.4s +.if \Columns\() > 2 + mul v7.4s, v\RowSumReg\().4s, v9.4s + mov v\Vec2Reg\().16b, v7.16b + add v\Vec2Reg\().4s, v\Vec2Reg\().4s, v1.4s +.endif +.if \Columns\() > 4 + mul v7.4s, v\RowSumReg\().4s, v10.4s + mov v\Vec3Reg\().16b, v7.16b + add v\Vec3Reg\().4s, v\Vec3Reg\().4s, v2.4s +.endif +.if \Columns\() > 6 + mul v7.4s, v\RowSumReg\().4s, v11.4s + mov v\Vec4Reg\().16b, v7.16b + add v\Vec4Reg\().4s, v\Vec4Reg\().4s, v3.4s +.endif + + .endm + + +// +// InitBlockAccumulators +// +// Generates the code to initialize the accumulators for 8x8 output +// block. +// + .macro InitBlockAccumulators Mode, Columns, Rows + + ld1 {v14.4s},[x10],#16 // load ColumnSumBuffer[0] +.if \Columns\() > 4 + ld1 {v15.4s},[x10],#16 // load ColumnSumBuffer[4] +.endif + // v4~v7 will be set to matrixB after this, so, they can used now + dup v4.4s,v14.s[0] // broadcast column + dup v5.4s,v14.s[1] + dup v6.4s,v14.s[2] + dup v7.4s,v14.s[3] + + zip1 v0.4s, v4.4s, v5.4s + zip2 v1.4s, v6.4s, v7.4s +.if \Columns\() > 4 + dup v4.4s,v15.s[0] // broadcast column + dup v5.4s,v15.s[1] + dup v6.4s,v15.s[2] + dup v7.4s,v15.s[3] + + zip1 v2.4s, v4.4s, v5.4s + zip2 v3.4s, v6.4s, v7.4s +.endif + + // v8~v11 will anyway get set in MatrixA loading, so they are free to use now + movi v8.4s, #1 + movi v9.4s, #1 + movi v10.4s, #1 + movi v11.4s, #1 + + cbz x11,.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB + + ld1 {v4.4s},[x11],#16 // load ZeroPointB[0] + ld1 {v5.4s},[x11],#16 // load ZeroPointB[4] + + dup v6.4s, v4.s[0] + dup v7.4s, v4.s[1] + zip1 v8.4s, v6.4s, v7.4s + + dup v6.4s, v4.s[2] + dup v7.4s, v4.s[3] + zip1 v9.4s, v6.4s, v7.4s + + dup v6.4s, v5.s[0] + dup v7.4s, v5.s[1] + zip1 v10.4s, v6.4s, v7.4s + + dup v6.4s, v5.s[2] + dup v7.4s, v5.s[3] + zip1 v11.4s, v6.4s, v7.4s + +.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB: + dup v4.4s, v12.s[0] //boardcast RowSums + dup v5.4s, v12.s[1] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),16,17,18,19,6 +.if \Rows\() > 2 + dup v4.4s, v12.s[2] //boardcast RowSums + dup v5.4s, v12.s[3] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),20,21,22,23,6 +.endif +.if \Rows\() > 4 + dup v4.4s,v13.s[0] // broadcast row sums + dup v5.4s,v13.s[1] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),24,25,26,27,6 +.endif +.if \Rows\() > 6 + dup v4.4s,v13.s[2] // broadcast row sums + dup v5.4s,v13.s[3] + + uzp1 v6.2d, v4.2d, v5.2d + InitRowAccumulators \Columns\(),28,29,30,31,6 +.endif + + .endm + + +// LoadPackedMatrixABy16Elements +// +// Generates the code to load 16 elements from matrix A. +// + .macro LoadPackedMatrixABy16Elements Rows +.if \Rows\() == 1 + ldr q8,[x0],#8 +.else + ldr q8,[x0],#16 + +.if \Rows\() > 2 + ldr q9,[x0],#16 +.endif + +.if \Rows\() > 4 + ldr q10,[x0],#16 +.endif + +.if \Rows\() > 6 + ldr q11,[x0],#16 +.endif +.endif + .endm + + +// +// MultiplyAccumulateRow +// +// Generates the code to multiply and accumulate a single row of the output +// block. +// + + .macro MultiplyAccumulateRow Columns, MatrixAReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + ummla v\Vec1Reg\().4s, \MatrixAReg\().16b, v4.16b +.if \Columns\() > 2 + ummla v\Vec2Reg\().4s, \MatrixAReg\().16b, v5.16b +.endif +.if \Columns\() > 4 + ummla v\Vec3Reg\().4s, \MatrixAReg\().16b, v6.16b +.endif +.if \Columns\() > 6 + ummla v\Vec4Reg\().4s, \MatrixAReg\().16b, v7.16b +.endif + + .endm + +// +// MultiplyAccumulateBlock +// +// Generates the code to multiply and accumulate into the output block. +// + + .macro MultiplyAccumulateBlock Columns, Rows + + MultiplyAccumulateRow \Columns\(),v8,16,17,18,19 +.if \Rows\() > 2 + MultiplyAccumulateRow \Columns\(),v9,20,21,22,23 +.endif +.if \Rows\() > 4 + MultiplyAccumulateRow \Columns\(),v10,24,25,26,27 +.endif +.if \Rows\() > 6 + MultiplyAccumulateRow \Columns\(),v11,28,29,30,31 +.endif + + .endm + +// +// ComputeBlockLoop +// +// Generates the code to loop over K entries of the input matrices to produce +// the output block. +// + + .macro ComputeBlockLoop Mode, Columns, Rows + + InitBlockAccumulators \Mode\(), \Columns\(),\Rows\() + + sub x9,x3,#1 // block count to process + tbnz x9,#63,.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop: + + LoadPackedMatrixABy16Elements \Rows\() + ld1 {v4.16b - v7.16b}, [x1], #64 + MultiplyAccumulateBlock \Columns\(),\Rows\() + + sub x9,x9,#1 + tbz x9,#63,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop +.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks: + add x9,x9,#1 // correct for over-subtract above + cbz x9,.L\Mode\().Output\Columns\().x\Rows\().Block + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4PaddedLoop: + LoadPackedMatrixABy16Elements \Rows\() + ld1 {v4.16b - v7.16b}, [x1], #64 + MultiplyAccumulateBlock \Columns\(),\Rows\() + +.L\Mode\().Output\Columns\().x\Rows\().Block: + + .endm + + +// +// OutputRow2Element +// OutputRow4Element +// OutputRow6Element +// OutputRow8Element +// OutputRow10Element +// OutputRow12Element +// OutputRow14Element +// OutputRow16Element +// +// Generates the code to store elements to the output block. +// + + .macro OutputRow2Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr s8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr s9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + mov v8.S[2], v9.S[0] + add v8.4s,v8.4s,v\Vec1Reg\().4s + + mov w27, v8.S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v8.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + mov w27, v\Vec1Reg\().S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v\Vec1Reg\().S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow4Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + + mov v8.D[1], v9.D[0] + + add v8.4s,v8.4s,v\Vec1Reg\().4s + + mov x27, v8.D[0] + mov x28, v8.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + mov x27, v\Vec1Reg\().D[0] + mov x28, v\Vec1Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.endif + + .endm + + + .macro OutputRow6Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#8 + ldr w28,[\AddrReg1\()],#-8 + mov v8.S[2], w28 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-8 + mov v9.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + mov x27, v8.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v8.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v9.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v9.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + mov x27, v4.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v4.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v5.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v5.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow8Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + +.endif + + .endm + + + .macro OutputRow10Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr w28, [\AddrReg1\()],#-16 + +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr w27,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + mov v8.S[0], w28 + mov v8.S[2], w27 + + add v8.4s,v8.4s,v\Vec3Reg\().4s + + mov w27, v8.S[0] + mov w28, v8.S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov w27, v\Vec3Reg\().S[0] + mov w28, v\Vec3Reg\().S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif +.endif + +.endm + + + .macro OutputRow12Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#-16 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + mov v11.D[0],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + + mov v10.D[1], v11.D[0] + + add v10.4s,v10.4s,v\Vec3Reg\().4s + + mov x27, v10.D[0] + mov x28, v10.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov x27, v\Vec3Reg\().D[0] + mov x28, v\Vec3Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif +.endif + + .endm + + .macro OutputRow14Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#8 + ldr w28, [\AddrReg1\()],#-24 + mov v10.S[2], w28 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-24 + mov v11.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + add v10.4s,v10.4s,v6.4s + add v11.4s,v11.4s,v7.4s + + str q8,[\AddrReg1\()],#16 + + mov x27, v10.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v10.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 + mov x27, v11.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v11.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + str q4,[\AddrReg1\()],#16 + mov x27, v6.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v6.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 + mov x27, v7.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v7.S[2] + str w27, [\AddrReg2\()],#4 +.endif +.endif + + .endm + + + .macro OutputRow16Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldp q8,q10,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldp q9,q11,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + add v10.4s,v10.4s,v6.4s + add v11.4s,v11.4s,v7.4s + + stp q8,q10,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q9,q11,[\AddrReg2\()],#32 +.endif +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + stp q4,q6,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q5,q7,[\AddrReg2\()],#32 +.endif +.endif + + .endm + +// +// OutputBlock +// +// Generates the code to store the output block. +// + + .macro OutputBlock Mode, Columns, Rows + + OutputRow\Columns\()Element \Mode\(),x2,x13,16,17,18,19,(\Rows\() == 1) + +.if \Rows\() > 2 + OutputRow\Columns\()Element \Mode\(),x14,x15,20,21,22,23,(\Rows\() == 3) +.endif + +.if \Rows\() > 4 + OutputRow\Columns\()Element \Mode\(),x16,x17,24,25,26,27,(\Rows\() == 5) +.endif + +.if \Rows\() > 6 + OutputRow\Columns\()Element \Mode\(),x18,x19,28,29,30,31,(\Rows\() == 7) +.endif + + .endm +// +// ProcessRows +// +// Generates the code to process a compute and store the output block for a +// fixed number of rows. +// + + .macro ProcessRows Mode, Rows + mov x4,#\Rows\() // return number of rows handled + cmp x5,#6 + ble .L\Mode\().ProcessNextColumnLoop6x\Rows\() + +.L\Mode\().ProcessNextColumnLoop8x\Rows\(): + ComputeBlockLoop \Mode\(),8,\Rows\() + + sub x5,x5,#8 + cmp x5,#0 + blt .L\Mode\().Output14ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),16,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#6 + bgt .L\Mode\().ProcessNextColumnLoop8x\Rows\() + cbz x5,.L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop6x\Rows\(): + + cmp x5,#4 + ble .L\Mode\().ProcessNextColumnLoop4x\Rows\() + ComputeBlockLoop \Mode\(),6,\Rows\() + sub x5,x5,#6 + cmp x5,#0 + blt .L\Mode\().Output10ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),12,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#4 + bgt .L\Mode\().ProcessNextColumnLoop6x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop4x\Rows\(): + cmp x5,#2 + ble .L\Mode\().ProcessNextColumnLoop2x\Rows\() + ComputeBlockLoop \Mode\(),4,\Rows\() + sub x5,x5,#4 + cmp x5,#0 + blt .L\Mode\().Output6ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),8,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#2 + bgt .L\Mode\().ProcessNextColumnLoop4x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop2x\Rows\(): + ComputeBlockLoop \Mode\(),2,\Rows\() + sub x5,x5,#2 + cmp x5,#0 + blt .L\Mode\().Output2ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),4,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#2 + b .L\Mode\().ExitKernel + +.L\Mode\().Output14ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),14,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output10ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),10,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output6ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),6,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output2ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),2,\Rows\() + b .L\Mode\().ExitKernel + + .endm + + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (x0) - Supplies the address of matrix A. The matrix data has been packed + using MlasGemmQuantCopyPackA. + + B (x1) - Supplies the address of matrix B. The matrix data has been packed + using MlasGemmQuantCopyPackB. + + C (x2) - Supplies the address of matrix C. + + PackedCountK (x3) - Supplies the number of packed columns from matrix A and + the number of packed rows from matrix B to iterate over. + + CountM (x4) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (x5) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + ldc (x6) - Supplies the first dimension of matrix C. + + RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values + have been pre-scaled by the zero point offset of matrix B if the offset + is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be + scaled by the per-column zero point offsets of matrix B. These values are + accumulated into every row of matrix C. + + ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied + by the zero point offset of matrix A. These values are accumulated into + every column of matrix C. + + ZeroPointB - Optionally supplies the per-column zero point offsets of matrix + B, else nullptr if the matrix B is using per-tensor quantization. + +Return Value: + + Returns the number of rows handled. + +--*/ + + .macro QgemmU8X8KernelUmmlaFunction Mode + + FUNCTION_ENTRY MlasGemmU8X8KernelUmmla\Mode\() + + ldr x10,[sp, #0] + ldr x11,[sp,#8] + + stp x19, x20, [sp, #.LMlasQgemmKernel_SavedRegisters_Neg]! + stp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] + stp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] + stp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] + stp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] + stp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] + stp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] + stp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] + stp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] + + add x13,x2,x6,lsl #2 // compute matrix C plus 1 row + add x14,x13,x6,lsl #2 // compute matrix C plus 2 rows + add x15,x14,x6,lsl #2 // compute matrix C plus 3 rows + add x16,x15,x6,lsl #2 // compute matrix C plus 4 rows + add x17,x16,x6,lsl #2 // compute matrix C plus 5 rows + add x18,x17,x6,lsl #2 // compute matrix C plus 6 rows + add x19,x18,x6,lsl #2 // compute matrix C plus 7 rows + + mov x8,x0 // save matrix A + +// +// Process 8 rows of the matrices. +// + ld1 {v12.4s},[x7],#16 // load row sum 1 ~ 4 + cmp x4,#8 + blt .L\Mode\().ProcessCountMLessThan8 + ld1 {v13.4s},[x7],#16 // load row sum 5 ~ 8 + ProcessRows \Mode\(),8 + +// +// Restore non-volatile registers and return. +// + +.L\Mode\().ExitKernel: + mov x0,x4 + + ldp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] + ldp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] + ldp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] + ldp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] + ldp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] + ldp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] + ldp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] + ldp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] + ldp x19, x20, [sp], #.LMlasQgemmKernel_SavedRegisters + + ret + +// +// Process 4 rows of the matrix. +// + +.L\Mode\().ProcessCountMLessThan8: + cmp x4,#4 + blt .L\Mode\().ProcessCountMLessThan4 + ProcessRows \Mode\(),4 + b .L\Mode\().ExitKernel + +// +// Process 2 row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan4: + cmp x4,#2 + blt .L\Mode\().ProcessCountMLessThan2 + + ProcessRows \Mode\(),2 + b .L\Mode\().ExitKernel + + +// +// Process the last row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan2: + ProcessRows \Mode\(),1 + b .L\Mode\().ExitKernel + + + .endm + + QgemmU8X8KernelUmmlaFunction Zero + QgemmU8X8KernelUmmlaFunction Add + + .end diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index b6ac4a1ca1d6c..2fdb0dda5d25c 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -184,11 +184,17 @@ class MLASCPUIDInfo bool IsCurrentCoreArmv8NarrowLd() const { return false; } + bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; } + + bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } + private: MLASCPUIDInfo(); bool has_arm_neon_dot_{false}; bool has_fp16_{false}; + bool has_arm_neon_i8mm_{false}; + bool has_arm_sve_i8mm_{false}; }; using MLAS_CPUIDINFO = MLASCPUIDInfo; @@ -856,6 +862,8 @@ extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchNeon; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmX8S8DispatchNeon; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUdot; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSdot; +extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUmmla; +extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSmmla; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmSimd; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmQuantDispatchDefault; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemm8X8DispatchPOWER10; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 3c0f82408179b..39586282e00ad 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -52,6 +52,14 @@ MLASCPUIDInfo::MLASCPUIDInfo() #define HWCAP_ASIMDDP (1 << 20) #endif +#ifndef HWCAP2_I8MM +#define HWCAP2_I8MM (1 << 13) +#endif + +#ifndef HWCAP2_SVEI8MM +#define HWCAP2_SVEI8MM (1 << 9) +#endif + #if defined(BUILD_MLAS_NO_ONNXRUNTIME) MLASCPUIDInfo::MLASCPUIDInfo() { @@ -59,6 +67,9 @@ MLASCPUIDInfo::MLASCPUIDInfo() // raw hack! Need CPUIDInfo implementation for more precise detection has_fp16_ = has_arm_neon_dot_; + + has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); + has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); } #endif @@ -480,6 +491,17 @@ Return Value: this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; } +#if defined(__linux__) + // + // Check if the processor supports ASIMD I8MM instructions. + // + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) { + this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchUmmla; + this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchUmmla; + this->GemmS8S8Dispatch = &MlasGemmS8S8DispatchSmmla; + } +#endif + #endif // MLAS_TARGET_ARM64 #if defined(MLAS_TARGET_POWER) this->GemmFloatKernel = MlasSgemmKernel; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_smmla.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_smmla.cpp new file mode 100644 index 0000000000000..c41f43ca22d18 --- /dev/null +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_smmla.cpp @@ -0,0 +1,964 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + qgemm_kernel_smmla.cpp + +Abstract: + + This module implements smmla QGEMM kernel. + +--*/ + +#include "mlasi.h" +#include "qgemm.h" + +// +// Define the prototypes of the NEON SMMLA routines written in assembly. +// + +extern "C" { + +size_t MLASCALL +MlasGemmS8S8KernelSmmlaZero(const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB); + +size_t MLASCALL +MlasGemmS8S8KernelSmmlaAdd(const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB); +} + +struct MLAS_GEMM_S8S8_KERNEL_SMMLA { + typedef uint8_t PackedAType; + typedef uint8_t PackedBType; + typedef int8_t OffsetAType; + typedef int8_t OffsetBType; + + static constexpr size_t PackedK = 8; + static constexpr MLAS_GEMM_QUANT_STRIDES Strides{24, 128, 256}; + static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{24, 128, 384}; +}; + +constexpr size_t MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedK; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_S8S8_KERNEL_SMMLA::Strides; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedStrides; + +template <> +MLAS_FORCEINLINE int32_t +MlasGemmQuantFixupZeroPointB(int32_t ZeroPointB, bool BIsSigned) +{ + MLAS_UNREFERENCED_PARAMETER(BIsSigned); + return ZeroPointB; +} + +template <> +void +MlasGemmQuantCopyPackA( + MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedAType* D_uint8_t, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer, + bool AIsSigned) +{ + int8_t* D = reinterpret_cast(D_uint8_t); + MLAS_UNREFERENCED_PARAMETER(AIsSigned); + int8_t PaddedMatrixAData[64]; + + // + // Process 8 rows of matrix A. + // + // MMLA kernels load 8x8 block of A with four vector registers. So A is packed + // a series of 64 byte vectors where eight rows are interleaved with the + // following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // [ C0 C1 C2 C3 C4 C5 C6 C7 ] + // [ D0 D1 D2 D3 D4 D5 D6 D7 ] + // [ E0 E1 E2 E3 E4 E5 E6 E7 ] + // [ F0 F1 F2 F3 F4 F5 F6 F7 ] + // [ G0 G1 G2 G3 G4 G5 G6 G7 ] + // [ H0 H1 H2 H3 H4 H5 H6 H7 ] + // + // ... + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + while (CountM >= 8) { + const int8_t* a0 = reinterpret_cast(A); + const int8_t* a1 = a0 + lda; + const int8_t* a2 = a0 + lda * 2; + const int8_t* a3 = a0 + lda * 3; + const int8_t* a4 = a0 + lda * 4; + const int8_t* a5 = a0 + lda * 5; + const int8_t* a6 = a0 + lda * 6; + const int8_t* a7 = a0 + lda * 7; + + size_t k = CountK; + int32x4_t RowSums0 = vmovq_n_s32(0); + int32x4_t RowSums1 = vmovq_n_s32(0); + + while (k >= 16) { + int64x2_t v0 = vld1q_s64(reinterpret_cast(a0)); + a0 += 16; + int64x2_t v1 = vld1q_s64(reinterpret_cast(a1)); + a1 += 16; + int64x2_t v2 = vld1q_s64(reinterpret_cast(a2)); + a2 += 16; + int64x2_t v3 = vld1q_s64(reinterpret_cast(a3)); + a3 += 16; + int64x2_t v4 = vld1q_s64(reinterpret_cast(a4)); + a4 += 16; + int64x2_t v5 = vld1q_s64(reinterpret_cast(a5)); + a5 += 16; + int64x2_t v6 = vld1q_s64(reinterpret_cast(a6)); + a6 += 16; + int64x2_t v7 = vld1q_s64(reinterpret_cast(a7)); + a7 += 16; + + int64x2_t z0 = vzip1q_s64(v0, v1); + int64x2_t z1 = vzip2q_s64(v0, v1); + int64x2_t z2 = vzip1q_s64(v2, v3); + int64x2_t z3 = vzip2q_s64(v2, v3); + + int64x2_t z4 = vzip1q_s64(v4, v5); + int64x2_t z5 = vzip2q_s64(v4, v5); + int64x2_t z6 = vzip1q_s64(v6, v7); + int64x2_t z7 = vzip2q_s64(v6, v7); + + vst1q_s8(&D[0], vreinterpretq_s8_s64(z0)); + vst1q_s8(&D[16], vreinterpretq_s8_s64(z2)); + vst1q_s8(&D[32], vreinterpretq_s8_s64(z4)); + vst1q_s8(&D[48], vreinterpretq_s8_s64(z6)); + vst1q_s8(&D[64], vreinterpretq_s8_s64(z1)); + vst1q_s8(&D[80], vreinterpretq_s8_s64(z3)); + vst1q_s8(&D[96], vreinterpretq_s8_s64(z5)); + vst1q_s8(&D[112], vreinterpretq_s8_s64(z7)); + + int32x4_t RowSums0L_pada = vmovq_n_s32(0); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z0))); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z1))); + + int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); + int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); + int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), + vdups_laneq_s32(RowSums0L_add, 2)}; + + int32x4_t RowSums0H_pada = vmovq_n_s32(0); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z2))); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z3))); + + int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); + int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); + int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), + vdups_laneq_s32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_s32(RowSums0, vcombine_s32(RowSums0L, RowSums0H)); + + int32x4_t RowSums1L_pada = vmovq_n_s32(0); + RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z4))); + RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z5))); + + int32x4_t RowSums1L_ext = vextq_s32(RowSums1L_pada, RowSums1L_pada, 1); + int32x4_t RowSums1L_add = vaddq_s32(RowSums1L_pada, RowSums1L_ext); + int32x2_t RowSums1L = {vdups_laneq_s32(RowSums1L_add, 0), + vdups_laneq_s32(RowSums1L_add, 2)}; + + int32x4_t RowSums1H_pada = vmovq_n_s32(0); + RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z6))); + RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z7))); + + int32x4_t RowSums1H_ext = vextq_s32(RowSums1H_pada, RowSums1H_pada, 1); + int32x4_t RowSums1H_add = vaddq_s32(RowSums1H_pada, RowSums1H_ext); + int32x2_t RowSums1H = {vdups_laneq_s32(RowSums1H_add, 0), + vdups_laneq_s32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_s32(RowSums1, vcombine_s32(RowSums1L, RowSums1H)); + + D += 128; + k -= 16; + } + + while (k >= 8) { + int64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + int64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + int64x1_t v2 = *reinterpret_cast(a2); + a2 += 8; + int64x1_t v3 = *reinterpret_cast(a3); + a3 += 8; + int64x1_t v4 = *reinterpret_cast(a4); + a4 += 8; + int64x1_t v5 = *reinterpret_cast(a5); + a5 += 8; + int64x1_t v6 = *reinterpret_cast(a6); + a6 += 8; + int64x1_t v7 = *reinterpret_cast(a7); + a7 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + *reinterpret_cast(&D[16]) = v2; + *reinterpret_cast(&D[24]) = v3; + *reinterpret_cast(&D[32]) = v4; + *reinterpret_cast(&D[40]) = v5; + *reinterpret_cast(&D[48]) = v6; + *reinterpret_cast(&D[56]) = v7; + + int64x2_t z01 = vcombine_s64(v0, v1); + int64x2_t z23 = vcombine_s64(v2, v3); + int64x2_t z45 = vcombine_s64(v4, v5); + int64x2_t z67 = vcombine_s64(v6, v7); + + int32x4_t RowSums0L_pada = vmovq_n_s32(0); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); + int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); + int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), + vdups_laneq_s32(RowSums0L_add, 2)}; + + int32x4_t RowSums0H_pada = vmovq_n_s32(0); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); + + int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); + int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); + int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), + vdups_laneq_s32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_s32(RowSums0, vcombine_s32(RowSums0L, RowSums0H)); + + int32x4_t RowSums1L_pada = vmovq_n_s32(0); + RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z45))); + + int32x4_t RowSums1L_ext = vextq_s32(RowSums1L_pada, RowSums1L_pada, 1); + int32x4_t RowSums1L_add = vaddq_s32(RowSums1L_pada, RowSums1L_ext); + int32x2_t RowSums1L = {vdups_laneq_s32(RowSums1L_add, 0), + vdups_laneq_s32(RowSums1L_add, 2)}; + + int32x4_t RowSums1H_pada = vmovq_n_s32(0); + RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z67))); + + int32x4_t RowSums1H_ext = vextq_s32(RowSums1H_pada, RowSums1H_pada, 1); + int32x4_t RowSums1H_add = vaddq_s32(RowSums1H_pada, RowSums1H_ext); + int32x2_t RowSums1H = {vdups_laneq_s32(RowSums1H_add, 0), + vdups_laneq_s32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_s32(RowSums1, vcombine_s32(RowSums1L, RowSums1H)); + + D += 64; + k -= 8; + } + + if (k > 0) { + // + // zero pad the remaining columns to 8 + // + int8_t* d = D; + + vst1q_s8(d, vmovq_n_s8(0)); + vst1q_s8(&d[16], vmovq_n_s8(0)); + vst1q_s8(&d[32], vmovq_n_s8(0)); + vst1q_s8(&d[48], vmovq_n_s8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + d[16] = *a2++; + d[24] = *a3++; + d[32] = *a4++; + d[40] = *a5++; + d[48] = *a6++; + d[56] = *a7++; + d += 1; + k -= 1; + } + d = D; + int64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v2 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v3 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v4 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v5 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v6 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v7 = *reinterpret_cast(d); + d = d + 8; + + int64x2_t z01 = vcombine_s64(v0, v1); + int64x2_t z23 = vcombine_s64(v2, v3); + int64x2_t z45 = vcombine_s64(v4, v5); + int64x2_t z67 = vcombine_s64(v6, v7); + + int32x4_t RowSums0L_pada = vmovq_n_s32(0); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); + int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); + int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), + vdups_laneq_s32(RowSums0L_add, 2)}; + + int32x4_t RowSums0H_pada = vmovq_n_s32(0); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); + + int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); + int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); + int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), + vdups_laneq_s32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_s32(RowSums0, vcombine_s32(RowSums0L, RowSums0H)); + + int32x4_t RowSums1L_pada = vmovq_n_s32(0); + RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z45))); + + int32x4_t RowSums1L_ext = vextq_s32(RowSums1L_pada, RowSums1L_pada, 1); + int32x4_t RowSums1L_add = vaddq_s32(RowSums1L_pada, RowSums1L_ext); + int32x2_t RowSums1L = {vdups_laneq_s32(RowSums1L_add, 0), + vdups_laneq_s32(RowSums1L_add, 2)}; + + int32x4_t RowSums1H_pada = vmovq_n_s32(0); + RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z67))); + + int32x4_t RowSums1H_ext = vextq_s32(RowSums1H_pada, RowSums1H_pada, 1); + int32x4_t RowSums1H_add = vaddq_s32(RowSums1H_pada, RowSums1H_ext); + int32x2_t RowSums1H = {vdups_laneq_s32(RowSums1H_add, 0), + vdups_laneq_s32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_s32(RowSums1, vcombine_s32(RowSums1L, RowSums1H)); + + D += 64; + } + + vst1q_s32(RowSumBuffer, RowSums0); + vst1q_s32(&RowSumBuffer[4], RowSums1); + + RowSumBuffer += 8; + + A = A + lda * 8; + CountM -= 8; + } + + // + // Process four rows of matrix A. + // + // The buffer is packed as a series of 32 byte vectors where four rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // [ C0 C1 C2 C3 C4 C5 C6 C7 ] + // [ D0 D1 D2 D3 D4 D5 D6 D7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + if (CountM >= 4) { + const int8_t* a0 = reinterpret_cast(A); + const int8_t* a1 = a0 + lda; + const int8_t* a2 = a1 + lda; + const int8_t* a3 = a2 + lda; + + size_t k = CountK; + int32x4_t RowSums = vmovq_n_s32(0); + + while (k >= 16) { + int64x2_t v0 = vld1q_s64(reinterpret_cast(a0)); + a0 += 16; + int64x2_t v1 = vld1q_s64(reinterpret_cast(a1)); + a1 += 16; + int64x2_t v2 = vld1q_s64(reinterpret_cast(a2)); + a2 += 16; + int64x2_t v3 = vld1q_s64(reinterpret_cast(a3)); + a3 += 16; + + int64x2_t z0 = vzip1q_s64(v0, v1); + int64x2_t z1 = vzip2q_s64(v0, v1); + int64x2_t z2 = vzip1q_s64(v2, v3); + int64x2_t z3 = vzip2q_s64(v2, v3); + + vst1q_s8(&D[0], vreinterpretq_s8_s64(z0)); + vst1q_s8(&D[16], vreinterpretq_s8_s64(z2)); + vst1q_s8(&D[32], vreinterpretq_s8_s64(z1)); + vst1q_s8(&D[48], vreinterpretq_s8_s64(z3)); + + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z0))); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z1))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + int32x4_t RowSumsH_pada = vmovq_n_s32(0); + RowSumsH_pada = vpadalq_s16(RowSumsH_pada, vpaddlq_s8(vreinterpretq_s8_s64(z2))); + RowSumsH_pada = vpadalq_s16(RowSumsH_pada, vpaddlq_s8(vreinterpretq_s8_s64(z3))); + + int32x4_t RowSumsH_ext = vextq_s32(RowSumsH_pada, RowSumsH_pada, 1); + int32x4_t RowSumsH_add = vaddq_s32(RowSumsH_pada, RowSumsH_ext); + int32x2_t RowSumsH = {vdups_laneq_s32(RowSumsH_add, 0), + vdups_laneq_s32(RowSumsH_add, 2)}; + + RowSums = vaddq_s32(RowSums, vcombine_s32(RowSumsL, RowSumsH)); + + D += 64; + k -= 16; + } + + while (k >= 8) { + int64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + int64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + int64x1_t v2 = *reinterpret_cast(a2); + a2 += 8; + int64x1_t v3 = *reinterpret_cast(a3); + a3 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + *reinterpret_cast(&D[16]) = v2; + *reinterpret_cast(&D[24]) = v3; + + int64x2_t z01 = vcombine_s64(v0, v1); + int64x2_t z23 = vcombine_s64(v2, v3); + + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + int32x4_t RowSumsH_pada = vmovq_n_s32(0); + RowSumsH_pada = vpadalq_s16(RowSumsH_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); + + int32x4_t RowSumsH_ext = vextq_s32(RowSumsH_pada, RowSumsH_pada, 1); + int32x4_t RowSumsH_add = vaddq_s32(RowSumsH_pada, RowSumsH_ext); + int32x2_t RowSumsH = {vdups_laneq_s32(RowSumsH_add, 0), + vdups_laneq_s32(RowSumsH_add, 2)}; + + RowSums = vaddq_s32(RowSums, vcombine_s32(RowSumsL, RowSumsH)); + + D += 32; + k -= 8; + } + + if (k > 0) { + // + // Copy the remaining bytes with zero padding. + // + int8_t* d = D; + + vst1q_s8(d, vmovq_n_s8(0)); + vst1q_s8(&d[16], vmovq_n_s8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + d[16] = *a2++; + d[24] = *a3++; + d += 1; + k -= 1; + } + + d = D; + int64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v2 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v3 = *reinterpret_cast(d); + d = d + 8; + + int64x2_t z01 = vcombine_s64(v0, v1); + int64x2_t z23 = vcombine_s64(v2, v3); + + int32x4_t RowSums0L_pada = vmovq_n_s32(0); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); + int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); + int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), + vdups_laneq_s32(RowSums0L_add, 2)}; + + int32x4_t RowSums0H_pada = vmovq_n_s32(0); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); + + int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); + int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); + int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), + vdups_laneq_s32(RowSums0H_add, 2)}; + + RowSums = vaddq_s32(RowSums, vcombine_s32(RowSums0L, RowSums0H)); + + D += 32; + } + + vst1q_s32(RowSumBuffer, RowSums); + RowSumBuffer += 4; + + A = A + lda * 4; + CountM -= 4; + } + + // + // Process two rows of matrix A. + // + // The buffer is packed as a series of 16 byte vectors where two rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + if (CountM >= 2) { + const int8_t* a0 = reinterpret_cast(A); + const int8_t* a1 = a0 + lda; + + size_t k = CountK; + int32x2_t RowSums = vmov_n_s32(0); + + while (k >= 16) { + int64x2_t v0 = vld1q_s64(reinterpret_cast(a0)); + a0 += 16; + int64x2_t v1 = vld1q_s64(reinterpret_cast(a1)); + a1 += 16; + + int64x2_t z0 = vzip1q_s64(v0, v1); + int64x2_t z1 = vzip2q_s64(v0, v1); + + vst1q_s8(&D[0], vreinterpretq_s8_s64(z0)); + vst1q_s8(&D[16], vreinterpretq_s8_s64(z1)); + + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z0))); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z1))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + RowSums = vadd_s32(RowSums, RowSumsL); + + D += 32; + k -= 16; + } + + while (k >= 8) { + int64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + int64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + + int64x2_t z01 = vcombine_s64(v0, v1); + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + RowSums = vadd_s32(RowSums, RowSumsL); + + D += 16; + k -= 8; + } + + if (k > 0) { + // + // Zero pad the remaining elements to make 8 columns. + // + + int8_t* d = PaddedMatrixAData; + vst1q_s8(PaddedMatrixAData, vmovq_n_s8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + + d += 1; + k -= 1; + } + + d = PaddedMatrixAData; + int64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + + int64x2_t z01 = vcombine_s64(v0, v1); + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + RowSums = vadd_s32(RowSums, RowSumsL); + + int8x16_t PackedVector = vld1q_s8(PaddedMatrixAData); + vst1q_s8(D, PackedVector); + + D += 16; + } + + vst1_s32(RowSumBuffer, RowSums); + RowSumBuffer += 2; + + A = A + lda * 2; + CountM -= 2; + } + + // + // Process one row of matrix A. + // + // The buffer is packed as a series of 8 byte with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of 8, then the vector is padded + // with zeroes. + // + + if (CountM > 0) { + // No need to pad the rows to 2, the .S takes care of zero pdding + const int8_t* a = reinterpret_cast(A); + size_t k = CountK; + int32x4_t RowSums = vmovq_n_s32(0); + + while (k >= 16) { + int8x16_t v = vld1q_s8(a); + a += 16; + + vst1q_s8(D, v); + + RowSums = vpadalq_s16(RowSums, vpaddlq_s8(v)); + + D += 16; + k -= 16; + } + + if (k > 0) { + // + // Copy the remaining bytes to the zero padded stack buffer. + // + + vst1q_s8(PaddedMatrixAData, vmovq_n_s8(0)); + + for (size_t kk = 0; kk < k; kk++) { + PaddedMatrixAData[kk] = a[kk]; + } + + int8x16_t v = vld1q_s8(PaddedMatrixAData); + vst1q_s8(D, v); + + RowSums = vpadalq_s16(RowSums, vpaddlq_s8(v)); + } + + *RowSumBuffer = int32_t(vaddvq_s32(RowSums)); + } +} + +MLAS_FORCEINLINE +void +MlasGemmS8S8CopyPackBProcessSmmla(int8_t* D, int8x8_t BytesRow[8], int32x4_t ColumnSums[2]) +{ + int8x16_t v02 = vcombine_s8(BytesRow[0], BytesRow[2]); + int8x16_t v13 = vcombine_s8(BytesRow[1], BytesRow[3]); + + int8x16_t v46 = vcombine_s8(BytesRow[4], BytesRow[6]); + int8x16_t v57 = vcombine_s8(BytesRow[5], BytesRow[7]); + + int8x16x2_t zw1 = vzipq_s8(v02, v13); + int16x8x2_t zd1 = vzipq_s16(vreinterpretq_s16_s8(zw1.val[0]), vreinterpretq_s16_s8(zw1.val[1])); + + int8x16x2_t zw2 = vzipq_s8(v46, v57); + int16x8x2_t zd2 = vzipq_s16(vreinterpretq_s16_s8(zw2.val[0]), vreinterpretq_s16_s8(zw2.val[1])); + + int32x4x2_t zd3 = + vzipq_s32(vreinterpretq_s32_s16(zd1.val[0]), vreinterpretq_s32_s16(zd2.val[0])); + int32x4x2_t zd4 = + vzipq_s32(vreinterpretq_s32_s16(zd1.val[1]), vreinterpretq_s32_s16(zd2.val[1])); + + vst1q_s8(&D[0], vreinterpretq_s8_s32(zd3.val[0])); + vst1q_s8(&D[16], vreinterpretq_s8_s32(zd3.val[1])); + vst1q_s8(&D[32], vreinterpretq_s8_s32(zd4.val[0])); + vst1q_s8(&D[48], vreinterpretq_s8_s32(zd4.val[1])); + + int32x4_t ColSums0L_pada = vmovq_n_s32(0); + ColSums0L_pada = vpadalq_s16(ColSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd3.val[0]))); + int32x4_t ColSums0L_ext = vextq_s32(ColSums0L_pada, ColSums0L_pada, 1); + int32x4_t ColSums0L_add = vaddq_s32(ColSums0L_pada, ColSums0L_ext); + int32x2_t ColSums0L = {vdups_laneq_s32(ColSums0L_add, 0), vdups_laneq_s32(ColSums0L_add, 2)}; + + int32x4_t ColSums0H_pada = vmovq_n_s32(0); + ColSums0H_pada = vpadalq_s16(ColSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd3.val[1]))); + int32x4_t ColSums0H_ext = vextq_s32(ColSums0H_pada, ColSums0H_pada, 1); + int32x4_t ColSums0H_add = vaddq_s32(ColSums0H_pada, ColSums0H_ext); + int32x2_t ColSums0H = {vdups_laneq_s32(ColSums0H_add, 0), vdups_laneq_s32(ColSums0H_add, 2)}; + + ColumnSums[0] = vaddq_s32(ColumnSums[0], vcombine_s32(ColSums0L, ColSums0H)); + + int32x4_t ColSums1L_pada = vmovq_n_s32(0); + ColSums1L_pada = vpadalq_s16(ColSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd4.val[0]))); + int32x4_t ColSums1L_ext = vextq_s32(ColSums1L_pada, ColSums1L_pada, 1); + int32x4_t ColSums1L_add = vaddq_s32(ColSums1L_pada, ColSums1L_ext); + int32x2_t ColSums1L = {vdups_laneq_s32(ColSums1L_add, 0), vdups_laneq_s32(ColSums1L_add, 2)}; + + int32x4_t ColSums1H_pada = vmovq_n_s32(0); + ColSums1H_pada = vpadalq_s16(ColSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd4.val[1]))); + int32x4_t ColSums1H_ext = vextq_s32(ColSums1H_pada, ColSums1H_pada, 1); + int32x4_t ColSums1H_add = vaddq_s32(ColSums1H_pada, ColSums1H_ext); + int32x2_t ColSums1H = {vdups_laneq_s32(ColSums1H_add, 0), vdups_laneq_s32(ColSums1H_add, 2)}; + + ColumnSums[1] = vaddq_s32(ColumnSums[1], vcombine_s32(ColSums1L, ColSums1H)); +} + +template <> +void +MlasGemmQuantCopyPackB(MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedBType* Dst, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned) +{ + MLAS_UNREFERENCED_PARAMETER(BIsSigned); + int8_t* D = reinterpret_cast(Dst); + const int8x16_t ZeroVector = vmovq_n_s8(0); + int8x8_t BytesRow[8]; + + // + // Copy data from matrix B into the destination buffer 8x2 blocks at a + // time. + // + // + while (CountN >= 8) { + const int8_t* b = reinterpret_cast(B); + size_t k = CountK; + int32x4_t ColumnSums[2]; + + ColumnSums[0] = vmovq_n_s32(0); + ColumnSums[1] = vmovq_n_s32(0); + + while (k >= 8) { + BytesRow[0] = vld1_s8(&b[ldb * 0]); + BytesRow[1] = vld1_s8(&b[ldb * 1]); + BytesRow[2] = vld1_s8(&b[ldb * 2]); + BytesRow[3] = vld1_s8(&b[ldb * 3]); + BytesRow[4] = vld1_s8(&b[ldb * 4]); + BytesRow[5] = vld1_s8(&b[ldb * 5]); + BytesRow[6] = vld1_s8(&b[ldb * 6]); + BytesRow[7] = vld1_s8(&b[ldb * 7]); + + MlasGemmS8S8CopyPackBProcessSmmla(D, BytesRow, ColumnSums); + + D += 64; + b += ldb * 8; + k -= 8; + } + + if (k > 0) { + // Pad k to 8 + + BytesRow[0] = vld1_s8(&b[ldb * 0]); + BytesRow[1] = (k >= 2) ? vld1_s8(&b[ldb * 1]) : vget_low_s8(ZeroVector); + BytesRow[2] = (k >= 3) ? vld1_s8(&b[ldb * 2]) : vget_low_s8(ZeroVector); + BytesRow[3] = (k >= 4) ? vld1_s8(&b[ldb * 3]) : vget_low_s8(ZeroVector); + BytesRow[4] = (k >= 5) ? vld1_s8(&b[ldb * 4]) : vget_low_s8(ZeroVector); + BytesRow[5] = (k >= 6) ? vld1_s8(&b[ldb * 5]) : vget_low_s8(ZeroVector); + BytesRow[6] = (k >= 7) ? vld1_s8(&b[ldb * 6]) : vget_low_s8(ZeroVector); + BytesRow[7] = vget_low_s8(ZeroVector); + + MlasGemmS8S8CopyPackBProcessSmmla(D, BytesRow, ColumnSums); + + D += 64; + } + + // Zero pad the output buffer to a multiple of PackedK if the above + // processed an odd number of four row bundles. + // + vst1q_s32(&ColumnSumBuffer[0], ColumnSums[0]); + vst1q_s32(&ColumnSumBuffer[4], ColumnSums[1]); + + ColumnSumBuffer += 8; + + B += 8; + CountN -= 8; + } + + // + // Process the remaining columns of matrix B. + // + + if (CountN > 0) { + const int8_t* b = reinterpret_cast(B); + size_t k = CountK; + int8_t PaddedMatrixBData[64]; + int32x4_t ColumnSums[2]; + + vst1q_s8(&PaddedMatrixBData[0], ZeroVector); + vst1q_s8(&PaddedMatrixBData[16], ZeroVector); + vst1q_s8(&PaddedMatrixBData[32], ZeroVector); + vst1q_s8(&PaddedMatrixBData[48], ZeroVector); + + ColumnSums[0] = vmovq_n_s32(0); + ColumnSums[1] = vmovq_n_s32(0); + + // + // Interleave rows of matrix B using an intermediate zero padded stack + // buffer and write to the packed buffer. + // + + while (k > 0) { + const int8_t* bcopy0 = &b[ldb * 0]; + const int8_t* bcopy1 = &b[ldb * 1]; + const int8_t* bcopy2 = &b[ldb * 2]; + const int8_t* bcopy3 = &b[ldb * 3]; + const int8_t* bcopy4 = &b[ldb * 4]; + const int8_t* bcopy5 = &b[ldb * 5]; + const int8_t* bcopy6 = &b[ldb * 6]; + const int8_t* bcopy7 = &b[ldb * 7]; + + if (k >= 8) { + b += ldb * 8; + k -= 8; + + } else { + vst1q_s8(&PaddedMatrixBData[0], ZeroVector); + vst1q_s8(&PaddedMatrixBData[16], ZeroVector); + vst1q_s8(&PaddedMatrixBData[32], ZeroVector); + vst1q_s8(&PaddedMatrixBData[48], ZeroVector); + + bcopy1 = (k >= 2) ? bcopy1 : &PaddedMatrixBData[56]; + bcopy2 = (k >= 3) ? bcopy2 : &PaddedMatrixBData[56]; + bcopy3 = (k >= 4) ? bcopy3 : &PaddedMatrixBData[56]; + bcopy4 = (k >= 5) ? bcopy4 : &PaddedMatrixBData[56]; + bcopy5 = (k >= 6) ? bcopy5 : &PaddedMatrixBData[56]; + bcopy6 = (k >= 7) ? bcopy6 : &PaddedMatrixBData[56]; + bcopy7 = &PaddedMatrixBData[56]; + + k = 0; + } + + int8_t* padded = PaddedMatrixBData; + int8_t* padded_end = padded + CountN; + do { + padded[0] = *bcopy0++; + padded[8] = *bcopy1++; + padded[16] = *bcopy2++; + padded[24] = *bcopy3++; + padded[32] = *bcopy4++; + padded[40] = *bcopy5++; + padded[48] = *bcopy6++; + padded[56] = *bcopy7++; + + } while (++padded < padded_end); + + BytesRow[0] = vld1_s8(&PaddedMatrixBData[0]); + BytesRow[1] = vld1_s8(&PaddedMatrixBData[8]); + BytesRow[2] = vld1_s8(&PaddedMatrixBData[16]); + BytesRow[3] = vld1_s8(&PaddedMatrixBData[24]); + BytesRow[4] = vld1_s8(&PaddedMatrixBData[32]); + BytesRow[5] = vld1_s8(&PaddedMatrixBData[40]); + BytesRow[6] = vld1_s8(&PaddedMatrixBData[48]); + BytesRow[7] = vld1_s8(&PaddedMatrixBData[56]); + + MlasGemmS8S8CopyPackBProcessSmmla(D, BytesRow, ColumnSums); + + D += 64; + } + + vst1q_s32(&ColumnSumBuffer[0], ColumnSums[0]); + vst1q_s32(&ColumnSumBuffer[4], ColumnSums[1]); + } +} + +template <> +MLAS_FORCEINLINE size_t +MlasGemmQuantKernel(const MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedAType* A, + const MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode) +{ + size_t RowsHandled; + + if (ZeroMode) { + RowsHandled = MlasGemmS8S8KernelSmmlaZero(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB); + } else { + RowsHandled = MlasGemmS8S8KernelSmmlaAdd(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB); + } + + return RowsHandled; +} + +const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSmmla = { + MlasGemmQuantOperation, + MlasGemmQuantPackedOperation, + MlasGemmQuantCopyPackB, + MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedK, + MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedStrides.K, + 8}; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_ummla.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_ummla.cpp new file mode 100644 index 0000000000000..3936154432ac7 --- /dev/null +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_ummla.cpp @@ -0,0 +1,967 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + qgemm_kernel_ummla.cpp + +Abstract: + + This module implements ummla QGEMM kernel. + +--*/ + +#include "mlasi.h" +#include "qgemm.h" + +// +// Define the prototypes of the NEON UMMLA routines written in assembly. +// + +extern "C" { + +size_t MLASCALL +MlasGemmU8X8KernelUmmlaZero(const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB); + +size_t MLASCALL +MlasGemmU8X8KernelUmmlaAdd(const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB); +} + +struct MLAS_GEMM_U8X8_KERNEL_UMMLA { + typedef uint8_t PackedAType; + typedef uint8_t PackedBType; + typedef uint8_t OffsetAType; + typedef uint8_t OffsetBType; + + static constexpr size_t PackedK = 8; + static constexpr MLAS_GEMM_QUANT_STRIDES Strides{24, 128, 256}; + static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{24, 128, 384}; +}; + +constexpr size_t MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedK; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8X8_KERNEL_UMMLA::Strides; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedStrides; + +template <> +MLAS_FORCEINLINE int32_t +MlasGemmQuantFixupZeroPointB(int32_t ZeroPointB, bool BIsSigned) +{ + if (BIsSigned) { + ZeroPointB = MLAS_GEMM_U8X8_KERNEL_UMMLA::OffsetBType(ZeroPointB ^ 0x80); + } + + return ZeroPointB; +} + +template <> +void +MlasGemmQuantCopyPackA(MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer, + bool AIsSigned) +{ + MLAS_UNREFERENCED_PARAMETER(AIsSigned); + uint8_t PaddedMatrixAData[64]; + + // + // Process 8 rows of matrix A. + // + // MMLA kernels load 8x8 block of A with four vector registers. So A is packed + // a series of 64 byte vectors where eight rows are interleaved with the + // following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // [ C0 C1 C2 C3 C4 C5 C6 C7 ] + // [ D0 D1 D2 D3 D4 D5 D6 D7 ] + // [ E0 E1 E2 E3 E4 E5 E6 E7 ] + // [ F0 F1 F2 F3 F4 F5 F6 F7 ] + // [ G0 G1 G2 G3 G4 G5 G6 G7 ] + // [ H0 H1 H2 H3 H4 H5 H6 H7 ] + // + // ... + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + while (CountM >= 8) { + const uint8_t* a0 = A; + const uint8_t* a1 = a0 + lda; + const uint8_t* a2 = a0 + lda * 2; + const uint8_t* a3 = a0 + lda * 3; + const uint8_t* a4 = a0 + lda * 4; + const uint8_t* a5 = a0 + lda * 5; + const uint8_t* a6 = a0 + lda * 6; + const uint8_t* a7 = a0 + lda * 7; + + size_t k = CountK; + uint32x4_t RowSums0 = vmovq_n_u32(0); + uint32x4_t RowSums1 = vmovq_n_u32(0); + + while (k >= 16) { + uint64x2_t v0 = vld1q_u64(reinterpret_cast(a0)); + a0 += 16; + uint64x2_t v1 = vld1q_u64(reinterpret_cast(a1)); + a1 += 16; + uint64x2_t v2 = vld1q_u64(reinterpret_cast(a2)); + a2 += 16; + uint64x2_t v3 = vld1q_u64(reinterpret_cast(a3)); + a3 += 16; + uint64x2_t v4 = vld1q_u64(reinterpret_cast(a4)); + a4 += 16; + uint64x2_t v5 = vld1q_u64(reinterpret_cast(a5)); + a5 += 16; + uint64x2_t v6 = vld1q_u64(reinterpret_cast(a6)); + a6 += 16; + uint64x2_t v7 = vld1q_u64(reinterpret_cast(a7)); + a7 += 16; + + uint64x2_t z0 = vzip1q_u64(v0, v1); + uint64x2_t z1 = vzip2q_u64(v0, v1); + uint64x2_t z2 = vzip1q_u64(v2, v3); + uint64x2_t z3 = vzip2q_u64(v2, v3); + + uint64x2_t z4 = vzip1q_u64(v4, v5); + uint64x2_t z5 = vzip2q_u64(v4, v5); + uint64x2_t z6 = vzip1q_u64(v6, v7); + uint64x2_t z7 = vzip2q_u64(v6, v7); + + vst1q_u8(&D[0], vreinterpretq_u8_u64(z0)); + vst1q_u8(&D[16], vreinterpretq_u8_u64(z2)); + vst1q_u8(&D[32], vreinterpretq_u8_u64(z4)); + vst1q_u8(&D[48], vreinterpretq_u8_u64(z6)); + vst1q_u8(&D[64], vreinterpretq_u8_u64(z1)); + vst1q_u8(&D[80], vreinterpretq_u8_u64(z3)); + vst1q_u8(&D[96], vreinterpretq_u8_u64(z5)); + vst1q_u8(&D[112], vreinterpretq_u8_u64(z7)); + + uint32x4_t RowSums0L_pada = vmovq_n_u32(0); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z0))); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z1))); + + uint32x4_t RowSums0L_ext = vextq_u32(RowSums0L_pada, RowSums0L_pada, 1); + uint32x4_t RowSums0L_add = vaddq_u32(RowSums0L_pada, RowSums0L_ext); + uint32x2_t RowSums0L = {vdups_laneq_u32(RowSums0L_add, 0), + vdups_laneq_u32(RowSums0L_add, 2)}; + + uint32x4_t RowSums0H_pada = vmovq_n_u32(0); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z2))); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z3))); + + uint32x4_t RowSums0H_ext = vextq_u32(RowSums0H_pada, RowSums0H_pada, 1); + uint32x4_t RowSums0H_add = vaddq_u32(RowSums0H_pada, RowSums0H_ext); + uint32x2_t RowSums0H = {vdups_laneq_u32(RowSums0H_add, 0), + vdups_laneq_u32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_u32(RowSums0, vcombine_u32(RowSums0L, RowSums0H)); + + uint32x4_t RowSums1L_pada = vmovq_n_u32(0); + RowSums1L_pada = vpadalq_u16(RowSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z4))); + RowSums1L_pada = vpadalq_u16(RowSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z5))); + + uint32x4_t RowSums1L_ext = vextq_u32(RowSums1L_pada, RowSums1L_pada, 1); + uint32x4_t RowSums1L_add = vaddq_u32(RowSums1L_pada, RowSums1L_ext); + uint32x2_t RowSums1L = {vdups_laneq_u32(RowSums1L_add, 0), + vdups_laneq_u32(RowSums1L_add, 2)}; + + uint32x4_t RowSums1H_pada = vmovq_n_u32(0); + RowSums1H_pada = vpadalq_u16(RowSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z6))); + RowSums1H_pada = vpadalq_u16(RowSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z7))); + + uint32x4_t RowSums1H_ext = vextq_u32(RowSums1H_pada, RowSums1H_pada, 1); + uint32x4_t RowSums1H_add = vaddq_u32(RowSums1H_pada, RowSums1H_ext); + uint32x2_t RowSums1H = {vdups_laneq_u32(RowSums1H_add, 0), + vdups_laneq_u32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_u32(RowSums1, vcombine_u32(RowSums1L, RowSums1H)); + + D += 128; + k -= 16; + } + + while (k >= 8) { + uint64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + uint64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + uint64x1_t v2 = *reinterpret_cast(a2); + a2 += 8; + uint64x1_t v3 = *reinterpret_cast(a3); + a3 += 8; + uint64x1_t v4 = *reinterpret_cast(a4); + a4 += 8; + uint64x1_t v5 = *reinterpret_cast(a5); + a5 += 8; + uint64x1_t v6 = *reinterpret_cast(a6); + a6 += 8; + uint64x1_t v7 = *reinterpret_cast(a7); + a7 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + *reinterpret_cast(&D[16]) = v2; + *reinterpret_cast(&D[24]) = v3; + *reinterpret_cast(&D[32]) = v4; + *reinterpret_cast(&D[40]) = v5; + *reinterpret_cast(&D[48]) = v6; + *reinterpret_cast(&D[56]) = v7; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint64x2_t z23 = vcombine_u64(v2, v3); + uint64x2_t z45 = vcombine_u64(v4, v5); + uint64x2_t z67 = vcombine_u64(v6, v7); + + uint32x4_t RowSums0L_pada = vmovq_n_u32(0); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSums0L_ext = vextq_u32(RowSums0L_pada, RowSums0L_pada, 1); + uint32x4_t RowSums0L_add = vaddq_u32(RowSums0L_pada, RowSums0L_ext); + uint32x2_t RowSums0L = {vdups_laneq_u32(RowSums0L_add, 0), + vdups_laneq_u32(RowSums0L_add, 2)}; + + uint32x4_t RowSums0H_pada = vmovq_n_u32(0); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z23))); + + uint32x4_t RowSums0H_ext = vextq_u32(RowSums0H_pada, RowSums0H_pada, 1); + uint32x4_t RowSums0H_add = vaddq_u32(RowSums0H_pada, RowSums0H_ext); + uint32x2_t RowSums0H = {vdups_laneq_u32(RowSums0H_add, 0), + vdups_laneq_u32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_u32(RowSums0, vcombine_u32(RowSums0L, RowSums0H)); + + uint32x4_t RowSums1L_pada = vmovq_n_u32(0); + RowSums1L_pada = vpadalq_u16(RowSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z45))); + + uint32x4_t RowSums1L_ext = vextq_u32(RowSums1L_pada, RowSums1L_pada, 1); + uint32x4_t RowSums1L_add = vaddq_u32(RowSums1L_pada, RowSums1L_ext); + uint32x2_t RowSums1L = {vdups_laneq_u32(RowSums1L_add, 0), + vdups_laneq_u32(RowSums1L_add, 2)}; + + uint32x4_t RowSums1H_pada = vmovq_n_u32(0); + RowSums1H_pada = vpadalq_u16(RowSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z67))); + + uint32x4_t RowSums1H_ext = vextq_u32(RowSums1H_pada, RowSums1H_pada, 1); + uint32x4_t RowSums1H_add = vaddq_u32(RowSums1H_pada, RowSums1H_ext); + uint32x2_t RowSums1H = {vdups_laneq_u32(RowSums1H_add, 0), + vdups_laneq_u32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_u32(RowSums1, vcombine_u32(RowSums1L, RowSums1H)); + + D += 64; + k -= 8; + } + + if (k > 0) { + // + // zero pad the remaining columns to 8 + // + uint8_t* d = D; + + vst1q_u8(d, vmovq_n_u8(0)); + vst1q_u8(&d[16], vmovq_n_u8(0)); + vst1q_u8(&d[32], vmovq_n_u8(0)); + vst1q_u8(&d[48], vmovq_n_u8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + d[16] = *a2++; + d[24] = *a3++; + d[32] = *a4++; + d[40] = *a5++; + d[48] = *a6++; + d[56] = *a7++; + d += 1; + k -= 1; + } + d = D; + uint64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v2 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v3 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v4 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v5 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v6 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v7 = *reinterpret_cast(d); + d = d + 8; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint64x2_t z23 = vcombine_u64(v2, v3); + uint64x2_t z45 = vcombine_u64(v4, v5); + uint64x2_t z67 = vcombine_u64(v6, v7); + + uint32x4_t RowSums0L_pada = vmovq_n_u32(0); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSums0L_ext = vextq_u32(RowSums0L_pada, RowSums0L_pada, 1); + uint32x4_t RowSums0L_add = vaddq_u32(RowSums0L_pada, RowSums0L_ext); + uint32x2_t RowSums0L = {vdups_laneq_u32(RowSums0L_add, 0), + vdups_laneq_u32(RowSums0L_add, 2)}; + + uint32x4_t RowSums0H_pada = vmovq_n_u32(0); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z23))); + + uint32x4_t RowSums0H_ext = vextq_u32(RowSums0H_pada, RowSums0H_pada, 1); + uint32x4_t RowSums0H_add = vaddq_u32(RowSums0H_pada, RowSums0H_ext); + uint32x2_t RowSums0H = {vdups_laneq_u32(RowSums0H_add, 0), + vdups_laneq_u32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_u32(RowSums0, vcombine_u32(RowSums0L, RowSums0H)); + + uint32x4_t RowSums1L_pada = vmovq_n_u32(0); + RowSums1L_pada = vpadalq_u16(RowSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z45))); + + uint32x4_t RowSums1L_ext = vextq_u32(RowSums1L_pada, RowSums1L_pada, 1); + uint32x4_t RowSums1L_add = vaddq_u32(RowSums1L_pada, RowSums1L_ext); + uint32x2_t RowSums1L = {vdups_laneq_u32(RowSums1L_add, 0), + vdups_laneq_u32(RowSums1L_add, 2)}; + + uint32x4_t RowSums1H_pada = vmovq_n_u32(0); + RowSums1H_pada = vpadalq_u16(RowSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z67))); + + uint32x4_t RowSums1H_ext = vextq_u32(RowSums1H_pada, RowSums1H_pada, 1); + uint32x4_t RowSums1H_add = vaddq_u32(RowSums1H_pada, RowSums1H_ext); + uint32x2_t RowSums1H = {vdups_laneq_u32(RowSums1H_add, 0), + vdups_laneq_u32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_u32(RowSums1, vcombine_u32(RowSums1L, RowSums1H)); + + D += 64; + } + + vst1q_s32(RowSumBuffer, vreinterpretq_s32_u32(RowSums0)); + vst1q_s32(&RowSumBuffer[4], vreinterpretq_s32_u32(RowSums1)); + + RowSumBuffer += 8; + + A = A + lda * 8; + CountM -= 8; + } + + // + // Process four rows of matrix A. + // + // The buffer is packed as a series of 32 byte vectors where four rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // [ C0 C1 C2 C3 C4 C5 C6 C7 ] + // [ D0 D1 D2 D3 D4 D5 D6 D7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + if (CountM >= 4) { + const uint8_t* a0 = A; + const uint8_t* a1 = a0 + lda; + const uint8_t* a2 = a1 + lda; + const uint8_t* a3 = a2 + lda; + + size_t k = CountK; + uint32x4_t RowSums = vmovq_n_u32(0); + + while (k >= 16) { + uint64x2_t v0 = vld1q_u64(reinterpret_cast(a0)); + a0 += 16; + uint64x2_t v1 = vld1q_u64(reinterpret_cast(a1)); + a1 += 16; + uint64x2_t v2 = vld1q_u64(reinterpret_cast(a2)); + a2 += 16; + uint64x2_t v3 = vld1q_u64(reinterpret_cast(a3)); + a3 += 16; + + uint64x2_t z0 = vzip1q_u64(v0, v1); + uint64x2_t z1 = vzip2q_u64(v0, v1); + uint64x2_t z2 = vzip1q_u64(v2, v3); + uint64x2_t z3 = vzip2q_u64(v2, v3); + + vst1q_u8(&D[0], vreinterpretq_u8_u64(z0)); + vst1q_u8(&D[16], vreinterpretq_u8_u64(z2)); + vst1q_u8(&D[32], vreinterpretq_u8_u64(z1)); + vst1q_u8(&D[48], vreinterpretq_u8_u64(z3)); + + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z0))); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z1))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + uint32x4_t RowSumsH_pada = vmovq_n_u32(0); + RowSumsH_pada = vpadalq_u16(RowSumsH_pada, vpaddlq_u8(vreinterpretq_u8_u64(z2))); + RowSumsH_pada = vpadalq_u16(RowSumsH_pada, vpaddlq_u8(vreinterpretq_u8_u64(z3))); + + uint32x4_t RowSumsH_ext = vextq_u32(RowSumsH_pada, RowSumsH_pada, 1); + uint32x4_t RowSumsH_add = vaddq_u32(RowSumsH_pada, RowSumsH_ext); + uint32x2_t RowSumsH = {vdups_laneq_u32(RowSumsH_add, 0), + vdups_laneq_u32(RowSumsH_add, 2)}; + + RowSums = vaddq_u32(RowSums, vcombine_u32(RowSumsL, RowSumsH)); + + D += 64; + k -= 16; + } + + while (k >= 8) { + uint64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + uint64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + uint64x1_t v2 = *reinterpret_cast(a2); + a2 += 8; + uint64x1_t v3 = *reinterpret_cast(a3); + a3 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + *reinterpret_cast(&D[16]) = v2; + *reinterpret_cast(&D[24]) = v3; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint64x2_t z23 = vcombine_u64(v2, v3); + + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + uint32x4_t RowSumsH_pada = vmovq_n_u32(0); + RowSumsH_pada = vpadalq_u16(RowSumsH_pada, vpaddlq_u8(vreinterpretq_u8_u64(z23))); + + uint32x4_t RowSumsH_ext = vextq_u32(RowSumsH_pada, RowSumsH_pada, 1); + uint32x4_t RowSumsH_add = vaddq_u32(RowSumsH_pada, RowSumsH_ext); + uint32x2_t RowSumsH = {vdups_laneq_u32(RowSumsH_add, 0), + vdups_laneq_u32(RowSumsH_add, 2)}; + + RowSums = vaddq_u32(RowSums, vcombine_u32(RowSumsL, RowSumsH)); + + D += 32; + k -= 8; + } + + if (k > 0) { + // + // Copy the remaining bytes with zero padding. + // + uint8_t* d = D; + + vst1q_u8(d, vmovq_n_u8(0)); + vst1q_u8(&d[16], vmovq_n_u8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + d[16] = *a2++; + d[24] = *a3++; + d += 1; + k -= 1; + } + + d = D; + uint64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v2 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v3 = *reinterpret_cast(d); + d = d + 8; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint64x2_t z23 = vcombine_u64(v2, v3); + + uint32x4_t RowSums0L_pada = vmovq_n_u32(0); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSums0L_ext = vextq_u32(RowSums0L_pada, RowSums0L_pada, 1); + uint32x4_t RowSums0L_add = vaddq_u32(RowSums0L_pada, RowSums0L_ext); + uint32x2_t RowSums0L = {vdups_laneq_u32(RowSums0L_add, 0), + vdups_laneq_u32(RowSums0L_add, 2)}; + + uint32x4_t RowSums0H_pada = vmovq_n_u32(0); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z23))); + + uint32x4_t RowSums0H_ext = vextq_u32(RowSums0H_pada, RowSums0H_pada, 1); + uint32x4_t RowSums0H_add = vaddq_u32(RowSums0H_pada, RowSums0H_ext); + uint32x2_t RowSums0H = {vdups_laneq_u32(RowSums0H_add, 0), + vdups_laneq_u32(RowSums0H_add, 2)}; + + RowSums = vaddq_u32(RowSums, vcombine_u32(RowSums0L, RowSums0H)); + + D += 32; + } + + vst1q_s32(RowSumBuffer, vreinterpretq_s32_u32(RowSums)); + RowSumBuffer += 4; + + A = A + lda * 4; + CountM -= 4; + } + + // + // Process two rows of matrix A. + // + // The buffer is packed as a series of 16 byte vectors where two rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + if (CountM >= 2) { + const uint8_t* a0 = A; + const uint8_t* a1 = a0 + lda; + + size_t k = CountK; + uint32x2_t RowSums = vmov_n_u32(0); + + while (k >= 16) { + uint64x2_t v0 = vld1q_u64(reinterpret_cast(a0)); + a0 += 16; + uint64x2_t v1 = vld1q_u64(reinterpret_cast(a1)); + a1 += 16; + + uint64x2_t z0 = vzip1q_u64(v0, v1); + uint64x2_t z1 = vzip2q_u64(v0, v1); + + vst1q_u8(&D[0], vreinterpretq_u8_u64(z0)); + vst1q_u8(&D[16], vreinterpretq_u8_u64(z1)); + + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z0))); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z1))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + RowSums = vadd_u32(RowSums, RowSumsL); + + D += 32; + k -= 16; + } + + while (k >= 8) { + uint64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + uint64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + RowSums = vadd_u32(RowSums, RowSumsL); + + D += 16; + k -= 8; + } + + if (k > 0) { + // + // Zero pad the remaining elements to make 8 columns. + // + + uint8_t* d = PaddedMatrixAData; + vst1q_u8(PaddedMatrixAData, vmovq_n_u8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + + d += 1; + k -= 1; + } + + d = PaddedMatrixAData; + uint64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + RowSums = vadd_u32(RowSums, RowSumsL); + + uint8x16_t PackedVector = vld1q_u8(PaddedMatrixAData); + vst1q_u8(D, PackedVector); + + D += 16; + } + + vst1_s32(RowSumBuffer, vreinterpret_s32_u32(RowSums)); + RowSumBuffer += 2; + + A = A + lda * 2; + CountM -= 2; + } + + // + // Process one row of matrix A. + // + // The buffer is packed as a series of 8 byte with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of 8, then the vector is padded + // with zeroes. + // + + if (CountM > 0) { + // No need to pad the rows to 2, the .S takes care of zero pdding + const uint8_t* a = A; + size_t k = CountK; + uint32x4_t RowSums = vmovq_n_u32(0); + + while (k >= 16) { + uint8x16_t v = vld1q_u8(a); + a += 16; + + vst1q_u8(D, v); + + RowSums = vpadalq_u16(RowSums, vpaddlq_u8(v)); + + D += 16; + k -= 16; + } + + if (k > 0) { + // + // Copy the remaining bytes to the zero padded stack buffer. + // + + vst1q_u8(PaddedMatrixAData, vmovq_n_u8(0)); + + for (size_t kk = 0; kk < k; kk++) { + PaddedMatrixAData[kk] = a[kk]; + } + + uint8x16_t v = vld1q_u8(PaddedMatrixAData); + vst1q_u8(D, v); + + RowSums = vpadalq_u16(RowSums, vpaddlq_u8(v)); + } + + *RowSumBuffer = int32_t(vaddvq_u32(RowSums)); + } +} + +MLAS_FORCEINLINE +void +MlasGemmU8X8CopyPackBProcessUmmla(MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedBType* D, + uint8x8_t BytesRow[8], + uint8x16_t BitFlipVector, + uint32x4_t ColumnSums[2]) +{ + uint8x16_t v02 = veorq_u8(vcombine_u8(BytesRow[0], BytesRow[2]), BitFlipVector); + uint8x16_t v13 = veorq_u8(vcombine_u8(BytesRow[1], BytesRow[3]), BitFlipVector); + + uint8x16_t v46 = veorq_u8(vcombine_u8(BytesRow[4], BytesRow[6]), BitFlipVector); + uint8x16_t v57 = veorq_u8(vcombine_u8(BytesRow[5], BytesRow[7]), BitFlipVector); + + uint8x16x2_t zw1 = vzipq_u8(v02, v13); + uint16x8x2_t zd1 = + vzipq_u16(vreinterpretq_u16_u8(zw1.val[0]), vreinterpretq_u16_u8(zw1.val[1])); + + uint8x16x2_t zw2 = vzipq_u8(v46, v57); + uint16x8x2_t zd2 = + vzipq_u16(vreinterpretq_u16_u8(zw2.val[0]), vreinterpretq_u16_u8(zw2.val[1])); + + uint32x4x2_t zd3 = + vzipq_u32(vreinterpretq_u32_u16(zd1.val[0]), vreinterpretq_u32_u16(zd2.val[0])); + uint32x4x2_t zd4 = + vzipq_u32(vreinterpretq_u32_u16(zd1.val[1]), vreinterpretq_u32_u16(zd2.val[1])); + + vst1q_u8(&D[0], vreinterpretq_u8_u32(zd3.val[0])); + vst1q_u8(&D[16], vreinterpretq_u8_u32(zd3.val[1])); + vst1q_u8(&D[32], vreinterpretq_u8_u32(zd4.val[0])); + vst1q_u8(&D[48], vreinterpretq_u8_u32(zd4.val[1])); + + uint32x4_t ColSums0L_pada = vmovq_n_u32(0); + ColSums0L_pada = vpadalq_u16(ColSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u32(zd3.val[0]))); + uint32x4_t ColSums0L_ext = vextq_u32(ColSums0L_pada, ColSums0L_pada, 1); + uint32x4_t ColSums0L_add = vaddq_u32(ColSums0L_pada, ColSums0L_ext); + uint32x2_t ColSums0L = {vdups_laneq_u32(ColSums0L_add, 0), vdups_laneq_u32(ColSums0L_add, 2)}; + + uint32x4_t ColSums0H_pada = vmovq_n_u32(0); + ColSums0H_pada = vpadalq_u16(ColSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u32(zd3.val[1]))); + uint32x4_t ColSums0H_ext = vextq_u32(ColSums0H_pada, ColSums0H_pada, 1); + uint32x4_t ColSums0H_add = vaddq_u32(ColSums0H_pada, ColSums0H_ext); + uint32x2_t ColSums0H = {vdups_laneq_u32(ColSums0H_add, 0), vdups_laneq_u32(ColSums0H_add, 2)}; + + ColumnSums[0] = vaddq_u32(ColumnSums[0], vcombine_u32(ColSums0L, ColSums0H)); + + uint32x4_t ColSums1L_pada = vmovq_n_u32(0); + ColSums1L_pada = vpadalq_u16(ColSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u32(zd4.val[0]))); + uint32x4_t ColSums1L_ext = vextq_u32(ColSums1L_pada, ColSums1L_pada, 1); + uint32x4_t ColSums1L_add = vaddq_u32(ColSums1L_pada, ColSums1L_ext); + uint32x2_t ColSums1L = {vdups_laneq_u32(ColSums1L_add, 0), vdups_laneq_u32(ColSums1L_add, 2)}; + + uint32x4_t ColSums1H_pada = vmovq_n_u32(0); + ColSums1H_pada = vpadalq_u16(ColSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u32(zd4.val[1]))); + uint32x4_t ColSums1H_ext = vextq_u32(ColSums1H_pada, ColSums1H_pada, 1); + uint32x4_t ColSums1H_add = vaddq_u32(ColSums1H_pada, ColSums1H_ext); + uint32x2_t ColSums1H = {vdups_laneq_u32(ColSums1H_add, 0), vdups_laneq_u32(ColSums1H_add, 2)}; + + ColumnSums[1] = vaddq_u32(ColumnSums[1], vcombine_u32(ColSums1L, ColSums1H)); +} + +template <> +void +MlasGemmQuantCopyPackB(MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned) +{ + const uint8x16_t BitFlipVector = vdupq_n_u8(BIsSigned ? 0x80 : 0); + uint8x8_t BytesRow[8]; + + // + // Copy data from matrix B into the destination buffer 8x2 blocks at a + // time. + // + // + while (CountN >= 8) { + const uint8_t* b = B; + size_t k = CountK; + uint32x4_t ColumnSums[2]; + ColumnSums[0] = vmovq_n_u32(0); + ColumnSums[1] = vmovq_n_u32(0); + + while (k >= 8) { + BytesRow[0] = vld1_u8(&b[ldb * 0]); + BytesRow[1] = vld1_u8(&b[ldb * 1]); + BytesRow[2] = vld1_u8(&b[ldb * 2]); + BytesRow[3] = vld1_u8(&b[ldb * 3]); + BytesRow[4] = vld1_u8(&b[ldb * 4]); + BytesRow[5] = vld1_u8(&b[ldb * 5]); + BytesRow[6] = vld1_u8(&b[ldb * 6]); + BytesRow[7] = vld1_u8(&b[ldb * 7]); + + MlasGemmU8X8CopyPackBProcessUmmla(D, BytesRow, BitFlipVector, ColumnSums); + + D += 64; + b += ldb * 8; + k -= 8; + } + + if (k > 0) { + // Pad k to 8 + + BytesRow[0] = vld1_u8(&b[ldb * 0]); + BytesRow[1] = (k >= 2) ? vld1_u8(&b[ldb * 1]) : vget_low_u8(BitFlipVector); + BytesRow[2] = (k >= 3) ? vld1_u8(&b[ldb * 2]) : vget_low_u8(BitFlipVector); + BytesRow[3] = (k >= 4) ? vld1_u8(&b[ldb * 3]) : vget_low_u8(BitFlipVector); + BytesRow[4] = (k >= 5) ? vld1_u8(&b[ldb * 4]) : vget_low_u8(BitFlipVector); + BytesRow[5] = (k >= 6) ? vld1_u8(&b[ldb * 5]) : vget_low_u8(BitFlipVector); + BytesRow[6] = (k >= 7) ? vld1_u8(&b[ldb * 6]) : vget_low_u8(BitFlipVector); + BytesRow[7] = vget_low_u8(BitFlipVector); + + MlasGemmU8X8CopyPackBProcessUmmla(D, BytesRow, BitFlipVector, ColumnSums); + + D += 64; + } + + // Zero pad the output buffer to a multiple of PackedK if the above + // processed an odd number of four row bundles. + // + vst1q_s32(&ColumnSumBuffer[0], vreinterpretq_s32_u32(ColumnSums[0])); + vst1q_s32(&ColumnSumBuffer[4], vreinterpretq_s32_u32(ColumnSums[1])); + + ColumnSumBuffer += 8; + + B += 8; + CountN -= 8; + } + + // + // Process the remaining columns of matrix B. + // + + if (CountN > 0) { + const uint8_t* b = B; + size_t k = CountK; + uint8_t PaddedMatrixBData[64]; + uint32x4_t ColumnSums[2]; + + vst1q_u8(&PaddedMatrixBData[0], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[16], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[32], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[48], BitFlipVector); + + ColumnSums[0] = vmovq_n_u32(0); + ColumnSums[1] = vmovq_n_u32(0); + + // + // Interleave rows of matrix B using an intermediate zero padded stack + // buffer and write to the packed buffer. + // + + while (k > 0) { + const uint8_t* bcopy0 = &b[ldb * 0]; + const uint8_t* bcopy1 = &b[ldb * 1]; + const uint8_t* bcopy2 = &b[ldb * 2]; + const uint8_t* bcopy3 = &b[ldb * 3]; + const uint8_t* bcopy4 = &b[ldb * 4]; + const uint8_t* bcopy5 = &b[ldb * 5]; + const uint8_t* bcopy6 = &b[ldb * 6]; + const uint8_t* bcopy7 = &b[ldb * 7]; + + if (k >= 8) { + b += ldb * 8; + k -= 8; + + } else { + vst1q_u8(&PaddedMatrixBData[0], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[16], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[32], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[48], BitFlipVector); + + bcopy1 = (k >= 2) ? bcopy1 : &PaddedMatrixBData[56]; + bcopy2 = (k >= 3) ? bcopy2 : &PaddedMatrixBData[56]; + bcopy3 = (k >= 4) ? bcopy3 : &PaddedMatrixBData[56]; + bcopy4 = (k >= 5) ? bcopy4 : &PaddedMatrixBData[56]; + bcopy5 = (k >= 6) ? bcopy5 : &PaddedMatrixBData[56]; + bcopy6 = (k >= 7) ? bcopy6 : &PaddedMatrixBData[56]; + bcopy7 = &PaddedMatrixBData[56]; + + k = 0; + } + + uint8_t* padded = PaddedMatrixBData; + uint8_t* padded_end = padded + CountN; + do { + padded[0] = *bcopy0++; + padded[8] = *bcopy1++; + padded[16] = *bcopy2++; + padded[24] = *bcopy3++; + padded[32] = *bcopy4++; + padded[40] = *bcopy5++; + padded[48] = *bcopy6++; + padded[56] = *bcopy7++; + + } while (++padded < padded_end); + + BytesRow[0] = vld1_u8(&PaddedMatrixBData[0]); + BytesRow[1] = vld1_u8(&PaddedMatrixBData[8]); + BytesRow[2] = vld1_u8(&PaddedMatrixBData[16]); + BytesRow[3] = vld1_u8(&PaddedMatrixBData[24]); + BytesRow[4] = vld1_u8(&PaddedMatrixBData[32]); + BytesRow[5] = vld1_u8(&PaddedMatrixBData[40]); + BytesRow[6] = vld1_u8(&PaddedMatrixBData[48]); + BytesRow[7] = vld1_u8(&PaddedMatrixBData[56]); + + MlasGemmU8X8CopyPackBProcessUmmla(D, BytesRow, BitFlipVector, ColumnSums); + + D += 64; + } + + vst1q_s32(&ColumnSumBuffer[0], vreinterpretq_s32_u32(ColumnSums[0])); + vst1q_s32(&ColumnSumBuffer[4], vreinterpretq_s32_u32(ColumnSums[1])); + } +} + +template <> +MLAS_FORCEINLINE size_t +MlasGemmQuantKernel(const MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode) +{ + size_t RowsHandled; + + if (ZeroMode) { + RowsHandled = MlasGemmU8X8KernelUmmlaZero(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB); + } else { + RowsHandled = MlasGemmU8X8KernelUmmlaAdd(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB); + } + + return RowsHandled; +} + +const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUmmla = { + MlasGemmQuantOperation, + MlasGemmQuantPackedOperation, + MlasGemmQuantCopyPackB, + MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedK, + MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedStrides.K, + 8};