From 6b6a62fb40d78383fa6e56faf9c8b7a12a08d1b5 Mon Sep 17 00:00:00 2001 From: Yi-Hong Lyu Date: Tue, 16 Apr 2024 13:52:43 -0700 Subject: [PATCH] Add vectorized AVX512F kernel for ReduceMaximumF32Kernel (#20268) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description This commit introduces a new vectorized AVX512F kernel, MlasReduceMaximumF32KernelAvx512F, which efficiently computes the maximum value of the supplied buffer. Additionally, microbenchmarks have been added for MlasComputeSoftmax (inplace), MlasReduceMaximumF32KernelAvx, MlasComputeSumExpF32KernelAvx512F, and MlasComputeSoftmaxOutputF32KernelAvx. ### Motivation and Context The goal of this commit is to enhance the performance of ReduceMaximumF32Kernel on CPUs with AVX512F instruction support.   | AVX |   |   | AVX512 |   |   |   -- | -- | -- | -- | -- | -- | -- | -- name | iterations | real_time | cpu_time | iterations | real_time | cpu_time | time_unit REDUCEMAXIMUMF32KERNEL[]/ByteAligned:4/D:3/real_time | 271277304 | 2.58095 | 2.58091 | 263338132 | 2.65661 | 2.65661 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:8/D:3/real_time | 271220477 | 2.58095 | 2.58095 | 263509929 | 2.65652 | 2.65649 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:16/D:3/real_time | 271240587 | 2.58064 | 2.58064 | 263479542 | 2.65671 | 2.65665 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:32/D:3/real_time | 271227745 | 2.58083 | 2.58079 | 263402506 | 2.65657 | 2.65657 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:64/D:3/real_time | 271255069 | 2.58073 | 2.58071 | 263463858 | 2.65682 | 2.65682 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:128/D:3/real_time | 271257174 | 2.58058 | 2.58052 | 263460120 | 2.65682 | 2.65682 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:4/D:4/real_time | 174395051 | 4.01401 | 4.01401 | 197330481 | 3.5465 | 3.54636 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:8/D:4/real_time | 174645502 | 3.99691 | 3.99691 | 197474831 | 3.54298 | 3.54278 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:16/D:4/real_time | 174523308 | 4.01391 | 4.01386 | 197389981 | 3.54518 | 3.54506 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:32/D:4/real_time | 174779200 | 3.99874 | 3.99874 | 197519075 | 3.54227 | 3.54209 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:64/D:4/real_time | 174642874 | 4.00645 | 4.00641 | 197642101 | 3.54195 | 3.54188 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:128/D:4/real_time | 174546754 | 4.0061 | 4.00608 | 197621033 | 3.54296 | 3.54281 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:4/D:5/real_time | 162752651 | 4.30119 | 4.30114 | 215552503 | 3.24767 | 3.24752 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:8/D:5/real_time | 162717463 | 4.30123 | 4.30116 | 215541082 | 3.24711 | 3.24695 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:16/D:5/real_time | 162718819 | 4.3016 | 4.30153 | 215589239 | 3.24725 | 3.24708 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:32/D:5/real_time | 162719596 | 4.30151 | 4.30145 | 215563846 | 3.24956 | 3.24949 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:64/D:5/real_time | 162753333 | 4.30125 | 4.30125 | 215537315 | 3.24924 | 3.24908 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:128/D:5/real_time | 162752258 | 4.3014 | 4.30141 | 215526482 | 3.24744 | 3.24735 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:4/D:7/real_time | 143579660 | 4.87526 | 4.87516 | 100000000 | 5.25767 | 5.25752 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:8/D:7/real_time | 143585097 | 4.87476 | 4.87467 | 100000000 | 5.41583 | 5.41567 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:16/D:7/real_time | 143571011 | 4.87506 | 4.87503 | 182359467 | 3.83773 | 3.83764 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:32/D:7/real_time | 143587142 | 4.87487 | 4.8748 | 182397261 | 3.83807 | 3.8379 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:64/D:7/real_time | 143578465 | 4.87525 | 4.87521 | 182428602 | 3.83777 | 3.83768 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:128/D:7/real_time | 143588555 | 4.87491 | 4.87488 | 125280452 | 5.59791 | 5.59766 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:4/D:9/real_time | 284851058 | 2.43476 | 2.43476 | 156879863 | 4.42895 | 4.42884 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:8/D:9/real_time | 270700898 | 2.59031 | 2.59024 | 157953114 | 4.42995 | 4.42968 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:16/D:9/real_time | 282871172 | 2.45385 | 2.45385 | 157801156 | 4.42817 | 4.42804 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:32/D:9/real_time | 285307738 | 2.47009 | 2.47005 | 158058507 | 4.4279 | 4.42786 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:64/D:9/real_time | 285709536 | 2.45481 | 2.45476 | 158070961 | 4.42809 | 4.42799 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:128/D:9/real_time | 285449733 | 2.47495 | 2.47491 | 158069718 | 4.45026 | 4.45017 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:4/D:11/real_time | 189213618 | 3.79684 | 3.79676 | 139459497 | 5.01882 | 5.01871 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:8/D:11/real_time | 185600468 | 3.76394 | 3.76376 | 139444892 | 5.01922 | 5.01905 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:16/D:11/real_time | 184968668 | 3.80636 | 3.80636 | 139470834 | 5.01948 | 5.01936 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:32/D:11/real_time | 183867226 | 3.80432 | 3.80427 | 139481986 | 5.01975 | 5.01944 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:64/D:11/real_time | 184301650 | 3.81634 | 3.81634 | 139452846 | 5.01983 | 5.01972 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:128/D:11/real_time | 186215795 | 3.82659 | 3.82654 | 139497736 | 5.02119 | 5.02113 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:4/D:13/real_time | 135622415 | 5.16256 | 5.16252 | 124661337 | 5.61227 | 5.61194 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:8/D:13/real_time | 135618907 | 5.15967 | 5.1596 | 124805224 | 5.6088 | 5.60854 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:16/D:13/real_time | 135612192 | 5.15506 | 5.15501 | 124803221 | 5.60901 | 5.60869 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:32/D:13/real_time | 135906082 | 5.15818 | 5.15818 | 124776601 | 5.60898 | 5.60886 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:64/D:13/real_time | 135369523 | 5.15709 | 5.15682 | 124790370 | 5.60927 | 5.60902 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:128/D:13/real_time | 135596827 | 5.1603 | 5.1603 | 124792145 | 5.61637 | 5.61614 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:4/D:15/real_time | 110947137 | 5.96511 | 5.96495 | 112861522 | 6.20035 | 6.20014 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:8/D:15/real_time | 118004792 | 6.22645 | 6.22628 | 112909900 | 6.20073 | 6.20073 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:16/D:15/real_time | 112630319 | 6.25564 | 6.25552 | 112874563 | 6.19932 | 6.19924 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:32/D:15/real_time | 117403034 | 6.17263 | 6.17258 | 112927318 | 6.19866 | 6.19842 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:64/D:15/real_time | 108921863 | 6.48624 | 6.48612 | 112927746 | 6.20057 | 6.20026 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:128/D:15/real_time | 110358148 | 6.66805 | 6.66789 | 112907312 | 6.19938 | 6.19908 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:4/D:16/real_time | 203419574 | 3.4415 | 3.44137 | 237134525 | 2.95649 | 2.95638 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:8/D:16/real_time | 203414035 | 3.4411 | 3.44099 | 237129564 | 2.95178 | 2.95171 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:16/D:16/real_time | 203404068 | 3.44157 | 3.44151 | 236981704 | 2.9518 | 2.95167 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:32/D:16/real_time | 203391471 | 3.44146 | 3.44137 | 237108807 | 2.95203 | 2.95196 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:64/D:16/real_time | 203393801 | 3.44131 | 3.44127 | 237126460 | 2.95278 | 2.95272 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:128/D:16/real_time | 203407476 | 3.44181 | 3.44162 | 237154444 | 2.95293 | 2.9528 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:4/D:500/real_time | 37551439 | 18.6407 | 18.6407 | 39222534 | 17.858 | 17.8571 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:8/D:500/real_time | 37544097 | 18.6404 | 18.6401 | 39174151 | 17.8539 | 17.8536 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:16/D:500/real_time | 37549837 | 18.6391 | 18.6391 | 39233956 | 17.8507 | 17.8505 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:32/D:500/real_time | 45996345 | 15.2157 | 15.2153 | 39285929 | 17.848 | 17.8474 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:64/D:500/real_time | 46012429 | 15.2184 | 15.2179 | 65664865 | 10.7366 | 10.7364 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:128/D:500/real_time | 45912375 | 15.2349 | 15.2346 | 65205908 | 10.8498 | 10.8492 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:4/D:2000/real_time | 9493955 | 73.7232 | 73.7203 | 10188090 | 68.7931 | 68.7908 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:8/D:2000/real_time | 9495562 | 73.7173 | 73.7173 | 10180895 | 68.7533 | 68.7511 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:16/D:2000/real_time | 9487371 | 73.7852 | 73.7831 | 10164473 | 68.7279 | 68.725 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:32/D:2000/real_time | 10816047 | 64.7322 | 64.7287 | 10168481 | 68.8109 | 68.8096 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:64/D:2000/real_time | 10808802 | 64.7232 | 64.721 | 19478320 | 36.1471 | 36.1461 | ns REDUCEMAXIMUMF32KERNEL[]/ByteAligned:128/D:2000/real_time | 10818192 | 64.7304 | 64.728 | 19419672 | 35.9635 | 35.9635 | ns --- cmake/onnxruntime_mlas.cmake | 2 + .../mlas/lib/amd64/SoftmaxKernelAvx512F.asm | 103 ++++++++ onnxruntime/core/mlas/lib/mlasi.h | 1 + onnxruntime/core/mlas/lib/platform.cpp | 1 + .../mlas/lib/x86_64/SoftmaxKernelAvx512F.S | 101 ++++++++ .../test/mlas/bench/bench_computesoftmax.cpp | 233 ++++++++++++++++++ 6 files changed, 441 insertions(+) create mode 100644 onnxruntime/core/mlas/lib/amd64/SoftmaxKernelAvx512F.asm create mode 100644 onnxruntime/core/mlas/lib/x86_64/SoftmaxKernelAvx512F.S create mode 100644 onnxruntime/test/mlas/bench/bench_computesoftmax.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 6b7d4402be8eb..f7103c3b00a37 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -197,6 +197,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/amd64/sgemma.asm ${MLAS_SRC_DIR}/amd64/cvtfp16a.asm ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx.asm + ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx512F.asm ${MLAS_SRC_DIR}/amd64/TransKernelFma3.asm ${MLAS_SRC_DIR}/amd64/TransKernelAvx512F.asm ${MLAS_SRC_DIR}/amd64/LogisticKernelFma3.asm @@ -536,6 +537,7 @@ else() ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx512F.S ${MLAS_SRC_DIR}/x86_64/SconvKernelAvx512F.S + ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx512F.S ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx512F.S ${MLAS_SRC_DIR}/x86_64/TransKernelAvx512F.S ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp diff --git a/onnxruntime/core/mlas/lib/amd64/SoftmaxKernelAvx512F.asm b/onnxruntime/core/mlas/lib/amd64/SoftmaxKernelAvx512F.asm new file mode 100644 index 0000000000000..3e83bc852f558 --- /dev/null +++ b/onnxruntime/core/mlas/lib/amd64/SoftmaxKernelAvx512F.asm @@ -0,0 +1,103 @@ +;++ +; +;Copyright (c) Microsoft Corporation. All rights reserved. +; +;Licensed under the MIT License. +; +;Module Name: +; +; SoftmaxKernelAvx512F.asm +; +;Abstract: +; +; This module implements the kernels for the single precision softmax +; operation. +; +; This implementation uses AVX512F instructions. +; +;-- + + .xlist +INCLUDE mlasi.inc + .list + + EXTERN MlasMinimumF32Value:NEAR + +;++ +; +;Routine Description: +; +; This routine implements a vectorized kernel to find the maximum value of +; the supplied buffer. +; +;Arguments: +; +; Input (rcx) - Supplies the input buffer. +; +; N (rdx) - Supplies the number of elements to process. +; +;Return Value: +; +; Returns the maximum value of the supplied buffer. +; +;-- + + LEAF_ENTRY MlasReduceMaximumF32KernelAvx512F, _TEXT + + vbroadcastss zmm0,DWORD PTR [MlasMinimumF32Value] + test rdx,rdx + jz ExitKernel + cmp rdx,16 + jb ProcessRemainingCountBy1 + cmp rdx,64 + jb ProcessRemainingCountBy16 + vmovaps zmm1,zmm0 + vmovaps zmm2,zmm0 + vmovaps zmm3,zmm0 + +ProcessRemainingCountBy64: + vmaxps zmm0,zmm0,ZMMWORD PTR [rcx] + vmaxps zmm1,zmm1,ZMMWORD PTR [rcx+16*4] + sub rdx,64 + vmaxps zmm2,zmm2,ZMMWORD PTR [rcx+32*4] + vmaxps zmm3,zmm3,ZMMWORD PTR [rcx+48*4] + add rcx,64*4 ; advance input by 64 elements + cmp rdx,64 + jae ProcessRemainingCountBy64 + vmaxps zmm0,zmm0,zmm1 ; reduce to single vector + vmaxps zmm2,zmm2,zmm3 + vmaxps zmm0,zmm0,zmm2 + +ProcessRemainingCountBy16: + cmp rdx,16 + jb ProcessRemainingCountLessThan16 + vmaxps zmm0,zmm0,ZMMWORD PTR [rcx] + sub rdx,16 + add rcx,16*4 ; advance input by 16 elements + jmp ProcessRemainingCountBy16 + +ProcessRemainingCountLessThan16: + vextractf32x8 ymm1,zmm0,1 ; reduce to single scalar + vmaxps ymm0,ymm0,ymm1 + vextractf128 xmm1,ymm0,1 + vmaxps xmm0,xmm0,xmm1 + vshufps xmm1,xmm0,xmm0,0EEh + vmaxps xmm0,xmm0,xmm1 + vshufps xmm1,xmm0,xmm0,055h + vmaxss xmm0,xmm0,xmm1 + test rdx,rdx + jz ExitKernel + +ProcessRemainingCountBy1: + vmaxss xmm0,xmm0,DWORD PTR [rcx] + add rcx,4 ; advance input by 1 element + dec edx + jnz ProcessRemainingCountBy1 + +ExitKernel: + vzeroupper + ret + + LEAF_END MlasReduceMaximumF32KernelAvx512F, _TEXT + + END diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 624eb913d5c9e..4b93dde1bcef9 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -846,6 +846,7 @@ extern "C" { MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL MlasReduceMinimumMaximumF32Kernel; #if defined(MLAS_TARGET_AMD64) MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32KernelAvx; + MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32KernelAvx512F; MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL MlasReduceMinimumMaximumF32KernelAvx; #endif diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index de092f7d1d350..a53c5085b10cf 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -421,6 +421,7 @@ Return Value: this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelAvx512F; this->ComputeExpF32Kernel = MlasComputeExpF32KernelAvx512F; this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelAvx512F; + this->ReduceMaximumF32Kernel = MlasReduceMaximumF32KernelAvx512F; this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8KernelAvx512F; this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8KernelAvx512F; this->NchwcBlockSize = 16; diff --git a/onnxruntime/core/mlas/lib/x86_64/SoftmaxKernelAvx512F.S b/onnxruntime/core/mlas/lib/x86_64/SoftmaxKernelAvx512F.S new file mode 100644 index 0000000000000..db97286046567 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/SoftmaxKernelAvx512F.S @@ -0,0 +1,101 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SoftmaxKernelAvx512F.s + +Abstract: + + This module implements the kernels for the single precision softmax + operation. + + This implementation uses AVX512F instructions. + +--*/ + +#include "asmmacro.h" + + .intel_syntax noprefix + + .text + +/*++ + +Routine Description: + + This routine implements a vectorized kernel to find the maximum value of + the supplied buffer. + +Arguments: + + Input (rdi) - Supplies the input buffer. + + N (rsi) - Supplies the number of elements to process. + +Return Value: + + Returns the maximum value of the supplied buffer. + +--*/ + + FUNCTION_ENTRY MlasReduceMaximumF32KernelAvx512F + + vbroadcastss zmm0,DWORD PTR C_UNDERSCORE(MlasMinimumF32Value)[rip] + test rsi,rsi + jz .LReduceMaximum.ExitKernel + cmp rsi,16 + jb .LReduceMaximum.ProcessRemainingCountBy1 + cmp rsi,64 + jb .LReduceMaximum.ProcessRemainingCountBy16 + vmovaps zmm1,zmm0 + vmovaps zmm2,zmm0 + vmovaps zmm3,zmm0 + +.LReduceMaximum.ProcessRemainingCountBy64: + vmaxps zmm0,zmm0,ZMMWORD PTR [rdi] + vmaxps zmm1,zmm1,ZMMWORD PTR [rdi+16*4] + sub rsi,64 + vmaxps zmm2,zmm2,ZMMWORD PTR [rdi+32*4] + vmaxps zmm3,zmm3,ZMMWORD PTR [rdi+48*4] + add rdi,64*4 # advance input by 64 elements + cmp rsi,64 + jae .LReduceMaximum.ProcessRemainingCountBy64 + vmaxps zmm0,zmm0,zmm1 # reduce to single vector + vmaxps zmm2,zmm2,zmm3 + vmaxps zmm0,zmm0,zmm2 + +.LReduceMaximum.ProcessRemainingCountBy16: + cmp rsi,16 + jb .LReduceMaximum.ProcessRemainingCountLessThan16 + vmaxps zmm0,zmm0,ZMMWORD PTR [rdi] + sub rsi,16 + add rdi,16*4 # advance input by 16 elements + jmp .LReduceMaximum.ProcessRemainingCountBy16 + +.LReduceMaximum.ProcessRemainingCountLessThan16: + vextractf32x8 ymm1,zmm0,1 # reduce to single scalar + vmaxps ymm0,ymm0,ymm1 + vextractf128 xmm1,ymm0,1 + vmaxps xmm0,xmm0,xmm1 + vshufps xmm1,xmm0,xmm0,0xEE + vmaxps xmm0,xmm0,xmm1 + vshufps xmm1,xmm0,xmm0,0x55 + vmaxss xmm0,xmm0,xmm1 + test rsi,rsi + jz .LReduceMaximum.ExitKernel + +.LReduceMaximum.ProcessRemainingCountBy1: + vmaxss xmm0,xmm0,DWORD PTR [rdi] + add rdi,4 # advance input by 1 element + dec esi + jnz .LReduceMaximum.ProcessRemainingCountBy1 + +.LReduceMaximum.ExitKernel: + vzeroupper + ret + + .end diff --git a/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp b/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp new file mode 100644 index 0000000000000..f777a7cfc4302 --- /dev/null +++ b/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp @@ -0,0 +1,233 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/mlas/lib/mlasi.h" +#include "core/util/thread_utils.h" +#include "test/mlas/bench/bench_util.h" + +using onnxruntime::narrow; + +struct RestrictAlignedPtr { + float* ptr; // Aligned pointer within the underlying buffer + void* underlying_buffer; // Underlying buffer (including extra space for alignment) +}; + +// Return a RestrictAlignedPtr where the ptr is aligned to byte_aligned, but not to byte_aligned * 2 +RestrictAlignedPtr restrict_aligned_alloc(int D, int byte_aligned) { + if (byte_aligned <= 0 || (byte_aligned & (byte_aligned - 1)) != 0) { + throw std::invalid_argument("Alignment must be a power of 2"); + } + + const int byte_alignedx2 = byte_aligned << 1; + + void* buffer = malloc(D * sizeof(float) + byte_alignedx2 * 2); + if (buffer == nullptr) { + ORT_THROW_EX(std::bad_alloc); + } + + uintptr_t address = reinterpret_cast(buffer); + uintptr_t aligned_address = ((address + byte_alignedx2 - 1) & ~(byte_alignedx2 - 1)) + byte_aligned; + ORT_ENFORCE((aligned_address % byte_aligned == 0) && (aligned_address % byte_alignedx2 != 0)); + float* aligned_ptr = reinterpret_cast(aligned_address); + + return {aligned_ptr, buffer}; +} + +void COMPUTESOFTMAXINPLACE(benchmark::State& state) { + const auto byte_aligned = narrow(state.range(0)); + const auto N = narrow(state.range(1)); + const auto D = narrow(state.range(2)); + const auto threads = narrow(state.range(3)); + + if (N <= 0 || D <= 0 || threads <= 0) { + throw std::invalid_argument("N, D, and Threads must be greater than 0!"); + } + + OrtThreadPoolParams tpo; + tpo.thread_pool_size = threads; + tpo.auto_set_affinity = true; + + std::unique_ptr tp( + onnxruntime::concurrency::CreateThreadPool( + &onnxruntime::Env::Default(), tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); + + auto data = RandomVectorUniform(static_cast(N * D), -1.0f, 1.0f); + RestrictAlignedPtr ptr = restrict_aligned_alloc(N * D, byte_aligned); + float* input = ptr.ptr; + float* output = input; + std::copy(data.begin(), data.end(), input); // Copy the data to the aligned memory + + // warming up run + MlasComputeSoftmax(input, output, N, D, false, tp.get()); + + for (auto _ : state) { + MlasComputeSoftmax(input, output, N, D, false, tp.get()); + } + + free(ptr.underlying_buffer); +} + +void REDUCEMAXIMUMF32KERNELAVX(benchmark::State& state) { + const auto byte_aligned = narrow(state.range(0)); + const auto D = narrow(state.range(1)); + + if (D <= 0) { + throw std::invalid_argument("D must be greater than 0!"); + } + + auto data = RandomVectorUniform(static_cast(D), -1.0f, 1.0f); + RestrictAlignedPtr ptr = restrict_aligned_alloc(D, byte_aligned); + float* input = ptr.ptr; + std::copy(data.begin(), data.end(), input); // Copy the data to the aligned memory + + // warming up run + float Maximum = MlasReduceMaximumF32KernelAvx(input, D); + + for (auto _ : state) { + Maximum = MlasReduceMaximumF32KernelAvx(input, D); + } + + free(ptr.underlying_buffer); + (void)Maximum; +} + +void REDUCEMAXIMUMF32KERNELAVX512F(benchmark::State& state) { + const auto byte_aligned = narrow(state.range(0)); + const auto D = narrow(state.range(1)); + + if (D <= 0) { + throw std::invalid_argument("D must be greater than 0!"); + } + + auto data = RandomVectorUniform(static_cast(D), -1.0f, 1.0f); + RestrictAlignedPtr ptr = restrict_aligned_alloc(D, byte_aligned); + float* input = ptr.ptr; + std::copy(data.begin(), data.end(), input); // Copy the data to the aligned memory + + // warming up run + float Maximum = MlasReduceMaximumF32KernelAvx512F(input, D); + + for (auto _ : state) { + Maximum = MlasReduceMaximumF32KernelAvx512F(input, D); + } + + free(ptr.underlying_buffer); + (void)Maximum; +} + +void COMPUTESUMEXPF32KERNELAVX512F(benchmark::State& state) { + const auto byte_aligned = narrow(state.range(0)); + const auto D = narrow(state.range(1)); + + if (D <= 0) { + throw std::invalid_argument("D must be greater than 0!"); + } + + auto data = RandomVectorUniform(static_cast(D), -1.0f, 1.0f); + RestrictAlignedPtr ptr = restrict_aligned_alloc(D, byte_aligned); + float* input = ptr.ptr; + float* output = input; + std::copy(data.begin(), data.end(), input); // Copy the data to the aligned memory + + float Maximum = MlasReduceMaximumF32KernelAvx(input, D); + float NegativeMaximum = -Maximum; + + // warming up run + float Accumulation = MlasComputeSumExpF32KernelAvx512F(input, output, D, &NegativeMaximum); + + for (auto _ : state) { + Accumulation = MlasComputeSumExpF32KernelAvx512F(input, output, D, &NegativeMaximum); + } + + free(ptr.underlying_buffer); + (void)Accumulation; +} + +void COMPUTESOFTMAXOUTPUTF32KERNELAVX(benchmark::State& state) { + const auto byte_aligned = narrow(state.range(0)); + const auto D = narrow(state.range(1)); + + if (D <= 0) { + throw std::invalid_argument("D must be greater than 0!"); + } + + auto data = RandomVectorUniform(static_cast(D), -1.0f, 1.0f); + RestrictAlignedPtr ptr = restrict_aligned_alloc(D, byte_aligned); + float* input = ptr.ptr; + float* output = input; + std::copy(data.begin(), data.end(), input); // Copy the data to the aligned memory + + float Maximum = MlasReduceMaximumF32KernelAvx(input, D); + float NegativeMaximum = -Maximum; + + float Accumulation = MlasComputeSumExpF32KernelAvx512F(input, output, D, &NegativeMaximum); + + float Parameters[] = {1.0f / Accumulation}; + + // warming up run + MlasComputeSoftmaxOutputF32KernelAvx(output, D, Parameters); + + for (auto _ : state) { + MlasComputeSoftmaxOutputF32KernelAvx(output, D, Parameters); + } + + free(ptr.underlying_buffer); +} + +static void ComputeSoftmaxInplaceArgs(benchmark::internal::Benchmark* b) { + b->ArgNames({"ByteAligned", "N", "D", "Threads"}); + for (int threads : {1, 8}) { + for (int byte_aligned : {64}) { // MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT is 64 + b->Args({byte_aligned, 16000, 4, threads}); + b->Args({byte_aligned, 16000, 500, threads}); + b->Args({byte_aligned, 48000, 3, threads}); + b->Args({byte_aligned, 48000, 2000, threads}); + b->Args({byte_aligned, 80000, 5, threads}); + b->Args({byte_aligned, 80000, 2000, threads}); + b->Args({byte_aligned, 112000, 7, threads}); + b->Args({byte_aligned, 112000, 2000, threads}); + b->Args({byte_aligned, 144000, 9, threads}); + b->Args({byte_aligned, 144000, 2000, threads}); + b->Args({byte_aligned, 176000, 11, threads}); + b->Args({byte_aligned, 176000, 2000, threads}); + b->Args({byte_aligned, 208000, 13, threads}); + b->Args({byte_aligned, 208000, 2000, threads}); + b->Args({byte_aligned, 240000, 15, threads}); + b->Args({byte_aligned, 240000, 2000, threads}); + } + } +} + +BENCHMARK(COMPUTESOFTMAXINPLACE)->Apply(ComputeSoftmaxInplaceArgs)->UseRealTime(); + +BENCHMARK(REDUCEMAXIMUMF32KERNELAVX) + ->ArgNames({"ByteAligned", "D"}) + ->ArgsProduct({ + {4, 8, 16, 32, 64, 128}, // ByteAligned + {3, 4, 5, 7, 9, 11, 13, 15, 16, 500, 2000}, // D + }) + ->UseRealTime(); + +BENCHMARK(REDUCEMAXIMUMF32KERNELAVX512F) + ->ArgNames({"ByteAligned", "D"}) + ->ArgsProduct({ + {4, 8, 16, 32, 64, 128}, // ByteAligned + {3, 4, 5, 7, 9, 11, 13, 15, 16, 500, 2000}, // D + }) + ->UseRealTime(); + +BENCHMARK(COMPUTESUMEXPF32KERNELAVX512F) + ->ArgNames({"ByteAligned", "D"}) + ->ArgsProduct({ + {4, 8, 16, 32, 64, 128}, // ByteAligned + {3, 4, 5, 7, 9, 11, 13, 15, 500, 2000}, // D + }) + ->UseRealTime(); + +BENCHMARK(COMPUTESOFTMAXOUTPUTF32KERNELAVX) + ->ArgNames({"ByteAligned", "D"}) + ->ArgsProduct({ + {4, 8, 16, 32, 64, 128}, // ByteAligned + {3, 4, 5, 7, 9, 11, 13, 15, 16, 500, 2000}, // D + }) + ->UseRealTime();