Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into snnn/fix_333
Browse files Browse the repository at this point in the history
  • Loading branch information
snnn committed Jan 22, 2024
2 parents a00c5bf + 77da2ef commit 60549f4
Show file tree
Hide file tree
Showing 30 changed files with 4,024 additions and 136 deletions.
4 changes: 4 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -355,19 +355,23 @@ else()
${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S
${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S
${MLAS_SRC_DIR}/activate_fp16.cpp
${MLAS_SRC_DIR}/dwconv.cpp
${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/pooling_fp16.cpp
${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp
${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
endif()

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
Expand Down
8 changes: 8 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3608,6 +3608,14 @@ struct OrtApi {
* - "1": Faster preparation time, less optimal graph.
* - "2": Longer preparation time, more optimal graph.
* - "3": Longest preparation time, most likely even more optimal graph. See QNN SDK documentation for specific details.
* "soc_model": The SoC model number. Refer to the QNN SDK documentation for valid values. Defaults to "0" (unknown).
* "htp_arch": The minimum HTP architecture the driver will use to select compatible QNN operators. Available options:
* - "0": Default (none).
* - "68"
* - "69"
* - "73"
* - "75"
* "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device).
*
* SNPE supported keys:
* "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,4 +249,10 @@ static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_p
// Flag to specify whether to dump the EP context into the Onnx model.
// "0": dump the EP context into separate file, keep the file name in the Onnx model.
// "1": dump the EP context into the Onnx model. (default).
static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode";
static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode";

// Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul.
// Option values:
// - "0": Gemm FastMath mode is not enabled. [DEFAULT]
// - "1": Gemm FastMath mode is enabled.
static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";
7 changes: 7 additions & 0 deletions onnxruntime/core/common/cpuid_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
#define HWCAP2_SVEI8MM (1 << 9)
#endif

#ifndef HWCAP2_BF16
#define HWCAP2_BF16 (1 << 14)
#endif

#endif // ARM

#endif // Linux
Expand Down Expand Up @@ -148,6 +152,7 @@ void CPUIDInfo::ArmLinuxInit() {
has_fp16_ = cpuinfo_has_arm_neon_fp16_arith();
has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm();
has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm();
has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16();

const uint32_t core_cnt = cpuinfo_get_cores_count();
core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown);
Expand Down Expand Up @@ -177,6 +182,7 @@ void CPUIDInfo::ArmLinuxInit() {
has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0);
has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0);

has_arm_neon_bf16_ = ((getauxval(AT_HWCAP2) & HWCAP2_BF16) != 0);
#endif
}

Expand Down Expand Up @@ -278,6 +284,7 @@ void CPUIDInfo::ArmWindowsInit() {
/* TODO: implement them when hw+sw is available for testing these features */
has_arm_neon_i8mm_ = false;
has_arm_sve_i8mm_ = false;
has_arm_neon_bf16_ = false;
}

#endif /* (arm or arm64) and windows */
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/common/cpuid_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class CPUIDInfo {
bool HasArmNeonDot() const { return has_arm_neon_dot_; }
bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; }
bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; }
bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; }

uint32_t GetCurrentCoreIdx() const;

Expand Down Expand Up @@ -125,6 +126,7 @@ class CPUIDInfo {
bool has_fp16_{false};
bool has_arm_neon_i8mm_{false};
bool has_arm_sve_i8mm_{false};
bool has_arm_neon_bf16_{false};

#ifdef CPUIDINFO_ARCH_X86

Expand Down
113 changes: 113 additions & 0 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,119 @@ MlasHalfGemmConvertPackB(
void* PackedB
);

#if defined(__aarch64__) && defined(__linux__)
/**
* @brief Whether current CPU supports Bfloat16(bf16) acceleration.
*/
bool MLASCALL
MlasBf16AccelerationSupported();

/**
* @brief Interface for bf16 gemm post processors.
*
* Example implementation of this interface includes activations,
* conversion from single precision to precision, etc.
*
* SBGEMM is computed tile by tile. When a tile of result matrix
* is produced, the method Process() is called to process this tile.
* Parameters of this method describe the location and shape of the
* tile.
*/
class MLAS_SBGEMM_POSTPROCESSOR
{
public:
virtual void Process(float*, /**< the address of matrix to process */
size_t, /**< the start row index of matrix */
size_t, /**< the start col index of matrix */
size_t, /**< the element count per row to process */
size_t, /**< the element count per col to process */
size_t /**< the leading dimension of matrix */
) const = 0;

virtual ~MLAS_SBGEMM_POSTPROCESSOR() {}
};

/**
* @brief bfloat16 precision activation functions, with optional sum tensor.
* Supplied sum tensor must be the same layout as the GEMM output tensor.
* And the supplied sum tensor will be added to the tensor before activation.
*/
class MLAS_SBGEMM_ACTIVATION_PROCESSOR : public MLAS_SBGEMM_POSTPROCESSOR
{
public:
MLAS_SBGEMM_ACTIVATION_PROCESSOR(const MLAS_ACTIVATION& Activation, const float* SumBuf = nullptr)
: Activation_(Activation), SumBuf_(SumBuf)
{
}

void Process(float* C, size_t StartM, size_t StartN, size_t CountM, size_t CountN, size_t ldc)
const override;

private:
const MLAS_ACTIVATION& Activation_;
const float* SumBuf_;
};

/**
* @brief Data parameters for bfloat16 precision GEMM routine
* All except C are [in] parameters
*/
struct MLAS_SBGEMM_DATA_PARAMS {
const void* A = nullptr; /**< address of A */
const void* B = nullptr; /**< address of B */
const float* Bias = nullptr; /**< address of Bias, vector size N */
float* C = nullptr; /**< address of result matrix */
size_t lda = 0; /**< leading dimension of A */
size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/
size_t ldc = 0; /**< leading dimension of C*/
const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr;
bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/
bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/
};

/**
* @brief Bfloat16 precision Batched GEMM: C = A * B + Bias
* Either B can be either fp32 or bf16
*
* Note: We only support uniform batching, so shapes and types of the
* input must be same across all parameter blocks.
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BatchN number of batches
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @param[in] ThreadPool
* @return
*/
void MLASCALL
MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool = nullptr);

/**
* @brief For bfloat16 precision GEMM, returns size of the
* packing buffer needed for right hand side
* @param[in] N Number of columns
* @param[in] K Number of rows
* @return size of the packing buffer,
* 0 if operation not supported
*/
size_t MLASCALL
MlasSBGemmPackBSize(size_t N, size_t K);

/**
* @brief For bfloat16 precision GEMM, convert the float matrix B
* to blfoat16 precision and pack it into a packing buffer
*
* @param[in] N Number of columns
* @param[in] K Number of rows
* @param[in] B Address of matrix B
* @param[in] ldb leading dimension of input matrix B
* @param[out] PackedB Address of the packed matrix
*/
void MLASCALL
MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB);
#endif

/**
* @brief Indirect Depthwise convolution for fp16
* @param Input Supplies the indirect buffer for NHWC input
Expand Down
Loading

0 comments on commit 60549f4

Please sign in to comment.