From 77da2ef278a4e77cca4cef4e5d72ed1ef46fcce3 Mon Sep 17 00:00:00 2001 From: snadampal <87143774+snadampal@users.noreply.github.com> Date: Mon, 22 Jan 2024 16:43:06 -0600 Subject: [PATCH] [aarch64] Add Sbgemm kernel to accelerate fp32 tensor matmul with bfloat16 (#17031) ### Description This PR adds SbgemmKernel for aarch64. This includes Sbegmm kernel to implement matrix multiplication with bfloat16 SIMD instructions (bfmmla) and MatMul operator changes to invoke the Sbgemm kernel. To enable Sbgemm kernel, set the following session option: "kOrtSessionOptionsGemmFastMathMode" The PR also adds new test cases for mlas and ort. ### Motivation and Context This is to improve MatMul performance on aarch64 platform. I have run the below benchmarking script (bert , roberta and gpt2 model inference) on AWS Graviton3 based c7g.4xl instance and observed 1.2x -1.76x performance improvement compared to sgemm (fp32) kernel performance. ``` cd onnxruntime/python/tools/transformers python3 benchmark.py ``` And the unit test precision results are matching to sgemm kernel results. `./build.sh --config RelWithDebInfo --build_shared_lib --parallel --compile_no_warning_as_error --skip_submodule_sync ` --- cmake/onnxruntime_mlas.cmake | 4 + .../onnxruntime_session_options_config_keys.h | 8 +- onnxruntime/core/common/cpuid_info.cc | 7 + onnxruntime/core/common/cpuid_info.h | 2 + onnxruntime/core/mlas/inc/mlas.h | 113 +++ .../core/mlas/lib/aarch64/SbgemmKernelNeon.S | 907 ++++++++++++++++++ onnxruntime/core/mlas/lib/mlasi.h | 25 + onnxruntime/core/mlas/lib/platform.cpp | 6 + onnxruntime/core/mlas/lib/sbgemm.h | 399 ++++++++ .../core/mlas/lib/sbgemm_kernel_neon.cpp | 362 +++++++ onnxruntime/core/providers/cpu/math/matmul.cc | 106 +- onnxruntime/core/providers/cpu/math/matmul.h | 15 + .../test/mlas/unittest/test_sbgemm.cpp | 141 +++ onnxruntime/test/mlas/unittest/test_sbgemm.h | 281 ++++++ .../qdq_transformer_fastmath_test.cc | 730 ++++++++++++++ .../cpu/math/matmul_fastmath_test.cc | 305 ++++++ onnxruntime/test/util/compare_ortvalue.cc | 80 ++ 17 files changed, 3473 insertions(+), 18 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S create mode 100644 onnxruntime/core/mlas/lib/sbgemm.h create mode 100644 onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp create mode 100644 onnxruntime/test/mlas/unittest/test_sbgemm.cpp create mode 100644 onnxruntime/test/mlas/unittest/test_sbgemm.h create mode 100644 onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc create mode 100644 onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index f89d2150a6830..17de2aa4aaea6 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -355,19 +355,23 @@ else() ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S + ${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S ${MLAS_SRC_DIR}/activate_fp16.cpp ${MLAS_SRC_DIR}/dwconv.cpp ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/pooling_fp16.cpp ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp + ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 8fd51962bf087..b282438795eb5 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -249,4 +249,10 @@ static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_p // Flag to specify whether to dump the EP context into the Onnx model. // "0": dump the EP context into separate file, keep the file name in the Onnx model. // "1": dump the EP context into the Onnx model. (default). -static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode"; \ No newline at end of file +static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode"; + +// Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul. +// Option values: +// - "0": Gemm FastMath mode is not enabled. [DEFAULT] +// - "1": Gemm FastMath mode is enabled. +static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16"; diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index fcf9c2b03dea5..711fd595e90fd 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -30,6 +30,10 @@ #define HWCAP2_SVEI8MM (1 << 9) #endif +#ifndef HWCAP2_BF16 +#define HWCAP2_BF16 (1 << 14) +#endif + #endif // ARM #endif // Linux @@ -148,6 +152,7 @@ void CPUIDInfo::ArmLinuxInit() { has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); + has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); const uint32_t core_cnt = cpuinfo_get_cores_count(); core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); @@ -177,6 +182,7 @@ void CPUIDInfo::ArmLinuxInit() { has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); + has_arm_neon_bf16_ = ((getauxval(AT_HWCAP2) & HWCAP2_BF16) != 0); #endif } @@ -278,6 +284,7 @@ void CPUIDInfo::ArmWindowsInit() { /* TODO: implement them when hw+sw is available for testing these features */ has_arm_neon_i8mm_ = false; has_arm_sve_i8mm_ = false; + has_arm_neon_bf16_ = false; } #endif /* (arm or arm64) and windows */ diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index a15c75104b83a..2f8041e39f680 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -30,6 +30,7 @@ class CPUIDInfo { bool HasArmNeonDot() const { return has_arm_neon_dot_; } bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; } bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } + bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; } uint32_t GetCurrentCoreIdx() const; @@ -125,6 +126,7 @@ class CPUIDInfo { bool has_fp16_{false}; bool has_arm_neon_i8mm_{false}; bool has_arm_sve_i8mm_{false}; + bool has_arm_neon_bf16_{false}; #ifdef CPUIDINFO_ARCH_X86 diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index bdd4dba521eba..ce7838556fbf0 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1614,6 +1614,119 @@ MlasHalfGemmConvertPackB( void* PackedB ); +#if defined(__aarch64__) && defined(__linux__) +/** + * @brief Whether current CPU supports Bfloat16(bf16) acceleration. + */ +bool MLASCALL +MlasBf16AccelerationSupported(); + +/** + * @brief Interface for bf16 gemm post processors. + * + * Example implementation of this interface includes activations, + * conversion from single precision to precision, etc. + * + * SBGEMM is computed tile by tile. When a tile of result matrix + * is produced, the method Process() is called to process this tile. + * Parameters of this method describe the location and shape of the + * tile. + */ +class MLAS_SBGEMM_POSTPROCESSOR +{ + public: + virtual void Process(float*, /**< the address of matrix to process */ + size_t, /**< the start row index of matrix */ + size_t, /**< the start col index of matrix */ + size_t, /**< the element count per row to process */ + size_t, /**< the element count per col to process */ + size_t /**< the leading dimension of matrix */ + ) const = 0; + + virtual ~MLAS_SBGEMM_POSTPROCESSOR() {} +}; + +/** + * @brief bfloat16 precision activation functions, with optional sum tensor. + * Supplied sum tensor must be the same layout as the GEMM output tensor. + * And the supplied sum tensor will be added to the tensor before activation. + */ +class MLAS_SBGEMM_ACTIVATION_PROCESSOR : public MLAS_SBGEMM_POSTPROCESSOR +{ + public: + MLAS_SBGEMM_ACTIVATION_PROCESSOR(const MLAS_ACTIVATION& Activation, const float* SumBuf = nullptr) + : Activation_(Activation), SumBuf_(SumBuf) + { + } + + void Process(float* C, size_t StartM, size_t StartN, size_t CountM, size_t CountN, size_t ldc) + const override; + + private: + const MLAS_ACTIVATION& Activation_; + const float* SumBuf_; +}; + +/** + * @brief Data parameters for bfloat16 precision GEMM routine + * All except C are [in] parameters + */ +struct MLAS_SBGEMM_DATA_PARAMS { + const void* A = nullptr; /**< address of A */ + const void* B = nullptr; /**< address of B */ + const float* Bias = nullptr; /**< address of Bias, vector size N */ + float* C = nullptr; /**< address of result matrix */ + size_t lda = 0; /**< leading dimension of A */ + size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/ + size_t ldc = 0; /**< leading dimension of C*/ + const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr; + bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/ + bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/ +}; + +/** + * @brief Bfloat16 precision Batched GEMM: C = A * B + Bias + * Either B can be either fp32 or bf16 + * + * Note: We only support uniform batching, so shapes and types of the + * input must be same across all parameter blocks. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] ThreadPool + * @return + */ +void MLASCALL +MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool = nullptr); + +/** + * @brief For bfloat16 precision GEMM, returns size of the + * packing buffer needed for right hand side + * @param[in] N Number of columns + * @param[in] K Number of rows + * @return size of the packing buffer, + * 0 if operation not supported + */ +size_t MLASCALL +MlasSBGemmPackBSize(size_t N, size_t K); + +/** + * @brief For bfloat16 precision GEMM, convert the float matrix B + * to blfoat16 precision and pack it into a packing buffer + * + * @param[in] N Number of columns + * @param[in] K Number of rows + * @param[in] B Address of matrix B + * @param[in] ldb leading dimension of input matrix B + * @param[out] PackedB Address of the packed matrix + */ +void MLASCALL +MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB); +#endif + /** * @brief Indirect Depthwise convolution for fp16 * @param Input Supplies the indirect buffer for NHWC input diff --git a/onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S new file mode 100644 index 0000000000000..e424c30515e9f --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S @@ -0,0 +1,907 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + SbgemmKernelNeon.s + +Abstract: + + This module implements the kernels for the bfloat16 half precision matrix/matrix + multiply operation (SBGEMM). + +--*/ + +#include "asmmacro.h" + + .text + +// +// Stack frame layout for the sbgemm kernel. d8-d15, x19-x30 need save +// + .equ .LMlasSbgemmKernel_backup_x19_x20, 0 + .equ .LMlasSbgemmKernel_backup_x21_x22, 16 + .equ .LMlasSbgemmKernel_backup_x23_x24, 32 + .equ .LMlasSbgemmKernel_backup_x25_x26, 48 + .equ .LMlasSbgemmKernel_backup_x27_x28, 64 + .equ .LMlasSbgemmKernel_backup_d8_d9, 80 + .equ .LMlasSbgemmKernel_backup_d10_d11, 96 + .equ .LMlasSbgemmKernel_backup_d12_d13, 112 + .equ .LMlasSbgemmKernel_backup_d14_d15, 128 + .equ .LMlasSbgemmKernel_SavedRegisters, 144 + .equ .LMlasSbgemmKernel_SavedRegisters_Neg, -144 + + +// +// ClearRowAccumulators +// +// Generates the code to clear the accumulators for a single row of the output +// block. +// + + .macro InitRowAccumulators Columns, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + mov v\Vec1Reg\().16b,v0.16b +.if \Columns\() > 2 + mov v\Vec2Reg\().16b,v1.16b +.endif +.if \Columns\() > 4 + mov v\Vec3Reg\().16b,v2.16b +.endif +.if \Columns\() > 6 + mov v\Vec4Reg\().16b,v3.16b +.endif + + .endm + +// +// InitBlockAccumulators +// +// Generates the code to init the accumulators for a single row of the output +// block. +// + + .macro InitBlockAccumulators Mode, Columns, Rows + + //check if the Bias != nullptr + cbz x8,.L\Mode\().InitBlock\Columns\().x\Rows\().SkipBiasAdd + + ld1 {v14.4s},[x8],#16 // load Bias[0] + // v4~v7 will be set to matrixB after this, so, they can used now + dup v4.4s,v14.s[0] // broadcast Bias + dup v5.4s,v14.s[1] + dup v6.4s,v14.s[2] + dup v7.4s,v14.s[3] + + zip1 v0.4s, v4.4s, v5.4s + zip2 v1.4s, v6.4s, v7.4s +.if \Columns\() > 4 + ld1 {v15.4s},[x8],#16 // load Bias[4] + dup v4.4s,v15.s[0] // broadcast Bias + dup v5.4s,v15.s[1] + dup v6.4s,v15.s[2] + dup v7.4s,v15.s[3] + + zip1 v2.4s, v4.4s, v5.4s + zip2 v3.4s, v6.4s, v7.4s +.endif + + b .L\Mode\().PopulateAccumulators\Columns\().x\Rows\() + +.L\Mode\().InitBlock\Columns\().x\Rows\().SkipBiasAdd: + eor v0.16b,v0.16b,v0.16b // No bias, reset regs + eor v1.16b,v1.16b,v1.16b + eor v2.16b,v2.16b,v2.16b + eor v3.16b,v3.16b,v3.16b + +.L\Mode\().PopulateAccumulators\Columns\().x\Rows\(): + InitRowAccumulators \Columns\(),16,17,18,19 +.if \Rows\() > 2 + InitRowAccumulators \Columns\(),20,21,22,23 +.endif +.if \Rows\() > 4 + InitRowAccumulators \Columns\(),24,25,26,27 +.endif +.if \Rows\() > 6 + InitRowAccumulators \Columns\(),28,29,30,31 +.endif + + .endm + +// LoadMatrixAElementsBy8 +// +// Generates the code to load 4 or 8 elements from matrix A. +// + .macro LoadMatrixAElementsBy8 Rows + + ldr q8,[x0],#16 + bfcvtn v8.4h, v8.4s +.if \Rows\() > 1 + ldr q1,[x10],#16 + bfcvtn2 v8.8h, v1.4s +.endif + +.if \Rows\() > 2 + ldr q9,[x11],#16 + bfcvtn v9.4h, v9.4s +.endif +.if \Rows\() > 3 + ldr q1,[x12],#16 + bfcvtn2 v9.8h, v1.4s +.endif + +.if \Rows\() > 4 + ldr q10,[x20],#16 + bfcvtn v10.4h, v10.4s +.endif +.if \Rows\() > 5 + ldr q1,[x21],#16 + bfcvtn2 v10.8h, v1.4s +.endif + +.if \Rows\() > 6 + ldr q11,[x22],#16 + bfcvtn v11.4h, v11.4s +.endif +.if \Rows\() > 7 + ldr q1,[x23],#16 + bfcvtn2 v11.8h, v1.4s +.endif + + .endm + + +// +// MultiplyAccumulateRow +// +// Generates the code to multiply and accumulate a single row of the output +// block. +// + + .macro MultiplyAccumulateRow Columns, MatrixAReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + bfmmla v\Vec1Reg\().4s, \MatrixAReg\().8h, v4.8h +.if \Columns\() > 2 + bfmmla v\Vec2Reg\().4s, \MatrixAReg\().8h, v5.8h +.endif +.if \Columns\() > 4 + bfmmla v\Vec3Reg\().4s, \MatrixAReg\().8h, v6.8h +.endif +.if \Columns\() > 6 + bfmmla v\Vec4Reg\().4s, \MatrixAReg\().8h, v7.8h +.endif + + .endm + +// +// MultiplyAccumulateBlock +// +// Generates the code to multiply and accumulate into the output block. +// + + .macro MultiplyAccumulateBlock Columns, Rows + + MultiplyAccumulateRow \Columns\(),v8,16,17,18,19 +.if \Rows\() > 2 + MultiplyAccumulateRow \Columns\(),v9,20,21,22,23 +.endif +.if \Rows\() > 4 + MultiplyAccumulateRow \Columns\(),v10,24,25,26,27 +.endif +.if \Rows\() > 6 + MultiplyAccumulateRow \Columns\(),v11,28,29,30,31 +.endif + + .endm + +// +// ComputeBlockLoop +// +// Generates the code to loop over K entries of the input matrices to produce +// the output block. +// + + .macro ComputeBlockLoop Mode, Columns, Rows + + InitBlockAccumulators \Mode\(),\Columns\(),\Rows\() + + add x10,x0,x6,lsl #2 // compute matrix A plus 1 row +.if \Rows\() > 2 + add x11,x10,x6,lsl #2 // compute matrix A plus 2 rows + add x12,x11,x6,lsl #2 // compute matrix A plus 3 rows +.endif +.if \Rows\() > 4 + add x20,x12,x6,lsl #2 // compute matrix A plus 4 rows + add x21,x20,x6,lsl #2 // compute matrix A plus 5 rows +.endif +.if \Rows\() > 6 + add x22,x21,x6,lsl #2 // compute matrix A plus 6 rows + add x23,x22,x6,lsl #2 // compute matrix A plus 7 rows +.endif + sub x9,x3,#4 // block count to process + tbnz x9,#63,.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop: + + LoadMatrixAElementsBy8 \Rows\() + ldr q4, [x1],#16 +.if \Columns\() > 2 + ldr q5,[x1],#16 +.endif +.if \Columns\() > 4 + ldr q6,[x1],#16 +.endif +.if \Columns\() > 6 + ldr q7,[x1],#16 +.endif + MultiplyAccumulateBlock \Columns\(),\Rows\() + + sub x9,x9,#4 + tbz x9,#63,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop +.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks: + add x9,x9,#4 // correct for over-subtract above + cbz x9,.L\Mode\().Output\Columns\().x\Rows\().Block + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4PaddedLoop: + LoadMatrixAElementsBy8 \Rows\() + ldr q4, [x1],#16 +.if \Columns\() > 2 + ldr q5,[x1],#16 +.endif +.if \Columns\() > 4 + ldr q6,[x1],#16 +.endif +.if \Columns\() > 6 + ldr q7,[x1],#16 +.endif + MultiplyAccumulateBlock \Columns\(),\Rows\() + +.L\Mode\().Output\Columns\().x\Rows\().Block: + + .endm + + +// +// OutputRow2Element +// OutputRow4Element +// OutputRow6Element +// OutputRow8Element +// OutputRow10Element +// OutputRow12Element +// OutputRow14Element +// OutputRow16Element +// +// Generates the code to store elements to the output block. +// + + .macro OutputRow2Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr s8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr s9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + mov v8.S[2], v9.S[0] + + fadd v8.4s,v8.4s,v\Vec1Reg\().4s + + mov w27, v8.S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v8.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + mov w27, v\Vec1Reg\().S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v\Vec1Reg\().S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow4Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + + mov v8.D[1], v9.D[0] + + fadd v8.4s,v8.4s,v\Vec1Reg\().4s + + mov x27, v8.D[0] + mov x28, v8.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + mov x27, v\Vec1Reg\().D[0] + mov x28, v\Vec1Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.endif + + .endm + + + .macro OutputRow6Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#8 + ldr w28,[\AddrReg1\()],#-8 + mov v8.S[2], w28 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-8 + mov v9.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + mov x27, v8.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v8.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v9.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v9.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + mov x27, v4.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v4.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v5.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v5.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow8Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + +.endif + + .endm + + + .macro OutputRow10Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr w28, [\AddrReg1\()],#-16 + +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr w27,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + mov v8.S[0], w28 + mov v8.S[2], w27 + + fadd v8.4s,v8.4s,v\Vec3Reg\().4s + + mov w27, v8.S[0] + mov w28, v8.S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov w27, v\Vec3Reg\().S[0] + mov w28, v\Vec3Reg\().S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif +.endif + +.endm + + + .macro OutputRow12Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#-16 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + mov v11.D[0],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + + mov v10.D[1], v11.D[0] + + fadd v10.4s,v10.4s,v\Vec3Reg\().4s + + mov x27, v10.D[0] + mov x28, v10.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov x27, v\Vec3Reg\().D[0] + mov x28, v\Vec3Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif +.endif + + .endm + + .macro OutputRow14Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#8 + ldr w28, [\AddrReg1\()],#-24 + mov v10.S[2], w28 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-24 + mov v11.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + fadd v10.4s,v10.4s,v6.4s + fadd v11.4s,v11.4s,v7.4s + + str q8,[\AddrReg1\()],#16 + + mov x27, v10.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v10.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 + mov x27, v11.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v11.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + str q4,[\AddrReg1\()],#16 + mov x27, v6.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v6.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 + mov x27, v7.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v7.S[2] + str w27, [\AddrReg2\()],#4 +.endif +.endif + + .endm + + + .macro OutputRow16Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldp q8,q10,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldp q9,q11,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + fadd v10.4s,v10.4s,v6.4s + fadd v11.4s,v11.4s,v7.4s + + stp q8,q10,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q9,q11,[\AddrReg2\()],#32 +.endif +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + stp q4,q6,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q5,q7,[\AddrReg2\()],#32 +.endif +.endif + + .endm + +// +// OutputBlock +// +// Generates the code to store the output block. +// + + .macro OutputBlock Mode, Columns, Rows + + OutputRow\Columns\()Element \Mode\(),x2,x13,16,17,18,19,(\Rows\() == 1) + +.if \Rows\() > 2 + OutputRow\Columns\()Element \Mode\(),x14,x15,20,21,22,23,(\Rows\() == 3) +.endif + +.if \Rows\() > 4 + OutputRow\Columns\()Element \Mode\(),x16,x17,24,25,26,27,(\Rows\() == 5) +.endif + +.if \Rows\() > 6 + OutputRow\Columns\()Element \Mode\(),x18,x19,28,29,30,31,(\Rows\() == 7) +.endif + + .endm +// +// ProcessRows +// +// Generates the code to process a compute and store the output block for a +// fixed number of rows. +// + + .macro ProcessRows Mode, Rows + mov x4,#\Rows\() // return number of rows handled + cmp x5,#6 + ble .L\Mode\().ProcessNextColumnLoop6x\Rows\() + +.L\Mode\().ProcessNextColumnLoop8x\Rows\(): + ComputeBlockLoop \Mode\(),8,\Rows\() + + sub x5,x5,#8 + cmp x5,#0 + blt .L\Mode\().Output14ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),16,\Rows\() + mov x0,x26 // reload matrix A + cmp x5,#6 + bgt .L\Mode\().ProcessNextColumnLoop8x\Rows\() + cbz x5,.L\Mode\().ExitKernel + + +.L\Mode\().ProcessNextColumnLoop6x\Rows\(): + + cmp x5,#4 + ble .L\Mode\().ProcessNextColumnLoop4x\Rows\() + ComputeBlockLoop \Mode\(),6,\Rows\() + sub x5,x5,#6 + cmp x5,#0 + blt .L\Mode\().Output10ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),12,\Rows\() + + mov x0,x26 // reload matrix A + cmp x5,#4 + bgt .L\Mode\().ProcessNextColumnLoop6x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop4x\Rows\(): + cmp x5,#2 + ble .L\Mode\().ProcessNextColumnLoop2x\Rows\() + ComputeBlockLoop \Mode\(),4,\Rows\() + sub x5,x5,#4 + cmp x5,#0 + blt .L\Mode\().Output6ElementsOnlyFor\Rows\() + + OutputBlock \Mode\(),8,\Rows\() + + mov x0,x26 // reload matrix A + cmp x5,#2 + bgt .L\Mode\().ProcessNextColumnLoop4x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop2x\Rows\(): + ComputeBlockLoop \Mode\(),2,\Rows\() + sub x5,x5,#2 + cmp x5,#0 + blt .L\Mode\().Output2ElementsOnlyFor\Rows\() + + OutputBlock \Mode\(),4,\Rows\() + + mov x0,x26 // reload matrix A + cmp x5,#2 + b .L\Mode\().ExitKernel + +.L\Mode\().Output14ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),14,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output10ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),10,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output6ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),6,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output2ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),2,\Rows\() + b .L\Mode\().ExitKernel + + .endm + + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (x0) - Supplies the address of matrix A. + + B (x1) - Supplies the address of matrix B. The matrix data has been packed + using MlasSbgemmCopyPackB or MlasSbgemmTransposePackB. + + C (x2) - Supplies the address of matrix C. + + CountK (x3) - Supplies the number of columns from matrix A and the number + of rows from matrix B to iterate over. + + CountM (x4) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (x5) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + lda (x6) - Supplies the first dimension of matrix A. + + ldc (x7) - Supplies the first dimension of matrix C. + + Bias - Supplies the address of Bias Vector [1xn] + + +Return Value: + + Returns the number of rows handled. + +--*/ + .macro SbgemmKernelNeonFunction Mode + + FUNCTION_ENTRY MlasSbgemmKernel\Mode\() + + ldr x8, [sp, #0] //Bias vector + + stp x19, x20, [sp, #.LMlasSbgemmKernel_SavedRegisters_Neg]! + stp x21, x22, [sp, #.LMlasSbgemmKernel_backup_x21_x22] + stp x23, x24, [sp, #.LMlasSbgemmKernel_backup_x23_x24] + stp x25, x26, [sp, #.LMlasSbgemmKernel_backup_x25_x26] + stp x27, x28, [sp, #.LMlasSbgemmKernel_backup_x27_x28] + stp d8, d9, [sp, #.LMlasSbgemmKernel_backup_d8_d9] + stp d10, d11, [sp, #.LMlasSbgemmKernel_backup_d10_d11] + stp d12, d13, [sp, #.LMlasSbgemmKernel_backup_d12_d13] + stp d14, d15, [sp, #.LMlasSbgemmKernel_backup_d14_d15] + + add x13,x2,x7,lsl #2 // compute matrix C plus 1 row + add x14,x13,x7,lsl #2 // compute matrix C plus 2 rows + add x15,x14,x7,lsl #2 // compute matrix C plus 3 rows + add x16,x15,x7,lsl #2 // compute matrix C plus 4 rows + add x17,x16,x7,lsl #2 // compute matrix C plus 5 rows + add x18,x17,x7,lsl #2 // compute matrix C plus 6 rows + add x19,x18,x7,lsl #2 // compute matrix C plus 7 rows + + mov x26,x0 // save matrix A +// +// Process 8 rows of the matrices. +// + cmp x4,#8 + blt .L\Mode\().ProcessCountMLessThan8 + ProcessRows \Mode\(),8 + +// +// Restore non-volatile registers and return. +// + +.L\Mode\().ExitKernel: + mov x0,x4 + + ldp d14, d15, [sp, #.LMlasSbgemmKernel_backup_d14_d15] + ldp d12, d13, [sp, #.LMlasSbgemmKernel_backup_d12_d13] + ldp d10, d11, [sp, #.LMlasSbgemmKernel_backup_d10_d11] + ldp d8, d9, [sp, #.LMlasSbgemmKernel_backup_d8_d9] + ldp x27, x28, [sp, #.LMlasSbgemmKernel_backup_x27_x28] + ldp x25, x26, [sp, #.LMlasSbgemmKernel_backup_x25_x26] + ldp x23, x24, [sp, #.LMlasSbgemmKernel_backup_x23_x24] + ldp x21, x22, [sp, #.LMlasSbgemmKernel_backup_x21_x22] + ldp x19, x20, [sp], #.LMlasSbgemmKernel_SavedRegisters + + ret + +// +// Process 4 rows of the matrix. +// + +.L\Mode\().ProcessCountMLessThan8: + cmp x4,#4 + blt .L\Mode\().ProcessCountMLessThan4 + ProcessRows \Mode\(),4 + b .L\Mode\().ExitKernel + +// +// Process 2 row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan4: + cmp x4,#2 + blt .L\Mode\().ProcessCountMLessThan2 + + ProcessRows \Mode\(),2 + b .L\Mode\().ExitKernel + + +// +// Process the last row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan2: + ProcessRows \Mode\(),1 + b .L\Mode\().ExitKernel + + + .endm + + SbgemmKernelNeonFunction Zero + SbgemmKernelNeonFunction Add diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 7bb8b17031a84..624eb913d5c9e 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -193,6 +193,8 @@ class MLASCPUIDInfo bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } + bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; } + private: MLASCPUIDInfo(); @@ -200,6 +202,7 @@ class MLASCPUIDInfo bool has_fp16_{false}; bool has_arm_neon_i8mm_{false}; bool has_arm_sve_i8mm_{false}; + bool has_arm_neon_bf16_{false}; }; using MLAS_CPUIDINFO = MLASCPUIDInfo; @@ -357,6 +360,20 @@ size_t #else +#if defined(__aarch64__) && defined(__linux__) +typedef size_t(MLASCALL MLAS_SBGEMM_FLOAT_KERNEL)( + const float* A, + const bfloat16_t* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + const float* Bias +); +#endif + typedef size_t (MLASCALL MLAS_GEMM_FLOAT_KERNEL)( @@ -727,6 +744,10 @@ extern "C" { #else MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelZero; MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelAdd; +#if defined(__aarch64__) && defined(__linux__) + MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelZero; + MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelAdd; +#endif MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelZero; MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelAdd; #endif @@ -856,6 +877,10 @@ extern "C" { #define MLAS_DGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) #define MLAS_QGEMM_THREAD_COMPLEXITY 65536 +#if defined(__aarch64__) && defined(__linux__) +#define MLAS_SBGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) +#endif + // // Single-threaded single precision matrix/matrix multiply operation. // diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 1310ed3f384b9..de092f7d1d350 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -60,6 +60,10 @@ MLASCPUIDInfo::MLASCPUIDInfo() #define HWCAP2_SVEI8MM (1 << 9) #endif +#ifndef HWCAP2_BF16 +#define HWCAP2_BF16 (1 << 14) +#endif + #if defined(BUILD_MLAS_NO_ONNXRUNTIME) MLASCPUIDInfo::MLASCPUIDInfo() { @@ -70,6 +74,8 @@ MLASCPUIDInfo::MLASCPUIDInfo() has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); + + has_arm_neon_bf16_ = ((getauxval(AT_HWCAP2) & HWCAP2_BF16) != 0); } #endif diff --git a/onnxruntime/core/mlas/lib/sbgemm.h b/onnxruntime/core/mlas/lib/sbgemm.h new file mode 100644 index 0000000000000..de7fd72fad45a --- /dev/null +++ b/onnxruntime/core/mlas/lib/sbgemm.h @@ -0,0 +1,399 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + sbgemm.h + +Abstract: + + This module defines the set of template functions to implement bfloat16 + precision matrix/matrix multiply operation (SBGEMM). + + To implement a new kernel, template functions below need to be specialized: + MlasSBGemmConvertPackB + MlasSBGemmPackedBOffset + MlasSBGemmPackedBLeadingDim + MlasSBGemmKernel + + MlasSBGemmOperation is the shared kernel driver. + + A kernel type should define the following constants: + bool PackNeeded; Whether B needs to be packed + size_t KernelMaxM; Max # rows the vectorized kernel can process + size_t PackedK; Packed alignment on the K dim (power of 2) + size_t PackedN; Packed alignment on the n dim (power of 2) + MLAS_SBGEMM_STRIDES Strides{128, 128, 256}; +--*/ + +#if defined(__aarch64__) && defined(__linux__) + +#pragma once + +#include +#include + +#include "mlasi.h" + +/** + * @brief Define the default striding parameters for + * the bfloat16 precision gemm operation + */ +struct MLAS_SBGEMM_STRIDES { + size_t M; + size_t N; + size_t K; +}; + +/** + * @brief Convert fp32 matrix B to bf16 and pack the data + * + * @tparam KernelType + * @param[out] D Address of packing buffer + * @param[in] B Address of source matrix B in fp32 + * @param[in] ldb Leading dimension of B + * @param[in] CountN # of column to pack + * @param[in] CountK # of rows to pack + */ +template +void +MlasSBGemmConvertPackB( + bfloat16_t* PackedB, const float* B, size_t ldb, size_t CountN, size_t CountK +); + +/** + * @brief Find the location of PackedB[StartK, StartN] + * + * @tparam KernelType + * @param PackedB + * @param DimN Total columns of the packing buffer + * @param DimK Total rows of the packing buffer + * @param StartN + * @param StartK + * @return Address of PackedB[StartK, StartN] + */ +template +MLAS_FORCEINLINE const bfloat16_t* +MlasSBGemmPackedBOffset( + const bfloat16_t* PackedB, size_t DimN, size_t DimK, size_t StartN, size_t StartK +) +{ + // By default the packed buffer is just a row major + // K row by N column buffer + MLAS_UNREFERENCED_PARAMETER(DimK); + return PackedB + StartK * DimN + StartN; +} + +/** + * @brief leading dimension of the packed B buffer + * Related to how B is packed + * @tparam KernelType + * @param DimN + * @param DimK + * @return leading dimension of the packed B buffer + */ +template +MLAS_FORCEINLINE size_t +MlasSBGemmPackedBLeadingDim(size_t DimN, size_t DimK) +{ + // By default the packed buffer is just a row major + // K row by N column buffer + MLAS_UNREFERENCED_PARAMETER(DimK); + return DimN; +} + +template +void +MlasSBGemmKernel(const size_t CountM, const size_t CountN, const size_t CountK, const float* A, const size_t lda, const bfloat16_t* B, float* C, size_t ldc, const float* Bias, const bool ZeroMode); + +template +MLAS_FORCEINLINE void +MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size_t AlignedN, size_t K, const float* A, size_t lda, const void* PackedB, float* C, size_t ldc, const float* Bias, void* PostProcessor) +{ + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + size_t PackedStrideN = Strides.N; + size_t PackedStrideK = Strides.K; + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + const size_t SliceStartN = RangeStartN + n; + CountN = std::min(RangeCountN - n, PackedStrideN); + + // + // Step through each slice of matrix B along the K dimension. + // + size_t CountK; + for (size_t k = 0; k < K; k += CountK) { + bool ZeroMode = (k == 0); + CountK = std::min(K - k, PackedStrideK); + + const bfloat16_t* pb = (const bfloat16_t*)PackedB + AlignedN * k + CountK * SliceStartN; + float* c = C + n; + const float* pbias = ((nullptr == Bias) ? nullptr : Bias + RangeStartN + n); + MlasSBGemmKernel(M, CountN, CountK, A + k, lda, pb, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode); + } + if (PostProcessor != nullptr) { + ((MLAS_SBGEMM_POSTPROCESSOR*)PostProcessor) + ->Process(C + n, M, SliceStartN, M, CountN, ldc); + } + } +} + +template +void +MlasSBGemmNonPackedOperation(size_t M, size_t N, size_t K, const float* A, size_t lda, const float* B, size_t ldb, float* C, size_t ldc, const float* Bias, void* PostProcessor) +{ + // + // Compute the strides to step through slices of the input matrices. + // + // Expand the N stride if K is small or expand the K stride if N is small + // for better utilization of the B panel. Avoid changing the K stride if + // the A panel needs to be used for transposing. + // + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + size_t StrideN = Strides.N; + size_t StrideK = Strides.K; + + if (N >= K) { + while (StrideK / 2 >= K) { + StrideN *= 2; + StrideK /= 2; + } + } else { + while (StrideN > 16 && StrideN / 2 >= N) { + StrideK *= 2; + StrideN /= 2; + } + } + + constexpr size_t packBSize = UpAlignSize(Strides.N * Strides.K * sizeof(bfloat16_t)); + MlasThreadedBufAlloc(packBSize); + uint8_t* p = ThreadedBufHolder.get(); + auto* PanelB = reinterpret_cast(p); + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountN; + for (size_t n = 0; n < N; n += CountN) { + CountN = std::min(N - n, StrideN); + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountK; + for (size_t k = 0; k < K; k += CountK) { + CountK = std::min(K - k, StrideK); + + // + // Copy a panel of matrix B to a local packed buffer. + // + MlasSBGemmConvertPackB(PanelB, B + n + k * ldb, ldb, CountN, CountK); + + auto* c = C + n; + const float* pbias = + ((nullptr == Bias) ? nullptr : Bias + n); // TODO: check the SliceNStart + + bool ZeroMode = (k == 0); + MlasSBGemmKernel(M, CountN, CountK, A + k, lda, PanelB, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode); + } + if (PostProcessor != nullptr) { + ((MLAS_SBGEMM_POSTPROCESSOR*)PostProcessor)->Process(C + n, M, N, M, CountN, ldc); + } + } +} + +template +void +MlasSBGemmOperation(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN, const size_t M, const size_t N, const size_t K, const MLAS_SBGEMM_DATA_PARAMS* DataParams, ptrdiff_t ThreadId) +{ + const ptrdiff_t ThreadIdM = ThreadId / ThreadCountN; + const ptrdiff_t ThreadIdN = ThreadId % ThreadCountN; + + // + // Partition the operation along the M dimension. + // + size_t RangeStartM; + size_t RangeCountM; + + MlasPartitionWork(ThreadIdM, ThreadCountM, M, &RangeStartM, &RangeCountM); + + // + // Partition the operation along the N dimension. + // + size_t RangeStartN; + size_t RangeCountN; + + const size_t BlockedN = + (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + MlasPartitionWork(ThreadIdN, ThreadCountN, BlockedN, &RangeStartN, &RangeCountN); + + RangeStartN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + RangeCountN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + RangeCountN = std::min(N - RangeStartN, RangeCountN); + + // + // Dispatch the partitioned operation. + // + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; + const float* A = (const float*)DataParams->A + RangeStartM * lda; + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + const float* bias = DataParams->Bias; + + if (!DataParams->BIsfp32) { + MlasSBGemmPackedOperation( + RangeCountM, RangeStartN, RangeCountN, BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, K, A, + lda, DataParams->B, C, ldc, bias, (void*)DataParams->OutputProcessor + ); + } else { + const size_t ldb = DataParams->ldb; + const float* B = (const float*)DataParams->B + RangeStartN; + MlasSBGemmNonPackedOperation(RangeCountM, RangeCountN, K, A, lda, B, ldb, C, ldc, bias, (void*)DataParams->OutputProcessor); + } +} + +// +// dispatch structure. +// +typedef void(MLAS_SBGEMM_OPERATION)(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN, const size_t M, const size_t N, const size_t K, const MLAS_SBGEMM_DATA_PARAMS* DataParams, ptrdiff_t ThreadId); + +typedef void(MLAS_SBGEMM_CONVERTPACKB_ROUTINE)( + bfloat16_t* D, const float* B, size_t ldb, size_t CountN, size_t CountK +); + +/** + * @brief Hardware dependent dispatch for half precision GEMM + */ +struct MLAS_SBGEMM_DISPATCH { + MLAS_SBGEMM_OPERATION* Operation; /**< HalfGemm driver */ + MLAS_SBGEMM_CONVERTPACKB_ROUTINE* ConvertPackBRoutine; /**< Convert and pack function for B */ + size_t PackedK; + size_t PackedN; + size_t StrideM; + size_t BufOverRead; +}; + +extern const MLAS_SBGEMM_DISPATCH MlasSBGemmDispatchNeon; + +MLAS_FORCEINLINE +const MLAS_SBGEMM_DISPATCH* +MlasSBGemmGetDispatch() +{ +#if defined(MLAS_TARGET_ARM64) + return &MlasSBGemmDispatchNeon; +#else + std::cerr << "SBGemm Kernel is supported only on ARM64 platform."; + exit(1); +#endif +} + +size_t MLASCALL +MlasSBGemmPackBSize(size_t N, size_t K) +{ + // + // Compute the number of bytes required to hold the packed buffer. + // + const auto* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return 0; + + const auto padding = dispatch->BufOverRead; + const auto PackedK = dispatch->PackedK; + const auto PackedN = dispatch->PackedN; + + const size_t AlignedK = (K + PackedK - 1) & ~(PackedK - 1); + const size_t AlignedN = (N + PackedN - 1) & ~(PackedN - 1); + const size_t BytesRequired = AlignedN * AlignedK * sizeof(bfloat16_t) + padding; + const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); + const size_t AlignedBytesRequired = + (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); + + return AlignedBytesRequired; +} + +void MLASCALL +MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB) +{ + const auto* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return; + + dispatch->ConvertPackBRoutine((bfloat16_t*)PackedB, B, ldb, N, K); +} + +void MLASCALL +MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* Data, MLAS_THREADPOOL* ThreadPool) +{ + const MLAS_SBGEMM_DISPATCH* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return; + + MLAS_SBGEMM_OPERATION* operation = dispatch->Operation; + + // + // Compute the number of target threads given the complexity of the SGEMM + // operation. Small requests should run using the single threaded path. + // + + const double Complexity = double(M) * double(N) * double(K); + + ptrdiff_t TargetThreadCount; + + if (Complexity < double(MLAS_SBGEMM_THREAD_COMPLEXITY * GetMlasPlatform().MaximumThreadCount)) { + TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1; + } else { + TargetThreadCount = GetMlasPlatform().MaximumThreadCount; + } + + ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + // + // Segment the operation across multiple threads. + // + // N.B. Currently, the operation is segmented as a 1D partition, which + // works okay for operations involving skinny matrices. + // + ptrdiff_t ThreadsPerGemm = (TargetThreadCount + BatchN - 1) / BatchN; + ptrdiff_t ThreadCountM; + ptrdiff_t ThreadCountN; + + if (N > M) { + const size_t BlockedN = + (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + if (size_t(ThreadsPerGemm) > BlockedN) { + ThreadsPerGemm = ptrdiff_t(BlockedN); + } + + ThreadCountM = 1; + ThreadCountN = ThreadsPerGemm; + + } else { + if (size_t(ThreadsPerGemm) > M) { + ThreadsPerGemm = ptrdiff_t(M); + } + + ThreadCountM = ThreadsPerGemm; + ThreadCountN = 1; + } + + MlasTrySimpleParallel( + ThreadPool, ThreadsPerGemm * static_cast(BatchN), [=](ptrdiff_t tid) { + ptrdiff_t GemmIdx = tid / ThreadsPerGemm; + ptrdiff_t ThreadIdx = tid % ThreadsPerGemm; + operation(ThreadCountM, ThreadCountN, M, N, K, &(Data[GemmIdx]), ThreadIdx); + } + ); +} +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp new file mode 100644 index 0000000000000..a6a73996c548b --- /dev/null +++ b/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp @@ -0,0 +1,362 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + sbgemm_kernel_neon.cpp + +Abstract: + + This module implements bfloat16 precision GEMM kernel for neon. + +--*/ + +#if defined(__aarch64__) && defined(__linux__) + +#include "arm_neon.h" +#include "mlasi.h" +#include "sbgemm.h" + +struct MLAS_SBGEMM_KERNEL_NEON { + static constexpr bool PackNeeded = true; + static constexpr size_t KernelMaxM = 8; // max # rows the vectorized kernel can process + static constexpr size_t PackedK = 4; + static constexpr size_t PackedN = MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + static constexpr MLAS_SBGEMM_STRIDES Strides{128, 128, 256}; // M:N:K +}; + +bool MLASCALL +MlasBf16AccelerationSupported() +{ +#if defined(MLAS_TARGET_ARM64) + return MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_BF16(); +#else + return false; +#endif +} + +/* + This routine converts fp32 to bf16 and copies elements from the source + matrix to the destination packed buffer. + + 4x2 elements from the source matrix are unrolled to be physically + contiguous for better locality inside the SBGEMM kernels. The remaining + rows and columns are padded to 4 and 2 alignment. +*/ +MLAS_FORCEINLINE +void +MlasSBGemmConvertCopyPackB(bfloat16_t* D, const float* B, size_t ldb, size_t CountN, size_t CountK) +{ + // + // Copy data from matrix B into the destination buffer 4x2 blocks at a + // time. + // + // + while (CountN >= 8) { + const float* b = B; + int y = static_cast(CountK); + + while (y > 0) { + MLAS_FLOAT32X4 t0_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t0_h = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t1_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t1_h = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t2_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t2_h = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t3_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t3_h = MlasZeroFloat32x4(); + + if (y >= 4) { + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + t1_l = MlasLoadFloat32x4(&b[ldb * 1]); + t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); + t2_l = MlasLoadFloat32x4(&b[ldb * 2]); + t2_h = MlasLoadFloat32x4(&b[ldb * 2 + 4]); + t3_l = MlasLoadFloat32x4(&b[ldb * 3]); + t3_h = MlasLoadFloat32x4(&b[ldb * 3 + 4]); + } else { + switch (y) { + case 3: + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + t1_l = MlasLoadFloat32x4(&b[ldb * 1]); + t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); + t2_l = MlasLoadFloat32x4(&b[ldb * 2]); + t2_h = MlasLoadFloat32x4(&b[ldb * 2 + 4]); + break; + case 2: + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + t1_l = MlasLoadFloat32x4(&b[ldb * 1]); + t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); + break; + case 1: + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + break; + } + } + + float32x4x2_t z0_l = vzipq_f32(t0_l, t2_l); + float32x4x2_t z1_l = vzipq_f32(t1_l, t3_l); + float32x4x2_t o0_l = vzipq_f32(z0_l.val[0], z1_l.val[0]); + float32x4x2_t o1_l = vzipq_f32(z0_l.val[1], z1_l.val[1]); + t0_l = o0_l.val[0]; + t1_l = o0_l.val[1]; + t2_l = o1_l.val[0]; + t3_l = o1_l.val[1]; + + bfloat16x8_t t0t1_l_4h = vcvtq_low_bf16_f32(t0_l); + bfloat16x8_t t0t1_l_8h = vcvtq_high_bf16_f32(t0t1_l_4h, t1_l); + + bfloat16x8_t t2t3_l_4h = vcvtq_low_bf16_f32(t2_l); + bfloat16x8_t t2t3_l_8h = vcvtq_high_bf16_f32(t2t3_l_4h, t3_l); + + vst1q_bf16(&D[0], t0t1_l_8h); + vst1q_bf16(&D[8], t2t3_l_8h); + + float32x4x2_t z0_h = vzipq_f32(t0_h, t2_h); + float32x4x2_t z1_h = vzipq_f32(t1_h, t3_h); + float32x4x2_t o0_h = vzipq_f32(z0_h.val[0], z1_h.val[0]); + float32x4x2_t o1_h = vzipq_f32(z0_h.val[1], z1_h.val[1]); + t0_h = o0_h.val[0]; + t1_h = o0_h.val[1]; + t2_h = o1_h.val[0]; + t3_h = o1_h.val[1]; + + bfloat16x8_t t0t1_h_4h = vcvtq_low_bf16_f32(t0_h); + bfloat16x8_t t0t1_h_8h = vcvtq_high_bf16_f32(t0t1_h_4h, t1_h); + + bfloat16x8_t t2t3_h_4h = vcvtq_low_bf16_f32(t2_h); + bfloat16x8_t t2t3_h_8h = vcvtq_high_bf16_f32(t2t3_h_4h, t3_h); + + vst1q_bf16(&D[16], t0t1_h_8h); + vst1q_bf16(&D[24], t2t3_h_8h); + + D += 32; + b += ldb * 4; + y -= 4; + }; + B += 8; + CountN -= 8; + } + + // + // Special case the handling of the remaining columns less than 8 elements + // wide. + // + if (CountN > 0) { + int y = static_cast(CountK); + while (y > 0) { + const float* b = B; + size_t b_inc = 0; + if ((CountN & 4) != 0) { + MLAS_FLOAT32X4 t0 = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t1 = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t2 = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t3 = MlasZeroFloat32x4(); + if (y >= 4) { + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + t1 = MlasLoadFloat32x4(&b[ldb * 1]); + t2 = MlasLoadFloat32x4(&b[ldb * 2]); + t3 = MlasLoadFloat32x4(&b[ldb * 3]); + } else { + switch (y) { + case 3: + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + t1 = MlasLoadFloat32x4(&b[ldb * 1]); + t2 = MlasLoadFloat32x4(&b[ldb * 2]); + break; + case 2: + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + t1 = MlasLoadFloat32x4(&b[ldb * 1]); + break; + case 1: + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + break; + } + } + + float32x4x2_t z0 = vzipq_f32(t0, t2); + float32x4x2_t z1 = vzipq_f32(t1, t3); + float32x4x2_t o0 = vzipq_f32(z0.val[0], z1.val[0]); + float32x4x2_t o1 = vzipq_f32(z0.val[1], z1.val[1]); + + t0 = o0.val[0]; + t1 = o0.val[1]; + t2 = o1.val[0]; + t3 = o1.val[1]; + + bfloat16x8_t t0t1_4h = vcvtq_low_bf16_f32(t0); + bfloat16x8_t t0t1_8h = vcvtq_high_bf16_f32(t0t1_4h, t1); + + bfloat16x8_t t2t3_4h = vcvtq_low_bf16_f32(t2); + bfloat16x8_t t2t3_8h = vcvtq_high_bf16_f32(t2t3_4h, t3); + + vst1q_bf16(&D[0], t0t1_8h); + vst1q_bf16(&D[8], t2t3_8h); + + D += 16; + b += 4; + b_inc += 4; + } + + if ((CountN & 2) != 0) { + float32x2_t t0 = {0x0, 0x0}; + float32x2_t t1 = {0x0, 0x0}; + float32x2_t t2 = {0x0, 0x0}; + float32x2_t t3 = {0x0, 0x0}; + + if (y >= 4) { + t0 = vld1_f32(&b[ldb * 0]); + t1 = vld1_f32(&b[ldb * 1]); + t2 = vld1_f32(&b[ldb * 2]); + t3 = vld1_f32(&b[ldb * 3]); + } else { + switch (y) { + case 3: + t0 = vld1_f32(&b[ldb * 0]); + t1 = vld1_f32(&b[ldb * 1]); + t2 = vld1_f32(&b[ldb * 2]); + break; + case 2: + t0 = vld1_f32(&b[ldb * 0]); + t1 = vld1_f32(&b[ldb * 1]); + break; + case 1: + t0 = vld1_f32(&b[ldb * 0]); + break; + } + } + + float32x2x2_t z0 = vzip_f32(t0, t2); + float32x2x2_t z1 = vzip_f32(t1, t3); + float32x2x2_t o0 = vzip_f32(z0.val[0], z1.val[0]); + float32x2x2_t o1 = vzip_f32(z0.val[1], z1.val[1]); + + float32x4_t tt0 = vcombine_f32(o0.val[0], o0.val[1]); + float32x4_t tt1 = vcombine_f32(o1.val[0], o1.val[1]); + + bfloat16x8_t t_4h = vcvtq_low_bf16_f32(tt0); + bfloat16x8_t t_8h = vcvtq_high_bf16_f32(t_4h, tt1); + + vst1q_bf16(&D[0], t_8h); + + D += 8; + b += 2; + b_inc += 2; + } + if ((CountN & 1) != 0) { + float a = 0.0f; + float b = 0.0f; + float c = 0.0f; + float d = 0.0f; + + if (y >= 4) { + a = *(float*)(&B[ldb * 0 + b_inc]); + b = *(float*)(&B[ldb * 1 + b_inc]); + c = *(float*)(&B[ldb * 2 + b_inc]); + d = *(float*)(&B[ldb * 3 + b_inc]); + } else { + switch (y) { + case 3: + a = *(float*)(&B[ldb * 0 + b_inc]); + b = *(float*)(&B[ldb * 1 + b_inc]); + c = *(float*)(&B[ldb * 2 + b_inc]); + break; + case 2: + a = *(float*)(&B[ldb * 0 + b_inc]); + b = *(float*)(&B[ldb * 1 + b_inc]); + break; + case 1: + a = *(float*)(&B[ldb * 0 + b_inc]); + break; + } + } + + float32x2_t t0 = {a, 0x0}; + float32x2_t t1 = {b, 0x0}; + float32x2_t t2 = {c, 0x0}; + float32x2_t t3 = {d, 0x0}; + + float32x2x2_t z0 = vzip_f32(t0, t2); + float32x2x2_t z1 = vzip_f32(t1, t3); + float32x2x2_t o0 = vzip_f32(z0.val[0], z1.val[0]); + float32x2x2_t o1 = vzip_f32(z0.val[1], z1.val[1]); + + float32x4_t tt0 = vcombine_f32(o0.val[0], o0.val[1]); + float32x4_t tt1 = vcombine_f32(o1.val[0], o1.val[1]); + + bfloat16x8_t t_4h = vcvtq_low_bf16_f32(tt0); + bfloat16x8_t t_8h = vcvtq_high_bf16_f32(t_4h, tt1); + + vst1q_bf16(&D[0], t_8h); + + D += 8; + b += 1; + b_inc += 1; + } + B += 4 * ldb; + y -= 4; + } + } +} + +template +void +MlasSBGemmConvertPackB( + bfloat16_t* PackedB, const float* B, size_t ldb, size_t CountN, size_t CountK +) +{ + const auto* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return; + + const auto PackedN = dispatch->PackedN; + + const size_t AlignedN = (CountN + PackedN - 1) & ~(PackedN - 1); + + // + // Step through each slice of matrix B along the K dimension. + // + size_t K_block_size; + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + + for (size_t k = 0; k < CountK; k += K_block_size) { + K_block_size = std::min(CountK - k, Strides.K); + + MlasSBGemmConvertCopyPackB((bfloat16_t*)PackedB, B + k * ldb, ldb, CountN, K_block_size); + PackedB = (bfloat16_t*)PackedB + AlignedN * K_block_size; + } +} + +template <> +MLAS_FORCEINLINE void +MlasSBGemmKernel(size_t CountM, size_t CountN, size_t CountK, const float* A, size_t lda, const bfloat16_t* B, float* C, size_t ldc, const float* Bias, const bool ZeroMode) +{ + while (CountM > 0) { + size_t RowsHandled; + if (ZeroMode) { + RowsHandled = MlasSbgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, Bias); + } else { + RowsHandled = MlasSbgemmKernelAdd(A, B, C, CountK, CountM, CountN, lda, ldc, Bias); + } + C += ldc * RowsHandled; + A += lda * RowsHandled; + CountM -= RowsHandled; + } +} + +const MLAS_SBGEMM_DISPATCH MlasSBGemmDispatchNeon = { + MlasSBGemmOperation, + MlasSBGemmConvertPackB, + MLAS_SBGEMM_KERNEL_NEON::PackedK, + MLAS_SBGEMM_KERNEL_NEON::PackedN, + MLAS_SBGEMM_KERNEL_NEON::KernelMaxM, + 32 // kernel may read beyond buffer end by 32 bytes +}; +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index ec395cf018f5e..583ee759cc2e6 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -6,7 +6,6 @@ #include "core/providers/cpu/math/matmul_helper.h" #include "core/util/math.h" #include "core/util/math_cpuonly.h" -#include "core/mlas/inc/mlas.h" namespace onnxruntime { @@ -125,6 +124,44 @@ Status MatMul::Compute(OpKernelContext* ctx) const { return Status::OK(); } +#if defined(__aarch64__) && defined(__linux__) +bool GemmPackBBfloat16(AllocatorPtr& alloc, + const Tensor& tensor_b, + bool trans_b, + IAllocatorUniquePtr& packed_b, + size_t& packed_b_size, + TensorShape& b_shape) { + // Only handle the common case of a 2D weight matrix. Additional matrices + // could be handled by stacking the packed buffers. + if (tensor_b.Shape().NumDimensions() != 2) { + return false; + } + + b_shape = tensor_b.Shape(); + + const size_t K = trans_b ? static_cast(b_shape[1]) : static_cast(b_shape[0]); + const size_t N = trans_b ? static_cast(b_shape[0]) : static_cast(b_shape[1]); + + packed_b_size = MlasSBGemmPackBSize(N, K); + if (packed_b_size == 0) { + return false; + } + + packed_b = IAllocator::MakeUniquePtr(alloc, packed_b_size, true); + auto* packed_b_data = packed_b.get(); + + // Initialize memory to 0 as there could be some padding associated with pre-packed + // buffer memory and we don not want it uninitialized and generate different hashes + // if and when we try to cache this pre-packed buffer for sharing between sessions. + memset(packed_b_data, 0, packed_b_size); + MlasSBGemmConvertPackB(N, + K, + tensor_b.Data(), + trans_b ? K : N, + packed_b_data); + return true; +} +#endif Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, @@ -134,7 +171,24 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc // only pack Matrix B if (input_idx == 1) { size_t packed_b_size; - is_packed = GemmPackBFp32(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_); +#if defined(__aarch64__) && defined(__linux__) + size_t dim1 = 0; + size_t dim2 = 0; + TensorShape b_shape = tensor.Shape(); + + if (b_shape.NumDimensions() == 2) { + dim1 = static_cast(b_shape[0]); + dim2 = static_cast(b_shape[1]); + } + + if (use_fastmath_mode_ && (trans_b_attr_ == 0) && ((dim1 * dim2) >= kFastMathModeKernelsizeThreshold)) { + is_packed = GemmPackBBfloat16(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_); + } else +#endif + { + is_packed = GemmPackBFp32(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_); + } + bool share_prepacked_weights = (prepacked_weights != nullptr); if (is_packed && share_prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); @@ -186,22 +240,40 @@ Status MatMul::Compute(OpKernelContext* ctx) const { const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(trans_a); const size_t ldb = helper.Ldb(trans_b); - - std::vector data(max_len); - for (size_t i = 0; i < max_len; i++) { - data[i].BIsPacked = bool(packed_b_); - data[i].A = a_data + helper.LeftOffsets()[i]; - data[i].lda = lda; - data[i].B = data[i].BIsPacked ? (float*)packed_b_.get() : b_data + helper.RightOffsets()[i]; - data[i].ldb = ldb; - data[i].C = y_data + helper.OutputOffsets()[i]; - data[i].ldc = N; - data[i].alpha = alpha_attr_; - data[i].beta = 0.0f; +#if defined(__aarch64__) && defined(__linux__) + if (use_fastmath_mode_ && !trans_b && ((N * K) >= kFastMathModeKernelsizeThreshold)) { + std::vector data(max_len); + for (size_t i = 0; i < max_len; i++) { + data[i].BIsfp32 = !(bool(packed_b_)); + data[i].AIsfp32 = true; + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = data[i].BIsfp32 ? b_data + helper.RightOffsets()[i] : (float*)packed_b_.get(); + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].Bias = nullptr; + data[i].OutputProcessor = nullptr; + } + MlasSBGemmBatch(M, N, K, max_len, data.data(), thread_pool); + } else +#endif + { + std::vector data(max_len); + for (size_t i = 0; i < max_len; i++) { + data[i].BIsPacked = bool(packed_b_); + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = data[i].BIsPacked ? (float*)packed_b_.get() : b_data + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = alpha_attr_; + data[i].beta = 0.0f; + } + MlasGemmBatch(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans, + M, N, K, data.data(), max_len, thread_pool); } - MlasGemmBatch(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans, - M, N, K, data.data(), max_len, thread_pool); - return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/math/matmul.h b/onnxruntime/core/providers/cpu/math/matmul.h index b960fa4fb0587..b9bbe36583879 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.h +++ b/onnxruntime/core/providers/cpu/math/matmul.h @@ -4,6 +4,8 @@ #pragma once #include "core/framework/op_kernel.h" +#include "core/mlas/inc/mlas.h" +#include "core/session/onnxruntime_session_options_config_keys.h" namespace onnxruntime { @@ -27,6 +29,11 @@ class MatMul final : public OpKernel { info.GetAttrOrDefault("transBatchB", &trans_batch_b_attr, 0); trans_batch_a_ = trans_batch_a_attr != 0; trans_batch_b_ = trans_batch_b_attr != 0; + +#if defined(__aarch64__) && defined(__linux__) + auto config_ops = info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16); + use_fastmath_mode_ = (config_ops == "1") && MlasBf16AccelerationSupported(); +#endif } Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, @@ -48,6 +55,14 @@ class MatMul final : public OpKernel { int64_t trans_b_attr_; bool trans_batch_a_; bool trans_batch_b_; + +#if defined(__aarch64__) && defined(__linux__) + // fastmath mode state + bool use_fastmath_mode_; + // sbgemm kernel is implemented as 8x8 blocks with weights pre-packed to 4 blocks of 4x2 + // so a minimum of 32 elements is defined to outweigh the additional prepacking overhead + const size_t kFastMathModeKernelsizeThreshold = 32; +#endif }; } // namespace onnxruntime diff --git a/onnxruntime/test/mlas/unittest/test_sbgemm.cpp b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp new file mode 100644 index 0000000000000..941de8f05061f --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp @@ -0,0 +1,141 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + test_sbgemm.cpp + +Abstract: + + Tests for MLAS bf16 precision GEMM. + +--*/ + +#if defined(__aarch64__) && defined(__linux__) + +#include "test_sbgemm.h" + +// +// Short Execute() test helper to register each test seperately by all parameters. +// +template +class SBGemmShortExecuteTest : public MlasTestFixture> { + public: + explicit SBGemmShortExecuteTest(size_t M, size_t N, size_t K, size_t Batch, bool hasBias) + : M_(M), N_(N), K_(K), Batch_(Batch), hasBias_(hasBias) {} + + void TestBody() override { + MlasTestFixture>::mlas_tester->Test(M_, N_, K_, Batch_, hasBias_); + } + + static size_t RegisterSingleTest(size_t M, size_t N, size_t K, size_t Batch, bool hasBias) { + std::stringstream ss; + ss << "Batch" << Batch << "/M" << M << "xN" << N << "xK" << K << "/" + << "hasBias" << hasBias; + auto test_name = ss.str(); + + testing::RegisterTest( + MlasSBGemmTest::GetTestSuiteName(), + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + // Important to use the fixture type as the return type here. + [=]() -> MlasTestFixture>* { + return new SBGemmShortExecuteTest( + M, N, K, Batch, hasBias); + }); + + return 1; + } + + static size_t RegisterShortExecuteTests() { + size_t test_registered = 0; + for (size_t b = 1; b < 16; b++) { + test_registered += RegisterSingleTest(b, b, b, 1, false); + test_registered += RegisterSingleTest(b, b, b, 1, true); + } + for (size_t b = 16; b <= 256; b <<= 1) { + test_registered += RegisterSingleTest(b, b, b, 1, false); + test_registered += RegisterSingleTest(b, b, b, 1, true); + } + for (size_t b = 256; b < 320; b += 32) { + test_registered += RegisterSingleTest(b, b, b, 1, true); + } + for (size_t b = 1; b < 96; b++) { + test_registered += RegisterSingleTest(1, b, 32, 1, false); + test_registered += RegisterSingleTest(1, 32, b, 1, true); + test_registered += RegisterSingleTest(1, b, b, 1, false); + if (!Packed) { + test_registered += RegisterSingleTest(1, b, 32, 3, true); + test_registered += RegisterSingleTest(1, 32, b, 5, false); + } + } + // TODO: check why the cosine similary is < 0.99 for this shape alone + // test_registered += RegisterSingleTest(43, 500, 401, 1, true); + test_registered += RegisterSingleTest(1001, 1027, 1031, 1, false); + if (!Packed) { + test_registered += RegisterSingleTest(43, 500, 401, 5, true); + test_registered += RegisterSingleTest(1000, 1029, 1030, 3, false); + } + + return test_registered; + } + + private: + size_t M_, N_, K_, Batch_; + bool hasBias_; +}; + +static size_t SBGemmRegistLongExecute() { + size_t count = 0; + + count += MlasLongExecuteTests>::RegisterLongExecute(); + if (MlasSBGemmPackBSize(128, 128) > 0) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + } + + if (GetMlasThreadPool() != nullptr) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + if (MlasSBGemmPackBSize(128, 128) > 0) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + } + } + + return count; +} + +static size_t SBGemmRegistShortExecute() { + size_t count = 0; + + count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); + if (MlasSBGemmPackBSize(128, 128) > 0) { + count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); + } + + if (GetMlasThreadPool() != nullptr) { + count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); + if (MlasSBGemmPackBSize(128, 128) > 0) { + count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); + } + } + + return count; +} + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + if (!MlasBf16AccelerationSupported()) { + return false; + } + + if (is_short_execute) { + return SBGemmRegistShortExecute() > 0; + } + return SBGemmRegistLongExecute() > 0; +}); +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/test/mlas/unittest/test_sbgemm.h b/onnxruntime/test/mlas/unittest/test_sbgemm.h new file mode 100644 index 0000000000000..13701e2e3de46 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_sbgemm.h @@ -0,0 +1,281 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + test_sbgemm.h + +Abstract: + + Tests for MLAS bf16 precision GEMM. + +--*/ + +#if defined(__aarch64__) && defined(__linux__) + +#pragma once + +#include "test_util.h" + +template +void SmallFloatFill(T* start, size_t size) { + constexpr float MinimumFillValue = -11.0f; + auto FillAddress = start; + size_t offset = size % 23; + + for (size_t i = 0; i < size; i++) { + offset = (offset + 21) % 23; + *FillAddress++ = T((MinimumFillValue + offset) / 16.0f); + } +} + +float cosine_similarity(const float* A, const float* B, size_t Vector_Length) { + float dot = 0.0, denom_a = 0.0, denom_b = 0.0; + for (size_t i = 0u; i < Vector_Length; ++i) { + dot += A[i] * B[i]; + denom_a += A[i] * A[i]; + denom_b += B[i] * B[i]; + } + return dot / (sqrt(denom_a) * sqrt(denom_b)); +} + +/** + * @brief Test class for bf16 precision GEMM + * @tparam AType Data type of A matrix, need to be float + * @tparam BType Data type of b matrix, can be either float or prepacked bf16 + */ +template +class MlasSBGemmTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferBPacked; + MatrixGuardBuffer BufferA; + MatrixGuardBuffer BufferB; + MatrixGuardBuffer BufferBias; + MatrixGuardBuffer BufferC; + MatrixGuardBuffer BufferCReference; + MatrixGuardBuffer BufferFloatC; + MLAS_THREADPOOL* threadpool_; + + void* PackB(size_t N, size_t K, const BType* B, size_t ldb) { + size_t PackedBSize = MlasSBGemmPackBSize(N, K); + if (PackedBSize == 0) { + return nullptr; + } + void* PackedB = BufferBPacked.GetBuffer(PackedBSize); + if (std::is_same::value) { + MlasSBGemmConvertPackB(N, K, (const float*)B, ldb, PackedB); + } else { + } + return PackedB; + } + + void CallSBGemm(size_t M, + size_t N, + size_t K, + size_t BatchSize, + const float* A, + size_t lda, + const BType* B, + size_t ldb, + const float* Bias, + float* C, + size_t ldc) { + std::vector GemmParameters(BatchSize); + + for (size_t i = 0; i < GemmParameters.size(); i++) { + auto& params = GemmParameters[i]; + params.A = A + (M * lda * i); + params.lda = lda; + if (nullptr != Bias) { + params.Bias = reinterpret_cast(Bias + N * i); + } else { + params.Bias = nullptr; + } + params.C = reinterpret_cast(C + (M * ldc * i)); + params.ldc = ldc; + params.AIsfp32 = true; + params.BIsfp32 = true; + + if (Packed) { + ASSERT_EQ(BatchSize, size_t(1)) << "Packing B not supported in batching yet!"; + params.B = PackB(N, K, B, ldb); + params.ldb = 0; + params.BIsfp32 = false; + } else { + params.B = B + (K * N * i); + params.ldb = ldb; + } + } + + MlasSBGemmBatch(M, N, K, BatchSize, GemmParameters.data(), threadpool_); + } + + void ReferenceSgemm(size_t M, + size_t N, + size_t K, + size_t BatchSize, + const AType* A, + const BType* B, + const float* Bias, + float* C) { + constexpr size_t KStride = 256; + + for (size_t batch = 0; batch < BatchSize; batch++) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + const AType* a = A + M * K * batch + m * K; + const BType* b = B + K * N * batch + n; + float* c = C + (M * N * batch) + (m * N) + n; + + for (size_t k = 0; k < K; k += KStride) { + float sum = 0.0f; + if (k == 0 && Bias != nullptr) { + sum = float(Bias[n]); + } + for (size_t kk = 0; kk < std::min(KStride, K - k); kk++) { + float down(float(*b) * float(*a) + sum); + sum = float(down); + b += N; + a += 1; + } + if (k == 0) { + *c = sum; + } else { + float d(sum + *c); + *c = float(d); + } + } + } + } + if (Bias) { + Bias += N; + } + } + } + + public: + MlasSBGemmTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} + + void Test(size_t M, size_t N, size_t K, size_t BatchSize, bool withBias) { + AType* A = BufferA.GetFilledBuffer(K * M * BatchSize + 16, SmallFloatFill); + AType Atail[16]; + std::memcpy(Atail, A + K * M * BatchSize, 16 * sizeof(AType)); + + BType* B = BufferB.GetFilledBuffer(N * K * BatchSize + 16, SmallFloatFill); + BType Btail[16]; + std::memcpy(Btail, B + N * K * BatchSize, 16 * sizeof(BType)); + + float BiasTail[16]; + const float* Bias = nullptr; + if (withBias) { + Bias = BufferBias.GetFilledBuffer(N * BatchSize + 16, SmallFloatFill); + std::memcpy(BiasTail, Bias + N * BatchSize, 16 * sizeof(float)); + } + + float* C = BufferC.GetFilledBuffer(N * M * BatchSize, SmallFloatFill); + float* CReference = BufferCReference.GetFilledBuffer( + N * M * BatchSize, + [](float* start, size_t size) { + std::fill_n(start, size, -1.0f); + }); + this->CallSBGemm(M, N, K, BatchSize, A, K, B, N, Bias, C, N); + ReferenceSgemm(M, N, K, BatchSize, A, B, Bias, CReference); + const float cosine_similarity_threshold = 0.98; + + for (size_t batch = 0, f = 0; batch < BatchSize; batch++) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++, f++) { + if (!(CloseEnough(float(C[f]), CReference[f]))) { + float cos_sim = cosine_similarity(C, CReference, (BatchSize * M * N)); + if (abs(cos_sim) < cosine_similarity_threshold) { + ASSERT_TRUE(false) << "cosine similarity check failed" << cos_sim; + } else { + break; + } + } + } + } + } + + ASSERT_EQ(std::memcmp(Atail, A + K * M * BatchSize, 16 * sizeof(AType)), 0) << "Matrix A buffer overwritten!"; + ASSERT_EQ(std::memcmp(Btail, B + N * K * BatchSize, 16 * sizeof(BType)), 0) << "Matrix B buffer overwritten!"; + if (withBias) { + ASSERT_EQ(std::memcmp(BiasTail, Bias + N * BatchSize, 16 * sizeof(float)), 0) << "Bias buffer overwritten!"; + } + } + + private: + public: + static const char* GetTestSuiteName() { + static std::string suite_name = std::string("SBGemmFP") + + (std::is_same::value ? "32" : "16") + + (std::is_same::value ? "32" : "16") + + (Packed ? "_Packed" : "_NoPack") + + (Threaded ? "_Threaded" : "_SingleThread"); + return suite_name.c_str(); + } + + void ExecuteLong(void) override { + for (size_t M = 16; M < 160; M += 32) { + for (size_t N = 16; N < 160; N += 32) { + static const size_t ks[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 20, 32, 48, 64, 118, 119, 120, 121, 122, 160, 240, 320}; + for (size_t k = 0; k < _countof(ks); k++) { + size_t K = ks[k]; + + Test(M, N, K, 1, false); + Test(M, N, K, 1, true); + Test(M + 1, N, K, 1, false); + Test(M, N + 1, K, 1, true); + Test(M + 1, N + 1, K, 1, false); + Test(M + 3, N + 2, K, 1, true); + Test(M + 4, N, K, 1, false); + Test(M, N + 4, K, 1, true); + Test(M + 4, N + 4, K, 1, false); + Test(M + 3, N + 7, K, 1, true); + Test(M + 8, N, K, 1, false); + Test(M, N + 8, K, 1, true); + Test(M + 12, N + 12, K, 1, false); + Test(M + 13, N, K, 1, true); + Test(M, N + 15, K, 1, false); + Test(M + 15, N + 15, K, 1, false); + if (!Packed) { + Test(M, N, K, 7, false); + Test(M + 3, N, K, 8, true); + Test(M, N + 1, K, 9, false); + Test(M + 12, N, K, 10, true); + Test(M, N + 15, K, 11, false); + Test(M + 15, N + 15, K, 12, true); + } + } + } + printf("M %zd\n", M); + } + + for (size_t M = 1; M < 160; M++) { + for (size_t N = 1; N < 160; N++) { + for (size_t K = 1; K < 160; K++) { + Test(M, N, K, 1, true); + } + } + printf("M %zd\n", M); + } + + for (size_t M = 160; M < 320; M += 24) { + for (size_t N = 112; N < 320; N += 24) { + for (size_t K = 1; K < 16; K++) { + Test(M, N, K, 1, true); + } + for (size_t K = 16; K < 160; K += 32) { + Test(M, N, K, 1, false); + } + } + printf("M %zd\n", M); + } + } +}; + +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc b/onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc new file mode 100644 index 0000000000000..ec9f78da14a75 --- /dev/null +++ b/onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc @@ -0,0 +1,730 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// Licensed under the MIT License. + +#include "core/framework/compute_capability.h" +#include "core/graph/model.h" +#include "core/graph/onnx_protobuf.h" +#include "core/mlas/inc/mlas.h" +#include "core/optimizer/qdq_transformer/qdq_final_cleanup.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" +#include "core/optimizer/utils.h" +#include "core/providers/partitioning_utils.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/environment.h" +#include "core/session/inference_session.h" + +#include "test/compare_ortvalue.h" +#include "test/test_environment.h" +#include "test/framework/test_utils.h" +#include "test/util/include/asserts.h" +#include "test/util/include/inference_session_wrapper.h" + +#include "gtest/gtest.h" +#include "graph_transform_test_builder.h" + +#include "qdq_test_utils.h" + +#if defined(__aarch64__) && defined(__linux__) && !defined(DISABLE_CONTRIB_OPS) + +struct QDQOpKeys { + const char* quantize_linear; + const char* dequantize_linear; +}; + +constexpr QDQOpKeys GetQDQOpKeys(bool use_contrib_qdq) { + if (use_contrib_qdq) { + return {"com.microsoft.QuantizeLinear", "com.microsoft.DequantizeLinear"}; + } + return {"QuantizeLinear", "DequantizeLinear"}; +} + +namespace onnxruntime { +namespace test { + +#if !defined(DISABLE_CONTRIB_OPS) + +TEST(QDQTransformerTests, DQ_S8_to_U8_FastMath) { + auto test_case = [](bool use_contrib_qdq) { + const std::vector& input_shape = {19, 37}; + const std::vector& weights_shape = {37, 23}; + + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input_shape, -1.f, 1.f); + + // Use full range weight values to expose u8s8 overflow problems + auto* weight = builder.MakeInitializer(weights_shape, -128, 127); + auto* output_arg = builder.MakeOutput(); + + // add QDQ activation + typedef std::numeric_limits Input1Limits; + auto* dq1_output = AddQDQNodePair(builder, input1_arg, .039f, + (int8_t)((Input1Limits::max() + Input1Limits::min()) / 2 + 1), + use_contrib_qdq); + + // add DQ weight + auto* dq_w_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(weight, .003f, -10, dq_w_output, use_contrib_qdq); + + builder.AddNode("MatMul", {dq1_output, dq_w_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 1); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 1); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, /*using NAN as a magic number to trigger cosine similarity*/ + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options); + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 18 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options); + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 19 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options); + + auto add_session_options_disable_fm = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options_disable_fm); + }; + + test_case(false); // Use ONNX QDQ ops + test_case(true); // Use com.microsoft QDQ ops +} + +template +void QDQTransformerMatMulTests(bool has_output_q, bool disable_fastmath = false) { + auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape, + bool use_contrib_qdq = false) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, -1.f, 1.f); + auto* input2_arg = builder.MakeInput(input2_shape, -1.f, 1.f); + auto* output_arg = builder.MakeOutput(); + + typedef std::numeric_limits Input1Limits; + typedef std::numeric_limits Input2Limits; + typedef std::numeric_limits OutputTypeLimits; + + // add QDQ 1 + auto* q1_output = builder.MakeIntermediate(); + auto* dq1_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input1_arg, + .039f, + (Input1Limits::max() + Input1Limits::min()) / 2 + 1, + q1_output, use_contrib_qdq); + builder.AddDequantizeLinearNode(q1_output, + .039f, + (Input2Limits::max() + Input1Limits::min()) / 2 + 1, + dq1_output, use_contrib_qdq); + + // add QDQ 2 + auto* q2_output = builder.MakeIntermediate(); + auto* dq2_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input2_arg, + .04f, + (Input2Limits::max() + Input2Limits::min()) / 2 + 1, + q2_output, use_contrib_qdq); + builder.AddDequantizeLinearNode(q2_output, + .04f, + (Input2Limits::max() + Input2Limits::min()) / 2 + 1, + dq2_output, use_contrib_qdq); + + if (has_output_q) { + // add binary operator + auto* matmul_op_output = builder.MakeIntermediate(); + builder.AddNode("MatMul", {dq1_output, dq2_output}, {matmul_op_output}); + + // add QDQ output + auto* q3_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(matmul_op_output, + .039f, + (OutputTypeLimits::max() + OutputTypeLimits::min()) / 2 + 1, + q3_output, use_contrib_qdq); + builder.AddDequantizeLinearNode(q3_output, + .039f, + (OutputTypeLimits::max() + OutputTypeLimits::min()) / 2 + 1, + output_arg, use_contrib_qdq); + } else { + builder.AddNode("MatMul", {dq1_output, dq2_output}, {output_arg}); + } + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + if (has_output_q) { + if constexpr (std::is_same::value && + (std::is_same::value || + QDQIsInt8Allowed() && std::is_same::value)) { + EXPECT_EQ(op_to_count["QLinearMatMul"], 1); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + } else { + EXPECT_EQ(op_to_count["QLinearMatMul"], 0); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 3); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 3); + } + } else { + if constexpr (std::is_same::value || + (QDQIsInt8Allowed() && std::is_same::value)) { + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 1); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + } else { + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 0); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 2); + } + } + }; + + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 18 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 19 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + + if (disable_fastmath) { + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + } + }; + + test_case({1, 2, 2}, {1, 2, 4}); + test_case({1, 23, 13, 13}, {13, 13}); + test_case({1, 22, 11, 13, 15}, {1, 22, 11, 15, 15}); + test_case({1, 2, 2}, {1, 2, 4}, true); // Use com.microsoft QDQ ops +} + +TEST(QDQTransformerTests, MatMul_U8U8U8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_U8S8S8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_U8U8S8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_U8S8U8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_S8S8S8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_S8U8U8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_S8U8S8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_S8S8U8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +// dummy test to disable the fastmath session op +TEST(QDQTransformerTests, MatMul_S8S8U8_DisableFastMath) { + QDQTransformerMatMulTests(false, true); + QDQTransformerMatMulTests(true, true); +} + +template +void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one = false, bool disable_fastmath = false) { + auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape, + bool use_contrib_qdq = false) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, -1.f, 1.f); + auto* input2_arg = builder.MakeInput(input2_shape, -1.f, 1.f); + auto* output_arg = builder.MakeOutput(); + + typedef std::numeric_limits Input1Limits; + typedef std::numeric_limits Input2Limits; + typedef std::numeric_limits OutputTypeLimits; + + std::vector input_args; + + // add QDQ A + auto* q1_output = builder.MakeIntermediate(); + auto* dq1_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input1_arg, + .039f, + (Input1Limits::max() + Input1Limits::min()) / 2 + 1, + q1_output, use_contrib_qdq); + builder.AddDequantizeLinearNode(q1_output, + .039f, + (Input2Limits::max() + Input1Limits::min()) / 2 + 1, + dq1_output, use_contrib_qdq); + + input_args.push_back(dq1_output); + + // add QDQ B + auto* q2_output = builder.MakeIntermediate(); + auto* dq2_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input2_arg, + .04f, + (Input2Limits::max() + Input2Limits::min()) / 2 + 1, + q2_output, use_contrib_qdq); + builder.AddDequantizeLinearNode(q2_output, + .04f, + (Input2Limits::max() + Input2Limits::min()) / 2 + 1, + dq2_output, use_contrib_qdq); + input_args.push_back(dq2_output); + + if (has_bias) { + auto* dq_bias_output = builder.MakeIntermediate(); + auto* bias = builder.MakeInitializer({input2_shape[1]}, static_cast(0), static_cast(127)); + builder.AddDequantizeLinearNode(bias, 0.00156f, + 0, + dq_bias_output, use_contrib_qdq); + input_args.push_back(dq_bias_output); + } + + Node* gemm_node = nullptr; + + if (has_output_q) { + auto* gemm_op_output = builder.MakeIntermediate(); + gemm_node = &builder.AddNode("Gemm", input_args, {gemm_op_output}); + + // add QDQ output + auto* q3_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(gemm_op_output, + .039f, + (OutputTypeLimits::max() + OutputTypeLimits::min()) / 2 + 1, + q3_output, use_contrib_qdq); + builder.AddDequantizeLinearNode(q3_output, + .039f, + (OutputTypeLimits::max() + OutputTypeLimits::min()) / 2 + 1, + output_arg, use_contrib_qdq); + } else { + gemm_node = &builder.AddNode("Gemm", input_args, {output_arg}); + } + + if (beta_not_one) { + gemm_node->AddAttribute("beta", 2.0f); + } + }; + + auto check_binary_op_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + if ((!has_output_q || std::is_same_v)&&(!has_bias || (std::is_same_v && !beta_not_one)) && + (std::is_same_v || std::is_same_v)) { + EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1); + EXPECT_EQ(op_to_count["Gemm"], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], has_output_q ? 1 : 0); + } else { + int q_count = 2; // Q for A and B + int dq_count = 2; // DQ for A and B + if (has_bias) { + dq_count++; + } + if (has_output_q) { + q_count++; + dq_count++; + } + EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 0); + EXPECT_EQ(op_to_count["Gemm"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], q_count); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], dq_count); + } + }; + + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + }; + + TransformerTester(build_test_case, + check_binary_op_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + TransformerTester(build_test_case, + check_binary_op_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 18 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + TransformerTester(build_test_case, + check_binary_op_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 19 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + + if (disable_fastmath) { + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + }; + + TransformerTester(build_test_case, + check_binary_op_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + } + }; + + test_case({2, 2}, {2, 4}); + test_case({13, 15}, {15, 15}); + test_case({2, 2}, {2, 4}, true); // Use com.microsoft QDQ ops +} + +template +void QDQTransformerGemmTests() { + QDQTransformerGemmTests(false, false); + QDQTransformerGemmTests(false, true); + QDQTransformerGemmTests(true, false); + QDQTransformerGemmTests(true, true); + QDQTransformerGemmTests(false, false, true); + QDQTransformerGemmTests(false, true, true); + QDQTransformerGemmTests(true, false, true); + QDQTransformerGemmTests(true, true, true); + // dummy test to disable the fastmath session + QDQTransformerGemmTests(true, true, true, true); +} + +TEST(QDQTransformerTests, Gemm_U8U8U8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_U8S8S8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_U8U8S8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_U8S8U8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_S8S8S8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_S8U8U8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_S8U8S8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_S8S8U8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, MatMul_No_Fusion_FastMath) { + auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape, + bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, -1.f, 1.f); + auto* input2_arg = builder.MakeInput(input2_shape, -1.f, 1.f); + auto* output_arg = builder.MakeOutput(); + + // add QDQ + MatMul + auto* matmul_output = builder.MakeIntermediate(); + auto* dq_matmul_output1 = AddQDQNodePair(builder, input1_arg, .004f, 129, use_contrib_qdq); + builder.AddNode("MatMul", {dq_matmul_output1, input2_arg}, {matmul_output}); + + // add Q + builder.AddQuantizeLinearNode(matmul_output, .0039f, 135, output_arg, use_contrib_qdq); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count["QLinearMatMul"], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + }; + + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options); + + auto add_session_options_disable_fm = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options_disable_fm); + }; + + test_case({12, 37}, {37, 12}, false /*use_contrib_qdq*/); + test_case({12, 37}, {37, 12}, true /*use_contrib_qdq*/); +} + +TEST(QDQTransformerTests, MatMul_1st_Input_Int8_FastMath) { + auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape, + bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, -128, 127); + auto* input2_arg = builder.MakeInput(input2_shape, -1.f, 1.f); + auto* output_arg = builder.MakeOutput(); + + // add DQ with type int8 + auto* dq_output_1 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input1_arg, .004f, 1, dq_output_1, use_contrib_qdq); + + // add QDQ + MatMul + auto* matmul_output = builder.MakeIntermediate(); + auto* dq_matmul_output2 = AddQDQNodePair(builder, input2_arg, .004f, 129, use_contrib_qdq); + builder.AddNode("MatMul", {dq_output_1, dq_matmul_output2}, {matmul_output}); + + // add Q + builder.AddQuantizeLinearNode(matmul_output, .0039f, 135, output_arg, use_contrib_qdq); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count["QLinearMatMul"], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 2); + }; + + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options); + + auto add_session_options_disable_fm = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options_disable_fm); + }; + + test_case({12, 37}, {37, 12}, false /*use_contrib_qdq*/); + test_case({12, 37}, {37, 12}, true /*use_contrib_qdq*/); + test_case({23, 13, 13}, {13, 13}, false /*use_contrib_qdq*/); + test_case({22, 11, 13, 15}, {15, 13}, false /*use_contrib_qdq*/); +} + +TEST(QDQTransformerTests, MatMulIntegerToFloat_FastMath) { + auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape, + bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + auto* input2_arg = builder.MakeInput(input2_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + auto* output_arg = builder.MakeOutput(); + + // add DQ + auto* dq_output_1 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input1_arg, .0035f, 135, dq_output_1, use_contrib_qdq); + + auto* dq_output_2 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input2_arg, .0035f, 135, dq_output_2, use_contrib_qdq); + + builder.AddNode("MatMul", {dq_output_1, dq_output_2}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, + add_session_options); + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 19 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, + add_session_options); + + auto add_session_options_disable_fm = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_disable_fm); + }; + + test_case({12, 37}, {37, 12}, false /*use_contrib_qdq*/); + test_case({12, 37}, {37, 12}, true /*use_contrib_qdq*/); + test_case({23, 13, 13}, {13, 13}, false /*use_contrib_qdq*/); + test_case({22, 11, 13, 15}, {15, 13}, false /*use_contrib_qdq*/); +} + +#endif // !defined(DISABLE_CONTRIB_OPS) && defined(__aarch64) + +} // namespace test +} // namespace onnxruntime + +#endif // defined(__aarch64) && defined(__linux__) && !defined(DISABLE_CONTRIB_OPS) diff --git a/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc b/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc new file mode 100644 index 0000000000000..75e0c06b04f0d --- /dev/null +++ b/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc @@ -0,0 +1,305 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// Licensed under the MIT License. + +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "test/providers/run_options_config_keys.h" +#include "test/common/dnnl_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/common/tensor_op_test_utils.h" +#include "default_providers.h" + +#if defined(__aarch64__) && defined(__linux__) + +namespace onnxruntime { +namespace test { + +namespace { + +const onnxruntime::RunOptions run_options = []() { + onnxruntime::RunOptions options{}; + ORT_THROW_IF_ERROR(options.config_options.AddConfigEntry(kOpTesterRunOptionsConfigTestTunableOp, "true")); + return options; +}(); + +const constexpr auto run_with_tunable_op = &run_options; + +} // namespace + +template +struct MatMulTestData { + std::string name; + std::vector input0_dims; + std::vector input1_dims; + std::vector expected_dims; + std::vector expected_vals; +}; + +template +std::vector> GenerateTestCases() { + std::vector> test_cases; + test_cases.push_back( + {"test padding and broadcast A > B", + {3, 1, 1, 6}, + {2, 6, 7}, + {3, 2, 1, 7}, + {385, 400, 415, 430, 445, 460, 475, 1015, 1030, 1045, 1060, 1075, 1090, 1105, 1015, 1066, 1117, 1168, 1219, 1270, 1321, 3157, 3208, 3259, 3310, 3361, 3412, 3463, 1645, 1732, 1819, 1906, 1993, 2080, 2167, 5299, 5386, 5473, 5560, 5647, 5734, 5821}}); + + test_cases.push_back( + {"test padding and broadcast B > A", + {2, 3, 12}, + {3, 2, 12, 3}, + {3, 2, 3, 3}, + {1518, 1584, 1650, 3894, 4104, 4314, 6270, 6624, 6978, 26574, 27072, 27570, 34134, 34776, 35418, 41694, 42480, 43266, 6270, 6336, 6402, 19014, 19224, 19434, 31758, 32112, 32466, 62430, 62928, 63426, 80358, 81000, 81642, 98286, 99072, 99858, 11022, 11088, 11154, 34134, 34344, 34554, 57246, 57600, 57954, 98286, 98784, 99282, 126582, 127224, 127866, 154878, 155664, 156450}}); + + test_cases.push_back( + {"test 2D", + {8, 6}, + {6, 6}, + {8, 6}, + {330, 345, 360, 375, 390, 405, 870, 921, 972, 1023, 1074, 1125, 1410, 1497, 1584, 1671, 1758, 1845, 1950, 2073, 2196, 2319, 2442, 2565, 2490, 2649, 2808, 2967, 3126, 3285, 3030, 3225, 3420, 3615, 3810, 4005, 3570, 3801, 4032, 4263, 4494, 4725, 4110, 4377, 4644, 4911, 5178, 5445}}); + + test_cases.push_back( + {"test 2D special", + {2, 2, 16}, + {16, 4}, + {2, 2, 4}, + {4960, 5080, 5200, 5320, 12640, 13016, 13392, 13768, 20320, 20952, 21584, 22216, 28000, 28888, 29776, 30664}}); + + test_cases.push_back( + {"test 2D special 2", + {2, 2, 9}, + {1, 9, 4}, + {2, 2, 4}, + {816, 852, 888, 924, 2112, 2229, 2346, 2463, 3408, 3606, 3804, 4002, 4704, 4983, 5262, 5541}}); + + test_cases.push_back( + {"test 2D special 3", + {2, 12}, + {1, 1, 12, 3}, + {1, 1, 2, 3}, + {1518, 1584, 1650, 3894, 4104, 4314}}); + + test_cases.push_back( + {"test 3D batch", + {3, 1, 18}, + {3, 18, 2}, + {3, 1, 2}, + { + // clang-format off + 3570, 3723, + 26250, 26727, + 72258, 73059, + // clang-format on + }}); + + test_cases.push_back( + {"test 4D batch", + {2, 2, 1, 20}, + {2, 2, 20, 2}, + {2, 2, 1, 2}, + { + // clang-format off + 4940, 5130, + 36140, 36730, + 99340, 100330, + 194540, 195930, + // clang-format on + }}); + + return test_cases; +} + +template +void RunMatMulTest(int32_t opset_version, bool is_a_constant, bool is_b_constant, bool disable_fastmath) { + for (auto t : GenerateTestCases()) { + SCOPED_TRACE("test case: " + t.name); + + OpTester test("MatMul", opset_version); + + int64_t size0 = TensorShape::FromExistingBuffer(t.input0_dims).SizeHelper(0, t.input0_dims.size()); + std::vector input0_vals = ValueRange(size0); + + test.AddInput("A", t.input0_dims, input0_vals, is_a_constant); + + int64_t size1 = TensorShape::FromExistingBuffer(t.input1_dims).SizeHelper(0, t.input1_dims.size()); + std::vector input1_vals = ValueRange(size1); + test.AddInput("B", t.input1_dims, input1_vals, is_b_constant); + + test.AddOutput("Y", t.expected_dims, t.expected_vals); + + // OpenVINO EP: Disabled temporarily matmul broadcasting not fully supported + // Disable TensorRT because of unsupported data type + std::unordered_set excluded_providers{kTensorrtExecutionProvider, kOpenVINOExecutionProvider}; + if (t.name == "test 2D empty input") { + // NNAPI: currently fails for the "test 2D empty input" case + excluded_providers.insert(kNnapiExecutionProvider); + } + + if ("test padding and broadcast A > B" == t.name || "test 2D empty input" == t.name) { + // QNN can't handle 0 shap + excluded_providers.insert(kQnnExecutionProvider); + } + + SessionOptions so; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + + test.ConfigExcludeEps(excluded_providers) + .Config(run_with_tunable_op) + .Config(so) + .RunWithConfig(); + + if (disable_fastmath) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + + test.ConfigExcludeEps(excluded_providers) + .Config(run_with_tunable_op) + .Config(so) + .RunWithConfig(); + } + } +} + +template +void RunMatMulTest(int32_t opset_version) { + RunMatMulTest(opset_version, false, false, false); +} + +TEST(MathOpTest, MatMulFloatType_FastMath) { + // TODO: Unskip when fixed #41968513 + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because of the following error: Assertion failed: m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType)"; + } + RunMatMulTest(7, false, false, false); +} + +TEST(MathOpTest, MatMulFloatTypeInitializer_FastMath) { + // TODO: Unskip when fixed #41968513 + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because of the following error: Assertion failed: m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType)"; + } + RunMatMulTest(7, false, true, false); +} + +TEST(MathOpTest, MatMulInt32Type_FastMath) { + RunMatMulTest(9); +} + +TEST(MathOpTest, MatMulUint32Type_FastMath) { + RunMatMulTest(9); +} + +TEST(MathOpTest, MatMulInt64Type_FastMath) { + RunMatMulTest(9); +} + +TEST(MathOpTest, MatMulUint64Type_FastMath) { + RunMatMulTest(9); +} + +#ifndef ENABLE_TRAINING +// Prepacking is disabled in full training build so no need to test the feature in a training build. +TEST(MathOpTest, MatMulSharedPrepackedWeights_FastMath) { + OpTester test("MatMul"); + + std::vector b_init_values(32, 1.0f); + test.AddInput("A", {8, 4}, + {1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f, + 1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f, + 1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f, + 1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}); + // B is to be an initializer for triggering pre-packing + test.AddInput("B", {4, 8}, b_init_values, true); + + test.AddOutput("Y", {8, 8}, + {10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, + -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, + 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, + -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, + 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, + -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, + 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, + -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f}); + + OrtValue b; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({4, 8}), + b_init_values.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), b); + + SessionOptions so; + // Set up B as a shared initializer to be shared between sessions + ASSERT_EQ(so.AddInitializer("B", &b), Status::OK()); + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + + // We want all sessions running using this OpTester to be able to share pre-packed weights if applicable + test.EnableSharingOfPrePackedWeightsAcrossSessions(); + + // Pre-packing is limited just to the CPU EP for now and we will only test the CPU EP + // and we want to ensure that it is available in this build + auto cpu_ep = []() -> std::vector> { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + return execution_providers; + }; + + size_t number_of_pre_packed_weights_counter_session_1 = 0; + size_t number_of_shared_pre_packed_weights_counter = 0; + + // Session 1 + { + test.Config(so) + .Config(run_with_tunable_op) + .ConfigEps(cpu_ep()) + .RunWithConfig(&number_of_pre_packed_weights_counter_session_1, &number_of_shared_pre_packed_weights_counter); + // Assert that no pre-packed weights have been shared thus far + ASSERT_EQ(number_of_shared_pre_packed_weights_counter, static_cast(0)); + } + + auto number_of_elements_in_shared_prepacked_buffers_container = + test.GetNumPrePackedWeightsShared(); + // Assert that the number of elements in the shared container + // is the same as the number of weights that have been pre-packed + ASSERT_EQ(number_of_pre_packed_weights_counter_session_1, number_of_elements_in_shared_prepacked_buffers_container); + + // On some platforms/architectures MLAS may choose to not do any pre-packing and the number of elements + // that have been pre-packed will be zero in which case we do not continue with the testing + // of "sharing" of pre-packed weights as there are no pre-packed weights to be shared at all. + if (number_of_pre_packed_weights_counter_session_1 == 0) + return; + + // Session 2 + { + size_t number_of_pre_packed_weights_counter_session_2 = 0; + test.Config(so) + .Config(run_with_tunable_op) + .ConfigEps(cpu_ep()) + .RunWithConfig(&number_of_pre_packed_weights_counter_session_2, &number_of_shared_pre_packed_weights_counter); + + // Assert that the same number of weights were pre-packed in both sessions + ASSERT_EQ(number_of_pre_packed_weights_counter_session_1, number_of_pre_packed_weights_counter_session_2); + + // Assert that the number of pre-packed weights that were shared equals + // the number of pre-packed weights in the second session + ASSERT_EQ(number_of_pre_packed_weights_counter_session_2, + static_cast(number_of_shared_pre_packed_weights_counter)); + } +} + +#endif + +// Dummy run to disable the FastMath mode for the current session +TEST(MathOpTest, MatMulUint64Type_DisableFastMath) { + RunMatMulTest(9, false, false, true); +} + +} // namespace test +} // namespace onnxruntime +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/test/util/compare_ortvalue.cc b/onnxruntime/test/util/compare_ortvalue.cc index 3d53d4a3a0193..64ebe24188762 100644 --- a/onnxruntime/test/util/compare_ortvalue.cc +++ b/onnxruntime/test/util/compare_ortvalue.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. // Licensed under the MIT License. #include "test/compare_ortvalue.h" @@ -65,6 +66,54 @@ const char* ElementTypeToString(MLDataType type) { return DataTypeImpl::ToString(type); } +#if defined(__aarch64__) && defined(__linux__) +template +std::pair CheckCosineSimilarity(const Tensor& outvalue, const Tensor& expected_value) { + const size_t tensor_size = static_cast(expected_value.Shape().Size()); + const T* expected_output = expected_value.Data(); + const T* real_output = outvalue.Data(); + std::pair res = std::make_pair(COMPARE_RESULT::SUCCESS, ""); + const T cosine_similarity_threshold = 0.99f; + + T dot = 0.0f, denom_a = 0.0f, denom_b = 0.0f; + for (size_t i = 0u; i < tensor_size; ++i) { + if (isnan(expected_output[i]) && isnan(real_output[i])) + continue; + if (isinf(expected_output[i]) && isinf(real_output[i])) + continue; + dot += expected_output[i] * real_output[i]; + denom_a += expected_output[i] * expected_output[i]; + denom_b += real_output[i] * real_output[i]; + } + + T cos_factor = abs(dot / (sqrt(denom_a) * sqrt(denom_b))); + if (cos_factor < cosine_similarity_threshold) { + res.first = COMPARE_RESULT::RESULT_DIFFERS; + std::ostringstream oss; + oss << std::hex << "results differed, cosine similarity factor is " << cos_factor << "."; + res.second = oss.str(); + } + return res; +} + +template +std::pair CheckCloseMatch(const Tensor& outvalue, const Tensor& expected_value) { + const size_t size1 = static_cast(expected_value.Shape().Size()); + const T* expected_output = expected_value.Data(); + const T* real_output = outvalue.Data(); + const T close_match_threshold = 1.0; + + for (size_t di = 0; di != size1; ++di) { + const T diff = expected_output[di] - real_output[di]; + if (std::fabs(diff) > close_match_threshold) { + std::ostringstream oss; + oss << "expected " << expected_output[di] << ", got " << real_output[di]; + return std::make_pair(COMPARE_RESULT::RESULT_DIFFERS, oss.str()); + } + } + return std::make_pair(COMPARE_RESULT::SUCCESS, ""); +} +#endif /** * @brief Check if two values are closely matched with given tolerance. @@ -207,6 +256,37 @@ std::pair CompareTwoTensors(const Tensor& outvalue, oss << "shape mismatch, expect " << expected_tensor.Shape().ToString() << " got " << outvalue.Shape().ToString(); return std::make_pair(COMPARE_RESULT::SHAPE_MISMATCH, oss.str()); } + +#if defined(__aarch64__) && defined(__linux__) + if (isnan(per_sample_tolerance) || isnan(per_sample_tolerance)) { + if (outvalue.IsDataType()) { + return CheckCosineSimilarity(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCosineSimilarity(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else { + return std::make_pair(COMPARE_RESULT::NOT_SUPPORT, ""); + } + } +#endif + if (outvalue.IsDataType()) { return CompareFloatResult(outvalue, expected_tensor, per_sample_tolerance, relative_per_sample_tolerance, post_processing);