diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 04efa5c2b4f6d..26e4380af4c23 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -284,6 +284,8 @@ else() set(X86 TRUE) elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") set(X86_64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^loongarch64.*") + set(LOONGARCH64 TRUE) endif() endif() @@ -575,6 +577,26 @@ else() set(MLAS_SOURCE_IS_NOT_SET 0) endif() endif() + if(LOONGARCH64 AND MLAS_SOURCE_IS_NOT_SET) + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/qgemm_kernel_lsx.cpp + ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLsx.S + ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLsx.S + ${MLAS_SRC_DIR}/loongarch64/SconvKernelLsx.S + ${MLAS_SRC_DIR}/loongarch64/SconvKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLSX.S + ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4LSX.S + ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4Lasx.S + ${MLAS_SRC_DIR}/loongarch64/SoftmaxKernelLasx.S + ) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mlsx -mlasx") + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH AND MLAS_SOURCE_IS_NOT_SET) file(GLOB_RECURSE mlas_platform_srcs "${MLAS_SRC_DIR}/scalar/*.cpp") diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index fd6b3df93444b..bdd4dba521eba 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -69,6 +69,9 @@ Module Name: #endif #endif +#if defined(__loongarch64) +#define MLAS_TARGET_LARCH64 +#endif // // Define the support levels for the target architecture. // @@ -87,7 +90,7 @@ Module Name: #define MLAS_F16VEC_INTRINSICS_SUPPORTED -#endif // +#endif // #endif // ARM64 #endif // Visual Studio 16 or earlier does not support fp16 intrinsic @@ -1619,7 +1622,7 @@ MlasHalfGemmConvertPackB( * @param Channels # of input channels * @param OutputCount # of output pixels * @param KernelSize # kernel size - * @return + * @return */ void MLASCALL @@ -1657,7 +1660,7 @@ MlasTranspose( * @param Channels C in NHWC * @param OutputCount Number of output pixels * @param KernelSize Size of the kernel - * @return + * @return */ void MLASCALL @@ -1676,7 +1679,7 @@ MlasNhwcMaxPool( * @param Channels C in NHWC * @param OutputCount Number of output pixels * @param KernelSize size of the kernel - * @return + * @return */ void MLASCALL diff --git a/onnxruntime/core/mlas/lib/activate.cpp b/onnxruntime/core/mlas/lib/activate.cpp index 6c4ab8ae118dc..df3b884a7e7c9 100644 --- a/onnxruntime/core/mlas/lib/activate.cpp +++ b/onnxruntime/core/mlas/lib/activate.cpp @@ -143,6 +143,8 @@ struct MLAS_ACTIVATION_FUNCTION return MlasBlendFloat32x4(ValueTimesAlpha, Value, _mm_cmple_ps(ZeroFloat32x4, Value)); #elif defined(MLAS_VSX_INTRINSICS) return vec_sel(ValueTimesAlpha, Value, vec_cmple(ZeroFloat32x4, Value)); +#elif defined(MLAS_LSX_INTRINSICS) + return MlasBlendFloat32x4(ValueTimesAlpha, Value, (__m128)__lsx_vfcmp_cle_s(ZeroFloat32x4, Value)); #else return MlasBlendFloat32x4(ValueTimesAlpha, Value, ZeroFloat32x4 < Value); #endif diff --git a/onnxruntime/core/mlas/lib/compute.cpp b/onnxruntime/core/mlas/lib/compute.cpp index 118351055157d..78cac2e617ff7 100644 --- a/onnxruntime/core/mlas/lib/compute.cpp +++ b/onnxruntime/core/mlas/lib/compute.cpp @@ -148,6 +148,9 @@ Return Value: // instead. normal = _mm_min_epi16(normal, MaximumExponent); normal = _mm_max_epi16(normal, MinimumExponent); +#elif defined(MLAS_LSX_INTRINSICS) + normal = __lsx_vmin_h(normal, MaximumExponent); + normal = __lsx_vmax_h(normal, MinimumExponent); #else normal = MlasMinimumInt32x4(normal, MaximumExponent); normal = MlasMaximumInt32x4(normal, MinimumExponent); @@ -215,6 +218,8 @@ Return Value: // N.B. SSE2 lacks a broadcast load instruction, so avoid a shuffle // and use zeroes for the upper elements. Vector = _mm_load_ss(Input); +#elif defined(MLAS_LSX_INTRINSICS) + Vector = (MLAS_FLOAT32X4)__lsx_vldrepl_w(Input, 0); #else Vector = MlasBroadcastFloat32x4(Input); #endif @@ -467,6 +472,8 @@ Return Value: // N.B. SSE2 lacks a broadcast load instruction, so avoid a shuffle and // use zeroes for the upper elements. MLAS_FLOAT32X4 Vector = _mm_load_ss(Input); +#elif defined(MLAS_LSX_INTRINSICS) + MLAS_FLOAT32X4 Vector = (MLAS_FLOAT32X4)__lsx_vldrepl_w(Input, 0); #else MLAS_FLOAT32X4 Vector = MlasBroadcastFloat32x4(Input); #endif @@ -849,7 +856,7 @@ Return Value: // Find the maximum value for the row. // -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) float Maximum = GetMlasPlatform().ReduceMaximumF32Kernel(Input, D); #else float Maximum = MlasReduceMaximumF32Kernel(Input, D); @@ -874,7 +881,7 @@ Return Value: float Parameters[] = { NegativeMaximum, std::log(Accumulation)}; -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) GetMlasPlatform().ComputeLogSoftmaxOutputF32Kernel(Input, Output, D, Parameters); #else MlasComputeLogSoftmaxOutputF32Kernel(Input, Output, D, Parameters); @@ -899,7 +906,7 @@ Return Value: float Parameters[] = { 1.0f / Accumulation }; -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) GetMlasPlatform().ComputeSoftmaxOutputF32Kernel(Output, D, Parameters); #else MlasComputeSoftmaxOutputF32Kernel(Output, D, Parameters); diff --git a/onnxruntime/core/mlas/lib/dgemm.cpp b/onnxruntime/core/mlas/lib/dgemm.cpp index 1ef63d03c8014..50c62744f1d8e 100644 --- a/onnxruntime/core/mlas/lib/dgemm.cpp +++ b/onnxruntime/core/mlas/lib/dgemm.cpp @@ -530,7 +530,7 @@ Return Value: size_t RowsHandled; -#if defined(MLAS_TARGET_AMD64_IX86) || defined (MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) RowsHandled = GetMlasPlatform().GemmDoubleKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode); #else if (ZeroMode) { diff --git a/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelCommon.h b/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelCommon.h new file mode 100644 index 0000000000000..8d812baabdf9d --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelCommon.h @@ -0,0 +1,27 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + DgemmKernelCommon.h + +Abstract: + + This module contains common kernel macros and structures for the double + precision matrix/matrix multiply operation (DGEMM). + +--*/ + +#define LFgemmElementShift 3 +#define LFgemmElementSize (1 << LFgemmElementShift) +#define LFgemmYmmElementCount (32/LFgemmElementSize) + +#include "FgemmKernelCommon.h" + +FGEMM_TYPED_INSTRUCTION(xvfadd, xvfadd.d) +FGEMM_TYPED_INSTRUCTION(xvfmadd, xvfmadd.d) +FGEMM_TYPED_INSTRUCTION(xvldrepl, xvldrepl.d) +FGEMM_TYPED_INSTRUCTION(xvfmul, xvfmul.d) diff --git a/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelLasx.S b/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelLasx.S new file mode 100644 index 0000000000000..2f197d6891579 --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelLasx.S @@ -0,0 +1,32 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + DgemmKernelLasx.s + +Abstract: + + This module implements the kernels for the double precision matrix/matrix + multiply operation (DGEMM). + + This implementation uses Lasx instructions. + +--*/ + +#include "asmmacro.h" +#include "DgemmKernelCommon.h" +#include "FgemmKernelLasxCommon.h" + + .text + +// +// Generate the GEMM kernel. +// + +FgemmKernelLasxFunction MlasGemmDoubleKernelLasx + + .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelLsx.S b/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelLsx.S new file mode 100644 index 0000000000000..63395631a9bc5 --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelLsx.S @@ -0,0 +1,217 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + DgemmKernelLsx.s + +Abstract: + + This module implements the kernels for the double precision matrix/matrix + multiply operation (DGEMM). + + This implementation uses Lsx instructions. + +--*/ + +#include "asmmacro.h" +#include "FgemmKernelLsxCommon.h" + +FGEMM_TYPED_INSTRUCTION(vfadd, vfadd.d) +/*++ + +Macro Description: + + This macro multiplies and accumulates for a 8xN block of the output matrix. + +Arguments: + + RowCount - Supplies the number of rows to process. + +Implicit Arguments: + + a1 (rsi) - Supplies the address into the matrix B data. + + vr0-vr1 - Supplies up to two elements loaded from matrix A and matrix A + plus one row. + + vr8-vr15 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockSseBy8 RowCount + + vld $vr4, $a1, 0 + vld $vr5, $a1, 16 +.if \RowCount\() == 2 + vmove $vr6, $vr4 + vmove $vr7, $vr5 +.endif + vfmadd.d $vr8, $vr4, $vr0, $vr8 + vfmadd.d $vr9, $vr5, $vr0, $vr9 +.if \RowCount\() == 2 + vfmadd.d $vr12, $vr6, $vr1, $vr12 + vfmadd.d $vr13, $vr7, $vr1, $vr13 +.endif + vld $vr4, $a1, 32 + vld $vr5, $a1, 48 +.if \RowCount\() == 2 + vmove $vr6, $vr4 + vmove $vr7, $vr5 +.endif + vfmadd.d $vr10, $vr4, $vr0, $vr10 + vfmadd.d $vr11, $vr5, $vr0, $vr11 +.if \RowCount\() == 2 + vfmadd.d $vr14, $vr6, $vr1, $vr14 + vfmadd.d $vr15, $vr7, $vr1, $vr15 +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to compute matrix multiplication for a fixed set + of rows. + +Arguments: + + RowCount - Supplies the number of rows to process. + + Fallthrough - Supplies a non-blank value if the macro may fall through to + the ExitKernel label. + +Implicit Arguments: + + a0 - Supplies the address of matrix A. + + a1 - Supplies the address of matrix B. + + t8 - Supplies the address of matrix A. + + a5 - Supplies the number of columns from matrix B and matrix C to iterate + over. + + a2 - Supplies the address of matrix C. + + a3 - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + t7 - Supplies the length in bytes of a row from matrix A. + + t5 - Supplies the length in bytes of a row from matrix C. + + s3 - Stores the ZeroMode argument from the stack frame. + +--*/ + + .macro ProcessCountM RowCount, Fallthrough +.LProcessNextColumnLoop8xN\@: + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr8,$vr8,$vr8" + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr9,$vr9,$vr9" + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr10,$vr10,$vr10" + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr11,$vr11,$vr11" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr12,$vr12,$vr12" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr13,$vr13,$vr13" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr14,$vr14,$vr14" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr15,$vr15,$vr15" + move $t7,$a3 # reload CountK +.LCompute8xNBlockBy1Loop\@: + EmitIfCountGE \RowCount\(), 1, "ld.d $s0, $a0, 0" + EmitIfCountGE \RowCount\(), 1, "vreplgr2vr.d $vr0, $s0" + EmitIfCountGE \RowCount\(), 2, "ldx.d $s0, $a0, $t0" + EmitIfCountGE \RowCount\(), 2, "vreplgr2vr.d $vr1, $s0" + ComputeBlockSseBy8 \RowCount\() + addi.d $a1, $a1, 8*8 # advance matrix B by 8 columns + addi.d $a0, $a0, 8 # advance matrix A by 1 column + addi.d $t7, $t7, -1 + bnez $t7, .LCompute8xNBlockBy1Loop\@ + +.LOutput8xNBlock\@: + movfr2gr.d $s0, $f24 + vreplgr2vr.d $vr2, $s0 + # multiply by alpha + EmitIfCountGE \RowCount\(), 1, "vfmul.d $vr8, $vr8, $vr2" + EmitIfCountGE \RowCount\(), 1, "vfmul.d $vr9, $vr9, $vr2" + EmitIfCountGE \RowCount\(), 1, "vfmul.d $vr10,$vr10, $vr2" + EmitIfCountGE \RowCount\(), 1, "vfmul.d $vr11,$vr11, $vr2" + EmitIfCountGE \RowCount\(), 2, "vfmul.d $vr12,$vr12, $vr2" + EmitIfCountGE \RowCount\(), 2, "vfmul.d $vr13,$vr13, $vr2" + EmitIfCountGE \RowCount\(), 2, "vfmul.d $vr14,$vr14, $vr2" + EmitIfCountGE \RowCount\(), 2, "vfmul.d $vr15,$vr15, $vr2" + li.d $s0, 8 + blt $a5, $s0, .LOutputPartial8xNBlock\@ + sub.d $a5, $a5, $s0 + AccumulateAndStoreBlock \RowCount\(), 4 + addi.d $a2, $a2, 8*8 # advance matrix C by 8 columns + move $a0, $t1 # reload matrix A + bnez $a5, .LProcessNextColumnLoop8xN\@ + b .LExitKernel + +// +// Output a partial 8xN block to the matrix. +// + +.LOutputPartial8xNBlock\@: + li.d $s0, 2 + blt $a5, $s0, .LOutputPartial1xNBlock\@ + li.d $s0, 4 + blt $a5, $s0, .LOutputPartialLessThan4xNBlock\@ + li.d $s0, 6 + blt $a5, $s0, .LOutputPartialLessThan6xNBlock\@ + AccumulateAndStoreBlock \RowCount\(), 3 + andi $s0, $a5, 1 # check if remaining count is small + beqz $s0, .LExitKernel + EmitIfCountGE \RowCount\(), 1, "vmove $vr8,$vr11" + # shift remaining elements down + EmitIfCountGE \RowCount\(), 2, "vmove $vr12,$vr15" + addi.d $a2, $a2, 6*8 # advance matrix C by 6 columns + b .LOutputPartial1xNBlock\@ + +.LOutputPartialLessThan6xNBlock\@: + AccumulateAndStoreBlock \RowCount\(), 2 + andi $s0, $a5,1 # check if remaining count is small + beqz $s0, .LExitKernel + EmitIfCountGE \RowCount\(), 1, "vmove $vr8,$vr10" + # shift remaining elements down + EmitIfCountGE \RowCount\(), 2, "vmove $vr12,$vr14" + addi.d $a2, $a2, 4*8 # advance matrix C by 4 columns + b .LOutputPartial1xNBlock\@ + +.LOutputPartialLessThan4xNBlock\@: + AccumulateAndStoreBlock \RowCount\(), 1 + andi $s0, $a5,1 # check if remaining count is small + beqz $s0, .LExitKernel + EmitIfCountGE \RowCount\(), 1, "vmove $vr8,$vr9" + # shift remaining elements down + EmitIfCountGE \RowCount\(), 2, "vmove $vr12,$vr13" + addi.d $a2, $a2, 2*8 # advance matrix C by 2 columns + +.LOutputPartial1xNBlock\@: + bnez $t5, .LSkipAccumulateOutput1xN\@ # ZeroMode? + + EmitIfCountGE \RowCount\(), 1, "fld.d $f15, $a2, 0" + EmitIfCountGE \RowCount\(), 1, "fadd.d $f15, $f15, $f8" + EmitIfCountGE \RowCount\(), 2, "fldx.d $f16, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "fadd.d $f16, $f16, $f12" + +.LSkipAccumulateOutput1xN\@: + EmitIfCountGE \RowCount\(), 1, "fst.d $f15, $a2, 0" + EmitIfCountGE \RowCount\(), 2, "fstx.d $f16, $a2, $t6" +.ifb \Fallthrough\() + b .LExitKernel +.endif + + .endm + +// +// Generate the GEMM kernel. +// + +FgemmKernelLsxFunction MlasGemmDoubleKernelLSX + + .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelCommon.h b/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelCommon.h new file mode 100644 index 0000000000000..777a592590ec4 --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelCommon.h @@ -0,0 +1,100 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + FgemmKernelCommon.h + +Abstract: + + This module contains common kernel macros and structures for the floating + point matrix/matrix multiply operation (SGEMM and DGEMM). + +--*/ + +// +// Define the typed instruction template. +// + +#define FGEMM_TYPED_INSTRUCTION(Untyped, Typed) \ + .macro Untyped Operand:vararg; Typed \Operand\(); .endm; + +/*++ + +Macro Description: + + This macro generates code to execute the block compute macro multiple + times and advancing the matrix A and matrix B data pointers. + +Arguments: + + ComputeBlock - Supplies the macro to compute a single block. + + RowCount - Supplies the number of rows to process. + + AdvanceMatrixAPlusRows - Supplies a non-zero value if the data pointer + in rbx should also be advanced as part of the loop. + +Implicit Arguments: + + a0 - Supplies the address into the matrix A data. + + t7 - Supplies the address into the matrix A data plus 3 rows. + + a1 - Supplies the address into the matrix B data. + + a3 - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + vr4-vr15 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockLoop ComputeBlock, RowCount, AdvanceMatrixAPlusRows + + move $t8, $a3 # reload CountK + li.d $s0, 4 + blt $t8, $s0, .LProcessRemainingBlocks\@ + +.LComputeBlockBy4Loop\@: + \ComputeBlock\() \RowCount\(), 0, LFgemmElementSize*0, 64*4 + \ComputeBlock\() \RowCount\(), 2*32, LFgemmElementSize*1, 64*4 + addi.d $a1, $a1, 2*2*32 # advance matrix B by 128 bytes + \ComputeBlock\() \RowCount\(), 0, LFgemmElementSize*2, 64*4 + \ComputeBlock\() \RowCount\(), 2*32, LFgemmElementSize*3, 64*4 + addi.d $a1, $a1, 2*2*32 # advance matrix B by 128 bytes + addi.d $a0, $a0, 4*LFgemmElementSize # advance matrix A by 4 elements +.if \RowCount\() > 3 + addi.d $t7, $t7, 4*LFgemmElementSize # advance matrix A plus rows by 4 elements +.if \RowCount\() == 12 + addi.d $t3, $t3, 4*LFgemmElementSize + addi.d $t4,, $t4, 4*LFgemmElementSize +.endif +.endif + addi.d $t8, $t8, -4 + li.d $s0, 4 + bge $t8, $s0, .LComputeBlockBy4Loop\@ + +.LProcessRemainingBlocks\@: + beqz $t8, .LOutputBlock\@ + +.LComputeBlockBy1Loop\@: + \ComputeBlock\() \RowCount\(), 0, 0 + addi.d $a1, $a1, 2*32 # advance matrix B by 64 bytes + addi.d $a0, $a0, LFgemmElementSize # advance matrix A by 1 element +.if \RowCount\() > 3 + addi.d $t7, $t7, LFgemmElementSize # advance matrix A plus rows by 1 element +.if \RowCount\() == 12 + addi.d $t3, $t3, LFgemmElementSize + addi.d $t4, $t4, LFgemmElementSize +.endif +.endif + addi.d $t8, $t8, -1 + bnez $t8, .LComputeBlockBy1Loop\@ + +.LOutputBlock\@: + + .endm diff --git a/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelLasxCommon.h b/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelLasxCommon.h new file mode 100644 index 0000000000000..b96db848617bf --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelLasxCommon.h @@ -0,0 +1,546 @@ + +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + FgemmKernelLasxCommon.h + +Abstract: + + This module implements the kernels for the floating point matrix/matrix + multiply operation (SGEMM and DGEMM). + + This implementation uses LASX instructions. + +--*/ + +/*++ + +Macro Description: + + This macro multiplies and accumulates for 2 YMMWORDs by N rows of the output + matrix. + +Arguments: + + RowCount - Supplies the number of rows to process. + + VectorOffset - Supplies the byte offset from matrix B to fetch elements. + + BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. + + PrefetchOffset - Optionally supplies the byte offset from matrix B to + prefetch elements. + +Implicit Arguments: + + a0 - Supplies the address into the matrix A data. + + t7 - Supplies the address into the matrix A data plus 2 rows. + + a1 - Supplies the address into the matrix B data. + + t0 - Supplies the length in bytes of a row from matrix A. + + xr8-xr15 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockLasxBy16 RowCount, VectorOffset, BroadcastOffset, PrefetchOffset + +.if \RowCount\() == 1 + xvldrepl.w $xr3, $a0, \BroadcastOffset\() + xvld $xr4, $a1, \VectorOffset\() + xvfmadd $xr8, $xr4, $xr3, $xr8 + xvld $xr5, $a1, \VectorOffset\()+32 + xvfmadd $xr9, $xr5, $xr3, $xr9 +.else + xvld $xr0, $a1, \VectorOffset\() + xvld $xr1, $a1, \VectorOffset\()+32 + EmitIfCountGE \RowCount\(), 1, "xvldrepl $xr3,$a0, \BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 1, "xvfmadd $xr8, $xr3, $xr0, $xr8" + EmitIfCountGE \RowCount\(), 1, "xvfmadd $xr9, $xr3, $xr1, $xr9" + EmitIfCountGE \RowCount\(), 2, "add.d $s0,$a0, $t0" + EmitIfCountGE \RowCount\(), 2, "xvldrepl $xr3,$s0, \BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 2, "xvfmadd $xr10, $xr3, $xr0, $xr10" + EmitIfCountGE \RowCount\(), 2, "xvfmadd $xr11, $xr3, $xr1, $xr11" + + EmitIfCountGE \RowCount\(), 3, "xvldrepl $xr3,$t7, \BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 3, "xvfmadd $xr12, $xr3, $xr0, $xr12" + EmitIfCountGE \RowCount\(), 3, "xvfmadd $xr13, $xr3, $xr1, $xr13" + EmitIfCountGE \RowCount\(), 4, "add.d $s0,$t7, $t0" + EmitIfCountGE \RowCount\(), 4, "xvldrepl $xr3,$s0, \BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 4, "xvfmadd $xr14, $xr3, $xr0, $xr14" + EmitIfCountGE \RowCount\(), 4, "xvfmadd $xr15, $xr3, $xr1, $xr15" +.endif + + .endm + +/*++ + +Macro Description: + + This macro multiplies and accumulates for 1 YMMWORD by N rows of the output + matrix. + +Arguments: + + RowCount - Supplies the number of rows to process. + + VectorOffset - Supplies the byte offset from matrix B to fetch elements. + + BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. + + PrefetchOffset - Optionally supplies the byte offset from matrix B to + prefetch elements. + +Implicit Arguments: + + a0 - Supplies the address into the matrix A data. + + t7 - Supplies the address into the matrix A data plus 2 rows. + + a1 - Supplies the address into the matrix B data. + + t0 - Supplies the length in bytes of a row from matrix A. + + xr8-xr15 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockLasxBy8 RowCount, VectorOffset, BroadcastOffset, PrefetchOffset + +.if \RowCount\() == 1 + xvldrepl.w $xr3, $a0, \BroadcastOffset\() + xvld $xr5, $a1, \VectorOffset\() + xvfmadd.s $xr9, $xr5, $xr3, $xr9 +.else + xvld $xr0, $a1, \VectorOffset\() + EmitIfCountGE \RowCount\(), 1, "xvldrepl $xr3, $a0, \BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 1, "xvfmadd $xr9, $xr3, $xr0, $xr9" + + EmitIfCountGE \RowCount\(), 2, "add.d $s0, $a0, $t0" + EmitIfCountGE \RowCount\(), 2, "xvldrepl $xr3, $s0, \BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 2, "xvfmadd $xr11, $xr3, $xr0, $xr11" + EmitIfCountGE \RowCount\(), 3, "xvldrepl $xr3, $t7, \BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 3, "xvfmadd $xr13, $xr3, $xr0, $xr13" + EmitIfCountGE \RowCount\(), 4, "add.d $s0, $t7, $t0" + EmitIfCountGE \RowCount\(), 4, "xvldrepl $xr3, $s0, \BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 4, "xvfmadd $xr15, $xr3, $xr0, $xr15" +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to execute the block compute macro multiple + times and advancing the matrix A and matrix B data pointers. + +Arguments: + + ComputeBlock - Supplies the macro to compute a single block. + + RowCount - Supplies the number of rows to process. + +Implicit Arguments: + + a0 - Supplies the address into the matrix A data. + + a1 - Supplies the address into the matrix B data. + + a3 - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + t0 - Supplies the length in bytes of a row from matrix A. + + vr4-vr15 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockLasxLoop ComputeBlock, RowCount + +.if \RowCount\() > 2 + # compute matrix A plus 2 rows + slli.d $s0, $t0, 1 + add.d $t7, $a0, $s0 +.endif + ComputeBlockLoop \ComputeBlock\(), \RowCount\(), \RowCount\() > 2 +.if \RowCount\() > 2 + # compute matrix C plus 2 rows + slli.d $s0, $t6, 1 + add.d $t7, $a2, $s0 +.endif + + .endm + + .macro store_n src, num, dst + move $s2, \num\() + beqz $s2, .Lstore_exit\@ + xvstelm.w \src\(), \dst\(), 0, 0 + addi.d $s2, $s2, -1 + beqz $s2, .Lstore_exit\@ + + xvstelm.w \src\(), \dst\(), 4, 1 + addi.d $s2, $s2, -1 + beqz $s2, .Lstore_exit\@ + + xvstelm.w \src\(), \dst\(), 8, 2 + addi.d $s2, $s2, -1 + beqz $s2, .Lstore_exit\@ + + xvstelm.w \src\(), \dst\(), 12, 3 + addi.d $s2, $s2, -1 + beqz $s2, .Lstore_exit\@ + + xvstelm.w \src\(), \dst\(), 16, 4 + addi.d $s2, $s2, -1 + beqz $s2, .Lstore_exit\@ + + xvstelm.w \src\(), \dst\(), 20, 5 + addi.d $s2, $s2, -1 + beqz $s2, .Lstore_exit\@ + + xvstelm.w \src\(), \dst\(), 24, 6 + addi.d $s2, $s2, -1 + beqz $s2, .Lstore_exit\@ + +.Lstore_exit\@: + .endm +/*++ + +Macro Description: + + This macro generates code to compute matrix multiplication for a fixed set + of rows. + +Arguments: + + RowCount - Supplies the number of rows to process. + + Fallthrough - Supplies a non-blank value if the macro may fall through to + the ExitKernel label. + +Implicit Arguments: + + a0 - Supplies the address of matrix A. + + a1 - Supplies the address of matrix B. + + t1 - Supplies the address of matrix A. + + a5 - Supplies the number of columns from matrix B and matrix C to iterate + over. + + a2 - Supplies the address of matrix C. + + a3 - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + t0 - Supplies the length in bytes of a row from matrix A. + + t6 - Supplies the length in bytes of a row from matrix C. + + t5 - Stores the ZeroMode argument from the stack frame. + +--*/ + + .macro ProcessCountM RowCount, Fallthrough + + ori $s1, $r0, LFgemmYmmElementCount + bgeu $s1, $a5, .LProcessRemainingCountN\@ + +.LProcessNextColumnLoop2xN\@: + EmitIfCountGE \RowCount\(), 1, "xvxor.v $xr8, $xr8, $xr8" + EmitIfCountGE \RowCount\(), 1, "xvxor.v $xr9, $xr9, $xr9" + EmitIfCountGE \RowCount\(), 2, "xvxor.v $xr10, $xr10, $xr10" + EmitIfCountGE \RowCount\(), 2, "xvxor.v $xr11, $xr11, $xr11" + EmitIfCountGE \RowCount\(), 3, "xvxor.v $xr12, $xr12, $xr12" + EmitIfCountGE \RowCount\(), 3, "xvxor.v $xr13, $xr13, $xr13" + EmitIfCountGE \RowCount\(), 4, "xvxor.v $xr14, $xr14, $xr14" + EmitIfCountGE \RowCount\(), 4, "xvxor.v $xr15, $xr15, $xr15" + + ComputeBlockLasxLoop ComputeBlockLasxBy16, \RowCount\() + EmitIfCountGE \RowCount\(), 1, "xvfmul $xr8, $xr8, $xr2" + EmitIfCountGE \RowCount\(), 1, "xvfmul $xr9, $xr9, $xr2" + EmitIfCountGE \RowCount\(), 2, "xvfmul $xr10, $xr10, $xr2" + EmitIfCountGE \RowCount\(), 2, "xvfmul $xr11, $xr11, $xr2" + EmitIfCountGE \RowCount\(), 3, "xvfmul $xr12, $xr12, $xr2" + EmitIfCountGE \RowCount\(), 3, "xvfmul $xr13, $xr13, $xr2" + EmitIfCountGE \RowCount\(), 4, "xvfmul $xr14, $xr14, $xr2" + EmitIfCountGE \RowCount\(), 4, "xvfmul $xr15, $xr15, $xr2" + + sub.d $a5, $a5, $s1 + sub.d $a5, $a5, $s1 + blt $a5, $zero, .LOutputMasked2xNBlock\@ + andi $s0, $t5, 0xff # ZeroMode? + bnez $s0, .LStore2xNBlock\@ + EmitIfCountGE \RowCount\(), 1, "xvld $xr16, $a2, 0" + EmitIfCountGE \RowCount\(), 1, "xvfadd $xr8, $xr8, $xr16" + EmitIfCountGE \RowCount\(), 1, "xvld $xr16, $a2, 0x20" + EmitIfCountGE \RowCount\(), 1, "xvfadd $xr9, $xr9, $xr16" + EmitIfCountGE \RowCount\(), 2, "xvldx $xr16, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "xvfadd $xr10, $xr10, $xr16" + EmitIfCountGE \RowCount\(), 2, "add.d $s0, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "xvld $xr16, $s0, 0x20" + EmitIfCountGE \RowCount\(), 2, "xvfadd $xr11, $xr11, $xr16" + EmitIfCountGE \RowCount\(), 3, "xvld $xr16, $t7, 0" + EmitIfCountGE \RowCount\(), 3, "xvfadd $xr12, $xr12, $xr16" + EmitIfCountGE \RowCount\(), 3, "xvld $xr16, $t7, 0x20" + EmitIfCountGE \RowCount\(), 3, "xvfadd $xr13, $xr13, $xr16" + EmitIfCountGE \RowCount\(), 4, "xvldx $xr16, $t7, $t6" + EmitIfCountGE \RowCount\(), 4, "xvfadd $xr14, $xr14, $xr16" + EmitIfCountGE \RowCount\(), 4, "add.d $s0, $t7, $t6" + EmitIfCountGE \RowCount\(), 4, "xvld $xr16, $s0, 0x20" + EmitIfCountGE \RowCount\(), 4, "xvfadd $xr15, $xr15, $xr16" + +.LStore2xNBlock\@: + EmitIfCountGE \RowCount\(), 1, "xvst $xr8, $a2, 0" + EmitIfCountGE \RowCount\(), 1, "xvst $xr9, $a2, 0x20" + EmitIfCountGE \RowCount\(), 2, "xvstx $xr10, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "add.d $s0, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "xvst $xr11, $s0, 0x20" + EmitIfCountGE \RowCount\(), 3, "xvst $xr12, $t7, 0" + EmitIfCountGE \RowCount\(), 3, "xvst $xr13, $t7, 0x20" + EmitIfCountGE \RowCount\(), 4, "xvstx $xr14, $t7, $t6" + EmitIfCountGE \RowCount\(), 4, "add.d $s0, $t7, $t6" + EmitIfCountGE \RowCount\(), 4, "xvst $xr15, $s0, 0x20" + + addi.d $a2, $a2, 0x40 # advance matrix C by 2 XRWORDs + move $a0, $t1 # reload matrix A + bltu $s1, $a5, .LProcessNextColumnLoop2xN\@ + beqz $a5, .LExitKernel + +.LProcessRemainingCountN\@: + EmitIfCountGE \RowCount\(), 1, "xvxor.v $xr9, $xr9, $xr9" + EmitIfCountGE \RowCount\(), 2, "xvxor.v $xr11, $xr11, $xr11" + EmitIfCountGE \RowCount\(), 3, "xvxor.v $xr13, $xr13, $xr13" + EmitIfCountGE \RowCount\(), 4, "xvxor.v $xr15, $xr15, $xr15" + + + ComputeBlockLasxLoop ComputeBlockLasxBy8, \RowCount\() + EmitIfCountGE \RowCount\(), 1, "xvfmul $xr9, $xr9, $xr2" + EmitIfCountGE \RowCount\(), 2, "xvfmul $xr11, $xr11, $xr2" + EmitIfCountGE \RowCount\(), 3, "xvfmul $xr13, $xr13, $xr2" + EmitIfCountGE \RowCount\(), 4, "xvfmul $xr15, $xr15, $xr2" + bltu $a5, $s1, .LOutputMasked1xNBlock\@ + andi $s0, $t5, 0xff # ZeroMode? + bnez $s0, .LStore1xNBlock\@ + EmitIfCountGE \RowCount\(), 1, "xvld $xr16, $a2, 0" + EmitIfCountGE \RowCount\(), 1, "xvfadd $xr9, $xr9, $xr16" + EmitIfCountGE \RowCount\(), 2, "xvldx $xr16, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "xvfadd $xr11, $xr11, $xr16" + EmitIfCountGE \RowCount\(), 3, "xvld $xr16, $t7, 0" + EmitIfCountGE \RowCount\(), 3, "xvfadd $xr13, $xr13, $xr16" + EmitIfCountGE \RowCount\(), 4, "xvldx $xr16, $t7, $t6" + EmitIfCountGE \RowCount\(), 4, "xvfadd $xr15, $xr15, $xr16" + +.LStore1xNBlock\@: + EmitIfCountGE \RowCount\(), 1, "xvst $xr9, $a2, 0" + EmitIfCountGE \RowCount\(), 2, "xvstx $xr11, $a2, $t6" + EmitIfCountGE \RowCount\(), 3, "xvst $xr13, $t7, 0" + EmitIfCountGE \RowCount\(), 4, "xvstx $xr15, $t7, $t6" + b .LExitKernel + +.LOutputMasked2xNBlock\@: + andi $s0, $t5, 0xff # ZeroMode? + bnez $s0, .LStoreMasked2xNBlock\@ + EmitIfCountGE \RowCount\(), 1, "xvld $xr16, $a2, 0" + EmitIfCountGE \RowCount\(), 1, "xvfadd $xr8, $xr8, $xr16" + EmitIfCountGE \RowCount\(), 2, "xvldx $xr16, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "xvfadd $xr10, $xr10, $xr16" + EmitIfCountGE \RowCount\(), 3, "xvld $xr16, $t7, 0" + EmitIfCountGE \RowCount\(), 3, "xvfadd $xr12, $xr12, $xr16" + EmitIfCountGE \RowCount\(), 4, "xvldx $xr16, $t7, $t6" + EmitIfCountGE \RowCount\(), 4, "xvfadd $xr14, $xr14, $xr16" + +.LStoreMasked2xNBlock\@: + EmitIfCountGE \RowCount\(), 1, "xvst $xr8, $a2, 0" + EmitIfCountGE \RowCount\(), 2, "xvstx $xr10, $a2, $t6" + EmitIfCountGE \RowCount\(), 3, "xvst $xr12, $t7, 0" + EmitIfCountGE \RowCount\(), 4, "xvstx $xr14, $t7, $t6" + addi.d $a2, $a2, 0x20 # advance matrix C by YMMWORD +.if \RowCount\() > 2 + addi.d $t7, $t7, 0x20 # advance matrix C plus 2 rows by YMMWORD + +.endif + addi.d $a5, $a5, LFgemmYmmElementCount # correct for over-subtract above + + +.LOutputMasked1xNBlock\@: + +.if \RowCount\() > 2 + slli.d $s0, $t0, 1 + add.d $t7, $a0, $s0 +.endif + +.if \RowCount\() == 1 +.else +.endif + +.if \RowCount\() > 2 + slli.d $s0, $t6, 1 + add.d $t7, $a2, $s0 +.endif + + sub.d $a5, $zero, $a5 + la.global $a0, MlasMaskMoveTableLasx + ori $s0, $r0, LFgemmElementSize + mul.d $s0, $a5, $s0 + addi.d $s0, $s0, 8*4 + xvldx $xr0, $a0, $s0 + andi $s0, $t5, 0xff + + sub.d $a5, $zero, $a5 + + bnez $s0, .LStoreMasked1xNBlock\@ + EmitIfCountGE \RowCount\(), 1, "xvld $xr16, $a2, 0" + EmitIfCountGE \RowCount\(), 1, "xvand.v $xr8, $xr16, $xr0" + EmitIfCountGE \RowCount\(), 2, "xvldx $xr16, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "xvand.v $xr10, $xr16, $xr0" + EmitIfCountGE \RowCount\(), 3, "xvld $xr16, $t7, 0" + EmitIfCountGE \RowCount\(), 3, "xvand.v $xr12, $xr16, $xr0" + EmitIfCountGE \RowCount\(), 4, "xvldx $xr16, $t7, $t6" + EmitIfCountGE \RowCount\(), 4, "xvand.v $xr14, $xr16, $xr0" + + EmitIfCountGE \RowCount\(), 1, "xvfadd $xr9, $xr9, $xr8" + EmitIfCountGE \RowCount\(), 2, "xvfadd $xr11, $xr11, $xr10" + EmitIfCountGE \RowCount\(), 3, "xvfadd $xr13, $xr13, $xr12" + EmitIfCountGE \RowCount\(), 4, "xvfadd $xr15, $xr15, $xr14" +.LStoreMasked1xNBlock\@: + EmitIfCountGE \RowCount\(), 1, "store_n $xr9, $a5, $a2" + + add.d $s3, $a2, $t6 + EmitIfCountGE \RowCount\(), 2, "store_n $xr11, $a5, $s3" + + EmitIfCountGE \RowCount\(), 3, "store_n $xr13, $a5, $t7" + + add.d $s3, $t7, $t6 + EmitIfCountGE \RowCount\(), 4, "store_n $xr15, $a5, $s3" + sub.d $a5, $zero, $a5 +.ifb \Fallthrough\() + b .LExitKernel +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates the inner kernel to compute matrix multiplication. + +Arguments: + + FunctionName - Supplies the name for the generated function. + +--*/ + + .macro FgemmKernelLasxFunction FunctionName + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A a0 - Supplies the address of matrix A. + + B a1 - Supplies the address of matrix B. The matrix data has been packed + using MlasSgemmCopyPackB or MlasSgemmTransposePackB. + + C a2 - Supplies the address of matrix C. + + CountK a3 - Supplies the number of columns from matrix A and the number + of rows from matrix B to iterate over. + + CountM a4 - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN a5 - Supplies the number of columns from matrix B and matrix C to + iterate over. + + lda a6 - Supplies the first dimension of matrix A. + + ldc a7 - Supplies the first dimension of matrix C. + + Alpha f0 - Supplies the scalar alpha multiplier (see GEMM definition). + + ZeroMode (sp + 0)- Supplies true if the output matrix must be zero initialized, + else false if the output matrix is accumulated into. + +Return Value: + + Returns the number of rows handled. + +--*/ + + FUNCTION_ENTRY \FunctionName\() + + addi.d $sp, $sp, -64 + st.d $ra, $sp, 56 + st.d $s0, $sp, 0*8 + st.d $s1, $sp, 1*8 + fst.s $f0, $sp, 2*8 + fst.d $f16, $sp,3*8 + st.d $s2, $sp, 4*8 + st.d $s3, $sp, 5*8 + + move $t1, $a0 + slli.d $t0, $a6, 2 # convert lda to bytes + slli.d $t6, $a7, 2 # convert ldc to bytes + ld.d $t5, $sp, 64 # get zeromode + fst.s $f0, $sp, 2*8 + xvldrepl.w $xr2, $sp, 0x10 + +// +// Process 4 rows of the matrices. +// + + ori $s0, $zero, 4 + bltu $a4, $s0, .LProcessCountMLessThan4 + li.d $a4, 4 # return 4 rows handled + ProcessCountM 4, Fallthrough + +// +// Restore non-volatile registers and return. +// + +.LExitKernel: + bstrpick.d $a0, $a4, 31, 0 + ld.d $s0, $sp, 0 + ld.d $s1, $sp, 8 + fld.d $f16, $sp,3*8 + ld.d $s2, $sp, 4*8 + ld.d $s3, $sp, 5*8 + ld.d $ra, $sp, 7*8 + addi.d $sp, $sp, 64 + jr $ra + +// +// Process 2 rows of the matrices. +// + +.LProcessCountMLessThan4: + ori $s0, $r0, 2 + bltu $a4, $s0, .LProcessCountMLessThan2 + li.d $a4, 2 # return 2 rows handled + ProcessCountM 2 + +// +// Process 1 row of the matrices. +// + +.LProcessCountMLessThan2: + ProcessCountM 1 + + .endm diff --git a/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelLsxCommon.h b/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelLsxCommon.h new file mode 100644 index 0000000000000..0333af792ba70 --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelLsxCommon.h @@ -0,0 +1,170 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + FgemmKernelLsxCommon.h + +Abstract: + + This module implements the kernels for the floating point matrix/matrix + multiply operation (SGEMM and DGEMM). + + This implementation uses Lsx instructions. + +--*/ + +#include "FgemmKernelCommon.h" +/*++ + +Macro Description: + + This stores the block accumulators to the output matrix with an optional + accumulation of the existing contents of the output matrix. + +Arguments: + + RowCount - Supplies the number of rows to process. + + VectorCount - Supplies the number of vector columns to process. + +Implicit Arguments: + + t5 - Supplies the length in bytes of a row from matrix C. + + a2 - Supplies the address of matrix C. + + s3 - Stores the ZeroMode argument from the stack frame. + + vr8-vr15 - Supplies the block accumulators. + +--*/ + + .macro AccumulateAndStoreBlock RowCount, VectorCount + + and $s0, $t5,$t5 # ZeroMode? + bnez $s0 , .LSkipAccumulateOutput\@ + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 1, "vld $vr0, $a2, 0" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 2, "vld $vr1, $a2, 16" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 3, "vld $vr2, $a2, 32" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 4, "vld $vr3, $a2, 48" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 1, "vldx $vr4, $a2, $t6" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 2, "addi.d $s0, $t6, 16" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 2, "vldx $vr5, $a2, $s0" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 3, "addi.d $s0, $t6, 32" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 3, "vldx $vr6, $a2, $s0" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 4, "addi.d $s0, $t6, 48" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 4, "vldx $vr7, $a2, $s0" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 1, "vfadd $vr8, $vr8, $vr0" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 2, "vfadd $vr9, $vr9, $vr1" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 3, "vfadd $vr10,$vr10,$vr2" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 4, "vfadd $vr11,$vr11,$vr3" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 1, "vfadd $vr12,$vr12,$vr4" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 2, "vfadd $vr13,$vr13,$vr5" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 3, "vfadd $vr14,$vr14,$vr6" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 4, "vfadd $vr15,$vr15,$vr7" + +.LSkipAccumulateOutput\@: + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 1, "vst $vr8, $a2, 0" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 2, "vst $vr9, $a2, 16" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 3, "vst $vr10, $a2, 32" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 4, "vst $vr11, $a2, 48" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 1, "vstx $vr12, $a2, $t6" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 2, "addi.d $s0, $t6, 16" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 2, "vstx $vr13, $a2, $s0" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 3, "addi.d $s0, $t6, 32" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 3, "vstx $vr14, $a2, $s0" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 4, "addi.d $s0, $t6, 48" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 4, "vstx $vr15, $a2, $s0" + + .endm +/*++ + +Macro Description: + + This macro generates the inner kernel to compute matrix multiplication. + +Arguments: + + FunctionName - Supplies the name for the generated function. + +--*/ + + .macro FgemmKernelLsxFunction FunctionName + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (a0) - Supplies the address of matrix A. + + B (a1) - Supplies the address of matrix B. The matrix data has been packed + using MlasSgemmCopyPackB or MlasSgemmTransposePackB. + + C (a2) - Supplies the address of matrix C. + + CountK (a3) - Supplies the number of columns from matrix A and the number + of rows from matrix B to iterate over. + + CountM (a4) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (a5) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + lda (a6) Supplies the first dimension of matrix A. + + ldc (a7) Supplies the first dimension of matrix C. + + Alpha (f0) - Supplies the scalar alpha multiplier (see GEMM definition). + + ZeroMode (sp 0) - Supplies true if the output matrix must be zero initialized, + else false if the output matrix is accumulated into. + +Return Value: + + Returns the number of rows handled. + +--*/ + +FUNCTION_ENTRY \FunctionName\() + addi.d $sp, $sp, -64 + st.d $t5, $sp, 0 + st.d $s0, $sp, 1*8 + st.d $s1, $sp, 2*8 + st.d $s2, $sp, 3*8 + st.d $s3, $sp, 4*8 + move $t1, $a0 + slli.d $t0, $a6, 2 //convert lda to bytes + slli.d $t6, $a7, 2 //convert ldc to bytes + ld.d $t5, $sp, 64 + fmov.s $f24, $f0 //f0 destroyed by lsx + + li.d $s0, 2 + blt $a4, $s0, .LProcessCountM1 + + li.d $a4, 2 + ProcessCountM 2, Fallthrough + +.LExitKernel: + ld.d $t5, $sp, 0 + ld.d $s0, $sp, 1*8 + ld.d $s1, $sp, 2*8 + ld.d $s2, $sp, 3*8 + ld.d $s3, $sp, 4*8 + addi.d $sp, $sp, 64 + move $a0, $a4 + jr $ra + +.LProcessCountM1: + ProcessCountM 1 + .endm diff --git a/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLasx.S b/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLasx.S new file mode 100644 index 0000000000000..e03503521912a --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLasx.S @@ -0,0 +1,412 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SconvKernelLasx.S + +Abstract: + + This module implements the kernels for the single precision convolution + operation. + + This implementation uses Lasx instructions. + +--*/ + +#include "asmmacro.h" +#include "SconvKernelLasxCommon.h" + + .text + +/*++ + +Macro Description: + + This macro multiplies and accumulates for FilterCount by OutputCount block + of the output buffer. + +Arguments: + + KernelType - Supplies the type of kernel to be generated. + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + + VectorOffset - Supplies the byte offset from the filter buffer to fetch + elements. + + BroadcastOffset - Supplies the byte offset from the input buffer to fetch + elements. + +Implicit Arguments: + + a3 - Supplies the address of the input buffer. + + a2 - Supplies the address of the filter buffer. + + a1 - Supplies the FilterStride parameter (see function description). + + t7 - Supplies the address of the filter buffer plus 2 * FilterStride. + + a5 - Supplies the StrideWidth parameter (see function description). + + xr0-xr7 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlock KernelType, FilterCount, OutputCount, VectorOffset, BroadcastOffset + +.ifeqs "\KernelType\()","Depthwise" + xvld $xr12, $a2, 0 + EmitIfCountGE \OutputCount\(), 1, "xvld $xr8, $a3, 0" + EmitIfCountGE \OutputCount\(), 1, "xvfmadd.s $xr0, $xr8, $xr12, $xr0" + EmitIfCountGE \OutputCount\(), 2, "xvldx $xr9, $a3, $a5" + EmitIfCountGE \OutputCount\(), 2, "xvfmadd.s $xr4, $xr9, $xr12, $xr4" + +.else + EmitIfCountGE \OutputCount\(), 1, "xvldrepl.w $xr13, $a3, \BroadcastOffset\()" + EmitIfCountGE \OutputCount\(), 2, "add.d $s0, $a3, $a5" + EmitIfCountGE \OutputCount\(), 2, "xvldrepl.w $xr14, $s0, \BroadcastOffset\()" +.if \OutputCount\() == 1 + EmitIfCountGE \FilterCount\(), 1, "xvld $xr8, $a2, \VectorOffset\()" + EmitIfCountGE \FilterCount\(), 1, "xvfmadd.s $xr0, $xr8, $xr13, $xr0" + EmitIfCountGE \FilterCount\(), 2, "add.d $s0, $a2, $a1" + EmitIfCountGE \FilterCount\(), 2, "xvld $xr9, $s0, \VectorOffset\()" + EmitIfCountGE \FilterCount\(), 2, "xvfmadd.s $xr1, $xr9, $xr13, $xr1" + EmitIfCountGE \FilterCount\(), 3, "xvld $xr10, $t7, \VectorOffset\()" + EmitIfCountGE \FilterCount\(), 3, "xvfmadd.s $xr2, $xr10, $xr13, $xr2" + EmitIfCountGE \FilterCount\(), 4, "add.d $s0, $t7, $a1" + EmitIfCountGE \FilterCount\(), 4, "xvld $xr11, $s0, \VectorOffset\()" + EmitIfCountGE \FilterCount\(), 4, "xvfmadd.s $xr3, $xr11, $xr13, $xr3" +.else + EmitIfCountGE \FilterCount\(), 1, "xvld $xr12, $a2, \VectorOffset\()" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvfmadd.s $xr0, $xr12, $xr13, $xr0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvfmadd.s $xr4, $xr12, $xr14, $xr4" + EmitIfCountGE \FilterCount\(), 2, "add.d $s0, $a2, $a1" + EmitIfCountGE \FilterCount\(), 2, "xvld $xr12, $s0, \VectorOffset\()" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvfmadd.s $xr1, $xr13, $xr12, $xr1" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvfmadd.s $xr5, $xr14, $xr12, $xr5" + EmitIfCountGE \FilterCount\(), 3, "xvld $xr12, $t7, \VectorOffset\()" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvfmadd.s $xr2, $xr13, $xr12, $xr2" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvfmadd.s $xr6, $xr14, $xr12, $xr6" + EmitIfCountGE \FilterCount\(), 4, "add.d $s0, $t7, $a1" + EmitIfCountGE \FilterCount\(), 4, "xvld $xr12, $s0, \VectorOffset\()" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvfmadd.s $xr3, $xr13, $xr12, $xr3" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvfmadd.s $xr7, $xr14, $xr12, $xr7" +.endif +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to compute the convolution for a specified number + of filter rows. + +Arguments: + + KernelFrame - Supplies the symbol name to access the convolution kernel + stack. + + KernelType - Supplies the type of kernel to be generated. + + FilterCount - Supplies the number of rows from the filter to process. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a1 - Supplies the FilterStride parameter (see function description) when + KernelType!=Depthwise. Supplies the address of the filter buffer when + KernelType=Depthwise. + + t7 - Supplies the DilationWidth parameter (see function description). + + a4 - Supplies the address of the output buffer. + + a5 - Supplies the StrideWidth parameter (see function description). + + t5 - Supplies the InputStride parameter (see function description). + +--*/ + + .macro ProcessFilterCountN KernelFrame, KernelType, FilterCount + +// +// Process the output blocks that include left padding. +// + + ld.d $t0, $sp, OutputCountLeftPad_arg + beqz $t0, .L\KernelType\().\FilterCount\().ProcessOutputCount + bl MlasConv\KernelType\()FloatSingleLasxFilter\FilterCount\() + +// +// Process the output blocks that do not include any padding. +// + +.L\KernelType\().\FilterCount\().ProcessOutputCount: + ld.d $t0, $sp, OutputCount_arg + li.d $s0, 2 + bltu $t0, $s0, .L\KernelType\().\FilterCount\().ProcessRemainingOutputCount + +.L\KernelType\().\FilterCount\().ProcessNextOutputCountBy2: + ProcessOutputCountN Lasx, \KernelFrame\(), \KernelType\(), 8, \FilterCount\(), 2 + slli.d $s0, $a5, 1 # advance input by 2 elements + add.d $a0, $a0, $s0 + addi.d $t0, $t0, -2 + li.d $s0, 2 + bgeu $t0, $s0, .L\KernelType\().\FilterCount\().ProcessNextOutputCountBy2 + +.L\KernelType\().\FilterCount\().ProcessRemainingOutputCount: + +// +// Process the output blocks that include right padding plus any remaining output +// blocks from above. +// + +.L\KernelType\().\FilterCount\().ProcessOutputCountRightPadAndRemaining: + ld.d $s0, $sp, OutputCountRightPad_arg + add.d $t0, $t0, $s0 + beqz $t0, .L\KernelType\().ExitKernel + bl MlasConv\KernelType\()FloatSingleLasxFilter\FilterCount\() + + .endm + +/*++ + +Macro Description: + + This macro generates code to compute the convolution for a specified number + of filter rows for a pointwise convolution. + +Arguments: + + FilterCount - Supplies the number of rows from the filter to process. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a1 - Supplies the FilterStride parameter (see function description). + + t8 - Supplies the InputStride parameter (see function description). + + a4 - Supplies the address of the output buffer. + + a5 - Supplies the StrideWidth parameter (see function description). + + t0 - Supplies the OutputCount parameter (see function description). + + t2 - Supplies the address of the filter buffer. + +--*/ + + .macro ProcessPointwiseFilterCountN FilterCount + li.d $s0, 2 + bltu $t0, $s0, .LPointwise.\FilterCount\().ProcessRemainingOutputCount + +.LPointwise.\FilterCount\().ProcessNextOutputCountBy2: + ProcessPointwiseOutputCountN Lasx, 8, \FilterCount\(), 2 + slli.d $s0, $a5, 1 # advance input by 2 elements + add.d $a0, $a0, $s0 + addi.d $t0, $t0, -2 + li.d $s0, 2 + bgeu $t0, $s0, .LPointwise.\FilterCount\().ProcessNextOutputCountBy2 + +.LPointwise.\FilterCount\().ProcessRemainingOutputCount: + beqz $t0, .LPointwise.ExitKernel + ProcessPointwiseOutputCountN Lasx, 8, \FilterCount\(), 1 + + .endm + +// +// Generate the convolution kernels. +// + + SconvKernelFunction Nchw, 8, Lasx + SconvKernelFunction Nchwc, 8, Lasx, BiasFilter + SconvKernelDepthwiseFunction 8, Lasx + SconvKernelPointwiseFunction Lasx, BiasFilter + +/*++ + +Macro Description: + + This macro generates code to process an output block after the inner + convolution kernel has executed and then stores the output block to the + output buffer. + +Arguments: + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + +--*/ + + .macro PostProcessBlock FilterCount, OutputCount + + .globl MlasConvPostProcessFloatLasxFilter\FilterCount\()Output\OutputCount\() + .hidden MlasConvPostProcessFloatLasxFilter\FilterCount\()Output\OutputCount\() +MlasConvPostProcessFloatLasxFilter\FilterCount\()Output\OutputCount\(): + + .globl MlasConvPostProcessFloatFma3Filter\FilterCount\()Output\OutputCount\() + .hidden MlasConvPostProcessFloatFma3Filter\FilterCount\()Output\OutputCount\() +MlasConvPostProcessFloatFma3Filter\FilterCount\()Output\OutputCount\(): + +.if \FilterCount\() > 2 + slli.d $s0, $t6, 1 # compute output plus 2 rows + add.d $t7, $a4, $s0 +.endif + +// +// Test if the existing contents of the output buffer should be accumulated +// with the output block. +// + + andi $s0, $a2, MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT + beqz $s0, .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipAccumulateOutput + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvld $xr16, $a4, 0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvfadd.s $xr0, $xr0, $xr16" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvld $xr16, $a4, 32" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvfadd.s $xr4, $xr4, $xr16" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "xvld $xr16, $a4, 0x40" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "xvfadd.s $xr8, $xr8, $xr16" + + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvldx $xr16, $a4, $t6" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvfadd.s $xr1, $xr1, $xr16" + + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "add.d $s0, $a4, $t6" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvld $xr16, $s0, 0x20" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvfadd.s $xr5, $xr5, $xr16" + + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "add.d $s0, $a4, $t6" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "xvld $xr16, $s0, 0x40" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "xvfadd.s $xr9, $xr9, $xr16" + + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvld $xr16,$t7, 0" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvfadd.s $xr2, $xr2, $xr16" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvld $xr16,$t7, 0x20" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvfadd.s $xr6, $xr6, $xr16" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "xvld $xr16,$t7, 0x40" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "xvfadd.s $xr10, $xr10, $xr16" + + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvldx $xr16,$t7, $t6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvfadd.s $xr3, $xr3, $xr16" + + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "add.d $s0, $t7, $t6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvld $xr16,$s0, 0x20" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvfadd.s $xr7, $xr7, $xr16" + + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "add.d $s0, $t7, $t6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "xvld $xr16,$s0, 0x40" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "xvfadd.s $xr11, $xr11, $xr16" + + +.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipAccumulateOutput: + +// +// Test if the bias buffer should be accumulated with the output block. +// + + andi $s0, $a2, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION + beqz $s0, .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipBiasAddition +.if \OutputCount\() == 1 + EmitIfCountGE \FilterCount\(), 1, "xvld $xr16, $a3, 0" + EmitIfCountGE \FilterCount\(), 1, "xvfadd.s $xr0, $xr0, $xr16" + EmitIfCountGE \FilterCount\(), 2, "xvld $xr16, $a3, 0x20" + EmitIfCountGE \FilterCount\(), 2, "xvfadd.s $xr1, $xr1, $xr16" + EmitIfCountGE \FilterCount\(), 3, "xvld $xr16, $a3, 0x40" + EmitIfCountGE \FilterCount\(), 3, "xvfadd.s $xr2, $xr2, $xr16" + EmitIfCountGE \FilterCount\(), 4, "xvld $xr16, $a3, 0x60" + EmitIfCountGE \FilterCount\(), 4, "xvfadd.s $xr3, $xr3, $xr16" +.else + EmitIfCountGE \FilterCount\(), 1, "xvld $xr12, $a3, 0" + EmitIfCountGE \FilterCount\(), 2, "xvld $xr13, $a3, 0x20" + EmitIfCountGE \FilterCount\(), 3, "xvld $xr14, $a3, 0x40" + EmitIfCountGE \FilterCount\(), 4, "xvld $xr15, $a3, 0x60" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvfadd.s $xr0, $xr0, $xr12" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvfadd.s $xr4, $xr4, $xr12" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "xvfadd.s $xr8, $xr8, $xr12" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvfadd.s $xr1, $xr1, $xr13" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvfadd.s $xr5, $xr5, $xr13" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "xvfadd.s $xr9, $xr9, $xr13" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvfadd.s $xr2, $xr2, $xr14" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvfadd.s $xr6, $xr6, $xr14" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "xvfadd.s $xr10, $xr10, $xr14" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvfadd.s $xr3, $xr3, $xr15" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvfadd.s $xr7, $xr7, $xr15" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "xvfadd.s $xr11, $xr11, $xr15" + +.endif + +.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipBiasAddition: + +// +// Test for fused ReLU activation. +// + + andi $s0, $a2, MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION + beqz $s0, .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipReluActivation + xvxor.v $xr15, $xr15, $xr15 + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvfmax.s $xr0, $xr15, $xr0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvfmax.s $xr4, $xr15, $xr4" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "xvfmax.s $xr8, $xr15, $xr8" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvfmax.s $xr1, $xr15, $xr1" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvfmax.s $xr5, $xr15, $xr5" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "xvfmax.s $xr9, $xr15, $xr9" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvfmax.s $xr2, $xr15, $xr2" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvfmax.s $xr6, $xr15, $xr6" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "xvfmax.s $xr10, $xr15, $xr10" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvfmax.s $xr3, $xr15, $xr3" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvfmax.s $xr7, $xr15, $xr7" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "xvfmax.s $xr11, $xr15, $xr11" + +.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipReluActivation: + +// +// Store the output block in the output buffer. +// + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvst $xr0, $a4, 0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvst $xr4, $a4, 0x20" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "xvst $xr8, $a4, 0x40" + + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvstx $xr1, $a4, $t6" + + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "add.d $s0, $a4, $t6" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvst $xr5, $s0, 0x20" + + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "add.d $s0, $a4, $t6" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "xvst $xr9, $s0, 0x40" + + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvst $xr2, $t7, 0" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvst $xr6, $t7, 0x20" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "xvst $xr10, $t7, 0x40" + + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvstx $xr3, $t7, $t6" + + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "add.d $s0, $t7, $t6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvst $xr7, $s0, 0x20" + + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "add.d $s0, $t7, $t6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "xvst $xr11, $s0, 0x40" + + add_immed $a4,\OutputCount\()*8*4 # advance output by N nchw8c blocks + jr $ra + + .endm + + .irp FilterCount, 1, 2, 3, 4 + .irp OutputCount, 1, 2, 3 + PostProcessBlock \FilterCount\(), \OutputCount\() + .endr + .endr + + .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLasxCommon.h b/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLasxCommon.h new file mode 100644 index 0000000000000..bd2db816ed9ab --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLasxCommon.h @@ -0,0 +1,868 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SconvKernelLasxCommon.h + +Abstract: + + This module contains common kernel macros and structures for the single + precision convolution operation for the Lasx kernels. + +--*/ + + +#define SP_SIZE 32*8 + +#define MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT 0x00000001 +#define MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION 0x00000002 +#define MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION 0x00000004 +#define MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION 0x00000008 + +#define OutputStride_arg 6*8 +#define KernelHeight_arg 7*8 +#define KernelWidth_arg 8*8 +#define InputBase_arg 9*8 +#define InputWidth_arg 10*8 +#define DilatedInputWidth_arg 11*8 +#define OutputCountLeftPad_arg 12*8 +#define OutputCount_arg 13*8 +#define OutputCountRightPad_arg 14*8 +#define Bias_arg 15*8 +#define Flags_arg 16*8 +#define InputChannels_arg 17*8 +#define Filter_save_offset 18*8 + +/*++ + +Macro Description: + + This macro generates code to compute the convolution for a vector of input + blocks and a vector of filter blocks to produce a matrix of output blocks. + + OutputCount=1 generates special case code to handle padding blocks. All + other output counts assume no padding. + +Arguments: + + Isa - Supplies the instruction set architecture string for function tags. + + KernelFrame - Supplies the symbol name to access the convolution kernel + stack. + + KernelType - Supplies the type of kernel to be generated. + + BlockSize - Supplies the number of elements per block. + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a1 - Supplies the FilterStride parameter (see function description) when + KernelType!=Depthwise. Supplies the address of the filter buffer when + KernelType=Depthwise. + + s8 - Supplies the DilationWidth parameter (see function description). + + a4 - Supplies the address of the output buffer. + + a5 - Supplies the StrideWidth parameter (see function description). + + t5 - Supplies the InputStride parameter (see function description). +--*/ + .macro ProcessOutputCountN Isa, KernelFrame, KernelType, BlockSize, FilterCount, OutputCount + + move $a3, $a0 +.ifeqs "\KernelType\()","Depthwise" + move $a2, $a1 +.else + ld.d $a2, $sp, Filter_save_offset +.endif + ld.d $t1, $sp, KernelHeight_arg + ld.d $t2, $sp, KernelWidth_arg +.if \OutputCount\() == 1 + ld.d $t3, $sp, InputBase_arg + ld.d $t4, $sp, InputWidth_arg + sub.d $t3, $zero, $t3 +.endif + ClearBlock \FilterCount\(), \OutputCount\() + beqz $t1, .L\KernelType\().\FilterCount\().\OutputCount\().HandlePostProcessing + +.L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextRow: + move $t6, $t2 # reload kernel width remaining + +.L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextColumn: +.if \OutputCount\() == 1 + add.d $t7, $a3, $t3 # compute (Input - InputBase) + # (Input - InputBase) >= InputWidth? + bgeu $t7, $t4, .L\KernelType\().\FilterCount\().\OutputCount\().SkipOverPadding +.endif +.if \OutputCount\() > 3 + slli.d $s0, $a5, 1 + add.d $s0, $s0, $a5 + add.d $t4, $a3, $s0 # compute input plus 3 blocks +.endif +.if \FilterCount\() > 2 + slli.d $s0, $a1, 1 # compute filter plus 2 rows + add.d $t7, $a2, $s0 +.endif +.ifeqs "\KernelType\()","Nchwc" +.if \BlockSize\() == 16 + .irp Index, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), \Index\()*16*4, \Index\()*4 + .endr +.else + .irp Index, 0, 1, 2, 3, 4, 5, 6, 7 + ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), (\Index\()-4)*8*4, \Index\()*4 + .endr +.endif +.else + ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), 0, 0 +.endif + +.L\KernelType\().\FilterCount\().\OutputCount\().SkipOverPadding: + # advance input by dilation width + add.d $a3, $a3, $t8 +.ifeqs "\KernelType\()","Nchwc" + # advance filter by 8i8o/16i16o block + addi.d $a2, $a2, \BlockSize\()*\BlockSize\()*4 +.else + addi.d $a2, $a2, \BlockSize\()*4 # advance filter by 8o/16o block +.endif + addi.d $t6, $t6, -1 + bnez $t6, .L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextColumn + add.d $a3, $a3, $t5 # advance input to next row +.if \OutputCount\() == 1 + ld.d $s0, $sp, DilatedInputWidth_arg + # advance input base to next row + sub.d $t3, $t3, $s0 +.endif + addi.d $t1, $t1, -1 # decrement rows remaining + bnez $t1, .L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextRow + +// +// Handle post processing of the output block. +// + +.L\KernelType\().\FilterCount\().\OutputCount\().HandlePostProcessing: + ld.w $a2, $sp, Flags_arg +.if \FilterCount\() > 1 + ld.d $t6, $sp, OutputStride_arg +.endif + ld.d $a3, $sp, Bias_arg + bl MlasConvPostProcessFloat\Isa\()Filter\FilterCount\()Output\OutputCount\() + + .endm + +/*++ + +Macro Description: + + This macro generates code for the inner convolution kernel. + +Arguments: + + KernelType - Supplies the type of kernel to be generated. + + BlockSize - Supplies the number of elements per block. + + Isa - Supplies the instruction set architecture string for function tags. + + BiasFilter - Supplies a non-blank value if the address of the filter buffer + should be biased to point to the middle of a OIhw8i8o block in order to + reduce the code size from relative byte offsets. + +--*/ + + .macro SconvKernelFunction KernelType, BlockSize, Isa, BiasFilter + +/*++ + +Routine Description: + + This routine is the inner kernel to compute a convolution for the elements + of an output row for a set of filter rows. + +Arguments: + + Input (a0) - Supplies the address of the input buffer. + + The address is biased to include padding blocks for the left width + dimension. The address is not biased to include padding rows for the + left height dimension these are accounted for in the outer kernel. + + Filter (a1) - Supplies the address of the filter buffer. + + Output (a2) - Supplies the address of the output buffer. + + StrideWidth (a3) - Supplies the length in bytes of the blocked stride width. + + DilationWidth (a4) - Supplies the length in bytes of the blocked dilation + width. + + FilterCount (a5) - Supplies the number of filters to process in this + iteration. + + InputStride (a6)- Supplies the length in bytes to advance the input buffer to + the next input row. + + FilterStride (a7) - Supplies the length in bytes to advance the filter buffer + to the next set of filters. + + OutputStride (sp + 0)- Supplies the length in bytes to advance the output buffer + to the next output address associated with the next set of filters. + + KernelHeight (sp + 8)- Supplies the height of the kernel to apply. This height may + be less than the original kernel height after removing any padding + rows. + + KernelWidth (sp + 0x10)- Supplies the width of the kernel to apply. + + InputBase (sp + 0x18)- Supplies the address of the valid input buffer. + + This parameter is similar to the Input parameter, but does not include + the padding blocks for the left width dimension. This parameter is used + with the following InputWidth parameter in order to validate that the + current input buffer address in bounds and not in the left or right + width padding region. + + InputWidth (sp + 0x20)- Supplies the length in bytes of the blocked input width. + + DilatedInputWidth (sp + 0x28)- Supplies the length in bytes to advance the input base + buffer to the next input row including dilation. + + OutputCountLeftPad (sp + 0x30)- Supplies the number of output elements that include + one or more padding elements from the left edge. + + OutputCount (sp + 0x38)- Supplies the number of output elements that do not include + any padding elements. + + OutputCountRightPad (sp + 0x40)- Supplies the number of output elements that include + one or more padding elements from the right edge. + + Bias (sp + 0x48)- Supplies the address of the bias buffer. + + Flags (sp + 0x50)- Supplies additional flags controlling the convolution operation, + especially post calculation options. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasConv\KernelType\()FloatKernel\Isa\() + + addi.d $sp, $sp, -SP_SIZE + st.d $s0, $sp, 0 + st.d $s1, $sp, 8 + st.d $s2, $sp, 2*8 + st.d $ra, $sp, 5*8 + + ld.d $t0, $sp, SP_SIZE+0*8 + ld.d $t1, $sp, SP_SIZE+1*8 + ld.d $t2, $sp, SP_SIZE+2*8 + ld.d $t3, $sp, SP_SIZE+3*8 + st.d $t0, $sp, OutputStride_arg + st.d $t1, $sp, KernelHeight_arg + st.d $t2, $sp, KernelWidth_arg + st.d $t3, $sp, InputBase_arg + ld.d $t0, $sp, SP_SIZE+4*8 + ld.d $t1, $sp, SP_SIZE+5*8 + ld.d $t2, $sp, SP_SIZE+6*8 + ld.d $t3, $sp, SP_SIZE+7*8 + st.d $t0, $sp, InputWidth_arg + st.d $t1, $sp, DilatedInputWidth_arg + st.d $t2, $sp, OutputCountLeftPad_arg + st.d $t3, $sp, OutputCount_arg + ld.d $t0, $sp, SP_SIZE+8*8 + ld.d $t1, $sp, SP_SIZE+9*8 + ld.d $t2, $sp, SP_SIZE+10*8 + st.d $t0, $sp, OutputCountRightPad_arg + st.d $t1, $sp, Bias_arg + st.d $t2, $sp, Flags_arg + +.ifeqs "\BiasFilter\()","BiasFilter" + addi.d $a1, $a1, 4*8*4 +.endif + st.d $a1, $sp, Filter_save_offset + move $a1, $a7 + move $t5, $a6 + move $t8, $a4 + move $t1, $a5 + move $a4, $a2 + move $a5, $a3 + +// +// Process the specified number of filter rows. +// + + ori $s0, $zero, 3 + beq $t1, $s0, .L\KernelType\().ProcessFilterCount3 + bltu $t1, $s0, .L\KernelType\().ProcessFilterCountLessThan3 + ProcessFilterCountN LSconvKernelFrame, \KernelType\(), 4 + b .L\KernelType\().ExitKernel + +.L\KernelType\().ProcessFilterCount3: + ProcessFilterCountN LSconvKernelFrame, \KernelType\(), 3 + b .L\KernelType\().ExitKernel + +.L\KernelType\().ProcessFilterCountLessThan3: + ori $s0, $zero, 2 + bltu $t1, $s0, .L\KernelType\().ProcessFilterCount1 + ProcessFilterCountN LSconvKernelFrame, \KernelType\(), 2 + b .L\KernelType\().ExitKernel + +.L\KernelType\().ProcessFilterCount1: + ProcessFilterCountN LSconvKernelFrame, \KernelType\(), 1 + +// +// Restore non-volatile registers and return. +// + +.L\KernelType\().ExitKernel: +.ifnes "\Isa\()","LSX" + xvinsgr2vr.d $xr0, $zero, 2 + xvinsgr2vr.d $xr0, $zero, 3 + xvinsgr2vr.d $xr1, $zero, 2 + xvinsgr2vr.d $xr1, $zero, 3 + xvinsgr2vr.d $xr2, $zero, 2 + xvinsgr2vr.d $xr2, $zero, 3 + xvinsgr2vr.d $xr3, $zero, 2 + xvinsgr2vr.d $xr3, $zero, 3 + xvinsgr2vr.d $xr4, $zero, 2 + xvinsgr2vr.d $xr4, $zero, 3 + xvinsgr2vr.d $xr5, $zero, 2 + xvinsgr2vr.d $xr5, $zero, 3 + xvinsgr2vr.d $xr6, $zero, 2 + xvinsgr2vr.d $xr6, $zero, 3 + xvinsgr2vr.d $xr7, $zero, 2 + xvinsgr2vr.d $xr7, $zero, 3 + xvinsgr2vr.d $xr8, $zero, 2 + xvinsgr2vr.d $xr8, $zero, 3 + xvinsgr2vr.d $xr9, $zero, 2 + xvinsgr2vr.d $xr9, $zero, 3 + xvinsgr2vr.d $xr10, $zero, 2 + xvinsgr2vr.d $xr10, $zero, 3 + xvinsgr2vr.d $xr11, $zero, 2 + xvinsgr2vr.d $xr11, $zero, 3 + xvinsgr2vr.d $xr12, $zero, 2 + xvinsgr2vr.d $xr12, $zero, 3 + xvinsgr2vr.d $xr13, $zero, 2 + xvinsgr2vr.d $xr13, $zero, 3 + xvinsgr2vr.d $xr14, $zero, 2 + xvinsgr2vr.d $xr14, $zero, 3 + xvinsgr2vr.d $xr15, $zero, 2 + xvinsgr2vr.d $xr15, $zero, 3 +.endif + ld.d $s0, $sp, 0 + ld.d $s1, $sp, 8 + ld.d $s2, $sp, 2*8 + ld.d $ra, $sp, 5*8 + addi.d $sp, $sp, SP_SIZE + jirl $zero, $ra, 0 + +.ifnes "\Isa\()","LSX" + +// +// Generate out-of-band helpers for handling output blocks involving padding. +// + + .irp FilterCount, 1, 2, 3, 4 + +MlasConv\KernelType\()FloatSingle\Isa\()Filter\FilterCount\(): + st.d $ra, $sp, 19*8 +loopMlasConv\KernelType\()FloatSingle\Isa\()Filter\FilterCount\(): + ProcessOutputCountN \Isa\(), LSconvKernelSingleFrame, \KernelType\(), \BlockSize\(), \FilterCount\(), 1 + add.d $a0, $a0, $a5 # advance input by 1 element + addi.d $t0, $t0, -1 # decrement output count remaining + bnez $t0, loopMlasConv\KernelType\()FloatSingle\Isa\()Filter\FilterCount\() + ld.d $ra, $sp, 19*8 + jr $ra + + .endr + +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code for the inner convolution kernel for the special + case of a depthwise separable convolution. + +Arguments: + + BlockSize - Supplies the number of elements per block. + + Isa - Supplies the instruction set architecture string for function tags. + +--*/ + + .macro SconvKernelDepthwiseFunction BlockSize, Isa + +/*++ + +Routine Description: + + This routine is the inner kernel to compute a convolution for the elements + of an output row for a set of filter rows. + + Depthwise separable convolutions are a form of grouped convolution where + the number of input and output channels per group are one. + +Arguments: + + Input (a0) - Supplies the address of the input buffer. + + The address is biased to include padding blocks for the left width + dimension. The address is not biased to include padding rows for the + left height dimension these are accounted for in the outer kernel. + + Filter (a1) - Supplies the address of the filter buffer. + + Output (a2) - Supplies the address of the output buffer. + + StrideWidth (a3) - Supplies the length in bytes of the blocked stride width. + + DilationWidth (a4) - Supplies the length in bytes of the blocked dilation + width. + + InputStride (a5) - Supplies the length in bytes to advance the input buffer + to the next input row. + + KernelHeight (a6)- Supplies the height of the kernel to apply. This height may + be less than the original kernel height after removing any padding + rows. + + KernelWidth (a7)- Supplies the width of the kernel to apply. + + InputBase (sp + 0 )- Supplies the address of the valid input buffer. + + This parameter is similar to the Input parameter, but does not include + the padding blocks for the left width dimension. This parameter is used + with the following InputWidth parameter in order to validate that the + current input buffer address in bounds and not in the left or right + width padding region. + + InputWidth (sp + 8 )- Supplies the length in bytes of the blocked input width. + + DilatedInputWidth (sp + 0x10)- Supplies the length in bytes to advance the input base + buffer to the next input row including dilation. + + OutputCountLeftPad (sp + 0x18)- Supplies the number of output elements that include + one or more padding elements from the left edge. + + OutputCount (sp + 0x20)- Supplies the number of output elements that do not include + any padding elements. + + OutputCountRightPad (sp + 0x28)- Supplies the number of output elements that include + one or more padding elements from the right edge. + + Bias (sp + 0x30)- Supplies the address of the bias buffer. + + Flags (sp + 0x38)- Supplies additional flags controlling the convolution operation, + especially post calculation options. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasConvDepthwiseFloatKernel\Isa\() + + addi.d $sp, $sp, -SP_SIZE + st.d $s0, $sp, 0 + st.d $s1, $sp, 8 + st.d $s2, $sp, 2*8 + st.d $ra, $sp, 5*8 + + st.d $a6, $sp, KernelHeight_arg + st.d $a7, $sp, KernelWidth_arg + + ld.d $t0, $sp, SP_SIZE+0*8 + ld.d $t1, $sp, SP_SIZE+1*8 + ld.d $t2, $sp, SP_SIZE+2*8 + ld.d $t3, $sp, SP_SIZE+3*8 + st.d $t0, $sp, InputBase_arg + st.d $t1, $sp, InputWidth_arg + st.d $t2, $sp, DilatedInputWidth_arg + st.d $t3, $sp, OutputCountLeftPad_arg + ld.d $t0, $sp, SP_SIZE+4*8 + ld.d $t1, $sp, SP_SIZE+5*8 + ld.d $t2, $sp, SP_SIZE+6*8 + ld.d $t3, $sp, SP_SIZE+7*8 + st.d $t0, $sp, OutputCount_arg + st.d $t1, $sp, OutputCountRightPad_arg + st.d $t2, $sp, Bias_arg + st.d $t3, $sp, Flags_arg + + move $t8, $a4 + move $t5, $a5 + move $a4, $a2 + move $a5, $a3 + +// +// Process the specified number of filter rows. +// + + ProcessFilterCountN LSconvKernelDepthwiseFrame, Depthwise, 1 + +// +// Restore non-volatile registers and return. +// + +.LDepthwise.ExitKernel: +.ifnes "\Isa\()","LSX" + xvinsgr2vr.d $xr0, $zero, 2 + xvinsgr2vr.d $xr0, $zero, 3 + xvinsgr2vr.d $xr1, $zero, 2 + xvinsgr2vr.d $xr1, $zero, 3 + xvinsgr2vr.d $xr2, $zero, 2 + xvinsgr2vr.d $xr2, $zero, 3 + xvinsgr2vr.d $xr3, $zero, 2 + xvinsgr2vr.d $xr3, $zero, 3 + xvinsgr2vr.d $xr4, $zero, 2 + xvinsgr2vr.d $xr4, $zero, 3 + xvinsgr2vr.d $xr5, $zero, 2 + xvinsgr2vr.d $xr5, $zero, 3 + xvinsgr2vr.d $xr6, $zero, 2 + xvinsgr2vr.d $xr6, $zero, 3 + xvinsgr2vr.d $xr7, $zero, 2 + xvinsgr2vr.d $xr7, $zero, 3 + xvinsgr2vr.d $xr8, $zero, 2 + xvinsgr2vr.d $xr8, $zero, 3 + xvinsgr2vr.d $xr9, $zero, 2 + xvinsgr2vr.d $xr9, $zero, 3 + xvinsgr2vr.d $xr10, $zero, 2 + xvinsgr2vr.d $xr10, $zero, 3 + xvinsgr2vr.d $xr11, $zero, 2 + xvinsgr2vr.d $xr11, $zero, 3 + xvinsgr2vr.d $xr12, $zero, 2 + xvinsgr2vr.d $xr12, $zero, 3 + xvinsgr2vr.d $xr13, $zero, 2 + xvinsgr2vr.d $xr13, $zero, 3 + xvinsgr2vr.d $xr14, $zero, 2 + xvinsgr2vr.d $xr14, $zero, 3 + xvinsgr2vr.d $xr15, $zero, 2 + xvinsgr2vr.d $xr15, $zero, 3 +.endif + ld.d $s0, $sp, 0 + ld.d $s1, $sp, 8 + ld.d $s2, $sp, 2*8 + ld.d $ra, $sp, 5*8 + addi.d $sp, $sp, SP_SIZE + jr $ra + +.ifnes "\Isa\()","LSX" + +// +// Generate out-of-band helpers for handling output blocks involving padding. +// + +MlasConvDepthwiseFloatSingle\Isa\()Filter1: + st.d $ra, $sp, 20*8 +MlasConvDepthwiseFloatSingle\Isa\()Filter1_loop: + ProcessOutputCountN \Isa\(), LSconvKernelDepthwiseSingleFrame, Depthwise, \BlockSize\(), 1, 1 + add.d $a0, $a0, $a5 # advance input by 1 element + addi.d $t0, $t0, -1 # decrement output count remaining + + bnez $t0, MlasConvDepthwiseFloatSingle\Isa\()Filter1_loop + ld.d $ra, $sp, 20*8 + jr $ra + +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to compute the convolution for a vector of input + blocks and a vector of filter blocks to produce a matrix of output blocks + for a pointwise convolution. + +Arguments: + + Isa - Supplies the instruction set architecture string for function tags. + + BlockSize - Supplies the number of elements per block. + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a1 - Supplies the FilterStride parameter (see function description). + + t8 - Supplies the InputStride parameter (see function description). + + a4 - Supplies the address of the output buffer. + + a5 - Supplies the StrideWidth parameter (see function description). + + t2 - Supplies the address of the filter buffer. + +--*/ + + .macro ProcessPointwiseOutputCountN Isa, BlockSize, FilterCount, OutputCount + + move $a3, $a0 + move $a2, $t2 + ld.d $t1, $sp, InputChannels_arg + ClearBlock \FilterCount\(), \OutputCount\() + +.LPointwise.\FilterCount\().\OutputCount\().ProcessNextInputBlock: +.if \OutputCount\() > 3 + slli.d $s0, $a5, 1 + add.d $s0, $s0, $a5 + add.d $t4, $s0, $a3 +.endif +.if \FilterCount\() > 2 + slli.d $s0, $a1, 1 + add.d $t7, $a2, $s0 +.endif +.if \BlockSize\() == 16 + .irp Index, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + ComputeBlock Pointwise, \FilterCount\(), \OutputCount\(), \Index\()*16*4, \Index\()*4 + .endr +.else + .irp Index, 0, 1, 2, 3, 4, 5, 6, 7 + ComputeBlock Pointwise, \FilterCount\(), \OutputCount\(), (\Index\()-4)*8*4, \Index\()*4 + .endr +.endif + add.d $a3, $a3, $t8 # advance input to next channel block + + addi.d $a2, $a2, \BlockSize\()*\BlockSize\()*4 # advance filter by 8i8o/16i16o block + addi.d $t1, $t1, -1 # decrement input blocks remaining + + bnez $t1, .LPointwise.\FilterCount\().\OutputCount\().ProcessNextInputBlock + +// +// Handle post processing of the output block. +// + + ld.w $a2, $sp, Flags_arg +.if \FilterCount\() > 1 + ld.d $t6, $sp, OutputStride_arg +.endif + ld.d $a3, $sp, Bias_arg + bl MlasConvPostProcessFloat\Isa\()Filter\FilterCount\()Output\OutputCount\() + + .endm + +/*++ + +Macro Description: + + This macro generates code for the inner convolution kernel for the special + case where the kernel dimensions are 1. + +Arguments: + + Isa - Supplies the instruction set architecture string for function tags. + + BiasFilter - Supplies a non-blank value if the address of the filter buffer + should be biased to point to the middle of a OIhw8i8o block in order to + reduce the code size from relative byte offsets. + +--*/ + + .macro SconvKernelPointwiseFunction Isa, BiasFilter + +/*++ + +Routine Description: + + This routine is the inner kernel to compute a convolution for the elements + of an output row for a set of filter rows. + + Pointwise convolutions have a kernel size of one. To simplify this + implementation, no input padding is allowed, which matches typical usage in + models. + +Arguments: + + Input (a0) - Supplies the address of the input buffer. + + Filter (a1) - Supplies the address of the filter buffer. + + Output (a2) - Supplies the address of the output buffer. + + StrideWidth (a3) - Supplies the length in bytes of the blocked stride width. + + InputChannels (a4) - Supplies the number of input channels to process. + + FilterCount (a5) - Supplies the number of rows from the filter to process. + + InputStride (a6) - Supplies the length in bytes to advance the input buffer to + the next input channel of the same input row. + + FilterStride (a7) - Supplies the length in bytes to advance the filter buffer + to the next set of filters. + + OutputStride (sp + 0)- Supplies the length in bytes to advance the output buffer + to the next output address associated with the next set of filters. + + OutputCount (sp + 8)- Supplies the number of output elements. + + Bias (sp + 0x10)- Supplies the address of the bias buffer. + + Flags (sp + 0x18)- Supplies additional flags controlling the convolution operation, + especially post calculation options. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasConvPointwiseFloatKernel\Isa\() + + addi.d $sp, $sp, -SP_SIZE + st.d $s0, $sp, 0*8 + st.d $s1, $sp, 1*8 + st.d $s2, $sp, 2*8 + st.d $ra, $sp, 5*8 + + ld.d $t0, $sp, SP_SIZE+0*8 + ld.d $t1, $sp, SP_SIZE+1*8 + ld.d $t2, $sp, SP_SIZE+2*8 + ld.d $t3, $sp, SP_SIZE+3*8 + st.d $t0, $sp, OutputStride_arg + st.d $t1, $sp, OutputCount_arg + st.d $t2, $sp, Bias_arg + st.d $t3, $sp, Flags_arg + st.d $a4, $sp, InputChannels_arg + +.ifeqs "\BiasFilter\()","BiasFilter" + addi.d $t2, $a1, 4*8*4 +.else + move $t2, $a1 +.endif + ld.d $t0, $sp, OutputCount_arg + move $a1, $a7 + move $t8, $a6 + move $t1, $a5 + move $a4, $a2 + move $a5, $a3 + +// +// Process the specified number of filter rows. +// + + ori $s0, $zero, 3 + beq $t1, $s0, .LPointwise.ProcessFilterCount3 + bltu $t1, $s0, .LPointwise.ProcessFilterCountLessThan3 + ProcessPointwiseFilterCountN 4 + b .LPointwise.ExitKernel + +.LPointwise.ProcessFilterCount3: + ProcessPointwiseFilterCountN 3 + b .LPointwise.ExitKernel + +.LPointwise.ProcessFilterCountLessThan3: + ori $s0, $zero, 2 + bltu $t1, $s0, .LPointwise.ProcessFilterCount1 + ProcessPointwiseFilterCountN 2 + b .LPointwise.ExitKernel + +.LPointwise.ProcessFilterCount1: + ProcessPointwiseFilterCountN 1 + +// +// Restore non-volatile registers and return. +// + +.LPointwise.ExitKernel: +.ifnes "\Isa\()","LSX" + xvinsgr2vr.d $xr0, $zero, 2 + xvinsgr2vr.d $xr0, $zero, 3 + xvinsgr2vr.d $xr1, $zero, 2 + xvinsgr2vr.d $xr1, $zero, 3 + xvinsgr2vr.d $xr2, $zero, 2 + xvinsgr2vr.d $xr2, $zero, 3 + xvinsgr2vr.d $xr3, $zero, 2 + xvinsgr2vr.d $xr3, $zero, 3 + xvinsgr2vr.d $xr4, $zero, 2 + xvinsgr2vr.d $xr4, $zero, 3 + xvinsgr2vr.d $xr5, $zero, 2 + xvinsgr2vr.d $xr5, $zero, 3 + xvinsgr2vr.d $xr6, $zero, 2 + xvinsgr2vr.d $xr6, $zero, 3 + xvinsgr2vr.d $xr7, $zero, 2 + xvinsgr2vr.d $xr7, $zero, 3 + xvinsgr2vr.d $xr8, $zero, 2 + xvinsgr2vr.d $xr8, $zero, 3 + xvinsgr2vr.d $xr9, $zero, 2 + xvinsgr2vr.d $xr9, $zero, 3 + xvinsgr2vr.d $xr10, $zero, 2 + xvinsgr2vr.d $xr10, $zero, 3 + xvinsgr2vr.d $xr11, $zero, 2 + xvinsgr2vr.d $xr11, $zero, 3 + xvinsgr2vr.d $xr12, $zero, 2 + xvinsgr2vr.d $xr12, $zero, 3 + xvinsgr2vr.d $xr13, $zero, 2 + xvinsgr2vr.d $xr13, $zero, 3 + xvinsgr2vr.d $xr14, $zero, 2 + xvinsgr2vr.d $xr14, $zero, 3 + xvinsgr2vr.d $xr15, $zero, 2 + xvinsgr2vr.d $xr15, $zero, 3 +.endif + ld.d $s0, $sp, 0*8 + ld.d $s1, $sp, 1*8 + ld.d $s2, $sp, 2*8 + ld.d $ra, $sp, 5*8 + addi.d $sp, $sp, SP_SIZE + jr $ra + + .endm + +/*++ + +Macro Description: + + This macro generates code to clear the block accumulators. + +Arguments: + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + xr0-xr11 - Supplies the block accumulators. + +--*/ + + .macro ClearBlock FilterCount, OutputCount + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvxor.v $xr0, $xr0, $xr0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvxor.v $xr4, $xr4, $xr4" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "xvxor.v $xr8, $xr8, $xr8" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvxor.v $xr1, $xr1, $xr1" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvxor.v $xr5, $xr5, $xr5" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "xvxor.v $xr9, $xr9, $xr9" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvxor.v $xr2, $xr2, $xr2" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvxor.v $xr6, $xr6, $xr6" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "xvxor.v $xr10, $xr10, $xr10" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvxor.v $xr3, $xr3, $xr3" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvxor.v $xr7, $xr7, $xr7" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "xvxor.v $xr11, $xr11, $xr11" + + .endm diff --git a/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLsx.S b/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLsx.S new file mode 100644 index 0000000000000..04b8dc14d067d --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLsx.S @@ -0,0 +1,339 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SconvKernelLsx.S + +Abstract: + + This module implements the kernels for the single precision convolution + operation. + + This implementation uses Lsx instructions. + +--*/ + +#include "asmmacro.h" +#include "SconvKernelLsxCommon.h" + +/*++ + +Macro Description: + + This macro generates code to clear the block accumulators. + +Arguments: + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + vr0-vr7 - Supplies the block accumulators. + +--*/ + + .macro ClearBlock FilterCount, OutputCount + + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vxor.v $vr0,$vr0,$vr0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vxor.v $vr1,$vr1,$vr1" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vxor.v $vr2,$vr2,$vr2" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vxor.v $vr3,$vr3,$vr3" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vxor.v $vr4,$vr4,$vr4" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vxor.v $vr5,$vr5,$vr5" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vxor.v $vr6,$vr6,$vr6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vxor.v $vr7,$vr7,$vr7" + + .endm + +/*++ + +Macro Description: + + This macro multiplies and accumulates for FilterCount by OutputCount block + of the output buffer. + +Arguments: + + KernelType - Supplies the type of kernel to be generated. + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + + VectorOffset - Supplies the byte offset from the filter buffer to fetch + elements. + + BroadcastOffset - Supplies the byte offset from the input buffer to fetch + elements. + +Implicit Arguments: + + a3 - Supplies the address of the input buffer. + + a2 - Supplies the address of the filter buffer. + + a1 - Supplies the FilterStride parameter (see function description). + + t6 - Supplies the address of the filter buffer plus 2 * FilterStride. + + a5 - Supplies the StrideWidth parameter (see function description). + + vr0-vr7 - Supplies the block accumulators. + +--*/ + .macro ComputeBlock KernelType, FilterCount, OutputCount, VectorOffset, BroadcastOffset + +.ifeqs "\KernelType\()","Depthwise" + vld $vr8, $a2, 0 + vld $vr9, $a2, 16 + vld $vr10, $a3, 0 + vld $vr11, $a3, 16 + vfmadd.s $vr0, $vr8, $vr10, $vr0 + vfmadd.s $vr1, $vr9, $vr11, $vr1 +.else + EmitIfCountGE \OutputCount\(), 1, "ld.w $s0, $a3, \BroadcastOffset\()" + EmitIfCountGE \OutputCount\(), 1, "vreplgr2vr.w $vr12, $s0" + EmitIfCountGE \FilterCount\(), 1, "vld $vr8, $a2, \VectorOffset\()" + EmitIfCountGE \FilterCount\(), 1, "vld $vr9, $a2, \VectorOffset\()+16" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfmadd.s $vr0, $vr8, $vr12, $vr0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfmadd.s $vr1, $vr9, $vr12, $vr1" + EmitIfCountGE \FilterCount\(), 2, "addi.d $s0, $a1, +\VectorOffset\()" + EmitIfCountGE \FilterCount\(), 2, "vldx $vr8, $a2, $s0" + EmitIfCountGE \FilterCount\(), 2, "addi.d $s0, $a1, +\VectorOffset\()+16" + EmitIfCountGE \FilterCount\(), 2, "vldx $vr9, $a2, $s0" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfmadd.s $vr2, $vr8, $vr12, $vr2" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfmadd.s $vr3, $vr9, $vr12, $vr3" + EmitIfCountGE \FilterCount\(), 3, "vld $vr8, $t7, \VectorOffset\()" + EmitIfCountGE \FilterCount\(), 3, "vld $vr9, $t7, \VectorOffset\()+16" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfmadd.s $vr4, $vr8, $vr12, $vr4" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfmadd.s $vr5, $vr9, $vr12, $vr5" + EmitIfCountGE \FilterCount\(), 4, "addi.d $s0, $a1, \VectorOffset\()" + EmitIfCountGE \FilterCount\(), 4, "vldx $vr8, $t7, $s0" + EmitIfCountGE \FilterCount\(), 4, "addi.d $s0, $a1, \VectorOffset\()+16" + EmitIfCountGE \FilterCount\(), 4, "vldx $vr9, $t7, $s0" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfmadd.s $vr6, $vr8, $vr12, $vr6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfmadd.s $vr7, $vr9, $vr12, $vr7" +.endif + .endm +/*++ + +Macro Description: + + This macro generates code to compute the convolution for a specified number + of filter rows. + +Arguments: + + KernelFrame - Supplies the symbol name to access the convolution kernel + stack. + + KernelType - Supplies the type of kernel to be generated. + + FilterCount - Supplies the number of rows from the filter to process. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a1 - Supplies the FilterStride parameter (see function description) when + KernelType!=Depthwise. Supplies the address of the filter buffer when + KernelType=Depthwise. + + s8 - Supplies the DilationWidth parameter (see function description). + + a4 - Supplies the address of the output buffer. + + a5 - Supplies the StrideWidth parameter (see function description). + + s3 - Supplies the InputStride parameter (see function description). + +--*/ + + .macro ProcessFilterCountN KernelFrame, KernelType, FilterCount + ld.d $s0, $sp, OutputCountLeftPad_arg //OutputCountLeftPad + ld.d $s1, $sp, OutputCount_arg //OutputCount + add.d $s0, $s0, $s1 + ld.d $s1, $sp, OutputCountRightPad_arg //OutputCountRightPad + add.d $t0, $s0, $s1 +.L\KernelType\().\FilterCount\().ProcessNextOutputCount: + ProcessOutputCountN Sse, \KernelFrame\(), \KernelType\(), 8, \FilterCount\(), 1 + add.d $a0, $a0, $a5 + addi.d $t0, $t0, -1 + bnez $t0, .L\KernelType\().\FilterCount\().ProcessNextOutputCount + .endm + +/*++ + +Macro Description: + + This macro generates code to compute the convolution for a specified number + of filter rows for a pointwise convolution. + +Arguments: + + FilterCount - Supplies the number of rows from the filter to process. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a1 - Supplies the FilterStride parameter (see function description). + + s8 - Supplies the InputStride parameter (see function description). + + a4 - Supplies the address of the output buffer. + + a5 - Supplies the StrideWidth parameter (see function description). + + t7 - Supplies the OutputCount parameter (see function description). + + s5 - Supplies the address of the filter buffer. + +--*/ + + .macro ProcessPointwiseFilterCountN FilterCount +.LPointwise.\FilterCount\().ProcessNextOutputCount: + ProcessPointwiseOutputCountN Sse, 8, \FilterCount\(), 1 + add.d $a0, $a0, $a5 + addi.d $t0, $t0, -1 + bnez $t0, .LPointwise.\FilterCount\().ProcessNextOutputCount + .endm + +// +// Generate the convolution kernels. +// + + SconvKernelFunction Nchw, 8, LSX + SconvKernelFunction Nchwc, 8, LSX, BiasFilter + SconvKernelDepthwiseFunction 8, LSX + SconvKernelPointwiseFunction LSX, BiasFilter + +/*++ + +Macro Description: + + This macro generates code to process an output block after the inner + convolution kernel has executed and then stores the output block to the + output buffer. + +Arguments: + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. +--*/ + + .macro PostProcessBlock FilterCount, OutputCount + + .globl MlasConvPostProcessFloatSseFilter\FilterCount\()Output\OutputCount\() +#if !defined(__APPLE__) + .hidden MlasConvPostProcessFloatSseFilter\FilterCount\()Output\OutputCount\() +#endif +MlasConvPostProcessFloatSseFilter\FilterCount\()Output\OutputCount\(): + +.if \FilterCount\() > 2 + li.d $s0, 2 + mul.d $s0, $s0, $t6 + add.d $t7, $a4, $s0 +.endif + andi $s0, $a2, MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT + andi $s0, $s0, 0xff + beqz $s0, .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipAccumulateOutput + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vld $vr8, $a4, 0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vld $vr9, $a4, 16" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vldx $vr10, $a4, $t6" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "addi.d $s0, $t6, 16" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vldx $vr11, $a4, $s0" + + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vld $vr12, $t7, 0" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vld $vr13, $t7, 16" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vldx $vr14, $t7, $t6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "addi.d $s0, $t6, 16" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vldx $vr15, $t7, $s0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfadd.s $vr0, $vr0, $vr8" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfadd.s $vr1, $vr1, $vr9" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfadd.s $vr2, $vr2, $vr10" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfadd.s $vr3, $vr3, $vr11" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfadd.s $vr4, $vr4, $vr12" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfadd.s $vr5, $vr5, $vr13" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfadd.s $vr6, $vr6, $vr14" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfadd.s $vr7, $vr7, $vr15" + +.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipAccumulateOutput: +// +// Test if the bias buffer should be accumulated with the output block. +// + + andi $s0, $a2, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION + andi $s0, $s0, 0xff + beqz $s0, .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipBiasAddition + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vld $vr8, $a3, 0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vld $vr9, $a3, 16" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vld $vr10, $a3, 32" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vld $vr11, $a3, 48" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vld $vr12, $a3, 64" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vld $vr13, $a3, 80" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vld $vr14, $a3, 96" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vld $vr15, $a3, 112" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfadd.s $vr0, $vr0, $vr8" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfadd.s $vr1, $vr1, $vr9" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfadd.s $vr2, $vr2, $vr10" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfadd.s $vr3, $vr3, $vr11" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfadd.s $vr4, $vr4, $vr12" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfadd.s $vr5, $vr5, $vr13" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfadd.s $vr6, $vr6, $vr14" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfadd.s $vr7, $vr7, $vr15" + +.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipBiasAddition: + +// +// Test for fused ReLU activation. +// + + andi $s0, $a2, MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION + andi $s0, $s0, 0xff + beqz $s0, .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipReluActivation + vxor.v $vr15,$vr15, $vr15 + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfmax.s $vr0, $vr0, $vr15" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfmax.s $vr1, $vr1, $vr15" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfmax.s $vr2, $vr2, $vr15" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfmax.s $vr3, $vr3, $vr15" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfmax.s $vr4, $vr4, $vr15" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfmax.s $vr5, $vr5, $vr15" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfmax.s $vr6, $vr6, $vr15" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfmax.s $vr7, $vr7, $vr15" + +.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipReluActivation: + +// +// Store the output block in the output buffer. +// + + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vst $vr0, $a4,0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vst $vr1, $a4, 16" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vstx $vr2, $a4, $t6" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "addi.d $s0, $t6, 16" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vstx $vr3, $a4, $s0" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vst $vr4, $t7, 0" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vst $vr5, $t7, 16" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vstx $vr6, $t7, $t6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "addi.d $s0, $t6, 16" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vstx $vr7, $t7, $s0" + add_immed $a4, \OutputCount\()*8*4 # advance output by N nchw8c blocks + jr $ra + + .endm + + .irp FilterCount, 1, 2, 3, 4 + .irp OutputCount, 1 + PostProcessBlock \FilterCount\(), \OutputCount\() + .endr + .endr + + .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLsxCommon.h b/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLsxCommon.h new file mode 100644 index 0000000000000..d03714f654500 --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLsxCommon.h @@ -0,0 +1,669 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SconvKernelLsxCommon.h + +Abstract: + + This module contains common kernel macros and structures for the single + precision convolution operation for the Lsx kernels. + +--*/ + +#define SP_SIZE 32*8 + +#define MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT 0x00000001 +#define MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION 0x00000002 +#define MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION 0x00000004 +#define MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION 0x00000008 + +#define Filter_save_offset 18*8 + +#define OutputStride_arg 6*8 +#define KernelHeight_arg 7*8 +#define KernelWidth_arg 8*8 +#define InputBase_arg 9*8 +#define InputWidth_arg 10*8 +#define DilatedInputWidth_arg 11*8 +#define OutputCountLeftPad_arg 12*8 +#define OutputCount_arg 13*8 +#define OutputCountRightPad_arg 14*8 +#define Bias_arg 15*8 +#define Flags_arg 16*8 +#define InputChannels_arg 17*8 + +/*++ + +Macro Description: + + This macro generates code to compute the convolution for a vector of input + blocks and a vector of filter blocks to produce a matrix of output blocks. + + OutputCount=1 generates special case code to handle padding blocks. All + other output counts assume no padding. + +Arguments: + + Isa - Supplies the instruction set architecture string for function tags. + + KernelFrame - Supplies the symbol name to access the convolution kernel + stack. + + KernelType - Supplies the type of kernel to be generated. + + BlockSize - Supplies the number of elements per block. + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a1 - Supplies the FilterStride parameter (see function description) when + KernelType!=Depthwise. Supplies the address of the filter buffer when + KernelType=Depthwise. + + s8 - Supplies the DilationWidth parameter (see function description). + + a4 - Supplies the address of the output buffer. + + a5 - Supplies the StrideWidth parameter (see function description). + + s3 - Supplies the InputStride parameter (see function description). +--*/ + + .macro ProcessOutputCountN Isa, KernelFrame, KernelType, BlockSize, FilterCount, OutputCount + move $a3, $a0 +.ifeqs "\KernelType\()","Depthwise" + move $a2, $a1 +.else + ld.d $a2, $sp, Filter_save_offset +.endif + ld.d $t1, $sp, KernelHeight_arg //KernelHeight + ld.d $t2, $sp, KernelWidth_arg //KernelWidth +.if \OutputCount\() == 1 + ld.d $t3, $sp, InputBase_arg //InputBase + ld.d $t4, $sp, InputWidth_arg //InputWidth + sub.d $t3, $zero, $t3 # keep negative for lea usage below +.endif + ClearBlock \FilterCount\(), \OutputCount\() + beqz $t1, .L\KernelType\().\FilterCount\().\OutputCount\().HandlePostProcessing + +.L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextRow: + move $t6, $t2 # reload kernel width remaining +.L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextColumn: +.if \OutputCount\() == 1 + add.d $t7, $a3, $t3 + bgeu $t7, $t4, .L\KernelType\().\FilterCount\().\OutputCount\().SkipOverPadding +.endif +.if \OutputCount\() > 3 + li.d $s2, 2 + mul.d $s2, $a5, $s2 + add.d $t4, $a5, $s2 + + add.d $t4, $t4, $a3 # compute input plus 3 blocks +.endif +.if \FilterCount\() > 2 + li.d $s2, 2 + mul.d $s2, $s2, $a1 + add.d $t7, $a2, $s2 //t6 is rbx used by ComputeBlock +.endif +.ifeqs "\KernelType\()","Nchwc" +.if \BlockSize\() == 16 + .irp Index, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), \Index\()*16*4, \Index\()*4 + .endr +.else + .irp Index, 0, 1, 2, 3, 4, 5, 6, 7 + ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), (\Index\()-4)*8*4, \Index\()*4 + .endr +.endif +.else + ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), 0, 0 +.endif +.L\KernelType\().\FilterCount\().\OutputCount\().SkipOverPadding: + add.d $a3, $a3, $t8 # advance input by dilation width +.ifeqs "\KernelType\()","Nchwc" + addi.d $a2, $a2, \BlockSize\()*\BlockSize\()*4 + # advance filter by 8i8o/16i16o block +.else + addi.d $a2, $a2, \BlockSize\()*4 # advance filter by 8o/16o block +.endif + addi.d $t6, $t6, -1 # decrement columns remaining + bnez $t6, .L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextColumn + add.d $a3, $a3, $t5 +.if \OutputCount\() == 1 + ld.d $s0, $sp, DilatedInputWidth_arg #DilatedInputWidth + sub.d $t3, $t3, $s0 + # advance input base to next row +.endif + addi.d $t1, $t1, -1 # decrement rows remaining + bnez $t1, .L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextRow + +// +// Handle post processing of the output block. +// +.L\KernelType\().\FilterCount\().\OutputCount\().HandlePostProcessing: + ld.w $a2, $sp, Flags_arg + +.if \FilterCount\() > 1 + ld.d $t6, $sp, OutputStride_arg +.endif + ld.d $a3, $sp, Bias_arg + bl MlasConvPostProcessFloat\Isa\()Filter\FilterCount\()Output\OutputCount\() +.endm +/*++ + +Macro Description: + + This macro generates code for the inner convolution kernel. + +Arguments: + + KernelType - Supplies the type of kernel to be generated. + + BlockSize - Supplies the number of elements per block. + + Isa - Supplies the instruction set architecture string for function tags. + + BiasFilter - Supplies a non-blank value if the address of the filter buffer + should be biased to point to the middle of a OIhw8i8o block in order to + reduce the code size from relative byte offsets. + +--*/ + + .macro SconvKernelFunction KernelType, BlockSize, Isa, BiasFilter + +/*++ + +Routine Description: + + This routine is the inner kernel to compute a convolution for the elements + of an output row for a set of filter rows. + +Arguments: + + Input (a0) - Supplies the address of the input buffer. + + The address is biased to include padding blocks for the left width + dimension. The address is not biased to include padding rows for the + left height dimension these are accounted for in the outer kernel. + + Filter (a1) - Supplies the address of the filter buffer. + + Output (a2) - Supplies the address of the output buffer. + + StrideWidth (a3) - Supplies the length in bytes of the blocked stride width. + + DilationWidth (a4) - Supplies the length in bytes of the blocked dilation + width. + + FilterCount (a5) - Supplies the number of filters to process in this + iteration. + + InputStride (a6) - Supplies the length in bytes to advance the input buffer to + the next input row. + + FilterStride (a7)- Supplies the length in bytes to advance the filter buffer + to the next set of filters. + + OutputStride (sp,8*0) - Supplies the length in bytes to advance the output buffer + to the next output address associated with the next set of filters. + + KernelHeight (sp,8*1)- Supplies the height of the kernel to apply. This height may + be less than the original kernel height after removing any padding + rows. + + KernelWidth (sp, 8*2)- Supplies the width of the kernel to apply. + + InputBase (sp, 8*3)- Supplies the address of the valid input buffer. + + This parameter is similar to the Input parameter, but does not include + the padding blocks for the left width dimension. This parameter is used + with the following InputWidth parameter in order to validate that the + current input buffer address in bounds and not in the left or right + width padding region. + + InputWidth (sp, 8*4)- Supplies the length in bytes of the blocked input width. + + DilatedInputWidth (sp, 8*5)- Supplies the length in bytes to advance the input base + buffer to the next input row including dilation. + + OutputCountLeftPad (sp, 8*6)- Supplies the number of output elements that include + one or more padding elements from the left edge. + + OutputCount (sp, 8*7)- Supplies the number of output elements that do not include + any padding elements. + + OutputCountRightPad (sp, 8*8)- Supplies the number of output elements that include + one or more padding elements from the right edge. + + Bias (sp, 8*9)- Supplies the address of the bias buffer. + + Flags (sp, 8*10)- Supplies additional flags controlling the convolution operation, + especially post calculation options. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasConv\KernelType\()FloatKernel\Isa\() + addi.d $sp, $sp, -SP_SIZE + st.d $s0, $sp, 0*8 + st.d $s1, $sp, 1*8 + st.d $s2, $sp, 2*8 + st.d $s3, $sp, 3*8 + st.d $s4, $sp, 4*8 + st.d $ra, $sp, 5*8 + ld.d $s0, $sp, SP_SIZE+0*8 + ld.d $s1, $sp, SP_SIZE+1*8 + ld.d $s2, $sp, SP_SIZE+2*8 + ld.d $s3, $sp, SP_SIZE+3*8 + st.d $s0, $sp, OutputStride_arg + st.d $s1, $sp, KernelHeight_arg + st.d $s2, $sp, KernelWidth_arg + st.d $s3, $sp, InputBase_arg + ld.d $s0, $sp, SP_SIZE+4*8 + ld.d $s1, $sp, SP_SIZE+5*8 + ld.d $s2, $sp, SP_SIZE+6*8 + ld.d $s3, $sp, SP_SIZE+7*8 + st.d $s0, $sp, InputWidth_arg + st.d $s1, $sp, DilatedInputWidth_arg + st.d $s2, $sp, OutputCountLeftPad_arg + st.d $s3, $sp, OutputCount_arg + ld.d $s0, $sp, SP_SIZE+8*8 + ld.d $s1, $sp, SP_SIZE+9*8 + ld.d $s2, $sp, SP_SIZE+10*8 + st.d $s0, $sp, OutputCountRightPad_arg + st.d $s1, $sp, Bias_arg + st.d $s2, $sp, Flags_arg + +.ifeqs "\BiasFilter\()","BiasFilter" + addi.d $a1, $a1,4*8*4 +.endif + st.d $a1, $sp, Filter_save_offset //store Filter + move $a1, $a7 + move $t5, $a6 + move $t8, $a4 # shuffle to Win64 register usage + move $t1, $a5 + move $a4, $a2 + move $a5, $a3 + + li.d $s0, 3 + beq $t1, $s0, .L\KernelType\().ProcessFilterCount3 + blt $t1, $s0, .L\KernelType\().ProcessFilterCountLessThan3 + ProcessFilterCountN SconvKernelFrame, \KernelType\(), 4 + b .L\KernelType\().ExitKernel + +.L\KernelType\().ProcessFilterCount3: + ProcessFilterCountN SconvKernelFrame, \KernelType\(), 3 + b .L\KernelType\().ExitKernel + +.L\KernelType\().ProcessFilterCountLessThan3: + li.d $s0,2 + blt $t1, $s0, .L\KernelType\().ProcessFilterCount1 + ProcessFilterCountN SconvKernelFrame, \KernelType\(), 2 + b .L\KernelType\().ExitKernel + +.L\KernelType\().ProcessFilterCount1: + ProcessFilterCountN SconvKernelFrame, \KernelType\(), 1 + +// +// Restore non-volatile registers and return. +// + +.L\KernelType\().ExitKernel: + ld.d $a1, $sp, Filter_save_offset //restore Filter + ld.d $s0, $sp, 0*8 + ld.d $s1, $sp, 1*8 + ld.d $s2, $sp, 2*8 + ld.d $s3, $sp, 3*8 + ld.d $s4, $sp, 4*8 + ld.d $ra, $sp, 5*8 + + addi.d $sp, $sp, SP_SIZE + jr $ra +.endm + +/*++ + +Macro Description: + + This macro generates code for the inner convolution kernel for the special + case of a depthwise separable convolution. + +Arguments: + + BlockSize - Supplies the number of elements per block. + + Isa - Supplies the instruction set architecture string for function tags. + +--*/ + + .macro SconvKernelDepthwiseFunction BlockSize, Isa + +/*++ + +Routine Description: + + This routine is the inner kernel to compute a convolution for the elements + of an output row for a set of filter rows. + + Depthwise separable convolutions are a form of grouped convolution where + the number of input and output channels per group are one. + +Arguments: + + Input a0 - Supplies the address of the input buffer. + + The address is biased to include padding blocks for the left width + dimension. The address is not biased to include padding rows for the + left height dimension these are accounted for in the outer kernel. + + Filter a1 - Supplies the address of the filter buffer. + + Output a2 - Supplies the address of the output buffer. + + StrideWidth a3 - Supplies the length in bytes of the blocked stride width. + + DilationWidth a4 - Supplies the length in bytes of the blocked dilation + width. + + InputStride a5 - Supplies the length in bytes to advance the input buffer + to the next input row. + + KernelHeight a6 - Supplies the height of the kernel to apply. This height may + be less than the original kernel height after removing any padding + rows. + + KernelWidth a7- Supplies the width of the kernel to apply. + + InputBase (sp, 0*8)- Supplies the address of the valid input buffer. + + This parameter is similar to the Input parameter, but does not include + the padding blocks for the left width dimension. This parameter is used + with the following InputWidth parameter in order to validate that the + current input buffer address in bounds and not in the left or right + width padding region. + + InputWidth (sp, 1*8)- Supplies the length in bytes of the blocked input width. + + DilatedInputWidth (sp, 2*8)- Supplies the length in bytes to advance the input base + buffer to the next input row including dilation. + + OutputCountLeftPad (sp, 3*8)- Supplies the number of output elements that include + one or more padding elements from the left edge. + + OutputCount (sp, 4*8)- Supplies the number of output elements that do not include + any padding elements. + + OutputCountRightPad (sp, 5*8)- Supplies the number of output elements that include + one or more padding elements from the right edge. + + Bias (sp, 6*8)- Supplies the address of the bias buffer. + + Flags (sp, 7*8)- Supplies additional flags controlling the convolution operation, + especially post calculation options. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasConvDepthwiseFloatKernel\Isa\() + addi.d $sp, $sp, -SP_SIZE + st.d $s0, $sp, 0*8 + st.d $s1, $sp, 1*8 + st.d $s2, $sp, 2*8 + st.d $s3, $sp, 3*8 + st.d $s4, $sp, 4*8 + st.d $ra, $sp, 5*8 + + st.d $a6, $sp, KernelHeight_arg + st.d $a7, $sp, KernelWidth_arg + + ld.d $s0, $sp, SP_SIZE+0*8 + ld.d $s1, $sp, SP_SIZE+1*8 + ld.d $s2, $sp, SP_SIZE+2*8 + ld.d $s3, $sp, SP_SIZE+3*8 + st.d $s0, $sp, InputBase_arg + st.d $s1, $sp, InputWidth_arg + st.d $s2, $sp, DilatedInputWidth_arg + st.d $s3, $sp, OutputCountLeftPad_arg + ld.d $s0, $sp, SP_SIZE+4*8 + ld.d $s1, $sp, SP_SIZE+5*8 + ld.d $s2, $sp, SP_SIZE+6*8 + ld.d $s3, $sp, SP_SIZE+7*8 + st.d $s0, $sp, OutputCount_arg + st.d $s1, $sp, OutputCountRightPad_arg + st.d $s2, $sp, Bias_arg + st.d $s3, $sp, Flags_arg +// +// Process the specified number of filter rows. +// + move $t8, $a4 // shuffle to Win64 register usage + move $t5, $a5 + move $a4, $a2 + move $a5, $a3 + ProcessFilterCountN SconvKernelDepthwiseFrame, Depthwise, 1 + +// +// Restore non-volatile registers and return. + ld.d $s0, $sp, 0*8 + ld.d $s1, $sp, 1*8 + ld.d $s2, $sp, 2*8 + ld.d $s3, $sp, 3*8 + ld.d $s4, $sp, 4*8 + ld.d $ra, $sp, 5*8 + addi.d $sp, $sp, SP_SIZE +// + jr $ra +.endm + +/*++ + +Macro Description: + + This macro generates code to compute the convolution for a vector of input + blocks and a vector of filter blocks to produce a matrix of output blocks + for a pointwise convolution. + +Arguments: + + Isa - Supplies the instruction set architecture string for function tags. + + BlockSize - Supplies the number of elements per block. + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + (a0) - Supplies the address of the input buffer. + + (a1) - Supplies the FilterStride parameter (see function description). + + (s8) - Supplies the InputStride parameter (see function description). + + (a4) - Supplies the address of the output buffer. + + (a5) - Supplies the StrideWidth parameter (see function description). + + (s5) - Supplies the address of the filter buffer. + +--*/ + + .macro ProcessPointwiseOutputCountN Isa, BlockSize, FilterCount, OutputCount + + move $a3, $a0 + move $a2, $t2 + ld.d $t1, $sp, InputChannels_arg + ClearBlock \FilterCount\(), \OutputCount\() + +.LPointwise.\FilterCount\().\OutputCount\().ProcessNextInputBlock: +.if \OutputCount\() > 3 + li.d $s0, 2 + mul $s0, $s0, $a5 + add.d $t4, $a5, $s0 + add.d $t4, $t4, $a3 # compute input plus 3 blocks +.endif +.if \FilterCount\() > 2 + li.d $s0, 2 # compute filter plus 2 rows + mul.d $s0, $s0, $a1 + add.d $t7, $a2, $s0 +.endif + +.if \BlockSize\() == 16 + .irp Index, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + ComputeBlock Pointwise, \FilterCount\(), \OutputCount\(), \Index\()*16*4, \Index\()*4 + .endr +.else + .irp Index, 0, 1, 2, 3, 4, 5, 6, 7 + ComputeBlock Pointwise, \FilterCount\(), \OutputCount\(), (\Index\()-4)*8*4, \Index\()*4 + .endr +.endif + add.d $a3, $a3, $t8 # advance input to next channel block + addi.d $a2, $a2, \BlockSize\()*\BlockSize\()*4 + # advance filter by 8i8o/16i16o block + addi.d $t1, $t1, -1 //InputChannels decrement input blocks remaining + bnez $t1, .LPointwise.\FilterCount\().\OutputCount\().ProcessNextInputBlock + +// +// Handle post processing of the output block. +// + ld.w $a2, $sp, Flags_arg #load flag +.if \FilterCount\() > 1 + ld.d $t6 ,$sp, OutputStride_arg #load .LSconvKernelPointwiseFrame_OutputStride +.endif + ld.d $a3, $sp, Bias_arg # load .LSconvKernelPointwiseFrame_Bias + bl MlasConvPostProcessFloat\Isa\()Filter\FilterCount\()Output\OutputCount\() +.endm + + .macro SconvKernelPointwiseFunction Isa, BiasFilter + +/*++ + +Routine Description: + + This routine is the inner kernel to compute a convolution for the elements + of an output row for a set of filter rows. + + Pointwise convolutions have a kernel size of one. To simplify this + implementation, no input padding is allowed, which matches typical usage in + models. + +Arguments: + + Input (a0) - Supplies the address of the input buffer. + + Filter (a1) - Supplies the address of the filter buffer. + + Output (a2) - Supplies the address of the output buffer. + + StrideWidth (a3) - Supplies the length in bytes of the blocked stride width. + + InputChannels (a4) - Supplies the number of input channels to process. + + FilterCount (a5) - Supplies the number of rows from the filter to process. + + InputStride (a6) - Supplies the length in bytes to advance the input buffer to + the next input channel of the same input row. + + FilterStride (a7) - Supplies the length in bytes to advance the filter buffer + to the next set of filters. + + OutputStride (sp+0) - Supplies the length in bytes to advance the output buffer + to the next output address associated with the next set of filters. + + OutputCount (sp+8) - Supplies the number of output elements. + + Bias (sp+16) - Supplies the address of the bias buffer. + + Flags (sp+24) - Supplies additional flags controlling the convolution operation, + especially post calculation options. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasConvPointwiseFloatKernel\Isa\() + addi.d $sp, $sp, -SP_SIZE + st.d $s0, $sp, 0*8 + st.d $s1, $sp, 1*8 + st.d $s2, $sp, 2*8 + st.d $s3, $sp, 3*8 + st.d $s4, $sp, 4*8 + st.d $ra, $sp, 5*8 + + ld.d $s0, $sp, SP_SIZE+0*8 + ld.d $s1, $sp, SP_SIZE+1*8 + ld.d $s2, $sp, SP_SIZE+2*8 + ld.d $s3, $sp, SP_SIZE+3*8 + st.d $s0, $sp, OutputStride_arg + st.d $s1, $sp, OutputCount_arg + st.d $s2, $sp, Bias_arg + st.d $s3, $sp, Flags_arg + st.d $a4, $sp, InputChannels_arg + +.ifeqs "\BiasFilter\()","BiasFilter" + addi.d $t2, $a1, 4*8*4 +.else + move $t2, $a1 +.endif + + ld.d $t0, $sp, OutputCount_arg //OutputCount + move $a1, $a7 // FilterStride + move $t8, $a6 // InputStride + move $t1, $a5 // shuffle to Win64 register usage + move $a4, $a2 + move $a5, $a3 + +// +// Process the specified number of filter rows. +// + li.d $s0, 3 + beq $t1, $s0, .LPointwise.ProcessFilterCount3 + blt $t1, $s0, .LPointwise.ProcessFilterCountLessThan3 + ProcessPointwiseFilterCountN 4 + b .LPointwise.ExitKernel + +.LPointwise.ProcessFilterCount3: + ProcessPointwiseFilterCountN 3 + b .LPointwise.ExitKernel + +.LPointwise.ProcessFilterCountLessThan3: + li.d $s0, 2 + blt $t1, $s0, .LPointwise.ProcessFilterCount1 + ProcessPointwiseFilterCountN 2 + b .LPointwise.ExitKernel + +.LPointwise.ProcessFilterCount1: + ProcessPointwiseFilterCountN 1 + +// +// Restore non-volatile registers and return. +// +.LPointwise.ExitKernel: + + ld.d $s0, $sp, 0*8 + ld.d $s1, $sp, 1*8 + ld.d $s2, $sp, 2*8 + ld.d $s3, $sp, 3*8 + ld.d $s4, $sp, 4*8 + ld.d $ra, $sp, 5*8 + addi.d $sp, $sp, SP_SIZE + jr $ra +.endm diff --git a/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelCommon.h b/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelCommon.h new file mode 100644 index 0000000000000..93b109c90ae4f --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelCommon.h @@ -0,0 +1,35 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SgemmKernelCommon.h + +Abstract: + + This module contains common kernel macros and structures for the single + precision matrix/matrix multiply operation (SGEMM). + +--*/ + +// +// Define the single precision parameters. +// + +#define LFgemmElementShift 2 +#define LFgemmElementSize (1 << LFgemmElementShift) +#define LFgemmYmmElementCount (32/LFgemmElementSize) + +#include "FgemmKernelCommon.h" + +// +// Define the typed instructions for single precision. +// + +FGEMM_TYPED_INSTRUCTION(xvfadd, xvfadd.s) +FGEMM_TYPED_INSTRUCTION(xvfmadd, xvfmadd.s) +FGEMM_TYPED_INSTRUCTION(xvldrepl, xvldrepl.w) +FGEMM_TYPED_INSTRUCTION(xvfmul, xvfmul.s) diff --git a/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelLasx.S b/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelLasx.S new file mode 100644 index 0000000000000..d537742016d01 --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelLasx.S @@ -0,0 +1,33 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SgemmKernelLasx.s + +Abstract: + + This module implements the kernels for the single precision matrix/matrix + multiply operation (SGEMM). + + This implementation uses LASX instructions. + +--*/ + +#include "asmmacro.h" +#include "SgemmKernelCommon.h" +#include "FgemmKernelLasxCommon.h" + + + .text + +// +// Generate the GEMM kernel. +// + +FgemmKernelLasxFunction MlasGemmFloatKernelLasx + + .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelLsx.S b/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelLsx.S new file mode 100644 index 0000000000000..86b5ef8b51b00 --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelLsx.S @@ -0,0 +1,267 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SgemmKernelLsx.s + +Abstract: + + This module implements the kernels for the single precision matrix/matrix + multiply operation (SGEMM). + + This implementation uses Lsx instructions. + +--*/ + +#include "asmmacro.h" +#include "FgemmKernelLsxCommon.h" + +FGEMM_TYPED_INSTRUCTION(vfadd, vfadd.s) + +/*++ + +Macro Description: + + This macro multiplies and accumulates for a 16xN block of the output matrix. + +Arguments: + + RowCount - Supplies the number of rows to process. + + VectorOffset - Supplies the byte offset from matrix B to fetch elements. + + Shuffle - Supplies the shuffle mask to extract the element from matrix A. + +Implicit Arguments: + + a1 - Supplies the address into the matrix B data. + + vr0-vr1 - Supplies up to four elements loaded from matrix A and matrix A + plus one row. + + vr8-vr15 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockSseBy16 RowCount, VectorOffset, Shuffle + vld $vr4, $a1, \VectorOffset + vld $vr5, $a1, \VectorOffset + 16 + vreplvei.w $vr2, $vr0, \Shuffle +.if \RowCount\() == 2 + vreplvei.w $vr3, $vr1, \Shuffle + vmove $vr6, $vr4 + vmove $vr7, $vr5 +.endif + vfmadd.s $vr8, $vr4, $vr2, $vr8 + vfmadd.s $vr9, $vr5, $vr2, $vr9 +.if \RowCount\() == 2 + vfmadd.s $vr12, $vr6, $vr3, $vr12 + vfmadd.s $vr13, $vr7, $vr3, $vr13 +.endif + vld $vr4, $a1, \VectorOffset + 32 + vld $vr5, $a1, \VectorOffset + 48 +.if \RowCount\() == 2 + vmove $vr6, $vr4 + vmove $vr7, $vr5 +.endif + vfmadd.s $vr10, $vr4, $vr2, $vr10 + vfmadd.s $vr11, $vr5, $vr2, $vr11 +.if \RowCount\() == 2 + vfmadd.s $vr14, $vr6, $vr3, $vr14 + vfmadd.s $vr15, $vr7, $vr3, $vr15 +.endif + .endm + + +/*++ + +Macro Description: + + This macro generates code to compute matrix multiplication for a fixed set + of rows. + +Arguments: + + RowCount - Supplies the number of rows to process. + + Fallthrough - Supplies a non-blank value if the macro may fall through to + the ExitKernel label. + +Implicit Arguments: + + a0 - Supplies the address of matrix A. + + a1 - Supplies the address of matrix B. + + t8 - Supplies the address of matrix A. + + a5 - Supplies the number of columns from matrix B and matrix C to iterate + over. + + a2 - Supplies the address of matrix C. + + a3 - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + t7 - Supplies the length in bytes of a row from matrix A. + + t5 - Supplies the length in bytes of a row from matrix C. + + s3 - Stores the ZeroMode argument from the stack frame. + +--*/ + + .macro ProcessCountM RowCount, Fallthrough +.LProcessNextColumnLoop16xN\@: + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr8, $vr8,$vr8" + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr9, $vr9,$vr9" + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr10, $vr10,$vr10" + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr11, $vr11,$vr11" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr12, $vr12,$vr12" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr13, $vr13,$vr13" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr14, $vr14,$vr14" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr15, $vr15,$vr15" + move $t8, $a3 + li.d $s0, 4 + blt $t8, $s0, .LProcessRemaining16xNBlocks\@ +.LCompute16xNBlockBy4Loop\@: + EmitIfCountGE \RowCount\(), 1, "vld $vr0, $a0, 0" + EmitIfCountGE \RowCount\(), 2, "vldx $vr1, $a0, $t0" #second line of A + ComputeBlockSseBy16 2, 0, 0x0 + ComputeBlockSseBy16 2, 16*4, 0x1 + addi.d $a1, $a1, 32*4 # advance matrix B by 32 columns + ComputeBlockSseBy16 2, 0, 0x2 + ComputeBlockSseBy16 2, 16*4, 0x3 + addi.d $a1, $a1, 32*4 # advance matrix B by 32 columns + addi.d $a0, $a0, 4*4 # advance matrix A by 4 columns + addi.d $t8, $t8, -4 + li.d $s0, 4 #check matrix A remaining less than 4 + bge $t8, $s0, .LCompute16xNBlockBy4Loop\@ + +.LProcessRemaining16xNBlocks\@: + beqz $t8, .LOutput16xNBlock\@ + +.LCompute16xNBlockBy1Loop\@: + EmitIfCountGE \RowCount\(), 1, "ld.w $s0, $a0, 0" + EmitIfCountGE \RowCount\(), 1, "vinsgr2vr.w $vr0, $s0, 0" + EmitIfCountGE \RowCount\(), 2, "ldx.w $s0,$a0, $t0" + EmitIfCountGE \RowCount\(), 2, "vinsgr2vr.w $vr1,$s0, 0" + ComputeBlockSseBy16 2, 0, 0x00 + addi.d $a1, $a1, 16*4 #advance matrix B by 16 columns + addi.d $a0, $a0, 1*4 #advance matrix A by 1 column + addi.d $t8, $t8, -1 + bnez $t8, .LCompute16xNBlockBy1Loop\@ + +.LOutput16xNBlock\@: + movfr2gr.s $s0, $f24 + vreplgr2vr.w $vr2, $s0 + EmitIfCountGE \RowCount\(), 1, "vfmul.s $vr8,$vr8,$vr2" + # multiply by alpha + EmitIfCountGE \RowCount\(), 1, "vfmul.s $vr9,$vr9,$vr2" + EmitIfCountGE \RowCount\(), 1, "vfmul.s $vr10,$vr10,$vr2" + EmitIfCountGE \RowCount\(), 1, "vfmul.s $vr11,$vr11,$vr2" + EmitIfCountGE \RowCount\(), 2, "vfmul.s $vr12,$vr12,$vr2" + EmitIfCountGE \RowCount\(), 2, "vfmul.s $vr13,$vr13,$vr2" + EmitIfCountGE \RowCount\(), 2, "vfmul.s $vr14,$vr14,$vr2" + EmitIfCountGE \RowCount\(), 2, "vfmul.s $vr15,$vr15,$vr2" + li.d $s0, 16 + blt $a5, $s0, .LOutputPartial16xNBlock\@ + sub.d $a5, $a5, $s0 + AccumulateAndStoreBlock \RowCount\(), 4 + addi.d $a2, $a2, 16*4 # advance matrix C by 16 columns + move $a0, $t1 # reload matrix A + bnez $a5, .LProcessNextColumnLoop16xN\@ + b .LExitKernel + +// +// Output a partial 16xN block to the matrix. +// + +.LOutputPartial16xNBlock\@: + li.d $s0, 4 + blt $a5, $s0, .LOutputPartialLessThan4xNBlock\@ + li.d $s0, 8 + blt $a5, $s0, .LOutputPartialLessThan8xNBlock\@ + li.d $s0, 12 + blt $a5, $s0, .LOutputPartialLessThan12xNBlock\@ + AccumulateAndStoreBlock \RowCount\(), 3 + andi $a5, $a5, 3 + beqz $a5, .LExitKernel + EmitIfCountGE \RowCount\(), 1, "vmove $vr8, $vr11" + # shift remaining elements down + EmitIfCountGE \RowCount\(), 2, "vmove $vr12, $vr15" + addi.d $a2, $a2,12*4 # advance matrix C by 12 columns + b .LOutputPartialLessThan4xNBlock\@ + +.LOutputPartialLessThan12xNBlock\@: + AccumulateAndStoreBlock \RowCount\(), 2 + andi $a5, $a5, 3 + beqz $a5, .LExitKernel + EmitIfCountGE \RowCount\(), 1, "vmove $vr8, $vr10" + # shift remaining elements down + EmitIfCountGE \RowCount\(), 2, "vmove $vr12, $vr14" + addi.d $a2, $a2,8*4 # advance matrix C by 8 columns + b .LOutputPartialLessThan4xNBlock\@ + +.LOutputPartialLessThan8xNBlock\@: + AccumulateAndStoreBlock \RowCount\(), 1 + andi $a5, $a5, 3 + beqz $a5, .LExitKernel + EmitIfCountGE \RowCount\(), 1, "vmove $vr8, $vr9" + # shift remaining elements down + EmitIfCountGE \RowCount\(), 2, "vmove $vr12, $vr13" + addi.d $a2, $a2, 4*4 # advance matrix C by 4 columns + +.LOutputPartialLessThan4xNBlock\@: + andi $s0, $a5, 2 + beqz $s0, .LOutputPartial1xNBlock\@ + and $s0, $t5, $t5 # ZeroMode? + bnez $s0, .LSkipAccumulateOutput2xN\@ + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr0, $vr0, $vr0" + EmitIfCountGE \RowCount\(), 1, "ld.d $s0, $a2, 0" + EmitIfCountGE \RowCount\(), 1, "vinsgr2vr.d $vr0, $s0, 0" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr1, $vr1, $vr1" + EmitIfCountGE \RowCount\(), 2, "ldx.d $s0, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "vinsgr2vr.d $vr1, $s0, 0" + EmitIfCountGE \RowCount\(), 1, "vfadd.s $vr8, $vr8, $vr0" + EmitIfCountGE \RowCount\(), 2, "vfadd.s $vr12, $vr12, $vr1" + +.LSkipAccumulateOutput2xN\@: + EmitIfCountGE \RowCount\(), 1, "vstelm.d $vr8, $a2, 0, 0" + EmitIfCountGE \RowCount\(), 2, "vpickve2gr.d $s0, $vr12, 0" + EmitIfCountGE \RowCount\(), 2, "stx.d $s0, $a2, $t6" + andi $s0, $a5, 1 + beqz $s0, .LExitKernel + EmitIfCountGE \RowCount\(), 1, "vpermi.w $vr8, $vr8, 0xee" + # shift third element down + EmitIfCountGE \RowCount\(), 2, "vpermi.w $vr12, $vr12, 0xee" + addi.d $a2, $a2, 2*4 # advance matrix C by 2 columns + +.LOutputPartial1xNBlock\@: + and $s0, $t5, $t5 # ZeroMode? + bnez $s0, .LSkipAccumulateOutput1xN\@ + + EmitIfCountGE \RowCount\(), 1, "fld.s $f16, $a2, 0" + EmitIfCountGE \RowCount\(), 1, "fadd.s $f8, $f16, $f8" + EmitIfCountGE \RowCount\(), 2, "fldx.s $f17, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "fadd.s $f12, $f12, $f17" + +.LSkipAccumulateOutput1xN\@: + EmitIfCountGE \RowCount\(), 1, "fst.s $f8, $a2, 0" + EmitIfCountGE \RowCount\(), 2, "fstx.s $f12, $a2, $t6" +.ifb \Fallthrough\() + b .LExitKernel +.endif + .endm + +// +// Generate the GEMM kernel. +// + +FgemmKernelLsxFunction MlasGemmFloatKernelLSX + + .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SgemmTransposePackB16x4LSX.S b/onnxruntime/core/mlas/lib/loongarch64/SgemmTransposePackB16x4LSX.S new file mode 100644 index 0000000000000..cd1747745d2a4 --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SgemmTransposePackB16x4LSX.S @@ -0,0 +1,89 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SgemmTransposePackB16x4LSX.s + +Abstract: + + This module implements routines for packing buffers for the single precision + matrix/matrix multiply operation (SGEMM). + + This implementation uses Lsx instructions. + +--*/ + +#include "asmmacro.h" + + .text + +/*++ + +Routine Description: + + This routine transposes elements from the source matrix to the destination + packed buffer. + + 4 columns of 16 rows from the source matrix are transposed to 16 columns of 4 + rows in the destination packed buffer. + +Arguments: + + D (a0) - Supplies the address of the destination packed buffer. + + B (a1) - Supplies the address of the source matrix. + + ldb (a2) - Supplies the number of elements per row of the source matrix. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasSgemmTransposePackB16x4LSX + addi.d $sp, $sp, -64 + st.d $s0, $sp, 0*8 + st.d $s1, $sp, 1*8 + slli.d $a2, $a2, 2 # convert ldb to bytes + ori $a3, $zero, 4 # transpose four 4x4 blocks + vxor.v $vr7, $vr7, $vr7 +.LTransposeBlockLoop: + slli.d $s0, $a2, 1 + add.d $s1, $a1, $s0 + vld $vr0, $a1, 0 + vldx $vr1, $a1, $a2 + vld $vr2, $s1, 0 + vldx $vr3, $s1, $a2 + + vor.v $vr4, $vr0, $vr7 + vilvl.w $vr4, $vr1, $vr4 + vilvh.w $vr0, $vr1, $vr0 + vor.v $vr5, $vr2, $vr7 + vilvl.w $vr5, $vr3, $vr5 + vilvh.w $vr2, $vr3, $vr2 + vor.v $vr1, $vr4, $vr7 + vilvl.d $vr1, $vr5, $vr1 + vilvh.d $vr4, $vr5, $vr4 + vor.v $vr3, $vr0, $vr7 + vilvl.d $vr3, $vr2, $vr3 + vilvh.d $vr0, $vr2, $vr0 + vst $vr1, $a0, 0 + vst $vr4, $a0, 0x40 + vst $vr3, $a0, 0x80 + vst $vr0, $a0, 0xc0 + addi.d $a0, $a0, 0x10 + slli.d $s0, $a2, 1 + add.d $a1, $s0, $s1 + addi.d $a3, $a3, -1 + bnez $a3, .LTransposeBlockLoop + ld.d $s0, $sp, 0*8 + ld.d $s1, $sp, 1*8 + addi.d $sp, $sp, 64 + jr $ra + + .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SgemmTransposePackB16x4Lasx.S b/onnxruntime/core/mlas/lib/loongarch64/SgemmTransposePackB16x4Lasx.S new file mode 100644 index 0000000000000..e617419989c4d --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SgemmTransposePackB16x4Lasx.S @@ -0,0 +1,126 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SgemmTransposePackB16x4Lasx.s + +Abstract: + + This module implements routines for packing buffers for the single precision + matrix/matrix multiply operation (SGEMM). + + This implementation uses Lasx instructions. + +--*/ + +#include "asmmacro.h" + + .text + +/*++ + +Macro Description: + + 4 columns of 8 rows from the source matrix are transposed to 8 columns of 4 + rows in the destination packed buffer. + +Arguments: + + StoreOffset - Supplies the relative byte offset into the destination packed + buffer. + +Implicit Arguments: + + a0 - Supplies the address of the destination packed buffer. + + a1 - Supplies the address of the source matrix. + + a2 - Supplies the number of elements per row of the source matrix. + +--*/ + + .macro TransposePackB8x4BlockLasx StoreOffset + +// +// Load 4 columns from 8 rows of the source matrix into the lower and upper +// halves of 4 XR registers. +// + + add.d $t0, $a2, $a2 + add.d $t6, $a1, $t0 + vld $vr0, $a1, 0 + vldx $vr1, $a1, $a2 + add.d $t0, $a2, $a2 + add.d $a1, $t6, $t0 + vld $vr2, $t6, 0 + vldx $vr3, $t6, $a2 + add.d $t0, $a2, $a2 + add.d $t6, $a1, $t0 + + vld $vr4, $a1, 0 + xvpermi.q $xr0, $xr4, 0x2 + vldx $vr5, $a1, $a2 + xvpermi.q $xr1, $xr5, 0x2 + vld $vr4, $t6, 0 + xvpermi.q $xr2, $xr4, 0x2 + vldx $vr5, $t6, $a2 + xvpermi.q $xr3, $xr5, 0x2 + +// +// Transpose the lower and upper halves of the 4 XR registers as two 4x4 +// matrices and store the output to the destination packed buffer. +// + + xvilvl.w $xr4, $xr1, $xr0 + xvilvh.w $xr5, $xr1, $xr0 + xvilvl.w $xr0, $xr3, $xr2 + xvilvh.w $xr1, $xr3, $xr2 + xvilvl.d $xr2, $xr0, $xr4 + xvilvh.d $xr3, $xr0, $xr4 + xvst $xr2, $a0, \StoreOffset\() + xvst $xr3, $a0, 0x40+\StoreOffset\() + xvilvl.d $xr0, $xr1, $xr5 + xvilvh.d $xr4, $xr1, $xr5 + xvst $xr0, $a0, 0x80+\StoreOffset\() + xvst $xr4, $a0, 0xc0+\StoreOffset\() + + .endm + +/*++ + +Routine Description: + + This routine transposes elements from the source matrix to the destination + packed buffer. + + 4 columns of 16 rows from the source matrix are transposed to 16 columns of 4 + rows in the destination packed buffer. + +Arguments: + + D (a0) - Supplies the address of the destination packed buffer. + + B (a1) - Supplies the address of the source matrix. + + ldb (a2) - Supplies the number of elements per row of the source matrix. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasSgemmTransposePackB16x4Lasx + + slli.d $a2, $a2, 2 # convert ldb to bytes + TransposePackB8x4BlockLasx 0*4 + add.d $t0, $a2, $a2 + add.d $a1, $t0, $t6 + TransposePackB8x4BlockLasx 8*4 + jr $ra + + .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SoftmaxKernelLasx.S b/onnxruntime/core/mlas/lib/loongarch64/SoftmaxKernelLasx.S new file mode 100644 index 0000000000000..aaaa3cbf9138d --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SoftmaxKernelLasx.S @@ -0,0 +1,357 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SoftmaxKernelLasx.s + +Abstract: + + This module implements the kernels for the single precision softmax + operation. + + This implementation uses Lasx instructions. + +--*/ + +#include "asmmacro.h" + + .text + +/*++ + +Routine Description: + + This routine implements a vectorized kernel to find the maximum value of + the supplied buffer. + +Arguments: + + Input (a0) - Supplies the input buffer. + + N (a1) - Supplies the number of elements to process. + +Return Value: + + Returns the maximum value of the supplied buffer. + +--*/ + + FUNCTION_ENTRY MlasReduceMaximumF32KernelLasx + addi.d $sp, $sp, -32 + + la.global $t0, MlasMinimumF32Value + ld.w $t0, $t0, 0 + xvreplgr2vr.w $xr0, $t0 + beqz $a1, .LReduceMaximum.ExitKernel + ori $t0, $zero, 8 + bltu $a1, $t0, .LReduceMaximum.ProcessRemainingCountBy1 + ori $t1, $zero, 32 + bltu $a1, $t1, .LReduceMaximum.ProcessRemainingCountBy8 + xvreplgr2vr.w $xr16, $zero + xvor.v $xr1, $xr0, $xr16 + xvor.v $xr2, $xr0, $xr16 + xvor.v $xr3, $xr0, $xr16 + +.LReduceMaximum.ProcessRemainingCountBy32: + xvld $xr16, $a0, 0 + xvfmax.s $xr0, $xr0, $xr16 + xvld $xr16, $a0, 8*4 + xvfmax.s $xr1, $xr1, $xr16 + addi.d $a1, $a1, -0x20 + xvld $xr16, $a0, 16*4 + xvfmax.s $xr2, $xr2, $xr16 + xvld $xr16, $a0, 24*4 + xvfmax.s $xr3, $xr3, $xr16 + addi.d $a0, $a0, 32*4 # advance input by 32 elements + ori $t1, $zero, 32 + bgeu $a1, $t1, .LReduceMaximum.ProcessRemainingCountBy32 + xvfmax.s $xr0, $xr0, $xr1 + xvfmax.s $xr2, $xr2, $xr3 + xvfmax.s $xr0, $xr0, $xr2 + +.LReduceMaximum.ProcessRemainingCountBy8: + ori $t1, $zero, 8 + bltu $a1, $t1, .LReduceMaximum.ProcessRemainingCountLessThan8 + xvld $xr16, $a0, 0 + xvfmax.s $xr0, $xr0, $xr16 + addi.d $a1, $a1, -8 + addi.d $a0, $a0, 8*4 + b .LReduceMaximum.ProcessRemainingCountBy8 + +.LReduceMaximum.ProcessRemainingCountLessThan8: + xvst $xr0, $sp, 0 + vld $vr1, $sp, 0x10 + vld $vr0, $sp, 0 + vfmax.s $vr0, $vr0, $vr1 + vshuf4i.w $vr1, $vr0, 0xee + vfmax.s $vr0, $vr0, $vr1 + vshuf4i.w $vr1, $vr0, 0x55 + vfmax.s $vr0, $vr0, $vr1 + beqz $a1, .LReduceMaximum.ExitKernel + +.LReduceMaximum.ProcessRemainingCountBy1: + vld $vr16, $a0, 0 + vfmax.s $vr0, $vr0, $vr16 + addi.d $a0, $a0, 4 # advance input by 1 element + addi.d $a1, $a1, -1 + bnez $a1, .LReduceMaximum.ProcessRemainingCountBy1 + +.LReduceMaximum.ExitKernel: + xvinsgr2vr.d $xr0, $zero, 2 + xvinsgr2vr.d $xr0, $zero, 3 + xvinsgr2vr.d $xr1, $zero, 2 + xvinsgr2vr.d $xr1, $zero, 3 + xvinsgr2vr.d $xr2, $zero, 2 + xvinsgr2vr.d $xr2, $zero, 3 + xvinsgr2vr.d $xr3, $zero, 2 + xvinsgr2vr.d $xr3, $zero, 3 + xvinsgr2vr.d $xr4, $zero, 2 + xvinsgr2vr.d $xr4, $zero, 3 + xvinsgr2vr.d $xr5, $zero, 2 + xvinsgr2vr.d $xr5, $zero, 3 + xvinsgr2vr.d $xr6, $zero, 2 + xvinsgr2vr.d $xr6, $zero, 3 + xvinsgr2vr.d $xr7, $zero, 2 + xvinsgr2vr.d $xr7, $zero, 3 + xvinsgr2vr.d $xr8, $zero, 2 + xvinsgr2vr.d $xr8, $zero, 3 + xvinsgr2vr.d $xr9, $zero, 2 + xvinsgr2vr.d $xr9, $zero, 3 + xvinsgr2vr.d $xr10, $zero, 2 + xvinsgr2vr.d $xr10, $zero, 3 + xvinsgr2vr.d $xr11, $zero, 2 + xvinsgr2vr.d $xr11, $zero, 3 + xvinsgr2vr.d $xr12, $zero, 2 + xvinsgr2vr.d $xr12, $zero, 3 + xvinsgr2vr.d $xr13, $zero, 2 + xvinsgr2vr.d $xr13, $zero, 3 + xvinsgr2vr.d $xr14, $zero, 2 + xvinsgr2vr.d $xr14, $zero, 3 + xvinsgr2vr.d $xr15, $zero, 2 + xvinsgr2vr.d $xr15, $zero, 3 + addi.d $sp, $sp, 32 + jr $ra + +/*++ + +Routine Description: + + This routine implements a vectorized kernel to produce the final output for + the softmax operation. + +Arguments: + + Output (a0) - Supplies the output buffer. + + N (a1) - Supplies the number of elements to process. + + Parameters (a2) - Supplies an array containing the scale value. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasComputeSoftmaxOutputF32KernelLasx + + ld.w $t0, $a2, 0 + xvreplgr2vr.w $xr4, $t0 + ori $t1, $zero, 0x20 + bltu $a1, $t1, .LComputeSoftmaxOutput.ProcessRemainingCountBy8 + +.LComputeSoftmaxOutput.ProcessRemainingCountBy32: + xvld $xr16, $a0, 0 + xvfmul.s $xr0, $xr4, $xr16 + xvld $xr16, $a0, 8*4 + xvfmul.s $xr1, $xr4, $xr16 + addi.d $a1, $a1, -0x20 + xvld $xr16, $a0, 16*4 + xvfmul.s $xr2, $xr4, $xr16 + xvld $xr16, $a0, 24*4 + xvfmul.s $xr3, $xr4, $xr16 + xvst $xr0, $a0, 0 + xvst $xr1, $a0, 8*4 + xvst $xr2, $a0, 16*4 + xvst $xr3, $a0, 24*4 + addi.d $a0, $a0, 0x80 # advance output by 32 elements + bgeu $a1, $t1, .LComputeSoftmaxOutput.ProcessRemainingCountBy32 + +.LComputeSoftmaxOutput.ProcessRemainingCountBy8: + ori $t2, $zero, 8 + bltu $a1, $t2, .LComputeSoftmaxOutput.ProcessRemainingCountLessThan8 + xvld $xr16, $a0, 0 + xvfmul.s $xr0, $xr4, $xr16 + addi.d $a1, $a1, -8 + xvst $xr0, $a0, 0 + addi.d $a0, $a0, 8*4 # advance output by 8 elements + b .LComputeSoftmaxOutput.ProcessRemainingCountBy8 + +.LComputeSoftmaxOutput.ProcessRemainingCountLessThan8: + beqz $a1, .LComputeSoftmaxOutput.ExitKernel + +.LComputeSoftmaxOutput.ProcessRemainingCountBy1: + fld.s $f16, $a0, 0 + fmul.s $f0, $f4, $f16 + fst.s $f0, $a0, 0 + addi.d $a0, $a0, 4 # advance output by 1 element + addi.d $a1, $a1, -1 + bnez $a1, .LComputeSoftmaxOutput.ProcessRemainingCountBy1 + +.LComputeSoftmaxOutput.ExitKernel: + xvinsgr2vr.d $xr0, $zero, 2 + xvinsgr2vr.d $xr0, $zero, 3 + xvinsgr2vr.d $xr1, $zero, 2 + xvinsgr2vr.d $xr1, $zero, 3 + xvinsgr2vr.d $xr2, $zero, 2 + xvinsgr2vr.d $xr2, $zero, 3 + xvinsgr2vr.d $xr3, $zero, 2 + xvinsgr2vr.d $xr3, $zero, 3 + xvinsgr2vr.d $xr4, $zero, 2 + xvinsgr2vr.d $xr4, $zero, 3 + xvinsgr2vr.d $xr5, $zero, 2 + xvinsgr2vr.d $xr5, $zero, 3 + xvinsgr2vr.d $xr6, $zero, 2 + xvinsgr2vr.d $xr6, $zero, 3 + xvinsgr2vr.d $xr7, $zero, 2 + xvinsgr2vr.d $xr7, $zero, 3 + xvinsgr2vr.d $xr8, $zero, 2 + xvinsgr2vr.d $xr8, $zero, 3 + xvinsgr2vr.d $xr9, $zero, 2 + xvinsgr2vr.d $xr9, $zero, 3 + xvinsgr2vr.d $xr10, $zero, 2 + xvinsgr2vr.d $xr10, $zero, 3 + xvinsgr2vr.d $xr11, $zero, 2 + xvinsgr2vr.d $xr11, $zero, 3 + xvinsgr2vr.d $xr12, $zero, 2 + xvinsgr2vr.d $xr12, $zero, 3 + xvinsgr2vr.d $xr13, $zero, 2 + xvinsgr2vr.d $xr13, $zero, 3 + xvinsgr2vr.d $xr14, $zero, 2 + xvinsgr2vr.d $xr14, $zero, 3 + xvinsgr2vr.d $xr15, $zero, 2 + xvinsgr2vr.d $xr15, $zero, 3 + jr $ra + +/*++ + +Routine Description: + + This routine implements a vectorized kernel to produce the final output for + the log softmax operation. + +Arguments: + + Input (a0) - Supplies the output buffer. + + Output (a1) - Supplies the output buffer. + + N (a2) - Supplies the number of elements to process. + + Parameters (a3) - Supplies an array containing the negative maximum and + logarithm values. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasComputeLogSoftmaxOutputF32KernelLasx + + ld.w $t0, $a3, 0 + ld.w $t1, $a3, 4 + ori $t2, $zero, 0x20 + xvreplgr2vr.w $xr4, $t0 # broadcast negative minimum value + xvreplgr2vr.w $xr5, $t1 # broadcast log(SumExp) + bltu $a2, $t2, .LComputeLogSoftmaxOutput.ProcessRemainingCountBy8 + +.LComputeLogSoftmaxOutput.ProcessRemainingCountBy32: + xvld $xr16, $a0, 0 + xvfadd.s $xr0, $xr4, $xr16 + xvld $xr16, $a0, 0x20 + xvfadd.s $xr1, $xr4, $xr16 + addi.d $a2, $a2, -0x20 + xvld $xr16, $a0, 0x40 + xvfadd.s $xr2, $xr4, $xr16 + xvld $xr16, $a0, 0x60 + xvfadd.s $xr3, $xr4, $xr16 + addi.d $a0, $a0, 0x80 # advance input by 32 elements + xvfsub.s $xr0, $xr0, $xr5 # do as two steps for numeric stability + xvfsub.s $xr1, $xr1, $xr5 # do as two steps for numeric stability + xvfsub.s $xr2, $xr2, $xr5 # do as two steps for numeric stability + xvfsub.s $xr3, $xr3, $xr5 # do as two steps for numeric stability + xvst $xr0, $a1, 0 + xvst $xr1, $a1, 0x20 + xvst $xr2, $a1, 0x40 + xvst $xr3, $a1, 0x60 + addi.d $a1, $a1, 0x80 # advance output by 32 elements + bgeu $a2, $t2, .LComputeLogSoftmaxOutput.ProcessRemainingCountBy32 + +.LComputeLogSoftmaxOutput.ProcessRemainingCountBy8: + ori $t3, $zero, 8 + bltu $a2, $t3, .LComputeLogSoftmaxOutput.ProcessRemainingCountLessThan8 + xvld $xr16, $a0, 0 + xvfadd.s $xr0, $xr4, $xr16 + addi.d $a0, $a0, 0x20 + xvfsub.s $xr0, $xr0, $xr5 + addi.d $a2, $a2, -8 + xvst $xr0, $a1, 0 + addi.d $a1, $a1, 0x20 # advance output by 8 elements + b .LComputeLogSoftmaxOutput.ProcessRemainingCountBy8 + +.LComputeLogSoftmaxOutput.ProcessRemainingCountLessThan8: + beqz $a2, .LComputeLogSoftmaxOutput.ExitKernel + +.LComputeLogSoftmaxOutput.ProcessRemainingCountBy1: + fld.s $f16, $a0, 0 + fadd.s $f0, $f4, $f16 + + addi.d $a0, $a0, 4 + fsub.s $f0, $f0, $f5 + fst.s $f0, $a1, 0 + + addi.d $a1, $a1, 4 + addi.d $a2, $a2, -1 + bnez $a2, .LComputeLogSoftmaxOutput.ProcessRemainingCountBy1 + +.LComputeLogSoftmaxOutput.ExitKernel: + xvinsgr2vr.d $xr0, $zero, 2 + xvinsgr2vr.d $xr0, $zero, 3 + xvinsgr2vr.d $xr1, $zero, 2 + xvinsgr2vr.d $xr1, $zero, 3 + xvinsgr2vr.d $xr2, $zero, 2 + xvinsgr2vr.d $xr2, $zero, 3 + xvinsgr2vr.d $xr3, $zero, 2 + xvinsgr2vr.d $xr3, $zero, 3 + xvinsgr2vr.d $xr4, $zero, 2 + xvinsgr2vr.d $xr4, $zero, 3 + xvinsgr2vr.d $xr5, $zero, 2 + xvinsgr2vr.d $xr5, $zero, 3 + xvinsgr2vr.d $xr6, $zero, 2 + xvinsgr2vr.d $xr6, $zero, 3 + xvinsgr2vr.d $xr7, $zero, 2 + xvinsgr2vr.d $xr7, $zero, 3 + xvinsgr2vr.d $xr8, $zero, 2 + xvinsgr2vr.d $xr8, $zero, 3 + xvinsgr2vr.d $xr9, $zero, 2 + xvinsgr2vr.d $xr9, $zero, 3 + xvinsgr2vr.d $xr10, $zero, 2 + xvinsgr2vr.d $xr10, $zero, 3 + xvinsgr2vr.d $xr11, $zero, 2 + xvinsgr2vr.d $xr11, $zero, 3 + xvinsgr2vr.d $xr12, $zero, 2 + xvinsgr2vr.d $xr12, $zero, 3 + xvinsgr2vr.d $xr13, $zero, 2 + xvinsgr2vr.d $xr13, $zero, 3 + xvinsgr2vr.d $xr14, $zero, 2 + xvinsgr2vr.d $xr14, $zero, 3 + xvinsgr2vr.d $xr15, $zero, 2 + xvinsgr2vr.d $xr15, $zero, 3 + jr $ra + + .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLSX.S b/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLSX.S new file mode 100644 index 0000000000000..96bda3bb12c6f --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLSX.S @@ -0,0 +1,460 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SpoolKernelLSX.s + +Abstract: + + This module implements the kernels for the single precision pooling + operation. + + This implementation uses LSX instructions. + +--*/ + +#define SP_SIZE 32*8 +#define InputBase_arg SP_SIZE+0*8 +#define InputWidth_arg SP_SIZE+1*8 +#define DilatedInputWidth_arg SP_SIZE+2*8 +#define OutputCountLeftPad_arg SP_SIZE+3*8 +#define OutputCount_arg SP_SIZE+4*8 +#define OutputCountRightPad_arg SP_SIZE+5*8 + + .macro FUNCTION_ENTRY FunctionName + + .p2align 4 + .globl \FunctionName\() + .type \FunctionName\(),@function +\FunctionName\(): + + .endm + + + .text + +/*++ + +Macro Description: + + This macro generates code to initialize registers used across the kernel. + +Arguments: + + PoolingType - Supplies the pooling type string. + +--*/ + + .macro InitializeKernel PoolingType + +.ifeqs "\PoolingType\()","Maximum" + li.w $s0, 0xFF7FFFFF + vreplgr2vr.w $vr5, $s0 +.endif + +.ifeqs "\PoolingType\()","AverageIncludePad" + vreplgr2vr.w $vr5, $a5 + vffint.s.w $vr5, $vr5 +.endif + + .endm +/*++ + +Macro Description: + + This macro generates the common prologue code for the pooling kernels. + +Arguments: + + PoolingType - Supplies the pooling type string. + +--*/ + + .macro SpoolKernelEntry PoolingType + + addi.d $sp, $sp, -SP_SIZE + st.d $s0, $sp, 0*8 + st.d $s1, $sp, 1*8 + st.d $s2, $sp, 2*8 + st.d $s3, $sp, 3*8 + st.d $s4, $sp, 4*8 + st.d $ra, $sp, 5*8 + fst.d $f24,$sp, 6*8 + + InitializeKernel \PoolingType\() + # move InputStride to s8 + or $t8, $a4, $r0 + # move StrideWidth to a4 + or $a4, $a2, $r0 + # move DilationWidth to a5 + or $a5, $a3, $r0 + # move Output to a2 + or $a2, $a1, $r0 + + .endm + +/*++ + +Macro Description: + + This macro generates the common epilogue code for the pooling kernels. + +Arguments: + + None. + +--*/ + + .macro SpoolKernelExit + + ld.d $s0, $sp, 0*8 + ld.d $s1, $sp, 1*8 + ld.d $s2, $sp, 2*8 + ld.d $s3, $sp, 3*8 + ld.d $s4, $sp, 4*8 + ld.d $ra, $sp, 5*8 + fld.d $f24,$sp, 6*8 + + addi.d $sp, $sp, SP_SIZE + jr $ra + + .endm + + +/*++ + +Macro Description: + + This macro generates code to clear the pooling intermediates. + + For PoolingType==Maximum, the pooling intermediates are set to the minimum + float value. Otherwise, the pooling intermediates are cleared to zero. + +Arguments: + + PoolingType - Supplies the pooling type string. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a1 - Supplies the number of blocks accessed by ComputeBlock, if + PoolingType=AverageExcludePad and OutputCount=1. + + vr0-vr1 - Supplies the pooling intermediates. + + vr2 - Supplies a vector containing the minimum float value broadcasted, + if PoolingType==Maximum. + +--*/ + + .macro ClearBlock PoolingType, OutputCount + +.ifeqs "\PoolingType\()","Maximum" + vor.v $vr0, $vr5, $vr5 + vor.v $vr1, $vr5, $vr5 +.else + vxor.v $vr0, $vr0, $vr0 + vxor.v $vr1, $vr1, $vr1 +.endif + +.ifeqs "\PoolingType\()","AverageExcludePad" + xor $a1, $a1, $a1 # reset valid block counter +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to sample the input buffer and update the pooling + intermediates as appropriate. + +Arguments: + + PoolingType - Supplies the pooling type string. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a3 - Supplies the address of the input buffer. + + a1 - Supplies the number of blocks accessed by ComputeBlock, if + PoolingType=AverageExcludePad and OutputCount=1. + + a4 - Supplies the StrideWidth parameter (see function description). + + vr0-vr1 - Supplies the pooling intermediates. + +--*/ + + .macro ComputeBlock PoolingType, OutputCount + +.ifeqs "\PoolingType\()","Maximum" + vld $vr24, $a3, 0 + vfmax.s $vr0, $vr0, $vr24 + vld $vr24, $a3, 16 + vfmax.s $vr1, $vr1, $vr24 +.else + vld $vr24, $a3, 0 + vfadd.s $vr0, $vr0, $vr24 + vld $vr24, $a3, 16 + vfadd.s $vr1, $vr1, $vr24 +.endif + +.ifeqs "\PoolingType\()","AverageExcludePad" + # increment valid block counter + addi.d $a1, $a1, 1 +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to process and store the pooling intermediates. + +Arguments: + + PoolingType - Supplies the pooling type string. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a2 - Supplies the address of the output buffer. + + a1 - Supplies the number of blocks accessed by ComputeBlock, if + PoolingType=AverageExcludePad and OutputCount=1. + + vr0-vr1 - Supplies the pooling intermediates. + + vr5 - Supplies the kernel size computed by InitializeKernel, if + PoolingType=AverageExcludePad, else the actual kernel size, if + PoolingType=AverageIncludePad. + +--*/ + + .macro PostProcessBlock PoolingType, OutputCount + +// +// If PoolingType=AverageExcludePad, divide the sum by the number of non-padding +// blocks. +// + +.ifeqs "\PoolingType\()","AverageExcludePad" + # convert valid block counter + vreplgr2vr.w $vr4, $a1 + vffint.s.w $vr4, $vr4 + vfdiv.s $vr0, $vr0, $vr4 + vfdiv.s $vr1, $vr1, $vr4 +.endif + +// +// If PoolingType=AverageIncludePad, divide the sum by the actual kernel size. +// + +.ifeqs "\PoolingType\()","AverageIncludePad" + vfdiv.s $vr0, $vr0, $vr5 + vfdiv.s $vr1, $vr1, $vr5 +.endif + +// +// Store the output block in the output buffer. +// + + vst $vr0, $a2, 0 + vst $vr1, $a2, 16 + # advance output by 1 nchw8c block + addi.d $a2, $a2, 8*4 + + .endm + +/*++ + +Macro Description: + + This macro generates code to compute pooling for a vector of input blocks + to produce a matrix of output blocks. + + OutputCount=1 generates special case code to handle padding blocks. All + other output counts assume no padding. + +Arguments: + + KernelFrame - Supplies the symbol name to access the convolution kernel + stack. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a2 - Supplies the address of the output buffer. + + a4 - Supplies the StrideWidth parameter (see function description). + + a5 - Supplies the DilationWidth parameter (see function description). + + s8 - Supplies the InputStride parameter (see function description). + +--*/ + + .macro ProcessOutputCountN KernelFrame, PoolingType, OutputCount + + move $a3, $a0 + move $t1, $a6 + move $t2, $a7 +.if \OutputCount\() == 1 + ld.d $t3, $sp, InputBase_arg + ld.d $t4, $sp, InputWidth_arg + sub.d $t3, $r0, $t3 # keep negative for lea usage below +.endif + ClearBlock \PoolingType\(), \OutputCount\() + beqz $t1, .L\PoolingType\().\OutputCount\().HandlePostProcessing + +.L\PoolingType\().\OutputCount\().ProcessNextRow: + or $t6, $t2, $t2 + +.L\PoolingType\().\OutputCount\().ProcessNextColumn: +.if \OutputCount\() == 1 + # (Input - InputBase) >= InputWidth? + add.d $t7, $a3, $t3 + bgeu $t7, $t4, .L\PoolingType\().\OutputCount\().SkipOverPadding +.endif + ComputeBlock \PoolingType\(), \OutputCount\() + +.L\PoolingType\().\OutputCount\().SkipOverPadding: + add.d $a3, $a3, $a5 # advance input by dilation width + # decrement columns remaining + addi.d $t6, $t6, -1 + bnez $t6, .L\PoolingType\().\OutputCount\().ProcessNextColumn + add.d $a3, $a3, $t8 # advance input to next row +.if \OutputCount\() == 1 + ld.d $s0, $sp, DilatedInputWidth_arg + # advance input base to next row + sub.d $t3, $t3, $s0 +.endif + addi.d $t1, $t1, -1 + bnez $t1, .L\PoolingType\().\OutputCount\().ProcessNextRow + +.L\PoolingType\().\OutputCount\().HandlePostProcessing: + PostProcessBlock \PoolingType\(), \OutputCount\() + + .endm +/*++ + +Macro Description: + + This macro generates code for the inner pooling kernel. + +Arguments: + + PoolingType - Supplies the pooling type string. + + Isa - Supplies the instruction set architecture string for function tags. + +--*/ + + .macro SpoolKernelFunction PoolingType, Isa + +/*++ + +Routine Description: + + This routine is the inner kernel to compute pooling for the elements of an + output row for a set of filter rows. + +Arguments: + + Input (a0) - Supplies the address of the input buffer. + + The address is biased to include padding blocks for the left width + dimension. The address is not biased to include padding rows for the + left height dimension these are accounted for in the outer kernel. + + Output (a1) - Supplies the address of the output buffer. + + StrideWidth (a2) - Supplies the length in bytes of the blocked stride width. + + DilationWidth (a3) - Supplies the length in bytes of the blocked dilation + width. + + InputStride (a4) - Supplies the length in bytes to advance the input buffer to + the next input row. + + ActualKernelSize (a5) - Supplies the size of the kernel based on the original + kernel dimensions, used for PoolingType=AverageIncludePad. + + KernelHeight (a6) - Supplies the height of the kernel to apply. This height may + be less than the original kernel height after removing any padding + rows. + + KernelWidth (a7) - Supplies the width of the kernel to apply. + + InputBase (0)- Supplies the address of the valid input buffer. + + This parameter is similar to the Input parameter, but does not include + the padding blocks for the left width dimension. This parameter is used + with the following InputWidth parameter in order to validate that the + current input buffer address in bounds and not in the left or right + width padding region. + + InputWidth (1*8)- Supplies the length in bytes of the blocked input width. + + DilatedInputWidth (2*8)- Supplies the length in bytes to advance the input base + buffer to the next input row including dilation. + + OutputCountLeftPad (3*8)- Supplies the number of output elements that include + one or more padding elements from the left edge. + + OutputCount (4*8)- Supplies the number of output elements that do not include + any padding elements. + + OutputCountRightPad (5*8)- Supplies the number of output elements that include + one or more padding elements from the right edge. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasPool\PoolingType\()FloatKernel\Isa\() + SpoolKernelEntry \PoolingType\() + + ld.d $s0, $sp, OutputCountLeftPad_arg + ld.d $s1, $sp, OutputCount_arg + add.d $t0, $s0, $s1 + ld.d $s0, $sp, OutputCountRightPad_arg + add.d $t0, $t0, $s0 + beqz $t0, .L\PoolingType\().ExitKernel + +.L\PoolingType\().ProcessNextOutputCount: + ProcessOutputCountN .LSpoolKernelFrame, \PoolingType\(), 1 + add.d $a0, $a0, $a4 + addi.d $t0, $t0, -1 + bnez $t0, .L\PoolingType\().ProcessNextOutputCount + +.L\PoolingType\().ExitKernel: + SpoolKernelExit + + .endm + +// +// Generate the pooling kernels. +// + + SpoolKernelFunction Maximum, LSX + SpoolKernelFunction AverageExcludePad, LSX + SpoolKernelFunction AverageIncludePad, LSX + + .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLasx.S b/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLasx.S new file mode 100644 index 0000000000000..6e5f0136cd4ab --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLasx.S @@ -0,0 +1,238 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SpoolKernelLasx.s + +Abstract: + + This module implements the kernels for the single precision pooling + operation. + + This implementation uses Lasx instructions. + +--*/ + +#include "asmmacro.h" +#include "SpoolKernelLasxCommon.h" + + .text + +/*++ + +Macro Description: + + This macro generates code to initialize registers used across the kernel. + +Arguments: + + PoolingType - Supplies the pooling type string. + +Implicit Arguments: + + a5 - Supplies the ActualKernelSize parameter (see function description). + +--*/ + + .macro InitializeKernel PoolingType + +.ifeqs "\PoolingType\()","Maximum" + li.w $s0, 0xFF7FFFFF + xvreplgr2vr.w $xr5, $s0 +.else + xvxor.v $xr5, $xr5, $xr5 +.ifeqs "\PoolingType\()","AverageExcludePad" + move $t6, $a6 + mul.d $t6, $t6, $a7 + xvreplgr2vr.w $xr5, $t6 +.else + xvreplgr2vr.w $xr5, $a5 +.endif + xvffint.s.w $xr5, $xr5 +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to clear the pooling intermediates. + + For PoolingType==Maximum, the pooling intermediates are set to the minimum + float value. Otherwise, the pooling intermediates are cleared to zero. + +Arguments: + + PoolingType - Supplies the pooling type string. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a1 - Supplies the number of blocks accessed by ComputeBlock, if + PoolingType=AverageExcludePad and OutputCount=1. + + xr0-xr2 - Supplies the pooling intermediates. + + xr5 - Supplies a vector containing the minimum float value broadcasted, + if PoolingType==Maximum. + +--*/ + + .macro ClearBlock PoolingType, OutputCount + +.ifeqs "\PoolingType\()","Maximum" + EmitIfCountGE \OutputCount\(), 1, "xvor.v $xr0, $xr5, $xr5" + EmitIfCountGE \OutputCount\(), 2, "xvor.v $xr1, $xr5, $xr5" + EmitIfCountGE \OutputCount\(), 3, "xvor.v $xr2, $xr5, $xr5" +.else + EmitIfCountGE \OutputCount\(), 1, "xvxor.v $xr0, $xr0, $xr0" + EmitIfCountGE \OutputCount\(), 2, "xvxor.v $xr1, $xr1, $xr1" + EmitIfCountGE \OutputCount\(), 3, "xvxor.v $xr2, $xr2, $xr2" +.endif + +.ifeqs "\PoolingType\()","AverageExcludePad" +.if \OutputCount\() == 1 + xor $a1, $a1, $a1 # reset valid block counter +.endif +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to sample the input buffer and update the pooling + intermediates as appropriate. + +Arguments: + + PoolingType - Supplies the pooling type string. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a3 - Supplies the address of the input buffer. + + a1 - Supplies the number of blocks accessed by ComputeBlock, if + PoolingType=AverageExcludePad and OutputCount=1. + + a4 - Supplies the StrideWidth parameter (see function description). + + xr0-xr2 - Supplies the pooling intermediates. + +--*/ + + .macro ComputeBlock PoolingType, OutputCount + +.ifeqs "\PoolingType\()","Maximum" + EmitIfCountGE \OutputCount\(), 1, "xvld $xr16, $a3, 0" + EmitIfCountGE \OutputCount\(), 1, "xvfmax.s $xr0, $xr0, $xr16" + EmitIfCountGE \OutputCount\(), 2, "xvldx $xr16, $a3, $a4" + EmitIfCountGE \OutputCount\(), 2, "xvfmax.s $xr1, $xr1, $xr16" + EmitIfCountGE \OutputCount\(), 3, "slli.d $s0, $a4, 1" + EmitIfCountGE \OutputCount\(), 3, "xvldx $xr16, $a3, $s0" + EmitIfCountGE \OutputCount\(), 3, "xvfmax.s $xr2, $xr2, $xr16" +.else + EmitIfCountGE \OutputCount\(), 1, "xvld $xr16, $a3, 0" + EmitIfCountGE \OutputCount\(), 1, "xvfadd.s $xr0, $xr0, $xr16" + EmitIfCountGE \OutputCount\(), 2, "xvldx $xr16, $a3, $a4" + EmitIfCountGE \OutputCount\(), 2, "xvfadd.s $xr1, $xr1, $xr16" + EmitIfCountGE \OutputCount\(), 3, "slli.d $s0, $a4, 1" + EmitIfCountGE \OutputCount\(), 3, "xvldx $xr16, $a3, $s0" + EmitIfCountGE \OutputCount\(), 3, "xvfadd.s $xr2, $xr2, $xr16" +.endif + +.ifeqs "\PoolingType\()","AverageExcludePad" +.if \OutputCount\() == 1 + addi.d $a1, $a1, 1 # increment valid block counter +.endif +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to process and store the pooling intermediates. + +Arguments: + + PoolingType - Supplies the pooling type string. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a2 - Supplies the address of the output buffer. + + a1 - Supplies the number of blocks accessed by ComputeBlock, if + PoolingType=AverageExcludePad and OutputCount=1. + + xr0-xr2 - Supplies the pooling intermediates. + + xr5 - Supplies the kernel size computed by InitializeKernel, if + PoolingType=AverageExcludePad, else the actual kernel size, if + PoolingType=AverageIncludePad. + +--*/ + + .macro PostProcessBlock PoolingType, OutputCount + +// +// If PoolingType=AverageExcludePad, divide the sum by the number of non-padding +// blocks. OutputCount=1 generates code to count the number of blocks accessed by +// ComputeBlock. Other cases use the kernel size computed by InitializeKernel. +// + +.ifeqs "\PoolingType\()","AverageExcludePad" +.if \OutputCount\() == 1 + xvxor.v $xr4, $xr4, $xr4 + xvreplgr2vr.w $xr4, $a1 + xvffint.s.w $xr4, $xr4 + xvfdiv.s $xr0, $xr0, $xr4 +.else + EmitIfCountGE \OutputCount\(), 1, "xvfdiv.s $xr0, $xr0, $xr5" + EmitIfCountGE \OutputCount\(), 2, "xvfdiv.s $xr1, $xr1, $xr5" + EmitIfCountGE \OutputCount\(), 3, "xvfdiv.s $xr2, $xr2, $xr5" +.endif +.endif + +// +// If PoolingType=AverageIncludePad, divide the sum by the actual kernel size. +// + +.ifeqs "\PoolingType\()","AverageIncludePad" + EmitIfCountGE \OutputCount\(), 1, "xvfdiv.s $xr0, $xr0, $xr5" + EmitIfCountGE \OutputCount\(), 2, "xvfdiv.s $xr1, $xr1, $xr5" + EmitIfCountGE \OutputCount\(), 3, "xvfdiv.s $xr2, $xr2, $xr5" +.endif + +// +// Store the output block in the output buffer. +// + + EmitIfCountGE \OutputCount\(), 1, "xvst $xr0, $a2, 0" + EmitIfCountGE \OutputCount\(), 2, "xvst $xr1, $a2, 0x20" + EmitIfCountGE \OutputCount\(), 3, "xvst $xr2, $a2, 0x40" + add_immed $a2,\OutputCount\()*8*4 # advance output by N nchw8c blocks + + .endm + +// +// Generate the pooling kernels. +// + + SpoolKernelFunction Maximum, Lasx + SpoolKernelFunction AverageExcludePad, Lasx + SpoolKernelFunction AverageIncludePad, Lasx + + .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLasxCommon.h b/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLasxCommon.h new file mode 100644 index 0000000000000..066c75d34f3f9 --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLasxCommon.h @@ -0,0 +1,311 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SpoolKernelasxCommon.h + +Abstract: + + This module contains common kernel macros and structures for the single + precision pooling operation for the Lasx kernels. + +--*/ + +// +// Stack frame layout for the pooling kernels. +// + +#define SP_SIZE 8*8 +#define InputBase_arg SP_SIZE+0*8 +#define InputWidth_arg SP_SIZE+1*8 +#define DilatedInputWidth_arg SP_SIZE+2*8 +#define OutputCountLeftPad_arg SP_SIZE+3*8 +#define OutputCount_arg SP_SIZE+4*8 +#define OutputCountRightPad_arg SP_SIZE+5*8 +/*++ + +Macro Description: + + This macro generates the common prologue code for the pooling kernels. + +Arguments: + + PoolingType - Supplies the pooling type string. + +--*/ + + .macro SpoolKernelEntry PoolingType + + addi.d $sp, $sp, -SP_SIZE + st.d $s0, $sp, 0 + st.d $s1, $sp, 1*8 + fst.d $f16, $sp, 2*8 + st.d $ra, $sp, 5*8 + + InitializeKernel \PoolingType\() + move $t8, $a4 + move $a4, $a2 + move $a5, $a3 + move $a2, $a1 + + .endm + +/*++ + +Macro Description: + + This macro generates the common epilogue code for the pooling kernels. + +Arguments: + + None. + +--*/ + + .macro SpoolKernelExit + + ld.d $s0, $sp, 0 + ld.d $s1, $sp, 1*8 + fld.d $f16, $sp, 2*8 + ld.d $ra, $sp, 5*8 + addi.d $sp, $sp, SP_SIZE + jr $ra + + .endm + +/*++ + +Macro Description: + + This macro generates code to compute pooling for a vector of input blocks + to produce a matrix of output blocks. + + OutputCount=1 generates special case code to handle padding blocks. All + other output counts assume no padding. + +Arguments: + + KernelFrame - Supplies the symbol name to access the convolution kernel + stack. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a2 - Supplies the address of the output buffer. + + a4 - Supplies the StrideWidth parameter (see function description). + + a5 - Supplies the DilationWidth parameter (see function description). + + t8 - Supplies the InputStride parameter (see function description). + +--*/ + + .macro ProcessOutputCountN KernelFrame, PoolingType, OutputCount + + move $a3, $a0 + move $t1, $a6 + move $t2, $a7 +.if \OutputCount\() == 1 + ld.d $t3, $sp, InputBase_arg + ld.d $t4, $sp, InputWidth_arg + sub.d $t3, $zero, $t3 +.endif + ClearBlock \PoolingType\(), \OutputCount\() + beqz $t1, .L\PoolingType\().\OutputCount\().HandlePostProcessing + +.L\PoolingType\().\OutputCount\().ProcessNextRow: + move $t6, $t2 + +.L\PoolingType\().\OutputCount\().ProcessNextColumn: +.if \OutputCount\() == 1 + add.d $t7, $a3, $t3 # compute (Input - InputBase) + # (Input - InputBase) >= InputWidth? + bgeu $t7, $t4, .L\PoolingType\().\OutputCount\().SkipOverPadding +.endif + ComputeBlock \PoolingType\(), \OutputCount\() + +.L\PoolingType\().\OutputCount\().SkipOverPadding: + add.d $a3, $a3, $a5 # advance input by dilation width + addi.d $t6, $t6, -1 # decrement columns remaining + bnez $t6, .L\PoolingType\().\OutputCount\().ProcessNextColumn + add.d $a3, $a3, $t8 # advance input to next row +.if \OutputCount\() == 1 + ld.d $s0, $sp, DilatedInputWidth_arg + sub.d $t3, $t3, $s0 + # advance input base to next row +.endif + addi.d $t1, $t1, -1 + bnez $t1, .L\PoolingType\().\OutputCount\().ProcessNextRow + +.L\PoolingType\().\OutputCount\().HandlePostProcessing: + PostProcessBlock \PoolingType\(), \OutputCount\() + + .endm +/*++ + +Macro Description: + + This macro generates code for the inner pooling kernel. + +Arguments: + + PoolingType - Supplies the pooling type string. + + Isa - Supplies the instruction set architecture string for function tags. + +--*/ + + .macro SpoolKernelFunction PoolingType, Isa + +/*++ + +Routine Description: + + This routine is the inner kernel to compute pooling for the elements of an + output row for a set of filter rows. + +Arguments: + + Input (a0) - Supplies the address of the input buffer. + + The address is biased to include padding blocks for the left width + dimension. The address is not biased to include padding rows for the + left height dimension these are accounted for in the outer kernel. + + Output (a1) - Supplies the address of the output buffer. + + StrideWidth (a2) - Supplies the length in bytes of the blocked stride width. + + DilationWidth (a3) - Supplies the length in bytes of the blocked dilation + width. + + InputStride (a4) - Supplies the length in bytes to advance the input buffer to + the next input row. + + ActualKernelSize (a5) - Supplies the size of the kernel based on the original + kernel dimensions, used for PoolingType=AverageIncludePad. + + KernelHeight (a6) - Supplies the height of the kernel to apply. This height may + be less than the original kernel height after removing any padding + rows. + + KernelWidth (a7)- Supplies the width of the kernel to apply. + + InputBase (sp + 0)- Supplies the address of the valid input buffer. + + This parameter is similar to the Input parameter, but does not include + the padding blocks for the left width dimension. This parameter is used + with the following InputWidth parameter in order to validate that the + current input buffer address in bounds and not in the left or right + width padding region. + + InputWidth (sp + 0x8)- Supplies the length in bytes of the blocked input width. + + DilatedInputWidth (sp + 0x10)- Supplies the length in bytes to advance the input base + buffer to the next input row including dilation. + + OutputCountLeftPad (sp + 0x18)- Supplies the number of output elements that include + one or more padding elements from the left edge. + + OutputCount (sp + 0x20)- Supplies the number of output elements that do not include + any padding elements. + + OutputCountRightPad (sp + 0x28)- Supplies the number of output elements that include + one or more padding elements from the right edge. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasPool\PoolingType\()FloatKernel\Isa\() + + SpoolKernelEntry \PoolingType\() + +.L\PoolingType\().ProcessOutputCountLeftPad: + ld.d $t0, $sp, OutputCountLeftPad_arg + + beqz $t0, .L\PoolingType\().ProcessOutputCount + bl MlasPool\PoolingType\()FloatSingle\Isa\() + +.L\PoolingType\().ProcessOutputCount: + ld.d $t0, $sp, OutputCount_arg + li.d $s0, 3 + bltu $t0, $s0, .L\PoolingType\().ProcessRemainingOutputCount + +.L\PoolingType\().ProcessNextOutputCountBy3: + ProcessOutputCountN .LSpoolKernelFrame, \PoolingType\(), 3 + slli.d $s0, $a4, 1 + add.d $t6, $s0, $a4 + add.d $a0, $a0, $t6 # advance input by 3 elements + addi.d $t0, $t0, -3 + li.d $s0, 3 + bgeu $t0, $s0, .L\PoolingType\().ProcessNextOutputCountBy3 + +.L\PoolingType\().ProcessRemainingOutputCount: + +.L\PoolingType\().ProcessOutputCountRightPad: + ld.d $s0, $sp, OutputCountRightPad_arg + add.d $t0, $t0, $s0 + beqz $t0, .L\PoolingType\().ExitKernel + bl MlasPool\PoolingType\()FloatSingle\Isa\() + +.L\PoolingType\().ExitKernel: + xvinsgr2vr.d $xr0, $zero, 2 + xvinsgr2vr.d $xr0, $zero, 3 + xvinsgr2vr.d $xr1, $zero, 2 + xvinsgr2vr.d $xr1, $zero, 3 + xvinsgr2vr.d $xr2, $zero, 2 + xvinsgr2vr.d $xr2, $zero, 3 + xvinsgr2vr.d $xr3, $zero, 2 + xvinsgr2vr.d $xr3, $zero, 3 + xvinsgr2vr.d $xr4, $zero, 2 + xvinsgr2vr.d $xr4, $zero, 3 + xvinsgr2vr.d $xr5, $zero, 2 + xvinsgr2vr.d $xr5, $zero, 3 + xvinsgr2vr.d $xr6, $zero, 2 + xvinsgr2vr.d $xr6, $zero, 3 + xvinsgr2vr.d $xr7, $zero, 2 + xvinsgr2vr.d $xr7, $zero, 3 + xvinsgr2vr.d $xr8, $zero, 2 + xvinsgr2vr.d $xr8, $zero, 3 + xvinsgr2vr.d $xr9, $zero, 2 + xvinsgr2vr.d $xr9, $zero, 3 + xvinsgr2vr.d $xr10, $zero, 2 + xvinsgr2vr.d $xr10, $zero, 3 + xvinsgr2vr.d $xr11, $zero, 2 + xvinsgr2vr.d $xr11, $zero, 3 + xvinsgr2vr.d $xr12, $zero, 2 + xvinsgr2vr.d $xr12, $zero, 3 + xvinsgr2vr.d $xr13, $zero, 2 + xvinsgr2vr.d $xr13, $zero, 3 + xvinsgr2vr.d $xr14, $zero, 2 + xvinsgr2vr.d $xr14, $zero, 3 + xvinsgr2vr.d $xr15, $zero, 2 + xvinsgr2vr.d $xr15, $zero, 3 + SpoolKernelExit + +// +// Generate out-of-band helpers for handling output blocks involving padding. +// + +MlasPool\PoolingType\()FloatSingle\Isa\(): + st.d $ra, $sp, 6*8 +loopMlasPool\PoolingType\()FloatSingle\Isa\(): + ProcessOutputCountN .LSpoolKernelSingleFrame, \PoolingType\(), 1 + add.d $a0, $a0, $a4 # advance input by 1 element + addi.d $t0, $t0, -1 # decrement output count remaining + bnez $t0, loopMlasPool\PoolingType\()FloatSingle\Isa\() + ld.d $ra, $sp, 6*8 + jr $ra + + .endm diff --git a/onnxruntime/core/mlas/lib/loongarch64/asmmacro.h b/onnxruntime/core/mlas/lib/loongarch64/asmmacro.h new file mode 100644 index 0000000000000..837aca77dd883 --- /dev/null +++ b/onnxruntime/core/mlas/lib/loongarch64/asmmacro.h @@ -0,0 +1,144 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + asmmacro.h + +Abstract: + + This module implements common macros for the assembly modules. + +--*/ + +#define C_UNDERSCORE(symbol) symbol + +.macro vmove dst src + vand.v \dst, \src, \src +.endm + +/*++ + +Macro Description: + + This macro emits the assembler directives to annotate a new function. + +Arguments: + + FunctionName - Supplies the name of the function. + +--*/ + + .macro FUNCTION_ENTRY FunctionName + .align 2 + .globl \FunctionName\() + .type \FunctionName\(),@function +\FunctionName\(): + + .endm + +/*++ + +Macro Description: + + This macro generates an optimization for "add reg,128" which can instead + be encoded as "sub reg,-128" to reduce code size by using a signed 8-bit + value. + +Arguments: + + Register - Supplies the register to be added to. + + Immediate - Supplies the immediate to add to the register. + +--*/ + + .macro add_immed Register, Immediate + +.if (\Immediate\() != 128) + addi.d \Register\(),\Register\(),\Immediate\() +.else + addi.d \Register\(),\Register\(),\Immediate\() # smaller encoding +.endif + + .endm + +/*++ + +Macro Description: + + This macro conditionally emits the statement if Count is greater than or + equal to Value. + +Arguments: + + Count - Supplies the variable used in the comparison. + + Value - Supplies the static used in the comparison. + + Statement - Supplies the statement to conditionally emit. + +--*/ + + .macro EmitIfCountGE Count1, Value1, Statement + +.if (\Count1\() >= \Value1\()) + \Statement\() +.endif + + .endm + +/*++ + +Macro Description: + + This macro conditionally emits the statement if Count1 is greater than or + equal to Value1 and Count2 is greater than or equal to Value2. + +Arguments: + + Count1 - Supplies the variable used in the comparison. + + Value1 - Supplies the static used in the comparison. + + Count2 - Supplies the variable used in the comparison. + + Value2 - Supplies the static used in the comparison. + + Statement - Supplies the statement to conditionally emit. + +--*/ + + .macro EmitIfCount2GE Count1, Value1, Count2, Value2, Statement + +.if (\Count1\() >= \Value1\()) && (\Count2\() >= \Value2\()) + \Statement\() +.endif + + .endm + +/*++ + +Macro Description: + + This macro emits the statement for each register listed in the register + list. The statement can use RegItem to access the current register. + +Arguments: + + RegList - Supplies the list of registers. + + Statement - Supplies the statement to emit. + +--*/ + + .macro EmitForEachRegister RegList, Statement + + .irp RegItem, \RegList\() + \Statement\() + .endr + + .endm diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 6c859e4e4f44b..7bda1bb504173 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -67,6 +67,9 @@ Module Name: #undef pixel #undef bool #endif +#if defined(__loongarch64) +#include +#endif #if defined(MLAS_TARGET_WASM_SIMD) #include #endif @@ -317,7 +320,8 @@ static_assert(sizeof(MLAS_FP16) == FP16_SIZE); // Define the prototypes of the platform optimized routines. // -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || \ + defined(MLAS_TARGET_LARCH64) typedef size_t @@ -694,6 +698,30 @@ extern "C" { MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelPOWER10; MLAS_QUANTIZE_LINEAR_S8_KERNEL MlasQuantizeLinearS8KernelVSX; MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8KernelVSX; +#elif defined(MLAS_TARGET_LARCH64) + MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelLSX; + MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelLasx; + MLAS_GEMM_DOUBLE_KERNEL MlasGemmDoubleKernelLSX; + MLAS_GEMM_DOUBLE_KERNEL MlasGemmDoubleKernelLasx; + MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelLSX; + MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelLSX; + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelLSX; + MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelLSX; + MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelLasx; + MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelLasx; + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelLasx; + MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelLasx; + MLAS_POOL_FLOAT_KERNEL MlasPoolMaximumFloatKernelLSX; + MLAS_POOL_FLOAT_KERNEL MlasPoolAverageExcludePadFloatKernelLSX; + MLAS_POOL_FLOAT_KERNEL MlasPoolAverageIncludePadFloatKernelLSX; + MLAS_POOL_FLOAT_KERNEL MlasPoolMaximumFloatKernelLasx; + MLAS_POOL_FLOAT_KERNEL MlasPoolAverageExcludePadFloatKernelLasx; + MLAS_POOL_FLOAT_KERNEL MlasPoolAverageIncludePadFloatKernelLasx; + MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE MlasSgemmTransposePackB16x4LSX; + MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE MlasSgemmTransposePackB16x4Lasx; + MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32KernelLasx; + MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL MlasComputeSoftmaxOutputF32KernelLasx; + MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL MlasComputeLogSoftmaxOutputF32KernelLasx; #else MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelZero; MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelAdd; @@ -854,6 +882,7 @@ MlasSgemmOperation( struct MLAS_GEMM_QUANT_DISPATCH; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchSse; +extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchLSX; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchSse41; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchAvx2; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8U8DispatchAvx2; @@ -979,7 +1008,22 @@ struct MLAS_PLATFORM { #if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) MLAS_GEMM_FLOAT_KERNEL* GemmFloatKernel; #endif - +#if defined(MLAS_TARGET_LARCH64) + const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch; + const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch; + MLAS_GEMM_FLOAT_KERNEL* GemmFloatKernel; + MLAS_GEMM_DOUBLE_KERNEL* GemmDoubleKernel; + MLAS_CONV_FLOAT_KERNEL* ConvNchwFloatKernel; + MLAS_CONV_FLOAT_KERNEL* ConvNchwcFloatKernel; + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* ConvDepthwiseFloatKernel; + MLAS_CONV_POINTWISE_FLOAT_KERNEL* ConvPointwiseFloatKernel; + MLAS_POOL_FLOAT_KERNEL* PoolFloatKernel[MlasPoolingKindCount]; + MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE* TransposePackB16x4Routine; + MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL* ReduceMaximumF32Kernel; + MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeSoftmaxOutputF32Kernel; + MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeLogSoftmaxOutputF32Kernel; + uint32_t NchwcBlockSize; +#endif #if defined(MLAS_TARGET_AMD64_IX86) const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch; const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch; @@ -1256,6 +1300,8 @@ MlasConvDepthwiseFloat_CHW( #endif #elif defined(MLAS_TARGET_WASM_SIMD) #define MLAS_WASM_SIMD_INTRINSICS +#elif defined(MLAS_TARGET_LARCH64) +#define MLAS_LSX_INTRINSICS #endif #if defined(MLAS_NEON_INTRINSICS) @@ -1271,6 +1317,9 @@ typedef __vector unsigned MLAS_UINT32X4; #elif defined(MLAS_WASM_SIMD_INTRINSICS) typedef v128_t MLAS_FLOAT32X4; typedef v128_t MLAS_INT32X4; +#elif defined(MLAS_LSX_INTRINSICS) +typedef __m128 MLAS_FLOAT32X4; +typedef __m128i MLAS_INT32X4; #else typedef float MLAS_FLOAT32X4 __attribute__ ((vector_size(16))); typedef int32_t MLAS_INT32X4 __attribute__ ((vector_size(16))); @@ -1284,6 +1333,8 @@ MlasReinterpretAsInt32x4(MLAS_FLOAT32X4 Vector) return vreinterpretq_s32_f32(Vector); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_castps_si128(Vector); +#elif defined(MLAS_LSX_INTRINSICS) + return (MLAS_INT32X4)Vector; #else return MLAS_INT32X4(Vector); #endif @@ -1299,6 +1350,8 @@ MlasCastToInt32x4(MLAS_FLOAT32X4 Vector) return _mm_cvttps_epi32(Vector); #elif defined(MLAS_VSX_INTRINSICS) return vec_cts(Vector, 0); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vftint_w_s(Vector); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return (MLAS_INT32X4)__builtin_convertvector((__f32x4)Vector, __i32x4); #else @@ -1318,6 +1371,8 @@ MlasCastToFloat32x4(MLAS_INT32X4 Vector) return vec_ctf(Vector, 0); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_convert_i32x4(Vector); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vffint_s_w(Vector); #else return MLAS_FLOAT32X4{float(Vector[0]), float(Vector[1]), float(Vector[2]), float(Vector[3])}; #endif @@ -1335,6 +1390,8 @@ MlasBroadcastInt32x4(int32_t Value) return wasm_i32x4_splat(Value); #elif defined(MLAS_VSX_INTRINSICS) return vec_splats(Value); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vreplgr2vr_w(Value); #else return MLAS_INT32X4{Value, Value, Value, Value}; #endif @@ -1352,6 +1409,8 @@ MlasLoadInt32x4(const int32_t* Buffer) return vec_vsx_ld(0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_load(Buffer); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vld((const MLAS_INT32X4*)Buffer, 0); #else return *((MLAS_INT32X4*)Buffer); #endif @@ -1369,6 +1428,8 @@ MlasStoreInt32x4(int32_t* Buffer, MLAS_INT32X4 Vector) vec_vsx_st(Vector, 0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) wasm_v128_store(Buffer, Vector); +#elif defined(MLAS_LSX_INTRINSICS) + __lsx_vst(Vector, (MLAS_INT32X4 *)Buffer, 0); #else *((MLAS_INT32X4*)Buffer) = Vector; #endif @@ -1386,6 +1447,8 @@ MlasAddInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return wasm_i32x4_add(Vector1, Vector2); #elif defined(MLAS_VSX_INTRINSICS) return vec_add(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vadd_w(Vector1, Vector2); #else return Vector1 + Vector2; #endif @@ -1401,6 +1464,8 @@ MlasSubtractInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return _mm_sub_epi32(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_i32x4_sub(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vsub_w(Vector1, Vector2); #else return Vector1 - Vector2; #endif @@ -1416,6 +1481,8 @@ MlasAndInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return _mm_and_si128(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_and(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vand_v(Vector1, Vector2); #else return Vector1 & Vector2; #endif @@ -1431,6 +1498,8 @@ MlasOrInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return _mm_or_si128(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_or(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vor_v(Vector1, Vector2); #else return Vector1 | Vector2; #endif @@ -1446,6 +1515,8 @@ MlasAndNotInt32x4(MLAS_INT32X4 VectorNot, MLAS_INT32X4 Vector) return _mm_andnot_si128(VectorNot, Vector); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_andnot(Vector, VectorNot); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vandn_v(VectorNot, Vector); #else return (~VectorNot) & Vector; #endif @@ -1463,6 +1534,8 @@ MlasXorInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return wasm_v128_xor(Vector1, Vector2); #elif defined(MLAS_VSX_INTRINSICS) return vec_xor(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vxor_v(Vector1, Vector2); #else return Vector1 ^ Vector2; #endif @@ -1486,6 +1559,8 @@ MlasShiftLeftInt32x4(MLAS_INT32X4 Vector) return _mm_slli_epi32(Vector, ShiftCount); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_i32x4_shl(Vector, ShiftCount); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vslli_w(Vector, ShiftCount); #else return Vector << ShiftCount; #endif @@ -1505,6 +1580,8 @@ MlasMaximumInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return vec_vmaxsw(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_i32x4_max(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vmax_w(Vector1, Vector2); #else return MlasBlendInt32x4(Vector2, Vector1, Vector1 > Vector2); #endif @@ -1524,6 +1601,8 @@ MlasMinimumInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return vec_vminsw(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_i32x4_min(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vmin_w(Vector1, Vector2); #else return MlasBlendInt32x4(Vector2, Vector1, Vector2 > Vector1); #endif @@ -1537,6 +1616,8 @@ MlasReinterpretAsFloat32x4(MLAS_INT32X4 Vector) return vreinterpretq_f32_s32(Vector); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_castsi128_ps(Vector); +#elif defined(MLAS_LSX_INTRINSICS) + return MLAS_FLOAT32X4(Vector); #else return MLAS_FLOAT32X4(Vector); #endif @@ -1556,6 +1637,8 @@ MlasBroadcastFloat32x4(float Value) // Suppress wrong GCC warnings MLAS_UNREFERENCED_PARAMETER(Value); return vec_splats(Value); +#elif defined(MLAS_LSX_INTRINSICS) + return MLAS_FLOAT32X4{Value, Value, Value, Value}; #else return MLAS_FLOAT32X4{Value, Value, Value, Value}; #endif @@ -1573,6 +1656,8 @@ MlasBroadcastFloat32x4(const float* Value) return wasm_v128_load32_splat(Value); #elif defined(MLAS_VSX_INTRINSICS) return vec_splats(*Value); +#elif defined(MLAS_LSX_INTRINSICS) + return MLAS_FLOAT32X4{*Value, *Value, *Value, *Value}; #else return MLAS_FLOAT32X4{*Value, *Value, *Value, *Value}; #endif @@ -1588,6 +1673,8 @@ MlasZeroFloat32x4(void) return _mm_setzero_ps(); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_const(0.0f, 0.0f, 0.0f, 0.0f); +#elif defined(MLAS_LSX_INTRINSICS) + return MlasBroadcastFloat32x4(0.0f); #else return MlasBroadcastFloat32x4(0.0f); #endif @@ -1605,6 +1692,9 @@ MlasLoadFloat32x4(const float* Buffer) return vec_vsx_ld(0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_load(Buffer); +#elif defined(MLAS_LSX_INTRINSICS) + // return MlasReinterpretAsFloat32x4(__lsx_vld((const MLAS_INT32X4 *)Buffer, 0)); + return (MLAS_FLOAT32X4)__lsx_vld((const MLAS_INT32X4 *)Buffer, 0); #else return *((MLAS_FLOAT32X4*)Buffer); #endif @@ -1622,6 +1712,8 @@ MlasStoreFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) vec_vsx_st(Vector, 0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) wasm_v128_store(Buffer, Vector); +#elif defined(MLAS_LSX_INTRINSICS) + __lsx_vst(MlasReinterpretAsInt32x4(Vector), Buffer, 0); #else *((MLAS_FLOAT32X4*)Buffer) = Vector; #endif @@ -1642,6 +1734,8 @@ MlasStoreAlignedFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) vec_st(Vector, 0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) wasm_v128_store(Buffer, Vector); +#elif defined(MLAS_LSX_INTRINSICS) + MlasStoreFloat32x4(Buffer, Vector); #else MlasStoreFloat32x4(Buffer, Vector); #endif @@ -1660,6 +1754,8 @@ MlasStoreLaneFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) _mm_store_ss(Buffer, _mm_shuffle_ps(Vector, Vector, _MM_SHUFFLE(Lane, Lane, Lane, Lane))); #elif defined(MLAS_WASM_SIMD_INTRINSICS) *Buffer = ((__f32x4)(Vector))[Lane]; +#elif defined(MLAS_LSX_INTRINSICS) + *Buffer = Vector[Lane]; #else *Buffer = Vector[Lane]; #endif @@ -1675,6 +1771,9 @@ MlasStoreLowHalfFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) _mm_storel_pi((__m64*)Buffer, Vector); #elif defined(MLAS_VSX_INTRINSICS) *((long long*)Buffer) = ((__vector long long)Vector)[0]; +#elif defined(MLAS_LSX_INTRINSICS) + MlasStoreLaneFloat32x4<0>(&Buffer[0], Vector); + MlasStoreLaneFloat32x4<1>(&Buffer[1], Vector); #else MlasStoreLaneFloat32x4<0>(&Buffer[0], Vector); MlasStoreLaneFloat32x4<1>(&Buffer[1], Vector); @@ -1692,6 +1791,8 @@ MlasExtractLaneFloat32x4(MLAS_FLOAT32X4 Vector) return _mm_cvtss_f32(_mm_shuffle_ps(Vector, Vector, _MM_SHUFFLE(Lane, Lane, Lane, Lane))); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_extract_lane(Vector, Lane); +#elif defined(MLAS_LSX_INTRINSICS) + return Vector[Lane]; #else return Vector[Lane]; #endif @@ -1736,6 +1837,9 @@ MlasShuffleFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return wasm_i32x4_shuffle(Vector1, Vector2, Index0, Index1, Index2, Index3); #elif defined(__clang__) return __builtin_shufflevector(Vector1, Vector2, Index0, Index1, Index2, Index3); +#elif defined(MLAS_LSX_INTRINSICS) + typedef int32_t GEN_INT32X4 __attribute__ ((vector_size(16))); + return __builtin_shuffle(Vector1, Vector2, GEN_INT32X4{Index0, Index1, Index2, Index3}); #else return __builtin_shuffle(Vector1, Vector2, MLAS_INT32X4{Index0, Index1, Index2, Index3}); #endif @@ -1764,6 +1868,8 @@ MlasInterleaveLowFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return _mm_unpacklo_ps(Vector1, Vector2); #elif defined(MLAS_VSX_INTRINSICS) return vec_mergeh(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return (MLAS_FLOAT32X4)__lsx_vilvl_w(MlasReinterpretAsInt32x4(Vector2), MlasReinterpretAsInt32x4(Vector1)); #else return MlasShuffleFloat32x4<0, 4, 1, 5>(Vector1, Vector2); #endif @@ -1782,6 +1888,8 @@ MlasInterleaveHighFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return _mm_unpackhi_ps(Vector1, Vector2); #elif defined(MLAS_VSX_INTRINSICS) return vec_mergel(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return (MLAS_FLOAT32X4)__lsx_vilvh_w(MlasReinterpretAsInt32x4(Vector2), MlasReinterpretAsInt32x4(Vector1)); #else return MlasShuffleFloat32x4<2, 6, 3, 7>(Vector1, Vector2); #endif @@ -1799,6 +1907,8 @@ MlasAddFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return wasm_f32x4_add(Vector1, Vector2); #elif defined(MLAS_VSX_INTRINSICS) return vec_add(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vfadd_s(Vector1, Vector2); #else return Vector1 + Vector2; #endif @@ -1816,6 +1926,8 @@ MlasSubtractFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return wasm_f32x4_sub(Vector1, Vector2); #elif defined(MLAS_VSX_INTRINSICS) return vec_sub(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vfsub_s(Vector1, Vector2); #else return Vector1 - Vector2; #endif @@ -1836,6 +1948,8 @@ MlasMultiplyFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) MLAS_UNREFERENCED_PARAMETER(Vector1); MLAS_UNREFERENCED_PARAMETER(Vector2); return vec_mul(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vfmul_s(Vector1, Vector2); #else return Vector1 * Vector2; #endif @@ -1855,6 +1969,8 @@ MlasMultiplyAddFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2, MLAS_FL return vec_madd(Vector1, Vector2, Vector3); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_add(wasm_f32x4_mul(Vector1, Vector2), Vector3); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vfmadd_s(Vector1, Vector2, Vector3); #else return Vector1 * Vector2 + Vector3; #endif @@ -1890,6 +2006,8 @@ MlasDivideFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return _mm_div_ps(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_div(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vfdiv_s(Vector1, Vector2); #else return Vector1 / Vector2; #endif @@ -1907,6 +2025,8 @@ MlasGreaterThanFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return wasm_f32x4_gt(Vector1, Vector2); #elif defined(MLAS_VSX_INTRINSICS) return MLAS_FLOAT32X4(vec_cmpgt(Vector1, Vector2)); +#elif defined(MLAS_LSX_INTRINSICS) + return (MLAS_FLOAT32X4)__lsx_vfcmp_clt_s(Vector2, Vector1); #else return Vector1 > Vector2; #endif @@ -1920,6 +2040,8 @@ MlasAndFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return _mm_and_ps(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_and(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return MlasReinterpretAsFloat32x4(MlasAndInt32x4(MlasReinterpretAsInt32x4(Vector1), MlasReinterpretAsInt32x4(Vector2))); #else return MlasReinterpretAsFloat32x4(MlasAndInt32x4(MlasReinterpretAsInt32x4(Vector1), MlasReinterpretAsInt32x4(Vector2))); #endif @@ -1933,6 +2055,8 @@ MlasOrFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return _mm_or_ps(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_or(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return MlasReinterpretAsFloat32x4(MlasOrInt32x4(MlasReinterpretAsInt32x4(Vector1), MlasReinterpretAsInt32x4(Vector2))); #else return MlasReinterpretAsFloat32x4(MlasOrInt32x4(MlasReinterpretAsInt32x4(Vector1), MlasReinterpretAsInt32x4(Vector2))); #endif @@ -1946,6 +2070,8 @@ MlasAndNotFloat32x4(MLAS_FLOAT32X4 VectorNot, MLAS_FLOAT32X4 Vector) return _mm_andnot_ps(VectorNot, Vector); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_andnot(Vector, VectorNot); +#elif defined(MLAS_LSX_INTRINSICS) + return MlasReinterpretAsFloat32x4(MlasAndNotInt32x4(MlasReinterpretAsInt32x4(VectorNot), MlasReinterpretAsInt32x4(Vector))); #else return MlasReinterpretAsFloat32x4(MlasAndNotInt32x4(MlasReinterpretAsInt32x4(VectorNot), MlasReinterpretAsInt32x4(Vector))); #endif @@ -1959,6 +2085,8 @@ MlasXorFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return _mm_xor_ps(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_xor(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return MlasReinterpretAsFloat32x4(MlasXorInt32x4(MlasReinterpretAsInt32x4(Vector1), MlasReinterpretAsInt32x4(Vector2))); #else return MlasReinterpretAsFloat32x4(MlasXorInt32x4(MlasReinterpretAsInt32x4(Vector1), MlasReinterpretAsInt32x4(Vector2))); #endif @@ -1984,6 +2112,8 @@ MlasMaximumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return vec_sel(Vector2, Vector1, vec_cmpgt(Vector1, Vector2)); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_max(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vfmax_s(Vector1, Vector2); #else return MlasBlendFloat32x4(Vector2, Vector1, Vector1 > Vector2); #endif @@ -2002,6 +2132,8 @@ MlasMinimumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return vec_sel(Vector2, Vector1, vec_cmpgt(Vector2, Vector1)); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_min(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vfmin_s(Vector1, Vector2); #else return MlasBlendFloat32x4(Vector2, Vector1, Vector2 > Vector1); #endif @@ -2108,6 +2240,8 @@ MlasPowerOf2Float32x4(MLAS_FLOAT32X4 Vector) typedef __m128d MLAS_FLOAT64X2; #elif defined(MLAS_VSX_INTRINSICS) typedef __vector double MLAS_FLOAT64X2; +#elif defined(MLAS_LSX_INTRINSICS) +typedef __m128d MLAS_FLOAT64X2; #else #define MLAS_FLOAT64X2_UNSUPPORTED #endif @@ -2129,6 +2263,27 @@ MlasMultiplyAddFloat64x2(MLAS_FLOAT64X2 Vector1, MLAS_FLOAT64X2 Vector2, MLAS_FL return vec_madd(Vector1, Vector2, Vector3); } +MLAS_FORCEINLINE +MLAS_FLOAT64X2 +MlasBroadcastFloat64x2(const double *Value) +{ + return MLAS_FLOAT64X2{*Value, *Value}; +} +#elif defined(MLAS_LSX_INTRINSICS) +template +MLAS_FORCEINLINE +double +MlasExtractLaneFloat64x2(MLAS_FLOAT64X2 Vector) +{ + return Vector[Lane]; +} +MLAS_FORCEINLINE +MLAS_FLOAT64X2 +MlasMultiplyAddFloat64x2(MLAS_FLOAT64X2 Vector1, MLAS_FLOAT64X2 Vector2, MLAS_FLOAT64X2 Vector3) +{ + return __lsx_vfmadd_d(Vector1, Vector2, Vector3); +} + MLAS_FORCEINLINE MLAS_FLOAT64X2 MlasBroadcastFloat64x2(const double *Value) @@ -2144,6 +2299,8 @@ MlasBroadcastFloat64x2(double Value) return _mm_set1_pd(Value); #elif defined(MLAS_VSX_INTRINSICS) return MLAS_FLOAT64X2{Value, Value}; +#elif defined(MLAS_LSX_INTRINSICS) + return MLAS_FLOAT64X2{Value, Value}; #endif } @@ -2155,6 +2312,8 @@ MlasZeroFloat64x2(void) return _mm_setzero_pd(); #elif defined(MLAS_VSX_INTRINSICS) return MlasBroadcastFloat64x2(0.0f); +#elif defined(MLAS_LSX_INTRINSICS) + return MlasBroadcastFloat64x2(0.0f); #endif } @@ -2166,6 +2325,8 @@ MlasLoadFloat64x2(const double* Buffer) return _mm_loadu_pd(Buffer); #elif defined(MLAS_VSX_INTRINSICS) return vec_vsx_ld(0, Buffer); +#elif defined(MLAS_LSX_INTRINSICS) + return MLAS_FLOAT64X2(__lsx_vld((const MLAS_INT32X4 *)Buffer, 0)); #endif } @@ -2177,6 +2338,8 @@ MlasStoreFloat64x2(double* Buffer, MLAS_FLOAT64X2 Vector) _mm_storeu_pd(Buffer, Vector); #elif defined(MLAS_VSX_INTRINSICS) vec_vsx_st(Vector, 0, Buffer); +#elif defined(MLAS_LSX_INTRINSICS) + (__lsx_vst(MLAS_INT32X4(Vector), Buffer, 0)); #endif } @@ -2188,6 +2351,8 @@ MlasStoreAlignedFloat64x2(double* Buffer, MLAS_FLOAT64X2 Vector) _mm_store_pd(Buffer, Vector); #elif defined(MLAS_VSX_INTRINSICS) *((MLAS_FLOAT64X2*)Buffer) = Vector; +#elif defined(MLAS_LSX_INTRINSICS) + (__lsx_vst(MLAS_INT32X4(Vector), Buffer, 0)); #endif } @@ -2199,6 +2364,8 @@ MlasMultiplyFloat64x2(MLAS_FLOAT64X2 Vector1, MLAS_FLOAT64X2 Vector2) return _mm_mul_pd(Vector1, Vector2); #elif defined(MLAS_VSX_INTRINSICS) return Vector1 * Vector2; +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vfmul_d(Vector1, Vector2); #endif } @@ -2233,6 +2400,17 @@ MlasReadTimeStampCounter(void) ); return ((uint64_t)edx << 32) | eax; +#elif defined(MLAS_TARGET_LARCH64) + uint64_t time_cnt, id; + + __asm__ __volatile__ + ( + "rdtime.d %0, %1\n\t" + : "=r" (time_cnt), "=r" (id) + :: + ); + + return time_cnt; #else return 0; #endif diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index fec56c6ee063f..8329a34f1338f 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -185,6 +185,28 @@ MlasInitAMX() #endif // MLAS_TARGET_AMD64_IX86 +#ifdef MLAS_TARGET_LARCH64 + +#if defined(__linux__) +#include +#include +#endif +// +// Stores a vector to build a conditional load/store mask for vmaskmovps. +// + +MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const uint32_t MlasMaskMoveLasx[8], 32) = { 0, 1, 2, 3, 4, 5, 6, 7 }; + +// +// Stores a table of AVX vmaskmovps/vmaskmovpd load/store masks. +// + +MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const uint32_t MlasMaskMoveTableLasx[16], 32) = { + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, +}; + +#endif MLAS_PLATFORM::MLAS_PLATFORM( void ) @@ -536,6 +558,63 @@ Return Value: #endif // __linux__ #endif // MLAS_TARGET_POWER +#if defined(MLAS_TARGET_LARCH64) + + // + // Default to the baseline LSX support. + // + + int hwcap = getauxval(AT_HWCAP); + bool cap_lasx = hwcap & HWCAP_LOONGARCH_LASX; + bool cap_lsx = hwcap & HWCAP_LOONGARCH_LSX; + + if( cap_lasx ){ + this->GemmFloatKernel = MlasGemmFloatKernelLasx; + this->GemmDoubleKernel = MlasGemmDoubleKernelLasx; + this->ConvNchwFloatKernel = MlasConvNchwFloatKernelLasx; + this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelLasx; + this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelLasx; + this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelLasx; + this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelLasx; + this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelLasx; + this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelLasx; + this->ReduceMaximumF32Kernel = MlasReduceMaximumF32KernelLasx; + this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32KernelLasx; + this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32KernelLasx; + this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4Lasx; + + this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchLSX; + this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchLSX; + }else if( cap_lsx ){ + this->GemmFloatKernel = MlasGemmFloatKernelLSX; + this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchLSX; + this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchLSX; + this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4LSX; + this->GemmDoubleKernel = MlasGemmDoubleKernelLSX; + this->ConvNchwFloatKernel = MlasConvNchwFloatKernelLSX; + this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelLSX; + this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelLSX; + this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelLSX; + + this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelLSX; + this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelLSX; + this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelLSX; + this->ReduceMaximumF32Kernel = MlasReduceMaximumF32Kernel; + this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32Kernel; + this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32Kernel; + }else{ + this->ReduceMaximumF32Kernel = MlasReduceMaximumF32Kernel; + this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32Kernel; + this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32Kernel; + } + + this->NchwcBlockSize = 8; + // this->PreferredBufferAlignment = MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT; + + // this->MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT; + +#endif // MLAS_TARGET_LARCH64 + } size_t diff --git a/onnxruntime/core/mlas/lib/pooling.cpp b/onnxruntime/core/mlas/lib/pooling.cpp index 12128f6c700fd..50dcf19224510 100644 --- a/onnxruntime/core/mlas/lib/pooling.cpp +++ b/onnxruntime/core/mlas/lib/pooling.cpp @@ -1569,6 +1569,96 @@ Return Value: c -= 16; } +#elif defined(MLAS_LSX_INTRINSICS) + uint32_t val = 0x80808080; + const __m128i BitFlipVector = __lsx_vreplgr2vr_w(val); + if constexpr (std::is_unsigned::value) { + MLAS_UNREFERENCED_PARAMETER(BitFlipVector); + } + + while (c >= 32) { + + __m128i MaximumVector0 = __lsx_vldi(0); + __m128i MaximumVector1 = __lsx_vldi(0); + + for (size_t k = 0; k < KernelSize; k++) { + + __m128i InputVector0 = __lsx_vld((const __m128i*)&Input[k][ChannelOffset], 0); + __m128i InputVector1 = __lsx_vld((const __m128i*)&Input[k][ChannelOffset + 16], 0); + + if constexpr (std::is_signed::value) { + InputVector0 = __lsx_vxor_v(InputVector0, BitFlipVector); + InputVector1 = __lsx_vxor_v(InputVector1, BitFlipVector); + } + + MaximumVector0 = __lsx_vmax_bu(MaximumVector0, InputVector0); + MaximumVector1 = __lsx_vmax_bu(MaximumVector1, InputVector1); + } + + if constexpr (std::is_signed::value) { + MaximumVector0 = __lsx_vxor_v(MaximumVector0, BitFlipVector); + MaximumVector1 = __lsx_vxor_v(MaximumVector1, BitFlipVector); + } + + __lsx_vst(MaximumVector0, (__m128i*)&Output[0], 0); + __lsx_vst(MaximumVector1, (__m128i*)&Output[16], 0); + Output += 32; + + ChannelOffset += 32; + c -= 32; + } + + while (c >= 16) { + + __m128i MaximumVector0 = __lsx_vldi(0); + + for (size_t k = 0; k < KernelSize; k++) { + + __m128i InputVector0 = __lsx_vld((const __m128i*)&Input[k][ChannelOffset], 0); + + if constexpr (std::is_signed::value){ + InputVector0 = __lsx_vxor_v(InputVector0, BitFlipVector); + } + + MaximumVector0 = __lsx_vmax_bu(MaximumVector0, InputVector0); + } + + if constexpr (std::is_signed::value) { + MaximumVector0 = __lsx_vxor_v(MaximumVector0, BitFlipVector); + } + + __lsx_vst(MaximumVector0, (__m128i*)&Output[0], 0); + Output += 16; + + ChannelOffset += 16; + c -= 16; + } + + if (c >= 8) { + + __m128i MaximumVector0 = __lsx_vldi(0); + + for (size_t k = 0; k < KernelSize; k++) { + + __m128i InputVector0 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)&Input[k][ChannelOffset], 0), 0, 1); + + if constexpr (std::is_signed::value){ + InputVector0 = __lsx_vxor_v(InputVector0, BitFlipVector); + } + + MaximumVector0 = __lsx_vmax_bu(MaximumVector0, InputVector0); + } + + if constexpr (std::is_signed::value) { + MaximumVector0 = __lsx_vxor_v(MaximumVector0, BitFlipVector); + } + + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i*)&Output[0] , 0), __lsx_vpickve2gr_d(MaximumVector0, 0), 0), (__m128i*)&Output[0], 0); + Output += 8; + + ChannelOffset += 8; + c -= 8; + } #endif while (c > 0) { diff --git a/onnxruntime/core/mlas/lib/q4gemm.h b/onnxruntime/core/mlas/lib/q4gemm.h index b1b51dd53c4fc..d16798eb8945f 100644 --- a/onnxruntime/core/mlas/lib/q4gemm.h +++ b/onnxruntime/core/mlas/lib/q4gemm.h @@ -126,7 +126,7 @@ MlasQ4GemmOperation( size_t RowsRemaining = RangeCountM; while (RowsRemaining > 0) { -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) auto RowsHandled = GetMlasPlatform().GemmFloatKernel( a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true); #else diff --git a/onnxruntime/core/mlas/lib/qdwconv.cpp b/onnxruntime/core/mlas/lib/qdwconv.cpp index 924009ab5ccf4..59f6877f70d56 100644 --- a/onnxruntime/core/mlas/lib/qdwconv.cpp +++ b/onnxruntime/core/mlas/lib/qdwconv.cpp @@ -41,6 +41,10 @@ MlasConvDepthwiseKernel( #elif defined(MLAS_NEON_INTRINSICS) const uint8x8_t InputZeroPointVector = vdup_n_u8(uint8_t(InputZeroPoint)); const uint8x8_t FilterZeroPointVector = vdup_n_u8(uint8_t(FilterZeroPoint)); +#elif defined(MLAS_LSX_INTRINSICS) + const __m128i ZeroVector = __lsx_vldi(0); + const __m128i InputZeroPointVector = __lsx_vreplgr2vr_h(InputZeroPoint); + const __m128i FilterZeroPointVector = __lsx_vreplgr2vr_h(FilterZeroPoint); #endif while (OutputCount > 0) { @@ -141,6 +145,54 @@ MlasConvDepthwiseKernel( vst1q_s32(&Output[4], Accumulator1); Output += 8; + ChannelOffset += 8; + c -= 8; + } +#elif defined(MLAS_LSX_INTRINSICS) + + while (c >= 8) { + __m128i Accumulator0 = __lsx_vldi(0); + __m128i Accumulator1 = __lsx_vldi(0); + size_t ChannelKernelOffset = ChannelOffset; + + for (size_t k = 0; k < KernelSize; k++) { + __m128i InputVector = __lsx_vld((const __m128i*)&Input[k][ChannelOffset], 0); + __lsx_vinsgr2vr_d(InputVector, 0, 1); + __m128i FilterVector = + __lsx_vld((const __m128i*)&Filter[ChannelKernelOffset], 0); + __lsx_vinsgr2vr_d(FilterVector, 0, 1); + + if (std::is_signed::value) { + InputVector = __lsx_vsrai_h(__lsx_vilvl_b(InputVector, ZeroVector), 8); + } else { + InputVector = __lsx_vilvl_b(ZeroVector, InputVector ); + } + + if (std::is_signed::value) { + FilterVector = __lsx_vsrai_h(__lsx_vilvl_b(FilterVector, ZeroVector), 8); + } else { + FilterVector = __lsx_vilvl_b(ZeroVector, FilterVector); + } + + InputVector = __lsx_vsub_h(InputVector, InputZeroPointVector); + FilterVector = __lsx_vsub_h(FilterVector, FilterZeroPointVector); + + // N.B. Emulate PMULLD functionality on LSX by computing the low + // and high parts of the result and interleaving the results. + __m128i MultiplyLowWords = __lsx_vmul_h(InputVector, FilterVector); + __m128i MultiplyHighWords = __lsx_vmuh_h(InputVector, FilterVector); + __m128i Multiply0 = __lsx_vilvl_h(MultiplyHighWords, MultiplyLowWords); + __m128i Multiply1 = __lsx_vilvh_h(MultiplyHighWords, MultiplyLowWords); + + Accumulator0 = __lsx_vadd_w(Accumulator0, Multiply0); + Accumulator1 = __lsx_vadd_w(Accumulator1, Multiply1); + ChannelKernelOffset += Channels; + } + + __lsx_vst(Accumulator0, (__m128i*)&Output[0], 0); + __lsx_vst(Accumulator1, (__m128i*)&Output[4], 0); + Output += 8; + ChannelOffset += 8; c -= 8; } @@ -322,4 +374,4 @@ Return Value: ); } } -} \ No newline at end of file +} diff --git a/onnxruntime/core/mlas/lib/qgemm.h b/onnxruntime/core/mlas/lib/qgemm.h index 1fcd44e78a28c..75c17a6b5a177 100644 --- a/onnxruntime/core/mlas/lib/qgemm.h +++ b/onnxruntime/core/mlas/lib/qgemm.h @@ -871,7 +871,7 @@ MlasGemmQuantGetDispatch( GemmQuantDispatch = &MlasGemmQuantDispatchDefault; } -#if defined(MLAS_TARGET_AMD64_IX86) +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_LARCH64) if (!AIsSigned) { if (BIsSigned) { GemmQuantDispatch = GetMlasPlatform().GemmU8S8Dispatch; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_lsx.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_lsx.cpp new file mode 100644 index 0000000000000..7d5817335bd77 --- /dev/null +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_lsx.cpp @@ -0,0 +1,531 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. + +Licensed under the MIT License. + +Module Name: + + qgemm_kernel_lsx.cpp + +Abstract: + + This module implements QGEMM kernels for LSX. + +--*/ + +#include "mlasi.h" +#include "qgemm.h" +#include + +struct MLAS_GEMM_U8X8_KERNEL_LSX +{ + typedef int16_t PackedAType; + typedef int16_t PackedBType; + typedef uint8_t OffsetAType; + typedef int8_t OffsetBType; + + static constexpr size_t PackedK = 2; + static constexpr MLAS_GEMM_QUANT_STRIDES Strides{ 12, 128, 128 }; + static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{0, 0, 0}; +}; + +constexpr size_t MLAS_GEMM_U8X8_KERNEL_LSX::PackedK; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8X8_KERNEL_LSX::Strides; + +template<> +MLAS_FORCEINLINE constexpr +int32_t +MlasGemmQuantFixupZeroPointB( + int32_t ZeroPointB, + bool BIsSigned + ) +{ + if (!BIsSigned) { + ZeroPointB = MLAS_GEMM_U8X8_KERNEL_LSX::OffsetBType(ZeroPointB ^ 0x80); + } + + return ZeroPointB; +} + +template<> +void +MlasGemmQuantCopyPackA( + MLAS_GEMM_U8X8_KERNEL_LSX::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer, + bool AIsSigned + ) +{ + MLAS_UNREFERENCED_PARAMETER(AIsSigned); + const __m128i ZeroVector = __lsx_vrepli_d(0); + uint16_t val = 1; + const __m128i OnesWordBroadcast = __lsx_vreplgr2vr_h(val); + uint8_t PaddedMatrixAData[8] = { 0 }; + + // + // Process a single row of matrix A in a loop. + // + + while (CountM > 0) { + + const uint8_t* a = A; + size_t k = CountK; + __m128i ReductionVector = ZeroVector; + + // + // Zero extend the source bytes to 16-bits and write to the packed + // buffer. + // + // The packed buffer has the same data ordering as the source bytes, + // but CountK is aligned up to a multiple of 2 to maintain 32-bit + // alignment. All extra bytes are zero-padded. + // + // These 16-bit values are also accumulated into an intermediate per-row + // accumulator. CountK cannot be greater than 128 to avoid overflowing + // these signed 16-bit accumulators. + // + + while (k >= 8) { + + __m128i Bytes = __lsx_vld((const __m128i*) & a[0], 0); + __lsx_vinsgr2vr_d(Bytes, 0, 1); + __m128i Words = __lsx_vilvl_b(ZeroVector, Bytes); + + ReductionVector = __lsx_vadd_h(ReductionVector, Words); + + __lsx_vst(Words, (__m128i*) & D[0], 0); + + a += 8; + D += 8; + k -= 8; + } + + if (k > 0) { + + // + // Copy the remaining bytes to the zero padded stack buffer. + // + + uint8_t* padded = PaddedMatrixAData; + uint8_t* padded_end = padded + k; + + do { + padded[0] = a[0]; + padded++; + a++; + } while (padded < padded_end); + + __m128i Bytes = __lsx_vld((__m128i*)PaddedMatrixAData, 0); + __lsx_vinsgr2vr_d(Bytes, 0, 1); + __m128i Words = __lsx_vilvl_b(ZeroVector, Bytes); + + ReductionVector = __lsx_vadd_h(ReductionVector, Words); + + // + // Copy pairs of 16-bit values from the vector to the packed + // buffer and rotate the vector for the next iteration. + // + + for (size_t pairs = (k + 1) / 2; pairs > 0; pairs--) { + __lsx_vstelm_w(Words, (int32_t*)D, 0 , 0); + D += 2; + Words = __lsx_vshuf4i_w(Words, 0x39); //(0, 3, 2, 1) + } + } + + // + // Reduce the partial accumulators. + // + __m128i tmp1 = ZeroVector, tmp2 = ZeroVector; + tmp1 = __lsx_vmaddwev_w_h(tmp1, ReductionVector, OnesWordBroadcast); + tmp2 = __lsx_vmaddwod_w_h(tmp2, ReductionVector, OnesWordBroadcast); + ReductionVector = __lsx_vadd_w(tmp1, tmp2); + ReductionVector = __lsx_vadd_w(ReductionVector, + __lsx_vshuf4i_w(ReductionVector, 0xee)); + ReductionVector = __lsx_vadd_w(ReductionVector, + __lsx_vshuf4i_w(ReductionVector, 0x11)); + + __lsx_vstelm_w(ReductionVector, RowSumBuffer++, 0 , 0); + + A += lda; + CountM -= 1; + } +} + +MLAS_FORCEINLINE +void +MlasGemmU8X8CopyPackBProcessLSX( + MLAS_GEMM_U8X8_KERNEL_LSX::PackedBType* D, + __m128i BytesRow0, + __m128i BytesRow1, + __m128i BitFlipVector, + __m128i ColumnSums[2] +) +{ + __m128i BytesInterleaved = __lsx_vilvl_b(BytesRow1, BytesRow0); + + BytesInterleaved = __lsx_vxor_v(BytesInterleaved, BitFlipVector); + + __m128i WordsInterleaved0 = __lsx_vsrai_h(__lsx_vilvl_b(BytesInterleaved, BytesInterleaved), 8); + __m128i WordsInterleaved1 = __lsx_vsrai_h(__lsx_vilvh_b(BytesInterleaved, BytesInterleaved), 8); + + ColumnSums[0] = __lsx_vadd_h(ColumnSums[0], WordsInterleaved0); + ColumnSums[1] = __lsx_vadd_h(ColumnSums[1], WordsInterleaved1); + + __lsx_vst(WordsInterleaved0, (__m128i*) & D[0], 0); + __lsx_vst(WordsInterleaved1, (__m128i*) & D[8], 0); +} + +template<> +void +MlasGemmQuantCopyPackB( + MLAS_GEMM_U8X8_KERNEL_LSX::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned + ) +{ + uint16_t val = 1; + const __m128i OnesWordBroadcast = __lsx_vreplgr2vr_h(val); + const __m128i BitFlipVector = __lsx_vreplgr2vr_w(BIsSigned ? 0 : 0x80808080); + + // + // Process 8 columns of matrix B in a loop. + // + + while (CountN >= 8) { + + const uint8_t* b = B; + size_t k = CountK; + __m128i ColumnSums[2]; + + ColumnSums[0] = __lsx_vldi(0); + ColumnSums[1] = __lsx_vldi(0); + + // + // Interleave rows of matrix B and write to the packed buffer. + // + // These values are also zero-extended and accumulated into an + // intermediate per-column accumulator. CountK cannot be greater than + // 128 to avoid overflowing these signed 16-bit accumulators. + // + + while (k >= MLAS_GEMM_U8X8_KERNEL_LSX::PackedK) { + + __m128i BytesRow0 = __lsx_vld((const __m128i*) & b[0], 0); + __lsx_vinsgr2vr_d(BytesRow0, 0, 1); + __m128i BytesRow1 = __lsx_vld((const __m128i*) & b[ldb], 0); + __lsx_vinsgr2vr_d(BytesRow1, 0, 1); + + MlasGemmU8X8CopyPackBProcessLSX(D, BytesRow0, BytesRow1, BitFlipVector, ColumnSums); + + b += ldb * 2; + D += 16; + k -= 2; + } + + if (k > 0) { + + __m128i BytesRow0 = __lsx_vld((const __m128i*) & b[0], 0); + __lsx_vinsgr2vr_d(BytesRow0, 0, 1); + + MlasGemmU8X8CopyPackBProcessLSX(D, BytesRow0, BitFlipVector, BitFlipVector, ColumnSums); + + D += 16; + } + + __m128i tmp1, tmp2; + tmp1 = tmp2 = __lsx_vldi(0); + tmp1 = __lsx_vmaddwev_w_h(tmp1, ColumnSums[0], OnesWordBroadcast); + tmp2 = __lsx_vmaddwod_w_h(tmp2, ColumnSums[0], OnesWordBroadcast); + ColumnSums[0]= __lsx_vadd_w(tmp1, tmp2); + tmp1 = tmp2 = __lsx_vldi(0); + tmp1 = __lsx_vmaddwev_w_h(tmp1, ColumnSums[1], OnesWordBroadcast); + tmp2 = __lsx_vmaddwod_w_h(tmp2, ColumnSums[1], OnesWordBroadcast); + ColumnSums[1]= __lsx_vadd_w(tmp1, tmp2); + + __lsx_vst(ColumnSums[0], (__m128i*) & ColumnSumBuffer[0], 0); + __lsx_vst(ColumnSums[1], (__m128i*) & ColumnSumBuffer[4], 0); + ColumnSumBuffer += 8; + + B += 8; + CountN -= 8; + } + + // + // Process the remaining columns of matrix B. + // + + if (CountN > 0) { + + const uint8_t* b = B; + size_t k = CountK; + __m128i ColumnSums[2]; + uint8_t PaddedMatrixBData[16]; + + __lsx_vst(BitFlipVector, (__m128i*)PaddedMatrixBData, 0); + + ColumnSums[0] = __lsx_vldi(0); + ColumnSums[1] = __lsx_vldi(0); + + // + // Interleave rows of matrix B using an intermediate zero padded stack + // buffer and write to the packed buffer. + // + + while (k >= MLAS_GEMM_U8X8_KERNEL_LSX::PackedK) { + + const uint8_t* bcopy = b; + uint8_t* padded = PaddedMatrixBData; + uint8_t* padded_end = padded + CountN; + + do { + padded[0] = bcopy[0]; + padded[8] = bcopy[ldb]; + padded++; + bcopy++; + } while (padded < padded_end); + + __m128i BytesRow0 = __lsx_vld((__m128i*) & PaddedMatrixBData[0], 0); + __lsx_vinsgr2vr_d(BytesRow0, 0, 1); + __m128i BytesRow1 = __lsx_vld((__m128i*) & PaddedMatrixBData[8], 0); + __lsx_vinsgr2vr_d(BytesRow1, 0, 1); + + MlasGemmU8X8CopyPackBProcessLSX(D, BytesRow0, BytesRow1, BitFlipVector, ColumnSums); + + b += ldb * 2; + D += 16; + k -= 2; + } + + if (k > 0) { + + const uint8_t* bcopy = b; + uint8_t* padded = PaddedMatrixBData; + uint8_t* padded_end = padded + CountN; + + do { + padded[0] = bcopy[0]; + padded++; + bcopy++; + } while (padded < padded_end); + + __m128i BytesRow0 = __lsx_vld((__m128i*) & PaddedMatrixBData[0], 0); + __lsx_vinsgr2vr_d(BytesRow0, 0, 1); + + MlasGemmU8X8CopyPackBProcessLSX(D, BytesRow0, BitFlipVector, BitFlipVector, ColumnSums); + } + + __m128i tmp1, tmp2; + tmp1 = tmp2 = __lsx_vldi(0); + tmp1 = __lsx_vmaddwev_w_h(tmp1, ColumnSums[0], OnesWordBroadcast); + tmp2 = __lsx_vmaddwod_w_h(tmp2, ColumnSums[0], OnesWordBroadcast); + ColumnSums[0]= __lsx_vadd_w(tmp1, tmp2); + tmp1 = tmp2 = __lsx_vldi(0); + tmp1 = __lsx_vmaddwev_w_h(tmp1, ColumnSums[1], OnesWordBroadcast); + tmp2 = __lsx_vmaddwod_w_h(tmp2, ColumnSums[1], OnesWordBroadcast); + ColumnSums[1]= __lsx_vadd_w(tmp1, tmp2); + + __lsx_vst(ColumnSums[0], (__m128i*) & ColumnSumBuffer[0], 0); + __lsx_vst(ColumnSums[1], (__m128i*) & ColumnSumBuffer[4], 0); + } +} + +MLAS_FORCEINLINE +void +MlasGemmU8X8MultiplyAccumulateRowLSX( + __m128i ABroadcast, + const int16_t* B, + __m128i Accumulators[2] +) +{ + __m128i BElements0 = __lsx_vld((__m128i*) & B[0], 0); + __m128i BElements1 = __lsx_vld((__m128i*) & B[8], 0); + + __m128i tmp1, tmp2; + tmp1 = tmp2 = __lsx_vldi(0); + tmp1 = __lsx_vmaddwev_w_h(tmp1, BElements0, ABroadcast); + tmp2 = __lsx_vmaddwod_w_h(tmp2, BElements0, ABroadcast); + Accumulators[0] = __lsx_vadd_w(Accumulators[0], __lsx_vadd_w(tmp1, tmp2)); + tmp1 = tmp2 = __lsx_vldi(0); + tmp1 = __lsx_vmaddwev_w_h(tmp1, BElements1, ABroadcast); + tmp2 = __lsx_vmaddwod_w_h(tmp2, BElements1, ABroadcast); + Accumulators[1] = __lsx_vadd_w(Accumulators[1], __lsx_vadd_w(tmp1, tmp2)); +} + +template<> +size_t +MlasGemmQuantKernel( + const MLAS_GEMM_U8X8_KERNEL_LSX::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_LSX::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode + ) +{ + MLAS_UNREFERENCED_PARAMETER(CountM); + MLAS_UNREFERENCED_PARAMETER(ldc); + + while (CountN > 0) { + + __m128i Accumulators[2]; + + // + // Initialize the accumulators with the row and column sums. + // + + int32_t RowSumValue = RowSumBuffer[0]; + + if (ZeroPointB != nullptr) { + + int32_t ScaledRowSumBuffer[8]; + + for (size_t i = 0; i < 8; i++) { + ScaledRowSumBuffer[i] = RowSumValue * ZeroPointB[i]; + } + + ZeroPointB += 8; + + Accumulators[0] = __lsx_vld((__m128i*) & ScaledRowSumBuffer[0], 0); + Accumulators[1] = __lsx_vld((__m128i*) & ScaledRowSumBuffer[4], 0); + + } + else { + + Accumulators[0] = __lsx_vreplgr2vr_w(RowSumValue); + Accumulators[1] = Accumulators[0]; + } + + Accumulators[0] = __lsx_vadd_w(Accumulators[0], __lsx_vld((const __m128i*) & ColumnSumBuffer[0], 0)); + Accumulators[1] = __lsx_vadd_w(Accumulators[1], __lsx_vld((const __m128i*) & ColumnSumBuffer[4], 0)); + ColumnSumBuffer += 8; + + // + // Broadcast each pair of 16-bit values from the matrix A and multiply + // with the pair of 16-bit values from matrix B, and add the 32-bit + // intermediate into the accumulator registers. + // + + const int16_t* a = A; + size_t k = PackedCountK; + + while (k >= 4) { + + __m128i AElements = __lsx_vld((__m128i*)a, 0); + __m128i ABroadcast; + + ABroadcast = __lsx_vreplvei_w(AElements, 0); + MlasGemmU8X8MultiplyAccumulateRowLSX(ABroadcast, &B[0], Accumulators); + + ABroadcast = __lsx_vreplvei_w(AElements, 1); + MlasGemmU8X8MultiplyAccumulateRowLSX(ABroadcast, &B[16], Accumulators); + + ABroadcast = __lsx_vreplvei_w(AElements, 2); + MlasGemmU8X8MultiplyAccumulateRowLSX(ABroadcast, &B[32], Accumulators); + + ABroadcast = __lsx_vreplvei_w(AElements, 3); + MlasGemmU8X8MultiplyAccumulateRowLSX(ABroadcast, &B[48], Accumulators); + + a += 4 * 2; + B += 4 * 16; + k -= 4; + } + + while (k > 0) { + + __m128i ABroadcast = __lsx_vldrepl_w((int32_t*)a, 0); + MlasGemmU8X8MultiplyAccumulateRowLSX(ABroadcast, &B[0], Accumulators); + + a += 2; + B += 16; + k -= 1; + } + + // + // Output the accumulator block after optionally accumulating the values + // from matrix C. + // + + if (CountN >= 8) { + + if (!ZeroMode) { + Accumulators[0] = __lsx_vadd_w(Accumulators[0], __lsx_vld((__m128i*) & C[0], 0)); + Accumulators[1] = __lsx_vadd_w(Accumulators[1], __lsx_vld((__m128i*) & C[4], 0)); + } + + __lsx_vst(Accumulators[0], (__m128i*) & C[0], 0); + __lsx_vst(Accumulators[1], (__m128i*) & C[4], 0); + + C += 8; + CountN -= 8; + + } + else { + + // + // Output the remaining partial output block. + // + + if ((CountN & 4) != 0) { + + if (!ZeroMode) { + Accumulators[0] = __lsx_vadd_w(Accumulators[0], __lsx_vld((__m128i*) & C[0], 0)); + } + + __lsx_vst(Accumulators[0], (__m128i*) & C[0], 0); + C += 4; + + Accumulators[0] = Accumulators[1]; + } + + if ((CountN & 2) != 0) { + + if (!ZeroMode) { + Accumulators[0] = __lsx_vadd_w(Accumulators[0], __lsx_vinsgr2vr_d(__lsx_vld((__m128i*) & C[0], 0), 0, 1)); + } + + *((uint64_t *)&C[0]) = __lsx_vpickve2gr_d(Accumulators[0], 0); + C += 2; + + Accumulators[0] = __lsx_vshuf4i_w(Accumulators[0], 0xee); + } + + if ((CountN & 1) != 0) { + + int32_t AccumulatorValue = __lsx_vpickve2gr_w(Accumulators[0], 0); + + if (!ZeroMode) { + AccumulatorValue += C[0]; + } + + C[0] = AccumulatorValue; + } + + CountN = 0; + } + } + + return 1; +} + +const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchLSX = { + MlasGemmQuantOperation, + nullptr, + nullptr, + MLAS_GEMM_U8X8_KERNEL_LSX::PackedK, + 0, + 1 // aLSXmbly kernel M stride +}; diff --git a/onnxruntime/core/mlas/lib/qladd.cpp b/onnxruntime/core/mlas/lib/qladd.cpp index 971ea0161d7af..5dafa17c2ae66 100644 --- a/onnxruntime/core/mlas/lib/qladd.cpp +++ b/onnxruntime/core/mlas/lib/qladd.cpp @@ -552,6 +552,119 @@ MlasQLinearAddKernelHelper( InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); } } +#elif defined(MLAS_LSX_INTRINSICS) + +template +static +void +MlasQLinearAddKernelHelper( + const DataType* InputA, + float ScaleA, + int32_t ZeroPointA, + const DataType* InputB, + float ScaleB, + int32_t ZeroPointB, + float ScaleC, + int32_t ZeroPointC, + DataType* OutputC, + size_t N + ) +{ + const float ScaleRatio_AC = ScaleA / ScaleC; + const float ScaleRatio_BC = ScaleB / ScaleC; + const auto VectorScaleRatio_AC = MlasBroadcastFloat32x4(ScaleRatio_AC); + const auto VectorScaleRatio_BC = MlasBroadcastFloat32x4(ScaleRatio_BC); + auto VectorFixedPart = MlasBroadcastFloat32x4((float)ZeroPointC - (ScaleRatio_AC * ZeroPointA + ScaleRatio_BC * ZeroPointB)); + + MLAS_FLOAT32X4 va_lo, va_hi, vb_lo, vb_hi; + if (IsScalarB) { + float tmp_f = (float)*InputB; + uint32_t *tmp_p = (uint32_t *)&tmp_f; + vb_lo = MlasReinterpretAsFloat32x4(__lsx_vreplgr2vr_w(*tmp_p)); + VectorFixedPart = __lsx_vfmadd_s(vb_lo, VectorScaleRatio_BC, VectorFixedPart); + } + + __m128i tmp, tmp1; + + while (N >= 8) { + const auto va_low_half = __lsx_vinsgr2vr_d(__lsx_vld((const MLAS_INT32X4*)InputA, 0), 0 ,1); + const auto va_i16x8 = __lsx_vilvl_b(va_low_half, va_low_half); + InputA += 8; + va_lo = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvl_h(va_i16x8, va_i16x8), 24)); + va_hi = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvh_h(va_i16x8, va_i16x8), 24)); + + if (!IsScalarB) { + const auto vb_low_half = __lsx_vinsgr2vr_d(__lsx_vld((const MLAS_INT32X4*)InputB, 0), 0 ,1); + const auto vb_i16x8 = __lsx_vilvl_b(vb_low_half, vb_low_half); + InputB += 8; + vb_lo = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvl_h(vb_i16x8, vb_i16x8), 24)); + vb_hi = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvh_h(vb_i16x8, vb_i16x8), 24)); + } + + MLAS_INT32X4 r_lo, r_hi; + if (IsScalarB) { + r_lo = __lsx_vftint_w_s(__lsx_vfmadd_s(va_lo, VectorScaleRatio_AC, VectorFixedPart)); + r_hi = __lsx_vftint_w_s(__lsx_vfmadd_s(va_hi, VectorScaleRatio_AC, VectorFixedPart)); + } else { + r_lo = __lsx_vftint_w_s(__lsx_vfadd_s(__lsx_vfmadd_s(va_lo, VectorScaleRatio_AC, VectorFixedPart), __lsx_vfmul_s(vb_lo, VectorScaleRatio_BC))); + r_hi = __lsx_vftint_w_s(__lsx_vfadd_s(__lsx_vfmadd_s(va_hi, VectorScaleRatio_AC, VectorFixedPart), __lsx_vfmul_s(vb_hi, VectorScaleRatio_BC))); + } + tmp = __lsx_vsat_w(r_lo, 15); + tmp1 = __lsx_vsat_w(r_hi, 15); + const auto vc_i16x8 = __lsx_vpickev_h(tmp1, tmp); + + MLAS_INT32X4 vc = MlasPackS16_128(vc_i16x8, vc_i16x8); + + N -= 8; + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((MLAS_INT32X4*)OutputC, 0), __lsx_vpickve2gr_d(vc, 0), 0), (MLAS_INT32X4*)OutputC, 0); + OutputC += 8; + } + + if (N > 0) { + uint8_t TailData[8] = { 0 }; + + MlasCopyTailBytes(TailData, (const uint8_t*)InputA, N); + const auto va_low_half = __lsx_vinsgr2vr_d(__lsx_vld((const MLAS_INT32X4*)TailData, 0), 0 ,1); + const auto va_i16x8 = __lsx_vilvl_b(va_low_half, va_low_half); + va_lo = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvl_h(va_i16x8, va_i16x8), 24)); + va_hi = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvh_h(va_i16x8, va_i16x8), 24)); + + if (!IsScalarB) { + MlasCopyTailBytes(TailData, (const uint8_t*)InputB, N); + const auto vb_low_half = __lsx_vinsgr2vr_d(__lsx_vld((const MLAS_INT32X4*)TailData, 0), 0 ,1); + const auto vb_i16x8 = __lsx_vilvl_b(vb_low_half, vb_low_half); + vb_lo = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvl_h(vb_i16x8, vb_i16x8), 24)); + vb_hi = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvh_h(vb_i16x8, vb_i16x8), 24)); + } + + MLAS_INT32X4 r_lo, r_hi; + if (IsScalarB) { + r_lo = __lsx_vftint_w_s(__lsx_vfmadd_s(va_lo, VectorScaleRatio_AC, VectorFixedPart)); + r_hi = __lsx_vftint_w_s(__lsx_vfmadd_s(va_hi, VectorScaleRatio_AC, VectorFixedPart)); + } else { + r_lo = __lsx_vftint_w_s(__lsx_vfadd_s(__lsx_vfmadd_s(va_lo, VectorScaleRatio_AC, VectorFixedPart), __lsx_vfmul_s(vb_lo, VectorScaleRatio_BC))); + r_hi = __lsx_vftint_w_s(__lsx_vfadd_s(__lsx_vfmadd_s(va_hi, VectorScaleRatio_AC, VectorFixedPart), __lsx_vfmul_s(vb_hi, VectorScaleRatio_BC))); + } + tmp = __lsx_vsat_w(r_lo, 15); + tmp1 = __lsx_vsat_w(r_hi, 15); + const auto vc_i16x8 = __lsx_vpickev_h(tmp1, tmp); + + MLAS_INT32X4 vc = MlasPackS16_128(vc_i16x8, vc_i16x8); + + if (N & 4) { + __lsx_vstelm_w(vc, (int*)OutputC, 0, 0); + N -= 4; + OutputC += 4; + vc = __lsx_vshuf4i_w(vc, 0x39); //_MM_SHUFFLE(0, 3, 2, 1) + } + + uint32_t PackedValueC = (uint32_t)__lsx_vpickve2gr_w(vc, 0); + for (size_t i = 0; i < N; ++i) { + *((uint8_t*)OutputC + i) = (uint8_t)PackedValueC; + PackedValueC >>= 8; + } + } +} #else template diff --git a/onnxruntime/core/mlas/lib/qladd.h b/onnxruntime/core/mlas/lib/qladd.h index 8c05a6185324a..94568941a5660 100644 --- a/onnxruntime/core/mlas/lib/qladd.h +++ b/onnxruntime/core/mlas/lib/qladd.h @@ -463,5 +463,132 @@ MlasPackS16_128( { return reinterpret_cast(vec_packs(a, b)); } +#elif defined(MLAS_LSX_INTRINSICS) +#define LSX_DBG 1 +template +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasShiftRightInt32( + MLAS_INT32X4 v, + int imm + ); + +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasShiftRightInt32( + MLAS_INT32X4 v, + int imm + ) +{ +#if LSX_DBG + MLAS_INT32X4 imm_v = __lsx_vreplgr2vr_w(imm); + return __lsx_vsra_w(v, imm_v); +#else + return __lsx_vsrai_w(v, imm); +#endif +} + +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasShiftRightInt32( + MLAS_INT32X4 v, + int imm + ) +{ +#if LSX_DBG + MLAS_INT32X4 imm_v = __lsx_vreplgr2vr_w(imm); + return __lsx_vsrl_w(v, imm_v); +#else + return __lsx_vsrli_w(v, imm); +#endif +} + +template +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasShiftRightInt16( + MLAS_INT32X4 v, + int imm + ); + +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasShiftRightInt16( + MLAS_INT32X4 v, + int imm + ) +{ +#if LSX_DBG + MLAS_INT32X4 imm_v = __lsx_vreplgr2vr_h(imm); + return __lsx_vsra_h(v, imm_v); +#else + return __lsx_vsrai_h(v, imm); +#endif +} + +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasShiftRightInt16( + MLAS_INT32X4 v, + int imm + ) +{ +#if LSX_DBG + MLAS_INT32X4 imm_v = __lsx_vreplgr2vr_h(imm); + return __lsx_vsrl_h(v, imm_v); +#else + return __lsx_vsrli_h(v, imm); +#endif +} + +template +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasPackS16_128( + MLAS_INT32X4 a, + MLAS_INT32X4 b + ); + +template <> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasPackS16_128( + MLAS_INT32X4 a, + MLAS_INT32X4 b + ) +{ + // return _mm_packus_epi16(a, b); + __m128i zero = __lsx_vldi(0); + __m128i tmp, tmp2, tmp3; + + tmp = __lsx_vmax_h(zero, a); + tmp2 = __lsx_vsat_hu(tmp, 7); + + tmp = __lsx_vmax_h(zero, b); + tmp3 = __lsx_vsat_hu(tmp, 7); + return __lsx_vpickev_b(tmp3, tmp2); + +} + +template <> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasPackS16_128( + MLAS_INT32X4 a, + MLAS_INT32X4 b + ) +{ + // return _mm_packs_epi16(a, b); + __m128i tmp, tmp1; + + tmp = __lsx_vsat_h(a, 7); + tmp1 = __lsx_vsat_h(b, 7); + return __lsx_vpickev_b(tmp1, tmp); + +} #endif diff --git a/onnxruntime/core/mlas/lib/qlgavgpool.cpp b/onnxruntime/core/mlas/lib/qlgavgpool.cpp index 1c2be0a833a3e..e44d7ad25c446 100644 --- a/onnxruntime/core/mlas/lib/qlgavgpool.cpp +++ b/onnxruntime/core/mlas/lib/qlgavgpool.cpp @@ -689,6 +689,316 @@ MlasQLinearGlobalAveragePoolNhwcSingleBatch( Output_zero_point, 0, 0, 1, Channels); } +#elif defined(MLAS_LSX_INTRINSICS) + +template +void MLASCALL +MlasQLinearGlobalAveragePoolNchw( + const T8Bits* Input, + float ScaleInput, + int32_t ZeroPointInput, + T8Bits* Output, + float ScaleOutput, + int32_t ZeroPointOutput, + size_t Channels, + size_t ImageSize, + int32_t* AccumulateBuffer + ) +{ + float scale = CheckQLinearGlobalAveragePoolScaleAndSize(ScaleInput, ScaleOutput, ImageSize); + const int32_t bias[] = {-ZeroPointInput * static_cast(ImageSize), 0, 0, 0}; + const auto vbias = __lsx_vld((const __m128i*)&bias, 0); + const auto vzero = __lsx_vldi(0); + uint8_t buffer[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + + int32_t* sum_buffer = AccumulateBuffer; + for (size_t c = Channels; c > 0; c--) { + + __m128i vacc_lo = vbias; + __m128i vacc_hi = vzero; + auto Len = ImageSize; + for (; Len >= 32; Len -= 32) { + + const __m128i vi0 = __lsx_vld((const __m128i*)Input, 0); + __lsx_vinsgr2vr_d(vi0, 0, 1); + const __m128i vi1 = __lsx_vld((const __m128i*)(Input + 8), 0); + __lsx_vinsgr2vr_d(vi1, 0, 1); + const __m128i vi2 = __lsx_vld((const __m128i*)(Input + 16), 0); + __lsx_vinsgr2vr_d(vi2, 0, 1); + const __m128i vi3 = __lsx_vld((const __m128i*)(Input + 24), 0); + __lsx_vinsgr2vr_d(vi3, 0, 1); + + if constexpr (std::is_signed::value) { + + const __m128i vxi0 = __lsx_vsrai_h(__lsx_vilvl_b(vi0, vzero), 8); + const __m128i vxi1 = __lsx_vsrai_h(__lsx_vilvl_b(vi1, vzero), 8); + const __m128i vxi2 = __lsx_vsrai_h(__lsx_vilvl_b(vi2, vzero), 8); + const __m128i vxi3 = __lsx_vsrai_h(__lsx_vilvl_b(vi3, vzero), 8); + const __m128i vsum = __lsx_vadd_h(__lsx_vadd_h(vxi0, vxi1), + __lsx_vadd_h(vxi2, vxi3)); + vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vsrai_w(__lsx_vilvl_h(vsum, vzero), 16)); + vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vsrai_w(__lsx_vilvh_h(vsum, vzero), 16)); + } else { + + const __m128i vxi0 = __lsx_vilvl_b(vzero, vi0); + const __m128i vxi1 = __lsx_vilvl_b(vzero, vi1); + const __m128i vxi2 = __lsx_vilvl_b(vzero, vi2); + const __m128i vxi3 = __lsx_vilvl_b(vzero, vi3); + const __m128i vsum = __lsx_vadd_h(__lsx_vadd_h(vxi0, vxi1), + __lsx_vadd_h(vxi2, vxi3)); + vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vilvl_h(vzero, vsum)); + vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vilvh_h(vzero, vsum)); + } + + Input += 32; + } + for (; Len >= 8; Len -= 8) { + + if constexpr (std::is_signed::value) { + + const __m128i vsum = __lsx_vsrai_h(__lsx_vilvl_b(__lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)Input, 0), 0, 1), vzero), 8); + vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vsrai_w(__lsx_vilvl_h(vsum, vzero), 16)); + vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vsrai_w(__lsx_vilvh_h(vsum, vzero), 16)); + } else { + + const __m128i vsum = __lsx_vilvl_b(vzero, __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)Input, 0), 0, 1)); + vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vilvl_h(vzero, vsum)); + vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vilvh_h(vzero, vsum)); + } + + Input += 8; + } + if (Len > 0) { + + memcpy(buffer, Input, Len); + + if constexpr (std::is_signed::value) { + + const __m128i vsum = __lsx_vsrai_h(__lsx_vilvl_b(__lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)buffer, 0), 0, 1), vzero), 8); + vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vsrai_w(__lsx_vilvl_h(vsum, vzero), 16)); + vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vsrai_w(__lsx_vilvh_h(vsum, vzero), 16)); + } else { + + const __m128i vsum = __lsx_vilvl_b(vzero, __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)buffer, 0), 0, 1)); + vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vilvl_h(vzero, vsum)); + vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vilvh_h(vzero, vsum)); + } + + Input += Len; + } + + __m128i vacc = __lsx_vadd_w(vacc_lo, vacc_hi); // [ D C | B A ] + __m128i vshuf = __lsx_vshuf4i_w(vacc, 0xb1); // [ C D | A B ] _MM_SHUFFLE(2, 3, 0, 1) + __m128i vsums = __lsx_vadd_w(vacc, vshuf); // [ D+C C+D | B+A A+B ] + vshuf = __lsx_vshuf4i_w(vsums, 0x4e); // [ B+A A+B | D+C C+D ] _MM_SHUFFLE(1, 0, 3, 2) + vsums = __lsx_vadd_w(vsums, vshuf); + __lsx_vstelm_w(vsums, sum_buffer++, 0 , 0); + } + + MlasRequantizeOutput(AccumulateBuffer, Channels, Output, Channels, nullptr, &scale, false, + static_cast(ZeroPointOutput), 0, 0, 1, Channels); +} + +template +MLAS_FORCEINLINE +void +MlasQLinearGlobalAveragePoolNhwcSingleBatch( + const T8Bits* Input, + T8Bits* Output, + const T8Bits* LastOf8, + size_t ImageSize, + size_t Channels, + size_t Stride, + int32_t Bias, + float Scale, + T8Bits Output_zero_point, + int32_t* AccumulateBuffer, + const T8Bits* ZeroBuffer + ) +{ + + constexpr size_t PixelsPerIteration = 7; +#define LOAD_FULL_CHANNELS() \ + const __m128i vi0 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i0, 0), 0 , 1); \ + i0 += 8; \ + const __m128i vi1 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i1, 0), 0 , 1); \ + i1 += 8; \ + const __m128i vi2 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i2, 0), 0 , 1); \ + i2 += 8; \ + const __m128i vi3 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i3, 0), 0 , 1); \ + i3 += 8; \ + const __m128i vi4 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i4, 0), 0 , 1); \ + i4 += 8; \ + const __m128i vi5 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i5, 0), 0 , 1); \ + i5 += 8; \ + const __m128i vi6 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i6, 0), 0 , 1); \ + i6 += 8 + +#define CALCULATE_ACCUMULATE_VECTORS() \ + __m128i vacc_lo = finish_one_pass ? __lsx_vld((__m128i*)acc, 0) : vbias; \ + __m128i vacc_hi = finish_one_pass ? __lsx_vld(((__m128i*)acc) + 1, 0) : vbias; \ + __m128i vxi0; \ + __m128i vxi1; \ + __m128i vxi2; \ + __m128i vxi3; \ + __m128i vxi4; \ + __m128i vxi5; \ + __m128i vxi6; \ + if constexpr (std::is_signed::value) { \ + vxi0 = __lsx_vsrai_h(__lsx_vilvl_b(vi0, vzero), 8); \ + vxi1 = __lsx_vsrai_h(__lsx_vilvl_b(vi1, vzero), 8); \ + vxi2 = __lsx_vsrai_h(__lsx_vilvl_b(vi2, vzero), 8); \ + vxi3 = __lsx_vsrai_h(__lsx_vilvl_b(vi3, vzero), 8); \ + vxi4 = __lsx_vsrai_h(__lsx_vilvl_b(vi4, vzero), 8); \ + vxi5 = __lsx_vsrai_h(__lsx_vilvl_b(vi5, vzero), 8); \ + vxi6 = __lsx_vsrai_h(__lsx_vilvl_b(vi6, vzero), 8); \ + } else { \ + vxi0 = __lsx_vilvl_b(vzero, vi0); \ + vxi1 = __lsx_vilvl_b(vzero, vi1); \ + vxi2 = __lsx_vilvl_b(vzero, vi2); \ + vxi3 = __lsx_vilvl_b(vzero, vi3); \ + vxi4 = __lsx_vilvl_b(vzero, vi4); \ + vxi5 = __lsx_vilvl_b(vzero, vi5); \ + vxi6 = __lsx_vilvl_b(vzero, vi6); \ + } \ + const __m128i vsum01 = __lsx_vadd_h(vxi0, vxi1); \ + const __m128i vsum23 = __lsx_vadd_h(vxi2, vxi3); \ + const __m128i vsum45 = __lsx_vadd_h(vxi4, vxi5); \ + const __m128i vsum016 = __lsx_vadd_h(vsum01, vxi6); \ + const __m128i vsum2345 = __lsx_vadd_h(vsum23, vsum45); \ + const __m128i vsum = __lsx_vadd_h(vsum016, vsum2345); \ + if constexpr (std::is_signed::value) { \ + vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vsrai_w(__lsx_vilvl_h(vsum, vzero), 16)); \ + vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vsrai_w(__lsx_vilvh_h(vsum, vzero), 16)); \ + } else { \ + vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vilvl_h(vzero, vsum)); \ + vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vilvh_h(vzero, vsum)); \ + } + + + T8Bits tail[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + bool finish_one_pass = false; + const __m128i vbias = __lsx_vreplgr2vr_w(Bias); + const __m128i vzero = __lsx_vldi(0); + size_t step_next_group = PixelsPerIteration * Stride - (Channels & ~size_t{7}); + + const T8Bits* i0 = Input; + const T8Bits* i1 = i0 + Stride; + const T8Bits* i2 = i1 + Stride; + const T8Bits* i3 = i2 + Stride; + const T8Bits* i4 = i0 + Stride * 4; + const T8Bits* i5 = i4 + Stride; + const T8Bits* i6 = i5 + Stride; + + for (; ImageSize > PixelsPerIteration; ImageSize -= PixelsPerIteration) { + + int32_t* acc = AccumulateBuffer; + size_t c = Channels; + for (; c >= 8; c -= 8) { + + LOAD_FULL_CHANNELS(); + + CALCULATE_ACCUMULATE_VECTORS(); + + __lsx_vst(vacc_lo, (__m128i*)acc, 0); + __lsx_vst(vacc_hi, ((__m128i*)acc) + 1, 0); + acc += 8; + } + if (c > 0) { + const __m128i vi0 = + __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i0 >= LastOf8 ? memcpy(tail, i0, c) : i0), 0), 0 ,1); + const __m128i vi1 = + __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i1 >= LastOf8 ? memcpy(tail, i1, c) : i1), 0), 0 ,1); + const __m128i vi2 = + __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i2 >= LastOf8 ? memcpy(tail, i2, c) : i2), 0), 0 ,1); + const __m128i vi3 = + __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i3 >= LastOf8 ? memcpy(tail, i3, c) : i3), 0), 0 ,1); + const __m128i vi4 = + __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i4 >= LastOf8 ? memcpy(tail, i4, c) : i4), 0), 0 ,1); + const __m128i vi5 = + __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i5 >= LastOf8 ? memcpy(tail, i5, c) : i5), 0), 0 ,1); + const __m128i vi6 = + __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i6 >= LastOf8 ? memcpy(tail, i6, c) : i6), 0), 0 ,1); + + CALCULATE_ACCUMULATE_VECTORS(); + + __lsx_vst(vacc_lo, (__m128i*)acc, 0); + __lsx_vst(vacc_hi, ((__m128i*)acc) + 1, 0); + } + finish_one_pass = true; + + i0 += step_next_group; + i1 += step_next_group; + i2 += step_next_group; + i3 += step_next_group; + i4 += step_next_group; + i5 += step_next_group; + i6 += step_next_group; + } + + if (ImageSize > 0) { + switch (ImageSize) { + case 1: + i1 = ZeroBuffer; + [[fallthrough]]; + case 2: + i2 = ZeroBuffer; + [[fallthrough]]; + case 3: + i3 = ZeroBuffer; + [[fallthrough]]; + case 4: + i4 = ZeroBuffer; + [[fallthrough]]; + case 5: + i5 = ZeroBuffer; + [[fallthrough]]; + case 6: + i6 = ZeroBuffer; + [[fallthrough]]; + default: + break; + } + + int32_t* acc = AccumulateBuffer; + size_t c = Channels; + for (; c >= 8; c -= 8) { + + LOAD_FULL_CHANNELS(); + + CALCULATE_ACCUMULATE_VECTORS(); + + __lsx_vst(vacc_lo, (__m128i*)acc, 0); + __lsx_vst(vacc_hi, ((__m128i*)acc) + 1, 0); + acc += 8; + } + + if (c > 0) { + const __m128i vi0 = + __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i0 >= LastOf8 ? memcpy(tail, i0, c) : i0), 0), 0 ,1); + const __m128i vi1 = __lsx_vinsgr2vr_d(__lsx_vld( + (const __m128i*)(1 < ImageSize && i1 >= LastOf8 ? memcpy(tail, i1, c) : i1), 0), 0, 1); + const __m128i vi2 = __lsx_vinsgr2vr_d(__lsx_vld( + (const __m128i*)(2 < ImageSize && i2 >= LastOf8 ? memcpy(tail, i2, c) : i2), 0), 0, 1); + const __m128i vi3 = __lsx_vinsgr2vr_d(__lsx_vld( + (const __m128i*)(3 < ImageSize && i3 >= LastOf8 ? memcpy(tail, i3, c) : i3), 0), 0, 1); + const __m128i vi4 = __lsx_vinsgr2vr_d(__lsx_vld( + (const __m128i*)(4 < ImageSize && i4 >= LastOf8 ? memcpy(tail, i4, c) : i4), 0), 0, 1); + const __m128i vi5 = __lsx_vinsgr2vr_d(__lsx_vld( + (const __m128i*)(5 < ImageSize && i5 >= LastOf8 ? memcpy(tail, i5, c) : i5), 0), 0, 1); + const __m128i vi6 = __lsx_vinsgr2vr_d(__lsx_vld( + (const __m128i*)(6 < ImageSize && i6 >= LastOf8 ? memcpy(tail, i6, c) : i6), 0), 0, 1); + + CALCULATE_ACCUMULATE_VECTORS(); + + __lsx_vst(vacc_lo, (__m128i*)acc, 0); + __lsx_vst(vacc_hi, ((__m128i*)acc) + 1, 0); + } + } + MlasRequantizeOutput(AccumulateBuffer, Channels, Output, Channels, nullptr, &Scale, false, + Output_zero_point, 0, 0, 1, Channels); +} + #else // Pure C++ Implementation @@ -771,7 +1081,7 @@ MlasQLinearGlobalAveragePoolNhwc( #endif -#if defined(MLAS_NEON_INTRINSICS) || defined(MLAS_SSE2_INTRINSICS) +#if defined(MLAS_NEON_INTRINSICS) || defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_LSX_INTRINSICS) template void diff --git a/onnxruntime/core/mlas/lib/qlmul.cpp b/onnxruntime/core/mlas/lib/qlmul.cpp index 4b8537f2b378f..38818e1190d21 100644 --- a/onnxruntime/core/mlas/lib/qlmul.cpp +++ b/onnxruntime/core/mlas/lib/qlmul.cpp @@ -377,6 +377,170 @@ MlasQLinearMulKernel( MLAS_UNREFERENCED_PARAMETER(ValueBVector); } +#elif defined(MLAS_LSX_INTRINSICS) + +template +MLAS_FORCEINLINE +static +__m128i +MlasExtendToS16( + __m128i Int8Vector, + __m128i ZeroVector + ); + +template <> +MLAS_FORCEINLINE +__m128i +MlasExtendToS16( + __m128i Int8Vector, + __m128i ZeroVector + ) +{ + return __lsx_vilvl_b(ZeroVector, Int8Vector); +} + +template <> +MLAS_FORCEINLINE +__m128i +MlasExtendToS16( + __m128i Int8Vector, + __m128i ZeroVector + ) +{ + return __lsx_vilvh_b(ZeroVector, Int8Vector); +} + +template <> +MLAS_FORCEINLINE +__m128i +MlasExtendToS16( + __m128i Int8Vector, + __m128i ZeroVector + ) +{ + MLAS_UNREFERENCED_PARAMETER(ZeroVector); + return __lsx_vsrai_h(__lsx_vilvl_b(Int8Vector, Int8Vector), 8); +} + +template <> +MLAS_FORCEINLINE +__m128i +MlasExtendToS16( + __m128i Int8Vector, + __m128i ZeroVector + ) +{ + MLAS_UNREFERENCED_PARAMETER(ZeroVector); + return __lsx_vsrai_h(__lsx_vilvh_b(Int8Vector, Int8Vector), 8); +} + +template +MLAS_FORCEINLINE +static +__m128i +MlasExtendToS16Debias( + __m128i Int8Vector, + __m128i ZeroVector, + __m128i VectorBias + ) +{ + return __lsx_vsub_h(MlasExtendToS16(Int8Vector, ZeroVector), VectorBias); +} + +MLAS_FORCEINLINE +static +__m128i +MlasQLinearMulVectorS16( + __m128i va_s16x8, + __m128i vb_s16x8, + __m128 VectorScaleRatio, + __m128 VectorZeroPointC + ) +{ + __m128i tmp, tmp1; + + const auto ab_lo = __lsx_vmul_h(va_s16x8, vb_s16x8); + const auto ab_hi = __lsx_vmuh_h(va_s16x8, vb_s16x8); + auto r_lo = __lsx_vilvl_h(ab_hi, ab_lo); + auto r_hi = __lsx_vilvh_h(ab_hi, ab_lo); + r_lo = __lsx_vftint_w_s(__lsx_vfmadd_s(__lsx_vffint_s_w(r_lo), VectorScaleRatio, VectorZeroPointC)); + r_hi = __lsx_vftint_w_s(__lsx_vfmadd_s(__lsx_vffint_s_w(r_hi), VectorScaleRatio, VectorZeroPointC)); + + tmp = __lsx_vsat_w(r_lo, 15); + tmp1 = __lsx_vsat_w(r_hi, 15); + return __lsx_vpickev_h(tmp1, tmp); +} + +template +static +void +MlasQLinearMulKernel( + const DataType* InputA, + float ScaleA, + int32_t ZeroPointA, + const DataType* InputB, + float ScaleB, + int32_t ZeroPointB, + float ScaleC, + int32_t ZeroPointC, + DataType* OutputC, + size_t N + ) +{ + const auto VectorZeroPointA = __lsx_vreplgr2vr_h((int16_t)ZeroPointA); + const auto VectorZeroPointB = __lsx_vreplgr2vr_h((int16_t)ZeroPointB); + const auto VectorZeroPointC = MlasBroadcastFloat32x4((float)ZeroPointC); + const auto VectorScaleRatio = MlasBroadcastFloat32x4(ScaleA * ScaleB / ScaleC); + const auto ZeroVector = __lsx_vldi(0); + + uint8_t TailDataA[16] = { 0 }; + uint8_t TailDataB[16] = { 0 }; + __m128i vb_lo_s16x8, vb_hi_s16x8; + + if (IsScalarB) { + vb_lo_s16x8 = __lsx_vsub_h(__lsx_vreplgr2vr_h((int16_t)*InputB), VectorZeroPointB); + vb_hi_s16x8 = vb_lo_s16x8; + } + + while (N > 0) { + if (N < 16) { + MlasCopyTailBytes(TailDataA, (const uint8_t*)InputA, N); + InputA = (const DataType*)TailDataA; + if (!IsScalarB) { + MlasCopyTailBytes(TailDataB, (const uint8_t*)InputB, N); + InputB = (const DataType*)TailDataB; + } + } + + const auto va_i8x16 = __lsx_vld((const MLAS_INT32X4*)InputA, 0); + InputA += 16; + const auto va_lo_s16x8 = MlasExtendToS16Debias(va_i8x16, ZeroVector, VectorZeroPointA); + const auto va_hi_s16x8 = MlasExtendToS16Debias(va_i8x16, ZeroVector, VectorZeroPointA); + + if (!IsScalarB) { + const auto vb_i8x16 = __lsx_vld((const MLAS_INT32X4*)InputB, 0); + InputB += 16; + vb_lo_s16x8 = MlasExtendToS16Debias(vb_i8x16, ZeroVector, VectorZeroPointB); + vb_hi_s16x8 = MlasExtendToS16Debias(vb_i8x16, ZeroVector, VectorZeroPointB); + } + + const auto vc_lo_s16x8 = MlasQLinearMulVectorS16(va_lo_s16x8, vb_lo_s16x8, VectorScaleRatio, VectorZeroPointC); + const auto vc_hi_s16x8 = MlasQLinearMulVectorS16(va_hi_s16x8, vb_hi_s16x8, VectorScaleRatio, VectorZeroPointC); + auto vc = MlasPackS16_128(vc_lo_s16x8, vc_hi_s16x8); + + if (N >= 16) { + __lsx_vst(vc, (__m128i*)OutputC, 0); + OutputC += 16; + N -= 16; + } else { + __lsx_vst(vc, (__m128i*)TailDataA, 0); + MlasCopyTailBytes((uint8_t*)OutputC, TailDataA, N); + N = 0; + } + } +} + + #else // Pure C++ implementation. diff --git a/onnxruntime/core/mlas/lib/quantize.cpp b/onnxruntime/core/mlas/lib/quantize.cpp index 133ad79594c55..ffecc2dbeff9e 100644 --- a/onnxruntime/core/mlas/lib/quantize.cpp +++ b/onnxruntime/core/mlas/lib/quantize.cpp @@ -20,7 +20,9 @@ Module Name: #include "mlasi.h" -#if defined(MLAS_NEON64_INTRINSICS) || defined(MLAS_SSE2_INTRINSICS) +#if defined(MLAS_NEON64_INTRINSICS) || defined(MLAS_SSE2_INTRINSICS) || \ + defined(MLAS_LSX_INTRINSICS) + #include // @@ -49,6 +51,9 @@ MlasQuantizeLinearVector( // is a NaN. FloatVector = vmaxnmq_f32(FloatVector, MinimumValueVector); FloatVector = vminnmq_f32(FloatVector, MaximumValueVector); +#elif defined(MLAS_LSX_INTRINSICS) + FloatVector = __lsx_vfmax_s(FloatVector, MinimumValueVector); + FloatVector = __lsx_vfmin_s(FloatVector, MaximumValueVector); #else // N.B. MINPS and MAXPS returns the value from the second vector if the // value from the first vector is a NaN. @@ -64,6 +69,9 @@ MlasQuantizeLinearVector( #if defined(MLAS_NEON64_INTRINSICS) auto IntegerVector = vcvtnq_s32_f32(FloatVector); IntegerVector = vaddq_s32(IntegerVector, ZeroPointVector); +#elif defined(MLAS_LSX_INTRINSICS) + auto IntegerVector = __lsx_vftint_w_s(FloatVector); + IntegerVector = __lsx_vadd_w(IntegerVector, ZeroPointVector); #else // N.B. Assumes MXCSR has been configured with the default rounding mode of // "round to nearest even". @@ -213,6 +221,121 @@ MlasQuantizeLinearStoreSingleValue( vst1q_lane_s16(Output, vreinterpretq_s16_s32(IntegerVector), 0); } +#elif defined(MLAS_LSX_INTRINSICS) +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasQuantizeLinearPackBytes( + MLAS_INT32X4 integervector + ) +{ + + __m128i zero = __lsx_vldi(0); + __m128i tmp, tmp2; + + tmp = __lsx_vmax_h(integervector, zero); + tmp2 = __lsx_vsat_hu(tmp, 7); + + integervector = __lsx_vpickev_b(tmp2, tmp2); + + + tmp = __lsx_vmax_h(integervector, zero); + tmp2 = __lsx_vsat_hu(tmp, 7); + + integervector = __lsx_vpickev_b(tmp2, tmp2); + return integervector; +} + +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasQuantizeLinearPackBytes( + MLAS_INT32X4 integervector + ) +{ + + __m128i tmp, tmp1; + + tmp = __lsx_vsat_h(integervector, 7); + tmp1 = __lsx_vsat_h(integervector, 7); + integervector = __lsx_vpickev_b(tmp1, tmp); + + tmp = __lsx_vsat_h(integervector, 7); + tmp1 = __lsx_vsat_h(integervector, 7); + integervector = __lsx_vpickev_b(tmp1, tmp); + return integervector; +} + +template +MLAS_FORCEINLINE +void +MlasQuantizeLinearStore4PackedValues( + MLAS_INT32X4 IntegerVector, + OutputType* Output + ) +{ + // Copies the lower 4 packed elements of the vector into memory (Output). + + if constexpr (std::is_same_v || std::is_same_v) { + __lsx_vstelm_w(IntegerVector, reinterpret_cast(Output), 0, 0); + } else { + static_assert(std::is_same_v || std::is_same_v); + + __lsx_vstelm_d(IntegerVector, reinterpret_cast(Output), 0, 0); + } +} + + +template +MLAS_FORCEINLINE +void +MlasQuantizeLinearStoreSingleValue( + MLAS_INT32X4 IntegerVector, + OutputType* Output + ) +{ + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v); + + // Copies the lower element of the vector into memory (Output). + // Expects that the 32-bit element in lane 0 is already within the valid numerical + // range of the OutputType. + *Output = static_cast(__lsx_vpickve2gr_w(IntegerVector, 0)); +} + +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasQuantizeLinearPackBytes( + MLAS_INT32X4 IntegerVector + ) +{ + __m128i zero = __lsx_vldi(0); + __m128i tmp, tmp2; + + tmp = __lsx_vmax_w(IntegerVector, zero); + tmp2 = __lsx_vsat_wu(tmp, 15); + + IntegerVector = __lsx_vpickev_h(tmp2, tmp2); + return IntegerVector; +} + +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasQuantizeLinearPackBytes( + MLAS_INT32X4 IntegerVector + ) +{ + __m128i tmp, tmp1; + + tmp = __lsx_vsat_w(IntegerVector, 15); + tmp1 = __lsx_vsat_w(IntegerVector, 15); + IntegerVector = __lsx_vpickev_h(tmp1, tmp); + return IntegerVector; +} #else template<> @@ -384,6 +507,8 @@ Return Value: #if defined(MLAS_NEON64_INTRINSICS) auto FloatVector = vld1q_dup_f32(Input + n); +#elif defined(MLAS_LSX_INTRINSICS) + MLAS_FLOAT32X4 FloatVector = (MLAS_FLOAT32X4)__lsx_vldrepl_w(Input+n, 0); #else auto FloatVector = _mm_load_ss(Input + n); #endif @@ -1362,6 +1487,286 @@ MlasRequantizeOutput( } } +#elif defined(MLAS_LSX_INTRINSICS) + +template +void +MlasRequantizeOutput( + const int32_t* Input, + size_t InputLeadingDimension, + OutputType* Output, + size_t OutputLeadingDimension, + const int32_t* Bias, + const float* Scale, + bool PerColumnScale, + OutputType ZeroPoint, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN + ) +{ + //TO BE CHECK + float min_f = float(std::numeric_limits::lowest() - ZeroPoint); + float max_f = float(std::numeric_limits::max() - ZeroPoint); + const __m128 PerMatrixScaleVector = PerColumnScale ? MlasReinterpretAsFloat32x4(__lsx_vldi(0)) : MlasReinterpretAsFloat32x4(__lsx_vldrepl_w(Scale, 0)); + const __m128 MinimumValueVector = MlasReinterpretAsFloat32x4(__lsx_vreplgr2vr_w( *((uint32_t*)&min_f))); + const __m128 MaximumValueVector = MlasReinterpretAsFloat32x4(__lsx_vreplgr2vr_w( *((uint32_t*)&max_f))); + const __m128i ZeroPointVector = __lsx_vreplgr2vr_w(ZeroPoint); + + if (nullptr != Bias) { + Bias += StartN; + } + if (PerColumnScale) { + Scale += StartN; + } + + Input += StartM * InputLeadingDimension + StartN; + Output += StartM * OutputLeadingDimension + StartN; + // + // Step through each row of the output matrix. + // + + while (CountM-- > 0) { + + const int32_t* bias = Bias; + const float* scale = PerColumnScale ? Scale : nullptr; + size_t n = CountN; + + auto* RowInput = Input; + auto* RowOutput = Output; + + // + // Process 16 columns of the matrices at a time. + // + + while (n >= 16) { + + // + // Load the input data and optionally add the per-column bias. + // + + __m128i IntegerVector0 = __lsx_vld((const __m128i*)&RowInput[0], 0); + __m128i IntegerVector1 = __lsx_vld((const __m128i*)&RowInput[4], 0); + __m128i IntegerVector2 = __lsx_vld((const __m128i*)&RowInput[8], 0); + __m128i IntegerVector3 = __lsx_vld((const __m128i*)&RowInput[12], 0); + RowInput += 16; + + if (bias != nullptr) { + IntegerVector0 = __lsx_vadd_w(IntegerVector0, __lsx_vld((const __m128i *)&bias[0], 0)); + IntegerVector1 = __lsx_vadd_w(IntegerVector1, __lsx_vld((const __m128i *)&bias[4], 0)); + IntegerVector2 = __lsx_vadd_w(IntegerVector2, __lsx_vld((const __m128i *)&bias[8], 0)); + IntegerVector3 = __lsx_vadd_w(IntegerVector3, __lsx_vld((const __m128i *)&bias[12], 0)); + bias += 16; + } + + // + // Convert to integer values to float and apply the per-tensor or + // per-column scaling. + // + + __m128 FloatVector0 = __lsx_vffint_s_w(IntegerVector0); + __m128 FloatVector1 = __lsx_vffint_s_w(IntegerVector1); + __m128 FloatVector2 = __lsx_vffint_s_w(IntegerVector2); + __m128 FloatVector3 = __lsx_vffint_s_w(IntegerVector3); + + if (scale != nullptr) { + + FloatVector0 = __lsx_vfmul_s(FloatVector0, MlasReinterpretAsFloat32x4(__lsx_vld((__m128i *)&scale[0], 0))); + FloatVector1 = __lsx_vfmul_s(FloatVector1, MlasReinterpretAsFloat32x4(__lsx_vld((__m128i *)&scale[4], 0))); + FloatVector2 = __lsx_vfmul_s(FloatVector2, MlasReinterpretAsFloat32x4(__lsx_vld((__m128i *)&scale[8], 0))); + FloatVector3 = __lsx_vfmul_s(FloatVector3, MlasReinterpretAsFloat32x4(__lsx_vld((__m128i *)&scale[12], 0))); + scale += 16; + + } else { + + FloatVector0 = __lsx_vfmul_s(FloatVector0, PerMatrixScaleVector); + FloatVector1 = __lsx_vfmul_s(FloatVector1, PerMatrixScaleVector); + FloatVector2 = __lsx_vfmul_s(FloatVector2, PerMatrixScaleVector); + FloatVector3 = __lsx_vfmul_s(FloatVector3, PerMatrixScaleVector); + } + FloatVector0 = __lsx_vfmax_s(FloatVector0, MinimumValueVector); + FloatVector1 = __lsx_vfmax_s(FloatVector1, MinimumValueVector); + FloatVector2 = __lsx_vfmax_s(FloatVector2, MinimumValueVector); + FloatVector3 = __lsx_vfmax_s(FloatVector3, MinimumValueVector); + + FloatVector0 = __lsx_vfmin_s(FloatVector0, MaximumValueVector); + FloatVector1 = __lsx_vfmin_s(FloatVector1, MaximumValueVector); + FloatVector2 = __lsx_vfmin_s(FloatVector2, MaximumValueVector); + FloatVector3 = __lsx_vfmin_s(FloatVector3, MaximumValueVector); + + IntegerVector0 = __lsx_vftint_w_s(FloatVector0); + IntegerVector1 = __lsx_vftint_w_s(FloatVector1); + IntegerVector2 = __lsx_vftint_w_s(FloatVector2); + IntegerVector3 = __lsx_vftint_w_s(FloatVector3); + + IntegerVector0 = __lsx_vadd_w(IntegerVector0, ZeroPointVector); + IntegerVector1 = __lsx_vadd_w(IntegerVector1, ZeroPointVector); + IntegerVector2 = __lsx_vadd_w(IntegerVector2, ZeroPointVector); + IntegerVector3 = __lsx_vadd_w(IntegerVector3, ZeroPointVector); + + __m128i WordVector0; + __m128i WordVector1; + __m128i ByteVector; + + if (std::is_signed::value) { + + __m128i tmp, tmp1; + tmp = __lsx_vsat_w(IntegerVector0, 15); + tmp1 = __lsx_vsat_w(IntegerVector1, 15); + WordVector0 = __lsx_vpickev_h(tmp1, tmp); + + tmp = __lsx_vsat_w(IntegerVector2, 15); + tmp1 = __lsx_vsat_w(IntegerVector3, 15); + WordVector1 = __lsx_vpickev_h(tmp1, tmp); + + tmp = __lsx_vsat_h(WordVector0, 7); + tmp1 = __lsx_vsat_h(WordVector1, 7); + ByteVector = __lsx_vpickev_b(tmp1, tmp); + + + } else { + + __m128i zero = __lsx_vldi(0); + __m128i tmp, tmp2, tmp3; + + tmp = __lsx_vmax_h(IntegerVector0, zero); + tmp2 = __lsx_vsat_hu(tmp, 7); + + tmp = __lsx_vmax_h(IntegerVector1, zero); + tmp3 = __lsx_vsat_hu(tmp, 7); + WordVector0 = __lsx_vpickev_b(tmp3, tmp2); + + tmp = __lsx_vmax_h(IntegerVector2, zero); + tmp2 = __lsx_vsat_hu(tmp, 7); + + tmp = __lsx_vmax_h(IntegerVector3, zero); + tmp3 = __lsx_vsat_hu(tmp, 7); + WordVector1 = __lsx_vpickev_b(tmp3, tmp2); + + tmp = __lsx_vmax_h(WordVector0, zero); + tmp2 = __lsx_vsat_hu(tmp, 7); + + tmp = __lsx_vmax_h(WordVector1, zero); + tmp3 = __lsx_vsat_hu(tmp, 7); + ByteVector = __lsx_vpickev_b(tmp3, tmp2); + + } + + __lsx_vst(ByteVector, (__m128i*)RowOutput, 0); + RowOutput += 16; + + n -= 16; + } + + // + // Process the remaining columns of the matrices. + // + + while (n > 0) { + + // + // Load the input data and optionally add the per-column bias. + // + + __m128i IntegerVector; + + if (n >= 4) { + + IntegerVector = __lsx_vld((const __m128i*)&RowInput[0], 0); + RowInput += 4; + + if (bias != nullptr) { + IntegerVector = __lsx_vadd_w(IntegerVector, __lsx_vld((const __m128i*)&bias[0], 0)); + bias += 4; + } + + } else { + + int32_t IntegerValue = *RowInput++; + + if (bias != nullptr) { + IntegerValue += *bias++; + } + IntegerVector = __lsx_vldrepl_w(&IntegerValue, 0); + } + + // + // Convert to integer values to float and apply the per-tensor or + // per-column scaling. + // + __m128 FloatVector = __lsx_vffint_s_w(IntegerVector); + __m128 ScaleVector; + + if (scale != nullptr) { + + if (n >= 4) { + ScaleVector = MlasReinterpretAsFloat32x4(__lsx_vld((__m128i *)scale, 0)); + scale += 4; + } else { + ScaleVector = (__m128)__lsx_vldrepl_w(scale, 0); + scale += 1; + } + + } else { + ScaleVector = PerMatrixScaleVector; + } + FloatVector = __lsx_vfmul_s(FloatVector, ScaleVector); + + FloatVector = __lsx_vfmax_s(FloatVector, MinimumValueVector); + FloatVector = __lsx_vfmin_s(FloatVector, MaximumValueVector); + + IntegerVector = __lsx_vftint_w_s(FloatVector); + IntegerVector = __lsx_vadd_w(IntegerVector, ZeroPointVector); + + if (std::is_signed::value) { + + __m128i tmp; + tmp = __lsx_vsat_w(IntegerVector, 15); + IntegerVector = __lsx_vpickev_h(tmp, tmp); + + tmp = __lsx_vsat_h(IntegerVector, 7); + IntegerVector = __lsx_vpickev_b(tmp, tmp); + + } else { + + __m128i zero = __lsx_vldi(0); + __m128i tmp, tmp2; + + tmp = __lsx_vmax_h(IntegerVector, zero); + tmp2 = __lsx_vsat_hu(tmp, 7); + IntegerVector = __lsx_vpickev_b(tmp2, tmp2); + + tmp = __lsx_vmax_h(IntegerVector, zero); + tmp2 = __lsx_vsat_hu(tmp, 7); + IntegerVector = __lsx_vpickev_b(tmp2, tmp2); + + } + + uint32_t OutputValue = uint32_t(__lsx_vpickve2gr_w(IntegerVector, 0)); + + if (n >= 4) { + + *reinterpret_cast(RowOutput) = OutputValue; + RowOutput += 4; + + n -= 4; + + } else { + + *RowOutput = uint8_t(OutputValue); + RowOutput += 1; + + n -= 1; + } + } + + // Next Row + Input += InputLeadingDimension; + Output += OutputLeadingDimension; + } +} + #else template diff --git a/onnxruntime/core/mlas/lib/reorder.cpp b/onnxruntime/core/mlas/lib/reorder.cpp index 99c1dbac3b692..b329ea2ffb149 100644 --- a/onnxruntime/core/mlas/lib/reorder.cpp +++ b/onnxruntime/core/mlas/lib/reorder.cpp @@ -180,6 +180,31 @@ Return Value: v[2] = _mm_movelh_ps(t[2], t[3]); v[3] = _mm_movehl_ps(t[3], t[2]); + MlasStoreFloat32x4(&D[ScatterStride * 0], v[0]); + MlasStoreFloat32x4(&D[ScatterStride * 1], v[1]); + MlasStoreFloat32x4(&D[ScatterStride * 2], v[2]); + MlasStoreFloat32x4(&D[ScatterStride * 3], v[3]); +#elif defined(MLAS_LSX_INTRINSICS) + + MLAS_FLOAT32X4 v[4]; + MLAS_FLOAT32X4 t[4]; + + v[0] = MlasLoadFloat32x4(&S[GatherStride * 0]); + v[1] = MlasLoadFloat32x4(&S[GatherStride * 1]); + v[2] = MlasLoadFloat32x4(&S[GatherStride * 2]); + v[3] = MlasLoadFloat32x4(&S[GatherStride * 3]); + + t[0] = (__m128)__lsx_vilvl_w((__m128i)v[1], (__m128i)v[0]); + t[2] = (__m128)__lsx_vilvh_w((__m128i)v[1], (__m128i)v[0]); + t[1] = (__m128)__lsx_vilvl_w((__m128i)v[3], (__m128i)v[2]); + t[3] = (__m128)__lsx_vilvh_w((__m128i)v[3], (__m128i)v[2]); + + + v[0] = (__m128)__lsx_vpickev_d((__m128i) t[1],(__m128i) t[0]); + v[1] = (__m128)__lsx_vpickod_d((__m128i) t[1],(__m128i) t[0]); + v[2] = (__m128)__lsx_vpickev_d((__m128i) t[3],(__m128i) t[2]); + v[3] = (__m128)__lsx_vpickod_d((__m128i) t[3],(__m128i) t[2]); + MlasStoreFloat32x4(&D[ScatterStride * 0], v[0]); MlasStoreFloat32x4(&D[ScatterStride * 1], v[1]); MlasStoreFloat32x4(&D[ScatterStride * 2], v[2]); @@ -456,7 +481,6 @@ Return Value: &TaskStart, &TasksRemaining); size_t TaskEnd = TaskStart + TasksRemaining; - // // Rebase the pointers to the source and destination buffers for this thread. // @@ -567,18 +591,17 @@ Return Value: WorkBlock.S = S; WorkBlock.D = D; - WorkBlock.OutputChannels = size_t(OutputShape[1]); WorkBlock.OutputSize = size_t(OutputShape[2]) * size_t(OutputShape[3]); const size_t BlockSize = MlasNchwcGetBlockSize(); const size_t TasksPerBatch = size_t(ceil(((float)WorkBlock.OutputChannels) / BlockSize)); const size_t BatchCount = size_t(OutputShape[0]); - const size_t TasksCount = BatchCount * TasksPerBatch; + const size_t TasksCount = BatchCount * TasksPerBatch; WorkBlock.TasksCount = TasksCount; // - // Schedule the operation across a set of worker threads if the output + // Schedule the operation across a set of worker threads if the output // tensor is sufficienly large. Limit the number of threads to at least // the number of available tasks. // @@ -590,7 +613,7 @@ Return Value: if (size_t(TargetThreadCount) > TasksCount) { TargetThreadCount = ptrdiff_t(TasksCount); } - } + } WorkBlock.TargetThreadCount = TargetThreadCount; MlasExecuteThreaded(MlasReorderOutputNchwThreaded, &WorkBlock, TargetThreadCount, ThreadPool); diff --git a/onnxruntime/core/mlas/lib/sgemm.cpp b/onnxruntime/core/mlas/lib/sgemm.cpp index 1ce64712d63dc..4d7a1ceb4eee7 100644 --- a/onnxruntime/core/mlas/lib/sgemm.cpp +++ b/onnxruntime/core/mlas/lib/sgemm.cpp @@ -472,7 +472,7 @@ Return Value: const float* b = B; size_t x = CountX; -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE* SgemmTransposePackB16x4Routine = GetMlasPlatform().TransposePackB16x4Routine; @@ -1061,7 +1061,7 @@ Return Value: size_t RowsHandled; -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) RowsHandled = GetMlasPlatform().GemmFloatKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode); #else if (ZeroMode) { diff --git a/onnxruntime/core/mlas/lib/snchwc.cpp b/onnxruntime/core/mlas/lib/snchwc.cpp index 74d65f934aaf5..f9cf1605787aa 100644 --- a/onnxruntime/core/mlas/lib/snchwc.cpp +++ b/onnxruntime/core/mlas/lib/snchwc.cpp @@ -101,7 +101,7 @@ Return Value: --*/ { -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) return GetMlasPlatform().NchwcBlockSize; #else return 1; @@ -674,7 +674,7 @@ struct MLAS_NCHWC_CONV_NCHWC_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const size_t BlockedOutputWidth = BlockSize * OutputWidth; -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) MLAS_CONV_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvNchwcFloatKernel; #else MLAS_CONV_FLOAT_KERNEL* Kernel = MlasConvNchwcFloatKernel; @@ -784,7 +784,7 @@ struct MLAS_NCHWC_CONV_NCHW_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const size_t BlockedOutputWidth = BlockSize * OutputWidth; -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) MLAS_CONV_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvNchwFloatKernel; #else MLAS_CONV_FLOAT_KERNEL* Kernel = MlasConvNchwFloatKernel; @@ -879,7 +879,7 @@ struct MLAS_NCHWC_CONV_POINTWISE_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const size_t FilterStrideBytes = BlockSize * InputChannels * sizeof(float); const size_t OutputStrideBytes = BlockSize * OutputSize * sizeof(float); -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvPointwiseFloatKernel; #else MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = MlasConvPointwiseFloatKernel; @@ -1016,7 +1016,7 @@ struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM const size_t BlockedOutputWidth = BlockSize * OutputWidth; -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvDepthwiseFloatKernel; #else MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* Kernel = MlasConvDepthwiseFloatKernel; @@ -1093,7 +1093,7 @@ struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM { -#if !defined(MLAS_TARGET_AMD64) +#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) static MLAS_POOL_FLOAT_KERNEL* const PoolKernels[]; #endif @@ -1131,7 +1131,7 @@ struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM const size_t DilatedInputWidthBytes = BlockSize * DilationHeight * InputWidth * sizeof(float); const size_t InputStrideBytes = DilatedInputWidthBytes - KernelWidth * DilationWidthBytes; -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) MLAS_POOL_FLOAT_KERNEL* Kernel = GetMlasPlatform().PoolFloatKernel[WorkBlock->PoolingKind]; #else MLAS_POOL_FLOAT_KERNEL* Kernel = PoolKernels[WorkBlock->PoolingKind]; @@ -1197,7 +1197,7 @@ struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM } }; -#if !defined(MLAS_TARGET_AMD64) +#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) MLAS_POOL_FLOAT_KERNEL* const MLAS_NCHWC_POOL_ALGORITHM::PoolKernels[] = { @@ -1621,7 +1621,7 @@ Return Value: } } -#if !defined(MLAS_TARGET_AMD64) +#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) // // Convolution and pooling kernel stubs for architectures that do not yet have diff --git a/onnxruntime/core/mlas/lib/transpose.cpp b/onnxruntime/core/mlas/lib/transpose.cpp index 86b0897bb91ec..a758a0e59fb4f 100644 --- a/onnxruntime/core/mlas/lib/transpose.cpp +++ b/onnxruntime/core/mlas/lib/transpose.cpp @@ -371,6 +371,121 @@ MlasTranspose16x16Block( vec_vsx_st(e0, 0, &Output[OutputStride * 14]); vec_vsx_st(e1, 0, &Output[OutputStride * 15]); } + +#elif defined(MLAS_LSX_INTRINSICS) + +MLAS_FORCEINLINE +void +MlasTranspose4x4Block( + const uint32_t* Input, + size_t InputStride, + uint32_t* Output, + size_t OutputStride + ) +{ + __m128i a0 = __lsx_vld((const __m128i*)&Input[InputStride * 0], 0); + __m128i a1 = __lsx_vld((const __m128i*)&Input[InputStride * 1], 0); + __m128i a2 = __lsx_vld((const __m128i*)&Input[InputStride * 2], 0); + __m128i a3 = __lsx_vld((const __m128i*)&Input[InputStride * 3], 0); + + __m128i b0 = __lsx_vilvl_w(a2, a0); + __m128i b1 = __lsx_vilvh_w(a2, a0); + __m128i b2 = __lsx_vilvl_w(a3, a1); + __m128i b3 = __lsx_vilvh_w(a3, a1); + __m128i c0 = __lsx_vilvl_w(b2, b0); + __m128i c1 = __lsx_vilvh_w(b2, b0); + __m128i c2 = __lsx_vilvl_w(b3, b1); + __m128i c3 = __lsx_vilvh_w(b3, b1); + + __lsx_vst(c0, (__m128i*)&Output[OutputStride * 0], 0); + __lsx_vst(c1, (__m128i*)&Output[OutputStride * 1], 0); + __lsx_vst(c2, (__m128i*)&Output[OutputStride * 2], 0); + __lsx_vst(c3, (__m128i*)&Output[OutputStride * 3], 0); +} + +MLAS_FORCEINLINE +void +MlasTranspose4x4Block( + const uint16_t* Input, + size_t InputStride, + uint16_t* Output, + size_t OutputStride + ) +{ + __m128i a0 = __lsx_vld((const __m128i*)&Input[InputStride * 0], 0); + __lsx_vinsgr2vr_d(a0, 0 , 1); + __m128i a1 = __lsx_vld((const __m128i*)&Input[InputStride * 1], 0); + __lsx_vinsgr2vr_d(a1, 0 , 1); + __m128i a2 = __lsx_vld((const __m128i*)&Input[InputStride * 2], 0); + __lsx_vinsgr2vr_d(a2, 0 , 1); + __m128i a3 = __lsx_vld((const __m128i*)&Input[InputStride * 3], 0); + __lsx_vinsgr2vr_d(a3, 0 , 1); + + __m128i b0 = __lsx_vilvl_h(a2, a0); + __m128i b1 = __lsx_vilvl_h(a3, a1); + __m128i c0 = __lsx_vilvl_h(b1, b0); + __m128i c1 = __lsx_vilvh_h(b1, b0); + + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 0], 0), __lsx_vpickve2gr_d(c0, 0), 0), (__m128i *)&Output[OutputStride * 0], 0); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 1], 0), __lsx_vpickve2gr_d(c0, 1), 0), (__m128i *)&Output[OutputStride * 1], 0); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 2], 0), __lsx_vpickve2gr_d(c1, 0), 0), (__m128i *)&Output[OutputStride * 2], 0); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 3], 0), __lsx_vpickve2gr_d(c1, 1), 0), (__m128i *)&Output[OutputStride * 3], 0); +} + +MLAS_FORCEINLINE +void +MlasTranspose8x8Block( + const uint8_t* Input, + size_t InputStride, + uint8_t* Output, + size_t OutputStride + ) +{ + __m128i a0 = __lsx_vld((const __m128i*)&Input[InputStride * 0], 0); + __lsx_vinsgr2vr_d(a0, 0, 1); + __m128i a1 = __lsx_vld((const __m128i*)&Input[InputStride * 1], 0); + __lsx_vinsgr2vr_d(a1, 0, 1); + __m128i b0 = __lsx_vilvl_b(a1, a0); + + __m128i a2 = __lsx_vld((const __m128i*)&Input[InputStride * 2], 0); + __lsx_vinsgr2vr_d(a2, 0, 1); + __m128i a3 = __lsx_vld((const __m128i*)&Input[InputStride * 3], 0); + __lsx_vinsgr2vr_d(a3, 0, 1); + __m128i b1 = __lsx_vilvl_b(a3, a2); + + __m128i a4 = __lsx_vld((const __m128i*)&Input[InputStride * 4], 0); + __lsx_vinsgr2vr_d(a4, 0, 1); + __m128i a5 = __lsx_vld((const __m128i*)&Input[InputStride * 5], 0); + __lsx_vinsgr2vr_d(a5, 0, 1); + __m128i b2 = __lsx_vilvl_b(a5, a4); + + __m128i a6 = __lsx_vld((const __m128i*)&Input[InputStride * 6], 0); + __lsx_vinsgr2vr_d(a6, 0, 1); + __m128i a7 = __lsx_vld((const __m128i*)&Input[InputStride * 7], 0); + __lsx_vinsgr2vr_d(a7, 0, 1); + __m128i b3 = __lsx_vilvl_b(a7, a6); + __m128i c0 = __lsx_vilvl_h(b1, b0); + __m128i c1 = __lsx_vilvh_h(b1, b0); + __m128i c2 = __lsx_vilvl_h(b3, b2); + __m128i c3 = __lsx_vilvh_h(b3, b2); + + __m128 d0 = (__m128)(__lsx_vilvl_w(c2, c0)); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 0], 0), __lsx_vpickve2gr_d(d0, 0), 0), (__m128i *)&Output[OutputStride * 0], 0); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 1], 0), __lsx_vpickve2gr_d(d0, 1), 0), (__m128i *)&Output[OutputStride * 1], 0); + + __m128 d1 = (__m128)(__lsx_vilvh_w(c2, c0)); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 2], 0), __lsx_vpickve2gr_d(d1, 0), 0), (__m128i *)&Output[OutputStride * 2], 0); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 3], 0), __lsx_vpickve2gr_d(d1, 1), 0), (__m128i *)&Output[OutputStride * 3], 0); + + __m128 d2 = (__m128)(__lsx_vilvl_w(c3, c1)); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 4], 0), __lsx_vpickve2gr_d(d2, 0), 0), (__m128i *)&Output[OutputStride * 4], 0); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 5], 0), __lsx_vpickve2gr_d(d2, 1), 0), (__m128i *)&Output[OutputStride * 5], 0); + + __m128 d3 = (__m128)(__lsx_vilvh_w(c3, c1)); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 6], 0), __lsx_vpickve2gr_d(d3, 0), 0), (__m128i *)&Output[OutputStride * 6], 0); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 7], 0), __lsx_vpickve2gr_d(d3, 1), 0), (__m128i *)&Output[OutputStride * 7], 0); +} + #endif template @@ -472,7 +587,8 @@ Return Value: uint32_t* d = Output; size_t m = M; -#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_TARGET_POWER) +#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_TARGET_POWER) || \ + defined(MLAS_LSX_INTRINSICS) while (m >= 4) { @@ -597,7 +713,7 @@ Return Value: uint16_t* d = Output; size_t m = M; -#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) +#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_LSX_INTRINSICS) while (m >= 4) { @@ -734,7 +850,7 @@ Return Value: uint8_t* d = Output; size_t m = M; -#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) +#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_LSX_INTRINSICS) while (m >= 8) {