Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aarch64] Add Sbgemm kernel to accelerate fp32 tensor matmul with bfloat16 #17031

Merged
merged 6 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -355,19 +355,23 @@ else()
${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S
${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S
${MLAS_SRC_DIR}/activate_fp16.cpp
${MLAS_SRC_DIR}/dwconv.cpp
${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/pooling_fp16.cpp
${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp
${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
endif()

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,4 +249,10 @@ static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_p
// Flag to specify whether to dump the EP context into the Onnx model.
// "0": dump the EP context into separate file, keep the file name in the Onnx model.
// "1": dump the EP context into the Onnx model. (default).
static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode";
static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode";

// Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul.
// Option values:
// - "0": Gemm FastMath mode is not enabled. [DEFAULT]
// - "1": Gemm FastMath mode is enabled.
static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";
7 changes: 7 additions & 0 deletions onnxruntime/core/common/cpuid_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
#define HWCAP2_SVEI8MM (1 << 9)
#endif

#ifndef HWCAP2_BF16
#define HWCAP2_BF16 (1 << 14)
#endif

#endif // ARM

#endif // Linux
Expand Down Expand Up @@ -148,6 +152,7 @@ void CPUIDInfo::ArmLinuxInit() {
has_fp16_ = cpuinfo_has_arm_neon_fp16_arith();
has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm();
has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm();
has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16();

const uint32_t core_cnt = cpuinfo_get_cores_count();
core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown);
Expand Down Expand Up @@ -177,6 +182,7 @@ void CPUIDInfo::ArmLinuxInit() {
has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0);
has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0);

has_arm_neon_bf16_ = ((getauxval(AT_HWCAP2) & HWCAP2_BF16) != 0);
#endif
}

Expand Down Expand Up @@ -278,6 +284,7 @@ void CPUIDInfo::ArmWindowsInit() {
/* TODO: implement them when hw+sw is available for testing these features */
has_arm_neon_i8mm_ = false;
has_arm_sve_i8mm_ = false;
has_arm_neon_bf16_ = false;
}

#endif /* (arm or arm64) and windows */
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/common/cpuid_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class CPUIDInfo {
bool HasArmNeonDot() const { return has_arm_neon_dot_; }
bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; }
bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; }
bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; }

uint32_t GetCurrentCoreIdx() const;

Expand Down Expand Up @@ -125,6 +126,7 @@ class CPUIDInfo {
bool has_fp16_{false};
bool has_arm_neon_i8mm_{false};
bool has_arm_sve_i8mm_{false};
bool has_arm_neon_bf16_{false};

#ifdef CPUIDINFO_ARCH_X86

Expand Down
113 changes: 113 additions & 0 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,119 @@
void* PackedB
);

#if defined(__aarch64__) && defined(__linux__)
/**
* @brief Whether current CPU supports Bfloat16(bf16) acceleration.
*/
bool MLASCALL
MlasBf16AccelerationSupported();

/**
* @brief Interface for bf16 gemm post processors.
*
* Example implementation of this interface includes activations,
* conversion from single precision to precision, etc.
*
* SBGEMM is computed tile by tile. When a tile of result matrix
* is produced, the method Process() is called to process this tile.
* Parameters of this method describe the location and shape of the
* tile.
*/
class MLAS_SBGEMM_POSTPROCESSOR
{

Check warning on line 1636 in onnxruntime/core/mlas/inc/mlas.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 { should almost always be at the end of the previous line [whitespace/braces] [4] Raw Output: onnxruntime/core/mlas/inc/mlas.h:1636: { should almost always be at the end of the previous line [whitespace/braces] [4]
public:

Check warning on line 1637 in onnxruntime/core/mlas/inc/mlas.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 public: should be indented +1 space inside class MLAS_SBGEMM_POSTPROCESSOR [whitespace/indent] [3] Raw Output: onnxruntime/core/mlas/inc/mlas.h:1637: public: should be indented +1 space inside class MLAS_SBGEMM_POSTPROCESSOR [whitespace/indent] [3]
virtual void Process(float*, /**< the address of matrix to process */
size_t, /**< the start row index of matrix */
size_t, /**< the start col index of matrix */
size_t, /**< the element count per row to process */
size_t, /**< the element count per col to process */
size_t /**< the leading dimension of matrix */
) const = 0;

Check warning on line 1644 in onnxruntime/core/mlas/inc/mlas.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Closing ) should be moved to the previous line [whitespace/parens] [2] Raw Output: onnxruntime/core/mlas/inc/mlas.h:1644: Closing ) should be moved to the previous line [whitespace/parens] [2]

virtual ~MLAS_SBGEMM_POSTPROCESSOR() {}
};

/**
* @brief bfloat16 precision activation functions, with optional sum tensor.
* Supplied sum tensor must be the same layout as the GEMM output tensor.
* And the supplied sum tensor will be added to the tensor before activation.
*/
class MLAS_SBGEMM_ACTIVATION_PROCESSOR : public MLAS_SBGEMM_POSTPROCESSOR
{

Check warning on line 1655 in onnxruntime/core/mlas/inc/mlas.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 { should almost always be at the end of the previous line [whitespace/braces] [4] Raw Output: onnxruntime/core/mlas/inc/mlas.h:1655: { should almost always be at the end of the previous line [whitespace/braces] [4]
public:

Check warning on line 1656 in onnxruntime/core/mlas/inc/mlas.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 public: should be indented +1 space inside class MLAS_SBGEMM_ACTIVATION_PROCESSOR [whitespace/indent] [3] Raw Output: onnxruntime/core/mlas/inc/mlas.h:1656: public: should be indented +1 space inside class MLAS_SBGEMM_ACTIVATION_PROCESSOR [whitespace/indent] [3]
MLAS_SBGEMM_ACTIVATION_PROCESSOR(const MLAS_ACTIVATION& Activation, const float* SumBuf = nullptr)

Check warning on line 1657 in onnxruntime/core/mlas/inc/mlas.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Constructors callable with one argument should be marked explicit. [runtime/explicit] [5] Raw Output: onnxruntime/core/mlas/inc/mlas.h:1657: Constructors callable with one argument should be marked explicit. [runtime/explicit] [5]
: Activation_(Activation), SumBuf_(SumBuf)
{

Check warning on line 1659 in onnxruntime/core/mlas/inc/mlas.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 { should almost always be at the end of the previous line [whitespace/braces] [4] Raw Output: onnxruntime/core/mlas/inc/mlas.h:1659: { should almost always be at the end of the previous line [whitespace/braces] [4]
}

void Process(float* C, size_t StartM, size_t StartN, size_t CountM, size_t CountN, size_t ldc)
const override;

private:

Check warning on line 1665 in onnxruntime/core/mlas/inc/mlas.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 private: should be indented +1 space inside class MLAS_SBGEMM_ACTIVATION_PROCESSOR [whitespace/indent] [3] Raw Output: onnxruntime/core/mlas/inc/mlas.h:1665: private: should be indented +1 space inside class MLAS_SBGEMM_ACTIVATION_PROCESSOR [whitespace/indent] [3]
const MLAS_ACTIVATION& Activation_;
const float* SumBuf_;
};

/**
* @brief Data parameters for bfloat16 precision GEMM routine
* All except C are [in] parameters
*/
struct MLAS_SBGEMM_DATA_PARAMS {
const void* A = nullptr; /**< address of A */
const void* B = nullptr; /**< address of B */
const float* Bias = nullptr; /**< address of Bias, vector size N */
float* C = nullptr; /**< address of result matrix */
size_t lda = 0; /**< leading dimension of A */
size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/
size_t ldc = 0; /**< leading dimension of C*/
const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr;
bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/
bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/
};

/**
* @brief Bfloat16 precision Batched GEMM: C = A * B + Bias
* Either B can be either fp32 or bf16
*
* Note: We only support uniform batching, so shapes and types of the
* input must be same across all parameter blocks.
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BatchN number of batches
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @param[in] ThreadPool
* @return
*/
void MLASCALL
MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool = nullptr);

Check warning on line 1703 in onnxruntime/core/mlas/inc/mlas.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/mlas/inc/mlas.h:1703: Lines should be <= 120 characters long [whitespace/line_length] [2]

/**
* @brief For bfloat16 precision GEMM, returns size of the
* packing buffer needed for right hand side
* @param[in] N Number of columns
* @param[in] K Number of rows
* @return size of the packing buffer,
* 0 if operation not supported
*/
size_t MLASCALL
MlasSBGemmPackBSize(size_t N, size_t K);

/**
* @brief For bfloat16 precision GEMM, convert the float matrix B
* to blfoat16 precision and pack it into a packing buffer
*
* @param[in] N Number of columns
* @param[in] K Number of rows
* @param[in] B Address of matrix B
* @param[in] ldb leading dimension of input matrix B
* @param[out] PackedB Address of the packed matrix
*/
void MLASCALL
MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB);
#endif

/**
* @brief Indirect Depthwise convolution for fp16
* @param Input Supplies the indirect buffer for NHWC input
Expand Down
Loading
Loading