From 2901467e97f8e1d6ac8826d9840c96a536079dac Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Wed, 11 Dec 2024 23:57:34 +0000 Subject: [PATCH] connect kernels --- .../contrib_ops/cpu/bert/gqa_attention_base.h | 6 +- onnxruntime/core/mlas/inc/mlas.h | 58 +++++++++++++++++-- onnxruntime/core/mlas/lib/halfgemm.cpp | 28 +++++++++ .../core/mlas/lib/halfgemm_kernel_neon.cpp | 8 +++ onnxruntime/core/mlas/lib/mlasi.h | 4 +- onnxruntime/core/mlas/lib/platform.cpp | 1 + 6 files changed, 97 insertions(+), 8 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 4428ab8f844b3..3b492e4256837 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -198,8 +198,10 @@ class GQAAttentionBase { math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q, static_cast(head_size), k, static_cast(head_size), 0.0f /*bata*/, output, static_cast(present_buffer_sequence_length), nullptr); - } else if (GetMlasPlatform().HasFP16Support()) { - // TODO: if kernel available, call MlasHGemmEx + } else if (MlasHGemmSupported(CblasNoTrans, CblasTrans)) { + MlasGemm(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, + q, static_cast(head_size), k, static_cast(head_size), output, + static_cast(present_buffer_sequence_length), alpha, 0.0f /*beta*/, nullptr); } else { size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof(float); auto q_k_fp32 = allocator->Alloc(bytes); diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 4b51ac73422ef..e5c428ff68cde 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1458,6 +1458,20 @@ MlasRotaryEmbedOneRow( T* output ); +/** + * @brief Supply matrices data information to half precision gemm functions + */ +struct MLAS_HGEMM_DATA_PARAMS { + const MLAS_FP16* A = nullptr; /**< Supplies the address of matrix A */ + size_t lda = 0; /**< Supplies the first dimension of matrix A. */ + const MLAS_FP16* B = nullptr; /**< Supplies the address of matrix B */ + size_t ldb = 0; /**< Supplies the first dimension of matrix B. */ + MLAS_FP16* C = nullptr; /**< Supplies the address of matrix C */ + size_t ldc = 0; /**< Supplies the first dimension of matrix C. */ + MLAS_FP16 alpha = MLAS_FP16(1.0f); /**< Supplies the scalar alpha multiplier (see GEMM definition) */ + MLAS_FP16 beta = MLAS_FP16(0.0f); /**< Supplies the scalar beta multiplier (see GEMM definition) */ +}; + /** * @brief Check whether current CPU supports half precision gemm. */ @@ -1465,7 +1479,34 @@ bool MLASCALL MlasHGemmSupported( CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB); + CBLAS_TRANSPOSE TransB + ); + +/** + * @brief Batched half precision matrix/matrix multiply operation (HGEMM) + * + * @param TransA Supplies the transpose operation for matrix A. + * @param TransB Supplies the transpose operation for matrix B. + * @param M Supplies the number of rows of matrix A and matrix C. + * @param N Supplies the number of columns of matrix B and matrix C. + * @param K Supplies the number of columns of matrix A and the number of rows of matrix B. + * @param Data A array of matrices data parameters + * @param BatchSize Supplies number of multiplications in this batch + * @param ThreadPool Supplies the thread pool object to use, else nullptr if the + base library threading support should be used. + */ +void +MLASCALL +MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_HGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool + ); /** * @brief half precision matrix/matrix multiply operation (HGEMM) @@ -1487,8 +1528,8 @@ MlasHGemmSupported( * @param ThreadPool Supplies the thread pool object to use, else nullptr if the base library threading support * should be used. */ +inline void -MLASCALL MlasGemm( CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, @@ -1505,10 +1546,19 @@ MlasGemm( MLAS_FP16 beta, MLAS_THREADPOOL* ThreadPool ) { - // TODO: call MlasGemmBatch for hgemm + MLAS_HGEMM_DATA_PARAMS Data; + Data.alpha = alpha; + Data.A = A; + Data.lda = lda; + Data.B = B; + Data.ldb = ldb; + Data.beta = beta; + Data.C = C; + Data.ldc = ldc; + MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool); } - /** +/** * @brief Whether current CPU supports FP16 acceleration. */ bool MLASCALL diff --git a/onnxruntime/core/mlas/lib/halfgemm.cpp b/onnxruntime/core/mlas/lib/halfgemm.cpp index 49387d2fc998f..6f24d5305392f 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm.cpp @@ -324,6 +324,34 @@ MlasHalfGemmKernel( } } +bool +MLASCALL +MlasHGemmSupported( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB +) { + auto* dispatch = GetMlasPlatform().HGemmDispatch; + if (TransA == CblasNoTrans && TransB == CblasTrans) { + return dispatch && dispatch->HGemmKernel_TransposeB; + } + + return false; +} + +void +MLASCALL +MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_HGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool +) { + +} const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchDefault = { MlasHalfGemmOperation, diff --git a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp index d7f5a90b00589..bc49f8e6142de 100644 --- a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp @@ -185,3 +185,11 @@ const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchNeon = { MLAS_HALF_GEMM_KERNEL_NEON::KernelMaxM, 32 // kernel may read beyond buffer end by 32 bytes }; + +const MLAS_HGEMM_DISPATCH MlasHGemmDispatchNeon = [](){ + MLAS_HGEMM_DISPATCH d; +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + d.HGemmKernel_TransposeB = nullptr; +#endif + return d; +}(); diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index ba4a8d3e4abcc..cceedfc537253 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1059,7 +1059,7 @@ extern const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon; // half gemm dispatch structure // struct MLAS_HGEMM_DISPATCH; -extern const MLAS_HGEMM_DISPATCH MlasHgemmDispatchNeon; +extern const MLAS_HGEMM_DISPATCH MlasHGemmDispatchNeon; // @@ -1223,7 +1223,7 @@ struct MLAS_PLATFORM { MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel; const MLAS_ROPE_DISPATCH* RopeDispatch{nullptr}; - const MLAS_HGEMM_DISPATCH* HGemmDIspatch{nullptr}; + const MLAS_HGEMM_DISPATCH* HGemmDispatch{nullptr}; }; inline diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index ec572a4150292..026a954bbc6c2 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -544,6 +544,7 @@ Return Value: this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon; this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; this->RopeDispatch = &MlasRopeDispatchNeon; + this->HGemmDispatch = &MlasHGemmDispatchNeon; // // Check if the processor supports ASIMD dot product instructions.