Skip to content

Commit

Permalink
Reduce formatting diffs
Browse files Browse the repository at this point in the history
  • Loading branch information
skottmckay committed Oct 11, 2023
1 parent 3907118 commit f49afe7
Showing 1 changed file with 62 additions and 43 deletions.
105 changes: 62 additions & 43 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ Module Name:
--*/

#include <mutex>
#include <thread>

#include "mlasi.h"

#include <thread>
#include <mutex>

#if defined(MLAS_TARGET_POWER) && defined(__linux__)
#include <sys/auxv.h>
#endif
Expand All @@ -44,8 +44,8 @@ MLASCPUIDInfo::MLASCPUIDInfo()

#elif defined(__linux__)

#include <asm/hwcap.h>
#include <sys/auxv.h>
#include <asm/hwcap.h>
// N.B. Support building with older versions of asm/hwcap.h that do not define
// this capability bit.
#ifndef HWCAP_ASIMDDP
Expand All @@ -68,30 +68,28 @@ MLASCPUIDInfo::MLASCPUIDInfo()
MLASCPUIDInfo::MLASCPUIDInfo() {}
#endif

#endif // Windows vs Linux vs Unknown
#else // not MLAS_TARGET_ARM64
#endif // Windows vs Linux vs Unknown
#else // not MLAS_TARGET_ARM64

#if defined(BUILD_MLAS_NO_ONNXRUNTIME)
MLASCPUIDInfo::MLASCPUIDInfo() {}
#endif

#endif // MLAS_TARGET_ARM64
#endif // MLAS_TARGET_ARM64

#ifdef MLAS_TARGET_AMD64_IX86

//
// Stores a vector to build a conditional load/store mask for vmaskmovps.
//

MLAS_INTERNAL_DATA
MLAS_DECLSPEC_ALIGN(const uint32_t MlasMaskMoveAvx[8], 32) = {0, 1, 2, 3, 4, 5, 6, 7};
MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const uint32_t MlasMaskMoveAvx[8], 32) = { 0, 1, 2, 3, 4, 5, 6, 7 };

//
// Stores a table of AVX vmaskmovps/vmaskmovpd load/store masks.
//

