Skip to content

Commit

Permalink
connect kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
fajin-corp committed Dec 11, 2024
1 parent 8716319 commit 2901467
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 8 deletions.
6 changes: 4 additions & 2 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,10 @@ class GQAAttentionBase {
math::GemmEx<float, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q,
static_cast<int>(head_size), k, static_cast<int>(head_size), 0.0f /*bata*/,
output, static_cast<int>(present_buffer_sequence_length), nullptr);
} else if (GetMlasPlatform().HasFP16Support()) {
// TODO: if kernel available, call MlasHGemmEx
} else if (MlasHGemmSupported(CblasNoTrans, CblasTrans)) {
MlasGemm(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size,
q, static_cast<int>(head_size), k, static_cast<int>(head_size), output,
static_cast<int>(present_buffer_sequence_length), alpha, 0.0f /*beta*/, nullptr);
} else {
size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof(float);
auto q_k_fp32 = allocator->Alloc(bytes);
Expand Down
58 changes: 54 additions & 4 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1458,14 +1458,55 @@ MlasRotaryEmbedOneRow(
T* output
);

/**
* @brief Supply matrices data information to half precision gemm functions
*/
struct MLAS_HGEMM_DATA_PARAMS {
const MLAS_FP16* A = nullptr; /**< Supplies the address of matrix A */
size_t lda = 0; /**< Supplies the first dimension of matrix A. */
const MLAS_FP16* B = nullptr; /**< Supplies the address of matrix B */
size_t ldb = 0; /**< Supplies the first dimension of matrix B. */
MLAS_FP16* C = nullptr; /**< Supplies the address of matrix C */
size_t ldc = 0; /**< Supplies the first dimension of matrix C. */
MLAS_FP16 alpha = MLAS_FP16(1.0f); /**< Supplies the scalar alpha multiplier (see GEMM definition) */
MLAS_FP16 beta = MLAS_FP16(0.0f); /**< Supplies the scalar beta multiplier (see GEMM definition) */
};

/**
* @brief Check whether current CPU supports half precision gemm.
*/
bool
MLASCALL
MlasHGemmSupported(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB);
CBLAS_TRANSPOSE TransB
);

/**
* @brief Batched half precision matrix/matrix multiply operation (HGEMM)
*
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
* @param N Supplies the number of columns of matrix B and matrix C.
* @param K Supplies the number of columns of matrix A and the number of rows of matrix B.
* @param Data A array of matrices data parameters
* @param BatchSize Supplies number of multiplications in this batch
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
*/
void
MLASCALL
MlasGemmBatch(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
const MLAS_HGEMM_DATA_PARAMS* Data,
size_t BatchSize,
MLAS_THREADPOOL* ThreadPool
);

/**
* @brief half precision matrix/matrix multiply operation (HGEMM)
Expand All @@ -1487,8 +1528,8 @@ MlasHGemmSupported(
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the base library threading support
* should be used.
*/
inline
void
MLASCALL
MlasGemm(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
Expand All @@ -1505,10 +1546,19 @@ MlasGemm(
MLAS_FP16 beta,
MLAS_THREADPOOL* ThreadPool
) {
// TODO: call MlasGemmBatch for hgemm
MLAS_HGEMM_DATA_PARAMS Data;
Data.alpha = alpha;
Data.A = A;
Data.lda = lda;
Data.B = B;
Data.ldb = ldb;
Data.beta = beta;
Data.C = C;
Data.ldc = ldc;
MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool);
}

/**
/**
* @brief Whether current CPU supports FP16 acceleration.
*/
bool MLASCALL
Expand Down
28 changes: 28 additions & 0 deletions onnxruntime/core/mlas/lib/halfgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,34 @@ MlasHalfGemmKernel<MLAS_HALF_GEMM_KERNEL_DEFAULT>(
}
}

bool
MLASCALL
MlasHGemmSupported(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB
) {
auto* dispatch = GetMlasPlatform().HGemmDispatch;
if (TransA == CblasNoTrans && TransB == CblasTrans) {
return dispatch && dispatch->HGemmKernel_TransposeB;
}

return false;
}

void
MLASCALL
MlasGemmBatch(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
const MLAS_HGEMM_DATA_PARAMS* Data,
size_t BatchSize,
MLAS_THREADPOOL* ThreadPool
) {

}

const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchDefault = {
MlasHalfGemmOperation<MLAS_HALF_GEMM_KERNEL_DEFAULT>,
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,11 @@ const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchNeon = {
MLAS_HALF_GEMM_KERNEL_NEON::KernelMaxM,
32 // kernel may read beyond buffer end by 32 bytes
};

const MLAS_HGEMM_DISPATCH MlasHGemmDispatchNeon = [](){
MLAS_HGEMM_DISPATCH d;
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)
d.HGemmKernel_TransposeB = nullptr;
#endif
return d;
}();
4 changes: 2 additions & 2 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -1059,7 +1059,7 @@ extern const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon;
// half gemm dispatch structure
//
struct MLAS_HGEMM_DISPATCH;
extern const MLAS_HGEMM_DISPATCH MlasHgemmDispatchNeon;
extern const MLAS_HGEMM_DISPATCH MlasHGemmDispatchNeon;


//
Expand Down Expand Up @@ -1223,7 +1223,7 @@ struct MLAS_PLATFORM {
MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel;

const MLAS_ROPE_DISPATCH* RopeDispatch{nullptr};
const MLAS_HGEMM_DISPATCH* HGemmDIspatch{nullptr};
const MLAS_HGEMM_DISPATCH* HGemmDispatch{nullptr};
};

inline
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,7 @@ Return Value:
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon;
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;
this->RopeDispatch = &MlasRopeDispatchNeon;
this->HGemmDispatch = &MlasHGemmDispatchNeon;

//
// Check if the processor supports ASIMD dot product instructions.
Expand Down

0 comments on commit 2901467

Please sign in to comment.