Skip to content

Commit

Permalink
[aarch64] Add Sbgemm kernel to accelerate fp32 tensor matmul with bfl…
Browse files Browse the repository at this point in the history
…oat16 (#17031)

### Description
This PR adds SbgemmKernel for aarch64. This includes Sbegmm kernel to
implement matrix multiplication with bfloat16 SIMD instructions (bfmmla)
and MatMul operator changes to invoke the Sbgemm kernel. To enable
Sbgemm kernel, set the following session option:
"kOrtSessionOptionsGemmFastMathMode"

The PR also adds new test cases for mlas and ort.

### Motivation and Context

This is to improve MatMul performance on aarch64 platform.
I have run the below benchmarking script (bert , roberta and gpt2 model
inference) on AWS Graviton3 based c7g.4xl instance and observed 1.2x
-1.76x performance improvement compared to sgemm (fp32) kernel
performance.

```
cd onnxruntime/python/tools/transformers
python3 benchmark.py
```
And the unit test precision results are matching to sgemm kernel
results.
`./build.sh --config RelWithDebInfo --build_shared_lib --parallel
--compile_no_warning_as_error --skip_submodule_sync `
  • Loading branch information
snadampal authored and rachguo committed Jan 23, 2024
1 parent 6aa7f79 commit 041dfd5
Show file tree
Hide file tree
Showing 17 changed files with 3,473 additions and 18 deletions.
4 changes: 4 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -355,19 +355,23 @@ else()
${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S
${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S
${MLAS_SRC_DIR}/activate_fp16.cpp
${MLAS_SRC_DIR}/dwconv.cpp
${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/pooling_fp16.cpp
${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp
${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
endif()

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,4 +249,10 @@ static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_p
// Flag to specify whether to dump the EP context into the Onnx model.
// "0": dump the EP context into separate file, keep the file name in the Onnx model.
// "1": dump the EP context into the Onnx model. (default).
static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode";
static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode";

// Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul.
// Option values:
// - "0": Gemm FastMath mode is not enabled. [DEFAULT]
// - "1": Gemm FastMath mode is enabled.
static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";
7 changes: 7 additions & 0 deletions onnxruntime/core/common/cpuid_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
#define HWCAP2_SVEI8MM (1 << 9)
#endif

#ifndef HWCAP2_BF16
#define HWCAP2_BF16 (1 << 14)
#endif

#endif // ARM

#endif // Linux
Expand Down Expand Up @@ -148,6 +152,7 @@ void CPUIDInfo::ArmLinuxInit() {
has_fp16_ = cpuinfo_has_arm_neon_fp16_arith();
has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm();
has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm();
has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16();

const uint32_t core_cnt = cpuinfo_get_cores_count();
core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown);
Expand Down Expand Up @@ -177,6 +182,7 @@ void CPUIDInfo::ArmLinuxInit() {
has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0);
has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0);

has_arm_neon_bf16_ = ((getauxval(AT_HWCAP2) & HWCAP2_BF16) != 0);
#endif
}

Expand Down Expand Up @@ -278,6 +284,7 @@ void CPUIDInfo::ArmWindowsInit() {
/* TODO: implement them when hw+sw is available for testing these features */
has_arm_neon_i8mm_ = false;
has_arm_sve_i8mm_ = false;
has_arm_neon_bf16_ = false;
}

#endif /* (arm or arm64) and windows */
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/common/cpuid_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class CPUIDInfo {
bool HasArmNeonDot() const { return has_arm_neon_dot_; }
bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; }
bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; }
bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; }

uint32_t GetCurrentCoreIdx() const;

Expand Down Expand Up @@ -125,6 +126,7 @@ class CPUIDInfo {
bool has_fp16_{false};
bool has_arm_neon_i8mm_{false};
bool has_arm_sve_i8mm_{false};
bool has_arm_neon_bf16_{false};

#ifdef CPUIDINFO_ARCH_X86

Expand Down
113 changes: 113 additions & 0 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,119 @@ MlasHalfGemmConvertPackB(
void* PackedB
);

#if defined(__aarch64__) && defined(__linux__)
/**
* @brief Whether current CPU supports Bfloat16(bf16) acceleration.
*/
bool MLASCALL
MlasBf16AccelerationSupported();

/**
* @brief Interface for bf16 gemm post processors.
*
* Example implementation of this interface includes activations,
* conversion from single precision to precision, etc.
*
* SBGEMM is computed tile by tile. When a tile of result matrix
* is produced, the method Process() is called to process this tile.
* Parameters of this method describe the location and shape of the
* tile.
*/
class MLAS_SBGEMM_POSTPROCESSOR
{
public:
virtual void Process(float*, /**< the address of matrix to process */
size_t, /**< the start row index of matrix */
size_t, /**< the start col index of matrix */
size_t, /**< the element count per row to process */
size_t, /**< the element count per col to process */
size_t /**< the leading dimension of matrix */
) const = 0;

virtual ~MLAS_SBGEMM_POSTPROCESSOR() {}
};

/**
* @brief bfloat16 precision activation functions, with optional sum tensor.
* Supplied sum tensor must be the same layout as the GEMM output tensor.
* And the supplied sum tensor will be added to the tensor before activation.
*/
class MLAS_SBGEMM_ACTIVATION_PROCESSOR : public MLAS_SBGEMM_POSTPROCESSOR
{
public:
MLAS_SBGEMM_ACTIVATION_PROCESSOR(const MLAS_ACTIVATION& Activation, const float* SumBuf = nullptr)
: Activation_(Activation), SumBuf_(SumBuf)
{
}

void Process(float* C, size_t StartM, size_t StartN, size_t CountM, size_t CountN, size_t ldc)
const override;

private:
const MLAS_ACTIVATION& Activation_;
const float* SumBuf_;
};

/**
* @brief Data parameters for bfloat16 precision GEMM routine
* All except C are [in] parameters
*/
struct MLAS_SBGEMM_DATA_PARAMS {
const void* A = nullptr; /**< address of A */
const void* B = nullptr; /**< address of B */
const float* Bias = nullptr; /**< address of Bias, vector size N */
float* C = nullptr; /**< address of result matrix */
size_t lda = 0; /**< leading dimension of A */
size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/
size_t ldc = 0; /**< leading dimension of C*/
const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr;
bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/
bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/
};

/**
* @brief Bfloat16 precision Batched GEMM: C = A * B + Bias
* Either B can be either fp32 or bf16
*
* Note: We only support uniform batching, so shapes and types of the
* input must be same across all parameter blocks.
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BatchN number of batches
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @param[in] ThreadPool
* @return
*/
void MLASCALL
MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool = nullptr);

/**
* @brief For bfloat16 precision GEMM, returns size of the
* packing buffer needed for right hand side
* @param[in] N Number of columns
* @param[in] K Number of rows
* @return size of the packing buffer,
* 0 if operation not supported
*/
size_t MLASCALL
MlasSBGemmPackBSize(size_t N, size_t K);

/**
* @brief For bfloat16 precision GEMM, convert the float matrix B
* to blfoat16 precision and pack it into a packing buffer
*
* @param[in] N Number of columns
* @param[in] K Number of rows
* @param[in] B Address of matrix B
* @param[in] ldb leading dimension of input matrix B
* @param[out] PackedB Address of the packed matrix
*/
void MLASCALL
MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB);
#endif

/**
* @brief Indirect Depthwise convolution for fp16
* @param Input Supplies the indirect buffer for NHWC input
Expand Down
Loading

0 comments on commit 041dfd5

Please sign in to comment.