MLAS_INTERNAL_DATA
MLAS_DECLSPEC_ALIGN(const uint32_t MlasMaskMoveTableAvx[16], 32) = {
MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const uint32_t MlasMaskMoveTableAvx[16], 32) = {
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
};
Expand All @@ -100,8 +98,7 @@ MLAS_DECLSPEC_ALIGN(const uint32_t MlasMaskMoveTableAvx[16], 32) = {
// Stores a table of AVX512 opmask register values.
//

MLAS_INTERNAL_DATA
MLAS_DECLSPEC_ALIGN(const int16_t MlasOpmask16BitTableAvx512[16], 32) = {
MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const int16_t MlasOpmask16BitTableAvx512[16], 32) = {
0x0000, 0x0001, 0x0003, 0x0007, 0x000F, 0x001F, 0x003F, 0x007F,
0x00FF, 0x01FF, 0x03FF, 0x07FF, 0x0FFF, 0x1FFF, 0x3FFF, 0x7FFF,
};
Expand All @@ -123,15 +120,23 @@ MLAS_DECLSPEC_ALIGN(const int16_t MlasOpmask16BitTableAvx512[16], 32) = {
#define XFEATURE_MASK_XTILE (XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA)
#endif

inline uint64_t
MlasReadExtendedControlRegister(unsigned int ext_ctrl_reg)
inline
uint64_t
MlasReadExtendedControlRegister(
unsigned int ext_ctrl_reg
)
{
#if defined(_WIN32)
return _xgetbv(ext_ctrl_reg);
#else
uint32_t eax, edx;

__asm__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(ext_ctrl_reg));
__asm__
(
"xgetbv"
: "=a" (eax), "=d" (edx)
: "c" (ext_ctrl_reg)
);

return ((uint64_t)edx << 32) | eax;
#endif
Expand Down Expand Up @@ -167,9 +172,11 @@ MlasInitAMX()
#endif
}

#endif // MLAS_TARGET_AMD64_IX86
#endif // MLAS_TARGET_AMD64_IX86

MLAS_PLATFORM::MLAS_PLATFORM(void)
MLAS_PLATFORM::MLAS_PLATFORM(
void
)
/*++
Routine Description:
Expand All @@ -186,6 +193,7 @@ Return Value:
--*/
{

this->ConvDepthwiseU8S8Kernel = MlasConvDepthwiseKernel<uint8_t, int8_t>;
this->ConvDepthwiseU8U8Kernel = MlasConvDepthwiseKernel<uint8_t, uint8_t>;
this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernel<int8_t, int8_t>;
Expand Down Expand Up @@ -259,13 +267,15 @@ Return Value:
//

if ((Cpuid1[2] & 0x18000000) == 0x18000000) {

//
// Check if the operating system supports saving SSE and AVX states.
//

uint64_t xcr0 = MlasReadExtendedControlRegister(_XCR_XFEATURE_ENABLED_MASK);

if ((xcr0 & 0x6) == 0x6) {

this->GemmFloatKernel = MlasGemmFloatKernelAvx;

#if defined(MLAS_TARGET_AMD64)
Expand All @@ -279,10 +289,8 @@ Return Value:
this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelAvx;
this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelAvx;
this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelAvx;
this->PoolFloatKernel[MlasAveragePoolingExcludePad] =
MlasPoolAverageExcludePadFloatKernelAvx;
this->PoolFloatKernel[MlasAveragePoolingIncludePad] =
MlasPoolAverageIncludePadFloatKernelAvx;
this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelAvx;
this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelAvx;
this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32KernelAvx;
this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32KernelAvx;
this->ReduceMaximumF32Kernel = MlasReduceMaximumF32KernelAvx;
Expand All @@ -301,6 +309,7 @@ Return Value:
#endif

if (((Cpuid1[2] & 0x1000) != 0) && ((Cpuid7[1] & 0x20) != 0)) {

this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAvx2;
this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx2;
this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx2;
Expand Down Expand Up @@ -346,6 +355,7 @@ Return Value:
#endif

if ((Cpuid7_1[0] & 0x10) != 0) {

this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAvx2;
this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni;
this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni;
Expand All @@ -360,17 +370,16 @@ Return Value:
//

if (((Cpuid7[1] & 0x10000) != 0) && ((xcr0 & 0xE0) == 0xE0)) {

this->GemmFloatKernel = MlasGemmFloatKernelAvx512F;
this->GemmDoubleKernel = MlasGemmDoubleKernelAvx512F;
this->ConvNchwFloatKernel = MlasConvNchwFloatKernelAvx512F;
this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelAvx512F;
this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelAvx512F;
this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelAvx512F;
this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelAvx512F;
this->PoolFloatKernel[MlasAveragePoolingExcludePad] =
MlasPoolAverageExcludePadFloatKernelAvx512F;
this->PoolFloatKernel[MlasAveragePoolingIncludePad] =
MlasPoolAverageIncludePadFloatKernelAvx512F;
this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelAvx512F;
this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelAvx512F;
this->ComputeExpF32Kernel = MlasComputeExpF32KernelAvx512F;
this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelAvx512F;
this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8KernelAvx512F;
Expand All @@ -384,6 +393,7 @@ Return Value:
//

if ((Cpuid7[1] & 0xC0020000) == 0xC0020000) {

this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx512Core;
this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Core;
this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx512Core;
Expand All @@ -395,6 +405,7 @@ Return Value:
//

if ((Cpuid7[2] & 0x800) != 0) {

this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAvx2;
this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx512Vnni;
this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Vnni;
Expand All @@ -409,23 +420,26 @@ Return Value:
// Check if the processor supports AMX-TILE and AMX-INT8
// features.
//
if ((Cpuid7[3] & 0b1 << 24) != 0 && (Cpuid7[3] & 0b1 << 25) != 0 &&
if ((Cpuid7[3] & 0b1 << 24) != 0 &&
(Cpuid7[3] & 0b1 << 25) != 0 &&
(xcr0 & XFEATURE_MASK_XTILE) == XFEATURE_MASK_XTILE) {
if (MlasInitAMX()) {
this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAmx;
this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAmx;
}
}
#endif // __APPLE__
#endif // __APPLE__

#endif // ORT_MINIMAL_BUILD

#endif // ORT_MINIMAL_BUILD
}

#endif // MLAS_TARGET_AMD64
#endif // MLAS_TARGET_AMD64

}
}

#endif // MLAS_TARGET_AMD64_IX86
#endif // MLAS_TARGET_AMD64_IX86

#if defined(MLAS_TARGET_ARM64)

Expand All @@ -443,8 +457,7 @@ Return Value:
bool HasDotProductInstructions;

#if defined(_WIN32)
HasDotProductInstructions =
(IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0);
HasDotProductInstructions = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0);
#else
// Use the cpuinfo value which is read from sysctl and has some additional special cases.
// https://github.com/pytorch/cpuinfo/blob/959002f82d7962a473d8bf301845f2af720e0aa4/src/arm/mach/init.c#L369-L379
Expand All @@ -467,7 +480,7 @@ Return Value:
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot;
}

#endif // MLAS_TARGET_ARM64
#endif // MLAS_TARGET_ARM64
#if defined(MLAS_TARGET_POWER)
this->GemmFloatKernel = MlasSgemmKernel;
this->GemmDoubleKernel = MlasDgemmKernel;
Expand All @@ -486,7 +499,7 @@ Return Value:
}

#if defined(POWER10)
#if (defined(__GNUC__) && ((__GNUC__ > 10) || (__GNUC__ == 10 && __GNUC_MINOR__ >= 2))) || \
#if (defined(__GNUC__) && ((__GNUC__ > 10) || (__GNUC__== 10 && __GNUC_MINOR__ >= 2))) || \
(defined(__clang__) && (__clang_major__ >= 12))
bool HasP10Instructions = ((hwcap2 & PPC_FEATURE2_MMA) && (hwcap2 & PPC_FEATURE2_ARCH_3_1));
if (HasP10Instructions) {
Expand All @@ -497,12 +510,16 @@ Return Value:
#endif
#endif

#endif // __linux__
#endif // MLAS_TARGET_POWER
#endif // __linux__
#endif // MLAS_TARGET_POWER

}

size_t MLASCALL
MlasGetPreferredBufferAlignment(void)
size_t
MLASCALL
MlasGetPreferredBufferAlignment(
void
)
/*++
Routine Description:
Expand Down Expand Up @@ -530,8 +547,11 @@ Return Value:

#ifdef MLAS_TARGET_AMD64_IX86

bool MLASCALL
MlasPlatformU8S8Overflow(void)
bool
MLASCALL
MlasPlatformU8S8Overflow(
void
)
{
const auto& p = GetMlasPlatform();
return p.GemmU8U8Dispatch != p.GemmU8S8Dispatch;
Expand All @@ -541,8 +561,7 @@ MlasPlatformU8S8Overflow(void)

thread_local size_t ThreadedBufSize = 0;
#ifdef _MSC_VER
thread_local std::unique_ptr<uint8_t, decltype(&_aligned_free)> ThreadedBufHolder(nullptr,
&_aligned_free);
thread_local std::unique_ptr<uint8_t, decltype(&_aligned_free)> ThreadedBufHolder(nullptr, &_aligned_free);
#else
thread_local std::unique_ptr<uint8_t, decltype(&free)> ThreadedBufHolder(nullptr, &free);
#endif

0 comments on commit f49afe7

Please sign in to comment.