diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 51225a80f7906..3c0f82408179b 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -15,11 +15,11 @@ Module Name: --*/ -#include -#include - #include "mlasi.h" +#include +#include + #if defined(MLAS_TARGET_POWER) && defined(__linux__) #include #endif @@ -44,8 +44,8 @@ MLASCPUIDInfo::MLASCPUIDInfo() #elif defined(__linux__) -#include #include +#include // N.B. Support building with older versions of asm/hwcap.h that do not define // this capability bit. #ifndef HWCAP_ASIMDDP @@ -68,14 +68,14 @@ MLASCPUIDInfo::MLASCPUIDInfo() MLASCPUIDInfo::MLASCPUIDInfo() {} #endif -#endif // Windows vs Linux vs Unknown -#else // not MLAS_TARGET_ARM64 +#endif // Windows vs Linux vs Unknown +#else // not MLAS_TARGET_ARM64 #if defined(BUILD_MLAS_NO_ONNXRUNTIME) MLASCPUIDInfo::MLASCPUIDInfo() {} #endif -#endif // MLAS_TARGET_ARM64 +#endif // MLAS_TARGET_ARM64 #ifdef MLAS_TARGET_AMD64_IX86 @@ -83,15 +83,13 @@ MLASCPUIDInfo::MLASCPUIDInfo() {} // Stores a vector to build a conditional load/store mask for vmaskmovps. // -MLAS_INTERNAL_DATA -MLAS_DECLSPEC_ALIGN(const uint32_t MlasMaskMoveAvx[8], 32) = {0, 1, 2, 3, 4, 5, 6, 7}; +MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const uint32_t MlasMaskMoveAvx[8], 32) = { 0, 1, 2, 3, 4, 5, 6, 7 }; // // Stores a table of AVX vmaskmovps/vmaskmovpd load/store masks. // -MLAS_INTERNAL_DATA -MLAS_DECLSPEC_ALIGN(const uint32_t MlasMaskMoveTableAvx[16], 32) = { +MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const uint32_t MlasMaskMoveTableAvx[16], 32) = { 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, }; @@ -100,8 +98,7 @@ MLAS_DECLSPEC_ALIGN(const uint32_t MlasMaskMoveTableAvx[16], 32) = { // Stores a table of AVX512 opmask register values. // -MLAS_INTERNAL_DATA -MLAS_DECLSPEC_ALIGN(const int16_t MlasOpmask16BitTableAvx512[16], 32) = { +MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const int16_t MlasOpmask16BitTableAvx512[16], 32) = { 0x0000, 0x0001, 0x0003, 0x0007, 0x000F, 0x001F, 0x003F, 0x007F, 0x00FF, 0x01FF, 0x03FF, 0x07FF, 0x0FFF, 0x1FFF, 0x3FFF, 0x7FFF, }; @@ -123,15 +120,23 @@ MLAS_DECLSPEC_ALIGN(const int16_t MlasOpmask16BitTableAvx512[16], 32) = { #define XFEATURE_MASK_XTILE (XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA) #endif -inline uint64_t -MlasReadExtendedControlRegister(unsigned int ext_ctrl_reg) +inline +uint64_t +MlasReadExtendedControlRegister( + unsigned int ext_ctrl_reg +) { #if defined(_WIN32) return _xgetbv(ext_ctrl_reg); #else uint32_t eax, edx; - __asm__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(ext_ctrl_reg)); + __asm__ + ( + "xgetbv" + : "=a" (eax), "=d" (edx) + : "c" (ext_ctrl_reg) + ); return ((uint64_t)edx << 32) | eax; #endif @@ -167,9 +172,11 @@ MlasInitAMX() #endif } -#endif // MLAS_TARGET_AMD64_IX86 +#endif // MLAS_TARGET_AMD64_IX86 -MLAS_PLATFORM::MLAS_PLATFORM(void) +MLAS_PLATFORM::MLAS_PLATFORM( + void + ) /*++ Routine Description: @@ -186,6 +193,7 @@ Return Value: --*/ { + this->ConvDepthwiseU8S8Kernel = MlasConvDepthwiseKernel; this->ConvDepthwiseU8U8Kernel = MlasConvDepthwiseKernel; this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernel; @@ -259,6 +267,7 @@ Return Value: // if ((Cpuid1[2] & 0x18000000) == 0x18000000) { + // // Check if the operating system supports saving SSE and AVX states. // @@ -266,6 +275,7 @@ Return Value: uint64_t xcr0 = MlasReadExtendedControlRegister(_XCR_XFEATURE_ENABLED_MASK); if ((xcr0 & 0x6) == 0x6) { + this->GemmFloatKernel = MlasGemmFloatKernelAvx; #if defined(MLAS_TARGET_AMD64) @@ -279,10 +289,8 @@ Return Value: this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelAvx; this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelAvx; this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelAvx; - this->PoolFloatKernel[MlasAveragePoolingExcludePad] = - MlasPoolAverageExcludePadFloatKernelAvx; - this->PoolFloatKernel[MlasAveragePoolingIncludePad] = - MlasPoolAverageIncludePadFloatKernelAvx; + this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelAvx; + this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelAvx; this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32KernelAvx; this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32KernelAvx; this->ReduceMaximumF32Kernel = MlasReduceMaximumF32KernelAvx; @@ -301,6 +309,7 @@ Return Value: #endif if (((Cpuid1[2] & 0x1000) != 0) && ((Cpuid7[1] & 0x20) != 0)) { + this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAvx2; this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx2; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx2; @@ -346,6 +355,7 @@ Return Value: #endif if ((Cpuid7_1[0] & 0x10) != 0) { + this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAvx2; this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni; @@ -360,6 +370,7 @@ Return Value: // if (((Cpuid7[1] & 0x10000) != 0) && ((xcr0 & 0xE0) == 0xE0)) { + this->GemmFloatKernel = MlasGemmFloatKernelAvx512F; this->GemmDoubleKernel = MlasGemmDoubleKernelAvx512F; this->ConvNchwFloatKernel = MlasConvNchwFloatKernelAvx512F; @@ -367,10 +378,8 @@ Return Value: this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelAvx512F; this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelAvx512F; this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelAvx512F; - this->PoolFloatKernel[MlasAveragePoolingExcludePad] = - MlasPoolAverageExcludePadFloatKernelAvx512F; - this->PoolFloatKernel[MlasAveragePoolingIncludePad] = - MlasPoolAverageIncludePadFloatKernelAvx512F; + this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelAvx512F; + this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelAvx512F; this->ComputeExpF32Kernel = MlasComputeExpF32KernelAvx512F; this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelAvx512F; this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8KernelAvx512F; @@ -384,6 +393,7 @@ Return Value: // if ((Cpuid7[1] & 0xC0020000) == 0xC0020000) { + this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx512Core; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Core; this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx512Core; @@ -395,6 +405,7 @@ Return Value: // if ((Cpuid7[2] & 0x800) != 0) { + this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAvx2; this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx512Vnni; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Vnni; @@ -409,23 +420,26 @@ Return Value: // Check if the processor supports AMX-TILE and AMX-INT8 // features. // - if ((Cpuid7[3] & 0b1 << 24) != 0 && (Cpuid7[3] & 0b1 << 25) != 0 && + if ((Cpuid7[3] & 0b1 << 24) != 0 && + (Cpuid7[3] & 0b1 << 25) != 0 && (xcr0 & XFEATURE_MASK_XTILE) == XFEATURE_MASK_XTILE) { if (MlasInitAMX()) { this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAmx; this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAmx; } } -#endif // __APPLE__ +#endif // __APPLE__ + +#endif // ORT_MINIMAL_BUILD -#endif // ORT_MINIMAL_BUILD } -#endif // MLAS_TARGET_AMD64 +#endif // MLAS_TARGET_AMD64 + } } -#endif // MLAS_TARGET_AMD64_IX86 +#endif // MLAS_TARGET_AMD64_IX86 #if defined(MLAS_TARGET_ARM64) @@ -443,8 +457,7 @@ Return Value: bool HasDotProductInstructions; #if defined(_WIN32) - HasDotProductInstructions = - (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); + HasDotProductInstructions = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); #else // Use the cpuinfo value which is read from sysctl and has some additional special cases. // https://github.com/pytorch/cpuinfo/blob/959002f82d7962a473d8bf301845f2af720e0aa4/src/arm/mach/init.c#L369-L379 @@ -467,7 +480,7 @@ Return Value: this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; } -#endif // MLAS_TARGET_ARM64 +#endif // MLAS_TARGET_ARM64 #if defined(MLAS_TARGET_POWER) this->GemmFloatKernel = MlasSgemmKernel; this->GemmDoubleKernel = MlasDgemmKernel; @@ -486,7 +499,7 @@ Return Value: } #if defined(POWER10) -#if (defined(__GNUC__) && ((__GNUC__ > 10) || (__GNUC__ == 10 && __GNUC_MINOR__ >= 2))) || \ +#if (defined(__GNUC__) && ((__GNUC__ > 10) || (__GNUC__== 10 && __GNUC_MINOR__ >= 2))) || \ (defined(__clang__) && (__clang_major__ >= 12)) bool HasP10Instructions = ((hwcap2 & PPC_FEATURE2_MMA) && (hwcap2 & PPC_FEATURE2_ARCH_3_1)); if (HasP10Instructions) { @@ -497,12 +510,16 @@ Return Value: #endif #endif -#endif // __linux__ -#endif // MLAS_TARGET_POWER +#endif // __linux__ +#endif // MLAS_TARGET_POWER + } -size_t MLASCALL -MlasGetPreferredBufferAlignment(void) +size_t +MLASCALL +MlasGetPreferredBufferAlignment( + void + ) /*++ Routine Description: @@ -530,8 +547,11 @@ Return Value: #ifdef MLAS_TARGET_AMD64_IX86 -bool MLASCALL -MlasPlatformU8S8Overflow(void) +bool +MLASCALL +MlasPlatformU8S8Overflow( + void + ) { const auto& p = GetMlasPlatform(); return p.GemmU8U8Dispatch != p.GemmU8S8Dispatch; @@ -541,8 +561,7 @@ MlasPlatformU8S8Overflow(void) thread_local size_t ThreadedBufSize = 0; #ifdef _MSC_VER -thread_local std::unique_ptr ThreadedBufHolder(nullptr, - &_aligned_free); +thread_local std::unique_ptr ThreadedBufHolder(nullptr, &_aligned_free); #else thread_local std::unique_ptr ThreadedBufHolder(nullptr, &free); #endif