From 58f80c16ff50a10d59f911059388a860d4571fa9 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Mon, 14 Feb 2022 15:16:20 -0800 Subject: [PATCH] Create branch according to cpu core uarch (#10521) This is a preparation change for a bigger goal. On ARM64 CPUs with Big.Little, different cores are always the same architecture but different micro-architecture. Specifically, it is often that the little core has narrow memory buses that makes 128b load very slow. While if we always use 64b load in our kernels, the code will run slower on big cores. As a result, we need to run different code on different cores to achieve better performance. This change constructs a manifold that pivot based on the core micro-architecture of the current core, so that we can develop and call different kernels accordingly. Co-authored-by: Chen Fu --- onnxruntime/core/common/cpuid_info.cc | 17 +++++ onnxruntime/core/common/cpuid_info.h | 5 ++ onnxruntime/core/mlas/lib/mlasi.h | 69 +++++++++++++++++++ onnxruntime/core/mlas/lib/qgemm.cpp | 8 ++- onnxruntime/core/mlas/lib/qgemm.h | 3 +- .../core/mlas/lib/qgemm_kernel_neon.cpp | 1 + .../core/mlas/lib/qgemm_kernel_sdot.cpp | 1 + 7 files changed, 102 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index f9b084b682e00..f94720f05bad8 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -118,4 +118,21 @@ CPUIDInfo::CPUIDInfo() { } +int32_t CPUIDInfo::GetCurrentUarch() const { +#if (defined(CPUIDINFO_ARCH_X86) || defined(CPUIDINFO_ARCH_ARM)) && defined(CPUINFO_SUPPORTED) + if (!pytorch_cpuinfo_init_) { + return -1; + } + const auto uarchIdx = cpuinfo_get_current_uarch_index(); + const struct cpuinfo_uarch_info* uarch_info = cpuinfo_get_uarch(uarchIdx); + if (uarch_info == NULL) { + return -1; + } + return uarch_info->uarch; + +#else + return -1; +#endif +} + } // namespace onnxruntime diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index 66fc21ff55562..aa0cc485e9d84 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -34,6 +34,11 @@ class CPUIDInfo { // ARM bool HasArmNeonDot() const { return has_arm_neon_dot_; } + /** + * @return CPU core micro-architecture running the current thread + */ + int32_t GetCurrentUarch() const; + private: CPUIDInfo(); bool has_avx_{false}; diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index ca9fdef577f7b..c8959e29818dd 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1987,3 +1987,72 @@ MlasReadTimeStampCounter(void) #endif #endif } + +/** + * @brief IDs for cpu microarchitectures. + * + * Copied from python cpuinfo package. Can't use the definition + * from cpuinfo directly as it causes lots of compilation issues + * in many platforms that we support. + */ +enum MlasUArch { + cpuinfo_uarch_unknown = 0, + + /** ARM Cortex-A32. */ + cpuinfo_uarch_cortex_a32 = 0x00300332, + /** ARM Cortex-A35. */ + cpuinfo_uarch_cortex_a35 = 0x00300335, + /** ARM Cortex-A53. */ + cpuinfo_uarch_cortex_a53 = 0x00300353, + /** ARM Cortex-A55 revision 0 (restricted dual-issue capabilities compared to revision 1+). */ + cpuinfo_uarch_cortex_a55r0 = 0x00300354, + /** ARM Cortex-A55. */ + cpuinfo_uarch_cortex_a55 = 0x00300355, + /** ARM Cortex-A57. */ + cpuinfo_uarch_cortex_a57 = 0x00300357, + /** ARM Cortex-A65. */ + cpuinfo_uarch_cortex_a65 = 0x00300365, + /** ARM Cortex-A72. */ + cpuinfo_uarch_cortex_a72 = 0x00300372, + /** ARM Cortex-A73. */ + cpuinfo_uarch_cortex_a73 = 0x00300373, + /** ARM Cortex-A75. */ + cpuinfo_uarch_cortex_a75 = 0x00300375, + /** ARM Cortex-A76. */ + cpuinfo_uarch_cortex_a76 = 0x00300376, + /** ARM Cortex-A77. */ + cpuinfo_uarch_cortex_a77 = 0x00300377, + /** ARM Cortex-A78. */ + cpuinfo_uarch_cortex_a78 = 0x00300378, +}; + +enum MlasCoreType { mlas_core_unknown = 0, mlas_core_little = 2, mlas_core_big = 3 }; + +/** + * @return 2 current core is little core with narrow memory load (e.g. ARMv8 a53) + * 3 current core is big core with wider load (e.g. ARMv8 a72) + */ +MLAS_FORCEINLINE +int32_t +MlasGetCoreUArch() +{ + thread_local int32_t core_type = mlas_core_unknown; + if (core_type == mlas_core_unknown) { + // initialization needed +#if defined(MLAS_TARGET_ARM64) && defined(__linux__) + auto uarch = MLAS_CPUIDINFO::GetCPUIDInfo().GetCurrentUarch(); + if (uarch == cpuinfo_uarch_cortex_a53 || uarch == cpuinfo_uarch_cortex_a55r0 || + uarch == cpuinfo_uarch_cortex_a55) { + core_type = mlas_core_little; + } else { + core_type = mlas_core_big; + } +#else + core_type = mlas_core_big; +#endif // MLAS_TARGET_ARM64 + + } + return core_type; +} + + diff --git a/onnxruntime/core/mlas/lib/qgemm.cpp b/onnxruntime/core/mlas/lib/qgemm.cpp index 33ac5cfc27466..772d28cb9875a 100644 --- a/onnxruntime/core/mlas/lib/qgemm.cpp +++ b/onnxruntime/core/mlas/lib/qgemm.cpp @@ -205,11 +205,13 @@ MlasSymmQgemmBatch( const size_t N = Shape.N; const size_t K = Shape.K; const MLAS_SYMM_QGEMM_DISPATCH* dispatch = GetMlasPlatform().SymmQgemmDispatch; - MLAS_SYMM_QGEMM_OPERATION* operation = dispatch->Operation; if (ThreadPool == nullptr) { // So our caller handles threaded job partition. // Call single threaded operation directly + auto uarch = MlasGetCoreUArch(); + MLAS_SYMM_QGEMM_OPERATION* operation = + uarch == mlas_core_little ? dispatch->LitOperation : dispatch->BigOperation; for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { auto Data = &DataParams[gemm_i]; @@ -258,6 +260,10 @@ MlasSymmQgemmBatch( ThreadsPerGemm = ThreadCountM * ThreadCountN; MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { + auto uarch = MlasGetCoreUArch(); + MLAS_SYMM_QGEMM_OPERATION* operation = + uarch == mlas_core_little ? dispatch->LitOperation : dispatch->BigOperation; + const auto gemm_i = tid / ThreadsPerGemm; const auto blk_i = tid % ThreadsPerGemm; auto Data = &DataParams[gemm_i]; diff --git a/onnxruntime/core/mlas/lib/qgemm.h b/onnxruntime/core/mlas/lib/qgemm.h index ab566c5b50645..2f6168c527050 100644 --- a/onnxruntime/core/mlas/lib/qgemm.h +++ b/onnxruntime/core/mlas/lib/qgemm.h @@ -802,7 +802,8 @@ struct MLAS_GEMM_QUANT_DISPATCH { }; struct MLAS_SYMM_QGEMM_DISPATCH { - MLAS_SYMM_QGEMM_OPERATION* Operation; + MLAS_SYMM_QGEMM_OPERATION* LitOperation; /// running on little cores with narrow memory load + MLAS_SYMM_QGEMM_OPERATION* BigOperation; /// running on big cores with wider memory load MLAS_GEMM_QUANT_COPY_PACKB_ROUTINE* CopyPackBRoutine; size_t StrideM; /**< num of rows processed by kernel at a time */ size_t PackedK; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_neon.cpp index ce9460be9b501..0b747bc7cc84b 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_neon.cpp @@ -1217,6 +1217,7 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmX8S8DispatchNeon = { }; const MLAS_SYMM_QGEMM_DISPATCH MlasSymmQgemmS8DispatchNeon = { + MlasSymmQGemmPackedOperation, MlasSymmQGemmPackedOperation, MlasGemmQuantCopyPackB, 4, // StrideM diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_sdot.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_sdot.cpp index 5b2d9a6e6cce3..604986cf9f662 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_sdot.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_sdot.cpp @@ -1027,6 +1027,7 @@ size_t MlasSymmQGemmKernel( } const MLAS_SYMM_QGEMM_DISPATCH MlasSymmQgemmS8DispatchSdot = { + MlasSymmQGemmPackedOperation, MlasSymmQGemmPackedOperation, MlasGemmQuantCopyPackB, 4, // StrideM