diff --git a/.github/stale.yml b/.github/stale.yml deleted file mode 100644 index d89f0cdd91e52..0000000000000 --- a/.github/stale.yml +++ /dev/null @@ -1,22 +0,0 @@ -# Number of days of inactivity before an issue becomes stale -daysUntilStale: 30 - -# Number of days of inactivity before a stale issue is closed -daysUntilClose: 7 - -# Issues with these labels will never be considered stale -exemptLabels: - - contributions welcome - - feature request - - regression - -# Label to use when marking an issue as stale -staleLabel: stale - -# Comment to post when marking an issue as stale. Set to `false` to disable -markComment: > - This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details. - -# Comment to post when closing a stale issue. Set to `false` to disable -closeComment: > - This issue has been automatically closed due to inactivity. Please reactivate if further support is needed. diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 0000000000000..67d8550d44204 --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,34 @@ +name: Close stale issues +on: + # Allows you to dictate when you want this workflow to run using cron syntax (times in UTC) + schedule: + - cron: "0 15 * * *" + # Allows you to run this workflow manually from the Actions tab + # workflow_dispatch: + +jobs: + close-stale-issues: + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + steps: + - uses: actions/stale@v4.1.1 + with: + # Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale + exempt-issue-labels: contributions welcome, feature request, regression + # Number of days without activity before the actions/stale action labels an issue + days-before-issue-stale: 30 + # Number of days without activity before the actions/stale action closes an issue + days-before-issue-close: 7 + # Label you want to apply to issues that have been inactive for the amount of time specified by days-before-issue-stale + stale-issue-label: "stale" + # Comment that you want to add to issues that are labeled by the actions/stale action + stale-issue-message: "This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details." + # Comment that you want to add to issues that are closed by the actions/stale action + close-issue-message: "This issue has been automatically closed due to inactivity. Please reactivate if further support is needed." + # If you never want this action to label PRs, set this value to -1 + days-before-pr-stale: -1 + # If you never want this action to close PRs, set this value to -1 + days-before-pr-close: -1 + repo-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.lintrunner.toml b/.lintrunner.toml index c44a66200ad1b..4e5d077b08ff4 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -45,6 +45,7 @@ exclude_patterns = [ 'cmake/external/**', # ignore generated flatbuffers code 'onnxruntime/core/flatbuffers/ort_flatbuffers_py/**', + 'orttraining/orttraining/python/training/optim/_ds_code_store.py', ] command = [ 'python', @@ -76,6 +77,7 @@ exclude_patterns = [ 'cmake/**', 'orttraining/*', 'onnxruntime/core/flatbuffers/**', + 'orttraining/orttraining/python/training/optim/_ds_code_store.py', ] command = [ 'python', diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 08ca90d7c3b7f..6b0e3659bd234 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -2,6 +2,36 @@ "$schema": "https://json.schemastore.org/component-detection-manifest.json", "Version": 1, "Registrations": [ + { + "component": { + "type": "git", + "git": { + "commitHash": "a896e3d066448b3530dbcaa48869fafefd738f57", + "repositoryUrl": "https://github.com/emscripten-core/emsdk.git" + }, + "comments": "git submodule at cmake/external/emsdk" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "7a2ed51a6b682a83e345ff49fc4cfd7ca47550db", + "repositoryUrl": "https://github.com/google/libprotobuf-mutator.git" + }, + "comments": "git submodule at cmake/external/libprotobuf-mutator" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "b86cc54efce19530fb953e4b21f57e6b3888534c", + "repositoryUrl": "https://github.com/onnx/onnx.git" + }, + "comments": "git submodule at cmake/external/onnx" + } + }, { "component": { "type": "git", @@ -166,17 +196,7 @@ "component": { "type": "git", "git": { - "commitHash": "fdefbe85ed9c362b95b9b401cd19db068a76141f", - "repositoryUrl": "https://github.com/onnx/onnx.git" - }, - "comments": "onnx" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "0462dc31ae78f48744b6141ae376df1f96d3f459", + "commitHash": "a43ce67187bab219520fd80f21af8bbd4354bc8c", "repositoryUrl": "https://github.com/onnx/onnx-tensorrt.git" }, "comments": "onnx_tensorrt" diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 1b230f9557984..94181448fd21c 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -78,6 +78,7 @@ option(onnxruntime_USE_CUDA "Build with CUDA support" OFF) # use. If you hit any problem with that, please do not report it to GTest. Turn OFF the following build option instead. cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS;LINUX" OFF) +option(onnxruntime_USE_CUDA_NHWC_OPS "Build CUDA with NHWC op support" OFF) option(onnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO "When building with CUDA support, generate device code line number information." OFF) option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF) option(onnxruntime_USE_COREML "Build with CoreML support" OFF) @@ -671,6 +672,9 @@ set(ORT_PROVIDER_FLAGS) set(ORT_PROVIDER_CMAKE_FLAGS) if (onnxruntime_USE_CUDA) + if (onnxruntime_USE_CUDA_NHWC_OPS) + add_compile_definitions(ENABLE_CUDA_NHWC_OPS) + endif() enable_language(CUDA) message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}") @@ -1278,14 +1282,6 @@ if (onnxruntime_USE_OPENVINO) add_definitions(-DOPENVINO_CONFIG_CPU_FP16=1) endif() - if (onnxruntime_USE_OPENVINO_VPUX_FP16) - add_definitions(-DOPENVINO_CONFIG_VPUX_FP16=1) - endif() - - if (onnxruntime_USE_OPENVINO_VPUX_U8) - add_definitions(-DOPENVINO_CONFIG_VPUX_U8=1) - endif() - if (onnxruntime_USE_OPENVINO_GPU_FP32_NP) add_definitions(-DOPENVINO_CONFIG_GPU_FP32=1) add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) @@ -1306,16 +1302,6 @@ if (onnxruntime_USE_OPENVINO) add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) endif() - if (onnxruntime_USE_OPENVINO_VPUX_FP32_NP) - add_definitions(-DOPENVINO_CONFIG_VPUX_FP32=1) - add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) - endif() - - if (onnxruntime_USE_OPENVINO_VPUX_FP16_NP) - add_definitions(-DOPENVINO_CONFIG_VPUX_FP16=1) - add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) - endif() - if (onnxruntime_USE_OPENVINO_HETERO) add_definitions(-DOPENVINO_CONFIG_HETERO=1) add_definitions(-DDEVICE_NAME="${onnxruntime_USE_OPENVINO_DEVICE}") diff --git a/cmake/deps.txt b/cmake/deps.txt index 7cf49f02333a4..4aab8c974d62f 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -3,30 +3,35 @@ #The columns are separated by ";" because a list in cmake is just a ";" separated group of strings. #Names should be in lower case. They will be used as variable names in cmake. #URLs can be either https URLs or local file paths in cmake-style(directory separator is a forward slash character). -#SHA1 hashes can be generated by running sha1sum command. +#SHA1 hashes can be generated by running sha1sum command on linux. PowerShell can also be used: +# (Get-FileHash -Algorithm SHA1 ).Hash.ToLower() #If you need to change abseil's version to a different one, you may also want to update external\abseil-cpp.natvis #since the file contains a version string: "lts_20230802". However, the file is for debugging purposes only and would #not affect built binaries. +# +# NOTE: You must run deps_update_and_upload.py when ready to test your changes in a CI. +# See https://microsoft.sharepoint.com/teams/ONNX2/_layouts/OneNote.aspx?id=%2Fteams%2FONNX2%2FShared%20Documents%2FNotebooks%2FONNX%20Ecosystem%20Team%20Notebook&wd=target%28Development.one%7C63D3AB47-51D1-4A62-9965-66882234BD44%2FAdd%20or%20update%20a%20dependency%20in%20deps.txt%7C0E9ED71D-89D5-40FA-B05F-C0123289C591%2F%29 +# abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20230802.0.zip;04271dfbfac59269b6939e1e9d5faf0d18a7ba91 cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0 date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159 dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b31321e5549591d78aa7f377173445 -eigen;https://gitlab.com/libeigen/eigen/-/archive/3.4/eigen-3.4.zip;ee201b07085203ea7bd8eb97cbcb31b07cfa3efb +eigen;https://gitlab.com/libeigen/eigen/-/archive/3.4.0/eigen-3.4.0.zip;ef24286b7ece8737c99fa831b02941843546c081 flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v1.12.0.zip;ba0a75fd12dbef8f6557a74e611b7a3d0c5fe7bf fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip;b985f6985a05a1c03ff1bb71190f66d8f98a1494 fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1 google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.7.0.zip;e97c368b176e8614e3f1bf13dd9abcf6a7ad9908 google_nsync;https://github.com/google/nsync/archive/refs/tags/1.26.0.zip;5e7c00ef6bf5b787386fc040067903ec774e2752 googletest;https://github.com/google/googletest/archive/refs/tags/v1.14.0.zip;0ac421f2ec11af38b0fff0f1992184032731a8bc -googlexnnpack;https://github.com/google/XNNPACK/archive/003c580e696a774afdc984996ee909b7c8d8128c.zip;9f192e3f15e1e37ae9c78d53eeea47e45c5eb31c +googlexnnpack;https://github.com/google/XNNPACK/archive/0da379fc4808f9601faef392352018c741c0f297.zip;663883491e380b628e0a5b162b5f2658032fae73 json;https://github.com/nlohmann/json/archive/refs/tags/v3.10.5.zip;f257f8dc27c5b8c085dc887b40cddd18ae1f725c microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 -onnx;https://github.com/onnx/onnx/archive/14303de049144035dfd94ace5f7a3b44773b1aad.zip;250eab9690392b248d75b56e605fb49eca373442 +onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.15.0.zip;54c3f960a0541c5d8d3e60c2933e11f5d3688a11 #use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) -onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/0462dc31ae78f48744b6141ae376df1f96d3f459.zip;5ff086361956cceb81ed17453a1fd8db2aa4328d +onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/a43ce67187bab219520fd80f21af8bbd4354bc8c.zip;572535aefef477050f86744dfab1fef840198035 protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa protoc_win64;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip;b4521f7ada5b260380f94c4bd7f1b7684c76969a protoc_win32;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win32.zip;3688010318192c46ce73213cdfb6b3e5656da874 @@ -35,7 +40,7 @@ protoc_linux_x86;https://github.com/protocolbuffers/protobuf/releases/download/v protoc_linux_aarch64;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-linux-aarch_64.zip;df9d45470b0b8cf939dd2f0ec6b88e9cafc4d617 protoc_mac_universal;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-osx-universal_binary.zip;23710c3d1c2036d8d65a6a22234372fa2d7af9ef psimd;https://github.com/Maratyszcza/psimd/archive/072586a71b55b7f8c584153d223e95687148a900.zip;1f5454b01f06f9656b77e4a5e2e31d7422487013 -pthreadpool;https://github.com/Maratyszcza/pthreadpool/archive/1787867f6183f056420e532eec640cba25efafea.zip;e43e80781560c5ab404a4da20f34d846f5f5d101 +pthreadpool;https://github.com/Maratyszcza/pthreadpool/archive/4fe0e1e183925bf8cfa6aae24237e724a96479b8.zip;07a0aa91dd9bf86f31b95497e00f31d8a261a4bd pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.10.1.zip;769b6aa67a77f17a770960f604b727645b6f6a13 pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/959002f82d7962a473d8bf301845f2af720e0aa4.zip;85da3caa60eb2b148613b443fbc2bfdc30689965 re2;https://github.com/google/re2/archive/refs/tags/2022-06-01.zip;aa77313b76e91b531ee7f3e45f004c6a502a5374 diff --git a/cmake/external/onnx b/cmake/external/onnx index e2525550194ce..b86cc54efce19 160000 --- a/cmake/external/onnx +++ b/cmake/external/onnx @@ -1 +1 @@ -Subproject commit e2525550194ce3d8a2c4a3af451c9d9b3ae6650e +Subproject commit b86cc54efce19530fb953e4b21f57e6b3888534c diff --git a/cmake/external/xnnpack.cmake b/cmake/external/xnnpack.cmake index 7455584f1a625..e661aa51bfc17 100644 --- a/cmake/external/xnnpack.cmake +++ b/cmake/external/xnnpack.cmake @@ -25,17 +25,23 @@ set(FXDIV_SOURCE_DIR ${fxdiv_SOURCE_DIR}) FetchContent_Declare(pthreadpool URL ${DEP_URL_pthreadpool} URL_HASH SHA1=${DEP_SHA1_pthreadpool}) onnxruntime_fetchcontent_makeavailable(pthreadpool) -FetchContent_Declare(googlexnnpack URL ${DEP_URL_googlexnnpack} URL_HASH SHA1=${DEP_SHA1_googlexnnpack} -PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/xnnpack/AddEmscriptenAndIosSupport.patch) +FetchContent_Declare(googlexnnpack URL ${DEP_URL_googlexnnpack} URL_HASH SHA1=${DEP_SHA1_googlexnnpack} + PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/xnnpack/AddEmscriptenAndIosSupport.patch + ) onnxruntime_fetchcontent_makeavailable(googlexnnpack) set(XNNPACK_DIR ${googlexnnpack_SOURCE_DIR}) set(XNNPACK_INCLUDE_DIR ${XNNPACK_DIR}/include) set(onnxruntime_EXTERNAL_LIBRARIES_XNNPACK XNNPACK pthreadpool) + # the XNNPACK CMake setup doesn't include the WASM kernels so we have to manually set those up if(CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + # See source lists in _deps/googlexnnpack-src/BUILD.bazel for wasm_prod_microkernels + message("Adding WebAssembly Source Files to XNNPACK") + set(wasm_srcs "") + file(READ "${XNNPACK_DIR}/BUILD.bazel" xnnpack_bazel_config) # Replace newlines with semicolon so that it is treated as a list by CMake @@ -70,25 +76,23 @@ if(CMAKE_SYSTEM_NAME STREQUAL "Emscripten") set(${target_srcs} ${bazel_srcs} PARENT_SCOPE) endfunction() - GetSrcListFromBazel("PROD_SCALAR_WASM_MICROKERNEL_SRCS" prod_scalar_wasm_srcs) - GetSrcListFromBazel("ALL_WASM_MICROKERNEL_SRCS" all_wasm_srcs) - GetSrcListFromBazel("WASM32_ASM_MICROKERNEL_SRCS" wasm32_asm_srcs) + GetSrcListFromBazel("OPERATOR_SRCS" operator_srcs) + GetSrcListFromBazel("TABLE_SRCS" table_srcs) + list(APPEND wasm_srcs ${operator_srcs} ${table_srcs}) - message(DEBUG "prod_scalar_wasm_srcs: ${prod_scalar_wasm_srcs}\n") - message(DEBUG "all_wasm_srcs: ${all_wasm_srcs}\n") - message(DEBUG "wasm32_asm_srcs: ${wasm32_asm_srcs}\n") + # kernels + list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/scalar.c) + list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/wasm.c) - message("Adding WebAssembly Source Files to XNNPACK") - set(wasm_srcs "") - list(APPEND wasm_srcs ${prod_scalar_wasm_srcs}) - list(APPEND wasm_srcs ${all_wasm_srcs}) - list(APPEND wasm_srcs ${wasm32_asm_srcs}) + if(onnxruntime_ENABLE_WEBASSEMBLY_SIMD) + list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/wasmsimd.c) + target_compile_options(XNNPACK PRIVATE "-msimd128") + endif() + message(DEBUG "wasm_srcs: ${wasm_srcs}\n") target_sources(XNNPACK PRIVATE ${wasm_srcs}) - if(onnxruntime_ENABLE_WEBASSEMBLY_SIMD) - GetSrcListFromBazel("ALL_WASMSIMD_MICROKERNEL_SRCS" all_wasmsimd_srcs) - message(DEBUG "all_wasmsimd_srcs: ${all_wasmsimd_srcs}") - target_sources(XNNPACK PRIVATE ${all_wasmsimd_srcs}) - endif() + # add flags from BAZEL.build + target_compile_options(XNNPACK PRIVATE "-fno-fast-math") + target_compile_options(XNNPACK PRIVATE "-fno-math-errno") endif() diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 6ccaf00499e95..9d9b006c595bb 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -282,44 +282,77 @@ endif() # Assemble the Apple static framework (iOS and macOS) if(onnxruntime_BUILD_APPLE_FRAMEWORK) + if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") + set(STATIC_FRAMEWORK_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}-${CMAKE_OSX_SYSROOT}) + else() # macOS + set(STATIC_FRAMEWORK_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}) + endif() + + # Setup the various directories required. Remove any existing ones so we start with a clean directory. set(STATIC_LIB_DIR ${CMAKE_CURRENT_BINARY_DIR}/static_libraries) - file(MAKE_DIRECTORY ${STATIC_LIB_DIR}) + set(STATIC_LIB_TEMP_DIR ${STATIC_LIB_DIR}/temp) + add_custom_command(TARGET onnxruntime PRE_BUILD COMMAND ${CMAKE_COMMAND} -E rm -rf ${STATIC_LIB_DIR}) + add_custom_command(TARGET onnxruntime PRE_BUILD COMMAND ${CMAKE_COMMAND} -E make_directory ${STATIC_LIB_DIR}) + add_custom_command(TARGET onnxruntime PRE_BUILD COMMAND ${CMAKE_COMMAND} -E make_directory ${STATIC_LIB_TEMP_DIR}) - # Remove the existing files in the STATIC_LIB_DIR folder - file(GLOB _OLD_STATIC_LIBS ${STATIC_LIB_DIR}/*.a) - file(REMOVE "${_OLD_STATIC_LIBS}") + set(STATIC_FRAMEWORK_DIR ${STATIC_FRAMEWORK_OUTPUT_DIR}/static_framework/onnxruntime.framework) + add_custom_command(TARGET onnxruntime PRE_BUILD COMMAND ${CMAKE_COMMAND} -E rm -rf ${STATIC_FRAMEWORK_DIR}) + add_custom_command(TARGET onnxruntime PRE_BUILD COMMAND ${CMAKE_COMMAND} -E make_directory ${STATIC_FRAMEWORK_DIR}) + + # replicate XCode's Single Object Pre-Link + # link the internal onnxruntime .o files with the external .a files into a single relocatable object + # to enforce symbol visibility. doing it this way limits the symbols included from the .a files to symbols used + # by the ORT .o files. - # Go through all the static libraries, and create symbolic links - foreach(_LIB ${onnxruntime_INTERNAL_LIBRARIES} ${onnxruntime_EXTERNAL_LIBRARIES}) + # If it's an onnxruntime library, extract .o files to a separate directory for each library to avoid any clashes + # with filenames (e.g. utils.o) + foreach(_LIB ${onnxruntime_INTERNAL_LIBRARIES} ) GET_TARGET_PROPERTY(_LIB_TYPE ${_LIB} TYPE) if(_LIB_TYPE STREQUAL "STATIC_LIBRARY") - add_custom_command(TARGET onnxruntime POST_BUILD COMMAND ${CMAKE_COMMAND} -E create_symlink $ ${STATIC_LIB_DIR}/$) + set(CUR_STATIC_LIB_OBJ_DIR ${STATIC_LIB_TEMP_DIR}/$) + add_custom_command(TARGET onnxruntime POST_BUILD + COMMAND ${CMAKE_COMMAND} -E make_directory ${CUR_STATIC_LIB_OBJ_DIR}) + + add_custom_command(TARGET onnxruntime POST_BUILD + COMMAND ar ARGS -x $ + WORKING_DIRECTORY ${CUR_STATIC_LIB_OBJ_DIR}) endif() endforeach() - if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") - set(STATIC_FRAMEWORK_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}-${CMAKE_OSX_SYSROOT}) - else() # macOS - set(STATIC_FRAMEWORK_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}) - endif() + # for external libraries we create a symlink to the .a file + foreach(_LIB ${onnxruntime_EXTERNAL_LIBRARIES}) + GET_TARGET_PROPERTY(_LIB_TYPE ${_LIB} TYPE) + if(_LIB_TYPE STREQUAL "STATIC_LIBRARY") + add_custom_command(TARGET onnxruntime POST_BUILD + COMMAND ${CMAKE_COMMAND} -E create_symlink + $ ${STATIC_LIB_DIR}/$) + endif() + endforeach() - # Assemble the static framework - set(STATIC_FRAMEWORK_DIR ${STATIC_FRAMEWORK_OUTPUT_DIR}/static_framework/onnxruntime.framework) - set(STATIC_FRAMEWORK_HEADER_DIR ${STATIC_FRAMEWORK_DIR}/Headers) - file(MAKE_DIRECTORY ${STATIC_FRAMEWORK_DIR}) - # Remove all files under STATIC_FRAMEWORK_DIR (if any) - file(GLOB_RECURSE _OLD_STATIC_FRAMEWORK ${STATIC_FRAMEWORK_DIR}/*.*) - file(REMOVE "${_OLD_STATIC_FRAMEWORK}") + # do the pre-link with `ld -r` to create a single relocatable object with correct symbol visibility + add_custom_command(TARGET onnxruntime POST_BUILD + COMMAND ld ARGS -r -o ${STATIC_LIB_DIR}/prelinked_objects.o */*.o ../*.a + WORKING_DIRECTORY ${STATIC_LIB_TEMP_DIR}) + + # create the static library + add_custom_command(TARGET onnxruntime POST_BUILD + COMMAND libtool -static -o ${STATIC_FRAMEWORK_DIR}/onnxruntime prelinked_objects.o + WORKING_DIRECTORY ${STATIC_LIB_DIR}) + # Assemble the other pieces of the static framework + add_custom_command(TARGET onnxruntime POST_BUILD + COMMAND ${CMAKE_COMMAND} -E + copy_if_different ${INFO_PLIST_PATH} ${STATIC_FRAMEWORK_DIR}/Info.plist) + + # add the framework header files + set(STATIC_FRAMEWORK_HEADER_DIR ${STATIC_FRAMEWORK_DIR}/Headers) file(MAKE_DIRECTORY ${STATIC_FRAMEWORK_HEADER_DIR}) - # copy the header files one by one, and the Info.plist foreach(h_ ${ONNXRUNTIME_PUBLIC_HEADERS}) get_filename_component(HEADER_NAME_ ${h_} NAME) - add_custom_command(TARGET onnxruntime POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different ${h_} ${STATIC_FRAMEWORK_HEADER_DIR}/${HEADER_NAME_}) + add_custom_command(TARGET onnxruntime POST_BUILD + COMMAND ${CMAKE_COMMAND} -E + copy_if_different ${h_} ${STATIC_FRAMEWORK_HEADER_DIR}/${HEADER_NAME_}) endforeach() - add_custom_command(TARGET onnxruntime POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different ${INFO_PLIST_PATH} ${STATIC_FRAMEWORK_DIR}/Info.plist) - # link the static library - add_custom_command(TARGET onnxruntime POST_BUILD COMMAND libtool -static -o ${STATIC_FRAMEWORK_DIR}/onnxruntime *.a WORKING_DIRECTORY ${STATIC_LIB_DIR}) endif() diff --git a/cmake/onnxruntime_graph.cmake b/cmake/onnxruntime_graph.cmake index 735c86956ec4f..3f532ec2c3261 100644 --- a/cmake/onnxruntime_graph.cmake +++ b/cmake/onnxruntime_graph.cmake @@ -20,6 +20,8 @@ if (onnxruntime_MINIMAL_BUILD) "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/onnx_deprecated_operators.cc" "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/onnx_function_util.h" "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/onnx_function_util.cc" + "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/shape_inference_functions.h" + "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/shape_inference_functions.cc" "${ONNXRUNTIME_ROOT}/core/graph/function_template.h" "${ONNXRUNTIME_ROOT}/core/graph/function_utils.h" "${ONNXRUNTIME_ROOT}/core/graph/function_utils.cc" diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 992908392c946..a62b1b259d109 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -339,12 +339,18 @@ else() set(mlas_platform_srcs ${mlas_platform_srcs} ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S + ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S ${MLAS_SRC_DIR}/activate_fp16.cpp ${MLAS_SRC_DIR}/dwconv.cpp ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/pooling_fp16.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 003012f8da071..f2a16fb29dc62 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -38,6 +38,11 @@ "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_matmul.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_slice.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reshape.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_expand.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reduce.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_unsqueeze.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_squeeze.cc" ) endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio @@ -246,4 +251,4 @@ install(TARGETS onnxruntime_providers_cuda ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) \ No newline at end of file + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/cmake/onnxruntime_providers_migraphx.cmake b/cmake/onnxruntime_providers_migraphx.cmake index 632600288bec9..91ac66a40721d 100644 --- a/cmake/onnxruntime_providers_migraphx.cmake +++ b/cmake/onnxruntime_providers_migraphx.cmake @@ -42,7 +42,7 @@ onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) add_dependencies(onnxruntime_providers_migraphx onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) target_link_libraries(onnxruntime_providers_migraphx PRIVATE ${migraphx_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) - target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR}) + target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime) set_target_properties(onnxruntime_providers_migraphx PROPERTIES LINKER_LANGUAGE CXX) set_target_properties(onnxruntime_providers_migraphx PROPERTIES FOLDER "ONNXRuntime") target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1) @@ -72,4 +72,4 @@ ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - ) \ No newline at end of file + ) diff --git a/cmake/onnxruntime_providers_rocm.cmake b/cmake/onnxruntime_providers_rocm.cmake index 2bb4c7d600a09..b66268291579c 100644 --- a/cmake/onnxruntime_providers_rocm.cmake +++ b/cmake/onnxruntime_providers_rocm.cmake @@ -10,6 +10,7 @@ find_package(hiprand REQUIRED) find_package(rocblas REQUIRED) find_package(MIOpen REQUIRED) + find_package(hipfft REQUIRED) # MIOpen version if(NOT DEFINED ENV{MIOPEN_PATH}) @@ -48,7 +49,7 @@ find_library(RCCL_LIB rccl REQUIRED) find_library(ROCTRACER_LIB roctracer64 REQUIRED) - set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen ${RCCL_LIB} ${ROCTRACER_LIB}) + set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen hip::hipfft ${RCCL_LIB} ${ROCTRACER_LIB}) file(GLOB_RECURSE onnxruntime_providers_rocm_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/rocm/*.h" @@ -219,4 +220,4 @@ install(TARGETS onnxruntime_providers_rocm ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) \ No newline at end of file + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake index ea94734afe332..686a993de3a4a 100644 --- a/cmake/onnxruntime_providers_tensorrt.cmake +++ b/cmake/onnxruntime_providers_tensorrt.cmake @@ -57,12 +57,14 @@ URL ${DEP_URL_onnx_tensorrt} URL_HASH SHA1=${DEP_SHA1_onnx_tensorrt} ) + if (NOT CUDA_INCLUDE_DIR) + set(CUDA_INCLUDE_DIR ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) # onnx-tensorrt repo needs this variable to build + endif() # The onnx_tensorrt repo contains a test program, getSupportedAPITest, which doesn't support Windows. It uses # unistd.h. So we must exclude it from our build. onnxruntime_fetchcontent_makeavailable is for the purpose. onnxruntime_fetchcontent_makeavailable(onnx_tensorrt) include_directories(${onnx_tensorrt_SOURCE_DIR}) set(CMAKE_CXX_FLAGS ${OLD_CMAKE_CXX_FLAGS}) - set(CUDA_INCLUDE_DIR ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) # onnx-tensorrt repo needs this variable to build if ( CMAKE_COMPILER_IS_GNUCC ) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") endif() @@ -110,7 +112,6 @@ endif() # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found - set_target_properties(onnxruntime_providers_tensorrt PROPERTIES PUBLIC_HEADER ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h) set_target_properties(onnxruntime_providers_tensorrt PROPERTIES LINKER_LANGUAGE CUDA) set_target_properties(onnxruntime_providers_tensorrt PROPERTIES FOLDER "ONNXRuntime") target_compile_definitions(onnxruntime_providers_tensorrt PRIVATE ONNXIFI_BUILD_LIBRARY=1) @@ -141,7 +142,6 @@ endif() install(TARGETS onnxruntime_providers_tensorrt - PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) \ No newline at end of file + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/cmake/onnxruntime_providers_xnnpack.cmake b/cmake/onnxruntime_providers_xnnpack.cmake index 30ae90e6564cc..9c00703ca0846 100644 --- a/cmake/onnxruntime_providers_xnnpack.cmake +++ b/cmake/onnxruntime_providers_xnnpack.cmake @@ -15,7 +15,8 @@ source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_xnnpack_cc_srcs}) onnxruntime_add_static_library(onnxruntime_providers_xnnpack ${onnxruntime_providers_xnnpack_cc_srcs}) onnxruntime_add_include_to_target(onnxruntime_providers_xnnpack - onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} XNNPACK pthreadpool flatbuffers::flatbuffers Boost::mp11 safeint_interface + onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} XNNPACK pthreadpool + flatbuffers::flatbuffers Boost::mp11 safeint_interface ) add_dependencies(onnxruntime_providers_xnnpack onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) @@ -35,4 +36,4 @@ # there are some in builds where sizeof(size_t) != sizeof(int64_t), e.g., in 'ONNX Runtime Web CI Pipeline' if (HAS_SHORTEN_64_TO_32 AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8) target_compile_options(onnxruntime_providers_xnnpack PRIVATE -Wno-error=shorten-64-to-32) - endif() \ No newline at end of file + endif() diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index bf9adbaefabcc..a9a78668b4810 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -387,6 +387,9 @@ if (onnxruntime_ENABLE_TRAINING) file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_fused_ops_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/*" ) + file(GLOB onnxruntime_python_ortmodule_graph_optimizers_srcs CONFIGURE_DEPENDS + "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/graph_optimizers/*" + ) file(GLOB onnxruntime_python_ort_triton_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/ort_triton/*.py" ) @@ -741,6 +744,7 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/graph_optimizers COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ort_triton COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ort_triton/kernel COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/utils @@ -794,6 +798,9 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_ortmodule_torch_cpp_ext_fused_ops_srcs} $/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_ortmodule_graph_optimizers_srcs} + $/onnxruntime/training/ortmodule/graph_optimizers/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_ort_triton_srcs} $/onnxruntime/training/ort_triton/ diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 55d03c14270d3..ce27bf756f247 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -48,19 +48,23 @@ set(contrib_ops_excluded_files "diffusion/group_norm_impl.cu" "diffusion/group_norm_impl.h" "diffusion/nhwc_conv.cc" - "math/complex_mul.cc" - "math/complex_mul.h" - "math/complex_mul_impl.cu" - "math/complex_mul_impl.h" - "math/cufft_plan_cache.h" - "math/fft_ops.cc" - "math/fft_ops.h" - "math/fft_ops_impl.cu" - "math/fft_ops_impl.h" + "math/gemm_float8.cc" + "math/gemm_float8.cu" + "math/gemm_float8.h" "quantization/attention_quantization.cc" "quantization/attention_quantization.h" "quantization/attention_quantization_impl.cu" "quantization/attention_quantization_impl.cuh" + "quantization/dequantize_blockwise.cuh" + "quantization/dequantize_blockwise.cu" + "quantization/dequantize_blockwise_bnb4.cuh" + "quantization/dequantize_blockwise_bnb4.cu" + "quantization/matmul_bnb4.cc" + "quantization/matmul_bnb4.cuh" + "quantization/matmul_bnb4.cu" + "quantization/matmul_nbits.cc" + "quantization/matmul_nbits.cuh" + "quantization/matmul_nbits.cu" "quantization/quantize_dequantize_linear.cc" "quantization/qordered_ops/qordered_attention_impl.cu" "quantization/qordered_ops/qordered_attention_impl.h" @@ -86,23 +90,15 @@ set(contrib_ops_excluded_files "quantization/qordered_ops/qordered_unary_ops.cc" "quantization/qordered_ops/qordered_unary_ops_impl.h" "quantization/qordered_ops/qordered_unary_ops_impl.cu" - "tensor/crop.cc" - "tensor/crop.h" - "tensor/crop_impl.cu" - "tensor/crop_impl.h" - "tensor/dynamicslice.cc" - "tensor/image_scaler.cc" - "tensor/image_scaler.h" - "tensor/image_scaler_impl.cu" - "tensor/image_scaler_impl.h" - "transformers/greedy_search.cc" - "transformers/greedy_search.h" - "conv_transpose_with_dynamic_pads.cc" - "conv_transpose_with_dynamic_pads.h" "cuda_contrib_kernels.cc" "cuda_contrib_kernels.h" "inverse.cc" "fused_conv.cc" + "bert/group_query_attention_helper.h" + "bert/group_query_attention.h" + "bert/group_query_attention.cc" + "bert/group_query_attention_impl.h" + "bert/group_query_attention_impl.cu" ) if (NOT onnxruntime_ENABLE_ATEN) @@ -115,14 +111,16 @@ if (NOT onnxruntime_USE_NCCL) list(APPEND contrib_ops_excluded_files "collective/sharding.cc") list(APPEND contrib_ops_excluded_files "collective/sharding_spec.cc") list(APPEND contrib_ops_excluded_files "collective/distributed_matmul.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_slice.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_reshape.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_expand.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_reduce.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_unsqueeze.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_squeeze.cc") endif() set(provider_excluded_files "atomic/common.cuh" - "controlflow/loop.cc" - "controlflow/loop.h" - "controlflow/scan.cc" - "controlflow/scan.h" "cu_inc/common.cuh" "math/einsum_utils/einsum_auxiliary_ops.cc" "math/einsum_utils/einsum_auxiliary_ops.h" @@ -170,7 +168,6 @@ set(provider_excluded_files "cuda_memory_check.h" "cuda_fence.cc" "cuda_fence.h" - "cuda_fwd.h" "cuda_kernel.h" "cuda_pch.cc" "cuda_pch.h" @@ -190,6 +187,8 @@ set(provider_excluded_files "gpu_data_transfer.h" "integer_gemm.cc" "tunable/*" + "cuda_nhwc_kernels.cc" + "cuda_nhwc_kernels.h" ) set(training_ops_excluded_files diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index ac9770d7cedf8..bdb0230a8ebd0 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -41,7 +41,7 @@ function(AddTest) if (MSVC) target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd6330>" "$<$>:/wd6330>") - #Abseil has a lot of C4127/C4324 warnings. + #Abseil has a lot of C4127/C4324 warnings. target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd4127>" "$<$>:/wd4127>") target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd4324>" @@ -201,8 +201,18 @@ function(AddTest) list(APPEND TEST_NODE_FLAGS "--experimental-wasm-simd") endif() + # prefer Node from emsdk so the version is more deterministic + if (DEFINED ENV{EMSDK_NODE}) + set(NODE_EXECUTABLE $ENV{EMSDK_NODE}) + else() + # warning as we don't know what node version is being used and whether things like the TEST_NODE_FLAGS + # will be valid. e.g. "--experimental-wasm-simd" is not valid with node v20 or later. + message(WARNING "EMSDK_NODE environment variable was not set. Falling back to system `node`.") + set(NODE_EXECUTABLE node) + endif() + add_test(NAME ${_UT_TARGET} - COMMAND node ${TEST_NODE_FLAGS} ${_UT_TARGET}.js ${TEST_ARGS} + COMMAND ${NODE_EXECUTABLE} ${TEST_NODE_FLAGS} ${_UT_TARGET}.js ${TEST_ARGS} WORKING_DIRECTORY $ ) endif() @@ -374,6 +384,13 @@ if (onnxruntime_USE_CUDA AND NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_R "${TEST_SRC_DIR}/providers/cuda/*" ) list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_cuda_src}) + + if (onnxruntime_USE_CUDA_NHWC_OPS) + file(GLOB onnxruntime_test_providers_cuda_nhwc_src CONFIGURE_DEPENDS + "${TEST_SRC_DIR}/providers/cuda/nhwc/*.cc" + ) + list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_cuda_nhwc_src}) + endif() endif() if (onnxruntime_USE_CANN) @@ -851,7 +868,7 @@ if (HAS_SHORTEN_64_TO_32 AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8) endif() if (UNIX AND onnxruntime_USE_TENSORRT) - # The test_main.cc includes NvInfer.h where it has many deprecated declarations + # The test_main.cc includes NvInfer.h where it has many deprecated declarations # simply ignore them for TensorRT EP build set_property(TARGET onnxruntime_test_all APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations") endif() @@ -1294,7 +1311,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() if (UNIX AND onnxruntime_USE_TENSORRT) - # The test_main.cc includes NvInfer.h where it has many deprecated declarations + # The test_main.cc includes NvInfer.h where it has many deprecated declarations # simply ignore them for TensorRT EP build set_property(TARGET onnxruntime_shared_lib_test APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations") endif() @@ -1583,7 +1600,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") endif() if (UNIX AND onnxruntime_USE_TENSORRT) - # The test_main.cc includes NvInfer.h where it has many deprecated declarations + # The test_main.cc includes NvInfer.h where it has many deprecated declarations # simply ignore them for TensorRT EP build set_property(TARGET onnxruntime_customopregistration_test APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations") endif() diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index c6510c97a617e..9014089cb6112 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -192,8 +192,13 @@ else() onnxruntime_util re2::re2 ) + + set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8'") + if (onnxruntime_USE_XNNPACK) target_link_libraries(onnxruntime_webassembly PRIVATE XNNPACK) + string(APPEND EXPORTED_RUNTIME_METHODS ",'addFunction'") + target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s ALLOW_TABLE_GROWTH=1") endif() if(onnxruntime_USE_WEBNN) @@ -204,7 +209,6 @@ else() target_link_libraries(onnxruntime_webassembly PRIVATE tensorboard) endif() - set(EXPORTED_RUNTIME_METHODS "['stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8']") if (onnxruntime_USE_JSEP) set(EXPORTED_FUNCTIONS "_malloc,_free,_JsepOutput,_JsepGetNodeName") else() @@ -212,7 +216,7 @@ else() endif() target_link_options(onnxruntime_webassembly PRIVATE - "SHELL:-s EXPORTED_RUNTIME_METHODS=${EXPORTED_RUNTIME_METHODS}" + "SHELL:-s EXPORTED_RUNTIME_METHODS=[${EXPORTED_RUNTIME_METHODS}]" "SHELL:-s EXPORTED_FUNCTIONS=${EXPORTED_FUNCTIONS}" "SHELL:-s MAXIMUM_MEMORY=4294967296" "SHELL:-s EXIT_RUNTIME=0" diff --git a/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch b/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch index 37bdbf9fb53f6..460b4d97c499b 100644 --- a/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch +++ b/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch @@ -1,66 +1,27 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index d53c48aa1..77c3cf983 100755 +index dba9b4687..bcaa18ad7 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -105,22 +105,12 @@ ENDIF() - +@@ -122,7 +122,7 @@ ENDIF() + # ---[ Build flags IF(NOT CMAKE_SYSTEM_NAME) MESSAGE(FATAL_ERROR "CMAKE_SYSTEM_NAME not defined") --ELSEIF(NOT CMAKE_SYSTEM_NAME MATCHES "^(Darwin|Linux|Android|Windows|CYGWIN|MSYS)$") -+ELSEIF(NOT CMAKE_SYSTEM_NAME MATCHES "^(Darwin|Linux|Android|Windows|CYGWIN|MSYS|Emscripten|iOS)$") - MESSAGE(FATAL_ERROR "Unrecognized CMAKE_SYSTEM_NAME = ${CMAKE_SYSTEM_NAME}") +-ELSEIF(NOT CMAKE_SYSTEM_NAME MATCHES "^(Android|Darwin|iOS|Linux|Windows|CYGWIN|MSYS|QURT)$") ++ELSEIF(NOT CMAKE_SYSTEM_NAME MATCHES "^(Android|Darwin|iOS|Linux|Windows|CYGWIN|MSYS|QURT|Emscripten|iOS)$") + MESSAGE(FATAL_ERROR "Unrecognized CMAKE_SYSTEM_NAME value \"${CMAKE_SYSTEM_NAME}\"") ENDIF() - - # ---[ Download deps - IF(NOT XNNPACK_USE_SYSTEM_LIBS) -- IF(NOT DEFINED CLOG_SOURCE_DIR) -- MESSAGE(STATUS "Downloading clog to ${CMAKE_BINARY_DIR}/clog-source (define CLOG_SOURCE_DIR to avoid it)") -- CONFIGURE_FILE(cmake/DownloadCLog.cmake "${CMAKE_BINARY_DIR}/clog-download/CMakeLists.txt") -- EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . -- WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/clog-download") -- EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" --build . -- WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/clog-download") -- SET(CLOG_SOURCE_DIR "${CMAKE_BINARY_DIR}/clog-source" CACHE STRING "clog source directory") -- ENDIF() -- - IF(NOT DEFINED CPUINFO_SOURCE_DIR) - MESSAGE(STATUS "Downloading cpuinfo to ${CMAKE_BINARY_DIR}/cpuinfo-source (define CPUINFO_SOURCE_DIR to avoid it)") - CONFIGURE_FILE(cmake/DownloadCpuinfo.cmake "${CMAKE_BINARY_DIR}/cpuinfo-download/CMakeLists.txt") -@@ -7108,6 +7098,10 @@ IF(MSVC) - SET_PROPERTY(SOURCE ${ALL_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$>: /O2 >") - SET_PROPERTY(SOURCE ${HOT_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$>: /O2 >") - SET_PROPERTY(SOURCE ${COLD_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$>: /O1 >") -+ELSEIF(CMAKE_GENERATOR STREQUAL Xcode) -+ TARGET_COMPILE_OPTIONS(all_microkernels PRIVATE $<$>: -O2 >) -+ TARGET_COMPILE_OPTIONS(XNNPACK PRIVATE $<$>: -O2 >) -+ TARGET_COMPILE_OPTIONS(XNNPACK PRIVATE $<$>: -Os >) - ELSE() - SET_PROPERTY(SOURCE ${ALL_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$>: -O2 >") - SET_PROPERTY(SOURCE ${HOT_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$>: -O2 >") -@@ -7142,26 +7136,6 @@ IF(LIBM) - TARGET_LINK_LIBRARIES(indirection PRIVATE ${LIBM}) + IF(CMAKE_SYSTEM_NAME MATCHES "Windows") +@@ -534,7 +534,12 @@ IF(XNNPACK_BUILD_LIBRARY) + TARGET_LINK_LIBRARIES(operator-utils PRIVATE logging) + TARGET_LINK_LIBRARIES(post-operation PRIVATE logging) + TARGET_LINK_LIBRARIES(subgraph PRIVATE allocator logging memory mutex operators operator-run) +- TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection jit logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing post-operation microkernels-prod subgraph) ++ IF(CMAKE_SYSTEM_NAME STREQUAL "Emscripten") ++ # omit microkernels-prod as the list is manually created by ORT in cmake/external/xnnpack.cmake ++ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection jit logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing post-operation subgraph) ++ ELSE() ++ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection jit logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing post-operation microkernels-prod subgraph) ++ ENDIF() + SET_TARGET_PROPERTIES(XNNPACK PROPERTIES C_EXTENSIONS YES) ENDIF() - --# ---[ Configure clog --IF(NOT TARGET clog) -- IF(NOT XNNPACK_USE_SYSTEM_LIBS) -- SET(CLOG_BUILD_TESTS OFF CACHE BOOL "") -- SET(CLOG_RUNTIME_TYPE "${CPUINFO_RUNTIME_TYPE}" CACHE STRING "") -- ADD_SUBDIRECTORY( -- "${CLOG_SOURCE_DIR}/deps/clog" -- "${CMAKE_BINARY_DIR}/clog") -- # We build static version of clog but a dynamic library may indirectly depend on it -- SET_PROPERTY(TARGET clog PROPERTY POSITION_INDEPENDENT_CODE ON) -- ELSE() -- ADD_LIBRARY(clog STATIC IMPORTED) -- FIND_LIBRARY(CLOG_LIBRARY clog) -- IF(NOT CLOG_LIBRARY) -- MESSAGE(FATAL_ERROR "Cannot find clog") -- ENDIF() -- SET_PROPERTY(TARGET clog PROPERTY IMPORTED_LOCATION "${CLOG_LIBRARY}") -- ENDIF() --ENDIF() -- - # ---[ Configure cpuinfo - IF(NOT TARGET cpuinfo) - IF(NOT XNNPACK_USE_SYSTEM_LIBS) + IF(NOT MSVC) diff --git a/csharp/OnnxRuntime.CSharp.proj b/csharp/OnnxRuntime.CSharp.proj index 0288d752d8749..69bfd9896f1e4 100644 --- a/csharp/OnnxRuntime.CSharp.proj +++ b/csharp/OnnxRuntime.CSharp.proj @@ -17,9 +17,13 @@ CMake creates a target to this project x64 false - false + true + true None + + true + .. ..\tools\nuget\generate_nuspec_for_native_nuget.py @@ -30,13 +34,15 @@ CMake creates a target to this project ..\build\Linux $(OnnxRuntimeBuildDirectory)\packages $(OnnxRuntimeBuildDirectory)\$(Configuration) + python3 - + ..\build\Windows $(OnnxRuntimeBuildDirectory)\packages $(OnnxRuntimeBuildDirectory)\$(Configuration)\$(Configuration) + python @@ -86,28 +92,48 @@ CMake creates a target to this project - + + + Properties="NoBuild=true;Platform=AnyCPU;PackageVersion=$(PackageVersion);OrtPackageId=$(OrtPackageId);IncludeMobileTargets=$(IncludeMobileTargets)"/> - - + + + - - + + + - + + + + - diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj index 29ccf55f081d5..0c74a23204d4f 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj @@ -4,66 +4,53 @@ Microsoft.ML.OnnxRuntime - - PreNet6 - netstandard2.0;netcoreapp3.1;net6.0 + true + netstandard2.0 + - - - xamarinios10;monoandroid11.0 + + + false - - monoandroid11.0 - + + NOTE: We include in a build of the managed package when creating Microsoft.ML.OnnxRuntime.Gpu as both + the CPU and GPU packaging pipelines can publish Microsoft.ML.OnnxRuntime.Managed, and we need the targets + to be consistent in both. + --> - net6.0;net6.0-android;net6.0-ios;net6.0-macos + '$(OrtPackageId)' == 'Microsoft.ML.OnnxRuntime.Gpu') AND + '$(IncludeMobileTargets)' == 'true' AND + Exists('$(MSBuildExtensionsPath)\Xamarin\Android') AND + Exists('$(MSBuildExtensionsPath)\Xamarin\iOS')"> + xamarinios10;monoandroid11.0 - - net6.0;net6.0-android + + monoandroid11.0 - - $(BaseTargets);$(XamarinTargets);$(XamarinTargetsForTraining) + + + $(MobileTargets);net6.0-android;net6.0-ios - - $(Net6Targets);$(Net6TargetsForTrainingPackage) + + $(MobileTargets);net6.0-android - - - $(BaseTargets);$(XamarinTargets);$(XamarinTargetsForTraining);$(Net6Targets);$(Net6TargetsForTrainingPackage) + + $(BaseTargets);$(MobileTargets) - AnyCPU;x86 default @@ -204,8 +191,9 @@ $(DefineConstants);$(OrtConstants) - + @@ -214,7 +202,6 @@ - --> + /// \param length_in_bytes Length to resize the string to. + /// \param buffer The resized buffer. + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetResizedStringTensorElementBuffer( - IntPtr /* OrtValue */ value, - UIntPtr /* size_t */ index, - UIntPtr /* size_t */ length_in_bytes, - out IntPtr /* char** */ buffer - ); + IntPtr /* OrtValue */ value, + UIntPtr /* size_t */ index, + UIntPtr /* size_t */ length_in_bytes, + out IntPtr /* char** */ buffer); public static DOrtGetResizedStringTensorElementBuffer OrtGetResizedStringTensorElementBuffer; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorContent( - IntPtr /*(OrtValue*)*/ value, - byte[] /*(void*)*/ dst_buffer, - UIntPtr dst_buffer_len, - UIntPtr[] offsets, - UIntPtr offsets_len); + IntPtr /*(OrtValue*)*/ value, + byte[] /*(void*)*/ dst_buffer, + UIntPtr dst_buffer_len, + UIntPtr[] offsets, + UIntPtr offsets_len); public static DOrtGetStringTensorContent OrtGetStringTensorContent; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorDataLength(IntPtr /*(OrtValue*)*/ value, - out UIntPtr /*(size_t*)*/ len); + out UIntPtr /*(size_t*)*/ len); public static DOrtGetStringTensorDataLength OrtGetStringTensorDataLength; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorElementLength(IntPtr /*(OrtValue*)*/ value, - UIntPtr /*(size_t)*/ index, - out UIntPtr /*(size_t*)*/ len); + UIntPtr /*(size_t)*/ index, + out UIntPtr /*(size_t*)*/ len); public static DOrtGetStringTensorElementLength OrtGetStringTensorElementLength; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorElement(IntPtr /*(OrtValue*)*/ value, - UIntPtr /*(size_t)*/ bufferLength, - UIntPtr /*(size_t)*/ elementIndex, - byte[] buffer); + UIntPtr /*(size_t)*/ bufferLength, + UIntPtr /*(size_t)*/ elementIndex, + byte[] buffer); public static DOrtGetStringTensorElement OrtGetStringTensorElement; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ - DOrtCastTypeInfoToTensorInfo(IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, out IntPtr /*(const struct OrtTensorTypeAndShapeInfo**)*/ typeAndShapeInfo); + public delegate IntPtr /*(OrtStatus*)*/ DOrtCastTypeInfoToTensorInfo( + IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, + out IntPtr /*(const struct OrtTensorTypeAndShapeInfo**)*/ typeAndShapeInfo); public static DOrtCastTypeInfoToTensorInfo OrtCastTypeInfoToTensorInfo; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorTypeAndShape(IntPtr /*(OrtValue*)*/ value, out IntPtr /*(struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo); + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorTypeAndShape( + IntPtr /*(OrtValue*)*/ value, + out IntPtr /*(struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo); public static DOrtGetTensorTypeAndShape OrtGetTensorTypeAndShape; @@ -1917,12 +1924,16 @@ out IntPtr /* char** */ buffer public static DOrtReleaseTensorTypeAndShapeInfo OrtReleaseTensorTypeAndShapeInfo; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorElementType(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, out IntPtr /*(TensorElementType*)*/ output); + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorElementType( + IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, + out IntPtr /*(TensorElementType*)*/ output); public static DOrtGetTensorElementType OrtGetTensorElementType; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DOrtGetDimensionsCount(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, out UIntPtr output); + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetDimensionsCount( + IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, + out UIntPtr output); public static DOrtGetDimensionsCount OrtGetDimensionsCount; diff --git a/csharp/tools/linux_pack/LinuxPackNativeNuget.csproj b/csharp/tools/linux_pack/LinuxPackNativeNuget.csproj new file mode 100644 index 0000000000000..098078d2e3683 --- /dev/null +++ b/csharp/tools/linux_pack/LinuxPackNativeNuget.csproj @@ -0,0 +1,15 @@ + + + + netstandard2.0 + $(OnnxRuntimeBuildDirectory)/NativeNuget.nuspec + + diff --git a/dockerfiles/Dockerfile.source b/dockerfiles/Dockerfile.source index 110e484e77d21..5822a805c674e 100644 --- a/dockerfiles/Dockerfile.source +++ b/dockerfiles/Dockerfile.source @@ -8,13 +8,14 @@ FROM mcr.microsoft.com/cbl-mariner/base/python:3 MAINTAINER Changming Sun "chasun@microsoft.com" ADD . /code -RUN tdnf install -y tar ca-certificates build-essential python3-numpy cmake python3-setuptools python3-wheel python3-pip curl python3-devel +RUN tdnf install -y tar ca-certificates build-essential cmake curl python3-devel python3-setuptools python3-wheel python3-pip python3-numpy python3-flatbuffers python3-packaging python3-protobuf +# The latest cmake version in Mariner2 is 3.21, but we need 3.26+ RUN /code/dockerfiles/scripts/install_cmake.sh # Prepare onnxruntime repository & build onnxruntime -RUN cd /code && python3 -m pip install -r tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt && /bin/bash ./build.sh --allow_running_as_root --skip_submodule_sync --config Release --build_wheel --update --build --parallel --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) +RUN cd /code && /bin/bash ./build.sh --allow_running_as_root --skip_submodule_sync --config Release --build_wheel --update --build --parallel --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) FROM mcr.microsoft.com/cbl-mariner/base/python:3 COPY --from=0 /code/build/Linux/Release/dist /root COPY --from=0 /code/dockerfiles/LICENSE-IMAGE.txt /code/LICENSE-IMAGE.txt -RUN tdnf install -y ca-certificates python3-setuptools python3-wheel python3-pip && python3 -m pip install /root/*.whl && rm -rf /root/*.whl +RUN tdnf install -y ca-certificates python3-setuptools python3-wheel python3-pip python3-numpy python3-flatbuffers python3-packaging python3-protobuf python3-mpmath python3-sympy && python3 -m pip install coloredlogs humanfriendly && python3 -m pip install --no-index --find-links /root onnxruntime && rm -rf /root/*.whl diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 2a16bdbf7b55d..646465ef8b56f 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -27,6 +27,7 @@ Do not modify directly.* * com.microsoft.DequantizeWithOrder * com.microsoft.DynamicQuantizeLSTM * com.microsoft.DynamicQuantizeMatMul + * com.microsoft.DynamicTimeWarping * com.microsoft.EPContext * com.microsoft.EmbedLayerNormalization * com.microsoft.ExpandDims @@ -39,6 +40,7 @@ Do not modify directly.* * com.microsoft.GatherND * com.microsoft.Gelu * com.microsoft.GemmFastGelu + * com.microsoft.GemmFloat8 * com.microsoft.GreedySearch * com.microsoft.GridSample * com.microsoft.GroupNorm @@ -46,9 +48,11 @@ Do not modify directly.* * com.microsoft.Inverse * com.microsoft.Irfft * com.microsoft.LongformerAttention + * com.microsoft.MatMulBnb4 * com.microsoft.MatMulFpQ4 * com.microsoft.MatMulInteger16 * com.microsoft.MatMulIntegerToFloat + * com.microsoft.MatMulNBits * com.microsoft.MaxpoolWithMask * com.microsoft.MulInteger * com.microsoft.MultiHeadAttention @@ -88,8 +92,10 @@ Do not modify directly.* * com.microsoft.RemovePadding * com.microsoft.RestorePadding * com.microsoft.Rfft + * com.microsoft.RotaryEmbedding * com.microsoft.SampleOp * com.microsoft.Sampling + * com.microsoft.SkipGroupNorm * com.microsoft.SkipLayerNormalization * com.microsoft.SkipSimplifiedLayerNormalization * com.microsoft.Snpe @@ -98,7 +104,9 @@ Do not modify directly.* * com.microsoft.TorchEmbedding * com.microsoft.TransposeMatMul * com.microsoft.Trilu + * com.microsoft.UnfoldTensor * com.microsoft.Unique + * com.microsoft.WhisperBeamSearch * com.microsoft.WordConvEmbedding * experimental com.microsoft.IsAllFinite * experimental com.microsoft.QEmbedLayerNormalization @@ -1132,6 +1140,8 @@ This version of the operator has been available since version 1 of the 'com.micr
The value to be filled in the attention mask. Default value is -10000.0f
num_heads : int (required)
Number of attention heads
+
output_qk : int
+
Need output the cross attention MatMul(Q, K)
past_present_share_buffer : int
Corresponding past and present are same tensor, its size is (batch_size, num_heads, max_sequence_length, head_size)
scale : float
@@ -1165,7 +1175,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
-#### Outputs (1 - 3) +#### Outputs (1 - 4)
output : T
@@ -1174,11 +1184,15 @@ This version of the operator has been available since version 1 of the 'com.micr
present state for key with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
present_value (optional) : T
present state for value with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
+
qk (optional) : V
+
normalized Q * K, of shape (batch_size, num_heads, 1, head_size).
#### Type Constraints
+
V : tensor(float)
+
Constrain qk output types to float32 tensors.
T : tensor(float), tensor(float16)
Constrain input and output types to float tensors.
M : tensor(int32)
@@ -1522,6 +1536,38 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.DynamicTimeWarping** + + Input is cost matrix where each value in input[r][c] is the cost for pass the point (r, c). From current point(r, c), points (r+1, c), (r+1, c+1) or (r, c+1) could be arrived in next move. Given such cost matrix, return dynamic time wrapping of shape [2, x], where the path made by all points (output[0][t], output[1][t])have the lowest cost among all paths from (0, 0) to (M-1, N-1). + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Inputs + +
+
input : F
+
Input cost tensor, it must be 2D tensor of shape M x N, or 1 x M x N
+
+ +#### Outputs + +
+
output : I
+
Output tensor. shape is [2, x], where max(M, N) <= x < M + N
+
+ +#### Type Constraints + +
+
F : tensor(float)
+
Constrain to float tensors.
+
I : tensor(int32)
+
Constrain to integer types.
+
+ + ### **com.microsoft.EPContext** Onnx node container for EP context. @@ -2093,6 +2139,71 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.GemmFloat8** + + Generic Gemm for float and float 8. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
activation : string
+
Activation function, RELU or GELU or NONE (default).
+
alpha : float
+
Scalar multiplier for the product of input tensors A * B.
+
beta : float
+
Scalar multiplier for the product of input bias C.
+
dtype : int
+
Output Type. Same definition as attribute 'to' for operator Cast.
+
transA : int
+
Whether A should be transposed. Float 8 only supprted transA=0.
+
transB : int
+
Whether B should be transposed. Float 8 only supprted transB=1.
+
+ +#### Inputs (2 - 6) + +
+
A : TA
+
Input tensor A. The shape of A should be (M, K) if transA is 0, or (K, M) if transA is non-zero.
+
B : TB
+
Input tensor B. The shape of B should be (K, N) if transB is 0, or (N, K) if transB is non-zero.
+
C (optional) : TC
+
Input tensor C.
+
scaleA (optional) : TS
+
Scale of tensor A if A is float 8 tensor
+
scaleB (optional) : TS
+
Scale of tensor B if B is float 8 tensor
+
scaleY (optional) : TS
+
Scale of the output tensor if A or B is float 8.
+
+ +#### Outputs + +
+
Y : TR
+
Output tensor of shape (M, N).
+
+ +#### Type Constraints + +
+
TA : tensor(float8e4m3fn), tensor(float8e5m2), tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain type to input A.
+
TB : tensor(float8e4m3fn), tensor(float8e5m2), tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain type to input B.
+
TC : tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain type to input C.
+
TR : tensor(float8e4m3fn), tensor(float8e5m2), tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain type to result type.
+
TS : tensor(float)
+
Constrain type for all input scales (scaleA, scaleB, scaleY).
+
+ + ### **com.microsoft.GreedySearch** Greedy Search for text generation. @@ -2232,7 +2343,7 @@ This version of the operator has been available since version 1 of the 'com.micr
activation : int (required)
-
Activation after group normalization: 0 for None, 1 for Swish
+
Activation after group normalization: 0 for None, 1 for SiLU
channels_last : int
1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.
epsilon : float
@@ -2311,14 +2422,14 @@ This version of the operator has been available since version 1 of the 'com.micr
When buffered past_key and past_value is used (present_key uses same tensor as past_key), requiredto specify past_sequence_length (could be 0). Otherwise, past_sequence_length inferred from past_key.
-#### Outputs (1 - 3) +#### Outputs
output : T
3D output tensor with shape (batch_size, sequence_length, hidden_size)
-
present_key (optional) : T
+
present_key : T
present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
-
present_value (optional) : T
+
present_value : T
present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
@@ -2461,6 +2572,89 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.MatMulBnb4** + + MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: + 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. + 2. Input B is quantized with 4 bits with quantization data type specified by attribute 'quant_type'. It is transposed, flattened and quantized blockwisely with block size specified by attribute 'block_size'. + And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. + 3. Input B's quantization constants or scales are specified by input 'absmax'. + + Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. + Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + + + 1. (Default value) transB=True (Majorly used for forward pass) + Shape of A: [D0, D1, ..., Dn, K] + Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features]. + + The computation math: + dequant_B = dequant(B, absmax, quant_type, block_size) + transposed_dequant_B = dequant_B^T + output = A @ transposed_dequant_B + + Shape of output: [D0, D1, ..., Dn, N] + + 2. transB=False (Majorly used for backward pass) + Shape of A: [D0, D1, ..., Dn, N] + Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features]. + + The computation math: + dequant_B = dequant(B, absmax, quant_type, block_size) + output = A @ dequant_B + + Shape of output: [D0, D1, ..., Dn, K] + + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
K : int (required)
+
size of each input feature
+
N : int (required)
+
size of each output feature
+
block_size : int (required)
+
number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.
+
quant_type : int (required)
+
quantization data type. 0 for FP4, 1 for NF4.
+
training_mode : int
+
Indicate if the ops run in training_mode, by default, False.
+
transB : int
+
Whether B should be transposed on the last two dimensions before doing multiplication. Default to be 1.
+
+ +#### Inputs + +
+
A : T1
+
The input tensor, not quantized
+
B : T2
+
1-dimensional quantized data for weight
+
absmax : T1
+
quantization constants
+
+ +#### Outputs + +
+
Y : T1
+
tensor. The output tensor has the same rank as the input.
+
+ +#### Type Constraints + +
+
T1 : tensor(float), tensor(float16)
+
Constrain input and output types to float/half_float tensors.
+
T2 : tensor(uint8)
+
Constrain quantized weight types to uint8.
+
+ + ### **com.microsoft.MatMulFpQ4** Matrix product with right hand matrix being pre-packed and quantized int4 data blob. @@ -2593,6 +2787,78 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.MatMulNBits** + + MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7).It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: + 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. + 2. Input B is quantized with x bits which is specified by attribute 'bits'. It is quantized blockwisely along dimension 0 (e.g. column) with block size specified by attribute block_size. + And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. + 3. Input B's scale and zero point are specified by input scales and zero_points. + + Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: + - n_blocks_per_col = (K + block_size - 1) / block_size + - blob_size = block_size / 8 * bits + + For a block blob. It is stored in format: + struct Blob { + uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization + uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization + uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization + } + + Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] + Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is: + - [(N * n_blocks_per_col + 1) / 2] if bits <=4 + - [N * n_blocks_per_col] if bits > 4 + + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
K : int (required)
+
size of each input feature
+
N : int (required)
+
size of each output feature
+
bits : int (required)
+
number of bits used for weight quantization (default 4)
+
block_size : int (required)
+
number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.
+
+ +#### Inputs (3 - 4) + +
+
A : T1
+
The input tensor, not quantized
+
B : T2
+
1-dimensional data blob
+
scales : T1
+
quantization scale
+
zero_points (optional) : T2
+
quantization zero points
+
+ +#### Outputs + +
+
Y : T1
+
tensor. The output tensor has the same rank as the input.
+
+ +#### Type Constraints + +
+
T1 : tensor(float), tensor(float16)
+
Constrain input and output types to float/half_float tensors.
+
T2 : tensor(uint8)
+
Constrain quantized weight types to uint8.
+
+ + ### **com.microsoft.MaxpoolWithMask** For internal use. @@ -2720,7 +2986,7 @@ This version of the operator has been available since version 1 of the 'com.micr
bias (optional) : T
Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
key_padding_mask (optional) : M
-
Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)
+
Key padding mask with shape (batch_size), (3 * batch_size + 2), (batch_size, kv_sequence_length), (batch_size, total_sequence_length), or (batch_size, sequence_length, total_sequence_length)
relative_position_bias (optional) : T
relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)
past_key (optional) : T
@@ -4682,6 +4948,54 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.RotaryEmbedding** + + RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices + that are multiplied to query and key before the inner product of query and key is taken. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
interleaved : int
+
Rotate using interleaved pattern. Default value is 0 (False).
+
scale : float
+
Custom scale will be used if specified. Default value is 1.0
+
+ +#### Inputs + +
+
input : T
+
3D tensor with shape (batch_size, sequence_length, hidden_size)
+
position_ids : M
+
1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)
+
cos_cache : T
+
2D tensor with shape (max_sequence_length, head_size / 2).
+
sin_cache : T
+
2D tensor with shape (max_sequence_length, head_size / 2).
+
+ +#### Outputs + +
+
output : T
+
3D tensor with shape (batch_size, sequence_length, hidden_size)
+
+ +#### Type Constraints + +
+
T : tensor(float), tensor(float16)
+
Constrain input and output types to float tensors.
+
M : tensor(int64)
+
Constrain input and output types to integer tensors
+
+ + ### **com.microsoft.SampleOp** Sample echo operator. @@ -4797,6 +5111,72 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.SkipGroupNorm** + + This operator element-wise adds x, skip and bias, then apply group normalization and optional activation. + + This operator transforms input according to + s = x + skip + bias + y = gamma * (s - mean) / sqrt(variance + epsilon) + beta + + The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. + The num_channels must be divisible by num_groups. + The mean and standard-deviation of s are calculated separately over the each group. + The weight and bias are per-channel affine transform parameter vectors of size num_channels. + + The activation attribute can be used to enable activation after group normalization. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
activation : int (required)
+
Activation after group normalization: 0 for None, 1 for SiLU
+
channels_last : int
+
1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.
+
epsilon : float
+
The epsilon value to use to avoid division by zero
+
groups : int (required)
+
The number of groups of channels. It should be a divisor of the number of channels C
+
+ +#### Inputs (4 - 5) + +
+
X : T
+
Input data tensor. Dimensions are (N x H x W x C) when channels_last is 1 or (N x C x H x W) otherwise, where N is the batch size, C is the number of channels, and H and W are the height and width of the data
+
gamma : M
+
1D gamma tensor for normalization with shape (C), where C is number of channels
+
beta : M
+
1D beta tensor for normalization with shape (C), where C is number of channels
+
skip : T
+
4D or 2D skip tensor. The shape can be (N x H x W x C) or (N x 1 x 1 x C) or (N x C)
+
bias (optional) : T
+
1D bias tensor. Dimensions are (C), where C is number of channels
+
+ +#### Outputs (1 - 2) + +
+
Y : T
+
The output tensor of the same shape as X
+
S (optional) : T
+
The element-wise sum of input x, skip and bias tensors. It has the same shape as X
+
+ +#### Type Constraints + +
+
T : tensor(float16), tensor(float)
+
Constrain input X, skip, bias and output Y, S types to float tensors.
+
M : tensor(float16), tensor(float)
+
Constrain gamma and beta to float tensors.
+
+ + ### **com.microsoft.SkipLayerNormalization** Skip and Layer Normalization Fusion @@ -5192,6 +5572,47 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.UnfoldTensor** + + Returns a tensor which contains all slices of size size from input tensor in the dimension dim. Step between two slices is given by step. If sizedim is the size of dimension dim for input tensor, the size of dimension dim in the returned tensor will be (sizedim - size) / step + 1. An additional dimension of size size is appended in the returned tensor. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
dim : int
+
specify the dimension to unfold
+
size : int (required)
+
specify the size
+
step : int
+
specify the step.
+
+ +#### Inputs + +
+
input : T
+
input tensor
+
+ +#### Outputs + +
+
output : T
+
Output tensor.
+
+ +#### Type Constraints + +
+
T : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(bfloat16), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128)
+
Allow inputs and outputs to be any kind of tensor.
+
+ + ### **com.microsoft.Unique** Finds all the unique values (deduped list) present in the given input tensor. @@ -5238,6 +5659,107 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.WhisperBeamSearch** + + Beam Search for whisper model, especiall with cross_qk features etc. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
decoder : graph (required)
+
Decoder subgraph to execute in a loop.
+
decoder_output_cross_qk : int
+
If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.
+
decoder_start_token_id : int
+
The id of the token that indicates decoding starts.
+
early_stopping : int
+
early stop or not
+
encoder : graph
+
The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.
+
eos_token_id : int (required)
+
The id of the end-of-sequence token
+
init_decoder : graph
+
The subgraph for the first decoding run. It will be called once before `decoder` subgraph. This is relevant only for the GPT2 model. If this attribute is missing, the `decoder` subgraph will be used for all decoding runs
+
model_type : int
+
Must be 2 for whisper
+
no_repeat_ngram_size : int
+
no repeat ngrams size
+
no_speech_token : int
+
The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.
+
pad_token_id : int (required)
+
The id of the padding token
+
vocab_size : int
+
Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
+
+ +#### Inputs (5 - 14) + +
+
input_ids : F
+
The sequence used as a prompt for the generation in the encoder subgraph. Shape is (batch_size, sequence_length)
+
max_length : I
+
The maximum length of the sequence to be generated. Shape is (1)
+
min_length (optional) : I
+
The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)
+
num_beams : I
+
Number of beams for beam search. 1 means no beam search. Shape is (1)
+
num_return_sequences : I
+
The number of returned sequences in the batch. Shape is (1)
+
length_penalty (optional) : T
+
Exponential penalty to the length. Default value 1.0 means no penalty.Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences.Shape is (1,)
+
repetition_penalty (optional) : T
+
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
+
vocab_mask (optional) : M
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
prefix_vocab_mask (optional) : M
+
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
+
attention_mask (optional) : I
+
Custom attention mask. Shape is (batch_size, sequence_length)
+
decoder_input_ids (optional) : I
+
The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)
+
logits_processor (optional) : I
+
Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)
+
cross_qk_layer_head (optional) : I
+
Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect allits shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
+
extra_decoding_ids (optional) : I
+
Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len).In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) are treated as stop of the extra_decoding_ids for corresponding batch.
+
+ +#### Outputs (1 - 5) + +
+
sequences : I
+
Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)
+
sequences_scores (optional) : T
+
Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)
+
scores (optional) : T
+
Processed beam scores for each vocabulary token at each generation step.Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam.Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)
+
cross_qk (optional) : V
+
Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers,B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F].If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]
+
non_speech_probs (optional) : T
+
For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token.Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph.The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]
+
+ +#### Type Constraints + +
+
T : tensor(float), tensor(float16)
+
Constrain to float tensors.
+
F : tensor(float), tensor(int32), tensor(float16)
+
Constrain input type to float or int tensors.
+
I : tensor(int32)
+
Constrain to integer types
+
M : tensor(int32)
+
Constrain mask to integer types
+
V : tensor(float)
+
Constrain cross_qk to float32 tensors.
+
+ + ### **com.microsoft.WordConvEmbedding** The WordConvEmbedding takes in a batch of sequence words and embed each word to a vector. diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index ce9d8aabfede3..e60da111afd36 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -25,6 +25,7 @@ Do not modify directly.* |||13|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[7, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |Affine|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| +|AffineGrid|*in* theta:**T1**
*in* size:**T2**
*out* grid:**T1**|20+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| |And|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| |ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)| @@ -137,7 +138,8 @@ Do not modify directly.* |||[7, 8]|**T** = tensor(double), tensor(float)
**T1** = tensor(bool)| |GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| |||[12, 15]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| -|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float)
**T2** = tensor(float)| +|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|20+|**T1** = tensor(double), tensor(float)
**T2** = tensor(double), tensor(float)| +|||[16, 19]|**T1** = tensor(float)
**T2** = tensor(float)| |HammingWindow|*in* size:**T1**
*out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |HannWindow|*in* size:**T1**
*out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |HardSigmoid|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float)| @@ -156,8 +158,10 @@ Do not modify directly.* |||[1, 10]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ImageScaler|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| |InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(float)| -|IsInf|*in* X:**T1**
*out* Y:**T2**|10+|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| -|IsNaN|*in* X:**T1**
*out* Y:**T2**|13+|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| +|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| +|IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float)| |||[1, 12]|**T** = tensor(float)| @@ -454,9 +458,11 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| |MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float)| |MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)| @@ -476,14 +482,17 @@ Do not modify directly.* |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)| |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| +|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float)| |SampleOp|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)| +|SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)| |SparseToDenseMatMul|*in* A:**T**
*in* B:**T1**
*out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |Tokenizer|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(string)| |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int64)| |Unique|*in* x:**T**
*out* y:**T**
*out* idx:**tensor(int64)**
*out* counts:**tensor(int64)**|1+|**T** = tensor(float)| +|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float)| |WordConvEmbedding|*in* Sequence:**T**
*in* W:**T1**
*in* B:**T1**
*in* C:**T1**
*out* Y:**T1**|1+|**T** = tensor(int32)
**T1** = tensor(float)| | | | | @@ -514,10 +523,8 @@ Do not modify directly.* |||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |Affine|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |And|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| -|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|11|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|11|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| +|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)| +|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)| |AveragePool|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| |||10|**T** = tensor(double), tensor(float), tensor(float16)| |||[7, 9]|**T** = tensor(double), tensor(float), tensor(float16)| @@ -688,39 +695,26 @@ Do not modify directly.* |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* output:**T**|11+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| |Reciprocal|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|ReduceL1|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|ReduceL2|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|ReduceLogSum|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|ReduceLogSumExp|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|ReduceMax|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||11|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|ReduceMean|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|ReduceMin|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||13|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||11|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|ReduceProd|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|ReduceL1|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|ReduceL2|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|ReduceLogSum|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16)| +|ReduceLogSumExp|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16)| +|ReduceMax|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|ReduceMean|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|ReduceMin|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|ReduceProd|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| |ReduceSum|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|ReduceSumSquare|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|ReduceSumSquare|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16)| |Relu|*in* X:**T**
*out* Y:**T**|14+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||13|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| @@ -808,7 +802,7 @@ Do not modify directly.* |||[1, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Upsample|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T**
*out* Y:**T**|9|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| -|Where|*in* condition:**B**
*in* X:**T**
*in* Y:**T**
*out* output:**T**|16+|**B** = tensor(bool)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)| +|Where|*in* condition:**B**
*in* X:**T**
*in* Y:**T**
*out* output:**T**|16+|**B** = tensor(bool)
**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)| |||[9, 15]|**B** = tensor(bool)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)| |Xor|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| | | @@ -827,16 +821,18 @@ Do not modify directly.* |ComplexMulConj|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| |ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |DecoderAttention|*in* query:**T**
*in* key:**T**
*in* q_weight:**T**
*in* kv_weight:**T**
*in* bias:**T**
*in* key_padding_mask:**B**
*in* key_cache:**T**
*in* value_cache:**T**
*in* static_kv:**B**
*in* use_past:**B**
*in* has_layer_state:**B**
*in* has_key_padding_mask:**B**
*out* output:**T**
*out* new_key_cache:**T**
*out* new_value_cache:**T**|1+|**T** = tensor(float), tensor(float16)| -|DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*in* bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| +|DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*in* bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**V**|1+|**T** = tensor(float), tensor(float16)| |DecoderMaskedSelfAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| |DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(float16)| |DequantizeWithOrder|*in* input:**Q**
*in* scale_input:**S**
*out* output:**F**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)| +|DynamicTimeWarping|*in* input:**F**
*out* output:**I**|1+|**F** = tensor(float)
**I** = tensor(int32)| |EmbedLayerNormalization|*in* input_ids:**T1**
*in* segment_ids:**T1**
*in* word_embedding:**T**
*in* position_embedding:**T**
*in* segment_embedding:**T**
*in* gamma:**T**
*in* beta:**T**
*in* mask:**T1**
*in* position_ids:**T1**
*out* output:**T**
*out* mask_index:**T1**
*out* embedding_sum:**T**|1+|**T** = tensor(float), tensor(float16)| |FastGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)| |FusedConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*in* Z:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|GemmFloat8|*in* A:**TA**
*in* B:**TB**
*in* C:**TC**
*in* scaleA:**TS**
*in* scaleB:**TS**
*in* scaleY:**TS**
*out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TS** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -844,6 +840,8 @@ Do not modify directly.* |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -862,11 +860,15 @@ Do not modify directly.* |RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)| |RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)| |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)| +|SkipGroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*in* skip:**T**
*in* bias:**T**
*out* Y:**T**
*out* S:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|UnfoldTensor|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)| | | | | @@ -1246,6 +1248,7 @@ Do not modify directly.* |QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float), tensor(float16), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| +|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| | | | | diff --git a/docs/python/ReadMeOV.rst b/docs/python/ReadMeOV.rst index f12c01d278dca..6ef16e1378139 100644 --- a/docs/python/ReadMeOV.rst +++ b/docs/python/ReadMeOV.rst @@ -7,7 +7,6 @@ OpenVINO™ Execution Provider for ONNX Runtime accelerates inference across man - Intel® CPUs - Intel® integrated GPUs - Intel® discrete GPUs - - Intel® integrated VPUs Installation ------------ @@ -22,7 +21,6 @@ This package supports: - Intel® CPUs - Intel® integrated GPUs - Intel® discrete GPUs - - Intel® integrated VPUs ``pip3 install onnxruntime-openvino`` diff --git a/include/onnxruntime/core/framework/float8.h b/include/onnxruntime/core/framework/float8.h index 0fd04f28d44b7..dd607cbbc6952 100644 --- a/include/onnxruntime/core/framework/float8.h +++ b/include/onnxruntime/core/framework/float8.h @@ -208,9 +208,10 @@ struct Float8E4M3FNUZ { val = static_cast((b & 0x80000000) >> 24); // sign if ((b & 0x7fffffff) == 0x7f800000) { // infinity if (saturate) { + // the highest available value val |= 0x7F; } else { - // infinity + // NaN val = 0x80; } } else if ((b & 0x7F800000) == 0x7F800000) { // NaN @@ -362,8 +363,10 @@ struct Float8E5M2 { val = (b & 0x80000000) >> 24; // sign if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf if (saturate) { + // the highest available value val |= 0x7B; } else { + // the infinity val |= 0x7C; } } else if ((b & 0x7F800000) == 0x7F800000) { // NaN diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index f153e88909b8d..462d410e13769 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include #include @@ -83,10 +84,10 @@ class Node { gsl::span output_args, const NodeAttributes* attributes, std::string_view domain) { - Init(std::string{name}, std::string{op_type}, std::string{description}, - std::vector{input_args.begin(), input_args.end()}, - std::vector{output_args.begin(), output_args.end()}, - attributes, std::string{domain}); + Init(name, op_type, description, + input_args, + output_args, + attributes, domain); } #endif @@ -563,13 +564,13 @@ class Node { ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Node); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) - void Init(const std::string& name, - const std::string& op_type, - const std::string& description, - const std::vector& input_args, - const std::vector& output_args, + void Init(std::string_view name, + std::string_view op_type, + std::string_view description, + gsl::span input_args, + gsl::span output_args, const NodeAttributes* attributes, - const std::string& domain); + std::string_view domain); #endif #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -1141,8 +1142,22 @@ class Graph { */ Status InlineFunction(Node& node); + /** + Directly insert the nodes in the function proto provided into the graph. + The function converts Constant nodes into the initializers in the graph. + It then creates a node in the graph for each of the function nodes. + All of the names are expected to be specialized, and, therefore unique. + See function_utils::Specialize(). + + The Graph needs to be Resolve()d after this call. + @param func_to_inline + @returns Status indicating success or providing an error message. + */ + + Status InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_inline); + /** Mark a NodeArg name as coming from the outer scope when programmatically constructing a Graph that will - be used as a GraphProto attribute in another Node.. + be used as a GraphProto attribute in another Node. e.g. when creating a Graph instance that will be used as a subgraph in a control flow operator, it is necessary to define placeholder NodeArgs for outer scope values. This prevents these values from becoming explicit graph inputs when the Graph is resolved. @@ -1391,6 +1406,13 @@ class Graph { Node& AddNode(const ONNX_NAMESPACE::NodeProto& node_proto, const ArgNameToTypeMap& name_to_type); + /** Helper that converts and adds constant node proto to an initializer in the graph. + @param constant_node_proto Constant node to convert + @param new_name use the new name for the initializer. + */ + Status AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& constant_node_proto, + std::optional new_name); + #endif Version IrVersion() const noexcept { diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h index 13c176dad3cc5..646f33ed952a4 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_context.h +++ b/include/onnxruntime/core/providers/cuda/cuda_context.h @@ -19,6 +19,7 @@ struct CudaContext : public CustomOpContext { cudaStream_t cuda_stream = {}; cudnnHandle_t cudnn_handle = {}; cublasHandle_t cublas_handle = {}; + OrtAllocator* deferred_cpu_allocator = {}; void Init(const OrtKernelContext& kernel_ctx) override { const auto& ort_api = Ort::GetApi(); @@ -44,6 +45,36 @@ struct CudaContext : public CustomOpContext { ORT_CXX_API_THROW("failed to fetch cublas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION); } cublas_handle = reinterpret_cast(resource); + + resource = {}; + status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::deferred_cpu_allocator_t, &resource); + if (status) { + ORT_CXX_API_THROW("failed to fetch deferred cpu allocator", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + deferred_cpu_allocator = reinterpret_cast(resource); + } + + void* AllocDeferredCpuMem(size_t size) const { + if (0 == size) { + return {}; + } + const auto& ort_api = Ort::GetApi(); + void* mem = {}; + auto status = ort_api.AllocatorAlloc(deferred_cpu_allocator, size, &mem); + if (status) { + ORT_CXX_API_THROW("failed to allocate deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + return mem; + } + + void FreeDeferredCpuMem(void* mem) const { + if (mem) { + const auto& ort_api = Ort::GetApi(); + auto status = ort_api.AllocatorFree(deferred_cpu_allocator, mem); + if (status) { + ORT_CXX_API_THROW("failed to free deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + } } }; diff --git a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h index eaf0e5337b8b6..82bb8ba83be4a 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h +++ b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h @@ -1,8 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #pragma once +#include + #include "onnxruntime_c_api.h" #include "core/framework/arena_extend_strategy.h" @@ -32,4 +35,6 @@ struct OrtCUDAProviderOptionsV2 { int tunable_op_max_tuning_duration_ms = 0; // Max tuning duration time limit for TunableOp. int enable_skip_layer_norm_strict_mode = 0; // flag specifying if SkipLayerNorm is in strict mode. If true, use LayerNormalization kernel. // The strict mode has better accuracy but lower performance. + int prefer_nhwc = 0; // make the CUDA EP NHWC preferred + int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not }; diff --git a/include/onnxruntime/core/providers/cuda/cuda_resource.h b/include/onnxruntime/core/providers/cuda/cuda_resource.h index e46fc5b4219dd..8c3ed46ade6a1 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_resource.h +++ b/include/onnxruntime/core/providers/cuda/cuda_resource.h @@ -3,10 +3,11 @@ #include "core/providers/resource.h" -#define ORT_CUDA_RESOUCE_VERSION 1 +#define ORT_CUDA_RESOUCE_VERSION 2 enum CudaResource : int { cuda_stream_t = cuda_resource_offset, cudnn_handle_t, - cublas_handle_t + cublas_handle_t, + deferred_cpu_allocator_t, }; \ No newline at end of file diff --git a/include/onnxruntime/core/providers/dml/dml_provider_factory.h b/include/onnxruntime/core/providers/dml/dml_provider_factory.h index dd4ffb835d51c..cf3ddc3f125f9 100644 --- a/include/onnxruntime/core/providers/dml/dml_provider_factory.h +++ b/include/onnxruntime/core/providers/dml/dml_provider_factory.h @@ -128,7 +128,7 @@ struct OrtDmlApi { /** * SessionOptionsAppendExecutionProvider_DML2 * Creates a DirectML Execution Provider given the supplied device options that contain a performance preference - * (high power, low power, or defult) and a device filter (None, GPU, or NPU). + * (high power, low power, or default) and a device filter (None, GPU, or NPU). */ ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_opts); }; diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index 8f2b5af870506..680ce1cc5b9a2 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -25,13 +25,14 @@ struct OrtTensorRTProviderOptionsV2 { int trt_dla_core{0}; // DLA core number. Default 0 int trt_dump_subgraphs{0}; // dump TRT subgraph. Default 0 = false, nonzero = true int trt_engine_cache_enable{0}; // enable engine caching. Default 0 = false, nonzero = true - const char* trt_engine_cache_path{nullptr}; // specify engine cache path + const char* trt_engine_cache_path{nullptr}; // specify engine cache path, defaults to the working directory int trt_engine_decryption_enable{0}; // enable engine decryption. Default 0 = false, nonzero = true const char* trt_engine_decryption_lib_path{nullptr}; // specify engine decryption library path int trt_force_sequential_engine_build{0}; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true int trt_context_memory_sharing_enable{0}; // enable context memory sharing between subgraphs. Default 0 = false, nonzero = true int trt_layer_norm_fp32_fallback{0}; // force Pow + Reduce ops in layer norm to FP32. Default 0 = false, nonzero = true int trt_timing_cache_enable{0}; // enable TensorRT timing cache. Default 0 = false, nonzero = true + const char* trt_timing_cache_path{nullptr}; // specify timing cache path, if none is provided the trt_engine_cache_path is used int trt_force_timing_cache{0}; // force the TensorRT cache to be used even if device profile does not match. Default 0 = false, nonzero = true int trt_detailed_build_log{0}; // Enable detailed build step logging on TensorRT EP with timing for each engine build. Default 0 = false, nonzero = true int trt_build_heuristics_enable{0}; // Build engine using heuristics to reduce build time. Default 0 = false, nonzero = true diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 4be49bdaeaa2d..729a302f3dd0f 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -299,6 +299,7 @@ ORT_RUNTIME_CLASS(DnnlProviderOptions); ORT_RUNTIME_CLASS(Op); ORT_RUNTIME_CLASS(OpAttr); ORT_RUNTIME_CLASS(Logger); +ORT_RUNTIME_CLASS(ShapeInferContext); #ifdef _WIN32 typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -610,7 +611,7 @@ typedef struct OrtMIGraphXProviderOptions { typedef struct OrtOpenVINOProviderOptions { #ifdef __cplusplus OrtOpenVINOProviderOptions() : device_type{}, - enable_vpu_fast_compile{}, + enable_npu_fast_compile{}, device_id{}, num_of_threads{}, cache_dir{}, @@ -623,7 +624,7 @@ typedef struct OrtOpenVINOProviderOptions { * Valid settings are one of: "CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU_FP16" */ const char* device_type; - unsigned char enable_vpu_fast_compile; ///< 0 = disabled, nonzero = enabled + unsigned char enable_npu_fast_compile; ///< 0 = disabled, nonzero = enabled const char* device_id; size_t num_of_threads; ///< 0 = Use default number of threads const char* cache_dir; // path is set to empty by default @@ -4444,6 +4445,73 @@ struct OrtApi { */ ORT_API2_STATUS(SetUserLoggingFunction, _Inout_ OrtSessionOptions* options, _In_ OrtLoggingFunction user_logging_function, _In_opt_ void* user_logging_param); + + /** + * Get number of input from OrtShapeInferContext + * + * \param[in] context + * \param[out] out The number of inputs + * + * \since Version 1.17. + */ + ORT_API2_STATUS(ShapeInferContext_GetInputCount, _In_ const OrtShapeInferContext* context, _Out_ size_t* out); + + /** + * Get type and shape info of an input + * + * \param[in] context + * \param[in] index The index of the input + * \param[out] info Type shape info of the input + * + * \since Version 1.17. + */ + ORT_API2_STATUS(ShapeInferContext_GetInputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _Outptr_ OrtTensorTypeAndShapeInfo** info); + + /** + * Get attribute from OrtShapeInferContext. Note that OrtShapeInferContext is a per-node context, one could only read attribute from current node. + * + * \param[in] context + * \param[in] attr_name Name of the attribute + * \param[out] attr Handle of the attribute fetched + * + * \since Version 1.17. + */ + ORT_API2_STATUS(ShapeInferContext_GetAttribute, _In_ const OrtShapeInferContext* context, _In_ const char* attr_name, _Outptr_ const OrtOpAttr** attr); + + /** + * Set type and shape info of an ouput + * + * \param[in] context + * \param[in] index The index of the ouput + * \param[out] info Type shape info of the output + * + * \since Version 1.17. + */ + ORT_API2_STATUS(ShapeInferContext_SetOutputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _In_ const OrtTensorTypeAndShapeInfo* info); + + /** + * Set symbolic shape to type shape info + * + * \param[in] info Type shape info + * \param[in] dim_params Symbolic strings + * \param[in] dim_params_length Number of strings + * + * \since Version 1.17. + */ + ORT_API2_STATUS(SetSymbolicDimensions, _In_ OrtTensorTypeAndShapeInfo* info, _In_ const char* dim_params[], _In_ size_t dim_params_length); + + /** + * Read contents of an attribute to data + * + * \param[in] op_attr + * \param[in] type Attribute type + * \param[out] data Memory address to save raw content of the attribute + * \param[in] len Number of bytes allowed to store in data + * \param[out] out Number of bytes required to save the data when the call failed, or the real number of bytes saved to data on success + * + * \since Version 1.17. + */ + ORT_API2_STATUS(ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, _In_ size_t len, _Out_ size_t* out); }; /* @@ -4535,6 +4603,12 @@ struct OrtCustomOp { // Perform the computation step. OrtStatusPtr(ORT_API_CALL* KernelComputeV2)(_In_ void* op_kernel, _In_ OrtKernelContext* context); + + OrtStatusPtr(ORT_API_CALL* InferOutputShapeFn)(_In_ const struct OrtCustomOp* op, _In_ OrtShapeInferContext*); + + // Get start range + int(ORT_API_CALL* GetStartVersion)(_In_ const struct OrtCustomOp* op); + int(ORT_API_CALL* GetEndVersion)(_In_ const struct OrtCustomOp* op); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 45f81783421e0..92c25d8688b66 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2156,6 +2156,80 @@ struct Op : detail::Base { size_t output_count); }; +/// +/// Provide access to per-node attributes and input shapes, so one could compute and set output shapes. +/// +struct ShapeInferContext { + struct SymbolicInteger { + SymbolicInteger(int64_t i) : i_(i), is_int_(true){}; + SymbolicInteger(const char* s) : s_(s), is_int_(false){}; + SymbolicInteger(const SymbolicInteger&) = default; + SymbolicInteger(SymbolicInteger&&) = default; + + SymbolicInteger& operator=(const SymbolicInteger&) = default; + SymbolicInteger& operator=(SymbolicInteger&&) = default; + + bool operator==(const SymbolicInteger& dim) const { + if (is_int_ == dim.is_int_) { + if (is_int_) { + return i_ == dim.i_; + } else { + return std::string{s_} == std::string{dim.s_}; + } + } + return false; + } + + bool IsInt() const { return is_int_; } + int64_t AsInt() const { return i_; } + const char* AsSym() const { return s_; } + + static constexpr int INVALID_INT_DIM = -2; + + private: + union { + int64_t i_; + const char* s_; + }; + bool is_int_; + }; + + using Shape = std::vector; + + ShapeInferContext(const OrtApi* ort_api, OrtShapeInferContext* ctx); + + const Shape& GetInputShape(size_t indice) const { return input_shapes_.at(indice); } + + size_t GetInputCount() const { return input_shapes_.size(); } + + Status SetOutputShape(size_t indice, const Shape& shape); + + int64_t GetAttrInt(const char* attr_name); + + using Ints = std::vector; + Ints GetAttrInts(const char* attr_name); + + float GetAttrFloat(const char* attr_name); + + using Floats = std::vector; + Floats GetAttrFloats(const char* attr_name); + + std::string GetAttrString(const char* attr_name); + + using Strings = std::vector; + Strings GetAttrStrings(const char* attr_name); + + private: + const OrtOpAttr* GetAttrHdl(const char* attr_name) const; + const OrtApi* ort_api_; + OrtShapeInferContext* ctx_; + std::vector input_shapes_; +}; + +using ShapeInferFn = Ort::Status (*)(Ort::ShapeInferContext&); + +#define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1 + template struct CustomOpBase : OrtCustomOp { CustomOpBase() { @@ -2206,6 +2280,16 @@ struct CustomOpBase : OrtCustomOp { static_cast(op_kernel)->Compute(context); }; } + + SetShapeInferFn(0); + + OrtCustomOp::GetStartVersion = [](const OrtCustomOp* this_) { + return static_cast(this_)->start_ver_; + }; + + OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) { + return static_cast(this_)->end_ver_; + }; } // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider @@ -2257,9 +2341,26 @@ struct CustomOpBase : OrtCustomOp { return std::vector{}; } + template + decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape)) { + OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { + ShapeInferContext ctx(&GetApi(), ort_ctx); + return C::InferOutputShape(ctx); + }; + return {}; + } + + template + void SetShapeInferFn(...) { + OrtCustomOp::InferOutputShapeFn = {}; + } + protected: // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys. void GetSessionConfigs(std::unordered_map& out, ConstSessionOptions options) const; + + int start_ver_ = 1; + int end_ver_ = MAX_CUSTOM_OP_END_VER; }; } // namespace Ort diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 22172832cde8e..860a27fc73f79 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -8,6 +8,15 @@ // the main C++ file with implementation details. #include +#include + +#define RETURN_ON_API_FAIL(expression) \ + { \ + auto err = (expression); \ + if (err) { \ + return Status(err); \ + } \ + } namespace Ort { @@ -1883,4 +1892,154 @@ void CustomOpBase::GetSessionConfigs(std::unordered_ma } } +inline ShapeInferContext::ShapeInferContext(const OrtApi* ort_api, + OrtShapeInferContext* ctx) : ort_api_(ort_api), ctx_(ctx) { + size_t input_count = 0; + Ort::ThrowOnError(ort_api_->ShapeInferContext_GetInputCount(ctx_, &input_count)); + for (size_t ith_input = 0; ith_input < input_count; ++ith_input) { + OrtTensorTypeAndShapeInfo* info{}; + Ort::ThrowOnError(ort_api_->ShapeInferContext_GetInputTypeShape(ctx, ith_input, &info)); + TensorTypeAndShapeInfo type_shape_info(info); + auto integer_shape = type_shape_info.GetShape(); + std::vector symbolic_shape(integer_shape.size(), {}); + type_shape_info.GetSymbolicDimensions(&symbolic_shape[0], integer_shape.size()); + Shape shape; + for (size_t ith = 0; ith < integer_shape.size(); ++ith) { + if (symbolic_shape[ith] && std::string{symbolic_shape[ith]}.size() > 0) { + shape.emplace_back(symbolic_shape[ith]); + } else { + shape.emplace_back(integer_shape[ith]); + } + } + input_shapes_.push_back(std::move(shape)); + type_shape_info.release(); + } +} + +inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shape) { + OrtTensorTypeAndShapeInfo* info = {}; + RETURN_ON_API_FAIL(ort_api_->CreateTensorTypeAndShapeInfo(&info)); + + using InfoPtr = std::unique_ptr>; + + InfoPtr info_ptr(info, [this](OrtTensorTypeAndShapeInfo* obj) { + ort_api_->ReleaseTensorTypeAndShapeInfo(obj); + }); + + std::vector integer_dims; + std::vector symbolic_dims; + + for (const auto dim : shape) { + if (dim.IsInt()) { + integer_dims.push_back(dim.IsInt()); + symbolic_dims.push_back(""); + } else { + if (!dim.AsSym() || std::string{dim.AsSym()}.empty()) { + ORT_CXX_API_THROW("Symbolic dim must not be an empty string", ORT_INVALID_ARGUMENT); + } + integer_dims.push_back(SymbolicInteger::INVALID_INT_DIM); + symbolic_dims.push_back(dim.AsSym()); + } + } + + RETURN_ON_API_FAIL(ort_api_->SetDimensions(info, integer_dims.data(), integer_dims.size())); + RETURN_ON_API_FAIL(ort_api_->SetSymbolicDimensions(info, symbolic_dims.data(), symbolic_dims.size())); + RETURN_ON_API_FAIL(ort_api_->ShapeInferContext_SetOutputTypeShape(ctx_, indice, info)); + return Status{nullptr}; +} + +inline int64_t ShapeInferContext::GetAttrInt(const char* attr_name) { + const auto* attr = GetAttrHdl(attr_name); + int64_t i = {}; + size_t out = {}; + Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INT, &i, sizeof(i), &out)); + return i; +} + +inline ShapeInferContext::Ints ShapeInferContext::GetAttrInts(const char* attr_name) { + const auto* attr = GetAttrHdl(attr_name); + int64_t i = {}; + size_t out = {}; + // first call to get the bytes needed + auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, &i, sizeof(i), &out); + if (status) { + size_t num_i = out / sizeof(int64_t); + ShapeInferContext::Ints ints(num_i, 0); + Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, ints.data(), out, &out)); + return ints; + } else { + return {i}; + } +} + +inline float ShapeInferContext::GetAttrFloat(const char* attr_name) { + const auto* attr = GetAttrHdl(attr_name); + float f = {}; + size_t out = {}; + Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOAT, &f, sizeof(f), &out)); + return f; +} + +inline ShapeInferContext::Floats ShapeInferContext::GetAttrFloats(const char* attr_name) { + const auto* attr = GetAttrHdl(attr_name); + float f = {}; + size_t out = {}; + // first call to get the bytes needed + auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, &f, sizeof(f), &out); + if (status) { + size_t num_f = out / sizeof(float); + ShapeInferContext::Floats floats(num_f, 0); + Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, floats.data(), out, &out)); + return floats; + } else { + return {f}; + } +} + +inline std::string ShapeInferContext::GetAttrString(const char* attr_name) { + const auto* attr = GetAttrHdl(attr_name); + char c = {}; + size_t out = {}; + // first call to get the bytes needed + auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, &c, sizeof(char), &out); + if (status) { + std::vector chars(out, '\0'); + Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, chars.data(), out, &out)); + return {chars.data()}; + } else { + return {c}; + } +} + +inline ShapeInferContext::Strings ShapeInferContext::GetAttrStrings(const char* attr_name) { + const auto* attr = GetAttrHdl(attr_name); + char c = {}; + size_t out = {}; + // first call to get the bytes needed + auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, &c, sizeof(char), &out); + if (status) { + std::vector chars(out, '\0'); + Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, chars.data(), out, &out)); + ShapeInferContext::Strings strings; + char* char_st = chars.data(); + char* char_ed = char_st + out; + while (char_st < char_ed) { + strings.emplace_back(char_st); + while (*char_st != '\0') { + char_st++; + } + char_st++; + } + return strings; + } else { + return {std::string{c}}; + } +} + +inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) const { + const OrtOpAttr* attr_hdl = {}; + Ort::ThrowOnError(ort_api_->ShapeInferContext_GetAttribute(ctx_, attr_name, &attr_hdl)); + return attr_hdl; +} + } // namespace Ort diff --git a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h index 887339ebc50d6..443710884743a 100644 --- a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h +++ b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h @@ -773,8 +773,11 @@ struct OrtLiteCustomOp : public OrtCustomOp { PARSE_ARGS(Ort::Float8E5M2FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ) OrtLiteCustomOp(const char* op_name, - const char* execution_provider) : op_name_(op_name), - execution_provider_(execution_provider) { + const char* execution_provider, + int start_ver = 1, int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name), + execution_provider_(execution_provider), + start_ver_(start_ver), + end_ver_(end_ver) { OrtCustomOp::version = ORT_API_VERSION; OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast(op)->op_name_.c_str(); }; @@ -835,6 +838,18 @@ struct OrtLiteCustomOp : public OrtCustomOp { OrtCustomOp::CreateKernelV2 = {}; OrtCustomOp::KernelComputeV2 = {}; OrtCustomOp::KernelCompute = {}; + + OrtCustomOp::InferOutputShapeFn = {}; + + OrtCustomOp::GetStartVersion = [](const OrtCustomOp* op) { + auto self = reinterpret_cast(op); + return self->start_ver_; + }; + + OrtCustomOp::GetEndVersion = [](const OrtCustomOp* op) { + auto self = reinterpret_cast(op); + return self->end_ver_; + }; } const std::string op_name_; @@ -842,6 +857,9 @@ struct OrtLiteCustomOp : public OrtCustomOp { std::vector input_types_; std::vector output_types_; + + int start_ver_ = 1; + int end_ver_ = MAX_CUSTOM_OP_END_VER; }; //////////////////////////// OrtLiteCustomFunc //////////////////////////////// @@ -870,8 +888,12 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { OrtLiteCustomFunc(const char* op_name, const char* execution_provider, - ComputeFn compute_fn) : OrtLiteCustomOp(op_name, execution_provider), - compute_fn_(compute_fn) { + ComputeFn compute_fn, + ShapeInferFn shape_infer_fn = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver), + compute_fn_(compute_fn), + shape_infer_fn_(shape_infer_fn) { ParseArgs(input_types_, output_types_); OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { @@ -894,12 +916,24 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete reinterpret_cast(op_kernel); }; + + if (shape_infer_fn_) { + OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { + auto shape_info_fn = static_cast(op)->shape_infer_fn_; + ShapeInferContext ctx(&GetApi(), ort_ctx); + return shape_info_fn(ctx); + }; + } } OrtLiteCustomFunc(const char* op_name, const char* execution_provider, - ComputeFnReturnStatus compute_fn_return_status) : OrtLiteCustomOp(op_name, execution_provider), - compute_fn_return_status_(compute_fn_return_status) { + ComputeFnReturnStatus compute_fn_return_status, + ShapeInferFn shape_infer_fn = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver), + compute_fn_return_status_(compute_fn_return_status), + shape_infer_fn_(shape_infer_fn) { ParseArgs(input_types_, output_types_); OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { @@ -922,10 +956,19 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete reinterpret_cast(op_kernel); }; + + if (shape_infer_fn_) { + OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { + auto shape_info_fn = static_cast(op)->shape_infer_fn_; + ShapeInferContext ctx(&GetApi(), ort_ctx); + return shape_info_fn(ctx); + }; + } } ComputeFn compute_fn_ = {}; ComputeFnReturnStatus compute_fn_return_status_ = {}; + ShapeInferFn shape_infer_fn_ = {}; }; // struct OrtLiteCustomFunc /////////////////////////// OrtLiteCustomStruct /////////////////////////// @@ -962,9 +1005,10 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp { }; OrtLiteCustomStruct(const char* op_name, - const char* execution_provider) : OrtLiteCustomOp(op_name, - execution_provider) { - init(&CustomOp::Compute); + const char* execution_provider, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver) { + SetCompute(&CustomOp::Compute); OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { auto kernel = std::make_unique(); @@ -979,10 +1023,12 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp { OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete reinterpret_cast(op_kernel); }; + + SetShapeInfer(0); } template - void init(CustomComputeFn) { + void SetCompute(CustomComputeFn) { ParseArgs(input_types_, output_types_); OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { auto kernel = reinterpret_cast(op_kernel); @@ -993,7 +1039,7 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp { } template - void init(CustomComputeFnReturnStatus) { + void SetCompute(CustomComputeFnReturnStatus) { ParseArgs(input_types_, output_types_); OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { auto kernel = reinterpret_cast(op_kernel); @@ -1002,6 +1048,20 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp { return std::apply([kernel](Args const&... t_args) { Status status = kernel->custom_op_->Compute(t_args...); return status.release(); }, t); }; } + + template + decltype(&C::InferOutputShape) SetShapeInfer(decltype(&C::InferOutputShape)) { + OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { + ShapeInferContext ctx(&GetApi(), ort_ctx); + return C::InferOutputShape(ctx); + }; + return {}; + } + + template + void SetShapeInfer(...) { + OrtCustomOp::InferOutputShapeFn = {}; + } }; // struct OrtLiteCustomStruct /////////////////////////// CreateLiteCustomOp //////////////////////////// @@ -1009,24 +1069,32 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp { template OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, const char* execution_provider, - void (*custom_compute_fn)(Args...)) { + void (*custom_compute_fn)(Args...), + Status (*shape_infer_fn)(ShapeInferContext&) = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) { using LiteOp = OrtLiteCustomFunc; - return std::make_unique(op_name, execution_provider, custom_compute_fn).release(); + return std::make_unique(op_name, execution_provider, custom_compute_fn, shape_infer_fn, start_ver, end_ver).release(); } template OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, const char* execution_provider, - Status (*custom_compute_fn_v2)(Args...)) { + Status (*custom_compute_fn_v2)(Args...), + Status (*shape_infer_fn)(ShapeInferContext&) = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) { using LiteOp = OrtLiteCustomFunc; - return std::make_unique(op_name, execution_provider, custom_compute_fn_v2).release(); + return std::make_unique(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn, start_ver, end_ver).release(); } template OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, - const char* execution_provider) { + const char* execution_provider, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) { using LiteOp = OrtLiteCustomStruct; - return std::make_unique(op_name, execution_provider).release(); + return std::make_unique(op_name, execution_provider, start_ver, end_ver).release(); } } // namespace Custom diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 37545f41b43dd..831def24e4f5e 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -67,6 +67,16 @@ static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enab // GeluApproximation has side effects which may change the inference results. It is disabled by default due to this. static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation"; +// This setting controls whether to enable AheadOfTime function inlining. +// AOT function inlining examines the graph and attempts to inline as many locally defined functions in the model +// as possible with the help of enabled execution providers. +// This can reduce the number of function calls and improve performance because it is done before +// Level1 optimizers and constant folding. However, under some circumstances, when the EPs are not available, +// one can disable the AOT inlining, produce an optimized model and postpone AOT until run time. +// "0": enable; "1": disable. +// Its default value is "0". +static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining = "session.disable_aot_function_inlining"; + #ifdef ENABLE_TRAINING // Specifies a list of op types for memory footprint reduction. // The value should be a ","-delimited list of pair of diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index 09d2cefbb8224..0078adb6402f8 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -13,6 +13,7 @@ import java.nio.IntBuffer; import java.nio.LongBuffer; import java.nio.ShortBuffer; +import java.util.Optional; /** * A Java object wrapping an OnnxTensor. Tensors are the main input to the library, and can also be @@ -21,18 +22,60 @@ public class OnnxTensor extends OnnxTensorLike { /** - * This reference is held for OnnxTensors backed by a Java nio buffer to ensure the buffer does + * This reference is held for OnnxTensors backed by a java.nio.Buffer to ensure the buffer does * not go out of scope while the OnnxTensor exists. */ private final Buffer buffer; + /** + * Denotes if the OnnxTensor made a copy of the buffer on construction (i.e. it may have the only + * reference). + */ + private final boolean ownsBuffer; + OnnxTensor(long nativeHandle, long allocatorHandle, TensorInfo info) { - this(nativeHandle, allocatorHandle, info, null); + this(nativeHandle, allocatorHandle, info, null, false); } - OnnxTensor(long nativeHandle, long allocatorHandle, TensorInfo info, Buffer buffer) { + OnnxTensor( + long nativeHandle, long allocatorHandle, TensorInfo info, Buffer buffer, boolean ownsBuffer) { super(nativeHandle, allocatorHandle, info); this.buffer = buffer; + this.ownsBuffer = ownsBuffer; + } + + /** + * Returns true if the buffer in this OnnxTensor was created on construction of this tensor, i.e., + * it is a copy of a user supplied buffer or array and may hold the only reference to that buffer. + * + *

When this is true the backing buffer was copied from the user input, so users cannot mutate + * the state of this buffer without first getting the reference via {@link #getBufferRef()}. + * + * @return True if the buffer in this OnnxTensor was allocated by it on construction (i.e., it is + * a copy of a user buffer.) + */ + public boolean ownsBuffer() { + return this.ownsBuffer; + } + + /** + * Returns a reference to the buffer which backs this {@code OnnxTensor}. If the tensor is not + * backed by a buffer (i.e., it was created from a Java array, or is backed by memory allocated by + * ORT) this method returns an empty {@link Optional}. + * + *

Changes to the buffer elements will be reflected in the native {@code OrtValue}, this can be + * used to repeatedly update a single tensor for multiple different inferences without allocating + * new tensors, though the inputs must remain the same size and shape. + * + *

Note: the tensor could refer to a contiguous range of elements in this buffer, not the whole + * buffer. It is up to the user to manage this information by respecting the position and limit. + * As a consequence, accessing this reference should be considered problematic when multiple + * threads hold references to the buffer. + * + * @return A reference to the buffer. + */ + public Optional getBufferRef() { + return Optional.ofNullable(buffer); } @Override @@ -45,7 +88,8 @@ public OnnxValueType getType() { * primitives if it has multiple dimensions. * *

Java multidimensional arrays are quite slow for more than 2 dimensions, in that case it is - * recommended you use the java.nio.Buffer extractors below (e.g. {@link #getFloatBuffer}). + * recommended you use the {@link java.nio.Buffer} extractors below (e.g., {@link + * #getFloatBuffer}). * * @return A Java value. * @throws OrtException If the value could not be extracted as the Tensor is invalid, or if the @@ -283,6 +327,12 @@ private native void getArray(long apiHandle, long nativeHandle, Object carrier) * multidimensional array. The shape is inferred from the object using reflection. The default * allocator is used. * + *

Note: Java multidimensional arrays are not dense and this method requires traversing a large + * number of pointers for high dimensional arrays. For types other than Strings it is recommended + * to use one of the {@code createTensor} methods which accepts a {@link java.nio.Buffer}, e.g. + * {@link #createTensor(OrtEnvironment, FloatBuffer, long[])} as those methods are zero copy to + * transfer data into ORT when using direct buffers. + * * @param env The current OrtEnvironment. * @param data The data to store in a tensor. * @return An OnnxTensor storing the data. @@ -700,7 +750,8 @@ private static OnnxTensor createTensor( info.onnxType.value), allocator.handle, info, - tuple.data); + tuple.data, + tuple.isCopy); } private static native long createTensor( diff --git a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java index fcb4590717fea..a5f285ba86a14 100644 --- a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java +++ b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java @@ -88,6 +88,91 @@ public void testScalarCreation() throws OrtException { } } + @Test + public void testBufferCreation() throws OrtException { + OrtEnvironment env = OrtEnvironment.getEnvironment(); + + // Test creating a value from an array + // Arrays result in tensors allocated by ORT, so they do not have a backing java.nio.Buffer + float[] arrValues = new float[] {0, 1, 2, 3, 4}; + try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) { + // array creation isn't backed by buffers + Assertions.assertFalse(t.ownsBuffer()); + Assertions.assertFalse(t.getBufferRef().isPresent()); + FloatBuffer buf = t.getFloatBuffer(); + float[] output = new float[arrValues.length]; + buf.get(output); + Assertions.assertArrayEquals(arrValues, output); + + // Can't modify the tensor through this buffer. + buf.put(0, 25); + Assertions.assertArrayEquals(arrValues, output); + } + + // Test creating a value from a non-direct byte buffer + // Non-direct byte buffers are allocated on the Java heap and must be copied into off-heap + // direct byte buffers + // which can be directly passed to ORT + FloatBuffer nonDirectBuffer = FloatBuffer.allocate(5); + nonDirectBuffer.put(arrValues); + nonDirectBuffer.rewind(); + try (OnnxTensor t = OnnxTensor.createTensor(env, nonDirectBuffer, new long[] {1, 5})) { + // non-direct buffers trigger a copy + Assertions.assertTrue(t.ownsBuffer()); + // tensors backed by buffers can get the buffer ref back out + Assertions.assertTrue(t.getBufferRef().isPresent()); + FloatBuffer buf = t.getFloatBuffer(); + float[] output = new float[arrValues.length]; + buf.get(output); + Assertions.assertArrayEquals(arrValues, output); + + // Can't modify the tensor through getFloatBuffer. + buf.put(0, 25); + Assertions.assertArrayEquals(arrValues, output); + + // Can modify the tensor through getBufferRef. + FloatBuffer ref = (FloatBuffer) t.getBufferRef().get(); + ref.put(0, 25); + buf = t.getFloatBuffer(); + buf.get(output); + Assertions.assertEquals(25, output[0]); + } + + // Test creating a value from a direct byte buffer + // Direct byte buffers can be passed into ORT without additional copies or processing + FloatBuffer directBuffer = + ByteBuffer.allocateDirect(5 * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); + directBuffer.put(arrValues); + directBuffer.rewind(); + try (OnnxTensor t = OnnxTensor.createTensor(env, directBuffer, new long[] {1, 5})) { + // direct buffers don't trigger a copy + Assertions.assertFalse(t.ownsBuffer()); + // tensors backed by buffers can get the buffer ref back out + Assertions.assertTrue(t.getBufferRef().isPresent()); + FloatBuffer buf = t.getFloatBuffer(); + float[] output = new float[arrValues.length]; + buf.get(output); + Assertions.assertArrayEquals(arrValues, output); + + // Can't modify the tensor through getFloatBuffer. + buf.put(0, 25); + Assertions.assertArrayEquals(arrValues, output); + + // Can modify the tensor through getBufferRef. + FloatBuffer ref = (FloatBuffer) t.getBufferRef().get(); + ref.put(0, 25); + buf = t.getFloatBuffer(); + buf.get(output); + Assertions.assertEquals(25, output[0]); + + // Can modify the tensor through our original ref to the direct byte buffer + directBuffer.put(1, 15); + buf = t.getFloatBuffer(); + buf.get(output); + Assertions.assertEquals(15, output[1]); + } + } + @Test public void testStringCreation() throws OrtException { OrtEnvironment env = OrtEnvironment.getEnvironment(); diff --git a/js/.vscode/settings.json b/js/.vscode/settings.json index 4948899ec671b..9c2fe646d728d 100644 --- a/js/.vscode/settings.json +++ b/js/.vscode/settings.json @@ -36,7 +36,8 @@ "node/lib/**/*.js.map": true, "node/lib/**/*.js": true, "web/lib/**/*.js.map": true, - "web/lib/**/*.js": true + "web/lib/**/*.js": true, + "web/lib/**/*.d.ts": true }, "files.insertFinalNewline": true, "files.trimTrailingWhitespace": true, diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index f06d06bda035f..ee6d26b22b1f6 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -1,11 +1,21 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {TrainingSessionHandler} from './backend.js'; +import {resolveBackend} from './backend-impl.js'; +import {SessionHandler, TrainingSessionHandler} from './backend.js'; import {InferenceSession as InferenceSession} from './inference-session.js'; +import {OnnxValue} from './onnx-value.js'; +import {Tensor} from './tensor.js'; import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js'; type SessionOptions = InferenceSession.SessionOptions; +type FeedsType = InferenceSession.FeedsType; +type FetchesType = InferenceSession.FetchesType; +type ReturnType = InferenceSession.ReturnType; +type RunOptions = InferenceSession.RunOptions; + +const noBackendErrMsg: string = 'Training backend could not be resolved. ' + + 'Make sure you\'re using the correct configuration & WebAssembly files.'; export class TrainingSession implements TrainingSessionInterface { private constructor(handler: TrainingSessionHandler) { @@ -20,26 +30,157 @@ export class TrainingSession implements TrainingSessionInterface { return this.handler.outputNames; } - static async create(_trainingOptions: TrainingSessionCreateOptions, _sessionOptions?: SessionOptions): + static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions): Promise { - throw new Error('Method not implemented'); + const evalModel: string|Uint8Array = trainingOptions.evalModel || ''; + const optimizerModel: string|Uint8Array = trainingOptions.optimizerModel || ''; + const options: SessionOptions = sessionOptions || {}; + + // get backend hints + const eps = options.executionProviders || []; + const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); + const backend = await resolveBackend(backendHints); + if (backend.createTrainingSessionHandler) { + const handler = await backend.createTrainingSessionHandler( + trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options); + return new TrainingSession(handler); + } else { + throw new Error(noBackendErrMsg); + } } - async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); + /** + * Helper function for runTrainStep and future runStep methods that handles the type-narrowing conversion from + * the given parameters to SessionHandler.FetchesType and RunOptions. + * + * @param feeds the required input + * @param arg1 narrowed & converted into the SessionHandler.FetchesType or RunOptions object + * @param arg2 optional RunOptions object. + * @returns + */ + typeNarrowingForRunStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): + [SessionHandler.FetchesType, RunOptions] { + const fetches: {[name: string]: OnnxValue|null} = {}; + let options: RunOptions = {}; + // check inputs + if (typeof feeds !== 'object' || feeds === null || feeds instanceof Tensor || Array.isArray(feeds)) { + throw new TypeError( + '\'feeds\' must be an object that use input names as keys and OnnxValue as corresponding values.'); + } + + let isFetchesEmpty = true; + // determine which override is being used + if (typeof arg1 === 'object') { + if (arg1 === null) { + throw new TypeError('Unexpected argument[1]: cannot be null.'); + } + if (arg1 instanceof Tensor) { + throw new TypeError('\'fetches\' cannot be a Tensor'); + } + + if (Array.isArray(arg1)) { + if (arg1.length === 0) { + throw new TypeError('\'fetches\' cannot be an empty array.'); + } + isFetchesEmpty = false; + // output names + for (const name of arg1) { + if (typeof name !== 'string') { + throw new TypeError('\'fetches\' must be a string array or an object.'); + } + if (this.outputNames.indexOf(name) === -1) { + throw new RangeError(`'fetches' contains invalid output name: ${name}.`); + } + fetches[name] = null; + } + + if (typeof arg2 === 'object' && arg2 !== null) { + options = arg2; + } else if (typeof arg2 !== 'undefined') { + throw new TypeError('\'options\' must be an object.'); + } + } else { + // decide whether arg1 is fetches or options + // if any output name is present and its value is valid OnnxValue, we consider it fetches + let isFetches = false; + const arg1Keys = Object.getOwnPropertyNames(arg1); + for (const name of this.outputNames) { + if (arg1Keys.indexOf(name) !== -1) { + const v = (arg1 as InferenceSession.NullableOnnxValueMapType)[name]; + if (v === null || v instanceof Tensor) { + isFetches = true; + isFetchesEmpty = false; + fetches[name] = v; + } + } + } + + if (isFetches) { + if (typeof arg2 === 'object' && arg2 !== null) { + options = arg2; + } else if (typeof arg2 !== 'undefined') { + throw new TypeError('\'options\' must be an object.'); + } + } else { + options = arg1 as RunOptions; + } + } + } else if (typeof arg1 !== 'undefined') { + throw new TypeError('Unexpected argument[1]: must be \'fetches\' or \'options\'.'); + } + + // check if all inputs are in feed + for (const name of this.inputNames) { + if (typeof feeds[name] === 'undefined') { + throw new Error(`input '${name}' is missing in 'feeds'.`); + } + } + + // if no fetches is specified, we use the full output names list + if (isFetchesEmpty) { + for (const name of this.outputNames) { + fetches[name] = null; + } + } + + return [fetches, options]; } - async getContiguousParameters(_trainableOnly: boolean): Promise { + /** + * Helper method for runTrainStep and any other runStep methods. Takes the ReturnType result from the SessionHandler + * and changes it into a map of Tensors. + * + * @param results + * @returns + */ + convertHandlerReturnTypeToMapOfTensors(results: SessionHandler.ReturnType): ReturnType { + const returnValue: {[name: string]: OnnxValue} = {}; + for (const key in results) { + if (Object.hasOwnProperty.call(results, key)) { + const result = results[key]; + if (result instanceof Tensor) { + returnValue[key] = result; + } else { + returnValue[key] = new Tensor(result.type, result.data, result.dims); + } + } + } + return returnValue; + } + + runTrainStep(feeds: FeedsType, options?: RunOptions): Promise; + runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise; + async runTrainStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise { + const [fetches, options] = this.typeNarrowingForRunStep(feeds, arg1, arg2); + const results = await this.handler.runTrainStep(feeds, fetches, options); + return this.convertHandlerReturnTypeToMapOfTensors(results); + } + + async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { throw new Error('Method not implemented.'); } - runTrainStep(feeds: InferenceSession.OnnxValueMapType, options?: InferenceSession.RunOptions|undefined): - Promise; - runTrainStep( - feeds: InferenceSession.OnnxValueMapType, fetches: InferenceSession.FetchesType, - options?: InferenceSession.RunOptions|undefined): Promise; - async runTrainStep(_feeds: unknown, _fetches?: unknown, _options?: unknown): - Promise { + async getContiguousParameters(_trainableOnly: boolean): Promise { throw new Error('Method not implemented.'); } diff --git a/js/react_native/e2e/yarn.lock b/js/react_native/e2e/yarn.lock index aaa35ae7895b9..9e20a286c4e27 100644 --- a/js/react_native/e2e/yarn.lock +++ b/js/react_native/e2e/yarn.lock @@ -39,6 +39,14 @@ dependencies: "@babel/highlight" "^7.18.6" +"@babel/code-frame@^7.22.13": + version "7.22.13" + resolved "https://registry.yarnpkg.com/@babel/code-frame/-/code-frame-7.22.13.tgz#e3c1c099402598483b7a8c46a721d1038803755e" + integrity sha512-XktuhWlJ5g+3TJXc5upd9Ks1HutSArik6jf2eAjYFyIOf4ej3RN+184cZbzDvbPnuTJIUhPKKJE3cIsYTiAT3w== + dependencies: + "@babel/highlight" "^7.22.13" + chalk "^2.4.2" + "@babel/compat-data@^7.13.11", "@babel/compat-data@^7.17.10": version "7.18.5" resolved "https://registry.yarnpkg.com/@babel/compat-data/-/compat-data-7.18.5.tgz#acac0c839e317038c73137fbb6ef71a1d6238471" @@ -100,15 +108,6 @@ "@jridgewell/gen-mapping" "^0.3.0" jsesc "^2.5.1" -"@babel/generator@^7.18.7": - version "7.18.7" - resolved "https://registry.yarnpkg.com/@babel/generator/-/generator-7.18.7.tgz#2aa78da3c05aadfc82dbac16c99552fc802284bd" - integrity sha512-shck+7VLlY72a2w9c3zYWuE1pwOKEiQHV7GTUbSnhyl5eu3i04t30tBY82ZRWrDfo3gkakCFtevExnxbkf2a3A== - dependencies: - "@babel/types" "^7.18.7" - "@jridgewell/gen-mapping" "^0.3.2" - jsesc "^2.5.1" - "@babel/generator@^7.21.4", "@babel/generator@^7.7.2": version "7.21.4" resolved "https://registry.yarnpkg.com/@babel/generator/-/generator-7.21.4.tgz#64a94b7448989f421f919d5239ef553b37bb26bc" @@ -119,6 +118,16 @@ "@jridgewell/trace-mapping" "^0.3.17" jsesc "^2.5.1" +"@babel/generator@^7.23.0": + version "7.23.0" + resolved "https://registry.yarnpkg.com/@babel/generator/-/generator-7.23.0.tgz#df5c386e2218be505b34837acbcb874d7a983420" + integrity sha512-lN85QRR+5IbYrMWM6Y4pE/noaQtg4pNiqeNGX60eqOfo6gtEj6uw/JagelB8vVztSd7R6M5n1+PQkDbHbBRU4g== + dependencies: + "@babel/types" "^7.23.0" + "@jridgewell/gen-mapping" "^0.3.2" + "@jridgewell/trace-mapping" "^0.3.17" + jsesc "^2.5.1" + "@babel/helper-annotate-as-pure@^7.16.7": version "7.16.7" resolved "https://registry.yarnpkg.com/@babel/helper-annotate-as-pure/-/helper-annotate-as-pure-7.16.7.tgz#bb2339a7534a9c128e3102024c60760a3a7f3862" @@ -220,6 +229,11 @@ resolved "https://registry.yarnpkg.com/@babel/helper-environment-visitor/-/helper-environment-visitor-7.18.9.tgz#0c0cee9b35d2ca190478756865bb3528422f51be" integrity sha512-3r/aACDJ3fhQ/EVgFy0hpj8oHyHpQc+LPtJoY9SzTThAsStm4Ptegq92vqKoE3vD706ZVFWITnMnxucw+S9Ipg== +"@babel/helper-environment-visitor@^7.22.20": + version "7.22.20" + resolved "https://registry.yarnpkg.com/@babel/helper-environment-visitor/-/helper-environment-visitor-7.22.20.tgz#96159db61d34a29dba454c959f5ae4a649ba9167" + integrity sha512-zfedSIzFhat/gFhWfHtgWvlec0nqB9YEIVrpuwjruLlXfUSnA8cJB0miHKwqDnQ7d32aKo2xt88/xZptwxbfhA== + "@babel/helper-explode-assignable-expression@^7.16.7": version "7.16.7" resolved "https://registry.yarnpkg.com/@babel/helper-explode-assignable-expression/-/helper-explode-assignable-expression-7.16.7.tgz#12a6d8522fdd834f194e868af6354e8650242b7a" @@ -243,27 +257,20 @@ "@babel/template" "^7.18.6" "@babel/types" "^7.18.6" -"@babel/helper-function-name@^7.21.0": - version "7.21.0" - resolved "https://registry.yarnpkg.com/@babel/helper-function-name/-/helper-function-name-7.21.0.tgz#d552829b10ea9f120969304023cd0645fa00b1b4" - integrity sha512-HfK1aMRanKHpxemaY2gqBmL04iAPOPRj7DxtNbiDOrJK+gdwkiNRVpCpUJYbUT+aZyemKN8brqTOxzCaG6ExRg== - dependencies: - "@babel/template" "^7.20.7" - "@babel/types" "^7.21.0" - -"@babel/helper-hoist-variables@^7.16.7": - version "7.16.7" - resolved "https://registry.yarnpkg.com/@babel/helper-hoist-variables/-/helper-hoist-variables-7.16.7.tgz#86bcb19a77a509c7b77d0e22323ef588fa58c246" - integrity sha512-m04d/0Op34H5v7pbZw6pSKP7weA6lsMvfiIAMeIvkY/R4xQtBSMFEigu9QTZ2qB/9l22vsxtM8a+Q8CzD255fg== +"@babel/helper-function-name@^7.23.0": + version "7.23.0" + resolved "https://registry.yarnpkg.com/@babel/helper-function-name/-/helper-function-name-7.23.0.tgz#1f9a3cdbd5b2698a670c30d2735f9af95ed52759" + integrity sha512-OErEqsrxjZTJciZ4Oo+eoZqeW9UIiOcuYKRJA4ZAgV9myA+pOXhhmpfNCKjEH/auVfEYVFJ6y1Tc4r0eIApqiw== dependencies: - "@babel/types" "^7.16.7" + "@babel/template" "^7.22.15" + "@babel/types" "^7.23.0" -"@babel/helper-hoist-variables@^7.18.6": - version "7.18.6" - resolved "https://registry.yarnpkg.com/@babel/helper-hoist-variables/-/helper-hoist-variables-7.18.6.tgz#d4d2c8fb4baeaa5c68b99cc8245c56554f926678" - integrity sha512-UlJQPkFqFULIcyW5sbzgbkxn2FKRgwWiRexcuaR8RNJRy8+LLveqPjwZV/bwrLZCN0eUHD/x8D0heK1ozuoo6Q== +"@babel/helper-hoist-variables@^7.22.5": + version "7.22.5" + resolved "https://registry.yarnpkg.com/@babel/helper-hoist-variables/-/helper-hoist-variables-7.22.5.tgz#c01a007dac05c085914e8fb652b339db50d823bb" + integrity sha512-wGjk9QZVzvknA6yKIUURb8zY3grXCcOZt+/7Wcy8O2uctxhplmUPkOdlgoNhmdVee2c92JXbf1xpMtVNbfoxRw== dependencies: - "@babel/types" "^7.18.6" + "@babel/types" "^7.22.5" "@babel/helper-member-expression-to-functions@^7.17.7": version "7.17.7" @@ -401,11 +408,23 @@ dependencies: "@babel/types" "^7.18.6" +"@babel/helper-split-export-declaration@^7.22.6": + version "7.22.6" + resolved "https://registry.yarnpkg.com/@babel/helper-split-export-declaration/-/helper-split-export-declaration-7.22.6.tgz#322c61b7310c0997fe4c323955667f18fcefb91c" + integrity sha512-AsUnxuLhRYsisFiaJwvp1QF+I3KjD5FOxut14q/GzovUe6orHLesW2C7d754kRm53h5gqrz6sFl6sxc4BVtE/g== + dependencies: + "@babel/types" "^7.22.5" + "@babel/helper-string-parser@^7.19.4": version "7.19.4" resolved "https://registry.yarnpkg.com/@babel/helper-string-parser/-/helper-string-parser-7.19.4.tgz#38d3acb654b4701a9b77fb0615a96f775c3a9e63" integrity sha512-nHtDoQcuqFmwYNYPz3Rah5ph2p8PFeFCsZk9A/48dPc/rGocJ5J3hAAZ7pb76VWX3fZKu+uEr/FhH5jLx7umrw== +"@babel/helper-string-parser@^7.22.5": + version "7.22.5" + resolved "https://registry.yarnpkg.com/@babel/helper-string-parser/-/helper-string-parser-7.22.5.tgz#533f36457a25814cf1df6488523ad547d784a99f" + integrity sha512-mM4COjgZox8U+JcXQwPijIZLElkgEpO5rsERVDJTc2qfCDfERyob6k5WegS14SX18IIjv+XD+GrqNumY5JRCDw== + "@babel/helper-validator-identifier@^7.16.7": version "7.16.7" resolved "https://registry.yarnpkg.com/@babel/helper-validator-identifier/-/helper-validator-identifier-7.16.7.tgz#e8c602438c4a8195751243da9031d1607d247cad" @@ -421,6 +440,11 @@ resolved "https://registry.yarnpkg.com/@babel/helper-validator-identifier/-/helper-validator-identifier-7.19.1.tgz#7eea834cf32901ffdc1a7ee555e2f9c27e249ca2" integrity sha512-awrNfaMtnHUr653GgGEs++LlAvW6w+DcPrOliSMXWCKo597CwL5Acf/wWdNkf/tfEQE3mjkeD1YOVZOUV/od1w== +"@babel/helper-validator-identifier@^7.22.20": + version "7.22.20" + resolved "https://registry.yarnpkg.com/@babel/helper-validator-identifier/-/helper-validator-identifier-7.22.20.tgz#c4ae002c61d2879e724581d96665583dbc1dc0e0" + integrity sha512-Y4OZ+ytlatR8AI+8KZfKuL5urKp7qey08ha31L8b3BwewJAoJamTzyvxPR/5D+KkdJCGPq/+8TukHBlY10FX9A== + "@babel/helper-validator-option@^7.16.7": version "7.16.7" resolved "https://registry.yarnpkg.com/@babel/helper-validator-option/-/helper-validator-option-7.16.7.tgz#b203ce62ce5fe153899b617c08957de860de4d23" @@ -487,6 +511,15 @@ chalk "^2.0.0" js-tokens "^4.0.0" +"@babel/highlight@^7.22.13": + version "7.22.20" + resolved "https://registry.yarnpkg.com/@babel/highlight/-/highlight-7.22.20.tgz#4ca92b71d80554b01427815e06f2df965b9c1f54" + integrity sha512-dkdMCN3py0+ksCgYmGG8jKeGA/8Tk+gJwSYYlFGxG5lmhfKNoAy004YpLxpS1W2J8m/EK2Ew+yOs9pVRwO89mg== + dependencies: + "@babel/helper-validator-identifier" "^7.22.20" + chalk "^2.4.2" + js-tokens "^4.0.0" + "@babel/parser@^7.1.0", "@babel/parser@^7.14.7", "@babel/parser@^7.20.7", "@babel/parser@^7.21.4": version "7.21.4" resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.21.4.tgz#94003fdfc520bbe2875d4ae557b43ddb6d880f17" @@ -497,11 +530,16 @@ resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.18.5.tgz#337062363436a893a2d22faa60be5bb37091c83c" integrity sha512-YZWVaglMiplo7v8f1oMQ5ZPQr0vn7HPeZXxXWsxXJRjGVrzUFn9OxFQl1sb5wzfootjA/yChhW84BV+383FSOw== -"@babel/parser@^7.18.6", "@babel/parser@^7.18.8": +"@babel/parser@^7.18.6": version "7.18.8" resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.18.8.tgz#822146080ac9c62dac0823bb3489622e0bc1cbdf" integrity sha512-RSKRfYX20dyH+elbJK2uqAkVyucL+xXzhqlMD5/ZXx+dAAwpyB7HsvnHe/ZUGOF+xLr5Wx9/JoXVTj6BQE2/oA== +"@babel/parser@^7.22.15", "@babel/parser@^7.23.0": + version "7.23.0" + resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.23.0.tgz#da950e622420bf96ca0d0f2909cdddac3acd8719" + integrity sha512-vvPKKdMemU85V9WE/l5wZEmImpCtLqbnTvqDS2U1fJ96KrxoW7KrXhNsNCblQlg8Ck4b85yxdTyelsMUgFUXiw== + "@babel/plugin-proposal-async-generator-functions@^7.0.0": version "7.18.6" resolved "https://registry.yarnpkg.com/@babel/plugin-proposal-async-generator-functions/-/plugin-proposal-async-generator-functions-7.18.6.tgz#aedac81e6fc12bb643374656dd5f2605bf743d17" @@ -1016,51 +1054,28 @@ "@babel/parser" "^7.20.7" "@babel/types" "^7.20.7" -"@babel/traverse@^7.13.0", "@babel/traverse@^7.14.0", "@babel/traverse@^7.16.8", "@babel/traverse@^7.18.0", "@babel/traverse@^7.18.2", "@babel/traverse@^7.18.5": - version "7.18.5" - resolved "https://registry.yarnpkg.com/@babel/traverse/-/traverse-7.18.5.tgz#94a8195ad9642801837988ab77f36e992d9a20cd" - integrity sha512-aKXj1KT66sBj0vVzk6rEeAO6Z9aiiQ68wfDgge3nHhA/my6xMM/7HGQUNumKZaoa2qUPQ5whJG9aAifsxUKfLA== - dependencies: - "@babel/code-frame" "^7.16.7" - "@babel/generator" "^7.18.2" - "@babel/helper-environment-visitor" "^7.18.2" - "@babel/helper-function-name" "^7.17.9" - "@babel/helper-hoist-variables" "^7.16.7" - "@babel/helper-split-export-declaration" "^7.16.7" - "@babel/parser" "^7.18.5" - "@babel/types" "^7.18.4" - debug "^4.1.0" - globals "^11.1.0" - -"@babel/traverse@^7.18.6": - version "7.18.8" - resolved "https://registry.yarnpkg.com/@babel/traverse/-/traverse-7.18.8.tgz#f095e62ab46abf1da35e5a2011f43aee72d8d5b0" - integrity sha512-UNg/AcSySJYR/+mIcJQDCv00T+AqRO7j/ZEJLzpaYtgM48rMg5MnkJgyNqkzo88+p4tfRvZJCEiwwfG6h4jkRg== - dependencies: - "@babel/code-frame" "^7.18.6" - "@babel/generator" "^7.18.7" - "@babel/helper-environment-visitor" "^7.18.6" - "@babel/helper-function-name" "^7.18.6" - "@babel/helper-hoist-variables" "^7.18.6" - "@babel/helper-split-export-declaration" "^7.18.6" - "@babel/parser" "^7.18.8" - "@babel/types" "^7.18.8" - debug "^4.1.0" - globals "^11.1.0" - -"@babel/traverse@^7.21.0", "@babel/traverse@^7.21.2", "@babel/traverse@^7.21.4", "@babel/traverse@^7.7.2": - version "7.21.4" - resolved "https://registry.yarnpkg.com/@babel/traverse/-/traverse-7.21.4.tgz#a836aca7b116634e97a6ed99976236b3282c9d36" - integrity sha512-eyKrRHKdyZxqDm+fV1iqL9UAHMoIg0nDaGqfIOd8rKH17m5snv7Gn4qgjBoFfLz9APvjFU/ICT00NVCv1Epp8Q== - dependencies: - "@babel/code-frame" "^7.21.4" - "@babel/generator" "^7.21.4" - "@babel/helper-environment-visitor" "^7.18.9" - "@babel/helper-function-name" "^7.21.0" - "@babel/helper-hoist-variables" "^7.18.6" - "@babel/helper-split-export-declaration" "^7.18.6" - "@babel/parser" "^7.21.4" - "@babel/types" "^7.21.4" +"@babel/template@^7.22.15": + version "7.22.15" + resolved "https://registry.yarnpkg.com/@babel/template/-/template-7.22.15.tgz#09576efc3830f0430f4548ef971dde1350ef2f38" + integrity sha512-QPErUVm4uyJa60rkI73qneDacvdvzxshT3kksGqlGWYdOTIUOwJ7RDUL8sGqslY1uXWSL6xMFKEXDS3ox2uF0w== + dependencies: + "@babel/code-frame" "^7.22.13" + "@babel/parser" "^7.22.15" + "@babel/types" "^7.22.15" + +"@babel/traverse@^7.13.0", "@babel/traverse@^7.14.0", "@babel/traverse@^7.16.8", "@babel/traverse@^7.18.0", "@babel/traverse@^7.18.2", "@babel/traverse@^7.18.5", "@babel/traverse@^7.18.6", "@babel/traverse@^7.21.0", "@babel/traverse@^7.21.2", "@babel/traverse@^7.21.4", "@babel/traverse@^7.7.2": + version "7.23.2" + resolved "https://registry.yarnpkg.com/@babel/traverse/-/traverse-7.23.2.tgz#329c7a06735e144a506bdb2cad0268b7f46f4ad8" + integrity sha512-azpe59SQ48qG6nu2CzcMLbxUudtN+dOM9kDbUqGq3HXUJRlo7i8fvPoxQUzYgLZ4cMVmuZgm8vvBpNeRhd6XSw== + dependencies: + "@babel/code-frame" "^7.22.13" + "@babel/generator" "^7.23.0" + "@babel/helper-environment-visitor" "^7.22.20" + "@babel/helper-function-name" "^7.23.0" + "@babel/helper-hoist-variables" "^7.22.5" + "@babel/helper-split-export-declaration" "^7.22.6" + "@babel/parser" "^7.23.0" + "@babel/types" "^7.23.0" debug "^4.1.0" globals "^11.1.0" @@ -1072,7 +1087,7 @@ "@babel/helper-validator-identifier" "^7.16.7" to-fast-properties "^2.0.0" -"@babel/types@^7.18.6", "@babel/types@^7.18.7", "@babel/types@^7.18.8": +"@babel/types@^7.18.6": version "7.18.8" resolved "https://registry.yarnpkg.com/@babel/types/-/types-7.18.8.tgz#c5af199951bf41ba4a6a9a6d0d8ad722b30cd42f" integrity sha512-qwpdsmraq0aJ3osLJRApsc2ouSJCdnMeZwB0DhbtHAtRpZNZCdlbRnHIgcRKzdE1g0iOGg644fzjOBcdOz9cPw== @@ -1089,6 +1104,15 @@ "@babel/helper-validator-identifier" "^7.19.1" to-fast-properties "^2.0.0" +"@babel/types@^7.22.15", "@babel/types@^7.22.5", "@babel/types@^7.23.0": + version "7.23.0" + resolved "https://registry.yarnpkg.com/@babel/types/-/types-7.23.0.tgz#8c1f020c9df0e737e4e247c0619f58c68458aaeb" + integrity sha512-0oIyUfKoI3mSqMvsxBdclDwxXKXAUA8v/apZbc+iSyARYou1o8ZGDxbUYyLFoW2arqS2jDGqJuZvv1d/io1axg== + dependencies: + "@babel/helper-string-parser" "^7.22.5" + "@babel/helper-validator-identifier" "^7.22.20" + to-fast-properties "^2.0.0" + "@bcoe/v8-coverage@^0.2.3": version "0.2.3" resolved "https://registry.yarnpkg.com/@bcoe/v8-coverage/-/v8-coverage-0.2.3.tgz#75a2e8b51cb758a7553d6804a5932d7aace75c39" @@ -2050,6 +2074,11 @@ balanced-match@^1.0.0: resolved "https://registry.yarnpkg.com/balanced-match/-/balanced-match-1.0.2.tgz#e83e3a7e3f300b34cb9d87f615fa0cbf357690ee" integrity sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw== +base-64@^0.1.0: + version "0.1.0" + resolved "https://registry.yarnpkg.com/base-64/-/base-64-0.1.0.tgz#780a99c84e7d600260361511c4877613bf24f6bb" + integrity sha512-Y5gU45svrR5tI2Vt/X9GPd3L0HNIKzGu202EjxrXMpuc2V2CiKgemAbUUsqYmZJvPtCXoUKjNZwBJzsNScUbXA== + base64-js@^1.1.2, base64-js@^1.3.1, base64-js@^1.5.1: version "1.5.1" resolved "https://registry.yarnpkg.com/base64-js/-/base64-js-1.5.1.tgz#1b1b440160a5bf7ad40b650f095963481903930a" @@ -2260,7 +2289,7 @@ caniuse-lite@^1.0.30001449: resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001478.tgz#0ef8a1cf8b16be47a0f9fc4ecfc952232724b32a" integrity sha512-gMhDyXGItTHipJj2ApIvR+iVB5hd0KP3svMWWXDvZOmjzJJassGLMfxRkQCSYgGd2gtdL/ReeiyvMSFD1Ss6Mw== -chalk@^2.0.0: +chalk@^2.0.0, chalk@^2.4.2: version "2.4.2" resolved "https://registry.yarnpkg.com/chalk/-/chalk-2.4.2.tgz#cd42541677a54333cf541a49108c1432b44c9424" integrity sha512-Mti+f9lpJNcwF4tWV8/OrTTtF1gZi+f8FqlyAdouralcFWFQWF2+NgCHShjkCb+IFBLq9buZwE1xckQU4peSuQ== @@ -5205,6 +5234,14 @@ react-native-codegen@^0.69.2: jscodeshift "^0.13.1" nullthrows "^1.1.1" +react-native-fs@^2.20.0: + version "2.20.0" + resolved "https://registry.yarnpkg.com/react-native-fs/-/react-native-fs-2.20.0.tgz#05a9362b473bfc0910772c0acbb73a78dbc810f6" + integrity sha512-VkTBzs7fIDUiy/XajOSNk0XazFE9l+QlMAce7lGuebZcag5CnjszB+u4BdqzwaQOdcYb5wsJIsqq4kxInIRpJQ== + dependencies: + base-64 "^0.1.0" + utf8 "^3.0.0" + react-native-gradle-plugin@^0.0.7: version "0.0.7" resolved "https://registry.yarnpkg.com/react-native-gradle-plugin/-/react-native-gradle-plugin-0.0.7.tgz#96602f909745239deab7b589443f14fce5da2056" @@ -6167,6 +6204,11 @@ utf8-byte-length@^1.0.1: resolved "https://registry.yarnpkg.com/utf8-byte-length/-/utf8-byte-length-1.0.4.tgz#f45f150c4c66eee968186505ab93fcbb8ad6bf61" integrity sha512-4+wkEYLBbWxqTahEsWrhxepcoVOJ+1z5PGIjPZxRkytcdSUaNjIjBM7Xn8E+pdSuV7SzvWovBFA54FO0JSoqhA== +utf8@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/utf8/-/utf8-3.0.0.tgz#f052eed1364d696e769ef058b183df88c87f69d1" + integrity sha512-E8VjFIQ/TyQgp+TZfS6l8yp/xWppSAHzidGiRrqe4bK4XP9pTRyKFgGJpO3SN7zdX4DeomTrwaseCHovfpFcqQ== + util-deprecate@^1.0.1, util-deprecate@~1.0.1: version "1.0.2" resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf" diff --git a/js/react_native/yarn.lock b/js/react_native/yarn.lock index 0b47158ab705b..ff9be7fbe3a5b 100644 --- a/js/react_native/yarn.lock +++ b/js/react_native/yarn.lock @@ -24,6 +24,14 @@ dependencies: "@babel/highlight" "^7.18.6" +"@babel/code-frame@^7.22.13": + version "7.22.13" + resolved "https://registry.yarnpkg.com/@babel/code-frame/-/code-frame-7.22.13.tgz#e3c1c099402598483b7a8c46a721d1038803755e" + integrity sha512-XktuhWlJ5g+3TJXc5upd9Ks1HutSArik6jf2eAjYFyIOf4ej3RN+184cZbzDvbPnuTJIUhPKKJE3cIsYTiAT3w== + dependencies: + "@babel/highlight" "^7.22.13" + chalk "^2.4.2" + "@babel/code-frame@~7.10.4": version "7.10.4" resolved "https://registry.yarnpkg.com/@babel/code-frame/-/code-frame-7.10.4.tgz#168da1a36e90da68ae8d49c0f1b48c7c6249213a" @@ -66,13 +74,14 @@ "@jridgewell/gen-mapping" "^0.3.0" jsesc "^2.5.1" -"@babel/generator@^7.18.7": - version "7.18.7" - resolved "https://registry.yarnpkg.com/@babel/generator/-/generator-7.18.7.tgz#2aa78da3c05aadfc82dbac16c99552fc802284bd" - integrity sha512-shck+7VLlY72a2w9c3zYWuE1pwOKEiQHV7GTUbSnhyl5eu3i04t30tBY82ZRWrDfo3gkakCFtevExnxbkf2a3A== +"@babel/generator@^7.23.0": + version "7.23.0" + resolved "https://registry.yarnpkg.com/@babel/generator/-/generator-7.23.0.tgz#df5c386e2218be505b34837acbcb874d7a983420" + integrity sha512-lN85QRR+5IbYrMWM6Y4pE/noaQtg4pNiqeNGX60eqOfo6gtEj6uw/JagelB8vVztSd7R6M5n1+PQkDbHbBRU4g== dependencies: - "@babel/types" "^7.18.7" + "@babel/types" "^7.23.0" "@jridgewell/gen-mapping" "^0.3.2" + "@jridgewell/trace-mapping" "^0.3.17" jsesc "^2.5.1" "@babel/helper-annotate-as-pure@^7.16.7": @@ -160,6 +169,11 @@ resolved "https://registry.yarnpkg.com/@babel/helper-environment-visitor/-/helper-environment-visitor-7.18.6.tgz#b7eee2b5b9d70602e59d1a6cad7dd24de7ca6cd7" integrity sha512-8n6gSfn2baOY+qlp+VSzsosjCVGFqWKmDF0cCWOybh52Dw3SEyoWR1KrhMJASjLwIEkkAufZ0xvr+SxLHSpy2Q== +"@babel/helper-environment-visitor@^7.22.20": + version "7.22.20" + resolved "https://registry.yarnpkg.com/@babel/helper-environment-visitor/-/helper-environment-visitor-7.22.20.tgz#96159db61d34a29dba454c959f5ae4a649ba9167" + integrity sha512-zfedSIzFhat/gFhWfHtgWvlec0nqB9YEIVrpuwjruLlXfUSnA8cJB0miHKwqDnQ7d32aKo2xt88/xZptwxbfhA== + "@babel/helper-explode-assignable-expression@^7.16.7": version "7.16.7" resolved "https://registry.yarnpkg.com/@babel/helper-explode-assignable-expression/-/helper-explode-assignable-expression-7.16.7.tgz#12a6d8522fdd834f194e868af6354e8650242b7a" @@ -183,6 +197,14 @@ "@babel/template" "^7.18.6" "@babel/types" "^7.18.6" +"@babel/helper-function-name@^7.23.0": + version "7.23.0" + resolved "https://registry.yarnpkg.com/@babel/helper-function-name/-/helper-function-name-7.23.0.tgz#1f9a3cdbd5b2698a670c30d2735f9af95ed52759" + integrity sha512-OErEqsrxjZTJciZ4Oo+eoZqeW9UIiOcuYKRJA4ZAgV9myA+pOXhhmpfNCKjEH/auVfEYVFJ6y1Tc4r0eIApqiw== + dependencies: + "@babel/template" "^7.22.15" + "@babel/types" "^7.23.0" + "@babel/helper-hoist-variables@^7.16.7": version "7.16.7" resolved "https://registry.yarnpkg.com/@babel/helper-hoist-variables/-/helper-hoist-variables-7.16.7.tgz#86bcb19a77a509c7b77d0e22323ef588fa58c246" @@ -190,12 +212,12 @@ dependencies: "@babel/types" "^7.16.7" -"@babel/helper-hoist-variables@^7.18.6": - version "7.18.6" - resolved "https://registry.yarnpkg.com/@babel/helper-hoist-variables/-/helper-hoist-variables-7.18.6.tgz#d4d2c8fb4baeaa5c68b99cc8245c56554f926678" - integrity sha512-UlJQPkFqFULIcyW5sbzgbkxn2FKRgwWiRexcuaR8RNJRy8+LLveqPjwZV/bwrLZCN0eUHD/x8D0heK1ozuoo6Q== +"@babel/helper-hoist-variables@^7.22.5": + version "7.22.5" + resolved "https://registry.yarnpkg.com/@babel/helper-hoist-variables/-/helper-hoist-variables-7.22.5.tgz#c01a007dac05c085914e8fb652b339db50d823bb" + integrity sha512-wGjk9QZVzvknA6yKIUURb8zY3grXCcOZt+/7Wcy8O2uctxhplmUPkOdlgoNhmdVee2c92JXbf1xpMtVNbfoxRw== dependencies: - "@babel/types" "^7.18.6" + "@babel/types" "^7.22.5" "@babel/helper-member-expression-to-functions@^7.17.7": version "7.17.7" @@ -293,12 +315,17 @@ dependencies: "@babel/types" "^7.16.7" -"@babel/helper-split-export-declaration@^7.18.6": - version "7.18.6" - resolved "https://registry.yarnpkg.com/@babel/helper-split-export-declaration/-/helper-split-export-declaration-7.18.6.tgz#7367949bc75b20c6d5a5d4a97bba2824ae8ef075" - integrity sha512-bde1etTx6ZyTmobl9LLMMQsaizFVZrquTEHOqKeQESMKo4PlObf+8+JA25ZsIpZhT/WEd39+vOdLXAFG/nELpA== +"@babel/helper-split-export-declaration@^7.22.6": + version "7.22.6" + resolved "https://registry.yarnpkg.com/@babel/helper-split-export-declaration/-/helper-split-export-declaration-7.22.6.tgz#322c61b7310c0997fe4c323955667f18fcefb91c" + integrity sha512-AsUnxuLhRYsisFiaJwvp1QF+I3KjD5FOxut14q/GzovUe6orHLesW2C7d754kRm53h5gqrz6sFl6sxc4BVtE/g== dependencies: - "@babel/types" "^7.18.6" + "@babel/types" "^7.22.5" + +"@babel/helper-string-parser@^7.22.5": + version "7.22.5" + resolved "https://registry.yarnpkg.com/@babel/helper-string-parser/-/helper-string-parser-7.22.5.tgz#533f36457a25814cf1df6488523ad547d784a99f" + integrity sha512-mM4COjgZox8U+JcXQwPijIZLElkgEpO5rsERVDJTc2qfCDfERyob6k5WegS14SX18IIjv+XD+GrqNumY5JRCDw== "@babel/helper-validator-identifier@^7.16.7": version "7.19.1" @@ -310,6 +337,11 @@ resolved "https://registry.yarnpkg.com/@babel/helper-validator-identifier/-/helper-validator-identifier-7.18.6.tgz#9c97e30d31b2b8c72a1d08984f2ca9b574d7a076" integrity sha512-MmetCkz9ej86nJQV+sFCxoGGrUbU3q02kgLciwkrt9QqEB7cP39oKEY0PakknEO0Gu20SskMRi+AYZ3b1TpN9g== +"@babel/helper-validator-identifier@^7.22.20": + version "7.22.20" + resolved "https://registry.yarnpkg.com/@babel/helper-validator-identifier/-/helper-validator-identifier-7.22.20.tgz#c4ae002c61d2879e724581d96665583dbc1dc0e0" + integrity sha512-Y4OZ+ytlatR8AI+8KZfKuL5urKp7qey08ha31L8b3BwewJAoJamTzyvxPR/5D+KkdJCGPq/+8TukHBlY10FX9A== + "@babel/helper-validator-option@^7.16.7": version "7.16.7" resolved "https://registry.yarnpkg.com/@babel/helper-validator-option/-/helper-validator-option-7.16.7.tgz#b203ce62ce5fe153899b617c08957de860de4d23" @@ -362,16 +394,30 @@ chalk "^2.0.0" js-tokens "^4.0.0" +"@babel/highlight@^7.22.13": + version "7.22.20" + resolved "https://registry.yarnpkg.com/@babel/highlight/-/highlight-7.22.20.tgz#4ca92b71d80554b01427815e06f2df965b9c1f54" + integrity sha512-dkdMCN3py0+ksCgYmGG8jKeGA/8Tk+gJwSYYlFGxG5lmhfKNoAy004YpLxpS1W2J8m/EK2Ew+yOs9pVRwO89mg== + dependencies: + "@babel/helper-validator-identifier" "^7.22.20" + chalk "^2.4.2" + js-tokens "^4.0.0" + "@babel/parser@^7.1.0", "@babel/parser@^7.13.16", "@babel/parser@^7.14.0", "@babel/parser@^7.14.7", "@babel/parser@^7.16.7", "@babel/parser@^7.18.0": version "7.18.3" resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.18.3.tgz#39e99c7b0c4c56cef4d1eed8de9f506411c2ebc2" integrity sha512-rL50YcEuHbbauAFAysNsJA4/f89fGTOBRNs9P81sniKnKAr4xULe5AecolcsKbi88xu0ByWYDj/S1AJ3FSFuSQ== -"@babel/parser@^7.18.6", "@babel/parser@^7.18.8": +"@babel/parser@^7.18.6": version "7.18.8" resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.18.8.tgz#822146080ac9c62dac0823bb3489622e0bc1cbdf" integrity sha512-RSKRfYX20dyH+elbJK2uqAkVyucL+xXzhqlMD5/ZXx+dAAwpyB7HsvnHe/ZUGOF+xLr5Wx9/JoXVTj6BQE2/oA== +"@babel/parser@^7.22.15", "@babel/parser@^7.23.0": + version "7.23.0" + resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.23.0.tgz#da950e622420bf96ca0d0f2909cdddac3acd8719" + integrity sha512-vvPKKdMemU85V9WE/l5wZEmImpCtLqbnTvqDS2U1fJ96KrxoW7KrXhNsNCblQlg8Ck4b85yxdTyelsMUgFUXiw== + "@babel/plugin-bugfix-safari-id-destructuring-collision-in-function-expression@^7.17.12": version "7.17.12" resolved "https://registry.yarnpkg.com/@babel/plugin-bugfix-safari-id-destructuring-collision-in-function-expression/-/plugin-bugfix-safari-id-destructuring-collision-in-function-expression-7.17.12.tgz#1dca338caaefca368639c9ffb095afbd4d420b1e" @@ -1182,35 +1228,28 @@ "@babel/parser" "^7.18.6" "@babel/types" "^7.18.6" -"@babel/traverse@^7.13.0", "@babel/traverse@^7.14.0", "@babel/traverse@^7.16.8", "@babel/traverse@^7.18.0", "@babel/traverse@^7.18.2", "@babel/traverse@^7.7.2": - version "7.18.2" - resolved "https://registry.yarnpkg.com/@babel/traverse/-/traverse-7.18.2.tgz#b77a52604b5cc836a9e1e08dca01cba67a12d2e8" - integrity sha512-9eNwoeovJ6KH9zcCNnENY7DMFwTU9JdGCFtqNLfUAqtUHRCOsTOqWoffosP8vKmNYeSBUv3yVJXjfd8ucwOjUA== - dependencies: - "@babel/code-frame" "^7.16.7" - "@babel/generator" "^7.18.2" - "@babel/helper-environment-visitor" "^7.18.2" - "@babel/helper-function-name" "^7.17.9" - "@babel/helper-hoist-variables" "^7.16.7" - "@babel/helper-split-export-declaration" "^7.16.7" - "@babel/parser" "^7.18.0" - "@babel/types" "^7.18.2" - debug "^4.1.0" - globals "^11.1.0" - -"@babel/traverse@^7.18.6": - version "7.18.8" - resolved "https://registry.yarnpkg.com/@babel/traverse/-/traverse-7.18.8.tgz#f095e62ab46abf1da35e5a2011f43aee72d8d5b0" - integrity sha512-UNg/AcSySJYR/+mIcJQDCv00T+AqRO7j/ZEJLzpaYtgM48rMg5MnkJgyNqkzo88+p4tfRvZJCEiwwfG6h4jkRg== - dependencies: - "@babel/code-frame" "^7.18.6" - "@babel/generator" "^7.18.7" - "@babel/helper-environment-visitor" "^7.18.6" - "@babel/helper-function-name" "^7.18.6" - "@babel/helper-hoist-variables" "^7.18.6" - "@babel/helper-split-export-declaration" "^7.18.6" - "@babel/parser" "^7.18.8" - "@babel/types" "^7.18.8" +"@babel/template@^7.22.15": + version "7.22.15" + resolved "https://registry.yarnpkg.com/@babel/template/-/template-7.22.15.tgz#09576efc3830f0430f4548ef971dde1350ef2f38" + integrity sha512-QPErUVm4uyJa60rkI73qneDacvdvzxshT3kksGqlGWYdOTIUOwJ7RDUL8sGqslY1uXWSL6xMFKEXDS3ox2uF0w== + dependencies: + "@babel/code-frame" "^7.22.13" + "@babel/parser" "^7.22.15" + "@babel/types" "^7.22.15" + +"@babel/traverse@^7.13.0", "@babel/traverse@^7.14.0", "@babel/traverse@^7.16.8", "@babel/traverse@^7.18.0", "@babel/traverse@^7.18.2", "@babel/traverse@^7.18.6", "@babel/traverse@^7.7.2": + version "7.23.2" + resolved "https://registry.yarnpkg.com/@babel/traverse/-/traverse-7.23.2.tgz#329c7a06735e144a506bdb2cad0268b7f46f4ad8" + integrity sha512-azpe59SQ48qG6nu2CzcMLbxUudtN+dOM9kDbUqGq3HXUJRlo7i8fvPoxQUzYgLZ4cMVmuZgm8vvBpNeRhd6XSw== + dependencies: + "@babel/code-frame" "^7.22.13" + "@babel/generator" "^7.23.0" + "@babel/helper-environment-visitor" "^7.22.20" + "@babel/helper-function-name" "^7.23.0" + "@babel/helper-hoist-variables" "^7.22.5" + "@babel/helper-split-export-declaration" "^7.22.6" + "@babel/parser" "^7.23.0" + "@babel/types" "^7.23.0" debug "^4.1.0" globals "^11.1.0" @@ -1222,7 +1261,7 @@ "@babel/helper-validator-identifier" "^7.16.7" to-fast-properties "^2.0.0" -"@babel/types@^7.18.6", "@babel/types@^7.18.7", "@babel/types@^7.18.8": +"@babel/types@^7.18.6": version "7.18.8" resolved "https://registry.yarnpkg.com/@babel/types/-/types-7.18.8.tgz#c5af199951bf41ba4a6a9a6d0d8ad722b30cd42f" integrity sha512-qwpdsmraq0aJ3osLJRApsc2ouSJCdnMeZwB0DhbtHAtRpZNZCdlbRnHIgcRKzdE1g0iOGg644fzjOBcdOz9cPw== @@ -1230,6 +1269,15 @@ "@babel/helper-validator-identifier" "^7.18.6" to-fast-properties "^2.0.0" +"@babel/types@^7.22.15", "@babel/types@^7.22.5", "@babel/types@^7.23.0": + version "7.23.0" + resolved "https://registry.yarnpkg.com/@babel/types/-/types-7.23.0.tgz#8c1f020c9df0e737e4e247c0619f58c68458aaeb" + integrity sha512-0oIyUfKoI3mSqMvsxBdclDwxXKXAUA8v/apZbc+iSyARYou1o8ZGDxbUYyLFoW2arqS2jDGqJuZvv1d/io1axg== + dependencies: + "@babel/helper-string-parser" "^7.22.5" + "@babel/helper-validator-identifier" "^7.22.20" + to-fast-properties "^2.0.0" + "@bcoe/v8-coverage@^0.2.3": version "0.2.3" resolved "https://registry.yarnpkg.com/@bcoe/v8-coverage/-/v8-coverage-0.2.3.tgz#75a2e8b51cb758a7553d6804a5932d7aace75c39" @@ -1530,6 +1578,11 @@ resolved "https://registry.yarnpkg.com/@jridgewell/resolve-uri/-/resolve-uri-3.0.7.tgz#30cd49820a962aff48c8fffc5cd760151fca61fe" integrity sha512-8cXDaBBHOr2pQ7j77Y6Vp5VDT2sIqWyWQ56TjEq4ih/a4iST3dItRe8Q9fp0rrIl9DoKhWQtUQz/YpOxLkXbNA== +"@jridgewell/resolve-uri@^3.1.0": + version "3.1.1" + resolved "https://registry.yarnpkg.com/@jridgewell/resolve-uri/-/resolve-uri-3.1.1.tgz#c08679063f279615a3326583ba3a90d1d82cc721" + integrity sha512-dSYZh7HhCDtCKm4QakX0xFpsRDqjjtZf/kjI/v3T3Nwt5r8/qz/M19F9ySyOqU94SXBmeG9ttTul+YnR4LOxFA== + "@jridgewell/set-array@^1.0.0": version "1.1.1" resolved "https://registry.yarnpkg.com/@jridgewell/set-array/-/set-array-1.1.1.tgz#36a6acc93987adcf0ba50c66908bd0b70de8afea" @@ -1545,6 +1598,19 @@ resolved "https://registry.yarnpkg.com/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.13.tgz#b6461fb0c2964356c469e115f504c95ad97ab88c" integrity sha512-GryiOJmNcWbovBxTfZSF71V/mXbgcV3MewDe3kIMCLyIh5e7SKAeUZs+rMnJ8jkMolZ/4/VsdBmMrw3l+VdZ3w== +"@jridgewell/sourcemap-codec@^1.4.14": + version "1.4.15" + resolved "https://registry.yarnpkg.com/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.15.tgz#d7c6e6755c78567a951e04ab52ef0fd26de59f32" + integrity sha512-eF2rxCRulEKXHTRiDrDy6erMYWqNw4LPdQ8UQA4huuxaQsVeRPFl2oM8oDGxMFhJUWZf9McpLtJasDDZb/Bpeg== + +"@jridgewell/trace-mapping@^0.3.17": + version "0.3.19" + resolved "https://registry.yarnpkg.com/@jridgewell/trace-mapping/-/trace-mapping-0.3.19.tgz#f8a3249862f91be48d3127c3cfe992f79b4b8811" + integrity sha512-kf37QtfW+Hwx/buWGMPcR60iF9ziHa6r/CZJIHbmcm4+0qrXiVdxegAH0F6yddEVQ7zdkjcGCgCzUu+BcbhQxw== + dependencies: + "@jridgewell/resolve-uri" "^3.1.0" + "@jridgewell/sourcemap-codec" "^1.4.14" + "@jridgewell/trace-mapping@^0.3.9": version "0.3.13" resolved "https://registry.yarnpkg.com/@jridgewell/trace-mapping/-/trace-mapping-0.3.13.tgz#dcfe3e95f224c8fe97a87a5235defec999aa92ea" @@ -2470,7 +2536,7 @@ caniuse-lite@^1.0.30001332: resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001342.tgz#87152b1e3b950d1fbf0093e23f00b6c8e8f1da96" integrity sha512-bn6sOCu7L7jcbBbyNhLg0qzXdJ/PMbybZTH/BA6Roet9wxYRm6Tr9D0s0uhLkOZ6MSG+QU6txUgdpr3MXIVqjA== -chalk@^2.0.0: +chalk@^2.0.0, chalk@^2.4.2: version "2.4.2" resolved "https://registry.yarnpkg.com/chalk/-/chalk-2.4.2.tgz#cd42541677a54333cf541a49108c1432b44c9424" integrity sha512-Mti+f9lpJNcwF4tWV8/OrTTtF1gZi+f8FqlyAdouralcFWFQWF2+NgCHShjkCb+IFBLq9buZwE1xckQU4peSuQ== diff --git a/js/web/.npmignore b/js/web/.npmignore index 16f487f5ff4cd..0f018f525a8d6 100644 --- a/js/web/.npmignore +++ b/js/web/.npmignore @@ -4,6 +4,26 @@ /dist/**/*.report.html +# We remove some of the files in NPM packages because restrictions in jsdelivr: +# +# "Packages larger than 150 MB or single files larger than 20 MB (in the case of GitHub) are not supported" +# +# from https://www.jsdelivr.com/documentation +# +# We only include development build in the NPM package for the following targets: +# - /dist/ort.js +# - /dist/ort.all.js +# +/dist/cjs/ort.js +/dist/esm/ort.js +/dist/cjs/ort.all.js +/dist/esm/ort.all.js +/dist/**/ort.wasm.js +/dist/**/ort.wasm-core.js +/dist/**/ort.webgl.js +/dist/**/ort.webgpu.js +/dist/**/ort.training.wasm.js + /types/ karma.conf.js diff --git a/js/web/docs/webgl-operators.md b/js/web/docs/webgl-operators.md index de84134ddbb3f..7c129b66bfa3d 100644 --- a/js/web/docs/webgl-operators.md +++ b/js/web/docs/webgl-operators.md @@ -12,6 +12,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [Acos](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Acos) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Acos-7) | | [Acosh](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Acosh) | | | [Add](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Add) | [7-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Add-7), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Add-13), [14+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Add-14) | +| [AffineGrid](https://github.com/onnx/onnx/blob/main/docs/Operators.md#AffineGrid) | | | [And](https://github.com/onnx/onnx/blob/main/docs/Operators.md#And) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#And-7) | | [ArgMax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ArgMax) | | | [ArgMin](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ArgMin) | | @@ -67,6 +68,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [Gather](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-13) | | [GatherElements](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherElements) | | | [GatherND](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherND) | | +| [Gelu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gelu) | | | [Gemm](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gemm) | [7-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gemm-7), [9-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gemm-9), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gemm-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gemm-13) | | [GlobalAveragePool](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GlobalAveragePool) | [1+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#GlobalAveragePool-1) | | [GlobalLpPool](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GlobalLpPool) | | @@ -82,6 +84,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [Hardmax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Hardmax) | | | [Identity](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Identity) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-1), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-13), [14-15](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-14), [16-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-16), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-19) | | [If](https://github.com/onnx/onnx/blob/main/docs/Operators.md#If) | | +| [ImageDecoder](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ImageDecoder) | | | [InstanceNormalization](https://github.com/onnx/onnx/blob/main/docs/Operators.md#InstanceNormalization) | [6+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#InstanceNormalization-6) | | [IsInf](https://github.com/onnx/onnx/blob/main/docs/Operators.md#IsInf) | | | [IsNaN](https://github.com/onnx/onnx/blob/main/docs/Operators.md#IsNaN) | | @@ -137,12 +140,13 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [ReduceL2](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceL2) | | | [ReduceLogSum](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceLogSum) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceLogSum-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceLogSum-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceLogSum-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceLogSum-18) | | [ReduceLogSumExp](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceLogSumExp) | | -| [ReduceMax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMax) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-1), [11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-11), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-12), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-18) | +| [ReduceMax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMax) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-1), [11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-11), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-12), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-13), [18-19](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-18), [20+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-20) | | [ReduceMean](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMean) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMean-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMean-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMean-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMean-18) | -| [ReduceMin](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMin) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-1), [11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-11), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-12), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-18) | +| [ReduceMin](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMin) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-1), [11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-11), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-12), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-13), [18-19](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-18), [20+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-20) | | [ReduceProd](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceProd) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceProd-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceProd-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceProd-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceProd-18) | | [ReduceSum](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSum) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSum-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSum-11) | | [ReduceSumSquare](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSumSquare) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-18) | +| [RegexFullMatch](https://github.com/onnx/onnx/blob/main/docs/Operators.md#RegexFullMatch) | | | [Relu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Relu) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-6), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-13), [14+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-14) | | [Reshape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Reshape) | [5-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-5), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-13), [14-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-14), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-19) | | [Resize](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize) | [10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-10), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-13), [18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-18), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-19) | @@ -179,7 +183,9 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [SplitToSequence](https://github.com/onnx/onnx/blob/main/docs/Operators.md#SplitToSequence) | | | [Sqrt](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sqrt) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sqrt-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sqrt-13) | | [Squeeze](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Squeeze) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-13) | +| [StringConcat](https://github.com/onnx/onnx/blob/main/docs/Operators.md#StringConcat) | | | [StringNormalizer](https://github.com/onnx/onnx/blob/main/docs/Operators.md#StringNormalizer) | | +| [StringSplit](https://github.com/onnx/onnx/blob/main/docs/Operators.md#StringSplit) | | | [Sub](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sub) | [7-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sub-7), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sub-13), [14+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sub-14) | | [Sum](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sum) | [6-7](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sum-6), [8-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sum-8), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sum-13) | | [Tan](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Tan) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Tan-7) | diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 13244853c6c50..b246e19137888 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -21,15 +21,15 @@ Do not modify directly.* | Atan | ai.onnx(7+) | | | Atanh | ai.onnx(9+) | | | Attention | com.microsoft(1+) | need implementing mask and past/present | -| AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(11+) | need perf optimization; need implementing activation | +| AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(7-9,10,11+) | need perf optimization; need implementing activation | | BiasAdd | com.microsoft(1+) | | | BiasSplitGelu | com.microsoft(1+) | | | Cast | ai.onnx(6-8,9-12,13-18,19+) | | | Ceil | ai.onnx(6-12,13+) | | | Clip | ai.onnx(6-10,11,12,13+) | | | Concat | ai.onnx(1-3,4-10,11-12,13+) | | -| Conv | ai.onnx(1-10,11+); com.ms.internal.nhwc(11+) | need perf optimization; conv3d is not supported; need implementing activation | -| ConvTranspose | ai.onnx(1-10,11+); com.ms.internal.nhwc(11+) | need perf optimization; ConvTranspose3d is not supported; need implementing activation | +| Conv | ai.onnx(1-10,11+); com.ms.internal.nhwc(1-10,11+) | need perf optimization; conv3d is not supported; need implementing activation | +| ConvTranspose | ai.onnx(1-10,11+); com.ms.internal.nhwc(1-10,11+) | need perf optimization; ConvTranspose3d is not supported; need implementing activation | | Cos | ai.onnx(7+) | | | Cosh | ai.onnx(9+) | | | Div | ai.onnx(7-12,13,14+) | | @@ -41,6 +41,7 @@ Do not modify directly.* | Expand | ai.onnx(8-12,13+) | | | Flatten | ai.onnx(1-8,9-10,11-12,13+) | | | Floor | ai.onnx(6-12,13+) | | +| FusedConv | com.microsoft(1+) | | | Gather | ai.onnx(1-10,11-12,13+) | | | GatherElements | ai.onnx(11-12,13+) | | | Gelu | com.microsoft(1+) | | @@ -57,7 +58,7 @@ Do not modify directly.* | LessOrEqual | ai.onnx(12-15,16+) | | | Log | ai.onnx(6-12,13+) | | | MatMul | ai.onnx(1-12,13+) | | -| MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(11,12+) | need perf optimization; need implementing activation | +| MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(1-7,8-9,10,11,12+) | need perf optimization; need implementing activation | | MemcpyFromHost | ai.onnx(1+) | | | MemcpyToHost | ai.onnx(1+) | | | Mul | ai.onnx(7-12,13,14+) | | @@ -80,7 +81,7 @@ Do not modify directly.* | ReduceSumSquare | ai.onnx(1-10,11-12,13-17,18+) | | | Relu | ai.onnx(6-12,13,14+) | | | Reshape | ai.onnx(5-12,13,14+) | no GPU kernel | -| Resize | ai.onnx(10,11-12,13-17,18,19+); com.ms.internal.nhwc(11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling | +| Resize | ai.onnx(10,11-12,13-17,18,19+); com.ms.internal.nhwc(10,11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling | | Shape | ai.onnx(1-12,13-14,15+) | no GPU kernel; an ORT warning is generated - need to fix | | Sigmoid | ai.onnx(6-12,13+) | | | Sin | ai.onnx(7+) | | diff --git a/js/web/lib/backend-onnxjs.ts b/js/web/lib/backend-onnxjs.ts index 5ea7de809a495..7176823c9bf13 100644 --- a/js/web/lib/backend-onnxjs.ts +++ b/js/web/lib/backend-onnxjs.ts @@ -5,7 +5,7 @@ import {Backend, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common'; import {Session} from './onnxjs/session'; -import {OnnxjsSessionHandler} from './onnxjs/session-handler'; +import {OnnxjsSessionHandler} from './onnxjs/session-handler-inference'; class OnnxjsBackend implements Backend { // eslint-disable-next-line @typescript-eslint/no-empty-function diff --git a/js/web/lib/backend-wasm-training.ts b/js/web/lib/backend-wasm-training.ts index af5b575c87a7f..09dac3a85311c 100644 --- a/js/web/lib/backend-wasm-training.ts +++ b/js/web/lib/backend-wasm-training.ts @@ -4,13 +4,17 @@ import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common'; import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; +import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-training'; class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { async createTrainingSessionHandler( - _checkpointStateUriOrBuffer: string|Uint8Array, _trainModelUriOrBuffer: string|Uint8Array, - _evalModelUriOrBuffer: string|Uint8Array, _optimizerModelUriOrBuffer: string|Uint8Array, - _options: InferenceSession.SessionOptions): Promise { - throw new Error('Method not implemented yet.'); + checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, + evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, + options: InferenceSession.SessionOptions): Promise { + const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler(); + await handler.createTrainingSession( + checkpointStateUriOrBuffer, trainModelUriOrBuffer, evalModelUriOrBuffer, optimizerModelUriOrBuffer, options); + return Promise.resolve(handler); } } diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index 5740263583031..78edcc90f55f9 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -5,7 +5,7 @@ import {cpus} from 'node:os'; import {Backend, env, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common'; import {initializeWebAssemblyInstance} from './wasm/proxy-wrapper'; -import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler'; +import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler-inference'; /** * This function initializes all flags for WebAssembly. diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index c5c27a4318049..6060271ced156 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -7,6 +7,9 @@ // So we import code inside the if-clause to allow bundler remove the code safely. export * from 'onnxruntime-common'; +import * as ort from 'onnxruntime-common'; +export default ort; + import {registerBackend, env} from 'onnxruntime-common'; import {version} from './version'; diff --git a/js/web/lib/onnxjs/session-handler.ts b/js/web/lib/onnxjs/session-handler-inference.ts similarity index 100% rename from js/web/lib/onnxjs/session-handler.ts rename to js/web/lib/onnxjs/session-handler-inference.ts diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index b7b2ff4537095..00431a4e86d5b 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -102,6 +102,11 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtTrainingCopyParametersFromBuffer? (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; + _OrtTrainingGetModelInputOutputCount? + (trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number; + _OrtTrainingGetModelInputOutputName? + (trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number; + _OrtTrainingReleaseSession?(trainingHandle: number): void; // #endregion diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 5d66caf77f08f..eb40da048835e 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -126,14 +126,14 @@ export class WebGpuBackend { */ kernels: Map unknown) | undefined, unknown]]>; - commandEncoder: GPUCommandEncoder|null = null; - computePassEncoder: GPUComputePassEncoder|null = null; + private commandEncoder: GPUCommandEncoder|null = null; + private computePassEncoder: GPUComputePassEncoder|null = null; pendingDispatchNumber = 0; - supportTimestampQuery = false; - profilingQuerySet: GPUQuerySet; - profilingQueryData: GpuData; - profilingTimeBase?: bigint; + queryData?: GpuData; + querySet?: GPUQuerySet; + querySetCount = 2; + queryTimeBase?: bigint; env: Env; @@ -168,11 +168,9 @@ export class WebGpuBackend { }, requiredFeatures, }; - // WebGPU Spec: Timestamp Queries Inside Passes - // https://github.com/gpuweb/gpuweb/blob/main/proposals/timestamp-query-inside-passes.md - if (adapter.features.has('timestamp-query-inside-passes')) { - this.supportTimestampQuery = true; - requiredFeatures.push('timestamp-query-inside-passes' as GPUFeatureName); + + if (adapter.features.has('timestamp-query')) { + requiredFeatures.push('timestamp-query'); } if (adapter.features.has('shader-f16')) { requiredFeatures.push('shader-f16'); @@ -197,21 +195,14 @@ export class WebGpuBackend { } }; - if (this.supportTimestampQuery) { - this.profilingQuerySet = this.device.createQuerySet({ - type: 'timestamp', - count: 2, - }); - } - Object.defineProperty(this.env.webgpu, 'device', {value: this.device}); } dispose(): void { - // currently, we do not do anything in this function. In all known use cases, we don't have the requirement to - // actually dispose the WebGpuBackend instance, because it's always used as a singleton. - // - // revisit this place if we get real requirement to dispose the instance. + if (typeof this.querySet !== 'undefined') { + this.querySet.destroy(); + } + this.gpuDataManager.dispose(); } getCommandEncoder(): GPUCommandEncoder { @@ -223,7 +214,22 @@ export class WebGpuBackend { getComputePassEncoder(): GPUComputePassEncoder { if (!this.computePassEncoder) { - this.computePassEncoder = this.getCommandEncoder().beginComputePass(); + const computePassDescriptor: GPUComputePassDescriptor = {}; + if (this.isQueryEnabled()) { + if (typeof this.querySet === 'undefined') { + this.querySet = this.device.createQuerySet({ + type: 'timestamp', + count: this.querySetCount, + }); + } + computePassDescriptor.timestampWrites = { + querySet: this.querySet, + beginningOfPassWriteIndex: 0, + endOfPassWriteIndex: 1, + }; + } + + this.computePassEncoder = this.getCommandEncoder().beginComputePass(computePassDescriptor); } return this.computePassEncoder; } @@ -245,6 +251,14 @@ export class WebGpuBackend { } } + isQueryEnabled(): boolean { + if (this.device.features.has('timestamp-query') && this.env.webgpu.profilingMode === 'default') { + return true; + } else { + return false; + } + } + /** * run a WebGPU program. * @param program a ProgramInfo instance diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 2184d6280189a..9f5dceb8f4726 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -70,6 +70,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Exp', [unaryOps.exp]], ['Expand', [expand]], ['Floor', [unaryOps.floor]], + ['FusedConv', [conv, parseConvAttributes]], ['Gather', [gather, parseGatherAttributes]], ['GatherElements', [gatherElements, parseGatherElementsAttributes]], ['Gelu', [unaryOps.gelu]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts index 22b91d680a9b4..6481a6b21d723 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts @@ -41,12 +41,12 @@ export const activationFnSnippet = if (!activation) { return ''; } - // TODO: add implementations return ''; }; export const biasActivationSnippet = (hasBias: boolean, activation?: Activation): string => ` ${hasBias ? 'value = value + getBiasByOutputCoords(coords);' : ''} - ${activation ? 'value = activation(value, coords);' : ''} + // TODO uncomment the following line when activation is supported above. + // ${activation ? 'value = activation(value, coords);' : ''} `; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 01ddca520deed..fbb936a045b9c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -242,8 +242,9 @@ export const createConv2DMatMulProgramInfo = ${declareFunctions} ${ conv2dCommonSnippet( - isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, undefined, false, elementsSize[0], - elementsSize[1], elementsSize[2], t)} + isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, + attributes.activation.toLowerCase() as Activation, false, elementsSize[0], elementsSize[1], + elementsSize[2], t)} ${ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) : diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index 840360223c75a..a95d3830f34eb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -236,7 +236,9 @@ export const createConv2DTransposeMatMulProgramInfo = const dimBOuter : i32 = ${dimBOuter}; const dimInner : i32 = ${dimInner}; ${declareFunctions} - ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, undefined, false, innerElementSize)} + ${ + conv2dTransposeCommonSnippet( + isChannelsLast, hasBias, attributes.activation.toLowerCase() as Activation, false, innerElementSize)} ${ isVec4 ? makeMatMulPackedVec4Source( elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 1032869412462..0a0f29db6a494 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -23,7 +23,7 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {ProgramInfo} from '../../types'; import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; -import {getActicationSnippet, InternalActivationAttributes} from '../fuse-utils'; +import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -440,7 +440,7 @@ export const createMatmulProgramInfo = const dimInner = aShape[aShape.length - 1]; const dimBOuter = bShape[bShape.length - 1]; const isVec4 = dimInner % 4 === 0 && dimBOuter % 4 === 0; - const {activationFunction, applyActivation} = getActicationSnippet(activationAttributes); + const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, isVec4); // TODO: fine tune size const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1]; @@ -473,8 +473,8 @@ export const createMatmulProgramInfo = const dimBOuter: i32 = ${dimBOuter}; const dimInner: i32 = ${dimInner}; ${shaderHelper.declareVariables(...inputVariables, output)} - ${declareFunctions} ${activationFunction} + ${declareFunctions} ${ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index eab571e87f5f5..0992552ebaf25 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view'; import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; type BuiltinFunctionName = string; type BinaryCustomExpression = (expressionA: string, expressionB: string) => string; @@ -18,10 +18,7 @@ type BinaryFunctionCall = BuiltinFunctionName|BinaryCustomExpression|{ const createBinaryOpProgramShader = (shaderHelper: ShaderHelper, dimsA: readonly number[], dimsB: readonly number[], dimsOutput: readonly number[], vectorize: boolean, doBroadcast: boolean, funcCall: BinaryFunctionCall, typeA: number, typeB: number, - typeOutput: number, additionalImplementation?: string) => { - const outputSize = ShapeUtil.size(dimsOutput); - const vecSize = Math.ceil(outputSize / 4); - + typeOutput: number, useShapesUniforms: boolean, additionalImplementation?: string) => { let expressionScalar: BinaryCustomExpression; let expressionVector: BinaryCustomExpression; if (typeof funcCall === 'string') { @@ -33,31 +30,12 @@ const createBinaryOpProgramShader = expressionVector = funcCall.vector; } - let broadcastImpl = ''; - const output = outputVariable('outputData', typeOutput, dimsOutput, 4); - const a = inputVariable('aData', typeA, dimsA, 4); - const b = inputVariable('bData', typeB, dimsB, 4); - if (doBroadcast) { - const calcOffsetImpl = (dims: readonly number[]) => { - const strides = ShapeUtil.computeStrides(dims); - const offsets: string[] = []; - for (let i = dims.length - 1; i >= 0; i--) { - const idx = output.indicesGet('outputIndices', i + dimsOutput.length - dims.length); - offsets.push(`${strides[i]}u * (${idx} % ${dims[i]}u)`); - } - return offsets.length > 0 ? offsets.join('+') : '0u'; - }; - - broadcastImpl = ` - fn calcOffsetA(outputIndices: ${output.type.indices}) -> u32 { - return ${calcOffsetImpl(dimsA)}; - } - - fn calcOffsetB(outputIndices: ${output.type.indices}) -> u32 { - return ${calcOffsetImpl(dimsB)}; - } - `; - } + const inputAShapeOrRank = useShapesUniforms ? dimsA.length : dimsA; + const inputBShapeOrRank = useShapesUniforms ? dimsB.length : dimsB; + const outputShapeOrRank = useShapesUniforms ? dimsOutput.length : dimsOutput; + const output = outputVariable('outputData', typeOutput, outputShapeOrRank, 4); + const a = inputVariable('aData', typeA, inputAShapeOrRank, 4); + const b = inputVariable('bData', typeB, inputBShapeOrRank, 4); let assignment: string; if (vectorize) { @@ -73,8 +51,8 @@ const createBinaryOpProgramShader = } else { assignment = ` let outputIndices = ${output.offsetToIndices('global_idx * 4u')}; - let offsetA = calcOffsetA(outputIndices); - let offsetB = calcOffsetB(outputIndices); + let offsetA = ${a.broadcastedIndicesToOffset('outputIndices', output)}; + let offsetB = ${b.broadcastedIndicesToOffset('outputIndices', output)}; ${ output.setByOffset( 'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))} @@ -94,8 +72,8 @@ const createBinaryOpProgramShader = const expressionB = `bData[indexB${x}][componentB${x}]`; return ` let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; - let offsetA${x} = calcOffsetA(outputIndices${x}); - let offsetB${x} = calcOffsetB(outputIndices${x}); + let offsetA${x} = ${a.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; + let offsetB${x} = ${b.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; let indexA${x} = offsetA${x} / 4u; let indexB${x} = offsetB${x} / 4u; let componentA${x} = offsetA${x} % 4u; @@ -122,13 +100,12 @@ const createBinaryOpProgramShader = } return ` - ${shaderHelper.declareVariables(a, b, output)} + ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(a, b, output)} ${additionalImplementation ?? ''} - ${broadcastImpl} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')} ${assignment} }`; }; @@ -144,6 +121,7 @@ const createBinaryOpProgramInfo = // TODO: deal with zero-sized tensors (eg. dims=[1,0]) + const cacheKeyAux = [isBroadcast]; if (isBroadcast) { const calculatedShape = BroadcastUtil.calcShape(a.dims, b.dims, false); if (!calculatedShape) { @@ -153,7 +131,8 @@ const createBinaryOpProgramInfo = outputSize = ShapeUtil.size(outputShape); const isAOneElement = ShapeUtil.size(a.dims) === 1; const isBOneElement = ShapeUtil.size(b.dims) === 1; - + cacheKeyAux.push(isAOneElement); + cacheKeyAux.push(isBOneElement); // check whether vectorize can be enabled let sharedDimension = 1; for (let i = 1; i < outputShape.length; i++) { @@ -172,16 +151,34 @@ const createBinaryOpProgramInfo = // element-wise vectorize = true; } - + cacheKeyAux.push(vectorize); + const useShapesUniforms = enableShapesUniforms(a.dims.length) && enableShapesUniforms(b.dims.length) && + enableShapesUniforms(outputShape.length); return { name, - shaderCache: {hint: cacheKey}, + shaderCache: { + hint: cacheKey + cacheKeyAux.map((x) => x.toString()).join('_'), + // If the input is scalar then use type instead of dims because useShapesUniforms is false. + inputDependencies: useShapesUniforms ? + ['rank', 'rank'] : + [a.dims.length > 0 ? 'dims' : 'type', b.dims.length > 0 ? 'dims' : 'type'], + }, getShaderSource: (shaderHelper) => createBinaryOpProgramShader( shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, funcCall, a.dataType, b.dataType, - outputDataType, additionalImplementation), + outputDataType, useShapesUniforms, additionalImplementation), getRunData: () => ({ outputs: [{dims: outputShape, dataType: outputDataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}, + programUniforms: useShapesUniforms ? + [ + {type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, + ...createTensorShapeVariables(a.dims), + ...createTensorShapeVariables(b.dims), + ...createTensorShapeVariables(outputShape), + ] : + [ + {type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, + ], }), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 0a64d1ad1792a..7352517733c3e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -803,3 +803,6 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly } return dims; }; + +// TODO: remove this limitation once >4D dims are supported by uniform. +export const enableShapesUniforms = (rank: number): boolean => rank <= 4 && rank > 0; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 7abf022928ade..8bfa722dd0909 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -7,7 +7,7 @@ import {ProgramInfo} from '../types'; import {inputVariable, outputVariable, ShaderHelper} from './common'; import {calculateOutputShape, ConvAttributes} from './conv'; -import {getActicationSnippet} from './fuse-utils'; +import {getActivationSnippet} from './fuse-utils'; /** * naive grouped conv implementation, supports 1d/2d conv @@ -22,7 +22,7 @@ export const createGroupedConvProgramInfo = const wShape = inputs[1].dims; const outputChannelsPerGroup = wShape[0] / attributes.group; - const {activationFunction, applyActivation} = getActicationSnippet(attributes); + const {activationFunction, applyActivation} = getActivationSnippet(attributes); const isChannelLast = attributes.format === 'NHWC'; const outputShape = calculateOutputShape( diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index d241b8b92a669..e880afe09a5d8 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -232,7 +232,7 @@ const convTranspose2d = // STEP.1: transpose weight const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? context.compute( - createTransposeProgramInfo(inputs[1].dataType, inputs[1].dims.length, weightTransposePerm), + createTransposeProgramInfo(inputs[1], weightTransposePerm), {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; if (attributes.wIsConst && !context.kernelCustomData.wT) { context.kernelCustomData.wT = transposedWeight; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index b323a36cee5c8..c7ea0cffe51c3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -168,7 +168,7 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut if (isChannelsLast) { const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? context.compute( - createTransposeProgramInfo(inputs[1].dataType, inputs[1].dims.length, weightTransposeAttribute), + createTransposeProgramInfo(inputs[1], weightTransposeAttribute), {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; if (attributes.wIsConst && !context.kernelCustomData.wT) { context.kernelCustomData.wT = transposedWeight; @@ -208,7 +208,7 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut // STEP.1: transpose weight const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? context.compute( - createTransposeProgramInfo(inputs[1].dataType, inputs[1].dims.length, weightTransposeAttribute), + createTransposeProgramInfo(inputs[1], weightTransposeAttribute), {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; if (attributes.wIsConst && !context.kernelCustomData.wT) { context.kernelCustomData.wT = transposedWeight; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index 92105859a8c0e..956ef18eb5cfb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -10,24 +10,25 @@ export interface InternalActivationAttributes { readonly activationCacheKey: string; } -export const getActicationSnippet = - (attributes: InternalActivationAttributes): {activationFunction: string; applyActivation: string} => { - switch (attributes.activation) { - case 'Relu': - return {activationFunction: '', applyActivation: 'value = max(value, 0.0);'}; - case 'Sigmoid': - return {activationFunction: '', applyActivation: 'value = (1.0 / (1.0 + exp(-value)));'}; - case 'Clip': - return { - activationFunction: - `const clip_min_=f32(${attributes.clipMin!});const clip_max_=f32(${attributes.clipMax!});`, - applyActivation: 'value = clamp(value, clip_min_, clip_max_);' - }; - // TODO: adding other activations that can be fused. - default: - return {activationFunction: '', applyActivation: ''}; - } - }; +export const getActivationSnippet = (attributes: InternalActivationAttributes, isVec4 = false): { + activationFunction: string; applyActivation: string; +} => { + switch (attributes.activation) { + case 'Relu': + return {activationFunction: '', applyActivation: 'value = max(value, 0.0);'}; + case 'Sigmoid': + return {activationFunction: '', applyActivation: 'value = (1.0 / (1.0 + exp(-value)));'}; + case 'Clip': + return { + activationFunction: `const clip_min_=f32(${attributes.clipMin!});const clip_max_=f32(${attributes.clipMax!});`, + applyActivation: isVec4 ? 'value = clamp(value, vec4(clip_min_), vec4(clip_max_));' : + 'value = clamp(value, clip_min_, clip_max_);' + }; + // TODO: adding other activations that can be fused. + default: + return {activationFunction: '', applyActivation: ''}; + } +}; export const parseInternalActivationAttributes = (attributes: Record|undefined): InternalActivationAttributes => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 0c39152f56dad..97f633c7cf47e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -1,12 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common'; export interface InstanceNormAttributes extends AttributeWithCacheKey { epsilon: number; @@ -111,69 +112,159 @@ const createInstanceNormProgramInfo = }; }; +const computeMean = + (context: ComputeContext, input: TensorView, scale: TensorView, bias: TensorView, n: number, h: number, c: number, + epsilon: number) => { + const components = getMaxComponents(c); + const inputHelper = inputVariable('input', input.dataType, input.dims, components); + const scaleHelper = inputVariable('scale', scale.dataType, scale.dims, components); + const biasHelper = inputVariable('bias', bias.dataType, bias.dims, components); + + const WG = 64; + // we will store channel scale and channel shift in [2, components] matrix + // or in vec2 when components == 1 + const outputType = components === 1 ? 'vec2f' : `mat2x${components}f`; + const sumCastType = components === 1 ? 'f32' : `vec${components}f`; + const setOutputValue = (var1: string, var2: string) => `${outputType}(${var1}, ${var2})`; + const unitsOfWork = n * c / components; + const wgSize = Math.ceil(h / WG); + + const getMeanShaderSource = (shaderHelper: ShaderHelper) => ` + const H: u32 = ${h}; + const C: u32 = ${c / components}; + const imageSize: u32 = ${h * c / components}; + + ${shaderHelper.declareVariables(inputHelper)} + @group(0) @binding(1) var output : array<${outputType}>; + + ${shaderHelper.mainStart(WG)} + let currentImageNumber = global_idx / ${WG} / C; + let currentChannelNumber = (global_idx / ${WG}) % C; + let wgId = global_idx % ${WG}; + let wgOffset = wgId * ${wgSize}; + if (wgOffset >= H) { + return; + } + let wgMax = min(wgOffset + ${wgSize}, H); + + let offset = currentImageNumber * imageSize + currentChannelNumber; + var sum = ${fillVector('f32', components)}; + var squaredSum = ${fillVector('f32', components)}; + for (var i: u32 = wgOffset; i < wgMax; i++) { + let value = ${sumCastType}(input[offset + i * C]); + sum += value; + squaredSum += value * value; + } + output[global_idx] = ${setOutputValue('sum', 'squaredSum')}; + }`; + + const meanValues = context.compute( + { + name: 'InstanceNormComputeMean', + shaderCache: {hint: JSON.stringify({components, n, h, c})}, + getRunData: () => ({ + outputs: [ + {dims: [n, c, WG, 2], dataType: DataType.float}, + ], + dispatchGroup: {x: n * c / components}, + }), + getShaderSource: getMeanShaderSource, + }, + {inputs: [input], outputs: [-1]})[0]; + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const H: u32 = ${h}; + const C: u32 = ${c / components}; + const imageSize: u32 = ${WG * c / components}; + const epsilon: f32 = ${epsilon}; + + @group(0) @binding(0) var input : array<${outputType}>; + @group(0) @binding(1) var scale : array<${scaleHelper.type.storage}>; + @group(0) @binding(2) var bias : array<${biasHelper.type.storage}>; + @group(0) @binding(3) var output : array<${outputType}>; + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(unitsOfWork)} + let currentImageNumber = global_idx / C; + let currentChannelNumber = global_idx % C; + + let offset = currentImageNumber * imageSize; + var sum = ${fillVector('f32', components)}; + var squaredSum = ${fillVector('f32', components)}; + for (var i: u32 = 0; i < ${WG}; i++) { + let value = input[offset + i + currentChannelNumber * ${WG}]; + sum += value[0]; + squaredSum += value[1]; + } + sum = sum / f32(H); + squaredSum = squaredSum / f32(H); + let invStdDev = 1 / sqrt(squaredSum - sum * sum + epsilon); + let channelScale = invStdDev * ${sumCastType}(scale[currentChannelNumber]); + let channelShift = ${sumCastType}(bias[currentChannelNumber]) - sum * channelScale; + + output[global_idx] = ${setOutputValue('channelScale', 'channelShift')}; + }`; + + return context.compute( + { + name: 'InstanceNormComputeChannelScaleShift', + shaderCache: {hint: JSON.stringify({components, n, h, c, epsilon})}, + getRunData: () => ({ + outputs: [ + {dims: [n, c, 2], dataType: DataType.float}, + ], + dispatchGroup: {x: Math.ceil(unitsOfWork / 64 /* workgroup size */)}, + }), + getShaderSource, + }, + {inputs: [meanValues, scale, bias], outputs: [-1]})[0]; + }; + const createInstanceNormNHWCProgramInfo = - (inputs: readonly TensorView[], attributes: InstanceNormAttributes): ProgramInfo => { + (context: ComputeContext, inputs: readonly TensorView[], attributes: InstanceNormAttributes) => { const xShape = inputs[0].dims; const outputShape = xShape; - const outputSize = ShapeUtil.size(outputShape); const N = xShape[0]; const C = xShape[xShape.length - 1]; const H = ShapeUtil.sizeFromDimension(xShape, 1) / C; + const components = getMaxComponents(C); + const outputSize = ShapeUtil.size(outputShape) / components; + const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components); + const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const scaleType = components === 1 ? 'vec2f' : `mat2x${components}f`; + const scaleCastType = components === 1 ? dataType : `vec${components}<${dataType}>`; + // first compute mean + const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon); - const normCount = C * N; const getShaderSource = (shaderHelper: ShaderHelper) => ` - const N: u32 = ${N}; const H: u32 = ${H}; - const C: u32 = ${C}; - const normSizeTyped: ${dataType} = ${H}; - const imageSize: u32 = ${H * C}; - const epsilon: f32 = ${attributes.epsilon}; + const C: u32 = ${C / components}; - @group(0) @binding(0) var x : array<${dataType}>; - @group(0) @binding(1) var scale : array<${dataType}>; - @group(0) @binding(2) var bias : array<${dataType}>; - @group(0) @binding(3) var output : array<${dataType}>; + @group(0) @binding(0) var input : array<${inputHelper.type.storage}>; + @group(0) @binding(1) var scaleInput : array<${scaleType}>; + @group(0) @binding(2) var output : array<${outputHelper.type.storage}>; ${shaderHelper.mainStart()} - let currentImageNumber = global_idx / C; + let currentImageNumber = global_idx / (C * H); let currentChannelNumber = global_idx % C; - // offset is channel num * N - let offset = currentImageNumber * imageSize; - if (offset >= ${outputSize}) { return; } - var mean: ${dataType} = 0; - - for (var i: u32 = 0u; i < H; i++) { - mean = mean + x[offset + i * C + currentChannelNumber]; - } - mean = mean / normSizeTyped; - - var squaredNorm: ${dataType} = 0; - for (var i: u32 = 0u; i < H; i++) { - let deviation: f32 = x[offset + i * C + currentChannelNumber] - mean; - squaredNorm = squaredNorm + deviation * deviation; - } - let invStdDev = 1 / sqrt(squaredNorm / normSizeTyped + epsilon); - let channelScale = invStdDev * scale[currentChannelNumber]; - let channelShift = bias[currentChannelNumber] - mean * channelScale; - for (var i: u32 = 0u; i < H; i++) { - let currentOffset = offset + i * C + currentChannelNumber; - output[currentOffset] = x[currentOffset] * channelScale + channelShift; - } + let scaleOffset = currentImageNumber * C + currentChannelNumber; + let scale = scaleInput[scaleOffset]; + output[global_idx] = fma(input[global_idx], ${scaleCastType}(scale[0]), ${scaleCastType}(scale[1])); }`; - return { - ...metadata, - shaderCache: {hint: attributes.cacheKey}, - getRunData: () => ({ - outputs: [ - {dims: outputShape, dataType: inputs[0].dataType}, - ], - dispatchGroup: {x: Math.ceil(normCount / 64 /* workgroup size */)} - }), - getShaderSource, - }; + context.compute( + { + name: 'InstanceNormalization', + shaderCache: {hint: `${attributes.cacheKey}`}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + }), + getShaderSource, + }, + {inputs: [inputs[0], channelScaleShift]}); }; export const parseInstanceNormAttributes = (attributes: InstanceNormAttributes): InstanceNormAttributes => @@ -181,7 +272,7 @@ export const parseInstanceNormAttributes = (attributes: InstanceNormAttributes): export const instanceNorm = (context: ComputeContext, attributes: InstanceNormAttributes): void => { if (attributes.format === 'NHWC') { - context.compute(createInstanceNormNHWCProgramInfo(context.inputs, attributes)); + createInstanceNormNHWCProgramInfo(context, context.inputs, attributes); } else { context.compute(createInstanceNormProgramInfo(context.inputs, attributes)); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index 186a9999e53f2..8a9eeecf2c68d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType,} from './common'; export interface LayerNormAttributes extends AttributeWithCacheKey { axis: number; @@ -18,10 +18,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length < 2) { throw new Error('layerNorm requires at least 2 inputs.'); } - - if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) { - throw new Error('inputs should be float type'); - } }; const createLayerNormProgramInfo = @@ -31,7 +27,6 @@ const createLayerNormProgramInfo = const bias = inputs[2]; const outputShape = xShape; - const outputSize = ShapeUtil.size(outputShape); const axis = ShapeUtil.normalizeAxis(attributes.axis, xShape.length); const normCount = ShapeUtil.sizeToDimension(xShape, axis); const normSize = ShapeUtil.sizeFromDimension(xShape, axis); @@ -53,44 +48,54 @@ const createLayerNormProgramInfo = } } + const components = getMaxComponents(normSize); const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const variables = [ + inputVariable('x', inputs[0].dataType, inputs[0].dims, components), + inputVariable('scale', scale.dataType, scale.dims, components), + ]; + if (bias) { + variables.push(inputVariable('bias', bias.dataType, bias.dims, components)); + } + variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); const hasMeanDataOutput = outputCount > 1; const hasInvStdOutput = outputCount > 2; - let bindingIndex = 0; + + if (hasMeanDataOutput) { + variables.push(outputVariable('meanDataOutput', DataType.float, meanInvStdDevDim)); + } + if (hasInvStdOutput) { + variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim)); + } + const getShaderSource = (shaderHelper: ShaderHelper) => ` - const normSize: u32 = ${normSize}; - const normSizeTyped: ${dataType} = ${normSize}; + const normSize: f32 = ${normSize}; + const normSizeVectorized: u32 = ${normSize / components}; const epsilon: f32 = ${attributes.epsilon}; - @group(0) @binding(${bindingIndex++}) var x : array<${dataType}>; - @group(0) @binding(${bindingIndex++}) var scale : array<${dataType}>; - ${bias ? `@group(0) @binding(${bindingIndex++}) var bias : array<${dataType}>;` : ''} - @group(0) @binding(${bindingIndex++}) var output : array<${dataType}>; - ${ - hasMeanDataOutput ? - `@group(0) @binding(${bindingIndex++}) var meanDataOutput : array<${dataType}>` : - ''}; - ${ - hasInvStdOutput ? - `@group(0) @binding(${bindingIndex++}) var invStdOutput : array<${dataType}>` : - ''}; - + ${shaderHelper.declareVariables(...variables)} ${shaderHelper.mainStart()} - let offset = global_idx * normSize; - if (offset >= ${outputSize}) { return; } - var mean: ${dataType} = 0; - var meanSquare: ${dataType} = 0; - - for (var h: u32 = 0u; h < normSize; h++) { - mean = mean + x[h + offset]; - meanSquare = meanSquare + x[h + offset] * x[h + offset]; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(normCount)} + let offset = global_idx * normSizeVectorized; + var meanVector = ${fillVector('f32', components)}; + var meanSquareVector = ${fillVector('f32', components)}; + + for (var h: u32 = 0u; h < normSizeVectorized; h++) { + let value = ${castToF32(dataType, components, 'x[h + offset]')}; + meanVector += value; + meanSquareVector += value * value; } - mean = mean / normSizeTyped; - meanSquare = sqrt(meanSquare / normSizeTyped - mean * mean + epsilon); - - for (var j: u32 = 0; j < normSize; j++) { - output[j + offset] = (x[j + offset] - mean) / meanSquare * scale[j] ${bias ? '+ bias[j]' : ''}; + let mean = ${sumVector('meanVector', components)} / normSize; + let meanSquare = sqrt(${sumVector('meanSquareVector', components)} + / normSize - mean * mean + epsilon); + + for (var j: u32 = 0; j < normSizeVectorized; j++) { + let f32input = ${castToF32(dataType, components, 'x[j + offset]')}; + let f32scale = ${castToF32(dataType, components, 'scale[j]')}; + output[j + offset] = ${variables[0].type.value}((f32input - mean) / meanSquare * f32scale + ${bias ? `+ ${castToF32(dataType, components, 'bias[j]')}` : ''} + ); } ${hasMeanDataOutput ? 'meanDataOutput[global_idx] = mean' : ''}; @@ -98,14 +103,10 @@ const createLayerNormProgramInfo = }`; const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; if (hasMeanDataOutput) { - outputs.push( - {dims: meanInvStdDevDim, dataType: inputs[0].dataType}, - ); + outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); } if (hasInvStdOutput) { - outputs.push( - {dims: meanInvStdDevDim, dataType: inputs[0].dataType}, - ); + outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); } return { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts index 05f02b07c4d89..1538644412afd 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -18,16 +18,18 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 1) { throw new Error('Pool ops requires 1 input.'); } - if (inputs[0].dims.length !== 4) { - throw new Error('Pool ops supports 2-D inputs only for now.'); + if (inputs[0].dims.length !== 4 && inputs[0].dims.length !== 3) { + throw new Error('Pool ops supports 1-D or 2-D inputs only for now.'); } }; const getAdjustedPoolAttributesAndOutputShape = ( input: TensorView, attributes: AttributeType, isGlobalOperator: boolean): [AttributeType, number[]] => { const isChannelsLast = attributes.format === 'NHWC'; - const inputShapeAsChannelFirst = - isChannelsLast ? [input.dims[0], input.dims[3], input.dims[1], input.dims[2]] : input.dims.slice(); + const inputShapeAsChannelFirst = input.dims.slice(); + if (isChannelsLast) { + inputShapeAsChannelFirst.splice(1, 0, inputShapeAsChannelFirst.pop()!); // Move channel to the second position. + } const hasDilations = Object.hasOwnProperty.call(attributes, 'dilations'); const kernelShape = attributes.kernelShape.slice(); const strides = attributes.strides.slice(); @@ -44,15 +46,9 @@ const getAdjustedPoolAttributesAndOutputShape = ( @@ -76,22 +72,22 @@ const generatePoolingCode = = ${inputDims[dimIdxW]}) { - pad++; - continue; - } - let x_val = x[${x.indicesToOffset('xIndices')}]; - ${op1} - }`; + for (var i: u32 = 0u; i < ${kw}u; i++) { + xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i; + if (xIndices[${dimIdxW}] < 0 || xIndices[${dimIdxW}] >= ${inputDims[dimIdxW]}) { + pad++; + continue; + } + let x_val = x[${x.indicesToOffset('xIndices')}]; + ${op1} + }`; } else { codeW = ` - for (var i: u32 = 0u; i < ${kw}u; i++) { - xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i; - let x_val = x[${x.indicesToOffset('xIndices')}]; - ${op1} - }`; + for (var i: u32 = 0u; i < ${kw}u; i++) { + xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i; + let x_val = x[${x.indicesToOffset('xIndices')}]; + ${op1} + }`; } if (attributes.kernelShape.length === 2) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts new file mode 100644 index 0000000000000..54a7414360c4e --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts @@ -0,0 +1,266 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {ComputeContext, ProgramInfo, ProgramShaderCacheInfo} from '../types'; + +import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {createReduceAttributesFromInputs, ReduceAttributes} from './reduce'; +import {createTransposeProgramInfo} from './transpose'; + +const reduceOps: {[key: string]: string} = { + max: 'select(bestValue, candidate, candidate > bestValue)', + min: 'select(bestValue, candidate, candidate < bestValue)', + mean: 'bestValue + candidate', + sum: 'bestValue + candidate', + prod: 'bestValue * candidate', + sumSquare: 'bestValue + candidate * candidate', + logSumExp: 'bestValue + exp(candidate)', + l1: 'bestValue + abs(candidate)', + l2: 'bestValue + candidate * candidate', + logSum: 'bestValue + candidate' +}; + +const reduceSharedOps: {[key: string]: string} = { + max: 'select(bestValue, candidate, candidate > bestValue)', + min: 'select(bestValue, candidate, candidate < bestValue)', + mean: 'bestValue + candidate', + sum: 'bestValue + candidate', + prod: 'bestValue * candidate', + sumSquare: 'bestValue + candidate', + logSumExp: 'bestValue + candidate', + l1: 'bestValue + candidate', + l2: 'bestValue + candidate', + logSum: 'bestValue + candidate' +}; + +const reduceInitValues: {[key: string]: string} = { + max: '_A[offset]', + min: '_A[offset]', + mean: '0', + sum: '0', + prod: '1', + sumSquare: '0', + logSumExp: '0', + l1: '0', + l2: '0', + logSum: '0' +}; + +const reduceOutputValues: {[key: string]: string} = { + max: 'bestValue', + min: 'bestValue', + sum: 'bestValue', + prod: 'bestValue', + sumSquare: 'bestValue', + logSumExp: 'log(bestValue)', + l1: 'bestValue', + l2: 'sqrt(bestValue)', + logSum: 'log(bestValue)' +}; + +const getInnerMostAxes = (numInnerAxes: number, rank: number): number[] => { + const res = []; + for (let i = rank - numInnerAxes; i < rank; ++i) { + res.push(i); + } + return res; +}; + +const computeOutAndReduceShapes = (shape: readonly number[], axes: readonly number[]): [number[], number[]] => { + const outputShape = []; + const rank = shape.length; + for (let dim = 0; dim < rank; dim++) { + if (axes.indexOf(dim) === -1) { + outputShape.push(shape[dim]); + } + } + const reduceShape = axes.map(dim => shape[dim]); + return [outputShape, reduceShape]; +}; + +const expandShapeToKeepDim = (shape: number[], axes: number[]): number[] => { + const rank = shape.length + axes.length; + const expandShape = []; + let shapeIdx = 0; + for (let dim = 0; dim < rank; dim++) { + if (axes.indexOf(dim) === -1) { + expandShape.push(shape[shapeIdx++]); + } else { + expandShape.push(1); + } + } + return expandShape; +}; + +const areAxesInnerMostDims = (axes: number[], rank: number): boolean => { + for (let i = 0; i < axes.length; ++i) { + if (axes[axes.length - i - 1] !== rank - 1 - i) { + return false; + } + } + return true; +}; + +const getAxesPermutation = (axes: number[], rank: number): number[] => { + const res = []; + if (!areAxesInnerMostDims(axes, rank)) { + for (let i = 0; i < rank; ++i) { + if (axes.indexOf(i) === -1) { + res.push(i); + } + } + axes.forEach(axis => res.push(axis)); + } + return res; +}; + +export const createReduceSharedProgramInfo = + (name: string, shaderCache: ProgramShaderCacheInfo, inputs: readonly TensorView[], reduceType: string, + outputDataType: DataType, outputShape: number[], reduceShape: number[]): ProgramInfo => { + const inputShape = inputs[0].dims; + + const outputSize = ShapeUtil.size(outputShape); + const reduceSize = ShapeUtil.size(reduceShape); + + const input = inputVariable('_A', inputs[0].dataType, inputShape); + const output = outputVariable('output', outputDataType, outputShape); + + const workgroupSize = 32; + + const sharedMemorySnippet = ` + var aBestValues : array<${output.type.storage}, ${workgroupSize}>; + `; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${shaderHelper.registerUniform('reduceSize', 'u32').declareVariables(input, output)} + ${sharedMemorySnippet} + fn DIV_CEIL(a : u32, b : u32) -> u32 { + return ((a - 1u) / b + 1u); + } + ${shaderHelper.mainStart(workgroupSize)} + let local_idx = local_id.x; + + let outputIndex = global_idx / ${workgroupSize}; + let offset = outputIndex * uniforms.reduceSize; + + var bestValue = ${output.type.storage}(${reduceInitValues[reduceType]}); + let Length = uniforms.reduceSize; + for (var k = local_idx; k < Length; k = k + ${workgroupSize}) { + let candidate = ${output.type.storage}(${input.getByOffset('offset + k')}); + bestValue = ${reduceOps[reduceType]}; + } + aBestValues[local_idx] = bestValue; + workgroupBarrier(); + + var reduceSize = min(Length, ${workgroupSize}u); + for (var currentSize = reduceSize / 2u; reduceSize > 1u; + currentSize = reduceSize / 2u) { + let interval = DIV_CEIL(reduceSize, 2u); + if (local_idx < currentSize) { + let candidate = aBestValues[local_idx + interval]; + bestValue = ${reduceSharedOps[reduceType]}; + aBestValues[local_idx] = bestValue; + } + reduceSize = interval; + workgroupBarrier(); + } + + if (local_idx == 0u) { + ${ + output.setByOffset( + 'outputIndex', + `${ + reduceType === 'mean' ? `bestValue / ${output.type.storage}(uniforms.reduceSize)` : + `${reduceOutputValues[reduceType]}`}`)}; + } + }`; + + // One work group is responsible for only one element of output. + return { + name, + shaderCache, + getShaderSource, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: outputDataType}], + dispatchGroup: {x: outputSize}, + programUniforms: [{type: 'uint32', data: reduceSize}] + }), + }; + }; + +const reduceCommon = + (context: ComputeContext, name: string, attributes: ReduceAttributes, + reduceType: 'sum'|'sumSquare'|'prod'|'min'|'max'|'mean'|'logSumExp'|'l1'|'l2'|'logSum'): void => { + const updatedAttributes: ReduceAttributes = + context.inputs.length === 1 ? attributes : createReduceAttributesFromInputs(context.inputs, attributes); + + let updatedAxes = updatedAttributes.axes; + if (updatedAxes.length === 0 && !updatedAttributes.noopWithEmptyAxes) { + updatedAxes = context.inputs[0].dims.map((s, i) => i); + } + const normalizeAxes = ShapeUtil.normalizeAxes(updatedAxes, context.inputs[0].dims.length); + + let axes = normalizeAxes; + let input = context.inputs[0]; + const permutedAxes = getAxesPermutation(axes, context.inputs[0].dims.length); + if (permutedAxes.length > 0) { + input = context.compute( + createTransposeProgramInfo(context.inputs[0], permutedAxes), {inputs: [0], outputs: [-1]})[0]; + axes = getInnerMostAxes(axes.length, input.dims.length); + } + + const [outputShape, reduceShape] = computeOutAndReduceShapes(input.dims, axes); + let finalOutputShape = outputShape; + if (updatedAttributes.keepDims) { + finalOutputShape = expandShapeToKeepDim(outputShape, normalizeAxes); + } + + context.compute( + createReduceSharedProgramInfo( + name, {hint: updatedAttributes.cacheKey, inputDependencies: ['type']}, [input], reduceType, + context.inputs[0].dataType, finalOutputShape, reduceShape), + {inputs: [input]}); + }; + +export const reduceMeanShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceMeanShared', attributes, 'mean'); +}; + +export const reduceL1Shared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceL1Shared', attributes, 'l1'); +}; + +export const reduceL2Shared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceL2Shared', attributes, 'l2'); +}; + +export const reduceLogSumExpShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceLogSumExpShared', attributes, 'logSumExp'); +}; + +export const reduceMaxShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceMaxShared', attributes, 'max'); +}; + +export const reduceMinShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceMinShared', attributes, 'min'); +}; + +export const reduceProdShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceProdShared', attributes, 'prod'); +}; + +export const reduceSumShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceSumShared', attributes, 'sum'); +}; + +export const reduceSumSquareShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceSumSquareShared', attributes, 'sumSquare'); +}; + +export const reduceLogSumShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceLogSumShared', attributes, 'logSum'); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts index 44d6332852d2a..b5c956e57a9b1 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts @@ -8,6 +8,7 @@ import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-w import {ComputeContext, ProgramInfo, ProgramShaderCacheInfo} from '../types'; import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {reduceL1Shared, reduceL2Shared, reduceLogSumExpShared, reduceLogSumShared, reduceMaxShared, reduceMeanShared, reduceMinShared, reduceProdShared, reduceSumShared, reduceSumSquareShared} from './reduce-shared'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length === 0 || inputs.length > 2) { @@ -106,7 +107,7 @@ export const createReduceProgramInfo = }; }; -const createReduceAttributesFromInputs = +export const createReduceAttributesFromInputs = (inputs: readonly TensorView[], attributes: ReduceAttributes): ReduceAttributes => { const axes: number[] = []; if (inputs[1].dims[0] > 0) { @@ -131,7 +132,7 @@ const runReduceProgram = {inputs: [0]}); }; -export const reduceLogSum = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceLogSumNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(0);`, @@ -142,7 +143,7 @@ export const reduceLogSum = (context: ComputeContext, attributes: ReduceAttribut runReduceProgram(context, 'ReduceLogSum', attributes, reduceOp); }; -export const reduceL1 = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceL1Naive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(0);`, @@ -153,7 +154,7 @@ export const reduceL1 = (context: ComputeContext, attributes: ReduceAttributes): runReduceProgram(context, 'ReduceL1', attributes, reduceOp); }; -export const reduceL2 = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceL2Naive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`, @@ -164,7 +165,7 @@ export const reduceL2 = (context: ComputeContext, attributes: ReduceAttributes): runReduceProgram(context, 'ReduceL2', attributes, reduceOp); }; -export const reduceLogSumExp = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceLogSumExpNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(0);`, @@ -175,7 +176,7 @@ export const reduceLogSumExp = (context: ComputeContext, attributes: ReduceAttri runReduceProgram(context, 'ReduceLogSumExp', attributes, reduceOp); }; -export const reduceMax = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceMaxNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, _output, axes) => { const idxZero = []; @@ -195,7 +196,7 @@ export const reduceMax = (context: ComputeContext, attributes: ReduceAttributes) runReduceProgram(context, 'ReduceMax', attributes, reduceOp); }; -export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceMeanNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output, axes) => { let size = 1.0; @@ -216,7 +217,7 @@ export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes runReduceProgram(context, 'ReduceMean', attributes, reduceOp); }; -export const reduceMin = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceMinNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, _output, axes) => { const idxZero = []; @@ -236,7 +237,7 @@ export const reduceMin = (context: ComputeContext, attributes: ReduceAttributes) runReduceProgram(context, 'ReduceMin', attributes, reduceOp); }; -export const reduceProd = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceProdNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(1);`, @@ -247,7 +248,7 @@ export const reduceProd = (context: ComputeContext, attributes: ReduceAttributes runReduceProgram(context, 'ReduceProd', attributes, reduceOp); }; -export const reduceSum = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceSumNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(0);`, @@ -258,7 +259,7 @@ export const reduceSum = (context: ComputeContext, attributes: ReduceAttributes) runReduceProgram(context, 'ReduceSum', attributes, reduceOp); }; -export const reduceSumSquare = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceSumSquareNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`, @@ -269,5 +270,107 @@ export const reduceSumSquare = (context: ComputeContext, attributes: ReduceAttri runReduceProgram(context, 'ReduceSumSquare', attributes, reduceOp); }; +const useNaiveReduceMethod = + (shape: readonly number[], axes: readonly number[], noopWithEmptyAxes: boolean): boolean => { + if (axes.length === 0) { + return noopWithEmptyAxes ? true : false; + } + + let outputSize = 1; + let reduceSize = 1; + for (let dim = 0; dim < axes.length; dim++) { + if (axes.indexOf(dim) === -1) { + outputSize *= shape[dim]; + } else { + reduceSize *= shape[dim]; + } + } + + // The condition data is very rough, although considering the count of Execution Unit (EU), the potential + // work groups in a EU and the counts of loops in the naive and shared methods, also doing experiments + // on some machines. + return reduceSize < 32 && outputSize > 1024 ? true : false; + }; + +export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceMeanNaive(context, attributes); + } else { + reduceMeanShared(context, attributes); + } +}; + +export const reduceL1 = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceL1Naive(context, attributes); + } else { + reduceL1Shared(context, attributes); + } +}; + +export const reduceL2 = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceL2Naive(context, attributes); + } else { + reduceL2Shared(context, attributes); + } +}; + +export const reduceLogSumExp = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceLogSumExpNaive(context, attributes); + } else { + reduceLogSumExpShared(context, attributes); + } +}; + +export const reduceMax = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceMaxNaive(context, attributes); + } else { + reduceMaxShared(context, attributes); + } +}; + +export const reduceMin = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceMinNaive(context, attributes); + } else { + reduceMinShared(context, attributes); + } +}; + +export const reduceProd = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceProdNaive(context, attributes); + } else { + reduceProdShared(context, attributes); + } +}; + +export const reduceSum = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceSumNaive(context, attributes); + } else { + reduceSumShared(context, attributes); + } +}; + +export const reduceSumSquare = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceSumSquareNaive(context, attributes); + } else { + reduceSumSquareShared(context, attributes); + } +}; + +export const reduceLogSum = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceLogSumNaive(context, attributes); + } else { + reduceLogSumShared(context, attributes); + } +}; + export const parseReduceAttributes = (attributes: Record): ReduceAttributes => createAttributeWithCacheKey(attributes as Omit); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index fed1dbcf51e9b..07cfefb8f191b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -454,6 +454,7 @@ const createResizeProgramInfo = const noScale = inputShape.length === outputShape.length && inputShape.every((d, i) => d === outputShape[i]); const useExtrapolation = attributes.coordinateTransformMode === 'tf_crop_and_resize'; const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${noScale ? '' : ` ${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode)}; ${(() => { switch (attributes.mode) { @@ -483,23 +484,22 @@ const createResizeProgramInfo = throw Error('Invalid resize mode'); } })()}; + `} ${shaderHelper.declareVariables(input, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - if (${noScale}) { - output[global_idx] = input[global_idx]; - } else { - let outputIndices = ${output.offsetToIndices('global_idx')}; - var inputIndices: ${input.type.indices}; - ${(() => { + ${noScale ? 'output[global_idx] = input[global_idx];' : ` + let outputIndices = ${output.offsetToIndices('global_idx')}; + var inputIndices: ${input.type.indices}; + ${(() => { switch (attributes.mode) { case 'nearest': return `inputIndices = calculateInputIndicesFromOutputIndices(outputIndices); - if (checkInputIndices(inputIndices)) { - output[global_idx] = input[${input.indicesToOffset('inputIndices')}]; - } else { - output[global_idx] = ${attributes.extrapolationValue}; - }`; + if (checkInputIndices(inputIndices)) { + output[global_idx] = input[${input.indicesToOffset('inputIndices')}]; + } else { + output[global_idx] = ${attributes.extrapolationValue}; + }`; case 'linear': return 'output[global_idx] = bilinearInterpolation(outputIndices);'; case 'cubic': @@ -508,14 +508,14 @@ const createResizeProgramInfo = throw Error(`Unsupported resize mode: ${attributes.mode}`); } })()}; - } + `} }`; return { name: 'Resize', shaderCache: { hint: `${attributes.cacheKey}|${opsetVersion}|${scales.length > 0 ? scales : ''}|${ - sizes.length > 0 ? sizes : ''}` + sizes.length > 0 ? sizes : ''}|${noScale}` }, getShaderSource, getRunData: () => ({ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index 75e6a84cd6fcd..7e500f865c19b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType,} from './common'; export interface SkipLayerNormAttributes extends AttributeWithCacheKey { epsilon: number; @@ -18,9 +18,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => { throw new Error('layerNorm requires at least 3 inputs.'); } - if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) { - throw new Error('inputs should be float type'); - } const input: TensorView = inputs[0]; const skip: TensorView = inputs[1]; const gamma: TensorView = inputs[2]; @@ -74,85 +71,93 @@ const validateInputs = (inputs: readonly TensorView[]): void => { }; const createSkipLayerNormProgramInfo = - (inputs: readonly TensorView[], attributes: SkipLayerNormAttributes, outputCount: number, - isTraining: boolean): ProgramInfo => { - const inputShape = inputs[0].dims; - const inputSize = ShapeUtil.size(inputShape); - const outputShape = inputShape; - const outputSize = inputSize; - const hiddenSize = inputShape.slice(-1)[0]; - const meanInvStdDevDim = isTraining ? inputShape.slice(0, -1).concat(1) : []; - const hasBetaInput = inputs.length > 3; - const hasBiasInput = inputs.length > 4; - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const hasMeanOutput = isTraining && outputCount > 1; - const hasInvStdDevOutput = isTraining && outputCount > 2; - const hasInputSkipBiasSumOutput = outputCount > 3; - let bindingNumber = 0; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const hiddenSize: u32 = ${hiddenSize}; + (inputs: readonly TensorView[], attributes: SkipLayerNormAttributes, outputCount: number, isTraining: boolean): + ProgramInfo => { + const inputShape = inputs[0].dims; + const inputSize = ShapeUtil.size(inputShape); + const outputShape = inputShape; + const outputSize = inputSize; + const hiddenSize = inputShape.slice(-1)[0]; + const meanInvStdDevDim = isTraining ? inputShape.slice(0, -1).concat(1) : []; + const hasBetaInput = inputs.length > 3; + const hasBiasInput = inputs.length > 4; + const hasMeanOutput = isTraining && outputCount > 1; + const hasInvStdDevOutput = isTraining && outputCount > 2; + const hasInputSkipBiasSumOutput = outputCount > 3; + + const components = getMaxComponents(hiddenSize); + const variables = [ + inputVariable('x', inputs[0].dataType, inputs[0].dims, components), + inputVariable('skip', inputs[1].dataType, inputs[1].dims, components), + inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components), + ]; + if (hasBetaInput) { + variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components)); + } + if (hasBiasInput) { + variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components)); + } + variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); + if (hasMeanOutput) { + variables.push(outputVariable('meanOutput', DataType.float, meanInvStdDevDim)); + } + if (hasInvStdDevOutput) { + variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim)); + } + if (hasInputSkipBiasSumOutput) { + variables.push(outputVariable('inputSkipBiasSum', inputs[0].dataType, outputShape, components)); + } + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const hiddenSize: f32 = ${hiddenSize}; + const hiddenSizeVectorized: u32 = ${hiddenSize / components}; const epsilon: f32 = ${attributes.epsilon}; - @group(0) @binding(${bindingNumber++}) var x : array<${dataType}>; - @group(0) @binding(${bindingNumber++}) var skip : array<${dataType}>; - @group(0) @binding(${bindingNumber++}) var gamma : array<${dataType}>; - ${hasBetaInput ? `@group(0) @binding(${bindingNumber++}) var beta : array<${dataType}>;` : ''} - ${hasBiasInput ? `@group(0) @binding(${bindingNumber++}) var bias : array<${dataType}>;` : ''} - @group(0) @binding(${bindingNumber++}) var output : array<${dataType}>; - ${ - hasMeanOutput ? - `@group(0) @binding(${bindingNumber++}) var meanOutput : array<${dataType}>;` : - ''} - ${ - hasInvStdDevOutput ? - `@group(0) @binding(${bindingNumber++}) var invStdOutput : array<${dataType}>;` : - ''} - ${ - hasInputSkipBiasSumOutput ? - `@group(0) @binding(${bindingNumber++}) var inputSkipBiasSum : array<${dataType}>;` : - ''} + ${shaderHelper.declareVariables(...variables)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize / hiddenSize)} - let offset = global_idx * hiddenSize; - var sum: f32 = 0.0; - var squareSum: f32 = 0.0; - for (var i: u32 = 0; i < hiddenSize; i++) { + let offset = global_idx * hiddenSizeVectorized; + var sum = ${fillVector('f32', components)}; + var squareSum = ${fillVector('f32', components)}; + for (var i: u32 = 0; i < hiddenSizeVectorized; i++) { let skipValue = skip[offset + i]; let biasValue = ${hasBiasInput ? 'bias[i]' : '0.0'}; let inputValue = x[offset + i]; let value = inputValue + skipValue + biasValue; ${hasInputSkipBiasSumOutput ? 'inputSkipBiasSum[offset + i] = value;' : ''} output[offset + i] = value; - sum += value; - squareSum += value * value; + let f32Value = ${castToF32(dataType, components, 'value')}; + sum += f32Value; + squareSum += f32Value * f32Value; } - let mean: f32 = sum / f32(hiddenSize); - let variance: f32 = sqrt(squareSum / f32(hiddenSize) - mean * mean + epsilon); + let mean = ${sumVector('sum', components)} / hiddenSize; + let variance = sqrt(${sumVector('squareSum', components)} / hiddenSize - mean * mean + epsilon); ${hasMeanOutput ? 'meanOutput[global_idx] = mean;' : ''} ${hasInvStdDevOutput ? 'invStdOutput[global_idx] = 1.0 / variance;' : ''} - for (var i: u32 = 0; i < hiddenSize; i++) { - output[offset + i] = (output[offset + i] - mean) / variance * gamma[i] + ${hasBetaInput ? 'beta[i]' : '0.0'}; + for (var i: u32 = 0; i < hiddenSizeVectorized; i++) { + output[offset + i] = (output[offset + i] - ${dataType}(mean)) / ${dataType}(variance) * gamma[i] + + ${hasBetaInput ? 'beta[i]' : '0.0'}; } }`; - const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; - if (outputCount > 1) { - outputs.push({dims: meanInvStdDevDim, dataType: inputs[0].dataType}); - } - if (outputCount > 2) { - outputs.push({dims: meanInvStdDevDim, dataType: inputs[0].dataType}); - } - if (outputCount > 3) { - outputs.push({dims: inputShape, dataType: inputs[0].dataType}); - } - - return { - name: 'SkipLayerNormalization', - shaderCache: {hint: attributes.cacheKey}, - getShaderSource, - getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}}), - }; - }; + const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; + if (outputCount > 1) { + outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); + } + if (outputCount > 2) { + outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); + } + if (outputCount > 3) { + outputs.push({dims: inputShape, dataType: inputs[0].dataType}); + } + + return { + name: 'SkipLayerNormalization', + shaderCache: {hint: attributes.cacheKey}, + getShaderSource, + getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}}), + }; + }; export const skipLayerNorm = (context: ComputeContext, attributes: SkipLayerNormAttributes): void => { // TODO: initialize isTraining from ComputeContext diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index d4dbad79e613e..ec651ce34e8c3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -10,7 +10,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {getMaxComponents, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 1) { @@ -37,23 +37,39 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut const cols = shape[axis]; const rows = outputSize / cols; + const components = getMaxComponents(cols); + const packedCols = cols / components; + const valueType = components === 1 ? dataType : `vec${components}<${dataType}>`; + + const maxVector = (name: string, components: number) => { + if (components === 4) { + return `max(max(${name}.x, ${name}.y), max(${name}.z, ${name}.w))`; + } else if (components === 2) { + return `max(${name}.x, ${name}.y)`; + } else if (components === 3) { + return `max(max(${name}.x, ${name}.y), ${name}.z)`; + } + + return name; + }; // 6.2.4 in wgsl spec - const threadMaxDecl = dataType === 'f32' ? 'var threadMax: f32 = -3.402823e+38f;' : 'var threadMax: f16 = -65504.0h;'; + const threadMaxDecl = + dataType === 'f32' ? `var threadMax = ${valueType}(-3.402823e+38f);` : `var threadMax = ${valueType}(-65504.0h);`; const getShaderSource = (_shaderHelper: ShaderHelper) => ` - var rowMaxShared : ${dataType}; - var rowSumShared : ${dataType}; - var threadShared : array<${dataType}, ${WG}>; + var rowMaxShared : ${valueType}; + var rowSumShared : ${valueType}; + var threadShared : array<${valueType}, ${WG}>; - @group(0) @binding(0) var x : array<${dataType}>; - @group(0) @binding(1) var result : array<${dataType}>; + @group(0) @binding(0) var x : array<${valueType}>; + @group(0) @binding(1) var result : array<${valueType}>; - fn getValue(row: i32, col: i32, row_stride: i32) -> ${dataType} { + fn getValue(row: i32, col: i32, row_stride: i32) -> ${valueType} { let index = row * row_stride + col; return x[index]; } - fn setValue(row: i32, col: i32, row_stride: i32, value: ${dataType}) { + fn setValue(row: i32, col: i32, row_stride: i32, value: ${valueType}) { let index = row * row_stride + col; result[index] = value; } @@ -64,8 +80,8 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut let lindex = i32(local_id.x); const wg = ${WG}; let row = gindex / wg; - let cols = ${cols}; - let row_stride : i32 = ${cols}; + let cols = ${packedCols}; + let row_stride : i32 = ${packedCols}; // find the rows max ${threadMaxDecl} @@ -87,12 +103,12 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut workgroupBarrier(); } if (lindex == 0) { - rowMaxShared = threadShared[0]; + rowMaxShared = ${valueType}(${maxVector('threadShared[0]', components)}); } workgroupBarrier(); // find the rows sum - var threadSum: ${dataType} = 0.0; + var threadSum = ${valueType}(0.0); for (var col = lindex; col < cols; col += wg) { let subExp = exp(getValue(row, col, row_stride) - rowMaxShared); threadSum += subExp; @@ -107,7 +123,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut workgroupBarrier(); } if (lindex == 0) { - rowSumShared = threadShared[0]; + rowSumShared = ${valueType}(${sumVector('threadShared[0]', components)}); } workgroupBarrier(); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index fe556a7fd8552..c4d43e9f466f5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -6,7 +6,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; export interface TransposeAttributes extends AttributeWithCacheKey { readonly perm: number[]; @@ -35,13 +35,18 @@ const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, ou return reverseFunc.join('\n'); }; -export const createTransposeProgramInfo = - (inputDataType: number, inputRank: number, permAttr: number[]): ProgramInfo => { - const perm = getAdjustedPerm(inputRank, permAttr); - const output = outputVariable('output', inputDataType, (permAttr && permAttr.length) || inputRank); - const input = inputVariable('a', inputDataType, inputRank); +export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: number[]): ProgramInfo => { + const inputDataType = inputTensor.dataType; + const inputRank = inputTensor.dims.length; + const perm = getAdjustedPerm(inputRank, permAttr); + const useShapesUniforms = enableShapesUniforms(inputRank); + const outputShape = getOutputShape(inputTensor.dims, perm); + const outShapeOrRank = useShapesUniforms ? outputShape.length : outputShape; + const inShapeOrRank = useShapesUniforms ? inputRank : inputTensor.dims; + const output = outputVariable('output', inputDataType, outShapeOrRank); + const input = inputVariable('a', inputDataType, inShapeOrRank); - const getShaderSource = (shaderHelper: ShaderHelper) => ` + const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} ${permFunctionBody(perm, inputRank, input, output)} @@ -54,30 +59,32 @@ export const createTransposeProgramInfo = ${output.setByOffset('global_idx', input.getByIndices('aIndices'))} }`; + return { + name: 'Transpose', + shaderCache: {hint: `${permAttr}`, inputDependencies: useShapesUniforms ? ['rank'] : ['dims']}, + getRunData: (inputs) => { + const outputSize = ShapeUtil.size(outputShape); return { - name: 'Transpose', - shaderCache: {hint: `${permAttr}`, inputDependencies: ['rank']}, - getRunData: (inputs) => { - const outputShape = getOutputShape(inputs[0].dims, perm); - const outputSize = ShapeUtil.size(outputShape); - return { - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: [ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms: useShapesUniforms ? + [ {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape), + ] : + [ + {type: 'uint32', data: outputSize}, ], - }; - }, - getShaderSource, }; - }; + }, + getShaderSource, + }; +}; export const transpose = (context: ComputeContext, attributes: TransposeAttributes): void => { validateInputs(context.inputs); - context.compute( - createTransposeProgramInfo(context.inputs[0].dataType, context.inputs[0].dims.length, attributes.perm)); + context.compute(createTransposeProgramInfo(context.inputs[0], attributes.perm)); }; export const parseTransposeAttributes = (attributes: Record): TransposeAttributes => diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index bead3e72f63c7..4238449f9246f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -29,12 +29,12 @@ const createElementwiseProgramShader = const output = outputVariable('outputData', outputDataType, [vecSize], 4); return ` - ${shaderHelper.declareVariables(input, output)} + ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(input, output)} ${additionalImplementation ?? ''} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')} let a = ${input.getByOffset('global_idx')}; ${output.setByOffset('global_idx', expression)} @@ -45,13 +45,16 @@ const createElementwiseProgramInfo = (input: TensorView, name: string, funcCall: ElementwiseFunctionCall, additionalImplementation?: string, cacheKey?: string, outputDataType: number = input.dataType): ProgramInfo => ({ name, - shaderCache: {hint: cacheKey}, + shaderCache: {hint: cacheKey, inputDependencies: ['type']}, getShaderSource: shaderHelper => createElementwiseProgramShader( shaderHelper, ShapeUtil.size(input.dims), input.dataType, outputDataType, funcCall, additionalImplementation), getRunData: (inputTensors) => ({ outputs: [{dims: input.dims, dataType: outputDataType}], dispatchGroup: - {x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */)} + {x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */)}, + programUniforms: [ + {type: 'uint32', data: Math.ceil(ShapeUtil.size(input.dims) / 4)}, + ], }) }); diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index 5c5a07d90d34a..341e6edf26cc8 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -38,14 +38,6 @@ export class ProgramManager { const device = this.backend.device; const computePassEncoder = this.backend.getComputePassEncoder(); - const profilingEnabled = this.backend.supportTimestampQuery && this.backend.env.webgpu.profilingMode === 'default'; - if (profilingEnabled) { - // profiling write start timestamp - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (computePassEncoder as any).writeTimestamp(this.backend.profilingQuerySet, 0); - } - computePassEncoder.setPipeline(buildArtifact.computePipeline); const entries = []; for (const input of inputs) { @@ -65,24 +57,20 @@ export class ProgramManager { this.backend.pendingDispatchNumber++; - if (profilingEnabled) { - // profiling write end timestamp - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (computePassEncoder as any).writeTimestamp(this.backend.profilingQuerySet, 1); - if (this.backend.profilingQueryData == null) { - this.backend.profilingQueryData = + if (this.backend.isQueryEnabled()) { + if (typeof this.backend.queryData === 'undefined') { + this.backend.queryData = this.backend.gpuDataManager.create( // eslint-disable-next-line no-bitwise - this.backend.gpuDataManager.create(16, GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE); + this.backend.querySetCount * 8, GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE); } - // eslint-disable-next-line no-bitwise - const syncData = this.backend.gpuDataManager.create(16, GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST); + const syncData = this.backend.gpuDataManager.create( + // eslint-disable-next-line no-bitwise + this.backend.querySetCount * 8, GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST); this.backend.endComputePass(); - this.backend.getCommandEncoder().resolveQuerySet( - this.backend.profilingQuerySet, 0, 2, this.backend.profilingQueryData.buffer, 0); + this.backend.getCommandEncoder().resolveQuerySet(this.backend.querySet, 0, 2, this.backend.queryData.buffer, 0); this.backend.getCommandEncoder().copyBufferToBuffer( - this.backend.profilingQueryData.buffer, 0, syncData.buffer, 0, 16); + this.backend.queryData.buffer, 0, syncData.buffer, 0, this.backend.querySetCount * 8); this.backend.flush(); const kernelId = this.backend.currentKernelId!; @@ -96,12 +84,12 @@ export class ProgramManager { syncData.buffer.unmap(); - if (typeof this.backend.profilingTimeBase === 'undefined') { - this.backend.profilingTimeBase = startTimeU64; + if (typeof this.backend.queryTimeBase === 'undefined') { + this.backend.queryTimeBase = startTimeU64; } - const startTime = Number(startTimeU64 - this.backend.profilingTimeBase); - const endTime = Number(endTimeU64 - this.backend.profilingTimeBase); + const startTime = Number(startTimeU64 - this.backend.queryTimeBase); + const endTime = Number(endTimeU64 - this.backend.queryTimeBase); if (!Number.isSafeInteger(startTime) || !Number.isSafeInteger(endTime)) { throw new RangeError('incorrect timestamp range'); diff --git a/js/web/lib/wasm/proxy-messages.ts b/js/web/lib/wasm/proxy-messages.ts index 7aa866773bcb1..efeb086256cf3 100644 --- a/js/web/lib/wasm/proxy-messages.ts +++ b/js/web/lib/wasm/proxy-messages.ts @@ -73,5 +73,10 @@ interface MesssageEndProfiling extends MessageError { in ?: number; } +interface MessageIsOrtEnvInitialized extends MessageError { + type: 'is-ort-env-initialized'; + out?: boolean; +} + export type OrtWasmMessage = MessageInitWasm|MessageInitOrt|MessageCreateSessionAllocate|MessageCreateSessionFinalize| - MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling; + MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling|MessageIsOrtEnvInitialized; diff --git a/js/web/lib/wasm/proxy-worker/main.ts b/js/web/lib/wasm/proxy-worker/main.ts index fe8bd9b11b191..1f4595818e5c0 100644 --- a/js/web/lib/wasm/proxy-worker/main.ts +++ b/js/web/lib/wasm/proxy-worker/main.ts @@ -4,7 +4,7 @@ /// import {OrtWasmMessage} from '../proxy-messages'; -import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, releaseSession, run} from '../wasm-core-impl'; +import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, isOrtEnvInitialized, releaseSession, run} from '../wasm-core-impl'; import {initializeWebAssembly} from '../wasm-factory'; self.onmessage = (ev: MessageEvent): void => { @@ -89,6 +89,14 @@ self.onmessage = (ev: MessageEvent): void => { postMessage({type: 'end-profiling', err} as OrtWasmMessage); } break; + case 'is-ort-env-initialized': + try { + const ortEnvInitialized = isOrtEnvInitialized(); + postMessage({type: 'is-ort-env-initialized', out: ortEnvInitialized} as OrtWasmMessage); + } catch (err) { + postMessage({type: 'is-ort-env-initialized', err} as OrtWasmMessage); + } + break; default: } }; diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index a3e4a1ef1fc75..069a1fa452dbc 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -24,6 +24,7 @@ const createSessionCallbacks: Array> = []; const runCallbacks: Array> = []; const endProfilingCallbacks: Array> = []; +const isOrtEnvInitializedCallbacks: Array> = []; const ensureWorker = (): void => { if (initializing || !initialized || aborted || !proxyWorker) { @@ -92,6 +93,13 @@ const onProxyWorkerMessage = (ev: MessageEvent): void => { endProfilingCallbacks.shift()![0](); } break; + case 'is-ort-env-initialized': + if (ev.data.err) { + isOrtEnvInitializedCallbacks.shift()![1](ev.data.err); + } else { + isOrtEnvInitializedCallbacks.shift()![0](ev.data.out!); + } + break; default: } }; @@ -251,3 +259,16 @@ export const endProfiling = async(sessionId: number): Promise => { core.endProfiling(sessionId); } }; + +export const isOrtEnvInitialized = async(): Promise => { + if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { + ensureWorker(); + return new Promise((resolve, reject) => { + isOrtEnvInitializedCallbacks.push([resolve, reject]); + const message: OrtWasmMessage = {type: 'is-ort-env-initialized'}; + proxyWorker!.postMessage(message); + }); + } else { + return core.isOrtEnvInitialized(); + } +}; diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler-inference.ts similarity index 93% rename from js/web/lib/wasm/session-handler.ts rename to js/web/lib/wasm/session-handler-inference.ts index d1760e37c93f7..3ca34d957c572 100644 --- a/js/web/lib/wasm/session-handler.ts +++ b/js/web/lib/wasm/session-handler-inference.ts @@ -5,13 +5,12 @@ import {readFile} from 'node:fs/promises'; import {env, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common'; import {SerializableModeldata, TensorMetadata} from './proxy-messages'; -import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run} from './proxy-wrapper'; +import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, isOrtEnvInitialized, releaseSession, run} from './proxy-wrapper'; import {isGpuBufferSupportedType} from './wasm-common'; -let runtimeInitialized: boolean; let runtimeInitializationPromise: Promise|undefined; -const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => { +export const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => { switch (tensor.location) { case 'cpu': return [tensor.type, tensor.dims, tensor.data, 'cpu']; @@ -22,7 +21,7 @@ const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMeta } }; -const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => { +export const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => { switch (tensor[3]) { case 'cpu': return new Tensor(tensor[0], tensor[2], tensor[1]); @@ -57,13 +56,12 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan } async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise { - if (!runtimeInitialized) { + if (!(await isOrtEnvInitialized())) { if (!runtimeInitializationPromise) { runtimeInitializationPromise = initializeRuntime(env); } await runtimeInitializationPromise; runtimeInitializationPromise = undefined; - runtimeInitialized = true; } if (typeof pathOrBuffer === 'string') { diff --git a/js/web/lib/wasm/session-handler-training.ts b/js/web/lib/wasm/session-handler-training.ts new file mode 100644 index 0000000000000..09d91591128d1 --- /dev/null +++ b/js/web/lib/wasm/session-handler-training.ts @@ -0,0 +1,131 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {env, InferenceSession, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common'; + +import {SerializableModeldata, TensorMetadata} from './proxy-messages'; +import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference'; +import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; +import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl'; + +export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { + async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + async getContiguousParameters(_trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + private sessionId: number; + private checkpointId: number; + + inputNames: string[]; + outputNames: string[]; + + inputEncodedNames: number[]; + outputEncodedNames: number[]; + + async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise { + let buffer: Uint8Array; + if (typeof uriOrBuffer === 'string') { + const response = await fetch(uriOrBuffer); + const arrayBuffer = await response.arrayBuffer(); + buffer = new Uint8Array(arrayBuffer); + } else { + buffer = uriOrBuffer; + } + return createSessionAllocate(buffer); + } + + async createTrainingSession( + checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, + evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, + options: InferenceSession.SessionOptions) { + if (!isOrtEnvInitialized()) { + await initRuntime(env); + } + const checkpointData: SerializableModeldata = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); + const trainModelData: SerializableModeldata = await this.uriOrBufferToHeap(trainModelUriOrBuffer); + // 0 is supposed to be the nullptr + let evalModelData: SerializableModeldata = [0, 0]; + let optimizerModelData: SerializableModeldata = [0, 0]; + + if (evalModelUriOrBuffer !== '') { + evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer); + } + if (optimizerModelUriOrBuffer !== '') { + optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer); + } + + this.checkpointId = createCheckpointHandle(checkpointData); + [[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] = + createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options); + } + + /** + * Helper method that converts a feeds or fetches datatype to two arrays, one of values and one that stores the + * corresponding name as a number referring to the index in the list of names provided. + * + * @param feeds meant to match either SessionHandler.FeedsType or SessionHandler.FetchesType + * @param names either inputNames or outputNames + * @returns a tuple of a list of values and a list of indices. + */ + convertMapIntoValuesArrayAndIndicesArray( + feeds: {[name: string]: T}, names: string[], mapFunc: (val: T, index: number) => U): [T[], number[], U[]] { + const values: T[] = []; + const indices: number[] = []; + Object.entries(feeds).forEach(kvp => { + const name = kvp[0]; + const tensor = kvp[1]; + const index = names.indexOf(name); + if (index === -1) { + throw new Error(`invalid input '${name}`); + } + values.push(tensor); + indices.push(index); + }); + + const uList = values.map(mapFunc); + return [values, indices, uList]; + } + + /** + * Helper method that converts the TensorMetadata that the wasm-core functions return to the + * SessionHandler.ReturnType. Any outputs in the provided outputArray that are falsy will be populated with the + * corresponding result. + * + * @param results used to populate the resultMap if there is no value for that outputName already + * @param outputArray used to populate the resultMap. If null or undefined, use the corresponding result from results + * @param outputIndices specifies which outputName the corresponding value for outputArray refers to. + * @returns a map of output names and OnnxValues. + */ + convertTensorMetadataToReturnType( + results: TensorMetadata[], outputArray: Array, outputIndices: number[]): SessionHandler.ReturnType { + const resultMap: SessionHandler.ReturnType = {}; + for (let i = 0; i < results.length; i++) { + resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]); + } + return resultMap; + } + + async runTrainStep( + feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, + options: InferenceSession.RunOptions): Promise { + const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray( + feeds, this.inputNames, + (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`)); + + const [outputArray, outputIndices, outputs] = + this.convertMapIntoValuesArrayAndIndicesArray( + fetches, this.outputNames, + (t, i): TensorMetadata|null => + t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null); + + const results = await runTrainStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); + return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); + } + + async dispose(): Promise { + return releaseTrainingSessionAndCheckpoint( + this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); + } +} diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 5b49a1d4202e3..3aacf8f4d90e0 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -10,6 +10,8 @@ import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType import {getInstance} from './wasm-factory'; import {allocWasmString, checkLastError} from './wasm-utils'; +let ortEnvInitialized = false; + /** * get the input/output count of the session. * @param sessionHandle the handle representing the session. should be non-zero. @@ -57,6 +59,8 @@ export const initRuntime = async(env: Env): Promise => { const initJsep = require('./jsep/init').init; await initJsep(getInstance(), env); } + + ortEnvInitialized = true; }; /** @@ -93,6 +97,8 @@ type SessionMetadata = [ const activeSessions = new Map(); +export const isOrtEnvInitialized = (): boolean => ortEnvInitialized; + /** * allocate the memory and memcpy the model bytes, preparing for creating an instance of InferenceSession. * @returns a 2-elements tuple - the pointer and size of the allocated buffer @@ -234,7 +240,7 @@ export const releaseSession = (sessionId: number): void => { activeSessions.delete(sessionId); }; -const prepareInputOutputTensor = +export const prepareInputOutputTensor = (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number): void => { if (!tensor) { diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts new file mode 100644 index 0000000000000..a35d285346db4 --- /dev/null +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -0,0 +1,329 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {InferenceSession, Tensor} from 'onnxruntime-common'; + +import {SerializableModeldata, SerializableSessionMetadata, TensorMetadata} from './proxy-messages'; +import {setRunOptions} from './run-options'; +import {setSessionOptions} from './session-options'; +import {tensorDataTypeEnumToString, tensorTypeToTypedArrayConstructor} from './wasm-common'; +import {prepareInputOutputTensor} from './wasm-core-impl'; +import {getInstance} from './wasm-factory'; +import {checkLastError} from './wasm-utils'; + +const NO_TRAIN_FUNCS_MSG = + 'Built without training API\'s enabled. Use the onnxruntime-web/training import for training ' + + 'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' + + 'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.'; + +export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { + const wasm = getInstance(); + + const [checkpointDataOffset, checkpointDataLength] = checkpointData; + let checkpointHandle = 0; + + try { + if (wasm._OrtTrainingLoadCheckpoint) { + checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + if (checkpointHandle === 0) { + checkLastError('Error occurred when trying to create a CheckpointState.'); + } + return checkpointHandle; + } catch (e) { + if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) { + wasm._OrtTrainingReleaseCheckpoint(checkpointHandle); + } + throw e; + } finally { + // free buffer from wasm heap + wasm._OrtFree(checkpointData[0]); + } +}; + +const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolean): [number, number] => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + try { + const dataOffset = wasm.stackAlloc(8); + if (wasm._OrtTrainingGetModelInputOutputCount) { + const errorCode = + wasm._OrtTrainingGetModelInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, isEvalModel); + if (errorCode !== 0) { + checkLastError('Can\'t get session input/output count.'); + } + return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } finally { + wasm.stackRestore(stack); + } +}; + +const getModelInputOutputNamesLoop = + (trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): [string[], number[]] => { + const names = []; + const wasm = getInstance(); + + const namesUTF8Encoded = []; + + for (let i = 0; i < count; i++) { + if (wasm._OrtTrainingGetModelInputOutputName) { + const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel); + if (name === 0) { + checkLastError('Can\'t get input or output name'); + } + + namesUTF8Encoded.push(name); + names.push(wasm.UTF8ToString(name)); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } + return [names, namesUTF8Encoded]; + }; + +const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => { + const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, false); + + const [inputNames, inputNamesUTF8Encoded] = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, false); + const [outputNames, outputNamesUTF8Encoded] = + getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, false); + + return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded]; +}; + +export const createTrainingSessionHandle = + (checkpointHandle: number, trainModelData: SerializableModeldata, evalModelData: SerializableModeldata, + optimizerModelData: SerializableModeldata, + options: InferenceSession.SessionOptions): [SerializableSessionMetadata, number[], number[]] => { + const wasm = getInstance(); + + let trainingSessionHandle = 0; + let sessionOptionsHandle = 0; + let allocs: number[] = []; + let inputNamesUTF8Encoded: number[] = []; + let outputNamesUTF8Encoded: number[] = []; + + let inputNames: string[] = []; + let outputNames: string[] = []; + + try { + [sessionOptionsHandle, allocs] = setSessionOptions(options); + if (wasm._OrtTrainingCreateSession) { + trainingSessionHandle = wasm._OrtTrainingCreateSession( + sessionOptionsHandle, checkpointHandle, trainModelData[0], trainModelData[1], evalModelData[0], + evalModelData[1], optimizerModelData[0], optimizerModelData[1]); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + if (trainingSessionHandle === 0) { + checkLastError('Error occurred when trying to create a TrainingSession.'); + } + + [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] = + getTrainingModelInputOutputNames(trainingSessionHandle); + return [[trainingSessionHandle, inputNames, outputNames], inputNamesUTF8Encoded, outputNamesUTF8Encoded]; + + } catch (e) { + if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) { + wasm._OrtTrainingReleaseSession(trainingSessionHandle); + } + throw e; + } finally { + wasm._free(trainModelData[0]); + wasm._free(evalModelData[0]); + wasm._free(optimizerModelData[0]); + + if (sessionOptionsHandle !== 0) { + wasm._OrtReleaseSessionOptions(sessionOptionsHandle); + } + allocs.forEach(alloc => wasm._free(alloc)); + inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + } + }; + +/** + * Prepares input and output tensors by creating the tensors in the WASM side then creates a list of the handles of the + * WASM tensors. + * + * @param trainingSessionId + * @param indices for each tensor, the index of the input or output name that the tensor corresponds with + * @param tensors list of TensorMetaData + * @param tensorHandles should pass in an empty list of numbers; modified in-place by this method & stores the resulting + * handles of the allocated tensors on the heap + * @param inputOutputAllocs modified in-place by this method + * @param indexAdd constant to add to the index that is passed to prepareInputOutputTensor + */ +const createAndAllocateTensors = + (trainingSessionId: number, indices: number[], tensors: Array, tensorHandles: number[], + inputOutputAllocs: number[], indexAdd: number) => { + const count = indices.length; + + // creates the tensors + for (let i = 0; i < count; i++) { + prepareInputOutputTensor( + tensors[i], tensorHandles, inputOutputAllocs, trainingSessionId, indexAdd + indices[i]); + } + + // moves to heap + const wasm = getInstance(); + const valuesOffset = wasm.stackAlloc(count * 4); + let valuesIndex = valuesOffset / 4; + for (let i = 0; i < count; i++) { + wasm.HEAPU32[valuesIndex++] = tensorHandles[i]; + } + + return valuesOffset; + }; + +/** + * Retrieves the information from the output tensor handles, copies to an array, and frees the WASM information + * associated with the tensor handle. + * + * @param outputValuesOffset + * @param outputCount + * @returns list of TensorMetadata retrieved from the output handles. + */ +const moveOutputToTensorMetadataArr = + (outputValuesOffset: number, outputCount: number, outputTensorHandles: number[], + outputTensors: Array) => { + const wasm = getInstance(); + const output: TensorMetadata[] = []; + + for (let i = 0; i < outputCount; i++) { + const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + if (tensor === outputTensorHandles[i]) { + // output tensor is pre-allocated. no need to copy data. + output.push(outputTensors[i]!); + continue; + } + + const beforeGetTensorDataStack = wasm.stackSave(); + // stack allocate 4 pointer value + const tensorDataOffset = wasm.stackAlloc(4 * 4); + + let type: Tensor.Type|undefined, dataOffset = 0; + try { + const errorCode = wasm._OrtGetTensorData( + tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); + if (errorCode !== 0) { + checkLastError(`Can't access output tensor data on index ${i}.`); + } + let tensorDataIndex = tensorDataOffset / 4; + const dataType = wasm.HEAPU32[tensorDataIndex++]; + dataOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsLength = wasm.HEAPU32[tensorDataIndex++]; + const dims = []; + for (let i = 0; i < dimsLength; i++) { + dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + } + wasm._OrtFree(dimsOffset); + + const size = dims.reduce((a, b) => a * b, 1); + type = tensorDataTypeEnumToString(dataType); + + if (type === 'string') { + const stringData: string[] = []; + let dataIndex = dataOffset / 4; + for (let i = 0; i < size; i++) { + const offset = wasm.HEAPU32[dataIndex++]; + const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; + stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); + } + output.push([type, dims, stringData, 'cpu']); + } else { + const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); + const data = new typedArrayConstructor(size); + new Uint8Array(data.buffer, data.byteOffset, data.byteLength) + .set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength)); + output.push([type, dims, data, 'cpu']); + } + } finally { + wasm.stackRestore(beforeGetTensorDataStack); + if (type === 'string' && dataOffset) { + wasm._free(dataOffset); + } + wasm._OrtReleaseTensor(tensor); + } + } + + return output; + }; + +export const runTrainStep = async( + trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], + outputTensors: Array, options: InferenceSession.RunOptions): Promise => { + const wasm = getInstance(); + + const inputCount = inputIndices.length; + const outputCount = outputIndices.length; + + let runOptionsHandle = 0; + let runOptionsAllocs: number[] = []; + + const inputTensorHandles: number[] = []; + const outputTensorHandles: number[] = []; + const inputOutputAllocs: number[] = []; + + const beforeRunStack = wasm.stackSave(); + + try { + // prepare parameters by moving them to heap + [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); + + // handle inputs -- you don't want anything added to the index + const inputValuesOffset = createAndAllocateTensors( + trainingSessionId, inputIndices, inputTensors, inputTensorHandles, inputOutputAllocs, 0); + // handle outputs + // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor + const outputValuesOffset = createAndAllocateTensors( + trainingSessionId, outputIndices, outputTensors, outputTensorHandles, inputOutputAllocs, inputCount); + + if (wasm._OrtTrainingRunTrainStep) { + const errorCode = wasm._OrtTrainingRunTrainStep( + trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle); + + if (errorCode !== 0) { + checkLastError('failed to call OrtTrainingRunTrainStep in the WebAssembly layer'); + } + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors); + } finally { + wasm.stackRestore(beforeRunStack); + + inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + inputOutputAllocs.forEach(p => wasm._free(p)); + + if (runOptionsHandle !== 0) { + wasm._OrtReleaseRunOptions(runOptionsHandle); + } + runOptionsAllocs.forEach(p => wasm._free(p)); + } +}; + +export const releaseTrainingSessionAndCheckpoint = + (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): + void => { + const wasm = getInstance(); + inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + + if (wasm._OrtTrainingReleaseSession) { + wasm._OrtTrainingReleaseSession(sessionId); + } + if (wasm._OrtTrainingReleaseCheckpoint) { + wasm._OrtTrainingReleaseCheckpoint(checkpointId); + } + }; diff --git a/js/web/package.json b/js/web/package.json index 0bdc529df704d..7271fed99d709 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -20,7 +20,7 @@ }, "scripts": { "preprepare": "node -e \"require('node:fs').copyFileSync('./node_modules/long/index.d.ts', './node_modules/long/umd/index.d.ts')\"", - "prepare": "tsc", + "prepare": "tsc --build ./script", "build:doc": "node ./script/generate-webgl-operator-md && node ./script/generate-webgpu-operator-md", "pull:wasm": "node ./script/pull-prebuilt-wasm-artifacts", "test:e2e": "node ./test/e2e/run", @@ -68,14 +68,8 @@ ".": { "node": "./dist/ort.node.min.js", "default": { - "import": { - "development": "./dist/esm/ort.js", - "default": "./dist/esm/ort.min.js" - }, - "require": { - "development": "./dist/cjs/ort.js", - "default": "./dist/cjs/ort.min.js" - }, + "import": "./dist/esm/ort.min.js", + "require": "./dist/cjs/ort.min.js", "default": { "development": "./dist/ort.js", "default": "./dist/ort.min.js" @@ -83,88 +77,37 @@ } }, "./experimental": { - "import": { - "development": "./dist/esm/ort.all.js", - "default": "./dist/esm/ort.all.min.js" - }, - "require": { - "development": "./dist/cjs/ort.all.js", - "default": "./dist/cjs/ort.all.min.js" - }, + "import": "./dist/esm/ort.all.min.js", + "require": "./dist/cjs/ort.all.min.js", "default": { "development": "./dist/ort.all.js", "default": "./dist/ort.all.min.js" } }, "./wasm": { - "import": { - "development": "./dist/esm/ort.wasm.js", - "default": "./dist/esm/ort.wasm.min.js" - }, - "require": { - "development": "./dist/cjs/ort.wasm.js", - "default": "./dist/cjs/ort.wasm.min.js" - }, - "default": { - "development": "./dist/ort.wasm.js", - "default": "./dist/ort.wasm.min.js" - } + "import": "./dist/esm/ort.wasm.min.js", + "require": "./dist/cjs/ort.wasm.min.js", + "default": "./dist/ort.wasm.min.js" }, "./wasm-core": { - "import": { - "development": "./dist/esm/ort.wasm-core.js", - "default": "./dist/esm/ort.wasm-core.min.js" - }, - "require": { - "development": "./dist/cjs/ort.wasm-core.js", - "default": "./dist/cjs/ort.wasm-core.min.js" - }, - "default": { - "development": "./dist/ort.wasm-core.js", - "default": "./dist/ort.wasm-core.min.js" - } + "import": "./dist/esm/ort.wasm-core.min.js", + "require": "./dist/cjs/ort.wasm-core.min.js", + "default": "./dist/ort.wasm-core.min.js" }, "./webgl": { - "import": { - "development": "./dist/esm/ort.webgl.js", - "default": "./dist/esm/ort.webgl.min.js" - }, - "require": { - "development": "./dist/cjs/ort.webgl.js", - "default": "./dist/cjs/ort.webgl.min.js" - }, - "default": { - "development": "./dist/ort.webgl.js", - "default": "./dist/ort.webgl.min.js" - } + "import": "./dist/esm/ort.webgl.min.js", + "require": "./dist/cjs/ort.webgl.min.js", + "default": "./dist/ort.webgl.min.js" }, "./webgpu": { - "import": { - "development": "./dist/esm/ort.webgpu.js", - "default": "./dist/esm/ort.webgpu.min.js" - }, - "require": { - "development": "./dist/cjs/ort.webgpu.js", - "default": "./dist/cjs/ort.webgpu.min.js" - }, - "default": { - "development": "./dist/ort.webgpu.js", - "default": "./dist/ort.webgpu.min.js" - } + "import": "./dist/esm/ort.webgpu.min.js", + "require": "./dist/cjs/ort.webgpu.min.js", + "default": "./dist/ort.webgpu.min.js" }, "./training": { - "import": { - "development": "./dist/esm/ort.training.wasm.js", - "default": "./dist/esm/ort.training.wasm.min.js" - }, - "require": { - "development": "./dist/cjs/ort.training.wasm.js", - "default": "./dist/cjs/ort.training.wasm.min.js" - }, - "default": { - "development": "./dist/ort.training.wasm.js", - "default": "./dist/ort.training.wasm.min.js" - } + "import": "./dist/esm/ort.training.wasm.min.js", + "require": "./dist/cjs/ort.training.wasm.min.js", + "default": "./dist/ort.training.wasm.min.js" } }, "types": "./types.d.ts", diff --git a/js/web/script/tsconfig.json b/js/web/script/tsconfig.json new file mode 100644 index 0000000000000..23b1fc96eb558 --- /dev/null +++ b/js/web/script/tsconfig.json @@ -0,0 +1,6 @@ +{ + "extends": "../../tsconfig.tools.json", + "compilerOptions": { + "sourceMap": true + } +} diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc new file mode 100644 index 0000000000000..812e9d7c2def0 --- /dev/null +++ b/js/web/test/data/ops/fused-conv.jsonc @@ -0,0 +1,112 @@ +[ + { + "name": "conv without bias addition A", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "Relu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40, 50, 60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [370, 470, 670, 770], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "T[1]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 0, 390, 430], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC conv without bias addition A", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "Relu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 11 }, + "cases": [ + { + "name": "T[2]", + "inputs": [ + { + "data": [10, 20, 30, 40, 50, 60, 70, 80, 90], + "dims": [1, 3, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [370, 470, 670, 770], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ] + }, + { + "name": "T[3]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 3, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 0, 390, 430], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/transpose.jsonc b/js/web/test/data/ops/transpose.jsonc index 285d14018e74d..e1edfa7e41513 100644 --- a/js/web/test/data/ops/transpose.jsonc +++ b/js/web/test/data/ops/transpose.jsonc @@ -166,5 +166,29 @@ ] } ] + }, + { + "name": "Transpose 5D - perms:[4, 3, 1, 0, 2]", + "operator": "Transpose", + "attributes": [{ "name": "perm", "data": [4, 3, 1, 0, 2], "type": "ints" }], + "cases": [ + { + "name": "T[3, 1, 2, 1, 4]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], + "dims": [3, 1, 2, 1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23, 4, 8, 12, 16, 20, 24], + "dims": [4, 1, 1, 3, 2], + "type": "float32" + } + ] + } + ] } ] diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 628e5408150f8..29acc07e118f9 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -164,7 +164,10 @@ async function initializeSession( session = await ort.InferenceSession.create(modelFilePath, sessionConfig); } } catch (e) { - Logger.error('TestRunner', `Failed to load model from file: ${modelFilePath}. Error: ${inspect(e)}`); + Logger.error( + 'TestRunner', + `Failed to load model from file: ${modelFilePath}. ` + + `Error: ${e.message} @ ${e.fileName}:${e.lineNumber}`); throw e; } diff --git a/js/web/tsconfig.json b/js/web/tsconfig.json index f7f5f961249bd..eb4674fab3ca9 100644 --- a/js/web/tsconfig.json +++ b/js/web/tsconfig.json @@ -4,9 +4,8 @@ "module": "Node16", "downlevelIteration": true, "declaration": true, - "declarationDir": "./types", "typeRoots": ["./node_modules/@webgpu/types", "./node_modules/@types", "../node_modules/@types"] }, - "include": ["lib", "script", "test"], + "include": ["lib", "test"], "exclude": ["lib/wasm/proxy-worker"] } diff --git a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc index 945c3aebce579..d0abf58922f88 100644 --- a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc +++ b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc @@ -32,8 +32,10 @@ Status ATen::Compute(OpKernelContext* p_ctx) const { aten_ops::ATenOperatorExecutor::Instance()(op_name_, overload_name_, input_size, dlpack_inputs.get(), output_size, dlpack_outputs.get()); for (size_t i = 0; i < output_size; ++i) { - ORT_RETURN_IF_ERROR( - p_ctx_internal->SetOutputMLValue(static_cast(i), dlpack::DlpackToOrtValue(dlpack_outputs[i]))); + if (dlpack_outputs[i]) { + ORT_RETURN_IF_ERROR( + p_ctx_internal->SetOutputMLValue(static_cast(i), dlpack::DlpackToOrtValue(dlpack_outputs[i]))); + } } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h index be9650d96b004..d72868cd8fa9f 100644 --- a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h +++ b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h @@ -10,7 +10,7 @@ namespace onnxruntime { namespace contrib { namespace aten_ops { -typedef bool (*IsTensorArgumentFunc)(const char* op_name, const char* overload_name, size_t index); +typedef bool (*IsCpuArgumentFunc)(const char* op_name, const char* overload_name, size_t index, bool is_input); typedef void (*ExecuteATenOperatorFunc)(const char* op_name, const char* overload_name, size_t input_size, DLManagedTensor** dlpack_inputs, size_t output_size, DLManagedTensor** dlpack_outputs); @@ -22,17 +22,17 @@ class ATenOperatorExecutor { return instance; } - void Initialize(void* p_is_tensor_argument_func_raw, void* p_execute_aten_op_func_raw) { - ORT_ENFORCE(p_is_tensor_argument_func_raw && p_execute_aten_op_func_raw); - p_is_tensor_argument_func_ = reinterpret_cast(p_is_tensor_argument_func_raw); + void Initialize(void* p_is_cpu_argument_func_raw, void* p_execute_aten_op_func_raw) { + ORT_ENFORCE(p_is_cpu_argument_func_raw && p_execute_aten_op_func_raw); + p_is_cpu_argument_func_ = reinterpret_cast(p_is_cpu_argument_func_raw); p_execute_aten_op_func_ = reinterpret_cast(p_execute_aten_op_func_raw); } bool IsInitialized() { return p_execute_aten_op_func_ != nullptr; } - bool IsTensorArgument(const std::string& op_name, const std::string& overload_name, size_t index) { - ORT_ENFORCE(p_is_tensor_argument_func_, "ATenOperatorExecutor is not initialized."); - return p_is_tensor_argument_func_(op_name.c_str(), overload_name.c_str(), index); + bool IsCpuArgument(const std::string& op_name, const std::string& overload_name, size_t index, bool is_input) { + ORT_ENFORCE(p_is_cpu_argument_func_, "ATenOperatorExecutor is not initialized."); + return p_is_cpu_argument_func_(op_name.c_str(), overload_name.c_str(), index, is_input); } void operator()(const std::string& op_name, const std::string& overload_name, size_t input_size, @@ -43,7 +43,7 @@ class ATenOperatorExecutor { } private: - IsTensorArgumentFunc p_is_tensor_argument_func_ = nullptr; + IsCpuArgumentFunc p_is_cpu_argument_func_ = nullptr; ExecuteATenOperatorFunc p_execute_aten_op_func_ = nullptr; }; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 5184dd99309b1..0fd8790e0d29d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -55,6 +55,7 @@ struct AttentionParameters { int v_hidden_size; // hidden size of V int v_head_size; // hidden size per head of V int num_heads; + int num_splits; bool is_unidirectional; bool past_present_share_buffer; bool do_rotary; @@ -95,9 +96,9 @@ struct GroupQueryAttentionParameters { int head_size; int kv_hidden_size; int kv_num_heads; + int num_splits; // number of splits for splitkv bool is_unidirectional; // causal float scale; - int num_splits; // number of splits for splitkv AttentionQkvFormat qkv_format; AttentionQkvFormat past_kv_format; }; diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index 0b55cb7804c61..694c40bf3eda6 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -16,7 +16,6 @@ #include #include -#include using onnxruntime::concurrency::ThreadPool; diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index 1dc85e6d345d7..00e82c9844b3d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -103,7 +103,8 @@ Status CheckInputs(const T* query, } if (past_key_dims[2] != past_value_dims[2]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' and 'past_value' shall have same dim 2 (past_sequence_length)"); + "Input 'past_key' and 'past_value' shall have same dim 2 (past_sequence_length). ", + past_key_dims[2], " vs ", past_value_dims[2]); } if (past_key_dims[3] != head_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -205,6 +206,7 @@ Status CheckInputs(const T* query, } } + int total_sequence_length = past_sequence_length + kv_sequence_length; AttentionMaskType mask_type = AttentionMaskType::MASK_NONE; if (key_padding_mask != nullptr) { mask_type = AttentionMaskType::MASK_UNKNOWN; @@ -215,13 +217,21 @@ Status CheckInputs(const T* query, } else if (mask_dims[0] == static_cast(3) * static_cast(batch_size) + static_cast(2)) { mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START; } - } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && mask_dims[1] == static_cast(kv_sequence_length)) { + } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && + mask_dims[1] == static_cast(kv_sequence_length)) { + mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; + } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && + mask_dims[1] == static_cast(total_sequence_length)) { mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; + } else if (mask_dims.size() == 3 && mask_dims[0] == static_cast(batch_size) && + mask_dims[1] == static_cast(sequence_length) && + mask_dims[2] == static_cast(total_sequence_length)) { + mask_type = AttentionMaskType::MASK_3D_ATTENTION; } if (mask_type == AttentionMaskType::MASK_UNKNOWN) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key_padding_mask' shape shall be (batch_size) or (batch_size, kv_sequence_length)"); + "Input 'key_padding_mask' shape shall be 1D, 2D, or 3D"); } } @@ -256,7 +266,6 @@ Status CheckInputs(const T* query, } } - int total_sequence_length = past_sequence_length + kv_sequence_length; bool broadcast_res_pos_bias = false; if (relative_position_bias != nullptr) { const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc new file mode 100644 index 0000000000000..4a266af789250 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/bert/rotary_embedding.h" +#include "contrib_ops/cpu/bert/rotary_embedding_helper.h" + +#include "core/platform/threadpool.h" + +using onnxruntime::concurrency::ThreadPool; +using namespace onnxruntime::contrib::rotary_embedding_helper; + +namespace onnxruntime { +namespace contrib { + +// These ops are internal-only, so register outside of onnx +ONNX_OPERATOR_TYPED_KERNEL_EX( + RotaryEmbedding, + kMSDomain, + 1, + float, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("M", DataTypeImpl::GetTensorType()), + RotaryEmbedding); + +template +RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) { + scale = info.GetAttrOrDefault("scale", 1.0); + interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); +} + +template +Status RotaryEmbedding::Compute(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* position_ids = context->Input(1); + const Tensor* cos_cache = context->Input(2); + const Tensor* sin_cache = context->Input(3); + + RotaryParameters parameters = {}; + ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs(input, + position_ids, + cos_cache, + sin_cache, + ¶meters)); + + Tensor* output = context->Output(0, input->Shape()); + + if (parameters.sequence_length > parameters.max_sequence_length) { + // Launch update_cos_sin_cache kernel with scale + ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); + } + + const T* input_src = input->Data(); + const int64_t* pos_ids_data = position_ids->Data(); + const T* cos_cache_data = cos_cache->Data(); + const T* sin_cache_data = sin_cache->Data(); + T* output_dest = output->MutableData(); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int head_size = parameters.head_size; + const int position_ids_format = parameters.position_ids_format; + const int half_head_size = head_size / 2; + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + auto* tp = context->GetOperatorThreadPool(); + + const int loop_len = batch_size * sequence_length * num_heads; + const double cost = static_cast(head_size); + ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { + const int b = static_cast((ptr / num_heads) / sequence_length); + const int s = static_cast((ptr / num_heads) % sequence_length); + const int n = static_cast(ptr % num_heads); + + const int block_offset = b * sequence_length * num_heads + s * num_heads + n; + const int data_offset = block_offset * head_size; + + const T* input_data = input_src + data_offset; + T* output_data = output_dest + data_offset; + + // Cache is (M, H/2) + const int position_id = (position_ids_format == 0) + ? static_cast(pos_ids_data[0]) + s + : static_cast(pos_ids_data[b * sequence_length + s]); + const int cache_offset = position_id * half_head_size; + const T* cos_data = cos_cache_data + cache_offset; + const T* sin_data = sin_cache_data + cache_offset; + + int cache_idx = 0; + T sign = 0; + int j = 0; + for (int i = 0; i < head_size; i++) { + if (interleaved) { + cache_idx = (i / 2) % half_head_size; + sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1); + j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign + } else { + cache_idx = i % half_head_size; + sign = (i < half_head_size) ? static_cast(-1) : static_cast(1); + j = (i + half_head_size) % head_size; + } + output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; + } + } + }); + + return Status::OK(); +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h new file mode 100644 index 0000000000000..be834a66cdc69 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace contrib { + +template +class RotaryEmbedding final : public OpKernel { + public: + RotaryEmbedding(const OpKernelInfo& info); + Status Compute(OpKernelContext* context) const override; + + protected: + float scale; + bool interleaved; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h new file mode 100644 index 0000000000000..cf8080800e072 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h @@ -0,0 +1,121 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/common.h" + +namespace onnxruntime { +namespace contrib { +namespace rotary_embedding_helper { + +// Parameters deduced from node attributes and inputs/outputs. +struct RotaryParameters { + int batch_size; // Batch size used by input + int sequence_length; // Sequence length used by input + int hidden_size; // Hidden size used by input + int head_size; // Head size used by cos/sin cache * 2 + int num_heads; // num_heads = hidden_size / head_size + int max_sequence_length; // Sequence length used by cos/sin cache + int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) +}; + +template +Status CheckInputs(const T* input, + const T* position_ids, + const T* cos_cache, + const T* sin_cache, + void* parameters) { + // input : (batch_size, sequence_length, hidden_size) + // position ids : (1) or (batch_size, sequence_length) + // cos cache : (max_sequence_length, head_size / 2) + // sin cache : (max_sequence_length, head_size / 2) + + // Check input + const auto& input_dims = input->Shape().GetDims(); + if (input_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 dimensions, got ", + input_dims.size()); + } + // Check position_ids + const auto& position_ids_dims = position_ids->Shape().GetDims(); + if (!onnxruntime::IsScalarOr1ElementVector(position_ids) && position_ids_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' is expected to have 0, 1, or 2 ", + "dimensions, got ", position_ids_dims.size()); + } + // Check cos_cache and sin_cache + const auto& cos_cache_dims = cos_cache->Shape().GetDims(); + if (cos_cache_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' is expected to have 2 dimensions, got ", + cos_cache_dims.size()); + } + const auto& sin_cache_dims = sin_cache->Shape().GetDims(); + if (sin_cache_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' is expected to have 2 dimensions, got ", + sin_cache_dims.size()); + } + if (cos_cache_dims[0] != sin_cache_dims[0] || cos_cache_dims[1] != sin_cache_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'cos_cache' and 'sin_cache' are expected to have ", + "the same shape"); + } + + // Get attributes from inputs + int batch_size = static_cast(input_dims[0]); + int sequence_length = static_cast(input_dims[1]); + int hidden_size = static_cast(input_dims[2]); + int max_sequence_length = static_cast(cos_cache_dims[0]); + int head_size = static_cast(cos_cache_dims[1]) * 2; + int num_heads = hidden_size / head_size; + int position_ids_format = -1; + + // Check position_ids input shapes + if (!onnxruntime::IsScalarOr1ElementVector(position_ids)) { + if (batch_size != static_cast(position_ids_dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 0 should be of size ", + "batch_size, got ", position_ids_dims[0]); + } + if (sequence_length != static_cast(position_ids_dims[1])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 1 should be of size ", + "sequence_length, got ", position_ids_dims[1]); + } + position_ids_format = 1; + } else { + position_ids_format = 0; + } + // Check cos_cache input shapes + if (max_sequence_length != static_cast(cos_cache_dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ", + "max_sequence_length, got ", cos_cache_dims[0]); + } + if ((head_size / 2) != static_cast(cos_cache_dims[1])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ", + "head_size / 2, got ", cos_cache_dims[1]); + } + // Check sin_cache input shapes + if (max_sequence_length != static_cast(sin_cache_dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 0 should be same as ", + "max_sequence_length, got ", sin_cache_dims[0]); + } + if ((head_size / 2) != static_cast(sin_cache_dims[1])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 1 should be same as ", + "head_size / 2, got ", sin_cache_dims[1]); + } + + // Set rotary parameters + if (parameters != nullptr) { + RotaryParameters* output_parameters = reinterpret_cast(parameters); + output_parameters->batch_size = batch_size; + output_parameters->sequence_length = sequence_length; + output_parameters->hidden_size = hidden_size; + output_parameters->head_size = head_size; + output_parameters->num_heads = num_heads; + output_parameters->max_sequence_length = max_sequence_length; + output_parameters->position_ids_format = position_ids_format; + } + + return Status::OK(); +} + +} // namespace rotary_embedding_helper +} // namespace contrib +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 0ec5088808656..f9d9b13f0fedc 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -13,12 +13,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GridSample); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Attention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, BeamSearch); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, WhisperBeamSearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer); @@ -27,6 +29,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordC class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); // backward compatibility class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4); #ifndef ORT_MINIMAL_BUILD class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4); #endif @@ -122,6 +126,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipSimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu); @@ -244,12 +250,14 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -262,6 +270,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, // backward compatibility BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifndef ORT_MINIMAL_BUILD BuildKernelCreateInfo, #endif @@ -295,6 +305,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h new file mode 100644 index 0000000000000..cb8e97a592d8c --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h @@ -0,0 +1,202 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +namespace onnxruntime { +namespace contrib { + +#if defined(_MSC_VER) +#define FORCEINLINE __forceinline +#else +#define FORCEINLINE __attribute__((always_inline)) inline +#endif + +typedef enum Bnb_DataType_t { + FP4 = 0, + NF4 = 1, +} Bnb_DataType_t; + +FORCEINLINE uint8_t QuantizeOneFP4(float x) { + // FP4 with bias of 3 + // first bit is a sign + // subnormals + // 0b000 = 0 + // 0b001 = 0.0625 + // 0b110 = 2 + // 0b111 = 3 + // 0b100 = 4 + // 0b101 = 6 + // 0b010 = 8 + // 0b011 = 12 + + // we do a binary search + // the pivots are divided by 12 (the FP4 absmax) + // since we assum input data is in [-1.0, 1.0] + + // !be careful here, its easy to make a mistake + // that is difficult to noice if you add an extra + // zero somewhere! + + uint8_t sign = x < 0 ? 0b1000 : 0b0000; + x = fabsf(x); + if (x > 0.29166667f) { + if (x > 0.583333f) { + if (x > 0.8333333f) { + return 0b0011 + sign; + } else { + return 0b0010 + sign; + } + } else if (x > 0.4166667f) { + return 0b101 + sign; + } else { + return 0b100 + sign; + } + } else if (x > 0.0859375f) { + if (x > 0.20833333f) { + return 0b0111 + sign; + } else { + return 0b0110 + sign; + } + } else if (x > 0.00260417f) { + return 0b0001 + sign; + } else { + return 0b0000 + sign; + } +} + +FORCEINLINE uint8_t QuantizeOneNF4(float x) { + if (x > 0.03979014977812767f) { + if (x > 0.3893125355243683f) { // 1 + if (x > 0.6427869200706482f) { // 11 + if (x > 0.8614784181118011f) { // 111 + return 0b1111; + } else { + return 0b1110; + } + } else if (x > 0.5016634166240692f) { // 110 + return 0b1101; + } else { + return 0b1100; + } + } else if (x > 0.2035212516784668f) { // 10 + if (x > 0.2920137718319893f) { // 101 + return 0b1011; + } else { + return 0b1010; + } + } else if (x > 0.1202552504837513f) { // 100 + return 0b1001; + } else { + return 0b1000; + } + } else if (x > -0.33967943489551544f) { // 0 + if (x > -0.13791173323988914f) { // 01 + if (x > -0.045525018125772476f) { // 011 + return 0b0111; + } else { + return 0b0110; + } + } else if (x > -0.23460740596055984f) { // 010 + return 0b0101; + } else { + return 0b0100; + } + } else if (x > -0.6106329262256622f) { // 00 + if (x > -0.4599952697753906f) { // 001 + return 0b0011; + } else { + return 0b0010; + } + } else if (x > -0.8480964004993439f) { // 000 + return 0b0001; + } else { + return 0b0000; + } +} + +template +FORCEINLINE uint8_t QuantizeOneBnb4(float x) { + if constexpr (DATA_TYPE == FP4) + return QuantizeOneFP4(x); + else + return QuantizeOneNF4(x); +} + +template +FORCEINLINE void QuantizeBlockBnb4(const T* src, uint8_t* dst, T& absmax_block, int32_t block_idx, int32_t numel) { + float local_absmax = 0.0f; + + int32_t block_len = std::min(block_size, numel - block_idx * block_size); + int32_t src_offset = block_idx * block_size; + int32_t dst_offset = block_idx * block_size / 2; + + for (int32_t idx = 0; idx < block_len; idx++) { + const float v = static_cast(src[src_offset + idx]); + local_absmax = fmaxf(local_absmax, fabsf(v)); + } + + absmax_block = static_cast(local_absmax); + const float reciprocal_absmax = local_absmax ? 1.0f / local_absmax : 0.0f; + + for (int32_t idx = 0; idx < block_len; idx += 2) { + const float v0 = static_cast(src[src_offset + idx]) * reciprocal_absmax; + const uint8_t vi0 = QuantizeOneBnb4(v0); + + const float v1 = (idx + 1 < block_len) ? static_cast(src[src_offset + idx + 1]) * reciprocal_absmax : 0; + const uint8_t vi1 = QuantizeOneBnb4(v1); + + dst[dst_offset + idx / 2] = (vi0 << 4) | vi1; + } +} + +static float fp4_qaunt_map[16] = {0.00000000f, 5.208333333e-03f, 0.66666667f, 1.00000000f, + 0.33333333f, 0.50000000f, 0.16666667f, 0.25000000f, + -0.00000000f, -5.208333333e-03f, -0.66666667f, -1.00000000f, + -0.33333333f, -0.50000000f, -0.16666667f, -0.25000000f}; + +static float nf4_qaunt_map[16] = {-1.0f, + -0.6961928009986877f, + -0.5250730514526367f, + -0.39491748809814453f, + -0.28444138169288635f, + -0.18477343022823334f, + -0.09105003625154495f, + 0.0f, + 0.07958029955625534f, + 0.16093020141124725f, + 0.24611230194568634f, + 0.33791524171829224f, + 0.44070982933044434f, + 0.5626170039176941f, + 0.7229568362236023f, + 1.0f}; + +template +FORCEINLINE T DequantizeOneBnb4(uint8_t x) { + if constexpr (DATA_TYPE == FP4) + return static_cast(fp4_qaunt_map[x]); + else + return static_cast(nf4_qaunt_map[x]); +} + +template +FORCEINLINE void DequantizeBlockBnb4(const uint8_t* src, T* dst, T absmax_block, int32_t block_idx, int32_t numel) { + int32_t block_len = std::min(block_size, numel - block_idx * block_size); + int32_t src_offset = block_idx * block_size / 2; + int32_t dst_offset = block_idx * block_size; + + for (int32_t idx = 0; idx < block_len; idx += 2) { + const uint8_t val = src[src_offset + idx / 2]; + + dst[dst_offset + idx] = DequantizeOneBnb4(val >> 4) * absmax_block; + if (idx + 1 < block_len) dst[dst_offset + idx + 1] = DequantizeOneBnb4(val & 0xF) * absmax_block; + } +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h new file mode 100644 index 0000000000000..5ddb77e5b5ee3 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "blockwise_quant_block_bnb4.h" + +#include + +#include "core/common/safeint.h" +#include "core/framework/float16.h" +#include "core/platform/threadpool.h" +#include + +namespace onnxruntime { +namespace contrib { + +template +void QuantizeBlockwiseBnb4( + uint8_t* dst, // shape: [(N * K + 1) / 2] + const T* src, // shape: [N, K] + T* absmax, // shape: [(N * K + block_size - 1) / block_size] + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + int32_t numel = N * K; + int32_t total_block_count = (numel + block_size - 1) / block_size; + + concurrency::ThreadPool::TryBatchParallelFor( + thread_pool, + total_block_count, + [&](ptrdiff_t block_idx) { + QuantizeBlockBnb4( + src, + dst, + absmax[block_idx], + static_cast(block_idx), + numel); + }, + 0); +} + +#define QuantizeBlockwiseBn4DataTyped(block_size, quant_type) \ + if (quant_type == FP4) \ + QuantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); \ + else \ + QuantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); + +template +void QuantizeBlockwiseBnb4( + uint8_t* dst, // shape: [(N * K + 1) / 2] + const T* src, // shape: [N, K] + T* absmax, // shape: [(N * K + block_size - 1) / block_size] + int32_t block_size, + int32_t quant_type, + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + ORT_ENFORCE( + quant_type == FP4 || quant_type == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + + if (block_size == 16) { + QuantizeBlockwiseBn4DataTyped(16, quant_type); + } else if (block_size == 32) { + QuantizeBlockwiseBn4DataTyped(32, quant_type); + } else if (block_size == 64) { + QuantizeBlockwiseBn4DataTyped(64, quant_type); + } else if (block_size == 128) { + QuantizeBlockwiseBn4DataTyped(128, quant_type); + } else if (block_size == 256) { + QuantizeBlockwiseBn4DataTyped(256, quant_type); + } else { + ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported."); + } +} + +#undef QuantizeBlockwiseBn4DataTyped + +template +void DequantizeBlockwiseBnb4( + T* dst, // shape: [N, K] + const uint8_t* src, // shape: [(N * K + 1) / 2)] + const T* absmax, // shape: [(N * K + block_size - 1) / block_size] + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + int32_t numel = N * K; + int32_t total_block_count = (numel + block_size - 1) / block_size; + + concurrency::ThreadPool::TryBatchParallelFor( + thread_pool, + total_block_count, + [&](ptrdiff_t block_idx) { + DequantizeBlockBnb4( + src, + dst, + absmax[block_idx], + static_cast(block_idx), + numel); + }, + 0); +} + +#define DequantizeBlockwiseBn4DataTyped(block_size, quant_type) \ + if (quant_type == FP4) \ + DequantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); \ + else \ + DequantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); + +template +void DequantizeBlockwiseBnb4( + T* dst, // shape: [N, K] + const uint8_t* src, // shape: [(N * K + 1) / 2)] + const T* absmax, // shape: [(N * K + block_size - 1) / block_size] + int32_t block_size, + int32_t quant_type, + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + ORT_ENFORCE( + quant_type == FP4 || quant_type == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + + if (block_size == 16) { + DequantizeBlockwiseBn4DataTyped(16, quant_type); + } else if (block_size == 32) { + DequantizeBlockwiseBn4DataTyped(32, quant_type); + } else if (block_size == 64) { + DequantizeBlockwiseBn4DataTyped(64, quant_type); + } else if (block_size == 128) { + DequantizeBlockwiseBn4DataTyped(128, quant_type); + } else if (block_size == 256) { + DequantizeBlockwiseBn4DataTyped(256, quant_type); + } else { + ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported."); + } +} + +#undef DequantizeBlockwiseBn4DataTyped + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc new file mode 100644 index 0000000000000..b898c956b6e6a --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/common.h" +#include "dequantize_blockwise_bnb4.h" +#include "core/mlas/inc/mlas.h" + +namespace onnxruntime { +namespace contrib { + +class MatMulBnb4 final : public OpKernel { + public: + MatMulBnb4(const OpKernelInfo& info) : OpKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("quant_type", &quant_type_)); + ORT_ENFORCE( + quant_type_ == FP4 || quant_type_ == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + is_training_mode_ = static_cast(info.GetAttrOrDefault("training_mode", static_cast(0))); + transB_ = static_cast(info.GetAttrOrDefault("transB", static_cast(1))); + } + + Status Compute(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t quant_type_; + bool is_training_mode_; + bool transB_; +}; + +Status MatMulBnb4::Compute(OpKernelContext* ctx) const { + concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); + + const Tensor* a = ctx->Input(0); + const Tensor* b_quant = ctx->Input(1); + const Tensor* absmax = ctx->Input(2); + + const float* a_data = a->Data(); + const uint8_t* b_quant_data = b_quant->Data(); + const float* absmax_data = absmax->Data(); + + AllocatorPtr allocator; + auto status = ctx->GetTempSpaceAllocator(&allocator); + ORT_RETURN_IF_ERROR(status); + auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); + DequantizeBlockwiseBnb4( + tmp_b_data_ptr.get(), + b_quant_data, + absmax_data, + static_cast(block_size_), + static_cast(quant_type_), + static_cast(N_), + static_cast(K_), + thread_pool); + + constexpr bool transa = false; + const bool transb = transB_; + TensorShape b_shape({N_, K_}); + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, transa, transb)); + + Tensor* y = ctx->Output(0, helper.OutputShape()); + + // Bail out early if the output is going to be empty + if (y->Shape().Size() == 0) return Status::OK(); + + auto* y_data = y->MutableData(); + + const size_t max_len = helper.OutputOffsets().size(); + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + const size_t lda = helper.Lda(transa); + const size_t ldb = helper.Ldb(transb); + + // TODO: implement with native kernel + std::vector data(max_len); + for (size_t i = 0; i < max_len; i++) { + data[i].BIsPacked = false; + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = 1.f; + data[i].beta = 0.0f; + } + MlasGemmBatch(CblasNoTrans, CblasTrans, M, N, K, data.data(), max_len, thread_pool); + + return Status::OK(); +} + +ONNX_OPERATOR_KERNEL_EX( + MatMulBnb4, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulBnb4); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc new file mode 100644 index 0000000000000..c72d811170a27 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -0,0 +1,120 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/common.h" +#include "core/mlas/inc/mlas_q4.h" + +namespace onnxruntime { +namespace contrib { + +class MatMulNBits final : public OpKernel { + public: + MatMulNBits(const OpKernelInfo& info) : OpKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + ORT_ENFORCE(nbits_ == 4, + "Only 4b quantization is supported for MatMulNBits op," + " additional bits support is planned."); + } + + Status Compute(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t nbits_; + bool column_wise_quant_{true}; +}; + +Status MatMulNBits::Compute(OpKernelContext* ctx) const { + concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); + + const Tensor* a = ctx->Input(0); + const Tensor* b = ctx->Input(1); + const Tensor* scales = ctx->Input(2); + const Tensor* zero_points = ctx->Input(3); + + const auto* a_data = a->Data(); + const uint8_t* b_data = b->Data(); + const auto* scales_data = scales->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + + AllocatorPtr allocator; + auto status = ctx->GetTempSpaceAllocator(&allocator); + ORT_RETURN_IF_ERROR(status); + auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); + + // dequantize b, only 4b quantization is supported for now + MlasDequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + zero_points_data, // quantization zero points + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + +#if 0 // for debug + auto tm_b_data_ptr_trans = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); + MlasTranspose(tmp_b_data_ptr.get(), tm_b_data_ptr_trans.get(), N_, K_); +#endif + + TensorShape b_shape({N_, K_}); + + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); + + Tensor* y = ctx->Output(0, helper.OutputShape()); + + // Bail out early if the output is going to be empty + if (y->Shape().Size() == 0) + return Status::OK(); + + auto* y_data = y->MutableData(); + + const size_t max_len = helper.OutputOffsets().size(); + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + const size_t lda = helper.Lda(false); + const size_t ldb = helper.Ldb(true); + + // TODO: implement with native kernel + std::vector data(max_len); + for (size_t i = 0; i < max_len; i++) { + data[i].BIsPacked = false; + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = 1.f; + data[i].beta = 0.0f; + } + MlasGemmBatch(CblasNoTrans, CblasTrans, + M, N, K, data.data(), max_len, thread_pool); + + return Status::OK(); +} + +ONNX_OPERATOR_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulNBits); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc index e86a12d9fb873..4e103c2556a7a 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc @@ -20,20 +20,29 @@ namespace contrib { kCpuExecutionProvider, \ KernelDefBuilder() \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - SkipLayerNorm); + SkipLayerNorm); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + SkipSimplifiedLayerNormalization, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + SkipLayerNorm); REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(double) -template -SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) +template +SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info) { ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); ORT_ENFORCE(epsilon_ >= 0); } -template -Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { +template +Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { const Tensor* input = p_ctx->Input(0); const Tensor* skip = p_ctx->Input(1); const Tensor* gamma = p_ctx->Input(2); @@ -102,10 +111,16 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { } mean = mean / hidden_size; - mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_); + if (simplified) { + mean_square = sqrt(mean_square / hidden_size + epsilon_); + } else { + mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_); + } for (int64_t h = 0; h < hidden_size; h++) { - if (nullptr == beta_data) { + if (simplified) { + p_output[h] = p_output[h] / mean_square * gamma_data[h]; + } else if (nullptr == beta_data) { p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h]; } else { p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h] + beta_data[h]; diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h index 7723541cb6b18..69edf4609e340 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h @@ -10,7 +10,7 @@ namespace onnxruntime { namespace contrib { -template +template class SkipLayerNorm final : public OpKernel { public: SkipLayerNorm(const OpKernelInfo& op_kernel_info); diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index c391f47e1927b..93cda00e5a3c3 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -52,28 +52,38 @@ namespace contrib { kCpuExecutionProvider, \ (*KernelDefBuilder::Create()) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - transformers::BeamSearch); + transformers::BeamSearch); \ + \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + WhisperBeamSearch, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + transformers::WhisperBeamSearch); REGISTER_KERNEL_TYPED(float) namespace transformers { void BeamSearch::Init(const OpKernelInfo& info) { - parameters_.ParseFromAttributes(info); + parameters_->ParseFromAttributes(info); // Model_type could be either 0 (GPT-2), 1 (encoder-decoder like T5), or 2 (Whisper) - ORT_ENFORCE(parameters_.model_type == IGenerationParameters::kModelTypeGpt || - parameters_.model_type == IGenerationParameters::kModelTypeT5 || - parameters_.model_type == IGenerationParameters::kModelTypeWhisper); + ORT_ENFORCE(parameters_->model_type == IGenerationParameters::kModelTypeGpt || + parameters_->model_type == IGenerationParameters::kModelTypeT5 || + parameters_->model_type == IGenerationParameters::kModelTypeWhisper); ONNX_NAMESPACE::GraphProto proto; - if (parameters_.model_type != IGenerationParameters::kModelTypeGpt) { + if (parameters_->model_type != IGenerationParameters::kModelTypeGpt) { // Make sure the encoder sub-graph attribute is present for the T5 and Whisper models. ORT_ENFORCE(info.GetAttr("encoder", &proto).IsOK()); } - if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { + if (parameters_->model_type == IGenerationParameters::kModelTypeGpt) { // Check if the init_decoder sub-graph attribute is present for the GPT2 model. if (info.GetAttr("init_decoder", &proto).IsOK()) { has_init_decoder_ = true; @@ -90,11 +100,11 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) { const auto& node = Node(); - if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { + if (parameters_->model_type == IGenerationParameters::kModelTypeGpt) { if (attribute_name == "decoder") { ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name, - subgraph_session_state, parameters_); + subgraph_session_state, *parameters_); auto status = res.first; if (!status.IsOK()) { @@ -109,7 +119,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, // updated once for the 'decoder' attribute). In future, find a way to update 'parameters' only once based on only one subgraph // attribute. auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name, - subgraph_session_state, parameters_); + subgraph_session_state, *parameters_); auto status = res.first; if (!status.IsOK()) { @@ -119,7 +129,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, init_run_gpt_subgraph_ = std::move(res.second); init_run_decoder_feeds_fetches_manager_ = init_run_gpt_subgraph_->GetFeedsFetchesManager(); } - } else if (parameters_.model_type == IGenerationParameters::kModelTypeT5) { + } else if (parameters_->model_type == IGenerationParameters::kModelTypeT5) { if (attribute_name == "encoder") { ORT_ENFORCE(t5_encoder_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); @@ -129,7 +139,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, ORT_RETURN_IF_ERROR(t5_encoder_subgraph_->Setup(session_state, subgraph_session_state)); encoder_feeds_fetches_manager_ = t5_encoder_subgraph_->GetFeedsFetchesManager(); - if (parameters_.decoder_start_token_id < 0) { + if (parameters_->decoder_start_token_id < 0) { ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 2, "Encoder subgraph shall have 2 inputs when decoder_start_token_id attribute is empty"); } else { @@ -144,12 +154,12 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, subgraph_session_state.GetGraphViewer()); ORT_RETURN_IF_ERROR(t5_decoder_subgraph_->Setup(session_state, subgraph_session_state)); decoder_feeds_fetches_manager_ = t5_decoder_subgraph_->GetFeedsFetchesManager(); - parameters_.SetSubgraphParameters(t5_decoder_subgraph_->vocab_size, - t5_decoder_subgraph_->num_heads, - t5_decoder_subgraph_->head_size, - t5_decoder_subgraph_->num_layers); + parameters_->SetSubgraphParameters(t5_decoder_subgraph_->vocab_size, + t5_decoder_subgraph_->num_heads, + t5_decoder_subgraph_->head_size, + t5_decoder_subgraph_->num_layers); } - } else if (parameters_.model_type == IGenerationParameters::kModelTypeWhisper) { + } else if (parameters_->model_type == IGenerationParameters::kModelTypeWhisper) { if (attribute_name == "encoder") { ORT_ENFORCE(whisper_encoder_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); @@ -169,10 +179,10 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, subgraph_session_state.GetGraphViewer()); ORT_RETURN_IF_ERROR(whisper_decoder_subgraph_->Setup(session_state, subgraph_session_state)); decoder_feeds_fetches_manager_ = whisper_decoder_subgraph_->GetFeedsFetchesManager(); - parameters_.SetSubgraphParameters(whisper_decoder_subgraph_->vocab_size, - whisper_decoder_subgraph_->num_heads, - whisper_decoder_subgraph_->head_size, - whisper_decoder_subgraph_->num_layers); + parameters_->SetSubgraphParameters(whisper_decoder_subgraph_->vocab_size, + whisper_decoder_subgraph_->num_heads, + whisper_decoder_subgraph_->head_size, + whisper_decoder_subgraph_->num_layers); } } @@ -197,9 +207,9 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); // Make a copy of parameters since we will update it based on inputs later - BeamSearchParameters parameters = parameters_; + BeamSearchParameters parameters = *parameters_; - if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { + if (parameters.model_type == IGenerationParameters::kModelTypeGpt) { if (!gpt_subgraph_->IsOutputFloat16()) { // Output float32 BeamSearchGpt impl{ *ctx_internal, @@ -253,7 +263,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { ORT_ENFORCE(encoder_session_state, "Subgraph SessionState was not found for 'encoder' attribute."); ORT_ENFORCE(encoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph."); - if (parameters_.model_type == IGenerationParameters::kModelTypeT5) { + if (parameters.model_type == IGenerationParameters::kModelTypeT5) { // Subgraph has constraint that the output is either float or float16 if (!t5_decoder_subgraph_->IsOutputFloat16()) { BeamSearchT5 impl{ @@ -303,7 +313,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { } // Change the CreateEncoderInputs function for Whisper shapes - if (parameters_.model_type == IGenerationParameters::kModelTypeWhisper) { + if (parameters.model_type == IGenerationParameters::kModelTypeWhisper) { // Subgraph has constraint that the output is either float or float16 if (!whisper_decoder_subgraph_->IsOutputFloat16()) { BeamSearchWhisper impl{ @@ -319,7 +329,10 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { update_decoder_feeds_func_ ? update_decoder_feeds_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds, expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer, expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer, - create_beam_scorer_func_}; + create_beam_scorer_func_, + update_decoder_cross_qk_func_ ? update_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::UpdateDecoderCrossQK, + finalize_decoder_cross_qk_func_ ? finalize_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::FinalizeDecoderCrossQK}; + #ifdef USE_CUDA ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_)); #endif @@ -340,7 +353,10 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { update_decoder_feeds_fp16_func_ ? update_decoder_feeds_fp16_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds, expand_buffer_float_func_, expand_buffer_float16_func_, - create_beam_scorer_func_}; + create_beam_scorer_func_, + update_decoder_cross_qk_func_ ? update_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::UpdateDecoderCrossQK, + finalize_decoder_cross_qk_func_ ? finalize_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::FinalizeDecoderCrossQK}; + #ifdef USE_CUDA ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_)); #endif @@ -354,6 +370,10 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { ORT_THROW("Model type is not supported."); } +Status WhisperBeamSearch::Compute(OpKernelContext* ctx) const { + return BeamSearch::Compute(ctx); +} + } // namespace transformers } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h index 93b7e08fabf94..fad7dcc75bcab 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h @@ -25,11 +25,12 @@ using namespace onnxruntime::controlflow; // namespace of IControlFlowKernel class BeamSearch : public IControlFlowKernel { public: - BeamSearch(const OpKernelInfo& info) + BeamSearch(const OpKernelInfo& info, std::unique_ptr param = std::make_unique()) : IControlFlowKernel(info), encoder_feeds_fetches_manager_(nullptr), decoder_feeds_fetches_manager_(nullptr), dumper_(nullptr) { + parameters_.swap(param); Init(info); } @@ -88,12 +89,16 @@ class BeamSearch : public IControlFlowKernel { const GenerationDeviceHelper::UpdateDecoderFeedsFunc& update_decoder_feeds_fp16_func, const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_int32_func, const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_float_func, - const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_float16_func) { + const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_float16_func, + const GenerationDeviceHelper::UpdateDecoderCrossQKFunc& update_decoder_cross_qk_func, + const GenerationDeviceHelper::FinalizeDecoderCrossQKFunc& finalize_decoder_cross_qk_func) { update_decoder_feeds_func_ = update_decoder_feeds_func; update_decoder_feeds_fp16_func_ = update_decoder_feeds_fp16_func; expand_buffer_int32_func_ = expand_buffer_int32_func; expand_buffer_float_func_ = expand_buffer_float_func; expand_buffer_float16_func_ = expand_buffer_float16_func; + update_decoder_cross_qk_func_ = update_decoder_cross_qk_func; + finalize_decoder_cross_qk_func_ = finalize_decoder_cross_qk_func; } #ifdef USE_CUDA @@ -101,7 +106,7 @@ class BeamSearch : public IControlFlowKernel { int cuda_device_arch_ = 0; #endif - private: + protected: // Device specific functions GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_; GenerationDeviceHelper::TopkFunc topk_func_; @@ -172,9 +177,21 @@ class BeamSearch : public IControlFlowKernel { IConsoleDumper* dumper_; - BeamSearchParameters parameters_; + std::unique_ptr parameters_; bool has_init_decoder_ = false; + + GenerationDeviceHelper::UpdateDecoderCrossQKFunc update_decoder_cross_qk_func_; + + GenerationDeviceHelper::FinalizeDecoderCrossQKFunc finalize_decoder_cross_qk_func_; +}; + +class WhisperBeamSearch : public BeamSearch { + public: + WhisperBeamSearch(const OpKernelInfo& info) + : BeamSearch(info, std::unique_ptr(new WhisperBeamSearchParameters())) {} + + Status Compute(OpKernelContext* ctx) const override; }; } // namespace transformers diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h index 8832b4314bad3..29b38fc234de5 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h @@ -17,34 +17,35 @@ struct BeamSearchState : IBeamSearchState { BeamSearchState(const IGenerationParameters& parameters, AllocatorPtr allocator, int has_decoder_masked_attention, - bool use_position) { + bool use_position, + Stream* stream) { size_t batch_beam_size = SafeInt(parameters.batch_size) * parameters.num_beams; size_t next_token_size = SafeInt(batch_beam_size) * parameters.vocab_size; - this->next_token_logits = AllocateBuffer(allocator, next_token_logits_buffer_, next_token_size); - this->next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size); - this->next_tokens = AllocateBuffer(allocator, next_tokens_buffer_, SafeInt(2) * batch_beam_size); - this->next_indices = AllocateBuffer(allocator, next_indices_buffer_, SafeInt(2) * batch_beam_size); - this->next_scores = AllocateBuffer(allocator, next_scores_buffer_, SafeInt(2) * batch_beam_size); + this->next_token_logits = AllocateBuffer(allocator, next_token_logits_buffer_, next_token_size, stream); + this->next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size, stream); + this->next_tokens = AllocateBuffer(allocator, next_tokens_buffer_, SafeInt(2) * batch_beam_size, stream); + this->next_indices = AllocateBuffer(allocator, next_indices_buffer_, SafeInt(2) * batch_beam_size, stream); + this->next_scores = AllocateBuffer(allocator, next_scores_buffer_, SafeInt(2) * batch_beam_size, stream); constexpr size_t max_parts_of_vocab = 128; size_t topk_buffer_size = SafeInt(batch_beam_size) * (max_parts_of_vocab + 1) * parameters.num_beams * 2 * 2; - this->topk_buffer = AllocateBuffer(allocator, topk_temp_buffer_, topk_buffer_size); + this->topk_buffer = AllocateBuffer(allocator, topk_temp_buffer_, topk_buffer_size, stream); if (allocator->Info().device.Type() == OrtDevice::GPU) { size_t sequences_elements = SafeInt(2) * batch_beam_size * parameters.max_length; - this->sequences_device = AllocateBuffer(allocator, sequences_device_buffer_, sequences_elements); + this->sequences_device = AllocateBuffer(allocator, sequences_device_buffer_, sequences_elements, stream); } if (use_position) { - this->next_positions = AllocateBuffer(allocator, next_positions_buffer_, batch_beam_size); + this->next_positions = AllocateBuffer(allocator, next_positions_buffer_, batch_beam_size, stream); } - this->beam_scores = AllocateBuffer(allocator, beam_scores_buffer_, batch_beam_size); + this->beam_scores = AllocateBuffer(allocator, beam_scores_buffer_, batch_beam_size, stream); if (parameters.output_scores) { size_t elements = SafeInt(parameters.max_length - parameters.sequence_length) * parameters.batch_size * parameters.num_beams * parameters.vocab_size; - this->scores = AllocateBuffer(allocator, scores_buffer_, elements); + this->scores = AllocateBuffer(allocator, scores_buffer_, elements, stream); this->remaining_scores = this->scores; } @@ -68,35 +69,38 @@ struct BeamSearchState : IBeamSearchState { } private: - BufferUniquePtr next_token_logits_buffer_; - BufferUniquePtr next_token_scores_buffer_; - BufferUniquePtr next_tokens_buffer_; - BufferUniquePtr next_indices_buffer_; - BufferUniquePtr next_scores_buffer_; - BufferUniquePtr next_positions_buffer_; - BufferUniquePtr beam_scores_buffer_; - BufferUniquePtr scores_buffer_; - BufferUniquePtr topk_temp_buffer_; - BufferUniquePtr sequences_device_buffer_; + IAllocatorUniquePtr next_token_logits_buffer_; + IAllocatorUniquePtr next_token_scores_buffer_; + IAllocatorUniquePtr next_tokens_buffer_; + IAllocatorUniquePtr next_indices_buffer_; + IAllocatorUniquePtr next_scores_buffer_; + IAllocatorUniquePtr next_positions_buffer_; + IAllocatorUniquePtr beam_scores_buffer_; + IAllocatorUniquePtr scores_buffer_; + IAllocatorUniquePtr topk_temp_buffer_; + IAllocatorUniquePtr sequences_device_buffer_; }; struct BeamSearchCpuState : IBeamSearchCpuState { Sequences sequences; - BeamSearchCpuState(const IGenerationParameters& parameters, AllocatorPtr allocator, bool is_cuda) + BeamSearchCpuState(const IGenerationParameters& parameters, AllocatorPtr allocator, bool is_cuda, Stream* stream) : parameters_{parameters} { - sequence_lengths = AllocateBuffer(allocator, sequence_lengths_buffer_, batch_beam_size_); + sequence_lengths = AllocateBuffer(allocator, sequence_lengths_buffer_, batch_beam_size_, stream); size_t sequences_bytes = SafeInt(2) * batch_beam_size_ * parameters.max_length; - sequences_space = AllocateBuffer(allocator, sequences_space_buffer_, sequences_bytes, true /* fill */); + sequences_space = AllocateBuffer(allocator, sequences_space_buffer_, sequences_bytes, stream, true /* fill */); sequences.Init(sequences_space, batch_beam_size_, parameters.sequence_length, parameters.max_length); if (is_cuda) { // buffers used by CUDA operator but not by CPU operator. - topk_scores = AllocateBuffer(allocator, topk_scores_buffer_, 2 * static_cast(batch_beam_size_)); - topk_tokens = AllocateBuffer(allocator, topk_tokens_buffer_, 2 * static_cast(batch_beam_size_)); - topk_indices = AllocateBuffer(allocator, topk_indices_buffer_, 2 * static_cast(batch_beam_size_)); - final_beam_scores = AllocateBuffer(allocator, final_beam_scores_buffer_, batch_beam_size_); + topk_scores = AllocateBuffer(allocator, topk_scores_buffer_, 2 * static_cast(batch_beam_size_), stream); + topk_tokens = AllocateBuffer(allocator, topk_tokens_buffer_, 2 * static_cast(batch_beam_size_), stream); + topk_indices = AllocateBuffer(allocator, topk_indices_buffer_, 2 * static_cast(batch_beam_size_), stream); + final_beam_scores = AllocateBuffer(allocator, final_beam_scores_buffer_, batch_beam_size_, stream); + + size_t next_token_size = SafeInt(batch_beam_size_) * parameters.vocab_size; + next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size, stream); } } @@ -124,12 +128,13 @@ struct BeamSearchCpuState : IBeamSearchCpuState { const IGenerationParameters& parameters_; const int batch_beam_size_{parameters_.batch_size * parameters_.num_beams}; - BufferUniquePtr final_beam_scores_buffer_; - BufferUniquePtr sequence_lengths_buffer_; - BufferUniquePtr topk_scores_buffer_; - BufferUniquePtr topk_tokens_buffer_; - BufferUniquePtr topk_indices_buffer_; - BufferUniquePtr sequences_space_buffer_; + IAllocatorUniquePtr final_beam_scores_buffer_; + IAllocatorUniquePtr sequence_lengths_buffer_; + IAllocatorUniquePtr topk_scores_buffer_; + IAllocatorUniquePtr topk_tokens_buffer_; + IAllocatorUniquePtr topk_indices_buffer_; + IAllocatorUniquePtr sequences_space_buffer_; + IAllocatorUniquePtr next_token_scores_buffer_; }; // Base class of beam search implementation that is common for GPT-2, T5, and Whisper. diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h index 205d94fae9fab..56d950ca2f41e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h @@ -215,7 +215,8 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch BeamSearchCpuState cpu_state{*parameters, this->cpu_allocator_, - this->IsCuda()}; + this->IsCuda(), + this->ort_stream_}; // buffer in GPU for input_ids, position_ids and attention_mask IAllocatorUniquePtr buffer; @@ -240,7 +241,8 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch BeamSearchState beam_state{*parameters, this->temp_space_allocator_, gpt_subgraph_.has_decoder_masked_attention_, - true /* use_position */}; + true /* use_position */, + this->ort_stream_}; init_beam_state_func_(&beam_state, cpu_state.sequence_lengths, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index 14a0db57c45de..94547887d3a90 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -144,7 +144,8 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches BeamSearchCpuState cpu_state{*parameters, this->cpu_allocator_, - this->IsCuda()}; + this->IsCuda(), + this->ort_stream_}; IAllocatorUniquePtr buffer; @@ -195,7 +196,8 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches BeamSearchState beam_state{*parameters, this->temp_space_allocator_, decoder_subgraph_.has_decoder_masked_attention_, - false /* use_position */}; + false /* use_position */, + this->ort_stream_}; init_beam_state_func_(&beam_state, cpu_state.sequence_lengths, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h index 198dec011c56f..91b93a125ad7a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h @@ -36,7 +36,9 @@ class BeamSearchWhisper : public BeamSearchBase { const GenerationDeviceHelper::UpdateDecoderFeedsFunc& update_decoder_feeds_func, const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_float_func, const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_float16_func, - const GenerationDeviceHelper::CreateBeamScorer& create_beam_scorer_func) + const GenerationDeviceHelper::CreateBeamScorer& create_beam_scorer_func, + const GenerationDeviceHelper::UpdateDecoderCrossQKFunc& update_decoder_cross_qk_func, + const GenerationDeviceHelper::FinalizeDecoderCrossQKFunc& finalize_decoder_cross_qk_func) : BeamSearchBase(context, decoder_session_state, thread_pool, ort_stream, cuda_dumper, params, topk_func, process_logits_func, device_copy_func, device_copy_int32_func), @@ -49,7 +51,11 @@ class BeamSearchWhisper : public BeamSearchBase { update_decoder_feeds_func_(update_decoder_feeds_func), expand_buffer_float_func_(expand_buffer_float_func), expand_buffer_float16_func_(expand_buffer_float16_func), - create_beam_scorer_func_(create_beam_scorer_func) {} + create_beam_scorer_func_(create_beam_scorer_func), + update_decoder_cross_qk_func_(update_decoder_cross_qk_func), + finalize_decoder_cross_qk_func_(finalize_decoder_cross_qk_func), + cuda_device_prop_(nullptr), + cuda_device_arch_(0) {} #ifdef USE_CUDA Status InitializeCuda( @@ -95,6 +101,8 @@ class BeamSearchWhisper : public BeamSearchBase { GenerationDeviceHelper::ExpandBufferFunc expand_buffer_float16_func_; GenerationDeviceHelper::CreateBeamScorer create_beam_scorer_func_; + const GenerationDeviceHelper::UpdateDecoderCrossQKFunc update_decoder_cross_qk_func_; + const GenerationDeviceHelper::FinalizeDecoderCrossQKFunc finalize_decoder_cross_qk_func_; const void* cuda_device_prop_ = nullptr; int cuda_device_arch_ = 0; }; @@ -122,6 +130,17 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe TensorShape scores_shape(&scores_dims[0], sizeof(scores_dims) / sizeof(scores_dims[0])); Tensor* output_scores = this->context_.Output(2, scores_shape); + if (parameters->no_speech_probs_output_id > 0) { + TensorShape no_speech_probs_shape{parameters->batch_size}; + Tensor* no_speech_probs = this->context_.Output(parameters->no_speech_probs_output_id, no_speech_probs_shape); + if (no_speech_probs && no_speech_probs->MutableData()) { + ORT_ENFORCE(parameters->no_speech_token >= 0 && parameters->no_speech_token < parameters->vocab_size, + "no_speech_token id out of range, it is ", parameters->no_speech_token, + ", vocab_size is ", parameters->vocab_size); + this->parameters_->no_speech_probs = (void*)no_speech_probs->MutableData(); + } + } + // Update the flag to indicate whether scores exists in output this->parameters_->output_scores = (output_scores != nullptr); @@ -136,7 +155,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe BeamSearchCpuState cpu_state{*parameters, this->cpu_allocator_, - this->IsCuda()}; + this->IsCuda(), + this->ort_stream_}; IAllocatorUniquePtr buffer; @@ -188,7 +208,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe BeamSearchState beam_state{*parameters, this->temp_space_allocator_, decoder_subgraph_.has_decoder_masked_attention_, - false /* use_position */}; + false /* use_position */, + this->ort_stream_}; init_beam_state_func_(&beam_state, cpu_state.sequence_lengths, @@ -222,6 +243,16 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe std::vector decoder_feeds; int current_length = parameters->sequence_length; + // for decoder subgraph output cross qk + int64_t frames_of_k = 0LL; + Tensor* cross_qk_output = nullptr; // output tensor + int64_t cross_qk_layer_head_pair_count = 0LL; + OrtValue cross_qk_buffer_value; + float* cross_qk_buffer_data = nullptr; + std::vector cross_qk_all_layer_heads; + const int32_t* cross_qk_layer_head_pairs = nullptr; + IAllocatorUniquePtr qk_layer_pointers; // if needed, device array hold the cross qk data pointers, shape of [num_layers] + std::vector decoder_fetches; if (current_length + 1 < parameters->max_length) { @@ -265,6 +296,41 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe } } + if (decoder_subgraph_.output_cross_qk_) { + ORT_ENFORCE(decoder_subgraph_.has_decoder_masked_attention_, "decoder subgraph: output_cross_qk could only work with has_decoder_masked_attention"); + ORT_ENFORCE(decoder_subgraph_.past_present_share_buffer_, "decoder subgraph: output_cross_qk could only work with past_present_share_buffer"); + + cross_qk_layer_head_pair_count = parameters->num_layers * parameters->num_heads; + const auto* input_tensor_cross_qk_layer_head = this->context_.template Input(parameters->cross_qk_layer_head_input_id); + ORT_ENFORCE(input_tensor_cross_qk_layer_head != nullptr, "Must specify input cross_qk_layer_head"); + cross_qk_layer_head_pair_count = input_tensor_cross_qk_layer_head->Shape()[0]; + cross_qk_layer_head_pairs = input_tensor_cross_qk_layer_head->template Data(); // it is on GPU + + size_t decoder_input_first_cross_key = static_cast(decoder_subgraph_.GetFirstPastInputIndex()) + (2 * decoder_subgraph_.num_layers); + auto first_cross_attention_key = decoder_feeds[decoder_input_first_cross_key].GetMutable(); + frames_of_k = first_cross_attention_key->Shape()[2]; + + TensorShape layer_cross_qk_shape{ + static_cast(parameters->BatchBeamSize()), + static_cast(parameters->num_heads), + 1LL, + static_cast(frames_of_k)}; + for (int layer = 0; layer < decoder_subgraph_.num_layers; layer++) { + OrtValue cross_qk_value; + Tensor::InitOrtValue(DataTypeImpl::GetType(), layer_cross_qk_shape, this->temp_space_allocator_, cross_qk_value); + decoder_fetches.emplace_back(cross_qk_value); + } + + TensorShape cross_qk_shape{ + static_cast(parameters->batch_size), + static_cast(parameters->num_beams), + cross_qk_layer_head_pair_count, + static_cast(parameters->max_length), + frames_of_k}; + Tensor::InitOrtValue(DataTypeImpl::GetType(), cross_qk_shape, this->temp_space_allocator_, cross_qk_buffer_value); + cross_qk_buffer_data = cross_qk_buffer_value.GetMutable()->MutableData(); + } + if (decoder_subgraph_.has_decoder_masked_attention_) { size_t offset = static_cast(decoder_subgraph_.GetFirstPastInputIndex()); // Need to check cross attention's past key tensor size, suppose all layers cross attention key size are same @@ -316,6 +382,21 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe ORT_RETURN_IF_ERROR(status); + if (decoder_subgraph_.output_cross_qk_) { + int decoder_output_first_cross_qk = decoder_subgraph_.GetFirstPresentOutputIndex() + (2 * decoder_subgraph_.num_layers); + ORT_RETURN_IF_ERROR(this->update_decoder_cross_qk_func_( + iteration_counter, + this->ort_stream_, + &decoder_fetches[decoder_output_first_cross_qk], + qk_layer_pointers, + parameters->num_layers, + static_cast(cross_qk_layer_head_pair_count), + cross_qk_layer_head_pairs, + cross_qk_buffer_data, + parameters->max_length, + this->temp_space_allocator_)); + } + #ifdef DEBUG_GENERATION for (int i = 0; i <= decoder_subgraph_.GetFirstPresentOutputIndex(); i++) { dumper->Print("decoder_fetches", i, true); @@ -383,6 +464,35 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe } } + if (decoder_subgraph_.output_cross_qk_) { + TensorShape cross_qk_shape{ + static_cast(parameters->batch_size), + static_cast(parameters->num_return_sequences), + cross_qk_layer_head_pair_count, + static_cast(iteration_counter - 1), + frames_of_k}; + cross_qk_output = this->context_.Output(parameters->cross_qk_output_id, cross_qk_shape); + + size_t cache_indir_input_offset = static_cast(decoder_subgraph_.GetFirstPastInputIndex()) + 4 * static_cast(decoder_subgraph_.num_layers) + 2; + const int* cache_indir_data = decoder_feeds[cache_indir_input_offset].GetMutable()->Data(); + auto beam_indices = this->beam_scorer_->GetNextIndicesGPU(); // currently only support on GPU + ORT_RETURN_IF_ERROR(this->finalize_decoder_cross_qk_func_( + this->ort_stream_, + iteration_counter, + parameters->sequence_length, + parameters->batch_size, + parameters->num_beams, + parameters->max_length, + static_cast(cross_qk_layer_head_pair_count), + cross_qk_layer_head_pairs, + static_cast(frames_of_k), + cross_qk_buffer_data, + cross_qk_output->MutableData(), + parameters->num_return_sequences, + cache_indir_data, + beam_indices)); + } + gsl::span final_beam_scores = beam_state.beam_scores; this->beam_scorer_->Finalize(cpu_state.sequences, final_beam_scores, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index 76011a5c89b66..3962486d5b5eb 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -47,6 +47,23 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { } batch_size = static_cast(dims[0]); + extra_decoding_ids = gsl::span(); + if (this->model_type == IGenerationParameters::kModelTypeWhisper && extra_decoding_ids_input_id > 0) { + const Tensor* extra_decoder_tensor = context->Input(extra_decoding_ids_input_id); + if (extra_decoder_tensor != nullptr) { + const auto& extra_decoder_tensor_dims = extra_decoder_tensor->Shape().GetDims(); + ORT_ENFORCE(extra_decoder_tensor_dims.size() == 2, + "extra_decoder_tensor shall have 2 dimensions. Got ", + extra_decoder_tensor_dims.size()); + ORT_ENFORCE(extra_decoder_tensor_dims[0] == batch_size, + "extra_decoder_tensor first dim not same as batch_size. Got ", + extra_decoder_tensor_dims[0], ", expecting ", batch_size); + if (extra_decoder_tensor->Shape().Size() > 0) { + extra_decoding_ids = gsl::span(extra_decoder_tensor->Data(), (size_t)extra_decoder_tensor->Shape().Size()); + } + } + } + if (this->model_type == IGenerationParameters::kModelTypeGpt) { sequence_length = static_cast(dims[1]); } else if (this->model_type == IGenerationParameters::kModelTypeWhisper) { @@ -119,6 +136,18 @@ void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, num_layers = layers; } +void WhisperBeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) { + BeamSearchParameters::ParseFromAttributes(info); + model_type = static_cast(info.GetAttrOrDefault("model_type", IGenerationParameters::kModelTypeWhisper)); + ORT_ENFORCE(model_type == IGenerationParameters::kModelTypeWhisper); + + no_speech_token = static_cast(info.GetAttrOrDefault("no_speech_token", -1LL)); + cross_qk_layer_head_input_id = 12; + extra_decoding_ids_input_id = 13; + cross_qk_output_id = 3; + no_speech_probs_output_id = 4; +} + } // namespace transformers } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h index 0cb2b39976cc3..87bc6cdbfe72c 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h @@ -11,17 +11,23 @@ namespace contrib { namespace transformers { struct BeamSearchParameters : public IGenerationParameters { + virtual ~BeamSearchParameters() {} + Status Validate() const; int BatchBeamSize() const { return batch_size * num_beams; } - void ParseFromAttributes(const OpKernelInfo& info); + virtual void ParseFromAttributes(const OpKernelInfo& info); void ParseFromInputs(OpKernelContext* context); void SetSubgraphParameters(int vocab_size, int num_heads, int head_size, int num_layers); }; +struct WhisperBeamSearchParameters : public BeamSearchParameters { + void ParseFromAttributes(const OpKernelInfo& info) override; +}; + } // namespace transformers } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h index e889281abb023..680cb23fd887a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h @@ -33,24 +33,43 @@ gsl::span AllocateBuffer(AllocatorPtr allocator, return span; } +template +gsl::span AllocateBuffer(AllocatorPtr allocator, + IAllocatorUniquePtr& buffer, + size_t elements, + Stream* stream, + bool fill = false, + T fill_value = T{}) { + size_t bytes = SafeInt(sizeof(T)) * elements; + buffer = IAllocator::MakeUniquePtr(allocator, bytes, false, stream); + T* first = reinterpret_cast(buffer.get()); + auto span = gsl::make_span(first, elements); + + if (fill) { + std::fill_n(first, elements, fill_value); + } + + return span; +} + template inline void AllocateTempBufferForGetGreedySearchTopOne( int32_t batch_size, AllocatorPtr allocator, - BufferUniquePtr& buffer, + IAllocatorUniquePtr& buffer, gsl::span& stage_1_scores, // shape (batch_size, parts_of_vocab) gsl::span& stage_1_tokens, // shape (batch_size, parts_of_vocab) gsl::span& output_scores, // shape (batch_size) - gsl::span& output_tokens // shape (batch_size) -) { + gsl::span& output_tokens, // shape (batch_size) + Stream* stream) { constexpr size_t kMaxPartsPerVocab = 128; const size_t stage_1_element_size = kMaxPartsPerVocab * batch_size; const size_t output_element_size = batch_size; // Note: use float to allocate buffer for temporary value buffer to avoid unalignment - void* topk_data = allocator->Alloc((stage_1_element_size + output_element_size) * (sizeof(float) + sizeof(int32_t))); - BufferUniquePtr temp_buffer(topk_data, BufferDeleter(allocator)); - buffer = std::move(temp_buffer); + size_t bytes = (stage_1_element_size + output_element_size) * (sizeof(float) + sizeof(int32_t)); + buffer = IAllocator::MakeUniquePtr(allocator, bytes, false, stream); + void* topk_data = buffer.get(); ElementType* stage_1_scores_data = reinterpret_cast(topk_data); stage_1_scores = gsl::make_span(stage_1_scores_data, stage_1_element_size); diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index 88348ad88dc27..927d3a58e5a6f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -1084,6 +1084,38 @@ template Status CreateWhisperEncoderInputs( OrtValue& encoder_input_features, OrtValue& decoder_input_ids); +Status UpdateDecoderCrossQK( + [[maybe_unused]] int iteration_number, + [[maybe_unused]] Stream* tream, + [[maybe_unused]] OrtValue* cross_qks, + [[maybe_unused]] IAllocatorUniquePtr& qk_layer_pointers, + [[maybe_unused]] int num_layers, + [[maybe_unused]] int cross_qk_layer_head_pair_count, + [[maybe_unused]] const int* cross_qk_layer_head_pairs, + [[maybe_unused]] float* cross_qk_buffer_data, + [[maybe_unused]] int max_length, + [[maybe_unused]] AllocatorPtr allocator) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CPU beam search current not support output cross QK."); +} + +Status FinalizeDecoderCrossQK( + [[maybe_unused]] Stream* stream, + [[maybe_unused]] int iteration_number, + [[maybe_unused]] int context_decoding_len, + [[maybe_unused]] int batch_size, + [[maybe_unused]] int num_beams, + [[maybe_unused]] int max_length, + [[maybe_unused]] int cross_qk_layer_head_pair_count, + [[maybe_unused]] const int* cross_qk_layer_head_pairs, + [[maybe_unused]] int frames_of_k, + [[maybe_unused]] const float* cross_qk_buffer_data, + [[maybe_unused]] float* cross_qk_output, + [[maybe_unused]] int num_return_sequences, + [[maybe_unused]] const int* cache_indir_data, + [[maybe_unused]] gsl::span beam_indices) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CPU beam search current not support output cross QK."); +} + } // namespace GenerationCpuDeviceHelper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h index ba1b0b662f1a5..6dfdc6b027671 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h @@ -204,6 +204,35 @@ using ExpandBufferFunc = std::function; + +using UpdateDecoderCrossQKFunc = std::function& qk_layer_pointers, + int num_layers, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + float* cross_qk_buffer_data, + int max_length, + AllocatorPtr allocator)>; + +using FinalizeDecoderCrossQKFunc = std::function beam_indices)>; + } // namespace GenerationDeviceHelper // These are CPU specific device helper implementations @@ -368,6 +397,34 @@ Status ExpandBuffer( bool only_copy_shape, int max_sequence_length); +Status UpdateDecoderCrossQK( + int iteration_number, + Stream* stream, + OrtValue* cross_qks, + IAllocatorUniquePtr& qk_layer_pointers, + int num_layers, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + float* cross_qk_buffer_data, + int max_length, + AllocatorPtr allocator); + +Status FinalizeDecoderCrossQK( + Stream* stream, + int iteration_number, + int context_decoding_len, + int batch_size, + int num_beams, + int max_length, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + int frames_of_k, + const float* cross_qk_buffer_data, + float* cross_qk_output, + int num_return_sequences, + const int* cache_indir_data, + gsl::span beam_indices); + } // namespace GenerationCpuDeviceHelper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 719dd302d274d..f6faf2e325f8f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -53,6 +53,7 @@ struct IBeamSearchCpuState { gsl::span topk_tokens; // shape (batch_size, 2*num_beams), tokens of topk candidates. gsl::span topk_indices; // shape (batch_size, 2*num_beams), beam indices of topk candidates. gsl::span final_beam_scores; // shape (batch_size, num_beams) + gsl::span next_token_scores; // shape (batch_size, num_beams * vocab_size) }; template @@ -175,6 +176,17 @@ struct IGenerationParameters { int seed = 0; int min_tokens_to_keep = 1; bool custom_sampling = false; + + // Parameters for whisper model + bool decoder_output_cross_qk = false; + gsl::span extra_decoding_ids; + int32_t no_speech_token = -1; + void* no_speech_probs = nullptr; + + int cross_qk_layer_head_input_id = -1; + int extra_decoding_ids_input_id = -1; + int cross_qk_output_id = -1; + int no_speech_probs_output_id = -1; }; } // namespace transformers diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h index be974ed2159d9..9f372e5b3a673 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h @@ -20,26 +20,27 @@ struct SamplingState : public ISamplingState { int vocab_size, int max_iter, int seed, - bool is_cuda) { + bool is_cuda, + Stream* stream) { int total_count = batch_size * vocab_size; - this->h_softmaxed_score = AllocateBuffer(cpu_allocator, h_softmaxed_score_buffer_, SafeInt(total_count)); + this->h_softmaxed_score = AllocateBuffer(cpu_allocator, h_softmaxed_score_buffer_, SafeInt(total_count), stream); this->generator = std::default_random_engine{gsl::narrow_cast(seed)}; if (is_cuda) { - this->d_index_in = AllocateBuffer(allocator, d_index_in_buffer_, SafeInt(total_count)); - this->d_index_out = AllocateBuffer(allocator, d_index_out_buffer_, SafeInt(total_count)); - this->d_offset = AllocateBuffer(allocator, d_offset_buffer_, SafeInt(batch_size + 1)); - this->d_sorted_score = AllocateBuffer(allocator, d_sorted_score_buffer_, SafeInt(total_count)); - this->d_sorted_softmaxed_score = AllocateBuffer(allocator, d_sorted_softmaxed_score_buffer_, SafeInt(total_count)); - this->d_softmaxed_score = AllocateBuffer(allocator, d_softmaxed_score_buffer_, SafeInt(total_count)); - this->d_sampled = AllocateBuffer(allocator, d_sampled_buffer_, SafeInt(batch_size)); - this->h_sampled_all = AllocateBuffer(cpu_allocator, h_sampled_all_buffer_, SafeInt(batch_size * max_iter)); - this->d_indices = AllocateBuffer(allocator, d_indices_buffer_, SafeInt(batch_size)); + this->d_index_in = AllocateBuffer(allocator, d_index_in_buffer_, SafeInt(total_count), stream); + this->d_index_out = AllocateBuffer(allocator, d_index_out_buffer_, SafeInt(total_count), stream); + this->d_offset = AllocateBuffer(allocator, d_offset_buffer_, SafeInt(batch_size + 1), stream); + this->d_sorted_score = AllocateBuffer(allocator, d_sorted_score_buffer_, SafeInt(total_count), stream); + this->d_sorted_softmaxed_score = AllocateBuffer(allocator, d_sorted_softmaxed_score_buffer_, SafeInt(total_count), stream); + this->d_softmaxed_score = AllocateBuffer(allocator, d_softmaxed_score_buffer_, SafeInt(total_count), stream); + this->d_sampled = AllocateBuffer(allocator, d_sampled_buffer_, SafeInt(batch_size), stream); + this->h_sampled_all = AllocateBuffer(cpu_allocator, h_sampled_all_buffer_, SafeInt(batch_size * max_iter), stream); + this->d_indices = AllocateBuffer(allocator, d_indices_buffer_, SafeInt(batch_size), stream); this->temp_storage_bytes = 0; // TODO: Do not allocate this buffer if there's no presence_mask - this->d_presence_mask = AllocateBuffer(allocator, d_presence_mask_buffer_, SafeInt(total_count)); + this->d_presence_mask = AllocateBuffer(allocator, d_presence_mask_buffer_, SafeInt(total_count), stream); std::uniform_real_distribution distribution(0.0, 1.0); static_cast(distribution(this->generator)); @@ -48,25 +49,25 @@ struct SamplingState : public ISamplingState { } } else { // TODO: Some buffer can be reused for CPU - this->sorted_scores = AllocateBuffer(cpu_allocator, sorted_scores_buffer_, SafeInt(total_count)); - this->cumulative_probs = AllocateBuffer(cpu_allocator, cumulative_probs_buffer_, SafeInt(total_count)); + this->sorted_scores = AllocateBuffer(cpu_allocator, sorted_scores_buffer_, SafeInt(total_count), stream); + this->cumulative_probs = AllocateBuffer(cpu_allocator, cumulative_probs_buffer_, SafeInt(total_count), stream); } } private: - BufferUniquePtr d_index_in_buffer_; - BufferUniquePtr d_index_out_buffer_; - BufferUniquePtr d_offset_buffer_; - BufferUniquePtr d_sorted_score_buffer_; - BufferUniquePtr d_sorted_softmaxed_score_buffer_; - BufferUniquePtr d_softmaxed_score_buffer_; - BufferUniquePtr h_softmaxed_score_buffer_; - BufferUniquePtr d_sampled_buffer_; - BufferUniquePtr h_sampled_all_buffer_; - BufferUniquePtr d_indices_buffer_; - BufferUniquePtr d_presence_mask_buffer_; - BufferUniquePtr sorted_scores_buffer_; - BufferUniquePtr cumulative_probs_buffer_; + IAllocatorUniquePtr d_index_in_buffer_; + IAllocatorUniquePtr d_index_out_buffer_; + IAllocatorUniquePtr d_offset_buffer_; + IAllocatorUniquePtr d_sorted_score_buffer_; + IAllocatorUniquePtr d_sorted_softmaxed_score_buffer_; + IAllocatorUniquePtr d_softmaxed_score_buffer_; + IAllocatorUniquePtr h_softmaxed_score_buffer_; + IAllocatorUniquePtr d_sampled_buffer_; + IAllocatorUniquePtr h_sampled_all_buffer_; + IAllocatorUniquePtr d_indices_buffer_; + IAllocatorUniquePtr d_presence_mask_buffer_; + IAllocatorUniquePtr sorted_scores_buffer_; + IAllocatorUniquePtr cumulative_probs_buffer_; }; template @@ -82,24 +83,25 @@ struct GreedySearchState : public IGreedySearchState { int num_heads, int head_size, bool has_decoder_masked_self_attention, - bool is_cuda) { + bool is_cuda, + Stream* stream) { // below buffers are on cpu this->sequences_space = AllocateBuffer(cpu_allocator, sequences_space_buffer_, - SafeInt(2) * batch_size * max_length); + SafeInt(2) * batch_size * max_length, stream); memset(this->sequences_space.data(), 0, this->sequences_space.size_bytes()); this->sequences.Init(this->sequences_space, static_cast(batch_size), sequence_length, max_length); - this->sequence_lengths = AllocateBuffer(cpu_allocator, sequence_lengths_buffer_, batch_size); - this->eos_meet = AllocateBuffer(cpu_allocator, eos_meet_buffer_, batch_size); + this->sequence_lengths = AllocateBuffer(cpu_allocator, sequence_lengths_buffer_, batch_size, stream); + this->eos_meet = AllocateBuffer(cpu_allocator, eos_meet_buffer_, batch_size, stream); memset(this->eos_meet.data(), 0, this->eos_meet.size_bytes()); - this->next_tokens = AllocateBuffer(cpu_allocator, next_tokens_buffer_, SafeInt(batch_size)); + this->next_tokens = AllocateBuffer(cpu_allocator, next_tokens_buffer_, SafeInt(batch_size), stream); // below buffers are on cpu or cuda size_t next_token_size = SafeInt(batch_size) * vocab_size; - this->next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size); - this->next_positions = AllocateBuffer(allocator, next_positions_buffer_, batch_size); + this->next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size, stream); + this->next_positions = AllocateBuffer(allocator, next_positions_buffer_, batch_size, stream); if (is_cuda) { AllocateTempBufferForGetGreedySearchTopOne( @@ -109,7 +111,8 @@ struct GreedySearchState : public IGreedySearchState { this->temp_topk_scores_buffer, this->temp_topk_tokens_buffer, this->topk_scores_buffer, - this->topk_tokens_buffer); + this->topk_tokens_buffer, + stream); // If at all we need to, we only need to re-order past state for CUDA as //`DecoderMaskedSelfAttention` is only supported on CUDA @@ -137,14 +140,14 @@ struct GreedySearchState : public IGreedySearchState { } private: - BufferUniquePtr sequences_space_buffer_; - BufferUniquePtr sequence_lengths_buffer_; - BufferUniquePtr next_token_scores_buffer_; - BufferUniquePtr next_tokens_buffer_; - BufferUniquePtr next_positions_buffer_; - BufferUniquePtr eos_meet_buffer_; - BufferUniquePtr temp_topk_buffer_; - BufferUniquePtr staging_for_past_state_reorder_buffer_; + IAllocatorUniquePtr sequences_space_buffer_; + IAllocatorUniquePtr sequence_lengths_buffer_; + IAllocatorUniquePtr next_token_scores_buffer_; + IAllocatorUniquePtr next_tokens_buffer_; + IAllocatorUniquePtr next_positions_buffer_; + IAllocatorUniquePtr eos_meet_buffer_; + IAllocatorUniquePtr temp_topk_buffer_; + IAllocatorUniquePtr staging_for_past_state_reorder_buffer_; }; // Base class of gready search implementation that is common for both GPT-2 and Bart/T5. diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h index 4504b099e32bd..69d25eaabbe02 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h @@ -211,7 +211,8 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_ static_cast(parameters->num_heads), static_cast(parameters->head_size), gpt_subgraph_.has_decoder_masked_attention_, - this->IsCuda()); + this->IsCuda(), + this->ort_stream_); SamplingState sampling_state; if (std::is_same::value) { @@ -221,7 +222,8 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_ static_cast(parameters->vocab_size), static_cast(parameters->max_length - parameters->sequence_length), parameters->seed, - this->IsCuda()); + this->IsCuda(), + this->ort_stream_); } IAllocatorUniquePtr buffer; diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_parameters.h index f1150fdba00e8..4ef0c180eba34 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_parameters.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_parameters.h @@ -13,7 +13,7 @@ namespace transformers { struct GreedySearchParameters : public BeamSearchParameters { int BatchBeamSize() const { return batch_size; } - void ParseFromAttributes(const OpKernelInfo& info); + void ParseFromAttributes(const OpKernelInfo& info) override; void ParseFromInputs(OpKernelContext* context); }; diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index 9f77c32f0c7cc..f39f090c78b0c 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -17,20 +17,6 @@ namespace onnxruntime { namespace contrib { namespace transformers { -template -gsl::span NextTokenScores::GetScores(int batch_beam_index) { - assert(batch_beam_index >= 0 && batch_beam_index < batch_beam_size); - return scores.subspan(static_cast(batch_beam_index) * vocab_size, vocab_size); -} - -template -void NextTokenScores::SetScore(int token_id, T score) { - assert(token_id >= 0 && token_id < vocab_size); - for (int i = 0; i < batch_beam_size; i++) { - scores[static_cast(i) * vocab_size + token_id] = score; - } -} - #ifdef DEBUG_GENERATION template void DumpScores(const char* name, const NextTokenScores& next_token_scores) { @@ -238,128 +224,6 @@ void PresencePenaltyLogitsProcessor::Process(const ISequences*, #endif } -template -TimestampLogitsProcessor::TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index) - : eos_token_id_(eos_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {} - -template -void TimestampLogitsProcessor::Process(const ISequences* sequences, - NextTokenScores& next_token_scores) { - const int beg_token_id_ = eos_token_id_ + 107; - const int not_token_id_ = eos_token_id_ + 106; - const int solm_token_id_ = eos_token_id_ + 105; - const int sot_token_id_ = eos_token_id_ + 1; - constexpr int translate_token_id_ = 50358; - constexpr int transcribe_token_id_ = 50359; - - const int batch_beam_size = next_token_scores.batch_beam_size; - const int vocab_size = next_token_scores.vocab_size; - for (int i = 0; i < batch_beam_size; i++) { - gsl::span beam_token_scores = next_token_scores.GetScores(i); - gsl::span sequence = sequences->GetSequence(i); - const size_t seq_length = sequence.size(); - - // Find first timestamp - size_t sample_begin = 0; - for (size_t j = 0; j < seq_length; j++) { - sample_begin++; - if (sequence[j] >= beg_token_id_) { - break; - } - } - - // Suppress tokens - for (int j = 0; j < vocab_size; j++) { - // Suppress notimestamps and solm tokens - if (j == not_token_id_ || j == solm_token_id_) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - - // Suppress sot, translate and transcribe tokens - if (seq_length > sample_begin) { - if (j == sot_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - } - } - - // Timestamps should be in pair except the first one - const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beg_token_id_; - const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beg_token_id_; - if (last_was_timestamp) { - if (penultimate_was_timestamp) { - // If timestamps show up in pair, or it's the first timestamp, no more timestamp is generated - for (int j = beg_token_id_; j < vocab_size; j++) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - } else { - // If timestamp doesn't show up in pair, generate timestamp - for (int j = 0; j < eos_token_id_; j++) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - } - } - - // Find timestamp tokens - std::vector timestamps; - for (const auto& word_id : sequence) { - if (word_id >= beg_token_id_) { - timestamps.push_back(word_id); - } - } - - // Timestamps will not decrease - const size_t timestamps_len = timestamps.size(); - if (timestamps_len > 0) { - int timestamp_last = 0; - if (last_was_timestamp && !penultimate_was_timestamp) { - // For single timestamp at the end, next timestamp must not be smaller - timestamp_last = timestamps.back(); - } else { - // For paired timestamp at the end, next timestamp must be greater - timestamp_last = timestamps.back() + 1; - } - - for (int j = beg_token_id_; j < timestamp_last; j++) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - } - - if (seq_length == sample_begin) { - const int last_allowed = beg_token_id_ + max_initial_timestamp_index_; - for (int j = last_allowed + 1; j < vocab_size; j++) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - } - - // Caculate logsumexp on timestamps - float timestamp_logprob = std::numeric_limits::lowest(); - { - float logsumexp = 0.0f; - const float logprob_max = *std::max_element(beam_token_scores.begin() + beg_token_id_, beam_token_scores.end()); - for (int j = beg_token_id_; j < vocab_size; ++j) { - if (beam_token_scores[j] > std::numeric_limits::lowest()) { - logsumexp += expf(beam_token_scores[j] - logprob_max); - } - } - if (logsumexp > 0.0f) { - timestamp_logprob = logf(logsumexp) + logprob_max; - } - } - - const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beg_token_id_); - if (timestamp_logprob > max_text_token_logprob) { - for (int j = 0; j < beg_token_id_; ++j) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - } - } - -#ifdef DEBUG_GENERATION - DumpScores("TimestampLogitsProcessor", next_token_scores); -#endif -} - void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { LogitsProcessorInitImpl(parameters); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 664c497a106d4..4688ff272cee9 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -6,6 +6,7 @@ #include "core/common/inlined_containers.h" #include "contrib_ops/cpu/transformers/sequences.h" #include "contrib_ops/cpu/transformers/beam_search_parameters.h" +#include "contrib_ops/cpu/transformers/dump_tensor.h" #include "contrib_ops/cpu/transformers/greedy_search_parameters.h" #include "contrib_ops/cpu/transformers/sampling_parameters.h" #include "contrib_ops/cpu/transformers/generation_shared.h" @@ -20,9 +21,17 @@ struct NextTokenScores { int batch_beam_size; int vocab_size; - gsl::span GetScores(int batch_beam_index); + gsl::span GetScores(int batch_beam_index) { + assert(batch_beam_index >= 0 && batch_beam_index < batch_beam_size); + return scores.subspan(static_cast(batch_beam_index) * vocab_size, vocab_size); + } - void SetScore(int token_id, T score); + void SetScore(int token_id, T score) { + assert(token_id >= 0 && token_id < vocab_size); + for (int i = 0; i < batch_beam_size; i++) { + scores[static_cast(i) * vocab_size + token_id] = score; + } + } }; // Interface for all scorers for beam search or beam sample. @@ -141,10 +150,126 @@ class PresencePenaltyLogitsProcessor : public ILogitsProcessor { template class TimestampLogitsProcessor : public ILogitsProcessor { public: - TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index); + TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index) + : eos_token_id_(eos_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {} void Process(const ISequences* sequences, - NextTokenScores& next_token_scores) override; + NextTokenScores& next_token_scores) override { + // TODO: translate_token_id_ and transcribe_token_id_ need to support both multilingual and English-only models. + const int beg_token_id_ = eos_token_id_ + 107; + const int not_token_id_ = eos_token_id_ + 106; + const int solm_token_id_ = eos_token_id_ + 105; + const int sot_token_id_ = eos_token_id_ + 1; + constexpr int translate_token_id_ = 50358; + constexpr int transcribe_token_id_ = 50359; + + const int batch_beam_size = next_token_scores.batch_beam_size; + const int vocab_size = next_token_scores.vocab_size; + for (int i = 0; i < batch_beam_size; i++) { + gsl::span beam_token_scores = next_token_scores.GetScores(i); + gsl::span sequence = sequences->GetSequence(i); + const size_t seq_length = sequence.size(); + + // Find first timestamp + size_t sample_begin = 0; + for (size_t j = 0; j < seq_length; j++) { + sample_begin++; + if (sequence[j] >= beg_token_id_) { + break; + } + } + + // Suppress tokens + for (int j = 0; j < vocab_size; j++) { + // Suppress notimestamps and solm tokens + if (j == not_token_id_ || j == solm_token_id_) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + + // Suppress sot, translate and transcribe tokens + if (seq_length > sample_begin) { + if (j == sot_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + } + + // Timestamps should be in pair except the first one + const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beg_token_id_; + const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beg_token_id_; + if (last_was_timestamp) { + if (penultimate_was_timestamp) { + // If timestamps show up in pair, or it's the first timestamp, no more timestamp is generated + for (int j = beg_token_id_; j < vocab_size; j++) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } else { + // If timestamp doesn't show up in pair, generate timestamp + for (int j = 0; j < eos_token_id_; j++) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + } + + // Find timestamp tokens + std::vector timestamps; + for (const auto& word_id : sequence) { + if (word_id >= beg_token_id_) { + timestamps.push_back(word_id); + } + } + + // Timestamps will not decrease + const size_t timestamps_len = timestamps.size(); + if (timestamps_len > 0) { + int timestamp_last = 0; + if (last_was_timestamp && !penultimate_was_timestamp) { + // For single timestamp at the end, next timestamp must not be smaller + timestamp_last = timestamps.back(); + } else { + // For paired timestamp at the end, next timestamp must be greater + timestamp_last = timestamps.back() + 1; + } + + for (int j = beg_token_id_; j < timestamp_last; j++) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + + if (seq_length == sample_begin) { + const int last_allowed = beg_token_id_ + max_initial_timestamp_index_; + for (int j = last_allowed + 1; j < vocab_size; j++) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + + // Caculate logsumexp on timestamps + float timestamp_logprob = std::numeric_limits::lowest(); + { + float logsumexp = 0.0f; + const float logprob_max = *std::max_element(beam_token_scores.begin() + beg_token_id_, beam_token_scores.end()); + for (int j = beg_token_id_; j < vocab_size; ++j) { + if (beam_token_scores[j] > std::numeric_limits::lowest()) { + logsumexp += expf(beam_token_scores[j] - logprob_max); + } + } + if (logsumexp > 0.0f) { + timestamp_logprob = logf(logsumexp) + logprob_max; + } + } + + const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beg_token_id_); + if (timestamp_logprob > max_text_token_logprob) { + for (int j = 0; j < beg_token_id_; ++j) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + } + +#ifdef DEBUG_GENERATION + DumpScores("TimestampLogitsProcessor", next_token_scores); +#endif + } private: int eos_token_id_; diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h index af203abc15d01..7779a73feffba 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h @@ -11,7 +11,7 @@ namespace contrib { namespace transformers { struct SamplingParameters : public GreedySearchParameters { - void ParseFromAttributes(const OpKernelInfo& info); + void ParseFromAttributes(const OpKernelInfo& info) override; void ParseFromInputs(OpKernelContext* context); }; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h index 3c11d2d324a85..487a35c55a85f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h @@ -45,6 +45,7 @@ class Subgraph { int num_layers; bool past_present_share_buffer_; bool has_decoder_masked_attention_; + bool output_cross_qk_ = false; // Setup execution Status Setup(const SessionState& session_state, diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 28acd81ae95fd..4d61ce71c69be 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -172,7 +172,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( int32_t* input_ids_data = input_ids.GetMutable()->MutableData(); AllocatorPtr buffer_allocator = std::make_shared(); size_t total_size = static_cast(static_cast(cur_len) * batch_beam_size * sizeof(int)); - auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size); + auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size, false, stream); int* seq_copy_ptr = seq_copy.get(); if (!use_sequence_as_input_ids_) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h index 085d8f3903976..83dae49c7dcbd 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h @@ -5,6 +5,7 @@ #include "contrib_ops/cpu/transformers/subgraph_base.h" #include "contrib_ops/cpu/transformers/sequences.h" +#include "core/framework/op_kernel.h" namespace onnxruntime { namespace contrib { @@ -20,6 +21,13 @@ class T5DecoderSubgraph : public Subgraph { has_hidden_state_(false), use_sequence_as_input_ids_(true) { first_present_output_index_ = 1; + + // Currently just using parent node's attribute. Maybe better to find it purely in subgraph. + const auto& attributes = node_in.GetAttributes(); + if (attributes.find("decoder_output_cross_qk") != attributes.end()) { + auto& attr = attributes.at("decoder_output_cross_qk"); + output_cross_qk_ = (attr.i() != 0LL); + } } // Create inputs for first inference of decoder subgraph. @@ -62,7 +70,7 @@ class T5DecoderSubgraph : public Subgraph { return first_present_output_index_; } - bool UseSequenceAsInputIds() const { + inline bool UseSequenceAsInputIds() const { return use_sequence_as_input_ids_; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc index 887a6a8984b83..7d0c62b618ee2 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc @@ -70,8 +70,15 @@ Status WhisperDecoderSubgraph::Validate(const std::vector& subgr "number of inputs expected to be kFirstPastInputIndex + 4 * layers + 1, got:", num_subgraph_inputs); } - ORT_RETURN_IF(num_subgraph_outputs < 3 || (num_subgraph_outputs - first_present_output_index_) % 2 != 0, - "number of outputs expected to be 1 + 2 * layers, got:", num_subgraph_outputs); + if (!output_cross_qk_) { + ORT_RETURN_IF(num_subgraph_outputs < 3 || (num_subgraph_outputs - first_present_output_index_) % 2 != 0, + "number of outputs expected to be first_present_output_index_", + first_present_output_index_, " + 2 * layers, got:", num_subgraph_outputs); + } else { + ORT_RETURN_IF(num_subgraph_outputs < 4 || (num_subgraph_outputs - first_present_output_index_) % 3 != 0, + "When outputing cross qk, number of outputs expected to be first_present_output_index_", + first_present_output_index_, " + 3 * layers, got:", num_subgraph_outputs); + } ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids", "decoder subgraph input 0 shall be named as input_ids, got: ", subgraph_inputs[0]->Name()); @@ -90,7 +97,8 @@ Status WhisperDecoderSubgraph::Validate(const std::vector& subgr // Save parameters related to the subgraph. ORT_RETURN_IF_ERROR(GetParameters(past_shape, logits_shape, false)); - num_layers = (static_cast(subgraph_outputs.size()) - first_present_output_index_) / 2; + + num_layers = (static_cast(subgraph_outputs.size()) - first_present_output_index_) / (output_cross_qk_ ? 3 : 2); // If input_ids's shape is ['batch_size', 1] then use next token as input_ids. // Otherwise in the case of shape ['batch_size', 'sequence'], use sequence as input_ids. @@ -112,12 +120,7 @@ Status WhisperDecoderSubgraph::Validate(const std::vector& subgr for (int i = first_past_input_index_; i < first_past_input_index_ + 4 * num_layers; i++) { ORT_RETURN_IF(subgraph_inputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type, - "decoder subgraph past inputs shall have same data type as that of encoder_hidden_states"); - } - - for (int i = 0; i < num_subgraph_outputs; i++) { - ORT_RETURN_IF(subgraph_outputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type, - "decoder subgraph output shall have same data type as that of encoder_hidden_states"); + "decoder subgraph past inputs shall have same data type as that of encoder_hidden_states."); } is_output_float16_ = (subgraph_outputs[0]->TypeAsProto()->tensor_type().elem_type() == float16_type); @@ -166,7 +169,7 @@ Status WhisperDecoderSubgraph::CreateInitialFeeds( AllocatorPtr buffer_allocator = std::make_shared(); size_t total_size = static_cast(static_cast(cur_len) * batch_beam_size * sizeof(int)); - auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size); + auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size, false, stream); int* seq_copy_ptr = seq_copy.get(); if (!use_sequence_as_input_ids_) { diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 48af0a9a6adec..bf6431cf1afb2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -135,8 +135,24 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { if (use_flash_attention && parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { use_flash_attention = false; } + // Allocate buffers + size_t softmax_lse_accum_bytes = 0; + size_t out_accum_bytes = 0; + if (use_flash_attention) { + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( + parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, + parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; + } + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); #else constexpr bool use_flash_attention = false; + auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif if (!use_flash_attention) { @@ -156,8 +172,10 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { if (use_causal_fused_runner) { // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. if (nullptr == fused_fp16_runner_.get()) { - fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, - enable_trt_flash_attention_, parameters.scale); + std::call_once(fused_fp16_runner_created_, [&]() { + fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, + enable_trt_flash_attention_, parameters.scale); + }); } // Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check. @@ -177,8 +195,10 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { if (use_fused_runner) { // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. if (nullptr == fused_fp16_runner_.get()) { - fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, - enable_trt_flash_attention_, parameters.scale); + std::call_once(fused_fp16_runner_created_, [&]() { + fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, + enable_trt_flash_attention_, parameters.scale); + }); } // In case some kernel not loaded due to shared memory limit, we need to double check here. @@ -215,11 +235,12 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; - IAllocatorUniquePtr gemm_buffer; + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); int m = batch_size * sequence_length; int n = (parameters.hidden_size + parameters.hidden_size + parameters.v_hidden_size); int k = parameters.input_hidden_size; - gemm_buffer = GetScratchBuffer(static_cast(m) * n, context->GetComputeStream()); + IAllocatorUniquePtr gemm_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(m * n) * sizeof(T), false, context->GetComputeStream()); CudaT one = ToCudaType::FromFloat(1.0f); CudaT zero = ToCudaType::FromFloat(0.0f); @@ -246,7 +267,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { use_flash_attention, use_fused_cross_attention, use_memory_efficient_attention); - auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); + IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, workSpaceSize, false, context->GetComputeStream()); + ; typedef typename ToCudaType::MappedType CudaT; AttentionData data; @@ -273,6 +295,12 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { data.fused_runner = reinterpret_cast(fused_runner); data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; + if (softmax_lse_accum_buffer != nullptr) { + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + } + if (out_accum_buffer != nullptr) { + data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } return QkvToContext(device_prop, cublas, context->GetComputeStream(), parameters, data); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.h b/onnxruntime/contrib_ops/cuda/bert/attention.h index 455e55ba05a66..acafb379d713f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "core/providers/cuda/cuda_kernel.h" #include "contrib_ops/cpu/bert/attention_base.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" @@ -28,6 +29,7 @@ class Attention final : public CudaKernel, public AttentionBase { bool disable_memory_efficient_attention_; int min_seq_len_for_flash_attention_packed_qkv_; mutable std::unique_ptr fused_fp16_runner_; + mutable std::once_flag fused_fp16_runner_created_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index b4a4ae208ceb1..16ce3a899fb5e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -316,7 +316,9 @@ Status FlashAttention( ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( device_prop, stream, query, key, value, data.output, reinterpret_cast(data.scratch), parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, - parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional)); + parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, + parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), + true)); DUMP_TENSOR("flash attention output", data.output, parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); @@ -372,6 +374,7 @@ Status EfficientAttention( p.num_heads = parameters.num_heads; p.sequence_length = parameters.sequence_length; p.kv_sequence_length = parameters.total_sequence_length; + p.max_sequence_length = parameters.total_sequence_length; p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = parameters.is_unidirectional; @@ -393,6 +396,7 @@ Status EfficientAttention( p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; p.output = data.output; + p.is_kv_bsnh = true; p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float)) ? data.scratch : nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index d0a5fb51a25d6..3e78978c3cc43 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -88,6 +88,11 @@ struct AttentionData { T* v = nullptr; T* scratch = nullptr; AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + + // Flash buffers + T* softmax_lse = nullptr; + T* softmax_lse_accum = nullptr; + T* out_accum = nullptr; }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index ed330b0fca332..51c3d3d3a458b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -51,25 +51,45 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { p.num_keys = params.kv_sequence_length; if (params.causal) { - p.custom_mask_type = Attention::CausalFromTopLeft; + p.custom_mask_type = Attention::CausalFromBottomRight; } - // Input format is BxSxNxH, output is BxSxNxH - p.q_strideH = params.qk_head_size; - p.k_strideH = params.qk_head_size; - p.v_strideH = params.v_head_size; - p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; - - p.q_strideM = params.num_heads * params.qk_head_size; - p.k_strideM = params.num_heads * params.qk_head_size; - p.v_strideM = params.num_heads * params.v_head_size; - p.o_strideM = params.num_heads * params.v_head_size; - p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; - - p.q_strideB = static_cast(p.q_strideM) * params.sequence_length; - p.k_strideB = static_cast(p.k_strideM) * params.kv_sequence_length; - p.v_strideB = static_cast(p.v_strideM) * params.kv_sequence_length; - p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + // We use max_sequence_length to calculate KV stride + if (params.is_kv_bsnh) { + // Input Q, K, V format is BxSxNxH, output is BxSxNxH + p.q_strideH = params.qk_head_size; + p.k_strideH = params.qk_head_size; + p.v_strideH = params.v_head_size; + p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; + + p.q_strideM = params.num_heads * params.qk_head_size; + p.k_strideM = params.num_heads * params.qk_head_size; + p.v_strideM = params.num_heads * params.v_head_size; + p.o_strideM = params.num_heads * params.v_head_size; + p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; + + p.q_strideB = static_cast(p.q_strideM) * params.sequence_length; + p.k_strideB = static_cast(p.k_strideM) * params.max_sequence_length; + p.v_strideB = static_cast(p.v_strideM) * params.max_sequence_length; + p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + } else { + // Input K, V format is BxNxSxH, Input Q is BxSxNxH, output is BxSxNxH + p.q_strideH = params.qk_head_size; + p.k_strideH = params.max_sequence_length * params.qk_head_size; + p.v_strideH = params.max_sequence_length * params.v_head_size; + p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; + + p.q_strideM = params.num_heads * params.qk_head_size; + p.k_strideM = params.qk_head_size; + p.v_strideM = params.v_head_size; + p.o_strideM = params.num_heads * params.v_head_size; + p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; + + p.q_strideB = params.num_heads * params.qk_head_size * params.sequence_length; + p.k_strideB = params.num_heads * params.qk_head_size * params.max_sequence_length; + p.v_strideB = params.num_heads * params.v_head_size * params.max_sequence_length; + p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + } } constexpr auto kernel_fn = attention_kernel_batched_impl; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h index f725be8d7cf89..f16567bb6f2b7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h @@ -14,10 +14,12 @@ namespace cuda { struct MemoryEfficientAttentionParams { int32_t sm; bool is_half; + bool is_kv_bsnh = true; int32_t batch_size; int32_t num_heads; int32_t sequence_length; int32_t kv_sequence_length; + int32_t max_sequence_length; int32_t qk_head_size; int32_t v_head_size; bool causal; diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc index 4bdc6db30b036..54aad9cbaf387 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc @@ -22,6 +22,7 @@ static constexpr int kBeamWidthInputIndex = 8; static constexpr int kCacheIndirectionInputIndex = 9; static constexpr int kPastInputIndex = 5; static constexpr int kPresentOutputIndex = 1; +static constexpr int kQKOutputIndex = 3; static constexpr int kBiasIndex = 10; #define REGISTER_KERNEL_TYPED(T1, T2) \ @@ -50,6 +51,7 @@ DecoderMaskedMultiHeadAttention::DecoderMaskedMultiHeadAttention(const O mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); scale_ = info.GetAttrOrDefault("scale", 0.0f); past_present_share_buffer_ = info.GetAttrOrDefault("past_present_share_buffer", 0LL); + output_qk_ = info.GetAttrOrDefault("output_qk", 0LL); } template @@ -98,7 +100,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* // This kernel is for decoding only (i.e.) sequence length has to be 1 if (sequence_length != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input sequence length should be 1 to use DecoderMaskedMultiHeadAttention"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input sequence length should be 1 to use DecoderMaskedMultiHeadAttention. Actual length is ", sequence_length); } if (parameters.head_size != parameters.v_head_size) { @@ -125,6 +127,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* TensorShape present_shape(present_dims); Tensor* present_key = context->Output(kPresentOutputIndex, present_shape); Tensor* present_value = context->Output(kPresentOutputIndex + 1, present_shape); + Tensor* cross_qk = nullptr; auto cuda_stream = Stream(context); @@ -191,6 +194,13 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* parameters.v_cache = present_value_data; } + if (output_qk_) { + int64_t qk_dims[] = {parameters.batch_size, parameters.num_heads, 1, parameters.total_sequence_length}; + TensorShape qk_shape(&qk_dims[0], sizeof(qk_dims) / sizeof(qk_dims[0])); + cross_qk = context->Output(kQKOutputIndex, qk_shape); + parameters.out_qk = cross_qk->MutableData(); + } + parameters.out = output->MutableDataRaw(); // Scale diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.h index 8200a66db383f..b5476e6b54c44 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.h @@ -22,6 +22,7 @@ class DecoderMaskedMultiHeadAttention final : public CudaKernel { float mask_filter_value_; float scale_; bool past_present_share_buffer_; + bool output_qk_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index c8877a5e3f872..33e7a33494778 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -427,6 +427,15 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio // Compute the logits and start the sum. float sum = 0.f; int sum_tlength = params.is_cross_attention ? tlength - 1 : tlength; + + if (params.out_qk != nullptr) { + // store cross qk before softmax, out_qk has shape [B(batchxbeam), #Head, 1, total_sequence_length] + float* target = ((float*)params.out_qk) + ((int64_t)bhi * tlength); + for (int ti = tidx; ti <= sum_tlength; ti += THREADS_PER_BLOCK) { + target[ti] = (float)(qk_smem[ti]); + } + } + for (int ti = tidx; ti <= sum_tlength; ti += THREADS_PER_BLOCK) { // This is a deviation from FasterTransformer kernel implementation // but this aligns with ORT's other Attention kernels which strives to diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h index 6d7f368db4dd4..4b408dafa2d81 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h @@ -37,6 +37,7 @@ struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters { void* v_cache = nullptr; void* out = nullptr; + void* out_qk = nullptr; const int32_t* cache_indir = nullptr; const int32_t* mask = nullptr; // [B, total_sequence_length] diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h index 0aaf5e5f1ba28..89e2351428d40 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h @@ -18,81 +18,89 @@ constexpr int D_DIM = 2; struct Qkv_params { using index_t = uint32_t; // The QKV matrices. - void* __restrict__ q_ptr; - void* __restrict__ k_ptr; - void* __restrict__ v_ptr; + void* __restrict__ q_ptr = nullptr; + void* __restrict__ k_ptr = nullptr; + void* __restrict__ v_ptr = nullptr; // The stride between rows of the Q, K and V matrices. - index_t q_batch_stride; - index_t k_batch_stride; - index_t v_batch_stride; - index_t q_row_stride; - index_t k_row_stride; - index_t v_row_stride; - index_t q_head_stride; - index_t k_head_stride; - index_t v_head_stride; + index_t q_batch_stride = 0; + index_t k_batch_stride = 0; + index_t v_batch_stride = 0; + index_t q_row_stride = 0; + index_t k_row_stride = 0; + index_t v_row_stride = 0; + index_t q_head_stride = 0; + index_t k_head_stride = 0; + index_t v_head_stride = 0; // The number of heads. - int h, h_k; + int h = 0; + int h_k = 0; // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be // different from nheads (query). - int h_h_k_ratio; // precompute h / h_k, + int h_h_k_ratio = 0; // precompute h / h_k, }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Flash_fwd_params : public Qkv_params { // The O matrix (output). - void* __restrict__ o_ptr; - void* __restrict__ oaccum_ptr; + void* __restrict__ o_ptr = nullptr; + void* __restrict__ oaccum_ptr = nullptr; // The stride between rows of O. - index_t o_batch_stride; - index_t o_row_stride; - index_t o_head_stride; + index_t o_batch_stride = 0; + index_t o_row_stride = 0; + index_t o_head_stride = 0; // The pointer to the P matrix. - void* __restrict__ p_ptr; + void* __restrict__ p_ptr = nullptr; // The pointer to the softmax sum. - void* __restrict__ softmax_lse_ptr; - void* __restrict__ softmax_lseaccum_ptr; + void* __restrict__ softmax_lse_ptr = nullptr; + void* __restrict__ softmax_lseaccum_ptr = nullptr; // The dimensions. - int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; + int b = 0; + int seqlen_q = 0; + int seqlen_k = 0; + int seqlen_knew = 0; + int d = 0; + int seqlen_q_rounded = 0; + int seqlen_k_rounded = 0; + int d_rounded = 0; // The scaling factors for the kernel. - float scale_softmax; - float scale_softmax_log2; + float scale_softmax = 0.0; + float scale_softmax_log2 = 0.0; // array of length b+1 holding starting offset of each sequence. - int* __restrict__ cu_seqlens_q; - int* __restrict__ cu_seqlens_k; + int* __restrict__ cu_seqlens_q = nullptr; + int* __restrict__ cu_seqlens_k = nullptr; - int* __restrict__ blockmask; + int* __restrict__ blockmask = nullptr; // The K_new and V_new matrices. - void* __restrict__ knew_ptr; - void* __restrict__ vnew_ptr; + void* __restrict__ knew_ptr = nullptr; + void* __restrict__ vnew_ptr = nullptr; // The stride between rows of the Q, K and V matrices. - index_t knew_batch_stride; - index_t vnew_batch_stride; - index_t knew_row_stride; - index_t vnew_row_stride; - index_t knew_head_stride; - index_t vnew_head_stride; + index_t knew_batch_stride = 0; + index_t vnew_batch_stride = 0; + index_t knew_row_stride = 0; + index_t vnew_row_stride = 0; + index_t knew_head_stride = 0; + index_t vnew_head_stride = 0; bool is_bf16 = false; - bool is_causal; + bool is_causal = false; // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. - bool is_seqlens_k_cumulative; - int num_splits; // For split-KV version + bool is_seqlens_k_cumulative = true; + int num_splits = 0; // For split-KV version - const cudaDeviceProp* dprops; + const cudaDeviceProp* dprops = nullptr; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index 805a73be96778..89a27c4d2b0d3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -140,11 +140,10 @@ void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split // So we find the best efficiency, then find the smallest number of splits that gets 85% // of the best efficiency. int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs, - int max_splits, bool new_kv, bool is_sm8x) { + int max_splits) { // This needs to match with run_mha_fwd_splitkv_dispatch - const int block_n = is_sm8x ? (head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64)) - : (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64)); - const int num_n_blocks = (seqlen_k + (!new_kv ? 0 : seqlen_q) + block_n - 1) / block_n; + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. // In any case we don't expect seqlen_q to be larger than 64 for inference. const int num_m_blocks = (seqlen_q + 64 - 1) / 64; @@ -190,6 +189,26 @@ int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_hea return 1; } +// Returns (num_splits, softmax_lse_accum bytes, out_accum bytes) +std::tuple get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads, + int head_size, int num_SMs) { + int max_splits = 128; + // split kv buffers + int num_splits = num_splits_heuristic(batch_size, seqlen_q, seqlen_k, num_heads, head_size, + num_SMs, max_splits); + if (num_splits > 1) { + // softmax_lse_accum buffer + int softmax_lse_accum_bytes = get_softmax_lse_accum_size(num_splits, batch_size, num_heads, seqlen_q); + // out_accum buffer + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + int out_accum_bytes = get_out_accum_size(num_splits, batch_size, num_heads, seqlen_q, head_size_rounded); + return {num_splits, softmax_lse_accum_bytes, out_accum_bytes}; + } else { + return {0, 0, 0}; + } +} + Status mha_fwd(const cudaDeviceProp& dprops, cudaStream_t stream, void* q, // batch_size x seqlen_q x num_heads x head_size @@ -215,7 +234,6 @@ Status mha_fwd(const cudaDeviceProp& dprops, const int seqlen_k_rounded = round_multiple(seqlen_k, 128); Flash_fwd_params params; - params.dprops = &dprops; set_params_fprop(params, batch_size, seqlen_q, seqlen_k, @@ -230,7 +248,7 @@ Status mha_fwd(const cudaDeviceProp& dprops, softmax_scale, is_causal, kv_bsnh); - + params.dprops = &dprops; params.knew_ptr = nullptr; params.vnew_ptr = nullptr; params.knew_batch_stride = 0; @@ -276,7 +294,6 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); Flash_fwd_params params; - params.dprops = &dprops; set_params_fprop(params, batch_size, max_seqlen_q, max_seqlen_k, @@ -290,6 +307,12 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, softmax_lse, softmax_scale, is_causal); + params.dprops = &dprops; + params.num_splits = 0; + params.softmax_lseaccum_ptr = nullptr; + params.oaccum_ptr = nullptr; + params.knew_ptr = nullptr; + params.vnew_ptr = nullptr; run_mha_fwd(params, stream); return Status::OK(); } @@ -336,7 +359,6 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, const int seqlen_k_rounded = round_multiple(seqlen_k, 128); Flash_fwd_params params; - params.dprops = &dprops; set_params_fprop(params, batch_size, seqlen_q, seqlen_k, @@ -351,6 +373,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, softmax_scale, is_causal, past_bsnh); + params.dprops = &dprops; if (k != nullptr && v != nullptr) { params.seqlen_knew = seqlen_k_new; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 0a0328edb0059..58f4304251872 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -31,6 +31,7 @@ #if USE_FLASH_ATTENTION #include "core/providers/cuda/cuda_common.h" +#include namespace onnxruntime { namespace flash { @@ -99,10 +100,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, ); size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); -size_t get_softmax_lse_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q); -size_t get_out_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q, int head_size_rounded); -int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs, int max_splits, bool new_kv, bool is_sm8x); +std::tuple get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads, + int head_size, int num_SMs); bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h index e0be6b828f85d..82dfa59b8f8e7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h @@ -12,18 +12,30 @@ namespace flash { template __global__ void flash_fwd_kernel(Flash_fwd_params params) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 flash::compute_attn(params); +#else + (void)params; +#endif } template __global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 flash::compute_attn_splitkv(params); +#else + (void)params; +#endif } template __global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 static_assert(Log_max_splits >= 1); flash::combine_attn_seqk_parallel(params); +#else + (void)params; +#endif } template @@ -111,17 +123,9 @@ void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) { - bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; constexpr int kBlockM = 64; // Fixed for all head dimensions - if (!is_sm8x) { // A100, H100 - // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, - // and for headdim 192 with block size 64 x 128. - constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 160 ? 128 : 64); - run_flash_splitkv_fwd>(params, stream); - } else { // Only 99KB of smem, so we have to set kBlockN smaller for Headdim 160 and above - constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); - run_flash_splitkv_fwd>(params, stream); - } + constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + run_flash_splitkv_fwd>(params, stream); } template diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 65d19d4473872..8694dc998c7a8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -6,9 +6,8 @@ #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cuda/bert/group_query_attention.h" #include "contrib_ops/cuda/bert/group_query_attention_helper.h" +#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" -// #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" -// #include "contrib_ops/cpu/utils/console_dumper.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -55,6 +54,13 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) #else disable_flash_attention_ = true; #endif + +#if USE_MEMORY_EFFICIENT_ATTENTION + disable_memory_efficient_attention_ = sizeof(T) != 2 || + ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); +#else + disable_memory_efficient_attention_ = true; +#endif } template @@ -92,18 +98,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { output_shape[2] = static_cast(parameters.hidden_size); Tensor* output = context->Output(0, output_shape); - std::vector present_dims; - if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { - present_dims = { - parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size}; - } else { // BNSH - present_dims = { - parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size}; - } - TensorShape present_shape(present_dims); - Tensor* present_key = context->Output(1, present_shape); - Tensor* present_value = context->Output(2, present_shape); - #if USE_FLASH_ATTENTION bool use_flash_attention = !disable_flash_attention_ && onnxruntime::flash::is_supported(device_prop, @@ -116,22 +110,16 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { size_t out_accum_bytes = 0; size_t seqlens_k_bytes = 0; if (use_flash_attention) { + // softmax buffer softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.sequence_length, parameters.batch_size, parameters.num_heads); - // split kv buffers - parameters.num_splits = onnxruntime::flash::num_splits_heuristic( + // split kv buffer + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, - parameters.head_size, device_prop.multiProcessorCount, 128, false, - device_prop.major == 8 && device_prop.minor > 0); - if (parameters.num_splits > 1) { - // softmax_lse_accum buffer - softmax_lse_accum_bytes = onnxruntime::flash::get_softmax_lse_accum_size( - parameters.num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length); - // out_accum buffer - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = round_multiple(parameters.head_size, 32); - out_accum_bytes = onnxruntime::flash::get_out_accum_size( - parameters.num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length, head_size_rounded); - } + parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; // seqlens_k buffer if (past_key != nullptr) { seqlens_k_bytes = sizeof(int) * parameters.batch_size; @@ -149,8 +137,47 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { auto seqlens_k_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif - // only kernel implemented for gqa right now - ORT_ENFORCE(use_flash_attention); +#if USE_MEMORY_EFFICIENT_ATTENTION + int sm = (device_prop.major * 10) + device_prop.minor; + bool use_memory_efficient_attention = + !use_flash_attention && + !disable_memory_efficient_attention_ && + (parameters.head_size & 7) == 0 && + parameters.sequence_length <= parameters.past_sequence_length + parameters.kv_sequence_length && + (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && + has_memory_efficient_attention(sm, sizeof(T) == 2); + // allocate buffers + size_t kv_buffer_bytes = 0; + // need a buffer if we must ungroup kv + const bool needs_buff = (parameters.num_heads != parameters.kv_num_heads); + if (use_memory_efficient_attention && needs_buff) { + kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * (parameters.past_sequence_length + parameters.kv_sequence_length) * parameters.head_size); + } + size_t fmha_buffer_bytes = 0; + if (use_memory_efficient_attention && MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) { + fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float)); + } + auto k_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); + auto v_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); + auto fmha_buffer = GetScratchBuffer(fmha_buffer_bytes, context->GetComputeStream()); +#else + constexpr bool use_memory_efficient_attention = false; + auto k_buffer = GetScratchBuffer(0, context->GetComputeStream()); + auto v_buffer = GetScratchBuffer(0, context->GetComputeStream()); + auto fmha_buffer = GetScratchBuffer(0, context->GetComputeStream()); +#endif + + std::vector present_dims; + if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { + present_dims = { + parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size}; + } else { // BNSH + present_dims = { + parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size}; + } + TensorShape present_shape(present_dims); + Tensor* present_key = context->Output(1, present_shape); + Tensor* present_value = context->Output(2, present_shape); data.query = reinterpret_cast(query->Data()); data.key = reinterpret_cast(key->Data()); @@ -161,6 +188,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData()); data.use_flash_attention = use_flash_attention; + data.use_memory_efficient_attention = use_memory_efficient_attention; if (softmax_lse_buffer != nullptr) { data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); } @@ -173,6 +201,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { if (seqlens_k_buffer != nullptr) { data.seqlens_k = reinterpret_cast(seqlens_k_buffer.get()); } + if (k_buffer != nullptr) { + data.k = reinterpret_cast(k_buffer.get()); + data.v = reinterpret_cast(v_buffer.get()); + } + if (fmha_buffer != nullptr) { + data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); + } cublasHandle_t cublas = GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 72c9814fad670..a90418ec2243a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -27,6 +27,7 @@ class GroupQueryAttention final : public CudaKernel { bool is_past_bsnh_; float scale_; bool disable_flash_attention_; + bool disable_memory_efficient_attention_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h index be8f5ca0ae3e9..8c21de9ced058 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -29,13 +29,13 @@ Status CheckInputs(const Tensor* query, // query (Q) : (B, S, D) // key (K) : (B, S+, D_kv) // value (V) : (B, S+, D_kv) + ORT_UNUSED_PARAMETER(value); AttentionQkvFormat qkv_format = Q_K_V_BSNH; AttentionQkvFormat past_kv_format = Q_K_V_BSNH; const auto& query_dims = query->Shape().GetDims(); const auto& key_dims = key->Shape().GetDims(); - const auto& value_dims = value->Shape().GetDims(); if (query_dims.size() != 3) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", @@ -47,10 +47,8 @@ Status CheckInputs(const Tensor* query, int q_hidden_size = static_cast(query_dims[2]); int head_size = static_cast(q_hidden_size) / num_heads; - int kv_sequence_length = sequence_length; - int kv_hidden_size = (key_dims.size() == 3) - ? static_cast(key_dims[2]) - : (kv_num_heads * static_cast(key_dims[3])); + int kv_sequence_length = static_cast(key_dims[1]); + int kv_hidden_size = static_cast(key_dims[2]); int max_sequence_length = 0; if (past_key != nullptr && past_value != nullptr) { @@ -134,63 +132,49 @@ Status CheckInputs(const Tensor* query, "Input 'past_key' and 'past_value' shall be both present or both absent"); } - if (key != nullptr) { - const auto& key_dims = key->Shape().GetDims(); - if (key_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", - key_dims.size()); - } - if (query_dims[0] != key_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); - } - - if (num_heads % kv_num_heads != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", - num_heads % kv_num_heads); - } - if (key_dims[2] != value_dims[2]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall have same dim 2 (kv_hidden_size)"); - } - - qkv_format = Q_K_V_BSNH; - kv_sequence_length = static_cast(key_dims[1]); - } else { + if (key_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", + key_dims.size()); + } + if (query_dims[0] != key_dims[0]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Missing key tensor."); + "Input 'query' and 'key' shall have same dim 0 (batch size)"); } - if (value != nullptr) { - const auto& value_dims = value->Shape().GetDims(); - if (value_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", - value_dims.size()); - } + if (num_heads % kv_num_heads != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", + num_heads % kv_num_heads); + } - if (query_dims[0] != value_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 0 (batch_size)"); - } + const auto& value_dims = value->Shape().GetDims(); + if (value_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", + value_dims.size()); + } - if (static_cast(kv_sequence_length) != value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)"); - } + if (query_dims[0] != value_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 0 (batch_size)"); + } - if (value_dims[2] != kv_hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); - } - } else { + if (static_cast(kv_sequence_length) != value_dims[1]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Missing value tensor."); + "Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)"); + } + + if (value_dims[2] != kv_hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); } // When kv-cache, we take past_seq_len as an argument... otherwise we use sequence length of past kv directly. int32_t past_sequence_length = 0; - int present_sequence_length = 0; + int present_sequence_length = kv_sequence_length; if (past_seq_len != nullptr) { + if (past_key == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Past KV must be present as share-buffer when using past_seq_len pointer."); + } if (!onnxruntime::IsScalarOr1ElementVector(past_seq_len)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "past_sequence_length tensor must be of one element when using past kv."); @@ -200,6 +184,10 @@ Status CheckInputs(const Tensor* query, } else { past_sequence_length = static_cast(*((*past_seq_len).template Data())); } + if (past_sequence_length + kv_sequence_length > max_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "KV buffer too small... shall be that max_sequence_length >= past_sequence_length + kv_sequence_length"); + } present_sequence_length = max_sequence_length; } else if (past_key != nullptr) { past_sequence_length = max_sequence_length; // this is the length of past_key tensor diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index ab3029ca34886..0455825c364a2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -37,6 +37,7 @@ limitations under the License. #include "contrib_ops/cpu/bert/attention_base.h" #include "contrib_ops/cuda/bert/bert_padding.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cuda/bert/attention_impl.h" @@ -47,6 +48,8 @@ namespace onnxruntime { namespace contrib { namespace cuda { +////////// Auxiliary Kernels for KV prep + // Kernel for seqlens_k __global__ void repeat_seqlen(int32_t* seqlens_k, int32_t seqlen, int batch_size) { int id = blockDim.x * blockIdx.x + threadIdx.x; @@ -75,7 +78,7 @@ __global__ void ConcatNewToPastKV(const int new_seqlen, const int present_head_stride = is_bsnh ? H : present_seqlen * H; // past_kv: BPNH or BNPH - // new_kv: BLNH or BNLH + // new_kv: BLNH // present_kv: BTNH or BNTH, where T = P + L const int past_seqlen = present_seqlen - new_seqlen; @@ -95,33 +98,32 @@ __global__ void ConcatNewToPastKV(const int new_seqlen, } } +// Use when (H*)*num_heads > 1024 template __global__ void ConcatNewToPastKVLarge(const int new_seqlen, const int H, + const int num_heads, const T* past_kv, const T* new_kv, T* present_kv, const bool is_bsnh) { - // Use when (H*)*num_heads > 1024 - int h = threadIdx.x; - const int n = threadIdx.y; - const int s = blockIdx.x; - const int b = blockIdx.y; + int i = threadIdx.x + (blockDim.x * blockIdx.x); + if (i < H * num_heads) { + const int h = i % H; + const int n = i / H; + const int s = blockIdx.y; + const int b = blockIdx.z; + const int present_seqlen = gridDim.y; + + const int present_batch_stride = present_seqlen * num_heads * H; + const int row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_seqlen * H; + + // past_kv: BPNH or BNPH + // new_kv: BLNH + // present_kv: BTNH or BNTH, where T = P + L + const int past_seqlen = present_seqlen - new_seqlen; - const int present_seqlen = gridDim.x; - const int num_heads = blockDim.y; - const int thread_stride = blockDim.x; - - const int present_batch_stride = present_seqlen * num_heads * H; - const int row_stride = is_bsnh ? num_heads * H : H; - const int present_head_stride = is_bsnh ? H : present_seqlen * H; - - // past_kv: BPNH or BNPH - // new_kv: BLNH or BNLH - // present_kv: BTNH or BNTH, where T = P + L - const int past_seqlen = present_seqlen - new_seqlen; - - while (h < H) { int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; if (s < past_seqlen) { const int past_batch_stride = past_seqlen * num_heads * H; @@ -135,133 +137,477 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen, const int in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; present_kv[out_offset] = new_kv[in_offset]; } - h += thread_stride; } } +// Concat new to past in present. Supports past BSNH or past BNSH template -Status QkvToContext( +Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int kv_sequence_length = parameters.kv_sequence_length; + const int present_sequence_length = parameters.present_sequence_length; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + + assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + const int H = head_size / 4; // divide by 4 so kernel can operate on 4 float16 elements at a time. + if (H * kv_num_heads <= max_threads_per_block) { + const dim3 grid(present_sequence_length, batch_size, 1); + const dim3 block(H, kv_num_heads, 1); + ConcatNewToPastKV<<>>(kv_sequence_length, + reinterpret_cast(data.past_key), + reinterpret_cast(data.key), + reinterpret_cast(data.present_key), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatNewToPastKV<<>>(kv_sequence_length, + reinterpret_cast(data.past_value), + reinterpret_cast(data.value), + reinterpret_cast(data.present_value), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } else { + int steps = (H * kv_num_heads + 255) / 256; + const dim3 grid(steps, present_sequence_length, batch_size); + const dim3 block(256, 1, 1); + ConcatNewToPastKVLarge<<>>(kv_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.past_key), + reinterpret_cast(data.key), + reinterpret_cast(data.present_key), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatNewToPastKVLarge<<>>(kv_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.past_value), + reinterpret_cast(data.value), + reinterpret_cast(data.present_value), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } + return CUDA_CALL(cudaGetLastError()); +} + +// Kernel to append new kv to kv buffer in place +template +__global__ void ConcatKVInPlace(const int past_seqlen, + const int present_seqlen, + T* kv_buff, + const T* new_kv, + const bool is_bsnh) { // refers to kv buff; otherwise bnsh + const int h = threadIdx.x; + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int new_seqlen = gridDim.x; + const int num_heads = blockDim.y; + const int H = blockDim.x; + + const int present_batch_stride = present_seqlen * num_heads * H; + const int present_row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_seqlen * H; + + // kv_buff: BTNH or BNTH with buffered memory for new + // new_kv: BLNH + + int out_offset = b * present_batch_stride + (s + past_seqlen) * present_row_stride + n * present_head_stride + h; + // Note: new KV always BSNH + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h; + kv_buff[out_offset] = new_kv[in_offset]; +} + +template +__global__ void ConcatKVInPlaceLarge(const int past_seqlen, + const int present_seqlen, + const int H, + const int num_heads, + T* kv_buff, + const T* new_kv, + const bool is_bsnh) { // refers to kv buff; otherwise bnsh + int i = threadIdx.x + (blockDim.x * blockIdx.x); + if (i < H * num_heads) { + const int h = i % H; + const int n = i / H; + const int s = blockIdx.y; + const int b = blockIdx.z; + const int new_seqlen = gridDim.y; + + const int present_batch_stride = present_seqlen * num_heads * H; + const int present_row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_seqlen * H; + + // kv_buff: BTNH or BNTH with buffered memory for new + // new_kv: BLNH + + int out_offset = b * present_batch_stride + (s + past_seqlen) * present_row_stride + n * present_head_stride + h; + // Note: new KV always BSNH + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h; + kv_buff[out_offset] = new_kv[in_offset]; + } +} + +// Concat new to kv buffer in place +template +Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int kv_sequence_length = parameters.kv_sequence_length; + const int present_sequence_length = parameters.present_sequence_length; + const int past_sequence_length = parameters.past_sequence_length; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + const int H = head_size / 4; + if (H * kv_num_heads <= max_threads_per_block) { + const dim3 grid(kv_sequence_length, batch_size, 1); + const dim3 block(H, kv_num_heads, 1); + ConcatKVInPlace<<>>(past_sequence_length, + present_sequence_length, + reinterpret_cast(data.present_key), + reinterpret_cast(data.key), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatKVInPlace<<>>(past_sequence_length, + present_sequence_length, + reinterpret_cast(data.present_value), + reinterpret_cast(data.value), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } else { + int steps = int(ceil(float(H * kv_num_heads) / 256.0)); + const dim3 grid(steps, kv_sequence_length, batch_size); + const dim3 block(256, 1, 1); + ConcatKVInPlaceLarge<<>>(past_sequence_length, + present_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.present_key), + reinterpret_cast(data.key), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatKVInPlaceLarge<<>>(past_sequence_length, + present_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.present_value), + reinterpret_cast(data.value), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } + return CUDA_CALL(cudaGetLastError()); +} + +// Kernel for use with memory efficient kernel... kv_in is grouped and of bnsh or bsnh... kv_out is ungrouped and bsnh +template +__global__ void Ungroup(const T* kv_in, + T* kv_out, + const int in_seqlen, + const int kv_num_heads, + const bool is_bsnh) { + const int h = threadIdx.x; + const int out_n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int out_seqlen = gridDim.x; + const int q_num_heads = blockDim.y; + const int H = blockDim.x; + + const int q_kv_head_ratio = q_num_heads / kv_num_heads; + const int out_batch_stride = out_seqlen * q_num_heads * H; + const int out_row_stride = is_bsnh ? q_num_heads * H : H; + const int out_head_stride = is_bsnh ? H : out_seqlen * H; + + const int in_batch_stride = in_seqlen * kv_num_heads * H; + const int in_row_stride = is_bsnh ? kv_num_heads * H : H; + const int in_head_stride = is_bsnh ? H : in_seqlen * H; + const int in_n = out_n / q_kv_head_ratio; + + const int out_offset = out_batch_stride * b + out_row_stride * s + out_head_stride * out_n + h; + const int in_offset = in_batch_stride * b + in_row_stride * s + in_head_stride * in_n + h; + kv_out[out_offset] = kv_in[in_offset]; +} + +template +__global__ void UngroupLarge(const T* kv_in, + T* kv_out, + const int H, + const int in_seqlen, + const int q_num_heads, + const int kv_num_heads, + const bool is_bsnh) { + int i = threadIdx.x + (blockDim.x * blockIdx.x); // index along H * q_num_heads elements + if (i < H * q_num_heads) { + const int out_seqlen = gridDim.y; + const int s = blockIdx.y; + const int b = blockIdx.z; + + const int q_kv_head_ratio = q_num_heads / kv_num_heads; + const int out_batch_stride = out_seqlen * q_num_heads * H; + const int out_row_stride = is_bsnh ? q_num_heads * H : H; + const int out_head_stride = is_bsnh ? H : out_seqlen * H; + + const int in_batch_stride = in_seqlen * kv_num_heads * H; + const int in_row_stride = is_bsnh ? kv_num_heads * H : H; + const int in_head_stride = is_bsnh ? H : in_seqlen * H; + + const int h = i % H; + const int out_n = i / H; + const int in_n = out_n / q_kv_head_ratio; + const int out_offset = out_batch_stride * b + out_row_stride * s + out_head_stride * out_n + h; + const int in_offset = in_batch_stride * b + in_row_stride * s + in_head_stride * in_n + h; + kv_out[out_offset] = kv_in[in_offset]; + } +} + +// Ungroup kv or present kv for use in Memory Efficient kernel. If present kv is not null and is BNSH, transposes it. +Status LaunchUngroup(contrib::GroupQueryAttentionParameters& parameters, + float2* k_buff, float2* v_buff, + const float2* k_og, const float2* v_og, + const int buff_seqlen, const int og_seqlen, + const bool is_bsnh, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + + const int H = head_size / 4; + if (H * num_heads <= max_threads_per_block) { + const dim3 grid(buff_seqlen, batch_size, 1); + const dim3 block(H, num_heads, 1); + Ungroup<<>>(k_og, + k_buff, + og_seqlen, + kv_num_heads, + is_bsnh); + Ungroup<<>>(v_og, + v_buff, + og_seqlen, + kv_num_heads, + is_bsnh); + } else { + int steps = int(ceil(float(H * num_heads) / 256.0)); + const dim3 grid(steps, buff_seqlen, batch_size); + const dim3 block(256, 1, 1); + UngroupLarge<<>>(k_og, + k_buff, + H, + og_seqlen, + num_heads, + kv_num_heads, + is_bsnh); + UngroupLarge<<>>(v_og, + v_buff, + H, + og_seqlen, + num_heads, + kv_num_heads, + is_bsnh); + } + return CUDA_CALL(cudaGetLastError()); +} + +////////// Launch Kernels + +#if USE_FLASH_ATTENTION +template +Status FlashAttention( const cudaDeviceProp& device_prop, - cublasHandle_t& cublas, - Stream* ort_stream, + cudaStream_t stream, contrib::GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data) { - assert(data.use_flash_attention); + GroupQueryAttentionData& data, + float scale) { + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int present_sequence_length = parameters.present_sequence_length; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; -#if USE_FLASH_ATTENTION - auto stream = static_cast(ort_stream->GetHandle()); + void* query = reinterpret_cast(const_cast(data.query)); + void* key = reinterpret_cast(const_cast(data.key)); + void* value = reinterpret_cast(const_cast(data.value)); + + bool is_causal = parameters.is_unidirectional; + + if (data.past_key != nullptr && data.past_key == data.present_key) { + // Share buffer case + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + + // Launch kernel to copy seqlen + int thr_per_blk = 256; + int blk_in_grid = ceil(float(batch_size) / thr_per_blk); + repeat_seqlen<<>>(data.seqlens_k, parameters.past_sequence_length, batch_size); + + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( + device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse), + reinterpret_cast(data.seqlens_k), batch_size, num_heads, kv_num_heads, + head_size, sequence_length, present_sequence_length, kv_sequence_length, + scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum))); + + } else { + // Not share buffer or no past (prompt generation) + // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( + device_prop, stream, query, present_key, present_value, data.output, reinterpret_cast(data.softmax_lse), + batch_size, num_heads, kv_num_heads, head_size, + sequence_length, present_sequence_length, scale, is_causal, parameters.num_splits, + reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), past_bsnh)); + } + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); + + return Status::OK(); +} +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION +template +Status EfficientAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + float scale) { const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int kv_sequence_length = parameters.kv_sequence_length; + const int past_sequence_length = parameters.past_sequence_length; const int present_sequence_length = parameters.present_sequence_length; const int num_heads = parameters.num_heads; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; AttentionQkvFormat past_kv_format = parameters.past_kv_format; - const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(head_size)) : parameters.scale; - if (data.use_flash_attention) { - assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); - assert(parameters.num_heads % parameters.kv_num_heads == 0); - - void* query = reinterpret_cast(const_cast(data.query)); - void* key = reinterpret_cast(const_cast(data.key)); - void* value = reinterpret_cast(const_cast(data.value)); - - bool is_causal = parameters.is_unidirectional; - - if (data.past_key == nullptr && data.present_key == nullptr) { - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( - device_prop, stream, query, key, value, data.output, reinterpret_cast(data.softmax_lse), - parameters.batch_size, parameters.num_heads, parameters.kv_num_heads, head_size, - parameters.sequence_length, parameters.kv_sequence_length, scale, is_causal, parameters.num_splits, - reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum))); - - } else if (data.past_key == data.present_key) { - // Assume past and present kv share buffer. - assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); - assert(parameters.past_sequence_length >= 0); - assert(data.past_value != nullptr); - - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); - - // Launch kernel to copy seqlen - int thr_per_blk = 256; - int blk_in_grid = ceil(float(batch_size) / thr_per_blk); - repeat_seqlen<<>>(data.seqlens_k, parameters.past_sequence_length, batch_size); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR("seqlens_k", data.seqlens_k, 1, batch_size); - - bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( - device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse), - reinterpret_cast(data.seqlens_k), batch_size, num_heads, kv_num_heads, - head_size, sequence_length, present_sequence_length, kv_sequence_length, - scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum))); - - } else if (data.present_key != nullptr && (data.past_key != nullptr || kv_sequence_length == present_sequence_length)) { - assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); - // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient - if (head_size % 4 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "requires head_size be divisible by 4"); - } - const int H = head_size / 4; - if (H * kv_num_heads <= max_threads_per_block) { - const dim3 grid(present_sequence_length, batch_size, 1); - const dim3 block(H, kv_num_heads, 1); - ConcatNewToPastKV<<>>(kv_sequence_length, - reinterpret_cast(data.past_key), - reinterpret_cast(data.key), - reinterpret_cast(data.present_key), - past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - ConcatNewToPastKV<<>>(kv_sequence_length, - reinterpret_cast(data.past_value), - reinterpret_cast(data.value), - reinterpret_cast(data.present_value), - past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - } else { - const dim3 grid(present_sequence_length, batch_size, 1); - const dim3 block(max_threads_per_block / kv_num_heads, kv_num_heads, 1); - ConcatNewToPastKVLarge<<>>(kv_sequence_length, - H, - reinterpret_cast(data.past_key), - reinterpret_cast(data.key), - reinterpret_cast(data.present_key), - past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - ConcatNewToPastKVLarge<<>>(kv_sequence_length, - H, - reinterpret_cast(data.past_value), - reinterpret_cast(data.value), - reinterpret_cast(data.present_value), - past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - } - - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); - - // Launch kernel to copy seqlen - int thr_per_blk = 256; - int blk_in_grid = ceil(float(batch_size) / thr_per_blk); - repeat_seqlen<<>>(data.seqlens_k, parameters.past_sequence_length, batch_size); - - bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( - device_prop, stream, query, present_key, present_value, data.output, reinterpret_cast(data.softmax_lse), - batch_size, num_heads, kv_num_heads, head_size, - sequence_length, present_sequence_length, scale, is_causal, parameters.num_splits, - reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), past_bsnh)); + const void* query = reinterpret_cast(data.query); + const void* key = reinterpret_cast(data.key); + const void* value = reinterpret_cast(data.value); + if (data.past_key != nullptr) { + // Past key case + // concatenate new kv to past kv + if (data.past_key == data.present_key) { + ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block)); + } else { + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); } + const bool is_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + if (num_heads == kv_num_heads) { + // Use present kv directly if not grouped + key = reinterpret_cast(data.present_key); + value = reinterpret_cast(data.present_value); + } else { + // Otherwise we use intermediate buffers to run memory efficient attention... best avoid this path + float2* k_buff = reinterpret_cast(data.k); + float2* v_buff = reinterpret_cast(data.v); + const float2* k_og = reinterpret_cast(data.present_key); + const float2* v_og = reinterpret_cast(data.present_value); + ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, past_sequence_length + kv_sequence_length, + present_sequence_length, is_bsnh, stream, max_threads_per_block)); + key = reinterpret_cast(data.k); + value = reinterpret_cast(data.v); + } + } else if (num_heads == kv_num_heads) { + // no past or present and no need to ungroup... still copy kv into present buffer + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + key = reinterpret_cast(data.present_key); + value = reinterpret_cast(data.present_value); + } else { + // intermediate buffer so q and kv have same num heads... still copy kv into present buffer + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + float2* k_buff = reinterpret_cast(data.k); + float2* v_buff = reinterpret_cast(data.v); + const float2* k_og = reinterpret_cast(data.present_key); + const float2* v_og = reinterpret_cast(data.present_value); + ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, kv_sequence_length, + kv_sequence_length, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH, stream, + max_threads_per_block)); + key = reinterpret_cast(data.k); + value = reinterpret_cast(data.v); + } + + MemoryEfficientAttentionParams p; + p.sm = device_prop.major * 10 + device_prop.minor; + p.is_half = sizeof(T) == 2; + p.batch_size = batch_size; + p.num_heads = num_heads; + p.sequence_length = sequence_length; + p.kv_sequence_length = past_sequence_length + kv_sequence_length; + p.max_sequence_length = (num_heads == kv_num_heads) ? present_sequence_length : past_sequence_length + kv_sequence_length; + p.qk_head_size = head_size; + p.v_head_size = head_size; + p.causal = parameters.is_unidirectional; + p.scale = scale; + p.seqlen_k_ptr = nullptr; + p.seqstart_q_ptr = nullptr; + p.seqstart_k_ptr = nullptr; + p.query = query; + p.key = key; + p.value = value; + p.attn_bias = nullptr; + p.is_attn_bias_batched = false; + p.is_kv_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + p.output = data.output; + p.workspace = MemoryEfficientAttentionParams::need_workspace(p.v_head_size, sizeof(T) == sizeof(float)) + ? data.fmha_buffer + : nullptr; + p.stream = stream; + run_memory_efficient_attention(p); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); + + return Status::OK(); +} +#endif + +////////// API Functions + +template +Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data) { + auto stream = static_cast(ort_stream->GetHandle()); + const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) : parameters.scale; - DUMP_TENSOR_INIT(); - DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); +#if USE_FLASH_ATTENTION + if (data.use_flash_attention) { + return FlashAttention(device_prop, stream, parameters, data, scale); + } +#endif - return Status::OK(); +#if USE_MEMORY_EFFICIENT_ATTENTION + if (data.use_memory_efficient_attention) { + return EfficientAttention(device_prop, stream, parameters, data, scale); } #endif + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unfused Group Query Attention not implemented yet."); } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index 0bad9eeb61231..8412631078e6a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -14,19 +14,28 @@ namespace cuda { template struct GroupQueryAttentionData { + // Input Tensors const T* query = nullptr; const T* key = nullptr; const T* value = nullptr; const T* past_key = nullptr; const T* past_value = nullptr; + // Flash buffers T* softmax_lse = nullptr; T* softmax_lse_accum = nullptr; T* out_accum = nullptr; int* seqlens_k = nullptr; + // Memory Efficient buffers + T* fmha_buffer = nullptr; + T* k = nullptr; + T* v = nullptr; + // Output Tensors T* output = nullptr; T* present_key = nullptr; T* present_value = nullptr; + // Kernel Flags bool use_flash_attention = false; + bool use_memory_efficient_attention = false; }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh b/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh index 5c083d64ee542..ff3178b56c2a6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh @@ -147,14 +147,16 @@ __device__ inline void LayerNormSmall(const T* input_v, const cub::KeyValuePair< __shared__ T rsigma; // 1 / std.dev. T beta_v[ILP], gamma_v[ILP], output_v[ILP]; - if (beta != nullptr) { - VecT* beta_val = reinterpret_cast(&beta_v); - *beta_val = *reinterpret_cast(&beta[threadIdx.x * ILP]); - } - VecT* gamma_val = reinterpret_cast(&gamma_v); - *gamma_val = *reinterpret_cast(&gamma[threadIdx.x * ILP]); + const bool is_valid = ILP * threadIdx.x < ld; + if (is_valid) { + if (beta != nullptr) { + VecT* beta_val = reinterpret_cast(&beta_v); + *beta_val = *reinterpret_cast(&beta[threadIdx.x * ILP]); + } - VecT* output_val = reinterpret_cast(&output_v); + VecT* gamma_val = reinterpret_cast(&gamma_v); + *gamma_val = *reinterpret_cast(&gamma[threadIdx.x * ILP]); + } KeyValuePairSum pair_sum; const cub::KeyValuePair sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); @@ -165,13 +167,15 @@ __device__ inline void LayerNormSmall(const T* input_v, const cub::KeyValuePair< } __syncthreads(); - if (ILP * threadIdx.x < ld) { + if (is_valid) { #pragma unroll for (int i = 0; i < ILP; i++) { output_v[i] = (beta != nullptr) ? gamma_v[i] * (input_v[i] - mu) * rsigma + beta_v[i] : gamma_v[i] * (input_v[i] - mu) * rsigma; } + + VecT* output_val = reinterpret_cast(&output_v); *(reinterpret_cast(&output[idx])) = *output_val; } } @@ -186,12 +190,15 @@ __device__ inline void SimplifiedLayerNormSmall(const T* input_v, const T& threa using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; __shared__ T rsigma; // 1 / std.dev. - T gamma_v[ILP], output_v[ILP]; - VecT* gamma_val = reinterpret_cast(&gamma_v); - *gamma_val = *reinterpret_cast(&gamma[threadIdx.x * ILP]); + const bool is_valid = ILP * threadIdx.x < ld; - VecT* output_val = reinterpret_cast(&output_v); + T gamma_v[ILP], output_v[ILP]; + + if (is_valid) { + VecT* gamma_val = reinterpret_cast(&gamma_v); + *gamma_val = *reinterpret_cast(&gamma[threadIdx.x * ILP]); + } const T sum = BlockReduce(temp_storage).Sum(thread_data); @@ -200,11 +207,13 @@ __device__ inline void SimplifiedLayerNormSmall(const T* input_v, const T& threa } __syncthreads(); - if (ILP * threadIdx.x < ld) { + if (is_valid) { #pragma unroll for (int i = 0; i < ILP; i++) { output_v[i] = gamma_v[i] * input_v[i] * rsigma; } + + VecT* output_val = reinterpret_cast(&output_v); *(reinterpret_cast(&output[idx])) = *output_val; } } diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 25f3f59165e43..ebd66d8c6528e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -153,8 +153,24 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { use_flash_attention = false; } + // Allocate buffers + size_t softmax_lse_accum_bytes = 0; + size_t out_accum_bytes = 0; + if (use_flash_attention) { + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( + parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, + parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; + } + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); #else constexpr bool use_flash_attention = false; + auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif bool use_fused_cross_attention = !use_flash_attention && @@ -194,8 +210,10 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { // Here we assume that num_heads and head_size does not change for a MultiHeadAttention node. if (nullptr == fused_fp16_runner_.get()) { constexpr bool is_unidirectional = false; - fused_fp16_runner_ = FusedMHARunnerFP16v2::Create( - num_heads_, parameters.head_size, sm, is_unidirectional, enable_trt_flash_attention_, parameters.scale); + std::call_once(fused_fp16_runner_created_, [&]() { + fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional, + enable_trt_flash_attention_, parameters.scale); + }); } // In case some kernel not loaded due to shared memory limit, we need to double check here. @@ -289,6 +307,12 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.use_memory_efficient_attention = use_memory_efficient_attention; data.cumulated_sequence_length_q_cache = &(this->cumulated_sequence_length_q_cache_); data.cumulated_sequence_length_kv_cache = &(this->cumulated_sequence_length_kv_cache_); + if (softmax_lse_accum_buffer != nullptr) { + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + } + if (out_accum_buffer != nullptr) { + data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } cublasHandle_t cublas = GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index 33fa3d50e4564..c162f7133cc1c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -32,6 +32,7 @@ class MultiHeadAttention final : public CudaKernel { bool disable_memory_efficient_attention_; int min_seq_len_for_flash_attention_packed_qkv_; mutable std::unique_ptr fused_fp16_runner_; + mutable std::once_flag fused_fp16_runner_created_; mutable const FusedMultiHeadCrossAttentionKernel* fused_fp16_cross_attention_kernel_; mutable CumulatedSequenceLengthCache cumulated_sequence_length_q_cache_; mutable CumulatedSequenceLengthCache cumulated_sequence_length_kv_cache_; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.h b/onnxruntime/contrib_ops/cuda/bert/packed_attention.h index 0cdd8021de4a1..f00c112fc73d2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.h @@ -24,10 +24,11 @@ class TrtFusedAttention { protected: MHARunner* GetFusedRunner(const cudaDeviceProp& device_prop, const PackedAttentionParameters& parameters) const; - private: + protected: bool disable_fused_runner_; bool enable_trt_flash_attention_; mutable std::unique_ptr fused_fp16_runner_; + mutable std::once_flag fused_fp16_runner_created_; }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index aba0efdbd7d5f..d7aeef1501cd6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -507,10 +507,12 @@ Status FusedScaledDotProductAttentionCutlass( MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; p.is_half = sizeof(T) == 2; + p.is_kv_bsnh = true; p.batch_size = parameters.batch_size; p.num_heads = parameters.num_heads; p.sequence_length = parameters.sequence_length; p.kv_sequence_length = parameters.sequence_length; + p.max_sequence_length = parameters.sequence_length; p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = false; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index e09fd9e6b36e5..3fe9dbf8ed34a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -688,6 +688,7 @@ Status FusedAttentionCutlass( p.num_heads = parameters.num_heads; p.sequence_length = parameters.sequence_length; p.kv_sequence_length = parameters.sequence_length; + p.max_sequence_length = parameters.sequence_length; p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = false; @@ -702,6 +703,7 @@ Status FusedAttentionCutlass( p.attn_bias = data.relative_position_bias; p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; p.output = data.output; + p.is_kv_bsnh = true; p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? (data.workspace + (data.no_qkv_workspace ? 0 : (elements_qk + elements_qk + elements_v))) : nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc new file mode 100644 index 0000000000000..b4b5dac1fbe19 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cpu/bert/rotary_embedding_helper.h" +#include "contrib_ops/cuda/bert/rotary_embedding.h" +#include "contrib_ops/cuda/bert/rotary_embedding_impl.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::contrib::rotary_embedding_helper; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + RotaryEmbedding, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()), \ + RotaryEmbedding); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +template +RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) { + scale = info.GetAttrOrDefault("scale", 1.0); + interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); +} + +template +Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* position_ids = context->Input(1); + const Tensor* cos_cache = context->Input(2); + const Tensor* sin_cache = context->Input(3); + + RotaryParameters parameters = {}; + ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs(input, + position_ids, + cos_cache, + sin_cache, + ¶meters)); + + Tensor* output = context->Output(0, input->Shape()); + + if (parameters.sequence_length > parameters.max_sequence_length) { + // Launch update_cos_sin_cache kernel with scale + ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); + } + + // Launch rotary embedding kernel + typedef typename ToCudaType::MappedType CudaT; + auto& device_prop = GetDeviceProp(); + return LaunchRotaryEmbeddingKernel( + Stream(context), + reinterpret_cast(output->template MutableData()), + reinterpret_cast(input->template Data()), + position_ids->Data(), + reinterpret_cast(cos_cache->template Data()), + reinterpret_cast(sin_cache->template Data()), + parameters.batch_size, + parameters.sequence_length, + parameters.num_heads, + parameters.head_size, + parameters.max_sequence_length, + parameters.position_ids_format, + interleaved, + device_prop.maxThreadsPerBlock); + + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h new file mode 100644 index 0000000000000..6dab2ad56749e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class RotaryEmbedding final : public CudaKernel { + public: + RotaryEmbedding(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + float scale; + bool interleaved; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu new file mode 100644 index 0000000000000..c54e72dcfce13 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -0,0 +1,141 @@ +/* +Copyright (c) Microsoft Corporation. +Licensed under the MIT License. +*/ + +/* +Kernel implementation for rotary embeddings. +*/ + +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "contrib_ops/cuda/bert/rotary_embedding_impl.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH + const T* input, // BxSxNxH + const T* cos_cache, // Mx(H/2) + const T* sin_cache, // Mx(H/2) + const int64_t* position_ids, // (1) or BxS + const int sequence_length, + const int num_heads, + const int head_size, + const int position_ids_format, + const bool interleaved) { + // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length + // Use .x in innermost loop to access global memory efficiently + + const int b = blockIdx.z; + const int s = blockIdx.y; + const int n = blockIdx.x; + + const int i = threadIdx.x; + + const int block_offset = b * sequence_length * num_heads + s * num_heads + n; + const int data_offset = block_offset * head_size; + + const T* input_data = input + data_offset; + T* output_data = output + data_offset; + + // Cache is (M, H/2) + const int half_head_size = head_size / 2; + const int position_id = (position_ids_format == 0) ? \ + static_cast(position_ids[0]) + s \ + : static_cast(position_ids[b * sequence_length + s]); + const int cache_offset = position_id * half_head_size; + const T* cos_data = cos_cache + cache_offset; + const T* sin_data = sin_cache + cache_offset; + + int cache_idx = 0; + T sign = 0; + int j = 0; + if (interleaved) { + cache_idx = (i / 2) % half_head_size; + sign = (i % 2 == 0) ? -1 : 1; + j = (i % 2 == 0) ? i+1 : i-1; // i - sign + } else { + cache_idx = i % half_head_size; + sign = (i < half_head_size) ? -1 : 1; + j = (i + half_head_size) % head_size; + } + output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; +} + + +template +Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + T* output, + const T* input, + const int64_t* position_ids, + const T* cos_cache, + const T* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block) { + + constexpr int smem_size = 0; + const dim3 grid(num_heads, sequence_length, batch_size); + const dim3 block(head_size, 1, 1); + + // Note: Current implementation assumes head_size <= max_threads_per_block + // because head_size is currently large for LLaMA-2. For smaller head_size + // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` + // instead. This will require kernel changes to support. + + assert(head_size <= max_threads_per_block); + RotaryEmbeddingBSNH<<>>( + output, input, cos_cache, sin_cache, position_ids, + sequence_length, num_heads, head_size, position_ids_format, interleaved + ); + + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + float* output, + const float* input, + const int64_t* position_ids, + const float* cos_cache, + const float* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block); + +template Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + half* output, + const half* input, + const int64_t* position_ids, + const half* cos_cache, + const half* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block); + + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h new file mode 100644 index 0000000000000..29ff48a8ad0fb --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + T* output, + const T* input, + const int64_t* position_ids, + const T* cos_cache, + const T* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc index 78174181acdc8..3299bc2cb11de 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @@ -3,6 +3,7 @@ #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/nn/layer_norm_impl.h" +#include "core/common/narrow.h" #include "skip_layer_norm.h" #include "skip_layer_norm_impl.h" #include "contrib_ops/cpu/skip_layer_norm_helper.h" @@ -50,6 +51,11 @@ template Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const { const Tensor* input = ctx->Input(0); const Tensor* skip = ctx->Input(1); + if (strict_ && skip->Shape() != input->Shape()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "'input' and 'skip' shall have same shape when enable_skip_layer_norm_strict_mode is True"); + } + const Tensor* gamma = ctx->Input(2); const Tensor* beta = Simplified ? nullptr : ctx->Input(3); @@ -57,16 +63,13 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const Tensor* output = ctx->Output(0, input->Shape()); - // For inferencing, we support one more optional output which is the sum - // of the input and skip tensors - Tensor* skip_input_bias_add_output = ctx->Output(3, input->Shape()); + // Optional output for the sum of skip, input and bias tensors (It is also the input of Layer Normalization). + Tensor* sum_output = ctx->Output(3, input->Shape()); const auto& input_dims = input->Shape().GetDims(); size_t input_dims_size = input_dims.size(); - const auto& skip_dims = skip->Shape().GetDims(); - size_t skip_dims_size = skip_dims.size(); - int hidden_size = static_cast(input_dims[input_dims_size - 1]); + int hidden_size = onnxruntime::narrow(input_dims[input_dims_size - 1]); ORT_RETURN_IF_ERROR(onnxruntime::contrib::skip_layer_norm_helper::CheckInputs(input, skip, @@ -76,12 +79,15 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const hidden_size, input_dims_size)); - const bool skip_broadcasted = (skip_dims[0] == 1 || skip_dims_size == 2) ? true : false; - const int skip_size = static_cast(skip_dims[skip_dims_size - 1] * skip_dims[skip_dims_size - 2]); + int row_count = onnxruntime::narrow(input->Shape().SizeToDimension(input_dims_size - 1)); + if (row_count == 0) { + return Status::OK(); + } - int row_count = gsl::narrow(input->Shape().SizeToDimension(input_dims_size - 1)); typedef typename ToCudaType::MappedType CudaT; + const int skip_size = onnxruntime::narrow(skip->Shape().Size()); + if (strict_) { HostApplyLayerNorm( GetDeviceProp(), @@ -97,21 +103,20 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, // beta reinterpret_cast(skip->Data()), // skip or residual to add (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, // bias to add - skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr); + sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr); } else { LaunchSkipLayerNormKernel( Stream(ctx), reinterpret_cast(output->MutableData()), - skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr, + sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr, reinterpret_cast(input->Data()), reinterpret_cast(skip->Data()), + (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, reinterpret_cast(gamma->Data()), (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, - (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, epsilon_, hidden_size, row_count, - skip_broadcasted, skip_size); } diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu index bfecacf4fb717..50c8e4b5e0398 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu @@ -51,60 +51,68 @@ half maybe2half(float x) { // Using only power of 2 numbers will lead to waste of compute for same size such as 768, which is a very common case // in BERT. Ideally we can step by wrap_size * num_unroll, but listing too many steps will cause long compile time. -constexpr int kSizes[] = {32, 64, 128, 384, 768, 1024, 2048}; +constexpr int kSizes[] = {128, 320, 384, 640, 768, 1024, 1280, 2048, 4096, 5120, 8192}; +constexpr size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); +constexpr int kMaxSize = kSizes[kNumOfSizes - 1]; constexpr int kMinBlockSize = 32; -constexpr int kMaxBlockSize = 256; +constexpr int kMaxBlockSize = 1024; int NextSize(int x) { - size_t len = sizeof(kSizes) / sizeof(kSizes[0]); - for (size_t i = 0; i < len; ++i) { + for (size_t i = 0; i < kNumOfSizes; ++i) { if (x <= kSizes[i]) { return kSizes[i]; } } - return kSizes[len - 1]; + return kMaxSize + 1; } -template -bool CanVectorized(T* output, T* skip_input_bias_add_output, const T* input, const T* skip, const T* gamma, - const T* beta, const T* bias, const int ld, const int next_size) { - constexpr int alignment = std::alignment_of>::value; - return ld % NumUnroll == 0 && reinterpret_cast(output) % alignment == 0 && - reinterpret_cast(skip_input_bias_add_output) % alignment == 0 && - reinterpret_cast(input) % alignment == 0 && reinterpret_cast(skip) % alignment == 0 && - reinterpret_cast(gamma) % alignment == 0 && reinterpret_cast(beta) % alignment == 0 && - reinterpret_cast(bias) % alignment == 0 && next_size / NumUnroll >= kMinBlockSize && - next_size / NumUnroll <= kMaxBlockSize; +bool CanVectorized(void* output, void* sum_output, const void* input, const void* skip, const void* bias, + const void* gamma, const void* beta, const int ld, const int next_size, int num_unroll, int element_size) { + int alignment = element_size * num_unroll; + return ld % num_unroll == 0 && + reinterpret_cast(output) % alignment == 0 && + reinterpret_cast(sum_output) % alignment == 0 && + reinterpret_cast(input) % alignment == 0 && + reinterpret_cast(skip) % alignment == 0 && + reinterpret_cast(bias) % alignment == 0 && + reinterpret_cast(gamma) % alignment == 0 && + reinterpret_cast(beta) % alignment == 0 && + next_size / num_unroll >= kMinBlockSize && + next_size / num_unroll <= kMaxBlockSize; } } // namespace template __global__ void SkipLayerNormKernel( - const int ld, const T* input, const T* skip, - const T* beta, const T* gamma, const T* bias, - const T epsilon, T* output, T* skip_input_bias_add_output, const bool skip_broadcasted, int skip_size) { + T* output, T* sum_output, const T* input, const T* skip, const T* bias, const T* gamma, const T* beta, T epsilon, + const int ld, int skip_size) { const T reverse_ld = T(1.f / ld); const int offset = blockIdx.x * ld; + const bool has_bias = (bias != nullptr); + // Reduce sum of x and x^2, and the results are divided by ld. KeyValuePairSum pair_sum; - // reduce x and x^2 cub::KeyValuePair thread_data(0, 0); for (int i = threadIdx.x; i < ld; i += TPB) { const int idx = offset + i; - const T skip_data = skip_broadcasted ? skip[idx % skip_size] : skip[idx]; - const T val = (bias == nullptr) ? input[idx] + skip_data : input[idx] + skip_data + bias[i]; + T val = input[idx]; + if (has_bias) { + val += bias[i]; + } + val += skip[idx % skip_size]; const T rldval = reverse_ld * val; thread_data = pair_sum(thread_data, cub::KeyValuePair(rldval, rldval * val)); - if (skip_input_bias_add_output != nullptr) { - skip_input_bias_add_output[idx] = val; + if (sum_output != nullptr) { + sum_output[idx] = val; } output[idx] = val; } + if (Simplified) { SimplifiedLayerNorm(thread_data.value, ld, offset, gamma, epsilon, output); return; @@ -115,106 +123,114 @@ __global__ void SkipLayerNormKernel( // Vectorized kernel template __global__ void SkipLayerNormKernelSmall( - const int ld, const T* input, const T* skip, const T* beta, const T* gamma, - const T* bias, const T epsilon, T* output, T* skip_input_bias_add_output, - bool hasBias, bool hasSkipInputBiasAdditionOutput, const bool skip_broadcasted, const int skip_size) { + T* output, T* sum_output, const T* input, const T* skip, const T* bias, const T* gamma, const T* beta, T epsilon, + int ld, int skip_size) { const T rld = T(1.f / ld); - const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld + const int idx = blockIdx.x * ld + threadIdx.x * ILP; using VecT = aligned_vector; + T sum_v[ILP]; - T input_v[ILP], skip_v[ILP], bias_v[ILP], skip_input_bias_add_output_v[ILP]; + cub::KeyValuePair thread_data(T(0.f), T(0.f)); - VecT* input_val = reinterpret_cast(&input_v); - *input_val = *reinterpret_cast(&input[idx]); + if (ILP * threadIdx.x < ld) { // load data under this guard to avoid reading out-of-bounds + T skip_v[ILP], bias_v[ILP]; - VecT* skip_val = reinterpret_cast(&skip_v); - if (skip_broadcasted) { - *skip_val = *reinterpret_cast(&skip[idx % skip_size]); - } else { - *skip_val = *reinterpret_cast(&skip[idx]); - } + // load input to sum_v + VecT* sum_val = reinterpret_cast(&sum_v); + *sum_val = *reinterpret_cast(&input[idx]); - if (hasBias) { - VecT* bias_val = reinterpret_cast(&bias_v); - *bias_val = *reinterpret_cast(&bias[threadIdx.x * ILP]); - } + VecT* skip_val = reinterpret_cast(&skip_v); + *skip_val = *reinterpret_cast(&skip[idx % skip_size]); - cub::KeyValuePair thread_data(T(0.f), T(0.f)); + const bool has_bias = (bias != nullptr); + if (has_bias) { + VecT* bias_val = reinterpret_cast(&bias_v); + *bias_val = *reinterpret_cast(&bias[threadIdx.x * ILP]); + } - if (ILP * threadIdx.x < ld) { T rldval_sum = T(0.f); T rldvalsq_sum = T(0.f); + const bool has_sum_output = (sum_output != nullptr); + #pragma unroll for (int i = 0; i < ILP; i++) { - input_v[i] += hasBias ? skip_v[i] + bias_v[i] : skip_v[i]; - - if (hasSkipInputBiasAdditionOutput) { - skip_input_bias_add_output_v[i] = input_v[i]; + if (has_bias) { + sum_v[i] += bias_v[i]; } + sum_v[i] += skip_v[i]; - const T rldval = rld * input_v[i]; + const T rldval = rld * sum_v[i]; rldval_sum += rldval; - rldvalsq_sum += rldval * input_v[i]; + rldvalsq_sum += rldval * sum_v[i]; } - if (hasSkipInputBiasAdditionOutput) { - *(reinterpret_cast(&skip_input_bias_add_output[idx])) = *reinterpret_cast(&skip_input_bias_add_output_v); + if (has_sum_output) { + *(reinterpret_cast(&sum_output[idx])) = *reinterpret_cast(&sum_v); } thread_data = cub::KeyValuePair(rldval_sum, rldvalsq_sum); } if (Simplified) { - SimplifiedLayerNormSmall(input_v, thread_data.value, ld, idx, gamma, epsilon, output); + SimplifiedLayerNormSmall(sum_v, thread_data.value, ld, idx, gamma, epsilon, output); return; } - LayerNormSmall(input_v, thread_data, ld, idx, beta, gamma, epsilon, output); + LayerNormSmall(sum_v, thread_data, ld, idx, beta, gamma, epsilon, output); } template void LaunchSkipLayerNormKernel( - cudaStream_t stream, T* output, T* skip_input_bias_add_output, const T* input, const T* skip, const T* gamma, - const T* beta, const T* bias, float epsilon, int ld, int row_count, bool skip_broadcasted, int skip_size) { - if (row_count == 0) { - return; - } - - bool hasBias = (bias == nullptr) ? false : true; - bool hasSkipInputBiasAdditionOutput = (skip_input_bias_add_output == nullptr) ? false : true; - + cudaStream_t stream, T* output, T* sum_output, + const T* input, const T* skip, const T* bias, const T* gamma, const T* beta, float epsilon, + int ld, int row_count, int skip_size) { const int next_size = NextSize(ld); const int grid_size = row_count; - bool flag_vec2 = - CanVectorized(output, skip_input_bias_add_output, input, skip, gamma, beta, bias, ld, next_size); - bool flag_vec4 = - CanVectorized(output, skip_input_bias_add_output, input, skip, gamma, beta, bias, ld, next_size); + bool can_unroll_vec4 = CanVectorized(output, sum_output, input, + skip, bias, gamma, + beta, ld, next_size, + 4, sizeof(T)); + bool can_unroll_vec8 = CanVectorized(output, sum_output, input, + skip, bias, gamma, + beta, ld, next_size, + 8, sizeof(T)); + +#define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \ + SkipLayerNormKernelSmall<<>>( \ + output, sum_output, input, skip, bias, gamma, beta, maybe2half(epsilon), ld, skip_size) + +#define LAUNCH_SKIP_LAYER_NORM_KERNEL() \ + SkipLayerNormKernel<<>>( \ + output, sum_output, input, skip, bias, gamma, beta, maybe2half(epsilon), ld, skip_size) + +#define CASE_NEXT_SIZE(next_size_value) \ + case next_size_value: { \ + static_assert(next_size_value >= kSizes[0] && next_size_value <= kMaxSize); \ + if constexpr (next_size_value >= 320) { \ + if (can_unroll_vec8) { \ + constexpr int block_size = next_size_value / 8; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(8); \ + } else { \ + constexpr int block_size = 256; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL(); \ + } \ + } else { \ + if (can_unroll_vec4) { \ + constexpr int block_size = next_size_value / 4; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(4); \ + } else { \ + if (next_size_value <= kMaxBlockSize) { \ + constexpr int block_size = next_size_value; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); \ + } else { \ + constexpr int block_size = 256; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL(); \ + } \ + } \ + } \ + } break switch (next_size) { -#define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \ - SkipLayerNormKernelSmall \ - <<>>(ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, \ - skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput, skip_broadcasted, skip_size) -#define LAUNCH_SKIP_LAYER_NORM_KERNEL() \ - SkipLayerNormKernel<<>>( \ - ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, skip_input_bias_add_output, skip_broadcasted, skip_size) -#define CASE_NEXT_SIZE(next_size_value) \ - case next_size_value: { \ - if (flag_vec4) { \ - constexpr int block_size = next_size_value / 4; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(4); \ - } else if (flag_vec2) { \ - constexpr int block_size = next_size_value / 2; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(2); \ - } else { \ - if (next_size_value <= kMaxBlockSize) { \ - constexpr int block_size = next_size_value; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); \ - } else { \ - LAUNCH_SKIP_LAYER_NORM_KERNEL(); \ - } \ - } \ - } break CASE_NEXT_SIZE(kSizes[0]); CASE_NEXT_SIZE(kSizes[1]); CASE_NEXT_SIZE(kSizes[2]); @@ -222,18 +238,27 @@ void LaunchSkipLayerNormKernel( CASE_NEXT_SIZE(kSizes[4]); CASE_NEXT_SIZE(kSizes[5]); CASE_NEXT_SIZE(kSizes[6]); + CASE_NEXT_SIZE(kSizes[7]); + CASE_NEXT_SIZE(kSizes[8]); + CASE_NEXT_SIZE(kSizes[9]); + CASE_NEXT_SIZE(kSizes[10]); + default: { + constexpr int block_size = 256; + LAUNCH_SKIP_LAYER_NORM_KERNEL(); + break; + } + } + #undef CASE_NEXT_SIZE #undef LAUNCH_SKIP_LAYER_NORM_KERNEL #undef LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL - } } -#define SKIPLAYERNORM_IMPL(T, Simplified) \ - template void LaunchSkipLayerNormKernel(cudaStream_t stream, T * output, \ - T * skip_input_bias_add_output, \ - const T* input, const T* skip, const T* gamma, \ - const T* beta, const T* bias, float epsilon, \ - int ld, int row_count, bool skip_broadcasted, int skip_size); +#define SKIPLAYERNORM_IMPL(T, Simplified) \ + template void LaunchSkipLayerNormKernel(cudaStream_t stream, T * output, T * sum_output, \ + const T* input, const T* skip, const T* bias, \ + const T* gamma, const T* beta, float epsilon, \ + int ld, int row_count, int skip_size); SKIPLAYERNORM_IMPL(float, true); SKIPLAYERNORM_IMPL(float, false); SKIPLAYERNORM_IMPL(half, true); diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h index ffb5850c827fe..9727dd6236ec8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h @@ -11,18 +11,17 @@ namespace cuda { template void LaunchSkipLayerNormKernel( cudaStream_t stream, - T* output, // normalized output tensor - T* skip_input_bias_add_output, // sum of the input and skip (and bias if it exists) tensors output - const T* input, // input tensor - const T* skip, // skip tensor - const T* gamma, // Layer normalization gamma tensor - const T* beta, // Layer normalization beta tensor - const T* bias, // Layer normalization beta tensor - float epsilon, // Layer normalization epsilon - int hidden_size, // hidden size, it is the leading dimension (ld) - int row_count, // number of rows. That is total number of elements divided by hidden size. - bool skip_broadcasted, // determines if broadcasting should be implemented - int skip_size); // determines size of the skip tensor + T* output, // normalized output tensor + T* sum_output, // sum of the input and skip (and bias if it exists) tensors output + const T* input, // input tensor + const T* skip, // skip tensor + const T* bias, // bias tensor + const T* gamma, // Layer normalization gamma tensor + const T* beta, // Layer normalization beta tensor + float epsilon, // Layer normalization epsilon + int hidden_size, // hidden size, it is the leading dimension (ld) + int row_count, // number of rows. That is total number of elements divided by hidden size. + int skip_size); // number of elements of the skip tensor } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc new file mode 100644 index 0000000000000..3cfa3ab959343 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "distributed_expand.h" +#include "sharding.h" +#include "sharding_spec.h" +#include "nccl_kernels.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cuda/tensor/expand.h" + +// std C++. +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +DistributedExpand::DistributedExpand(const OpKernelInfo& info) : DistributedKernel(info) {} + +template +Status DistributedExpand::ComputeInternal(OpKernelContext* context) const { + ORT_ENFORCE(context != nullptr); + // Assumptions. + // - Shape is not sharded. + // Algorithm. + // - Compute logical output shape. + // - Compute local output shape. + // - Expand from local input to local output. + + auto input_tensor = context->Input(0); + auto shape_tensor = context->Input(1); + const auto& input_sharding_spec = input_shard_specs_.at(0); + const auto& shape_sharding_spec = input_shard_specs_.at(1); + const auto& output_sharding_spec = output_shard_specs_.at(0); + + ORT_ENFORCE(shape_sharding_spec.HasNoShard(), + "It's not worth to shard Shape tensor. " + "If sharding shape is needed, please submit a feature request."); + // Compute logical input shape. + const auto original_input_shape = ComputeOriginShape(input_tensor->Shape(), input_sharding_spec); + + // Compute logical output shape. + // This `shape_tensor` stores the logical output shape. + const auto* p_shape = shape_tensor->Data(); + TensorShapeVector original_output_dims{p_shape, p_shape + shape_tensor->Shape().Size()}; + TensorShape original_output_shape(original_output_dims); + ORT_ENFORCE( + onnxruntime::cuda::ComputeOutputShape( + Node().Name(), + original_input_shape, + original_output_dims, original_output_shape) + .IsOK()); + + // Compute local output shape. + const auto local_output_shape = ComputeShardShape(original_output_shape, output_sharding_spec); + + auto output_tensor = context->Output(0, local_output_shape); + + return FuncExpand( + this, + context, + input_tensor, + shape_tensor, + output_tensor); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedExpand, + kMSDomain, + 1, + int64_t, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedExpand); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedExpand, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedExpand); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedExpand, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedExpand); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_expand.h b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.h new file mode 100644 index 0000000000000..dedb1bdc5aa36 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sharding_spec.h" +#include "sharding.h" +#include "core/providers/cuda/cuda_kernel.h" + +#include +#include +#include +#include +#include +#include + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedExpand final : public DistributedKernel { + public: + explicit DistributedExpand(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc new file mode 100644 index 0000000000000..967f30a304ac2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc @@ -0,0 +1,175 @@ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "distributed_reduce.h" +#include "sharding.h" +#include "sharding_spec.h" +#include "nccl_kernels.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/reduction/reduction_ops.h" + +// std C++. +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +DistributedReduceBase::DistributedReduceBase( + const OpKernelInfo& info, + cudnnReduceTensorOp_t cudnn_reduce_op) : DistributedKernel(info) { + keepdims_ = info.GetAttrOrDefault("keepdims", 1); + cudnn_reduce_op_ = cudnn_reduce_op; +}; + +template +Status DistributedReduceBase::ComputeInternal(OpKernelContext* context) const { + const auto& input_sharding_spec = input_shard_specs_.at(0); + const auto& axes_sharding_spec = input_shard_specs_.at(1); + const auto& output_sharding_spec = output_shard_specs_.at(0); + + ORT_ENFORCE(axes_sharding_spec.HasNoShard(), + "It's not worthy to shard axes tensor. " + "If sharding axes is needed, please submit a feature request."); + + const Tensor* input_tensor = context->Input(0); + const Tensor* axes_tensor = context->Input(1); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "Axes tensor must be an 1-D tensor."); + auto axes_span = axes_tensor->DataAsSpan(); + + // Case 1: empty axes means treating this reduction as an identity. + if (axes_span.empty()) { + ORT_ENFORCE( + input_sharding_spec == output_sharding_spec, + "Input and output sharding specs should be the same. Otherwise, resharding is needed."); + auto* output_tensor = context->Output(0, input_tensor->Shape()); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor->MutableData(), input_tensor->Data(), input_tensor->SizeInBytes(), + cudaMemcpyDeviceToDevice, Stream(context))); + return Status::OK(); + } + + // Case 2: this is a valid reduction. Let's prepare for it. + + bool sharding_on_reduced_axes = false; + for (auto axis_it = axes_span.begin(); input_sharding_spec.HasShard() && axis_it != axes_span.end(); ++axis_it) { + if (*axis_it == input_sharding_spec.GetPartitionAxis()) { + sharding_on_reduced_axes = true; + break; + } + } + + if (sharding_on_reduced_axes) { + // Case 2-1: sharding on reduced axes. + ORT_THROW(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, "Not implemented. Resharding is required to make reduced axes replica."); + } else { + // Case 2-2: sharding on passing-through axes or no shard. + ORT_ENFORCE( + input_sharding_spec == output_sharding_spec, + "Input and output sharding specs should be the same. Otherwise, resharding is needed."); + onnxruntime::cuda::PrepareReduceMetadata metadata; + ORT_RETURN_IF_ERROR( + onnxruntime::cuda::PrepareForReduce(input_tensor, keepdims_, axes_span, metadata)); + auto output_tensor = context->Output(0, metadata.squeezed_output_dims); + + // Fast reduction is not deterministic, so sometimes we want to turn it off. + const bool enable_fast_but_non_deterministic_reduction = !context->GetUseDeterministicCompute(); + return onnxruntime::cuda::ReduceComputeCore( + /* GPU allocator */ Info().GetAllocator(OrtMemType::OrtMemTypeDefault), + *input_tensor, metadata, *output_tensor, cudnn_reduce_op_, axes_span, + /* calculate_log */ false, /* calculate_sqt */ false, /* log_sum_exp_ */ false, + enable_fast_but_non_deterministic_reduction, context->GetComputeStream()); + } + return Status::OK(); +} + +template +DistributedReduceSum::DistributedReduceSum( + const OpKernelInfo& info) : DistributedReduceBase(info, CUDNN_REDUCE_TENSOR_ADD){}; + +template +DistributedReduceMean::DistributedReduceMean( + const OpKernelInfo& info) : DistributedReduceBase(info, CUDNN_REDUCE_TENSOR_AVG){}; + +template +DistributedReduceMax::DistributedReduceMax( + const OpKernelInfo& info) : DistributedReduceBase(info, CUDNN_REDUCE_TENSOR_MAX){}; + +// ReduceSum +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceSum, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceSum); +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceSum, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceSum); + +// ReduceMean +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMean, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMean); +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMean, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMean); + +// ReduceMax +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMax, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMax); +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMax, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMax); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h new file mode 100644 index 0000000000000..2939852c75c60 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sharding_spec.h" +#include "sharding.h" +#include "core/providers/cuda/cuda_kernel.h" + +#include +#include +#include +#include +#include +#include + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedReduceBase : public DistributedKernel { + public: + explicit DistributedReduceBase(const OpKernelInfo& info, cudnnReduceTensorOp_t cudnn_reduce_op); + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + // ONNX attribute. If true, reduced axes are retained as dimensions with size one. + // Otherwise, drop reduced axes. + bool keepdims_; + cudnnReduceTensorOp_t cudnn_reduce_op_; +}; + +template +class DistributedReduceSum final : public DistributedReduceBase { + public: + explicit DistributedReduceSum(const OpKernelInfo& info); +}; + +template +class DistributedReduceMean final : public DistributedReduceBase { + public: + explicit DistributedReduceMean(const OpKernelInfo& info); +}; + +template +class DistributedReduceMax final : public DistributedReduceBase { + public: + explicit DistributedReduceMax(const OpKernelInfo& info); +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.cc new file mode 100644 index 0000000000000..e413ccf580870 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.cc @@ -0,0 +1,861 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "distributed_reshape.h" +#include "sharding.h" +#include "sharding_spec.h" +#include "nccl_kernels.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cuda/tensor/transpose.h" +#include "core/providers/cuda/cuda_check_memory.h" + +// std C++. +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +// Return true if src_shape[src_begin:src_end] is the same as +// dst_shape[dst_begin:dst_end]. Otherwise, return false. +// TODO: replace std::vector with gsl::span. +bool CompareSubVectors( + const std::vector& src_shape, + const std::vector& dst_shape, + size_t src_begin, size_t src_end, + size_t dst_begin, size_t dst_end) { + if (src_end - src_begin != dst_end - dst_begin) { + // Sub-vectors have different lengths. + return false; + } + for (size_t src_index = src_begin, dst_index = dst_begin; + src_index < src_end && dst_index < dst_end; + ++src_index, ++dst_index) { + if (src_shape[src_index] != dst_shape[dst_index]) { + // Sub-vectors have different elements. + return false; + } + } + // Sub-vectors have same length and same elements. + return true; +} + +// TODO: replace std::vector with gsl::span. +std::tuple IsTwoAxisFusion( + const std::vector& src_shape, + const std::vector& dst_shape) { + // Return values: + // - bool: whether two consecutive axes are fused. + // - size_t: the axis in destination shape formed by fusing two source axes. + // - size_t: the first axis fused. + // - size_t: the length of fusion. In two-axis fusion considered by this + // function, the length of fusion is always 2. + const size_t src_rank = src_shape.size(); + const size_t dst_rank = dst_shape.size(); + if (src_rank < 2 || dst_rank < 1) { + return std::make_tuple(false, -1, -1, -1); + } + if (src_rank - 1 != dst_rank) { + return std::make_tuple(false, -1, -1, -1); + } + for (size_t i_src = 0; i_src < src_rank; ++i_src) { + if (i_src + 1 > src_rank - 1) { + // We are at src_shape[i] and we need + // src_shape[i + 1] to fuse. + // If we are at the last axis, we cannot fuse. + break; + } + const int64_t prod = src_shape[i_src] * src_shape[i_src + 1]; + + for (size_t i_dst = 0; i_dst < dst_rank; ++i_dst) { + // Check if shape[i_src:i_src+2] (i.e., shape[i_src] and shape[i_src+1]) + // for source tensor are fused into shape[i_dst] for destination tensor. + if (prod != dst_shape[i_dst]) { + continue; + } + // Check if corresponding dimensions before fusion area + // are the same. + const bool prefix_shape_match = CompareSubVectors( + src_shape, + dst_shape, + // Represent src_shape[0:i_src]. + 0, i_src, + // Represent dst_shape[0:i_dst]. + 0, i_dst); + const bool suffix_shape_match = CompareSubVectors( + src_shape, + dst_shape, + // Represent src_shape[i_src+2:]. + i_src + 2, src_rank, + // Represent dst_shape[i_dst+1:]. + i_dst + 1, dst_rank); + if (prefix_shape_match && suffix_shape_match) { + return std::make_tuple( + true, i_dst, i_src, 2); + } + } + } + return std::make_tuple(false, 0, 0, 0); +} + +std::tuple IsTwoAxisDecomposition( + const std::vector& src_shape, + const std::vector& dst_shape) { + // Return values: + // - bool: whether one source axis is decomposed into two consecutive destination axes. + // - size_t: the axis in source shape decomposed into two consecutive destination axes. + // - size_t: the first axis the source axis decomposed into. + // - size_t: the number of decomposed axes. It's always 2 in this function. + return IsTwoAxisFusion(dst_shape, src_shape); +} + +std::vector RepeatVector(const std::vector& vec, int64_t repeat) { + std::vector new_vec; + for (int64_t i = 0; i < repeat; ++i) { + new_vec.insert(new_vec.end(), vec.begin(), vec.end()); + } + return new_vec; +} + +DeviceMesh CreateInterleaveDeviceMesh( + const DeviceMesh& source_mesh, const int64_t repeat) { + // Given a 1-D device mesh [0, 1] and repeat=2, + // return 1-D device mesh [0, 1, 0, 1]. + if (source_mesh.device_mesh_shape.size() != 1) { + throw std::runtime_error("Source mesh shape 1-D."); + } + + // Mesh to return. + DeviceMesh new_mesh; + + std::vector& elements = new_mesh.device_mesh_elements; + for (int64_t i = 0; i < repeat; ++i) { + elements.insert( + elements.end(), + source_mesh.device_mesh_elements.begin(), + source_mesh.device_mesh_elements.end()); + } + + // source mesh must be 1-D so we only care its 1st dimension. + new_mesh.device_mesh_shape.push_back(source_mesh.device_mesh_shape[0] * repeat); + + return new_mesh; +} + +std::tuple ComputeNativeSpecForTwoAxisFusion( + const TensorPartitionSpec& src_spec, + const std::vector& src_shape, + const std::vector& dst_shape, + const int64_t fused_axis_in_src, + const int64_t fusion_axis_in_dst) { + // TODO(wechi): use device mesh stride to support non-1 stride. + // Example: S[0]R, shape=[2, 3], device_mesh=[0, 1] -> S[0], shape = [6], device_mesh=[0, 1] + // Example: RS[0], shape=[2, 3], device_mesh=[0, 1] -> S[0], shape = [6], device_mesh=[0, 1, 0, 1] + // Example: S[0]RR, shape=[2, 3, 5], device_mesh=[0, 1] -> S[0]R, shape = [2, 15], device_mesh=[0, 1] + ORT_ENFORCE(src_spec.CountShardingAxes() == 1, "Tensor to be reshaped has too many sharding axes."); + ORT_ENFORCE(src_spec.device_mesh.device_mesh_shape.size() == 1, "Source device mesh be 1-D."); + + if (src_spec.HasNoShard()) { + return std::make_tuple(true, TensorPartitionSpec::CreateAllReplica(dst_shape.size(), src_spec.device_mesh)); + } else if (src_spec.HasShard() && src_spec.OnlyShardAxis(fused_axis_in_src)) { + // Example: S[0]R, shape=[2, 3], device_mesh=[0, 1] -> S[0], shape = [6], device_mesh=[0, 1] + // Example 1: + // - logical input shape: [2, 8] + // - logical output shape: [16] + // - input sharding spec: S[0]R, device_mesh=[0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0, 0, 0, 0, 0, 0, 0], (device assignment) + // [1, 1, 1, 1, 1, 1, 1, 1]] + // [[ 0, 1, 2, 3, 4, 5, 6, 7], (values) + // [ 8, 9, 10, 11, 12, 13, 14, 15]] + // - Device 0's local tensor (shape: [2, 4]). + // [[ 0, 1, 2, 3, 4, 5, 6, 7]] + // - Device 1's local tensor (shape: [2, 4]). + // [[ 8, 9, 10, 11, 12, 13, 14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [2, 4]. + // 3. Run local reshape (reshape from shape [2, 4] to shape [8]): + // - Device 0's local output tensor. + // [ 0, 1, 2, 3, 4, 5, 6, 7] + // - Device 1's local output tensor. + // [ 8, 9, 10, 11, 12, 13, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - Logical output tensor: + // [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - Device assignment by comparing local tensors and logical output tensor: + // [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1] + // - S[0] with device_mesh = [0, 1] = input device mesh. + // 5. Native output sharding spec: + // - S[0] with device_mesh [0, 1] + // + // Example 2: + // - logical input shape: [8, 2] + // - logical output shape: [16] + // - input sharding spec: S[0]R, device_mesh=[0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0], (device assignment) + // [0, 0], + // [0, 0], + // [0, 0], + // [1, 1], + // [1, 1], + // [1, 1], + // [1, 1]] + // [[ 0, 1], (values) + // [ 2, 3], + // [ 4, 5], + // [ 6, 7], + // [ 8, 9], + // [10, 11], + // [12, 13], + // [14, 15]] + // - Device 0's local tensor (shape: [4, 2]). + // [[ 0, 1], + // [ 2, 3], + // [ 4, 5], + // [ 6, 7]] + // - Device 1's local tensor (shape: [4, 2]). + // [[ 8, 9], + // [10, 11], + // [12, 13], + // [14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [4, 2]. + // 3. Run local reshape (reshape from shape [4, 2] to shape [8]): + // - Device 0's local output tensor. + // [ 0, 1, 2, 3, 4, 5, 6, 7] + // - Device 1's local output tensor. + // [ 8, 9, 10, 11, 12, 13, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - Logical output tensor: + // [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - Device assignment by comparing local tensors and logical output tensor: + // [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1] + // - S[0] with device_mesh = [0, 1] = input device mesh. + // 5. Native output sharding spec: + // - S[0] with device_mesh [0, 1] + // + // Example 3: + // - logical input shape: [8, 2] + // - logical output shape: [16] + // - input sharding spec: S[0]R, device_mesh=[0, 1, 0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0], (device assignment) + // [0, 0], + // [1, 1], + // [1, 1], + // [0, 0], + // [0, 0], + // [1, 1], + // [1, 1]] + // [[ 0, 1], (values) + // [ 2, 3], + // [ 4, 5], + // [ 6, 7], + // [ 8, 9], + // [10, 11], + // [12, 13], + // [14, 15]] + // - Device 0's local tensor (shape: [4, 2]). + // [[ 0, 1], + // [ 2, 3], + // [ 8, 9], + // [10, 11]] + // - Device 1's local tensor (shape: [4, 2]). + // [[ 4, 5], + // [ 6, 7], + // [12, 13], + // [14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [4, 2]. + // 3. Run local reshape (reshape from shape [4, 2] to shape [8]): + // - Device 0's local output tensor. + // [ 0, 1, 2, 3, 8, 9, 10, 11] + // - Device 1's local output tensor. + // [ 4, 5, 6, 7, 12, 13, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - Logical output tensor: + // [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - Device assignment by comparing local tensors and logical output tensor: + // [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1] + // - S[0] with device_mesh = [0, 1] = input device mesh. + // 5. Native output sharding spec: + // - S[0] with device_mesh [0, 1, 0, 1] + + // Reuse original device mesh but shard the fusion axis in output tensor. + auto dst_spec = TensorPartitionSpec::CreateOneTensorAxisOneDeviceMeshAxisSharding( + dst_shape.size(), src_spec.device_mesh, fusion_axis_in_dst, /* 1-D mesh */ 0); + return std::make_tuple(true, dst_spec); + } else if (src_spec.HasShard() && src_spec.OnlyShardAxis(fused_axis_in_src + 1)) { + // Example 1 of determining native output sharding spec: + // - logical input shape: [3, 4] + // - logical output shape: [12] + // - input sharding spec: RS[0], device_mesh=[0, 1, 0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 1, 0, 1], (device assignment) + // [0, 1, 0, 1], + // [0, 1, 0, 1]] + // [[0, 1, 2, 3], (values) + // [4, 5, 6, 7], + // [8, 9, 10, 11]], + // - Device 0's local tensor. + // [[0, 0], + // [0, 0], + // [0, 0]] + // [[0, 2], + // [4, 6], + // [8, 10]], + // - Device 1's local tensor. + // [[1, 1], + // [1, 1], + // [1, 1]] + // [[1, 3], + // [5, 7], + // [9, 11]], + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [6] by fusing both axes in shape [3, 2]. + // 3. Run local reshape (reshape from [3, 2] to [6]): + // - Device 0's local output tensor. + // [0, 0, 0, 0, 0, 0] + // [0, 2, 4, 6, 8, 10] + // - Device 1's local output tensor. + // [1, 1, 1, 1, 1, 1] + // [1, 3, 5, 7, 9, 11] + // 4. Determine native output sharding spec by comparing local output tensors and logical tensor. + // - Logical output tensor: + // [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + // - S[0] with device_mesh = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] = [0, 1, 0, 1] * (first fused dimension). + // 5. Native output sharding spec: + // - S[0] with device_mesh = [0, 1, 0, 1] * (first fused dimension) = [0, 1, 0, 1] * 3 + // + // Example 2 of determining native output sharding spec: + // - logical input shape: [3, 8] + // - logical output shape: [24] + // - input sharding spec: RS[0], device_mesh=[0, 1, 0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0, 1, 1, 0, 0, 1, 1], (device assignment) + // [0, 0, 1, 1, 0, 0, 1, 1], + // [0, 0, 1, 1, 0, 0, 1, 1]] + // [[ 0, 1, 2, 3, 4, 5, 6, 7], (values) + // [ 8, 9, 10, 11, 12, 13, 14, 15], + // [16, 17, 18, 19, 20, 21, 22, 23]] + // - Device 0's local tensor (shape: [3, 4]). + // [[0, 0, 0, 0], + // [0, 0, 0, 0], + // [0, 0, 0, 0]] + // [[ 0, 1, 4, 5], + // [ 8, 9, 12, 13], + // [16, 17, 20, 21]] + // - Device 1's local tensor (shape: [3, 4]). + // [[1, 1, 1, 1], + // [1, 1, 1, 1], + // [1, 1, 1, 1]] + // [[ 2, 3, 6, 7], + // [10, 11, 14, 15], + // [18, 19, 22, 23]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [12] by fusing both axes in shape [3, 4]. + // 3. Run local reshape (reshape from [3, 4] to [12]): + // - Device 0's local output tensor . + // [0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21] + // - Device 1's local output tensor . + // [2, 3, 6, 7, 10, 11, 14, 15, 18, 19, 22, 23] + // 4. Determine native output sharding spec from local output tensors. + // - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] + // - [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + // - S[0] with device_mesh = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] = . + // 5. Native output sharding spec: + // - S[0] with device_mesh = [0, 1, 0, 1] * (first fused dimension) = [0, 1, 0, 1] * 3 + // + // Example 3: + // - logical input shape: [2, 8] + // - logical output shape: [16] + // - input sharding spec: RS[0], device_mesh=[0, 1, 0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0, 1, 1, 0, 0, 1, 1], (device assignment) + // [0, 0, 1, 1, 0, 0, 1, 1]] + // [[ 0, 1, 2, 3, 4, 5, 6, 7], (values) + // [ 8, 9, 10, 11, 12, 13, 14, 15]] + // - Device 0's local tensor (shape: [2, 4]). + // [[0, 0, 0, 0], + // [0, 0, 0, 0]] + // [[ 0, 1, 4, 5], + // [ 8, 9, 12, 13]] + // - Device 1's local tensor (shape: [2, 4]). + // [[1, 1, 1, 1], + // [1, 1, 1, 1]] + // [[ 2, 3, 6, 7], + // [10, 11, 14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [2, 4]. + // 3. Run local reshape (reshape from [2, 4] to [8]): + // - Device 0's local output tensor . + // [ 0, 1, 4, 5, 8, 9, 12, 13] + // - Device 1's local output tensor . + // [ 2, 3, 6, 7, 10, 11, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + // - S[0] with device_mesh = [0, 1, 0, 1, 0, 1, 0, 1] = [0, 1, 0, 1] * (first fused dimension). + // 5. Native output sharding spec: + // - S[0] with device_mesh = [0, 1, 0, 1] * (first fused dimension) = [0, 1, 0, 1] * 2 + // + // Example 4: + // - logical input shape: [2, 8] + // - logical output shape: [16] + // - input sharding spec: RS[0], device_mesh=[0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0, 0, 0, 1, 1, 1, 1], (device assignment) + // [0, 0, 0, 0, 1, 1, 1, 1]] + // [[ 0, 1, 2, 3, 4, 5, 6, 7], (values) + // [ 8, 9, 10, 11, 12, 13, 14, 15]] + // - Device 0's local tensor (shape: [2, 4]). + // [[0, 0, 0, 0], + // [0, 0, 0, 0]] + // [[ 0, 1, 2, 3], + // [ 8, 9, 10, 11]] + // - Device 1's local tensor (shape: [2, 4]). + // [[1, 1, 1, 1], + // [1, 1, 1, 1]] + // [[ 4, 5, 6, 7], + // [12, 13, 14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [2, 4]. + // 3. Run local reshape (reshape from [2, 4] to [8]): + // - Device 0's local output tensor . + // [ 0, 1, 2, 3, 8, 9, 10, 11] + // - Device 1's local output tensor . + // [ 4, 5, 6, 7, 12, 13, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1] + // - S[0] with device_mesh = [0, 1, 0, 1] = [0, 1] * (first fused dimension). + // 5. Native output sharding spec: + // - S[0] with device_mesh = [0, 1] * (first fused dimension) = [0, 1] * 2 = [0, 1, 0, 1] + + // The output device mesh is the repeats of the original device. + // Let's use Python syntax. If the original device mesh is [0, 1, 0, 1], and + // the first fused dimension is 3, then the output device mesh is [0, 1, 0, 1] * 3. + auto dst_device_mesh = DeviceMesh::Create1D( + src_spec.device_mesh.device_mesh_elements, + src_shape[fused_axis_in_src]); + // Sharding happens in the fusion axis with the new device mesh. + auto dst_spec = TensorPartitionSpec::CreateOneTensorAxisOneDeviceMeshAxisSharding( + dst_shape.size(), dst_device_mesh, fusion_axis_in_dst, /* 1-D mesh */ 0); + return std::make_tuple(true, dst_spec); + } else if (src_spec.HasShard() && (src_spec.GetPartitionAxis() < fused_axis_in_src || src_spec.GetPartitionAxis() > fused_axis_in_src + 1)) { + // It's two-axis fusion but the fused axes is not sharded. + // Example: S[0]RR, shape=[2, 3, 5], device_mesh=[0, 1] -> S[0]R, shape = [2, 15], device_mesh=[0, 1] + auto dst_spec = TensorPartitionSpec::CreateByDropAxes( + src_spec, {fused_axis_in_src + 1}); + return std::make_tuple(true, dst_spec); + } else { + return std::make_tuple(false, TensorPartitionSpec()); + } +} + +// Arguments: +// - device_elements: a vector of device IDs. +// It should only contain unique device IDs or +// repeats of a list of unique device IDs. Otherwise, +// (0, 0) is returned. +// Returns: +// - count per device ID (all device IDs should have the same count) +// - number of unique device IDs +// Examples: +// - [0, 1] -> (2, 1) +// - [0, 1, 2, 0, 1, 2] -> (2, 3) +std::tuple ComputeRepeatAndRepeatStride( + const std::vector& device_elements) { + int64_t first_device_id = device_elements.at(0); + int64_t first_device_id_count = 0; + for (size_t i = 0; i < device_elements.size(); ++i) { + if (device_elements.at(i) == first_device_id) { + ++first_device_id_count; + } + } + size_t repeat_stride = device_elements.size() / first_device_id_count; + + // Check if the device mesh pattern is supported. + // Supported examples: [0, 1, 2] and [0, 1, 0, 1, 0, 1]. + // Unsupported examples: [0, 1, 2, 1, 2, 0] and [0, 1, 2, 0]. + for (size_t repeat = 0; repeat < first_device_id_count; ++repeat) { + for (size_t device_id = 0; device_id < repeat_stride; ++device_id) { + ORT_ENFORCE( + device_elements.at(repeat * repeat_stride + device_id) == device_elements.at(device_id), + "Unsupported device mesh pattern."); + } + } + + // If device_mesh=[0, 1, 2, 0, 1, 2], returns (2, 3), which means + // - each device repeats twice for "2" in (2, 3). + // - there are 3 unique devices for "3" in (2, 3). + return std::make_tuple(first_device_id_count, repeat_stride); +} + +std::tuple ComputeNativeSpecForTwoAxisDecomposition( + const TensorPartitionSpec& src_spec, + const std::vector& src_shape, + const std::vector& dst_shape, + const int64_t decomposed_axis_in_src, + const int64_t decomposition_axis_in_dst) { + // TODO(wechi): use device mesh stride to support non-1 stride. + // Example: S[0], shape=[8], device_mesh=[0, 1] -> S[0]R + // Example: S[0], shape=[8], device_mesh=[0, 1] -> RS[0] + // Example: S[0], shape=[8], device_mesh=[0, 1, 0, 1] -> S[0]R + // Example: S[0], shape=[8], device_mesh=[0, 1, 0, 1] -> RS[0] + // Example: RS[0]R, shape=[8], device_mesh=[0, 1] -> RS[0]RR + // Example: RS[0]R, shape=[8], device_mesh=[0, 1] -> RRS[0]R + if (src_spec.CountShardingAxes() != 1) { + throw std::runtime_error("Too many sharding axes."); + } + if (src_spec.device_mesh.device_mesh_shape.size() != 1) { + throw std::runtime_error("Source device mesh be 1-D."); + } + + if (src_spec.HasNoShard()) { + return std::make_tuple(true, TensorPartitionSpec::CreateAllReplica(dst_shape.size(), src_spec.device_mesh)); + } else if (src_spec.OnlyShardAxis(decomposed_axis_in_src)) { + const int64_t device_stride = src_shape[decomposed_axis_in_src] / src_spec.device_mesh.device_mesh_shape[0]; + if (device_stride >= dst_shape[decomposition_axis_in_dst + 1] && device_stride % dst_shape[decomposition_axis_in_dst + 1] == 0) { + // Since 2nd decomposition dimension is a factor of device stride, + // Sharding happens at 1st decomposition axis in dst. + // device_stride = 10 + // S[0], shape=[20], device=[0, 1] -> S[0]R, shape=[2, 10], device=[0, 1] + // + // device_stride = 8 + // S[0], shape=[16], device=[0, 1] -> RS[0], shape=[1, 16], device=[0, 1] + // + // device_stride = 8 + // S[0], shape=[16], device=[0, 1] -> S[0]R, shape=[4, 4], device=[0, 1] + std::vector dst_axis_specs; + for (size_t src_axis = 0; src_axis < src_shape.size(); ++src_axis) { + if (src_axis != decomposed_axis_in_src) { + // Sharding spec is copied if the axis is not decomposed. + // E.g, shape [5, 6] -> Reshape -> shape [5, 3, 2] + // The spec for "5" is copied. + dst_axis_specs.push_back(AxisPartitionSpec::CreateCopy(src_spec.GetAxisSpec(src_axis))); + } else if (dst_shape[decomposition_axis_in_dst] == 1) { + // S[0] -> RS[0] + // E.g., shape [5] -> Reshape -> shape [1, 5] + // The spec for "5" is copied and "1" is replica. + // This reshape only adds a dummy new axis without affecting + // the underlying sharding status. + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + } else { + // S[0] -> S[0]R + // E.g., shape [5] -> Reshape -> shape [5, 1] + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + } + } + // Now, we know sharding happens at decomposed_axis_in_src axis in destination tensor. + // - effective_device_stride along decomposed_axis_in_src: device_stride / dst_shape[decomposed_axis_in_src + 1] + // - The original device patterns repeats: dst_shape[decomposed_axis_in_src] / effective_device_stride times. + const int64_t effective_device_stride = device_stride / dst_shape[decomposed_axis_in_src + 1]; + // How many times a device ID changes along decomposed_axis_in_src axis in destination tensor. + const int64_t number_of_device_changes = dst_shape[decomposed_axis_in_src] / effective_device_stride; + if ((size_t)number_of_device_changes != src_spec.device_mesh.device_mesh_elements.size()) { + throw std::runtime_error("Not supported. Resharding is required."); + } + auto dst_device_mesh = CreateInterleaveDeviceMesh( + src_spec.device_mesh, 1); + return std::make_tuple(true, TensorPartitionSpec::Create(dst_axis_specs, dst_device_mesh)); + } else if (dst_shape[decomposition_axis_in_dst + 1] > device_stride && dst_shape[decomposition_axis_in_dst + 1] % device_stride == 0) { + // Since 2nd decomposition dimension is a multiple of device stride, + // sharding happens at 2nd decomposition axis in dst. + // stride = 4 + // S[0], shape=[8], device=[0, 1] -> S[0]R, shape=[4, 2], device=[0, 1] + // + // stride = 8 + // S[0], shape=[32], device=[0, 1, 0, 1] -> RS[0], shape=[2, 16], device=[0, 1] + std::vector dst_axis_specs; + // How many times a device ID appears. + // E.g., [0, 1, 0, 1, 0, 1] -> 3 + int64_t repeats = 0; + // Number of unique devices. + // E.g., [0, 1, 0, 1, 0, 1] -> 2 + int64_t repeat_stride = 0; + DeviceMesh dst_device_mesh; + std::tie(repeats, repeat_stride) = ComputeRepeatAndRepeatStride(src_spec.device_mesh.device_mesh_elements); + for (size_t src_axis = 0; src_axis < src_shape.size(); ++src_axis) { + if (src_axis != decomposed_axis_in_src) { + dst_axis_specs.push_back(AxisPartitionSpec::CreateCopy(src_spec.GetAxisSpec(src_axis))); + } else if (dst_shape[decomposition_axis_in_dst] == 1) { + // S[0] -> RS[0] + // E.g., shape [5] -> Reshape -> shape [1, 5] + // In this case "1" is added as a dummy axis without affecting + // the underlying sharding status, so we just copy the spec + // for input "5" to output "5". + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + dst_device_mesh = src_spec.device_mesh; + } else if (dst_shape[decomposition_axis_in_dst + 1] == 1) { + // S[0] -> S[0]R + // E.g., shape [5] -> Reshape -> shape [5, 1] + // In this case "1" is added as a dummy axis without affecting + // the underlying sharding status, so we just copy the spec + // for input "5" to output "5". + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_device_mesh = src_spec.device_mesh; + } else if (repeats == 1 && dst_shape[decomposition_axis_in_dst + 1] == device_stride * repeat_stride) { + // S[0] -> RS[0] + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + dst_device_mesh = src_spec.device_mesh; + } else if (repeats != 1 && dst_shape[decomposition_axis_in_dst + 1] % (device_stride * repeat_stride) == 0) { + // S[0] -> RS[0] + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + // Extract [0, 1] from [0, 1, 0, 1]. + std::vector unique_device_mesh_elements( + src_spec.device_mesh.device_mesh_elements.begin(), + src_spec.device_mesh.device_mesh_elements.begin() + repeat_stride); + // Compute new repeats. + // Example of repeats change from 2 to 1: + // [16]-shape tensor [2, 8]-shape tensor + // with 1-D device mesh -> Reshape -> with 1-D device mesh + // [0, 1, 0, 1] (repeats=2) [0, 1] (repeats=1) + const int64_t new_repeat = dst_shape[decomposition_axis_in_dst + 1] / (device_stride * repeat_stride); + dst_device_mesh.device_mesh_shape.push_back(repeat_stride); + dst_device_mesh.device_mesh_elements = RepeatVector(unique_device_mesh_elements, new_repeat); + } else { + throw std::runtime_error("Not supported. Resharding is required."); + } + } + return std::make_tuple(true, TensorPartitionSpec::Create(dst_axis_specs, dst_device_mesh)); + } else { + // Not supported. Resharding is required. + return std::make_tuple(false, TensorPartitionSpec()); + } + } else { + // Source tensor is sharded on non-decomposed axis. + std::vector dst_axis_specs; + for (size_t src_axis = 0; src_axis < src_shape.size(); ++src_axis) { + if (src_axis != decomposed_axis_in_src) { + dst_axis_specs.push_back(AxisPartitionSpec::CreateCopy(src_spec.GetAxisSpec(src_axis))); + } else { + // R -> RR + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + } + } + + return std::make_tuple(true, TensorPartitionSpec::Create(dst_axis_specs, src_spec.device_mesh)); + } +} + +// Arguments: +// global_data_shape: logical shape of Reshape's 1st input. +// global_shape_span: logical content of Reshape's 2nd input. +// Returns: +// logical shape of Reshape's output. +inline TensorShape InferDistributedReshapeLogicalOutputShape( + const TensorShape& global_data_shape, + const gsl::span& global_shape_span, + const int64_t allow_zero) { + return onnxruntime::cuda::InferReshapeOutputShape( + global_data_shape, + global_shape_span, + allow_zero); +} + +template +DistributedReshape::DistributedReshape(const OpKernelInfo& info) : DistributedKernel(info) { + allow_zero_ = info.GetAttrOrDefault("allowzero", static_cast(0)); +} + +template +Status DistributedReshape::ComputeInternal(OpKernelContext* context) const { + ORT_ENFORCE(context != nullptr); + auto data_tensor = context->Input(0); + auto shape_tensor = context->Input(1); + const auto& data_sharding_spec = input_shard_specs_.at(0); + const auto& shape_sharding_spec = input_shard_specs_.at(1); + const auto& output_sharding_spec = output_shard_specs_.at(0); + + if (data_sharding_spec.HasNoShard() && shape_sharding_spec.HasNoShard() && output_sharding_spec.HasNoShard()) { + // Case: all inputs and outputs are not sharded. + const auto target_shape = onnxruntime::cuda::InferReshapeOutputShape( + data_tensor, + shape_tensor, + allow_zero_); + + auto output_tensor = context->Output(0, target_shape); + + // Copy data from input from output. + return FuncReshape( + this, + context, + data_tensor, + shape_tensor, + allow_zero_, + output_tensor); + } else { + ORT_ENFORCE(shape_sharding_spec.HasNoShard(), + "Shape tensor should not be sharded because it will trigger communication. " + "If sharding shape is needed, please request this feature on Github."); + ORT_ENFORCE(shape_tensor->Shape().NumDimensions() == 1, "Shape must be a 1-D tensor."); + const auto original_data_shape = ComputeOriginShape(data_tensor->Shape(), data_sharding_spec); + const auto original_output_shape = InferDistributedReshapeLogicalOutputShape( + original_data_shape, + shape_tensor->template DataAsSpan(), + allow_zero_); + + // TODO: remove below code after replacing std::vector with TensorShape in other APIs. + std::vector src_shape(original_data_shape.GetDims().begin(), original_data_shape.GetDims().end()); + std::vector dst_shape(original_output_shape.GetDims().begin(), original_output_shape.GetDims().end()); + + // Case: Two axis fusion + bool is_two_axis_fusion = false; + size_t two_axis_fusion_axis_in_dst = 0; + size_t two_axis_fusion_first_fused_axis_in_src = 0; + size_t two_axis_fusion_fused_axis_count = 0; + std::tie( + is_two_axis_fusion, + two_axis_fusion_axis_in_dst, + two_axis_fusion_first_fused_axis_in_src, + two_axis_fusion_fused_axis_count) = IsTwoAxisFusion(src_shape, dst_shape); + + if (is_two_axis_fusion) { + bool is_supported = false; + TensorPartitionSpec native_dst_spec; + std::tie(is_supported, native_dst_spec) = ComputeNativeSpecForTwoAxisFusion( + data_sharding_spec, + src_shape, + dst_shape, + two_axis_fusion_first_fused_axis_in_src, + two_axis_fusion_axis_in_dst); + + if (is_supported && native_dst_spec == output_sharding_spec) { + // In this case, we can apply Reshape with local shape on local tensor without resharding. + // Those local output tensors match the output tensors defined by + // sharding the logical tensor following the native sharding spec. + TensorShape local_shape = ComputeShardShape(original_output_shape, native_dst_spec); + auto output_tensor = context->Output(0, local_shape); + return FuncReshape( + this, + context, + data_tensor, + shape_tensor, + allow_zero_, + output_tensor); + } else { + // TODO: Reshape outputs from `native_dst_spec` to `output_sharding_spec`. + return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Encounter unsupported reshape pattern."); + } + } + + // Case: Two axis decomposition + bool is_two_axis_decomposition = false; + size_t two_axis_decomposition_decomposed_axis_in_src = 0; + size_t two_axis_decomposition_first_factor_axis_in_dst = 0; + size_t two_axis_decomposition_factor_axis_count_in_dst = 0; + std::tie( + is_two_axis_decomposition, + two_axis_decomposition_decomposed_axis_in_src, + two_axis_decomposition_first_factor_axis_in_dst, + two_axis_decomposition_factor_axis_count_in_dst) = IsTwoAxisDecomposition(src_shape, dst_shape); + + if (is_two_axis_decomposition) { + bool is_supported = false; + TensorPartitionSpec native_dst_spec; + std::tie(is_supported, native_dst_spec) = ComputeNativeSpecForTwoAxisDecomposition( + data_sharding_spec, + src_shape, + dst_shape, + two_axis_decomposition_decomposed_axis_in_src, + two_axis_decomposition_first_factor_axis_in_dst); + + if (is_supported && native_dst_spec == output_sharding_spec) { + // In this case, we can apply Reshape with local shape on local tensor without resharding. + // Those local output tensors match the output tensors defined by + // sharding the logical tensor following the native sharding spec. + TensorShape local_shape = ComputeShardShape(original_output_shape, native_dst_spec); + auto output_tensor = context->Output(0, local_shape); + return FuncReshape( + this, + context, + data_tensor, + shape_tensor, + allow_zero_, + output_tensor); + } else { + // TODO: Reshape outputs from `native_dst_spec` to `output_sharding_spec`. + return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Encounter unsupported reshape pattern."); + } + } + } + + return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Encounter unsupported reshape pattern."); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReshape, + kMSDomain, + 1, + int64_t, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .AllocateInputsContiguously() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReshape); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReshape, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .AllocateInputsContiguously() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReshape); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReshape, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .AllocateInputsContiguously() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReshape); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.h b/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.h new file mode 100644 index 0000000000000..e251c3cdc38d7 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sharding_spec.h" +#include "sharding.h" +#include "core/framework/tensor_shape.h" +#include "core/providers/cuda/tensor/reshape.h" +#include "core/providers/cuda/cuda_kernel.h" + +#include +#include +#include +#include +#include +#include + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedReshape final : public DistributedKernel { + public: + explicit DistributedReshape(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t allow_zero_; +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_squeeze.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_squeeze.cc new file mode 100644 index 0000000000000..c3cae2d0bf8ca --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_squeeze.cc @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "distributed_squeeze.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cuda/cuda_check_memory.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) +template +DistributedSqueeze::DistributedSqueeze(const OpKernelInfo& info) : DistributedKernel(info) { +} + +template +Status DistributedSqueeze::ComputeInternal(OpKernelContext* context) const { + auto input_tensor = context->Input(0); + auto axes_tensor = context->Input(1); + auto axes_span = axes_tensor->DataAsSpan(); + + const TensorPartitionSpec& input_spec = input_shard_specs_[0]; + const TensorPartitionSpec& axes_spec = input_shard_specs_[1]; + const TensorPartitionSpec& output_spec = output_shard_specs_[0]; + + ORT_ENFORCE(axes_spec.HasNoShard(), "Axes tensor cannot be sharded."); + + // Non-negative collection of axes to drop. + std::vector axes; + for (const auto axis : axes_span) { + axes.push_back(axis >= 0 ? axis : axis + input_tensor->Shape().NumDimensions()); + } + // Shape after dropping axes. + auto dims = input_tensor->Shape().AsShapeVector(); + // Sort in descending order so that we can drop axes from the end. + std::sort(axes.begin(), axes.end(), [](Tind a, Tind b) { return a > b; }); + for (const auto axis : axes) { + ORT_ENFORCE(input_tensor->Shape()[axis] == 1, "Cannot squeeze non-singleton dimension."); + dims.erase(dims.begin() + axis); + } + auto native_output_spec = TensorPartitionSpec::CreateByDropAxes( + input_spec, + axes); + ORT_ENFORCE( + output_spec == native_output_spec, + "Re-sharding is required but not supported yet for this case."); + auto output_tensor = context->Output(0, dims); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( + output_tensor->MutableDataRaw(), + input_tensor->DataRaw(), + input_tensor->SizeInBytes(), cudaMemcpyDeviceToDevice, Stream(context))); + return Status::OK(); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedSqueeze, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + DistributedSqueeze); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedSqueeze, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + DistributedSqueeze); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedSqueeze, + kMSDomain, + 1, + int64_t, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + DistributedSqueeze); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_squeeze.h b/onnxruntime/contrib_ops/cuda/collective/distributed_squeeze.h new file mode 100644 index 0000000000000..5b81d9c4792bd --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_squeeze.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "sharding.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedSqueeze final : public DistributedKernel { + public: + explicit DistributedSqueeze(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_unsqueeze.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_unsqueeze.cc new file mode 100644 index 0000000000000..a78f19101b0da --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_unsqueeze.cc @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "distributed_unsqueeze.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cuda/cuda_check_memory.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) +template +DistributedUnsqueeze::DistributedUnsqueeze(const OpKernelInfo& info) : DistributedKernel(info) { +} + +template +Status DistributedUnsqueeze::ComputeInternal(OpKernelContext* context) const { + auto input_tensor = context->Input(0); + auto axes_tensor = context->Input(1); + auto axes_span = axes_tensor->DataAsSpan(); + + const TensorPartitionSpec& input_spec = input_shard_specs_[0]; + const TensorPartitionSpec& axes_spec = input_shard_specs_[1]; + const TensorPartitionSpec& output_spec = output_shard_specs_[0]; + + ORT_ENFORCE(axes_spec.HasNoShard(), "Axes tensor cannot be sharded."); + + std::vector axes(axes_span.begin(), axes_span.end()); + std::sort(axes.begin(), axes.end()); + auto dims = input_tensor->Shape().AsShapeVector(); + auto native_output_spec = input_spec; + for (auto axis : axes) { + if (axis < 0) { + axis += input_tensor->Shape().NumDimensions() + 1; + } + dims.insert(dims.begin() + axis, 1); + native_output_spec = TensorPartitionSpec::CreateByInsertOneAxis( + native_output_spec, + axis); + } + ORT_ENFORCE( + output_spec == native_output_spec, + "Re-sharding is required but not supported yet for this case. ", + "Specified: ", output_spec.ToString(), + " Actual: ", native_output_spec.ToString()); + auto output_tensor = context->Output(0, dims); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( + output_tensor->MutableDataRaw(), + input_tensor->DataRaw(), + input_tensor->SizeInBytes(), cudaMemcpyDeviceToDevice, Stream(context))); + return Status::OK(); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedUnsqueeze, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + DistributedUnsqueeze); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedUnsqueeze, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + DistributedUnsqueeze); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedUnsqueeze, + kMSDomain, + 1, + int64_t, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + DistributedUnsqueeze); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_unsqueeze.h b/onnxruntime/contrib_ops/cuda/collective/distributed_unsqueeze.h new file mode 100644 index 0000000000000..005093ef78fb9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_unsqueeze.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "sharding.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedUnsqueeze final : public DistributedKernel { + public: + explicit DistributedUnsqueeze(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding.cc b/onnxruntime/contrib_ops/cuda/collective/sharding.cc index 7d106fd75e2d0..b6b509023a1a9 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharding.cc @@ -5,6 +5,8 @@ #include "mpi_include.h" #include "sharding_spec.h" +#include +#include #include "core/providers/cpu/tensor/slice.h" #include "core/providers/cuda/tensor/slice.h" #include "core/providers/cuda/math/matmul.h" @@ -28,7 +30,7 @@ void GatherTensor( const Tensor* tensor, Tensor* gathered) { const int64_t shard_axis = spec.GetPartitionAxis(); - const int64_t shard_count = spec.GetPartitionCount(shard_axis); + const int64_t shard_count = spec.GetUniqueDeviceCount(shard_axis); FuncAllGather( nccl_kernel, @@ -49,7 +51,7 @@ std::unique_ptr GatherTensor( const TensorPartitionSpec& spec, const Tensor* tensor) { const int64_t shard_axis = spec.GetPartitionAxis(); - const int64_t shard_count = spec.GetPartitionCount(shard_axis); + const int64_t shard_count = spec.GetUniqueDeviceCount(shard_axis); TensorShape gathered_shape(tensor->Shape()); gathered_shape[shard_axis] *= shard_count; @@ -80,7 +82,7 @@ void ShardTensor( const Tensor* tensor, Tensor* shard_tensor) { const int64_t shard_axis = spec.GetPartitionAxis(); - const int64_t shard_count = spec.GetPartitionCount(shard_axis); + const int64_t shard_count = spec.GetUniqueDeviceCount(shard_axis); TensorShape shard_shape = ComputeShardShape( tensor->Shape(), shard_axis, @@ -116,7 +118,7 @@ std::unique_ptr ShardTensor( TensorShape shard_shape = ComputeShardShape( tensor->Shape(), spec.GetPartitionAxis(), - spec.GetPartitionCount(spec.GetPartitionAxis())); + spec.GetUniqueDeviceCount(spec.GetPartitionAxis())); auto shard_buffer = Tensor::Create(tensor->DataType(), shard_shape, alloc); // Shard with pre-allocated buffer. @@ -237,16 +239,55 @@ void ReshardTensor( } DistributedKernel::DistributedKernel(const OpKernelInfo& info) : NcclKernel(info) { - std::vector device_mesh_elements = info.GetAttrsOrDefault("device_mesh_elements"); - std::vector device_mesh_shape = info.GetAttrsOrDefault("device_mesh_shape"); - std::vector input_shard_specs = info.GetAttrsOrDefault("input_shard_specs"); - std::vector output_shard_specs = info.GetAttrsOrDefault("output_shard_specs"); + // input_device_mesh_shapes[i] is the shape of device mesh for the i-th input. + // E.g., device_mesh_shapes = ["[2]", "[1]"] means the first input is + // stored on a 1-D mesh with 2 devices and the second input on another 1-D + // mesh with 1 device. + std::vector attr_input_device_mesh_shapes; + ORT_ENFORCE(info.GetAttrs("input_device_mesh_shapes", attr_input_device_mesh_shapes).IsOK()); + // input_device_mesh_elements[i] is the flattened device mesh for the i-th input. + // Note that its actual shape is input_device_mesh_shapes[i]. + // Example: + // Assume + // device_mesh_shapes = ["[2]", "[1]"] + // device_mesh_elements = ["[0,1]", "[0]"] + // Then the first input is stored on a 1-D mesh with 2 devices and the second + // input on another 1-D mesh with 1 device. + std::vector attr_input_device_mesh_elements; + ORT_ENFORCE(info.GetAttrs("input_device_mesh_elements", attr_input_device_mesh_elements).IsOK()); + + // input_shard_specs[i] is the sharding spec of the i-th input; e.g., + // "RR" if the i-th input is not sharded. + std::vector input_shard_specs; + ORT_ENFORCE(info.GetAttrs("input_shard_specs", input_shard_specs).IsOK()); + + ORT_ENFORCE(attr_input_device_mesh_shapes.size() == attr_input_device_mesh_elements.size()); + ORT_ENFORCE(attr_input_device_mesh_shapes.size() == input_shard_specs.size()); + + // Begin parsing sharding metadata for inputs. for (size_t i = 0; i < input_shard_specs.size(); ++i) { + auto device_mesh_shape = ParseStringAsInt64Vector(attr_input_device_mesh_shapes[i]); + auto device_mesh_elements = ParseStringAsInt64Vector(attr_input_device_mesh_elements[i]); auto spec = CreateTensorPartitionSpec(input_shard_specs[i], device_mesh_shape, device_mesh_elements); input_shard_specs_.push_back(spec); } + + std::vector attr_output_device_mesh_shapes; + ORT_ENFORCE(info.GetAttrs("output_device_mesh_shapes", attr_output_device_mesh_shapes).IsOK()); + + std::vector attr_output_device_mesh_elements; + ORT_ENFORCE(info.GetAttrs("output_device_mesh_elements", attr_output_device_mesh_elements).IsOK()); + + std::vector output_shard_specs; + ORT_ENFORCE(info.GetAttrs("output_shard_specs", output_shard_specs).IsOK()); + + ORT_ENFORCE(attr_output_device_mesh_shapes.size() == attr_output_device_mesh_elements.size()); + ORT_ENFORCE(attr_output_device_mesh_shapes.size() == output_shard_specs.size()); + for (size_t i = 0; i < output_shard_specs.size(); ++i) { + auto device_mesh_shape = ParseStringAsInt64Vector(attr_output_device_mesh_shapes[i]); + auto device_mesh_elements = ParseStringAsInt64Vector(attr_output_device_mesh_elements[i]); auto spec = CreateTensorPartitionSpec(output_shard_specs[i], device_mesh_shape, device_mesh_elements); output_shard_specs_.push_back(spec); } diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc index f1d399077e37b..20c936e1b6718 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc @@ -27,6 +27,28 @@ void ValidateAxisIndex(const int64_t axis, const int64_t rank) { ORT_ENFORCE(adjusted_axis >= 0 && adjusted_axis < rank, "axis,", axis, ", should be in [", -rank, ",", rank, ")."); } +std::vector ParseStringAsInt64Vector(const std::string& str) { + if (str.empty() || str.front() != '[' || str.back() != ']') { + throw std::invalid_argument("Invalid input string format"); + } + // Parsed vector. + // If input is "[0, 1, 2]", result should be {0, 1, 2}. + std::vector result; + // Skip '[' and ']' + std::istringstream iss(str.substr(1, str.size() - 2)); + + // Extract integers separated by ',' or whitespaces. + int64_t num = -1; + while (/* Read one number at a time */ iss >> num) { + result.push_back(num); + // Skip the comma + if (iss.peek() == ',') { + iss.ignore(); + } + } + return result; +} + DeviceMesh CreateDeviceMesh( std::vector device_mesh_shape, std::vector device_mesh_elements) { @@ -107,7 +129,7 @@ TensorShape ComputeOriginShape(const TensorShape& shard_shape, const TensorParti } TensorShape shape(shard_shape); const int64_t axis = spec.GetPartitionAxis(); - shape[axis] *= spec.GetPartitionCount(axis); + shape[axis] *= spec.GetUniqueDeviceCount(axis); return shape; } @@ -118,7 +140,15 @@ TensorShape ComputeShardShape(const TensorShape& shape, const TensorPartitionSpe return shard_shape; } const int64_t axis = spec.GetPartitionAxis(); - shard_shape[axis] /= spec.GetPartitionCount(axis); + const int64_t unique_device_count = spec.GetUniqueDeviceCount(axis); + ORT_ENFORCE(shard_shape[axis] % unique_device_count == 0, "Number of shards must be divisible by sharded axis' dimension."); + // If a [8, 16]-tensor is shared by device mesh [0, 1, 0, 1] along axis=1 (2nd axis), + // the local tensors on device 0 & 1 have same shape [8, 8 (from 16/2)] instead of + // [8, 4 (from 16/4)]. The reason is that + // - First, the original tensor are split into 4 sub-tensors [8, 4] along the 2nd axis. + // - The 1st and 3rd sub-tensors are concatenated along axis=1 to one tensor on device 0. + // - The 2nd and 4th sub-tensors are concatenated along axis=1 to one tensor on device 1. + shard_shape[axis] /= unique_device_count; return shard_shape; } @@ -180,7 +210,7 @@ bool CanShard(const TensorShape& shape, const TensorPartitionSpec& spec) { if (axis < 0 || gsl::narrow(axis) >= shape.NumDimensions()) { return false; } - if (shape[axis] % spec.GetPartitionCount(axis) != 0) { + if (shape[axis] % spec.GetDeviceCount(axis) != 0) { return false; } return true; diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h index 0f5ef6927a545..5abc50a61c9a3 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h +++ b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h @@ -76,6 +76,43 @@ class DeviceMesh { void Print() const { std::cout << ToString() << std::endl; } + + static DeviceMesh Create1D(std::vector device_mesh_elements, size_t repeats = 1) { + DeviceMesh device_mesh; + device_mesh.device_mesh_shape.push_back(device_mesh_elements.size() * repeats); + for (size_t i = 0; i < repeats; ++i) { + device_mesh.device_mesh_elements.insert( + device_mesh.device_mesh_elements.end(), + device_mesh_elements.begin(), + device_mesh_elements.end()); + } + return device_mesh; + } + + // If the two meshes have the same shape and elements, return true. + // Otherwise, return false. + bool operator==(const DeviceMesh& other) const { + if (device_mesh_shape.size() != other.device_mesh_shape.size() || + device_mesh_elements.size() != other.device_mesh_elements.size()) { + return false; + } + + for (size_t i = 0; i < device_mesh_elements.size(); ++i) { + if (device_mesh_elements.at(i) != other.device_mesh_elements.at(i)) { + return false; + } + } + for (size_t i = 0; i < device_mesh_shape.size(); ++i) { + if (device_mesh_shape.at(i) != other.device_mesh_shape.at(i)) { + return false; + } + } + return true; + } + + bool operator!=(const DeviceMesh& other) const { + return !(*this == other); + } }; class AxisPartitionSpec { @@ -114,10 +151,14 @@ class AxisPartitionSpec { return AxisPartitionSpec(Condition::Shard, device_mesh_axis); } + static AxisPartitionSpec CreateCopy(const AxisPartitionSpec& spec) { + return AxisPartitionSpec(spec.cond, spec.device_mesh_axis); + } + // A normal ctor. // TODO(wechi): Consider to hide it and revise the `public` members/functions // exposed to the user. - AxisPartitionSpec(Condition cond_, int device_mesh_axis_) : device_mesh_axis(device_mesh_axis_), cond(cond_) {} + AxisPartitionSpec(Condition cond_, int device_mesh_axis_) : cond(cond_), device_mesh_axis(device_mesh_axis_) {} // Helper to debug and generate error message; e.g., // "RS[0]". @@ -132,6 +173,14 @@ class AxisPartitionSpec { void Print() const { std::cout << ToString() << std::endl; } + + bool operator==(const AxisPartitionSpec& other) const { + return cond == other.cond && device_mesh_axis == other.device_mesh_axis; + } + + bool operator!=(const AxisPartitionSpec& other) const { + return !(*this == other); + } }; // Return true if `axis` is a valid axis index for a tensor of rank `rank`. @@ -193,6 +242,42 @@ class TensorPartitionSpec { // const TensorPartitionSpec& spec, int64_t new_shard_axis) { // } + // Copy-construct `spec` but with all tensor axes replicated. + // The new spec have the same number of axis specs and the same device mesh. + static TensorPartitionSpec CreateAllReplica( + const size_t rank, const DeviceMesh& device_mesh) { + std::vector axis_specs(rank, AxisPartitionSpec::CreateReplica()); + return TensorPartitionSpec::Create(axis_specs, device_mesh); + } + + static TensorPartitionSpec CreateOneTensorAxisOneDeviceMeshAxisSharding( + const size_t rank, const DeviceMesh& device_mesh, const size_t tensor_axis, const size_t device_mesh_axis) { + std::vector axis_specs(rank, AxisPartitionSpec::CreateReplica()); + axis_specs[tensor_axis] = AxisPartitionSpec::CreateShard(device_mesh_axis); + return TensorPartitionSpec::Create(axis_specs, device_mesh); + } + + static TensorPartitionSpec CreateByDropAxes( + const TensorPartitionSpec& spec, const std::vector& axes_to_drop) { + std::vector axis_specs; + for (size_t i = 0; i < spec.axis_specs.size(); ++i) { + if (std::find(axes_to_drop.begin(), axes_to_drop.end(), i) != axes_to_drop.end()) { + // This axis, i, is in axes_to_drop. Let's not copy its spec. + continue; + } + axis_specs.push_back(spec.axis_specs[i]); + } + return TensorPartitionSpec::Create(axis_specs, spec.device_mesh); + } + + static TensorPartitionSpec CreateByInsertOneAxis( + const TensorPartitionSpec& spec, + const size_t axis_to_insert) { + std::vector axis_specs(spec.axis_specs); + axis_specs.insert(axis_specs.begin() + axis_to_insert, AxisPartitionSpec::CreateReplica()); + return TensorPartitionSpec::Create(axis_specs, spec.device_mesh); + } + // Helper to debug and generate error message; e.g., // "TensorPartitionSpec{RS[0], Device Mesh: DeviceMesh{Shape: [4,], Elements: [0,1,2,3,]}}". std::string ToString() const { @@ -303,7 +388,7 @@ class TensorPartitionSpec { // Return the number of shards along the first sharded tensor axis. // This value matches the number of devices along the associated mesh axis. // Return 1 if there is no sharding. - int64_t GetPartitionCount(int64_t axis) const { + int64_t GetDeviceCount(int64_t axis) const { ValidateAxisIndex(axis, Rank()); auto axis_spec = GetAxisSpec(axis); if (axis_spec.cond == AxisPartitionSpec::Condition::Replica) { @@ -312,8 +397,42 @@ class TensorPartitionSpec { return device_mesh.device_mesh_shape.at(axis_spec.device_mesh_axis); } } + + // Similar to GetDeviceCount(), but returns the number of unique devices + // along the first sharded tensor axis. + int64_t GetUniqueDeviceCount(int64_t axis) const { + ValidateAxisIndex(axis, Rank()); + auto axis_spec = GetAxisSpec(axis); + if (axis_spec.cond == AxisPartitionSpec::Condition::Replica) { + return 1; + } else { + std::set device_ids( + device_mesh.device_mesh_elements.begin(), + device_mesh.device_mesh_elements.end()); + return device_ids.size(); + } + } + + bool operator==(const TensorPartitionSpec& other) const { + if (axis_specs.size() != other.axis_specs.size()) { + return false; + } + for (size_t i = 0; i < axis_specs.size(); ++i) { + if (!(axis_specs.at(i) == other.axis_specs.at(i))) { + return false; + } + } + return device_mesh == other.device_mesh; + } + + bool operator!=(const TensorPartitionSpec& other) const { + return !(*this == other); + } }; +// Parse "[0, 1, 2, 3]" as std::vector{0, 1, 2, 3}. +std::vector ParseStringAsInt64Vector(const std::string& str); + DeviceMesh CreateDeviceMesh( std::vector device_mesh_shape, std::vector device_mesh_elements); diff --git a/onnxruntime/contrib_ops/cuda/conv_transpose_with_dynamic_pads.h b/onnxruntime/contrib_ops/cuda/conv_transpose_with_dynamic_pads.h index 6f7a04d059034..a768b2a7d8a24 100644 --- a/onnxruntime/contrib_ops/cuda/conv_transpose_with_dynamic_pads.h +++ b/onnxruntime/contrib_ops/cuda/conv_transpose_with_dynamic_pads.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #pragma once @@ -10,12 +11,12 @@ namespace contrib { namespace cuda { template -class ConvTransposeWithDynamicPads : public ::onnxruntime::cuda::ConvTranspose { +class ConvTransposeWithDynamicPads : public ::onnxruntime::cuda::ConvTranspose { public: - ConvTransposeWithDynamicPads(const OpKernelInfo& info) : ::onnxruntime::cuda::ConvTranspose(info) {} + ConvTransposeWithDynamicPads(const OpKernelInfo& info) : ::onnxruntime::cuda::ConvTranspose(info) {} Status ComputeInternal(OpKernelContext* context) const override { - return ::onnxruntime::cuda::ConvTranspose::DoConvTranspose(context, true); + return ::onnxruntime::cuda::ConvTranspose::DoConvTranspose(context, true); } }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 3e440a091870a..6162895d00d79 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -65,6 +65,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, PackedMultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, PackedMultiHeadAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BeamSearch); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, WhisperBeamSearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ConvTransposeWithDynamicPads); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop); @@ -90,10 +91,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SkipGroupNorm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); @@ -113,7 +117,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulBnb4); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulBnb4); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, UnfoldTensor); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DynamicTimeWarping); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, DequantizeLinear); @@ -135,6 +145,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GemmFloat8); #ifdef ENABLE_ATEN class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen); @@ -156,6 +167,31 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedSlice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSlice); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedReshape); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReshape); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReshape); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedExpand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceSum); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMax); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMean); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedUnsqueeze); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedUnsqueeze); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedUnsqueeze); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedSqueeze); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedSqueeze); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSqueeze); #endif template <> @@ -219,6 +255,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -244,10 +281,13 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -267,6 +307,10 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -278,6 +322,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // TransposedMatMul is still here for backward compatibility @@ -295,6 +341,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifdef ENABLE_ATEN BuildKernelCreateInfo, @@ -316,6 +363,31 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index 301b2e76b1b2d..87e88ac31c998 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -1,6 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/diffusion/group_norm.h" #include "contrib_ops/cuda/diffusion/group_norm_impl.h" @@ -15,14 +14,22 @@ ONNX_OPERATOR_KERNEL_EX( GroupNorm, kMSDomain, 1, kCudaExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm); +ONNX_OPERATOR_KERNEL_EX( + SkipGroupNorm, kMSDomain, 1, kCudaExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm); + using namespace ONNX_NAMESPACE; namespace { + template struct DispatchGroupNorm { Status operator()(cudaStream_t stream, Tensor* output, + Tensor* add_out, const Tensor* input, + const Tensor* skip, + const Tensor* bias, const Tensor* gamma, const Tensor* beta, void* workspace, @@ -32,12 +39,17 @@ struct DispatchGroupNorm { int height, int width, int num_groups, - bool use_swish_activation) { + bool use_swish_activation, + bool broadcast_skip, + int channels_per_block) { typedef typename ToCudaType::MappedType CudaT; return LaunchGroupNormKernel( stream, reinterpret_cast(output->MutableData()), + add_out == nullptr ? nullptr : reinterpret_cast(add_out->MutableData()), reinterpret_cast(input->Data()), + skip == nullptr ? nullptr : reinterpret_cast(skip->Data()), + bias == nullptr ? nullptr : reinterpret_cast(bias->Data()), gamma->Data(), beta->Data(), workspace, @@ -47,13 +59,21 @@ struct DispatchGroupNorm { height, width, num_groups, - use_swish_activation); + use_swish_activation, + broadcast_skip, + channels_per_block); } }; } // namespace GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { + has_skip_ = false; + const std::string& op_name = op_info.GetKernelDef().OpName(); + if (op_name == "SkipGroupNorm") { + has_skip_ = true; + } + epsilon_ = op_info.GetAttrOrDefault("epsilon", 1e-5f); ORT_ENFORCE(epsilon_ >= 0); @@ -68,6 +88,23 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { use_swish_activation_ = (activation == 1); channels_last_ = (op_info.GetAttrOrDefault("channels_last", static_cast(1)) != 0); + + channels_per_block_ = 0; +} + +Status GroupNorm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr /*alloc*/, + bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { + is_packed = false; + + // Compute and cache cPerBlock using number of channels from gamma tensor shape. + if (input_idx == 1) { + auto gamma_shape = tensor.Shape(); + if (gamma_shape.NumDimensions() == 1) { + channels_per_block_ = GetChannelsPerBlock(static_cast(gamma_shape[0]), num_groups_); + } + } + + return Status::OK(); } Status GroupNorm::ComputeInternal(OpKernelContext* context) const { @@ -77,22 +114,38 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { Tensor* output = context->Output(0, input->Shape()); if (!channels_last_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "only the channels_last layout is supported"); } + if (!gamma->IsDataType() || !beta->IsDataType()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "GroupNorm only supports gamma and beta in float type"); + } + const auto& input_dims = input->Shape().GetDims(); if (input_dims.size() != 4) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input is expected to have 4 dimensions, got ", input_dims.size()); } + // Only support NHWC format right now. + int batch_size = static_cast(input_dims[0]); + int height = static_cast(input_dims[1]); + int width = static_cast(input_dims[2]); + int num_channels = static_cast(input_dims[3]); + + if (num_channels % num_groups_ != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "number of channels should be divisiable by num_groups"); + } + const auto& gamma_dims = gamma->Shape().GetDims(); if (gamma_dims.size() != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "gamma is expected to have 1 dimension, got ", gamma_dims.size()); } - if (gamma_dims[0] != input_dims[3]) { + if (gamma_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of channels in gamma and input does not match"); } @@ -102,22 +155,11 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "beta is expected to have 1 dimension, got ", beta_dims.size()); } - if (beta_dims[0] != input_dims[3]) { + if (beta_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of channels in beta and input does not match"); } - // Input and output format is NHWC - int batch_size = static_cast(input_dims[0]); - int num_channels = static_cast(input_dims[3]); - int height = static_cast(input_dims[1]); - int width = static_cast(input_dims[2]); - - if (num_channels % num_groups_ != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "number of channels should be divisiable by num_groups"); - } - if (context->GetUseDeterministicCompute()) { static std::once_flag log_warning; std::call_once(log_warning, []() { @@ -125,17 +167,59 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { }); } - auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream()); + const Tensor* skip = nullptr; + const Tensor* bias = nullptr; + Tensor* add_out = nullptr; + + bool broadcast_skip = false; + if (has_skip_) { + skip = context->Input(3); + bias = context->Input(4); + add_out = context->Output(1, input->Shape()); + + if (bias != nullptr) { // Bias is optional + // If provided, bias has shape (C). + const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "bias is expected to have 1 dimension, got ", bias_dims.size()); + } + if (bias_dims[0] != num_channels) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of channels in bias and input does not match"); + } + } + + // Check whether skip can be broadcasted to input shape. + if (skip->Shape() != input->Shape()) { + const auto& dims = skip->Shape().GetDims(); + // The shape of ship can be (N, C) or (N, 1, 1, C) for broadcast. + const bool b2 = (dims.size() == 2 && dims[0] == batch_size && dims[1] == num_channels); + const bool b4 = (dims.size() == 4 && dims[0] == batch_size && + dims[1] == 1 && dims[2] == 1 && dims[3] == num_channels); + broadcast_skip = b2 || b4; + if (!broadcast_skip) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "skip shape is expected to be (N, H, W, C) or (N, 1, 1, C) or (N, C)"); + } + } + } + + auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups_), + context->GetComputeStream()); utils::MLTypeCallDispatcher dispatcher(input->GetElementType()); - return dispatcher.InvokeRet(Stream(context), output, input, gamma, beta, workspace.get(), + return dispatcher.InvokeRet(Stream(context), output, add_out, input, skip, bias, + gamma, beta, workspace.get(), epsilon_, batch_size, num_channels, height, width, num_groups_, - use_swish_activation_); + use_swish_activation_, + broadcast_skip, + channels_per_block_); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h index 52c006e6bdb96..b408b3c1ee79b 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h @@ -16,11 +16,16 @@ class GroupNorm final : public CudaKernel { GroupNorm(const OpKernelInfo& op_kernel_info); Status ComputeInternal(OpKernelContext* context) const override; + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, PrePackedWeights* prepacked_weights) override; + private: - bool use_swish_activation_; + bool use_swish_activation_; // use SiLU (also known as Swish) activation after group normalization? float epsilon_; int num_groups_; bool channels_last_; + bool has_skip_; // true for SkipGroupNorm operator; false for GroupNorm + int channels_per_block_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu index 01ba078b4be77..48b161552ce0c 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -16,18 +16,45 @@ */ // The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +// Modifications: heuristic channels per block; support epsilon; support skip and bias; update coding style. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include #include #include #include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/diffusion/group_norm_impl.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +using namespace onnxruntime::cuda; + namespace onnxruntime { namespace contrib { namespace cuda { -static inline int32_t divUp(int32_t m, int32_t n) { +namespace { + +// TODO: Similar to SkipLayerNorm kernel, read/write up to 8 channels at same time. +constexpr static int32_t CHANNELS_PER_THREAD = 2; + +constexpr static int kSizes[] = {128, 256, 320, 384, 512}; +constexpr static size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); +constexpr static int kMaxSize = kSizes[kNumOfSizes - 1]; + +int NextSize(int x) { + for (size_t i = 0; i < kNumOfSizes; ++i) { + if (x <= kSizes[i]) { + return kSizes[i]; + } + } + + return x; +} +} // namespace + +static inline int32_t DivUp(int32_t m, int32_t n) { return (m + n - 1) / n; } @@ -41,14 +68,14 @@ struct GroupSums { // The sum. float sum; // The sum of squares. - float sumSq; + float sum_sq; }; struct GroupSumsOp { inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { GroupSums dst; dst.sum = b.flag ? b.sum : (a.sum + b.sum); - dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq); + dst.sum_sq = b.flag ? b.sum_sq : (a.sum_sq + b.sum_sq); dst.flag = a.flag + b.flag; return dst; } @@ -56,54 +83,85 @@ struct GroupSumsOp { template struct GroupNormNHWCParams { - // The output buffer. Layout NHWC. + // The output buffer. Shape is (n, h, w, c). T* dst; - // The input buffer. Layout NHWC. + + // Optional output of element-wise add result of src, skip and bias. Shape is (n, h, w, c). + T* add_out; + + // The input buffer. Shape is (n, h, w, c). T const* src; + + // Optional input buffer for skip tensor. Shape is (n, h, w, c) or (n, 1, 1, c) or (n, c). + T const* skip; + + // Optional input buffer for bias tensor. Shape is (c). + T const* bias; + // The gamma scaling factor. float const* gamma; + // The beta term to add in GN. float const* beta; - // The temporary buffer to do the global parallel reduction. Size: - // BLOCKS_PER_BATCH x C x 2. - float* redBuffer; + + // The temporary buffer to do the global parallel reduction. Shape is (n, 2, g), where g is number of groups. + float* group_sum_buffer; // The number of instances in the batch. int32_t n; + // The height and width of each activation map. int32_t h; int32_t w; - // The number of channels. + + // Number of channels. int32_t c; - // The number of groups. + + // Number of groups. int32_t groups; - // Do we apply the Swish activation function? - bool withSwish; + + // Do we apply the SiLU activation function? + bool use_silu; // Precomputed values and parameters to control the execution of the kernels. - // The number of activations per instance (h * w) and the number of - // activations per block. + // Number of activations per instance (h * w) int32_t hw; - int32_t hwPerBlock; - // The number of channels per group and blocks per activation in the C - // dimension. - int32_t cPerBlock; - int32_t cPerGroup; + + // Number of activations per block + int32_t hw_per_block; + + // Number of channels per block in the C dimension. + int32_t channels_per_block; + + // Number of channels per group in the C dimension. + int32_t channels_per_group; // The precomputed stride between instances. int32_t hwc; - // The inverse of hwc in floats (to compute mean/var). - float invHWC; + // The inverse of hw*channels_per_group to compute mean of a group. + float inv_hw_channels_per_group; // The precomputed number of groups per block. - int32_t groupsPerBlock; + int32_t groups_per_block; + + // Number of threads per block + int32_t threads_per_block; + + // Epsilon to get stable variance in normalization. + float epsilon; + + // Whether skip need broadcast. True if shape of skip is (N, C) or (N, 1, 1, C); False otherwise. + bool broadcast_skip; + + // For SkipGroupNorm, it points to the intermediate result of adding skip and bias. + T* skip_workspace; }; template -inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sumSq); +inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq); template <> -inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sumSq) { +inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) { // Fetch two channels per thread. __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); @@ -113,11 +171,11 @@ inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, fl sum += f2.x + f2.y; // Update the sum of squares. - sumSq += f2.x * f2.x + f2.y * f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; } template <> -inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sumSq) { +inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { // Fetch two channels per thread. float2 f2 = *reinterpret_cast(&src[offset]); @@ -125,119 +183,220 @@ inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, f sum += f2.x + f2.y; // Update the sum of squares. - sumSq += f2.x * f2.x + f2.y * f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +// Sum for SkipGroupNorm: add_out[offset] = src[offset] + skip[skip_offset] + bias[bias_offset] +template +inline __device__ void AddSkipBias(T* add_out, const T* src, const T* skip, const T* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq); + +template <> +inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); + __half2 b = *reinterpret_cast<__half2 const*>(&bias[bias_offset]); + h2 = h2 + b; + h2 = h2 + s; + + *reinterpret_cast<__half2*>(&add_out[offset]) = h2; + + float2 f2 = __half22float2(h2); + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + float2 f2 = *reinterpret_cast(&src[offset]); + float2 s = *reinterpret_cast(&skip[skip_offset]); + float2 b = *reinterpret_cast(&bias[bias_offset]); + f2.x += s.x + b.x; + f2.y += s.y + b.y; + + *reinterpret_cast(&add_out[offset]) = f2; + + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +// Sum for SkipGroupNorm without bias: add_out[offset] = src[offset] + skip[skip_offset] +template +inline __device__ void AddSkip(T* add_out, const T* src, const T* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq); + +template <> +inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); + h2 = h2 + s; + + *reinterpret_cast<__half2*>(&add_out[offset]) = h2; + + float2 f2 = __half22float2(h2); + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void AddSkip(float* add_out, const float* src, const float* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + float2 f2 = *reinterpret_cast(&src[offset]); + float2 s = *reinterpret_cast(&skip[skip_offset]); + f2.x += s.x; + f2.y += s.y; + *reinterpret_cast(&add_out[offset]) = f2; + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; } -template -__global__ void groupNormNHWCSumKernel(GroupNormNHWCParams params) { +template +__global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { // The object in charge of doing the sums for the different blocks. - typedef cub::BlockScan BlockScan; + typedef cub::BlockScan BlockScan; // Allocate shared memory for BlockScan. - __shared__ typename BlockScan::TempStorage tempStorage; - // Allocate shared memory for the groups. We could reduce the amount of shared - // memory reserved. - __shared__ float2 smem[tTHREADS_PER_BLOCK]; + __shared__ typename BlockScan::TempStorage temp_storage; + + // Allocate shared memory for the groups. We could reduce the amount of shared memory reserved. + __shared__ float2 smem[THREADS_PER_BLOCK]; // The instance in the batch. int32_t ni = blockIdx.z; - // The channel loaded by that thread (2 channels per thread for F16x2). - int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; + + // The channel loaded by that thread. + int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; + + if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { + return; + } // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; + int32_t hw_begin = blockIdx.y * params.hw_per_block; // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); // The sums. float sum = 0.F; - float sumSq = 0.F; + float sum_sq = 0.F; // Iterate over the activations to compute the sums. - if (ci < params.c) { - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The offset. - int64_t offset = static_cast(ni) * params.hwc + static_cast(hwi) * params.c + ci; - UpdateSum(params.src, offset, sum, sumSq); + int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; + if (params.skip != nullptr) { + // SkipGroupNorm: skip is (n, h, w, c) or (n, 1, 1, c) or (n, c), bias is (c), and add_out is (n, h, w, c) + const int64_t bias_offset = static_cast(ci); + T* add_out = params.skip_workspace; + if (params.broadcast_skip) { + const int64_t skip_offset = static_cast(ni) * params.c + ci; + + if (params.bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkipBias(add_out, params.src, params.skip, params.bias, offset, skip_offset, bias_offset, sum, sum_sq); + } + } else { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkip(add_out, params.src, params.skip, offset, skip_offset, sum, sum_sq); + } + } + } else { + if (params.bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkipBias(add_out, params.src, params.skip, params.bias, offset, offset, bias_offset, sum, sum_sq); + } + } else { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkip(add_out, params.src, params.skip, offset, offset, sum, sum_sq); + } + } + } + } else { // GroupNorm + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + UpdateSum(params.src, offset, sum, sum_sq); } } - // The group that thread works on and the channel in the group (modulus). - int32_t gi = threadIdx.x * 2 / params.cPerGroup; - int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi; + // The group index relative to the first group within the same block. + int32_t gi = threadIdx.x * CHANNELS_PER_THREAD / params.channels_per_group; + // The channel in the group. + int32_t cj = ci % params.channels_per_group; // The data for the summations. - GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq}; + GroupSums inp{cj == 0 ? 1 : 0, sum, sum_sq}; - // Do the segmented scan. + // Do the segmented scan. InclusiveScan is not deterministic. GroupSums out; - BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp()); + BlockScan(temp_storage).InclusiveScan(inp, out, GroupSumsOp()); - // Store the results for the groups in shared memory (to produce coalesced - // stores later). - if (cj == params.cPerGroup - 2) { //2 channels per thread - smem[gi] = make_float2(out.sum, out.sumSq); + // Store the results for the groups in shared memory (to produce coalesced stores later). + // For each group, only the last thread of that group is picked to save sum to shared memory. + if (cj == params.channels_per_group - CHANNELS_PER_THREAD) { + smem[gi] = make_float2(out.sum, out.sum_sq); } // Make sure the data is in shared memory. __syncthreads(); - // The global group index. - int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x; - // Threads that have nothing left to do, exit. - if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) { + if (threadIdx.x >= params.groups_per_block) { return; } - // The first threads (those storing to global memory, load the values). - float2 sums = smem[threadIdx.x]; - - // Store to global memory. - atomicAdd(¶ms.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x); - atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); + // The global group index. + // Use neighboring threads for coalesced write. + int32_t gj = blockIdx.x * params.groups_per_block + threadIdx.x; + + if (gj < params.groups) { + float2 sums = smem[threadIdx.x]; + const int index = (2 * ni) * params.groups + gj; + atomicAdd(¶ms.group_sum_buffer[index], sums.x); + atomicAdd(¶ms.group_sum_buffer[index + params.groups], sums.y); + } } template -void groupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { - // Make sure the values are as we expect. - ORT_ENFORCE(params.c % params.cPerBlock == 0 && params.hw % params.hwPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0); - +void GroupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params.c / params.cPerBlock; + grid.x = DivUp(params.c, params.channels_per_block); + // The number of blocks to compute all the activations in a given instance. - grid.y = divUp(params.hw, params.hwPerBlock); + grid.y = DivUp(params.hw, params.hw_per_block); + // The number of instances. grid.z = params.n; - switch (params.cPerBlock) { - case 320: - groupNormNHWCSumKernel<<>>(params); + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params.threads_per_block) { + case 256: + GroupNormNHWCSumKernel<<>>(params); break; - case 480: - groupNormNHWCSumKernel<<>>(params); + case 192: + GroupNormNHWCSumKernel<<>>(params); break; - case 256: - groupNormNHWCSumKernel<<>>(params); + case 160: + GroupNormNHWCSumKernel<<>>(params); break; case 128: - groupNormNHWCSumKernel<<>>(params); + GroupNormNHWCSumKernel<<>>(params); + break; + case 64: + GroupNormNHWCSumKernel<<>>(params); break; - default: - ORT_NOT_IMPLEMENTED("Not implemented"); } } template -__device__ void computeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float invStdDev, float2& gammaF2, float2& betaF2, bool swish); +__device__ void ComputeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu); template <> -__device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float invStdDev, - float2& gammaF2, float2& betaF2, bool swish) { +__device__ void ComputeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu) { // Fetch two channels per thread. __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); @@ -245,15 +404,15 @@ __device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, flo float2 f2 = __half22float2(h2); // Normalize the channels. - f2.x = (f2.x - mean) * invStdDev; - f2.y = (f2.y - mean) * invStdDev; + f2.x = (f2.x - mean) * inv_std_dev; + f2.y = (f2.y - mean) * inv_std_dev; // Scale by gamma and add beta. - f2.x = gammaF2.x * f2.x + betaF2.x; - f2.y = gammaF2.y * f2.y + betaF2.y; + f2.x = gamma_f2.x * f2.x + beta_f2.x; + f2.y = gamma_f2.y * f2.y + beta_f2.y; - // Apply Swish if needed. - if (swish) { + // Apply SiLU activation if needed. + if (silu) { f2.x = f2.x * sigmoid(f2.x); f2.y = f2.y * sigmoid(f2.y); } @@ -262,21 +421,21 @@ __device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, flo } template <> -__device__ void computeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float invStdDev, - float2& gammaF2, float2& betaF2, bool swish) { +__device__ void ComputeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu) { // Fetch two channels per thread. float2 f2 = *reinterpret_cast(&src[offset]); // Normalize the channels. - f2.x = (f2.x - mean) * invStdDev; - f2.y = (f2.y - mean) * invStdDev; + f2.x = (f2.x - mean) * inv_std_dev; + f2.y = (f2.y - mean) * inv_std_dev; // Scale by gamma and add beta. - f2.x = gammaF2.x * f2.x + betaF2.x; - f2.y = gammaF2.y * f2.y + betaF2.y; + f2.x = gamma_f2.x * f2.x + beta_f2.x; + f2.y = gamma_f2.y * f2.y + beta_f2.y; - // Apply Swish if needed. - if (swish) { + // Apply SiLU activation if needed. + if (silu) { f2.x = f2.x * sigmoid(f2.x); f2.y = f2.y * sigmoid(f2.y); } @@ -284,110 +443,142 @@ __device__ void computeGroupNorm(const float* src, float* dst, int64_t offset, f *reinterpret_cast(&dst[offset]) = f2; } -template -__global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams params) { - // The channel loaded by that thread (2 channels per thread for F16x2). - int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; - if (ci >= params.c) { +template +__global__ void GroupNormNHWCScaleKernel(GroupNormNHWCParams params) { + // The channel loaded by that thread. + int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; + if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { return; } // The instance in the batch. int32_t ni = blockIdx.z; - // The group that thread works on and the channel in the group (modulus). - int32_t gi = ci / params.cPerGroup; + // The group that thread works on. + int32_t gi = ci / params.channels_per_group; // Load the sum and sum of squares for the group. - float sum = 0.F, sumSq = 0.F; + float sum = 0.F, sum_sq = 0.F; if (gi < params.groups) { - sum = params.redBuffer[(2 * ni + 0) * params.groups + gi]; - sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi]; + const int index = (2 * ni) * params.groups + gi; + sum = params.group_sum_buffer[index]; + sum_sq = params.group_sum_buffer[index + params.groups]; } - // Load gamma/beta. - float2 gammaF2 = *reinterpret_cast(¶ms.gamma[ci]); - float2 betaF2 = *reinterpret_cast(¶ms.beta[ci]); + // Load gamma/beta. Fetch two per thread. + float2 gamma_f2 = *reinterpret_cast(¶ms.gamma[ci]); + float2 beta_f2 = *reinterpret_cast(¶ms.beta[ci]); // Compute the mean. - float mean = sum * params.invHWC; + float mean = sum * params.inv_hw_channels_per_group; // Compute the variance. - float var = sumSq * params.invHWC - (mean * mean); + float var = sum_sq * params.inv_hw_channels_per_group - (mean * mean); // Compute the inverse of the stddev. - float invStdDev = var <= 0.F ? 1.F : rsqrtf(var); + float inv_std_dev = rsqrtf(var + params.epsilon); - // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; - // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + int32_t hw_begin = blockIdx.y * params.hw_per_block; + int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); - // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The src/dst offset. - int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; - - // Fetch two channels per thread. - computeGroupNorm(params.src, params.dst, offset, mean, invStdDev, gammaF2, betaF2, params.withSwish); + const T* input = (params.skip != nullptr) ? params.skip_workspace : params.src; + int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + ComputeGroupNorm(input, params.dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, params.use_silu); } } template -void groupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream) { - // Make sure the dimensions are aligned with what we expect. - ORT_ENFORCE(params.c % params.cPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0); - +void GroupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params.c / params.cPerBlock; + grid.x = DivUp(params.c, params.channels_per_block); // The number of blocks to compute all the activations in a given instance. - grid.y = divUp(params.hw, params.hwPerBlock); + grid.y = DivUp(params.hw, params.hw_per_block); // The number of instances. grid.z = params.n; - switch (params.cPerBlock) { - case 320: - groupNormNHWCScaleKernel<<>>(params); + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params.threads_per_block) { + case 256: + GroupNormNHWCScaleKernel<<>>(params); break; - case 480: - groupNormNHWCScaleKernel<<>>(params); + case 192: + GroupNormNHWCScaleKernel<<>>(params); break; - case 256: - groupNormNHWCScaleKernel<<>>(params); + case 160: + GroupNormNHWCScaleKernel<<>>(params); break; case 128: - groupNormNHWCScaleKernel<<>>(params); + GroupNormNHWCScaleKernel<<>>(params); + break; + case 64: + GroupNormNHWCScaleKernel<<>>(params); break; - default: - ORT_NOT_IMPLEMENTED("Not implemented"); } } -int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) { - int32_t maxDivisor = -1; +int32_t FindMaxDivisor(int32_t n, int32_t max_allowed_divisor) { + int32_t max_divisor = -1; for (int32_t i = 1; i <= std::sqrt(n); i++) { if (n % i == 0) { int32_t divisor1 = n / i; int32_t divisor2 = i; - if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) { - maxDivisor = divisor1; + if (divisor1 > max_divisor && divisor1 < max_allowed_divisor) { + max_divisor = divisor1; } - if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) { - maxDivisor = divisor2; + if (divisor2 > max_divisor && divisor2 < max_allowed_divisor) { + max_divisor = divisor2; } } } - return maxDivisor; + return max_divisor; +} + +// Find proper channels per block based on a cost function: The cost is number of channels corresponding to +// extra threads allocated but no channels assigned to them to work on. If cost is zero, every thread has +// work to do so it is ideal case. +int FindChannelsPerBlock(int num_channels, int channels_per_group) { + int min_cost = -1; + int best_candidate = -1; + for (size_t i = kNumOfSizes; i > 0; --i) { + if (kSizes[i - 1] < channels_per_group) { + break; + } + + int channels_per_block = kSizes[i - 1] / channels_per_group * channels_per_group; + int blocks = (num_channels + channels_per_block - 1) / channels_per_block; + int cost = blocks * kSizes[i - 1] - num_channels; + if (cost == 0) { + return channels_per_block; + } + + if (min_cost == -1 || cost < min_cost) { + min_cost = cost; + best_candidate = channels_per_block; + } + } + + return best_candidate; +} + +int GetChannelsPerBlock(int num_channels, int num_groups) { + int32_t channels_per_group = num_channels / num_groups; + int32_t channels_per_block = channels_per_group; + if (channels_per_group < kMaxSize / 2) { + channels_per_block = FindChannelsPerBlock(num_channels, channels_per_group); + } + return channels_per_block; } template Status LaunchGroupNormKernel( cudaStream_t stream, T* output, + T* add_out, const T* input, + const T* skip, + const T* bias, const float* gamma, const float* beta, void* workspace, @@ -397,79 +588,94 @@ Status LaunchGroupNormKernel( int height, int width, int num_groups, - bool use_swish_activation) { - if (batch_size > static_cast(kMaxGroupNormBatchSize)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only support batch_size <= 32. Got", batch_size); - } + bool use_silu, + bool broadcast_skip, + int channels_per_block) { + GroupNormNHWCParams params; - if (num_groups != static_cast(kGroupNormNumberOfGroups)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only num_groups=32 is supported. Got", num_groups); + int32_t channels_per_group = num_channels / num_groups; + // channels_per_block is computed in PrePack. + // If the gamma is not initializer, channels_per_block might be zero after PrePack. In that happens, compute it here. + if (channels_per_block < channels_per_group) { + channels_per_block = GetChannelsPerBlock(num_channels, num_groups); } - GroupNormNHWCParams params; - int32_t cPerBlock = 320; - int32_t maxBlocksPerHW = 1024; - switch (num_channels) { - case 960: - case 1920: - cPerBlock = 480; - break; - case 512: - case 256: - cPerBlock = 256; - break; - case 128: - cPerBlock = 128; - break; - default: - cPerBlock = 320; + // TODO: Update the kernel to support CHANNELS_PER_THREAD==1 and other corner cases + if (channels_per_block % channels_per_group != 0 || + channels_per_block > kMaxSize || + (channels_per_group % CHANNELS_PER_THREAD != 0)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "GroupNorm in CUDA does not support the input: n=", batch_size, + " h=", height, + " w=", width, + " c=", num_channels, + " groups=", num_groups); } - params.withSwish = use_swish_activation; + params.use_silu = use_silu; params.dst = output; + params.add_out = add_out; params.src = input; + params.skip = skip; + params.bias = bias; params.gamma = gamma; params.beta = beta; - params.redBuffer = reinterpret_cast(workspace); + params.group_sum_buffer = reinterpret_cast(workspace); params.n = batch_size; params.h = height; params.w = width; params.c = num_channels; params.groups = num_groups; params.hw = params.h * params.w; - const int32_t blocksPerHW = findMaxDivisor(params.hw, maxBlocksPerHW); - params.hwPerBlock = divUp(params.hw, blocksPerHW); - params.cPerBlock = cPerBlock; - params.cPerGroup = params.c / params.groups; + + // This will allocate as many blocks as possible to partition HW. + // For Stable Diffusion, latent hw is 4K ~ 16K. This will allocate 1024 blocks, and each handles 4~16 hw. + // TODO: tune this logic to find proper blocks when hw is small. + constexpr int32_t max_blocks_per_hw = 1024; + const int32_t blocks_per_hw = FindMaxDivisor(params.hw, max_blocks_per_hw); + params.hw_per_block = DivUp(params.hw, blocks_per_hw); + + params.channels_per_block = channels_per_block; + params.channels_per_group = channels_per_group; params.hwc = params.hw * params.c; - params.invHWC = 1.F / (float)(params.hw * params.cPerGroup); - params.groupsPerBlock = cPerBlock / params.cPerGroup; + params.inv_hw_channels_per_group = 1.F / (float)(params.hw * params.channels_per_group); + params.groups_per_block = channels_per_block / params.channels_per_group; + params.epsilon = epsilon; + params.broadcast_skip = broadcast_skip; - DUMP_TENSOR_INIT(); - DUMP_TENSOR("input", input, batch_size, num_channels, height * width); - DUMP_TENSOR("gamma", gamma, 1, num_channels); - DUMP_TENSOR("beta", beta, 1, num_channels); - cudaMemsetAsync(params.redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), stream); - groupNormNHWCSum(params, stream); - DUMP_TENSOR("workspace", params.redBuffer, batch_size, num_groups, 2); + // Workspace for SkipGroupNorm to store intermediate results of src+skip+bias. + params.skip_workspace = (params.add_out != nullptr) ? params.add_out : params.dst; + + params.threads_per_block = NextSize(channels_per_block) / CHANNELS_PER_THREAD; + + CUDA_RETURN_IF_ERROR(cudaMemsetAsync( + params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream)); + + GroupNormNHWCSum(params, stream); CUDA_RETURN_IF_ERROR(cudaGetLastError()); - groupNormNHWCScale(params, stream); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("workspace", params.group_sum_buffer, batch_size, 2, num_groups); + + GroupNormNHWCScale(params, stream); CUDA_RETURN_IF_ERROR(cudaGetLastError()); - DUMP_TENSOR("output", output, batch_size, num_channels, height * width); + return Status::OK(); } -template Status LaunchGroupNormKernel(cudaStream_t stream, half* output, - const half* input, const float* gamma, const float* beta, void* workspace, +template Status LaunchGroupNormKernel(cudaStream_t stream, half* output, half* add_out, + const half* input, const half* skip, const half* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + int height, int width, int num_groups, bool silu, + bool broadcast_skip, int channels_per_block); -template Status LaunchGroupNormKernel(cudaStream_t stream, float* output, - const float* input, const float* gamma, const float* beta, void* workspace, +template Status LaunchGroupNormKernel(cudaStream_t stream, float* output, float* add_out, + const float* input, const float* skip, const float* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + int height, int width, int num_groups, bool silu, + bool broadcast_skip, int channels_per_block); } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h index c7e9245050ee6..9532aeecb2f57 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h @@ -12,29 +12,33 @@ namespace onnxruntime { namespace contrib { namespace cuda { -constexpr size_t kMaxGroupNormBatchSize = 32; -constexpr size_t kGroupNormNumberOfGroups = 32; - -constexpr size_t GetGroupNormWorkspaceSizeInBytes() { +constexpr size_t GetGroupNormWorkspaceSizeInBytes(size_t batch_size, size_t num_groups) { // Two buffers for sum and squared sum - return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups; + return (sizeof(float) * 2) * batch_size * num_groups; } +int GetChannelsPerBlock(int num_channels, int num_groups); + template Status LaunchGroupNormKernel( cudaStream_t stream, - T* output, // normalized output tensor - const T* input, // input tensor - const float* gamma, // gamma (also known as weight or scale) - const float* beta, // beta (also known as bias) - void* workspace, // Work space - float epsilon, // epsilon used normalization - int batch_size, // N - int num_channels, // C - int height, // H - int width, // W - int num_groups, // number of groups - bool use_swish_activation // Whether there is Swish activation after group normalization + T* output, // normalized output tensor. Shape is (n, h, w, c) + T* add_out, // optional output tensor for element-wise sum of input + skip + bias. Shape is (n, h, w, c) + const T* input, // input tensor. Shape is (n, h, w, c) + const T* skip, // optional skip tensor. Shape is (n, h, w, c) + const T* bias, // optional bias tensor. Shape is (c) for SkipGroupNorm or (n, c) for BiasGroupNorm + const float* gamma, // gamma (also known as weight or scale). Shape is (c) + const float* beta, // beta (also known as bias). Shape is (c) + void* workspace, // Work space + float epsilon, // epsilon used normalization + int batch_size, // N + int num_channels, // C + int height, // H + int width, // W + int num_groups, // number of groups + bool use_silu, // Whether there is Sigmoid Linear Unit (SiLU) activation after group normalization + bool broadcast_skip, // Whether skip need broadcast. When skip has shape (n, c) or (n, 1, 1, c), it need broadcast. + int channels_per_block // Pre-computed channels per block. ); } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc new file mode 100644 index 0000000000000..251850f621361 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "core/providers/cuda/math/gemm.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cpu/math/gemm_helper.h" +#include "contrib_ops/cuda/math/gemm_float8.h" + +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL() \ + ONNX_OPERATOR_KERNEL_EX( \ + GemmFloat8, \ + kMSDomain, \ + 1, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("TA", BuildKernelDefConstraints()) \ + .TypeConstraint("TB", BuildKernelDefConstraints()) \ + .TypeConstraint("TR", BuildKernelDefConstraints()) \ + .TypeConstraint("TS", BuildKernelDefConstraints()), \ + GemmFloat8); + +REGISTER_KERNEL() + +GemmFloat8::GemmFloat8(const OpKernelInfo& info) : CudaKernel(info) { + transA_ = info.GetAttrOrDefault("transA", 0); + transB_ = info.GetAttrOrDefault("transB", 0); + dtype_ = info.GetAttrOrDefault("dtype", ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto& device_prop = GetDeviceProp(); + sm_count_ = device_prop.multiProcessorCount; + alpha_ = info.GetAttrOrDefault("alpha", 1); + beta_ = info.GetAttrOrDefault("beta", 0); + +#if (CUDA_VERSION <= 12000) + ORT_ENFORCE(beta_ == 0, "CUDA < 12.0 does not support bias, beta must be 0."); +#endif + + std::string stemp = info.GetAttrOrDefault("activation", "NONE"); + if (stemp == "NONE") { + epilogue_ = CUBLASLT_EPILOGUE_DEFAULT; + } else if (stemp == "RELU") { + epilogue_ = CUBLASLT_EPILOGUE_RELU; + } else if (stemp == "GELU") { + epilogue_ = CUBLASLT_EPILOGUE_GELU; + } else { + ORT_THROW("Unexpected value for activation: '", stemp, "'."); + } +} + +Status GemmFloat8::SetCheck(const TensorShape& a_shape, const TensorShape& b_shape, int& M, int& N, int& K) const { + GemmHelper helper(a_shape, transA_, b_shape, transB_, TensorShape({})); + if (!helper.State().IsOK()) + return helper.State(); + + M = gsl::narrow_cast(helper.M()); + N = gsl::narrow_cast(helper.N()); + K = gsl::narrow_cast(helper.K()); + return helper.State(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu new file mode 100644 index 0000000000000..df25342342cd5 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -0,0 +1,402 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// The operator calls function 'cublasLtMatmul' +// (https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmul#cublasltmatmul). +// It lets the function checks what configuration is valid or not. If not, the error message +// shows the error message 'CUBLAS_STATUS_NOT_SUPPORTED'. NVIDIA documentation provides +// information on what attribute or type must be modified. +// This operator requires CUDA_VERSION >= 11.8 for float 8 and CUDA_VERSION >= 12.0 +// for beta != 0. + +#include +#include +#include +#include "contrib_ops/cuda/math/gemm_float8.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// It must exist somewhere already. +int32_t TypeSize(int32_t element_type) { + switch (element_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + return 4; + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + return 2; +#if (!defined(DISABLE_FLOAT8_TYPES) && (CUDA_VERSION >= 11080)) + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: + return 1; +#endif + default: + ORT_THROW("Unexpected element_type=", element_type, "."); + } +} + +void GemmFloat8::SetParams(const TensorShape& a_shape, const TensorShape& b_shape, + int& M, int& N, int& K, int& lda, int& ldb, int& ldd) const { + int m_idx = transA_ ? 1 : 0; + int k_idx = 1 - m_idx; + int n_idx = transB_ ? 0 : 1; + + M = static_cast(a_shape[m_idx]); + K = static_cast(a_shape[k_idx]); + N = static_cast(b_shape[n_idx]); + lda = static_cast(a_shape[1]); + ldb = static_cast(b_shape[1]); + ldd = static_cast(b_shape[n_idx]); +} + +template +int32_t GetTypeAndShape(const TValue* input, + TensorShape& shape, + bool swap = false) { + shape = input->Shape(); + ORT_ENFORCE(shape.NumDimensions() == 2); + if (swap) { + std::swap(shape[0], shape[1]); + } + return input->GetElementType(); +} + +Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* input_A = nullptr; + const Tensor* input_B = nullptr; + const Tensor* input_C = nullptr; + const Tensor* scale_A = nullptr; + const Tensor* scale_B = nullptr; + const Tensor* scale_Y = nullptr; + bool has_scales = false; + bool has_bias = false; + int n_inputs = ctx->InputCount(); + + input_A = ctx->Input(0); + input_B = ctx->Input(1); + if (n_inputs == 3) { + input_C = ctx->Input(2); + has_bias = true; + } else if (n_inputs > 3) { + ORT_ENFORCE(n_inputs >= 5, "Unexpected number of inputs=", n_inputs, "."); + has_scales = true; + scale_A = ctx->Input(3); + scale_B = ctx->Input(4); + scale_Y = n_inputs < 6 ? nullptr : ctx->Input(5); + ORT_ENFORCE(scale_A->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + ORT_ENFORCE(scale_B->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + ORT_ENFORCE(scale_Y == nullptr || scale_Y->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + if (ctx->Input(2) != nullptr) { + input_C = ctx->Input(2); + has_bias = true; + ORT_ENFORCE(input_C->GetElementType() == dtype_, "Bias type must be equal to dtype."); + } + } + + auto first_type = input_A->GetElementType(); + bool is_float8 = first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN || first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2; + if (!is_float8) + return ComputeRowMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B, + input_C, scale_A, scale_B, scale_Y); + return ComputeColMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B, + input_C, scale_A, scale_B, scale_Y); +} + +Status GemmFloat8::ComputeRowMajor( + OpKernelContext* ctx, int n_inputs, bool has_bias, bool has_scales, + const Tensor* input_A, const Tensor* input_B, + const Tensor* input_C, const Tensor* scale_A, + const Tensor* scale_B, const Tensor* scale_Y) const { + TensorShape shape_A, shape_B, shape_C, shape_Y; + int32_t dtype_A, dtype_B, dtype_C, dtype_Y; + dtype_A = GetTypeAndShape(input_A, shape_A); + dtype_B = GetTypeAndShape(input_B, shape_B); + + int M, N, K, lda, ldb, ldd; + SetParams(shape_A, shape_B, M, N, K, lda, ldb, ldd); + + TensorShape dimensions{M, N}; + Tensor* Y = ctx->Output(0, dimensions); + dtype_Y = GetTypeAndShape(Y, shape_Y); + dtype_C = has_bias ? GetTypeAndShape(input_C, shape_C) + : ONNX_NAMESPACE::TensorProto_DataType_FLOAT; + return ComputeGemm(ctx, n_inputs, has_bias, has_scales, dtype_A, dtype_B, dtype_C, + dtype_Y, shape_A, shape_B, shape_C, shape_Y, transA_, transB_, + input_A->DataRaw(), input_B->DataRaw(), + has_bias ? input_C->DataRaw() : nullptr, + has_scales ? scale_A->DataRaw() : nullptr, + has_scales ? scale_B->DataRaw() : nullptr, + has_scales && scale_Y != nullptr ? scale_Y->DataRaw() : nullptr, + Y->MutableDataRaw(), M, N, K, lda, ldb, ldd, true); +} + +Status GemmFloat8::ComputeColMajor( + OpKernelContext* ctx, int n_inputs, bool has_bias, bool has_scales, + const Tensor* input_A, const Tensor* input_B, + const Tensor* input_C, const Tensor* scale_A, + const Tensor* scale_B, const Tensor* scale_Y) const { + TensorShape shape_A, shape_B, shape_C, shape_Y; + int32_t dtype_A, dtype_B, dtype_C, dtype_Y; + dtype_A = GetTypeAndShape(input_A, shape_A); + dtype_B = GetTypeAndShape(input_B, shape_B); + + int M, N, K, lda, ldb, ldd; + SetParams(shape_A, shape_B, M, N, K, lda, ldb, ldd); + + std::swap(shape_A[0], shape_A[1]); + std::swap(shape_B[0], shape_B[1]); + + TensorShape dimensions{M, N}; + Tensor* Y = ctx->Output(0, dimensions); + dtype_Y = GetTypeAndShape(Y, shape_Y); + dtype_C = has_bias ? GetTypeAndShape(input_C, shape_C, true) + : ONNX_NAMESPACE::TensorProto_DataType_FLOAT; + + return ComputeGemm(ctx, n_inputs, has_bias, has_scales, dtype_B, dtype_A, dtype_C, + dtype_Y, shape_B, shape_A, shape_C, shape_Y, transB_, transA_, + input_B->DataRaw(), input_A->DataRaw(), + has_bias ? input_C->DataRaw() : nullptr, + has_scales ? scale_B->DataRaw() : nullptr, + has_scales ? scale_A->DataRaw() : nullptr, + has_scales && scale_Y != nullptr ? scale_Y->DataRaw() : nullptr, + Y->MutableDataRaw(), N, M, K, ldb, lda, ldd, false); +} + +Status GemmFloat8::ComputeGemm( + OpKernelContext* ctx, int n_inputs, bool has_bias, bool has_scales, + int32_t dtype_A, int32_t dtype_B, + int32_t dtype_C, int32_t dtype_Y, + const TensorShape& shape_A, const TensorShape& shape_B, + const TensorShape& shape_C, const TensorShape& shape_Y, + bool trans_A, bool trans_B, const void* p_input_a, const void* p_input_b, + const void* p_input_c, const void* p_scale_a, const void* p_scale_b, + const void* p_scale_y, void* p_output_y, int M, int N, int K, int lda, + int ldb, int ldd, bool row_major_compute) const { + cudaStream_t stream = Stream(ctx); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + + cublasLtHandle_t cublasLt; + CUBLAS_RETURN_IF_ERROR(cublasLtCreate(&cublasLt)); + + cublasLtMatmulDesc_t operationDesc = nullptr; + cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, + Ddesc = nullptr; + + // Create matrix descriptors. Not setting any extra attributes. + cudaDataType_t a_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_A); + cudaDataType_t b_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_B); + cudaDataType_t d_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_Y); + cudaDataType_t scale_cuda_type = + onnxruntime::cuda::ToCudaDataType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + cudaDataType_t bias_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_C); + + cublasComputeType_t compute_type; + switch (d_cuda_type) { + case CUDA_R_16F: + switch (a_cuda_type) { + case CUDA_R_8F_E4M3: + case CUDA_R_8F_E5M2: + compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + break; + default: + compute_type = CUBLAS_COMPUTE_32F_FAST_16F; + break; + } + break; + case CUDA_R_16BF: + compute_type = CUBLAS_COMPUTE_32F_FAST_16BF; + break; + case CUDA_R_32F: + compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + break; + default: + ORT_THROW("Unable to determine computeType in operator GemmFloat8."); + } + + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutCreate( + &Adesc, a_cuda_type, trans_A ? K : M, trans_A ? M : K, lda)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutCreate( + &Bdesc, b_cuda_type, trans_B ? N : K, trans_B ? K : N, ldb)); + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Ddesc, d_cuda_type, M, N, ldd)); + + if (row_major_compute) { + cublasLtOrder_t matrixOrder = CUBLASLT_ORDER_ROW; + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &matrixOrder, sizeof(matrixOrder))); + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &matrixOrder, sizeof(matrixOrder))); + } + + CUBLAS_RETURN_IF_ERROR( + cublasLtMatmulDescCreate(&operationDesc, compute_type, scale_cuda_type)); + cublasOperation_t ctransa = trans_A ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t ctransb = trans_B ? CUBLAS_OP_T : CUBLAS_OP_N; + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &ctransa, sizeof(ctransa))); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &ctransb, sizeof(ctransb))); + + if (sm_count_ != 0) { + int math_sm_count = static_cast(sm_count_); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sm_count, + sizeof(math_sm_count))); + } + + if (has_scales) { + // gemm float 8 + const int8_t ifast_accumulation_mode = 1; + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, + cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_FAST_ACCUM, + &ifast_accumulation_mode, sizeof(ifast_accumulation_mode))); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &p_scale_a, + sizeof(p_scale_a))); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &p_scale_b, + sizeof(p_scale_b))); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &p_scale_y, + sizeof(p_scale_b))); + + // float 8 +#if CUDA_VERSION >= 11080 + if (dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN || + dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2) { + // For FP8 output, cuBLAS requires C_type to be same as bias_type + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, bias_cuda_type, M, N, ldd)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_cuda_type, + sizeof(bias_cuda_type))); + } else { + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); + } + } else { + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); + } +#else + // An output is still needed but it is not initialized. + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); +#endif + + if (row_major_compute) { + cublasLtOrder_t matrixOrder = CUBLASLT_ORDER_ROW; + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &matrixOrder, sizeof(matrixOrder))); + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutSetAttribute(Ddesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &matrixOrder, sizeof(matrixOrder))); + } + + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, + &epilogue_, sizeof(epilogue_)); + + // See + // https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmulPreferenceAttributes_t#cublasltmatmulpreferenceattributes-t + // The workspace should be allocated once from OpKernelContext assuming + // only one cuda function is running at a time (which is not necessarily true + // with H100). + size_t workspaceSize = static_cast(1 << 25); // suggested fixed value 32Mb + cublasLtMatmulPreference_t preference = nullptr; + cublasLtMatmulPreferenceCreate(&preference); + cublasLtMatmulPreferenceSetAttribute(preference, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspaceSize, sizeof(workspaceSize)); + + // https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmulAlgoGetHeuristic#cublasltmatmulalgogetheuristic + cublasLtMatmulHeuristicResult_t heuristicResult = {}; + int returnedResults = 0; + cublasStatus_t cuda_status = cublasLtMatmulAlgoGetHeuristic( + cublasLt, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, 1, + &heuristicResult, &returnedResults); + ORT_ENFORCE( + returnedResults > 0 && cuda_status == CUBLAS_STATUS_SUCCESS, + " Unable to find any suitable algorithm due to ", + onnxruntime::cuda::cublasGetErrorEnum(cuda_status), + ", returnedResults=", returnedResults, + ", alpha=", alpha_, ", beta=", beta_, ", n_inputs=", n_inputs, + ", A_type=", onnxruntime::cuda::CudaDataTypeToString(a_cuda_type), + ", B_type=", onnxruntime::cuda::CudaDataTypeToString(b_cuda_type), + ", C_type=", onnxruntime::cuda::CudaDataTypeToString(bias_cuda_type), + ", result_type=", onnxruntime::cuda::CudaDataTypeToString(d_cuda_type), + ", bias_type=", onnxruntime::cuda::CudaDataTypeToString(bias_cuda_type), + ", scale_type=", onnxruntime::cuda::CudaDataTypeToString(scale_cuda_type), + ", computeType=", onnxruntime::cuda::CublasComputeTypeToString(compute_type), + ", epilogue=", epilogue_, ", smCount=", sm_count_, ", transA=", trans_A, + ", transB=", trans_B, + ", fastAccumulationMode=", 1, + ", shape_A=", shape_A[0], "x", shape_A[1], ", shape_B=", shape_B[0], "x", + shape_B[1], ", shape_C=", (shape_C.NumDimensions() > 0 ? shape_C[0] : 0), "x", + (shape_C.NumDimensions() > 1 ? shape_C[1] : 0), ", M=", M, ", N=", N, ", K=", K, + ", lda=", lda, ", ldb=", ldb, ", ldd=", ldd, + ", workspaceSize=", workspaceSize, ", rowMajorCompute=", (row_major_compute ? 1 : 0), + ". Check NVIDIA documentation to see what combination is valid: ", + "https://docs.nvidia.com/cuda/cublas/" + "index.html?highlight=cublasLtMatmulAlgoGetHeuristic#" + "cublasltmatmulalgogetheuristic."); + + void* workspace = nullptr; + if (workspaceSize > 0) { + CUDA_RETURN_IF_ERROR(cudaMalloc(reinterpret_cast(&workspace), workspaceSize)); + } + // https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmul#cublasltmatmul + const void* bias = has_bias ? p_input_c : p_output_y; + cuda_status = cublasLtMatmul( + cublasLt, operationDesc, static_cast(&alpha_), /* alpha */ + p_input_a, /* A */ + Adesc, p_input_b, /* B */ + Bdesc, static_cast(&beta_), /* beta */ + bias, /* C */ + Cdesc, p_output_y, /* Y */ + Ddesc, &heuristicResult.algo, /* algo */ + workspace, /* workspace */ + workspaceSize, stream); /* stream */ + ORT_ENFORCE( + cuda_status == CUBLAS_STATUS_SUCCESS, + " Unable to run cublasLtMatmul due to ", + onnxruntime::cuda::cublasGetErrorEnum(cuda_status), + ", returnedResults=", returnedResults, ", alpha=", alpha_, + ", n_inputs=", n_inputs, ", A_type=", + onnxruntime::cuda::CudaDataTypeToString(a_cuda_type), + ", B_type=", onnxruntime::cuda::CudaDataTypeToString(b_cuda_type), + ", result_type=", onnxruntime::cuda::CudaDataTypeToString(d_cuda_type), + ", bias_type=", onnxruntime::cuda::CudaDataTypeToString(bias_cuda_type), + ", scale_type=", onnxruntime::cuda::CudaDataTypeToString(scale_cuda_type), + ", computeType=", onnxruntime::cuda::CublasComputeTypeToString(compute_type), + ", epilogue=", epilogue_, ", smCount=", sm_count_, ", transA=", trans_A, + ", transB=", trans_B, + ", fastAccumulationMode=", 1, + ", shape_A=", shape_A[0], "x", shape_A[1], ", shape_B=", shape_B[0], "x", + shape_B[1], ", M=", M, ", N=", N, ", K=", K, ", lda=", lda, ", ldb=", ldb, + ", ldd=", ldd, ", workspaceSize=", workspaceSize, + ", rowMajorCompute=", (row_major_compute ? 1 : 0), "."); + + if (workspaceSize > 0) { + CUDA_RETURN_IF_ERROR(cudaFree(workspace)); + } + + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulPreferenceDestroy(preference)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Ddesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Cdesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Bdesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Adesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescDestroy(operationDesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtDestroy(cublasLt)); + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.h b/onnxruntime/contrib_ops/cuda/math/gemm_float8.h new file mode 100644 index 0000000000000..e84ccd55b2003 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "cublas_v2.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Calls https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmul. +// D = alpha*(A*B) +class GemmFloat8 final : public onnxruntime::cuda::CudaKernel { + public: + GemmFloat8(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + void SetParams(const TensorShape& shape_a, + const TensorShape& shape_b, + int& M, int& N, int& K, + int& lda, int& ldb, int& ldd) const; + Status SetCheck(const TensorShape& shape_a, + const TensorShape& shape_b, + int& M, int& N, int& K) const; + + Status ComputeRowMajor(OpKernelContext* ctx, int n_inputs, bool has_bias, + bool has_scales, const Tensor* input_A, + const Tensor* input_B, const Tensor* input_C, + const Tensor* scale_A, const Tensor* scale_B, + const Tensor* scale_Y) const; + Status ComputeColMajor(OpKernelContext* ctx, int n_inputs, bool has_bias, + bool has_scales, const Tensor* input_A, + const Tensor* input_B, const Tensor* input_C, + const Tensor* scale_A, const Tensor* scale_B, + const Tensor* scale_Y) const; + + Status ComputeGemm( + OpKernelContext* ctx, int n_inputs, bool has_bias, bool has_scales, + int32_t dtype_A, int32_t dtype_b, + int32_t dtype_c, int32_t dtype_Y, + const TensorShape& shape_A, const TensorShape& shape_B, + const TensorShape& shape_C, const TensorShape& shape_Y, + bool transa, bool transb, const void* p_input_a, const void* p_input_b, + const void* p_input_c, const void* p_scale_a, const void* p_scale_b, + const void* p_scale_y, void* p_output_y, int M, int N, int K, int lda, + int ldb, int ldd, bool row_major_compute) const; + + float alpha_; + float beta_; + bool transA_; + bool transB_; + int64_t sm_count_; + int64_t dtype_; + cublasLtEpilogue_t epilogue_; + + // TODO(xadupre): add epilogue (= activation function, Relu or Gelu are available). +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu new file mode 100644 index 0000000000000..7921315ab52e1 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu @@ -0,0 +1,375 @@ +// Modifications: scaling is moved from masked softmax to the gemm before that. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cuda_common.h" +#include "dequantize_blockwise.cuh" + +using namespace onnxruntime::cuda; +using namespace cub; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + + +__device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, half scale, half zp, half* output) { + half2 scale_half2 = {scale, scale}; + half zp_adjust = -scale * __short2half_rn(zp); + half2 zp_adjust2 = {zp_adjust, zp_adjust}; + + alignas(16) half2 results[4]; + half v0 = __uint2half_rn(values_quant & 0xF); + half v1 = __uint2half_rn((values_quant >> 4) & 0xF); + results[0] = __halves2half2(v0, v1) * scale_half2 + zp_adjust2; + + half v2 = __uint2half_rn((values_quant >> 8) & 0xF); + half v3 = __uint2half_rn((values_quant >> 12) & 0xF); + results[1] = __halves2half2(v2, v3) * scale_half2 + zp_adjust2; + + half v4 = __uint2half_rn((values_quant >> 16) & 0xF); + half v5 = __uint2half_rn((values_quant >> 20) & 0xF); + results[2] = __halves2half2(v4, v5) * scale_half2 + zp_adjust2; + + half v6 = __uint2half_rn((values_quant >> 24) & 0xF); + half v7 = __uint2half_rn((values_quant >> 28) & 0xF); + results[3] = __halves2half2(v6, v7) * scale_half2 + zp_adjust2; + *(reinterpret_cast(output)) = *(reinterpret_cast(results)); +} + +__device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, float scale, float zp, float* output) { + float zp_adjust = -scale * zp; + output[0] = float(values_quant & 0xF) * scale + zp_adjust; + output[1] = float((values_quant >> 4) & 0xF) * scale + zp_adjust; + output[2] = float((values_quant >> 8) & 0xF) * scale + zp_adjust; + output[3] = float((values_quant >> 12) & 0xF) * scale + zp_adjust; + output[4] = float((values_quant >> 16) & 0xF) * scale + zp_adjust; + output[5] = float((values_quant >> 20) & 0xF) * scale + zp_adjust; + output[6] = float((values_quant >> 24) & 0xF) * scale + zp_adjust; + output[7] = float((values_quant >> 28) & 0xF) * scale + zp_adjust; +} + +template +__global__ void Dequantize4BitsKernel( + T* output, + const uint8_t* quant_data, + const T* scale_data, + const uint8_t* zero_points, + int block_size, + int blocks_per_K, + int blocks_per_threadblock, + int shift) { + int block_id = blockIdx.x * blocks_per_threadblock + ((threadIdx.x * 8) >> shift); + int n_idx = block_id / blocks_per_K; + int kb_idx = block_id % blocks_per_K; + int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1)); + uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); + T scale = *(scale_data + block_id); + uint8_t zp = 8; + if (zero_points) { + zp = zero_points[n_idx * ((blocks_per_K + 1)/2) + kb_idx / 2]; + zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f); + } + + output = output + element_offset; + DequantizeEightElements(quant_value, scale, static_cast(zp), output); +} + +template +Status Dequantize4Bits( + T* output, + const uint8_t* quant_data, + const T* scales_data, + const uint8_t* zero_points, // shape: [N, (block_per_K + 1)/2] + int k, + int n, + int block_size, + cudaStream_t stream) { + // k is padded and equal to block_per_K * block_size + ORT_ENFORCE(k % block_size == 0, "k must be a multiplier of block_size"); + constexpr int element_per_thread = 8; + int blocks_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; + int blocks_per_K = k / block_size; + int blocks_per_grid = static_cast(CeilDiv(n * blocks_per_K, blocks_per_threadblock)); + int shift = static_cast(log2f(float(block_size))); + + Dequantize4BitsKernel<<>>( + output, + quant_data, + scales_data, + zero_points, + block_size, + blocks_per_K, + blocks_per_threadblock, + shift); + + return Status::OK(); +} + +template Status Dequantize4Bits( + float* output, + const uint8_t* quant_data, + const float* scales_data, + const uint8_t* zero_points, + int k, + int n, + int block_size, + cudaStream_t stream); + +template Status Dequantize4Bits( + half* output, + const uint8_t* quant_data, + const half* scales_data, + const uint8_t* zero_points, + int k, + int n, + int block_size, + cudaStream_t stream); + + +/////////////////////////////////////////////////////////////////////////////// +// A more general block-wise dequantization implementation that supports +// different block sizes and block orientations (row-wise/column-wise). + +template < + int Row_, ///< rows of a matrix + int Column_ ///< columns of a matrix + > +struct Shape2D { + static int const kRow = Row_; ///< rows of a matrix + static int const kColumn = Column_; ///< columns of a matrix + static int const kCount = Row_ * Column_; ///< total number of elements in a matrix +}; + +/** + * @brief Blockwise quantization constants + * @tparam ElementT source data type, e.g. fp32/fp16 + * @tparam block_size number of elemenets quantized together + * @tparam qbits number of bits in each quantized element + * @tparam Columnwise true: elements in a block come from one single column + * false: elements in a block come from one single row + */ +template < + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +struct BlkQuantTraits { + // number of qbit elements to pack into whole bytes + static constexpr int kPackSize = (qbits == 8) ? 1 : (qbits == 4) ? 2 : (qbits == 2) ? 4 : 0; + static_assert(kPackSize != 0, "Packing to whole bytes not supported for this qbits!"); + + using QuantBlk = std::conditional_t, Shape2D<1, block_size>>; + using ThreadBlk = Shape2D; +}; + +template < + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +__global__ +void dequantizeThread(ElementT* dst, + const uint8_t* weights, + const ElementT* scales, + const uint8_t* zero_points, + int rows, + int columns, + int thrd_row_blks) { + using QuantBlk = typename BlkQuantTraits::QuantBlk; + using ThreadBlk = typename BlkQuantTraits::ThreadBlk; + + // !! 4b specific code + static_assert(qbits == 4, "Only 4b block quantization is supported!"); + + const auto block_idx = blockIdx.x * blockDim.x + threadIdx.x; + const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + + const auto meta_rows = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + + // quantized matrix is stored in column major, packed by column + const auto q_rows = (meta_rows * QuantBlk::kRow * qbits + 7) / 8; + + int32_t r_blk_idx = static_cast(block_idx % thrd_row_blks); + int32_t c_blk_idx = static_cast(block_idx / thrd_row_blks); + + int32_t r = r_blk_idx * ThreadBlk::kRow; + int32_t c = c_blk_idx * ThreadBlk::kColumn; + + int32_t r_end = std::min(r + ThreadBlk::kRow, rows); + int32_t c_end = std::min(c + ThreadBlk::kColumn, columns); + + // for 4b quant, kPackSize = 2, so we have 2 scales and 2 offsets + const ElementT scale_buf[2] = { + scales[(c / QuantBlk::kColumn) * row_blks + r / QuantBlk::kRow], + ((r/QuantBlk::kRow) < (meta_rows - 1)) + ? scales[(c / QuantBlk::kColumn) * row_blks + r / QuantBlk::kRow + 1] + : static_cast(0.0f)}; + const uint8_t zp_pair = (zero_points == nullptr) + ? 0x88 + : zero_points[(c / QuantBlk::kColumn) * ((row_blks + 1) / 2) + (r / QuantBlk::kRow) / 2]; + const uint16_t zp_buf[2] = {(uint16_t)(zp_pair & 0x0f), (uint16_t)((zp_pair >> 4) & 0x0f)}; + const ElementT adjust_buf[2] = {(-scale_buf[0]) * static_cast(zp_buf[0]), + (-scale_buf[1]) * static_cast(zp_buf[1])}; + + for (int32_t j = c; j < c_end; ++j) { + const uint8_t* q_ptr = weights + j * q_rows; + for (int32_t i = r; i < (r_end - 1); i += 2) { + const auto scale0 = scale_buf[(i - r) / QuantBlk::kRow]; + const auto adjust0 = adjust_buf[(i - r) / QuantBlk::kRow]; + + const auto scale1 = scale_buf[(i + 1 - r) / QuantBlk::kRow];; + const auto adjust1 = adjust_buf[(i + 1 - r) / QuantBlk::kRow]; + + const auto vi = q_ptr[i / 2]; + + if constexpr (std::is_same::value) { + half2 scale_half2 = {scale0, scale1}; + half2 zp_adjust2 = {adjust0, adjust1}; + + half2 v = {__ushort2half_rn(vi & 0xF), __ushort2half_rn((vi >> 4) & 0xF)}; + half2 results = v * scale_half2 + zp_adjust2; + + dst[j * rows + i] = results.x; + dst[j * rows + (i + 1)] = results.y; + } else { + static_assert(std::is_same::value, "Only float and half are supported!"); + const uint8_t vi0 = vi & 0xf; + const uint8_t vi1 = vi >> 4; + dst[j * rows + i] = static_cast(vi0) * scale0 + adjust0;; + dst[j * rows + (i + 1)] = static_cast(vi1) * scale1 + adjust1; + } + } + + if ((r_end & 1) && (r_end > r)) { + const auto scale0 = scale_buf[(r_end - 1 - r) / QuantBlk::kRow]; + const auto adjust0 = adjust_buf[(r_end - 1 - r) / QuantBlk::kRow]; + + const auto vi = q_ptr[(r_end - 1) / 2]; + const uint8_t vi0 = vi & 0xf; + + dst[j * rows + (r_end - 1)] = static_cast(vi0) * scale0 + adjust0; + } + } +} + +template < + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +static void dequantize(ElementT* dst, const uint8_t* weights, const ElementT* scales, + const uint8_t* zero_points, int32_t rows, int32_t columns, + cudaStream_t stream) { + using QuantBlk = typename BlkQuantTraits::QuantBlk; + using ThreadBlk = typename BlkQuantTraits::ThreadBlk; + + // Thread partitioning + const auto thrd_row_blks = (rows + ThreadBlk::kRow - 1) / ThreadBlk::kRow; + const auto thrd_col_blks = (columns + ThreadBlk::kColumn - 1) / ThreadBlk::kColumn; + const auto total_thrd_blks = thrd_row_blks * thrd_col_blks; + + const auto grids = (total_thrd_blks + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock; + dequantizeThread<<>>( + dst, + weights, + scales, + zero_points, + rows, + columns, + thrd_row_blks); +} + + +template +Status +DequantizeBlockwise4b( + T* dst, + const uint8_t* src, + const T* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + cudaStream_t stream) { + switch (block_size) { + case 16: + if (columnwise) { + dequantize(dst, src, scales, zero_points, rows, columns, stream); + } else { + dequantize(dst, src, scales, zero_points, rows, columns, stream); + } + return Status::OK(); + case 32: + if (columnwise) { + dequantize(dst, src, scales, zero_points, rows, columns, stream); + } else { + dequantize(dst, src, scales, zero_points, rows, columns, stream); + } + return Status::OK(); + case 64: + if (columnwise) { + dequantize(dst, src, scales, zero_points, rows, columns, stream); + } else { + dequantize(dst, src, scales, zero_points, rows, columns, stream); + } + return Status::OK(); + case 128: + if (columnwise) { + dequantize(dst, src, scales, zero_points, rows, + columns, stream); + } else { + dequantize(dst, src, scales, zero_points, + rows, columns, stream); + } + return Status::OK(); + case 256: + if (columnwise) { + dequantize(dst, src, scales, zero_points, rows, + columns, stream); + } else { + dequantize(dst, src, scales, zero_points, + rows, columns, stream); + } + return Status::OK(); + default: + // Only block size 16, 32, 64, 128, 256 are supported. + return Status(::onnxruntime::common::ONNXRUNTIME, ::onnxruntime::common::FAIL, + "Unsupported block size for blockwise quantization."); + } +} + +template +Status DequantizeBlockwise4b( + float* dst, + const uint8_t* src, + const float* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + cudaStream_t stream); + +template +Status DequantizeBlockwise4b( + half* dst, + const uint8_t* src, + const half* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh new file mode 100644 index 0000000000000..f9c09c55fd893 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +template +Status Dequantize4Bits( + T* output, + const uint8_t* quant_data, + const T* scales_data, + const uint8_t* zero_points, + int k, + int n, + int block_size, + cudaStream_t stream); + + +/** + * @brief Dequantize a block-wise quantized matrix, and store the result in a + * column major matrix for use in subsequent GEMM. This implementation supports + * columnwise and rowwise block orientation. + * @param[out] dst pointer to the dequantized matrix, column major: [columns, rows] + * @param[in] qelements pointer to the quantized elements, column major: [columns, rows] + * @param[in] scales pointer to the scales of quantized blocks, column major layout + * @param[in] zero_points pointer to the zero points of quantized blocks, packed column major + * scales + * @param[in] block_size size of the quantized block + * @param[in] columnwise whether the quantized matrix is columnwise or rowwise quantized + * @param[in] rows + * @param[in] columns + */ +template +Status DequantizeBlockwise4b( + T* dst, + const uint8_t* qelements, + const T* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu new file mode 100644 index 0000000000000..e58723f0b31e1 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu @@ -0,0 +1,129 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h" +#include "dequantize_blockwise_bnb4.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream) { + ORT_ENFORCE( + quant_type == FP4 || quant_type == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + + T host_quant_map[16]; + switch (quant_type) { + case FP4: + for (int i = 0; i < 16; i++) host_quant_map[i] = static_cast(fp4_qaunt_map[i]); + break; + case NF4: + for (int i = 0; i < 16; i++) host_quant_map[i] = static_cast(nf4_qaunt_map[i]); + break; + } + CUDA_CALL_THROW(cudaMemcpyAsync(quant_map_buffer, host_quant_map, sizeof(T) * 16, cudaMemcpyHostToDevice, stream)); + + return Status::OK(); +} + +template Status SetBnbQuantMap(int quant_type, float* quant_map_buffer, cudaStream_t stream); + +template Status SetBnbQuantMap(int quant_type, half* quant_map_buffer, cudaStream_t stream); + +template +__global__ void kDequantizeBlockwise( + const T* quant_map, + T* output, + const uint8_t* quant_data, + const T* absmax, + const int block_size, + const int n) { + const int n_load = (gridDim.x * TILE_SIZE); + int valid_items_load = 0; + int valid_items_store = 0; + const int base_idx = (blockIdx.x * TILE_SIZE); + + T vals[NUM_PER_TH * 2]; + uint8_t qvals[NUM_PER_TH]; + T local_abs_max = T(0.0f); + + typedef cub::BlockLoad LoadChar; + typedef cub::BlockStore StoreT; + + __shared__ typename LoadChar::TempStorage loadchar; + __shared__ typename StoreT::TempStorage storet; + + for (unsigned int i = base_idx; i < n_load; i += gridDim.x * TILE_SIZE) { + valid_items_load = (n + 1) / 2 - i > TILE_SIZE ? TILE_SIZE : (n + 1) / 2 - i; + valid_items_store = n - i * 2 > TILE_SIZE * 2 ? TILE_SIZE * 2 : n - i * 2; + + local_abs_max = __ldg(&absmax[(i + threadIdx.x * NUM_PER_TH) / (block_size)]); + + __syncthreads(); + LoadChar(loadchar).Load(&(quant_data[i]), qvals, valid_items_load, 128); + + #pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + vals[j * 2] = quant_map[qvals[j] >> 4] * local_abs_max; + vals[j * 2 + 1] = quant_map[qvals[j] & 0x0F] * local_abs_max; + #else + // half multiplication not supported + vals[j * 2] = static_cast(static_cast(quant_map[qvals[j] >> 4]) * static_cast(local_abs_max)); + vals[j * 2 + 1] = + static_cast(static_cast(quant_map[qvals[j] & 0x0F]) * static_cast(local_abs_max)); + #endif + } + + __syncthreads(); + StoreT(storet).Store(&(output[i * 2]), vals, valid_items_store); + } +} + +template +Status DequantizeBnb4( + const T* quant_map, + T* output, + const uint8_t* quant_data, + const T* absmax, + int block_size, + int numel, + cudaStream_t stream) { + int tile_size = 1024; + kDequantizeBlockwise<<<(numel + tile_size - 1) / tile_size, 64, 0, stream>>>( + quant_map, + output, + quant_data, + absmax, + block_size / 2, + numel); + + return Status::OK(); +} + +template Status DequantizeBnb4( + const float* quant_map, + float* output, + const uint8_t* quant_data, + const float* absmax, + int block_size, + int numel, + cudaStream_t stream); + +template Status DequantizeBnb4( + const half* quant_map, + half* output, + const uint8_t* quant_data, + const half *absmax, + int block_size, + int numel, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh new file mode 100644 index 0000000000000..4aef3ab699f9c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream); + +template +Status DequantizeBnb4( + const T* quant_map, + T* output, + const uint8_t* quant_data, + const T* absmax, + int block_size, + int numel, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc new file mode 100644 index 0000000000000..ecf332715d470 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h" +#include "matmul_bnb4.cuh" +#include "dequantize_blockwise_bnb4.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +using namespace onnxruntime::cuda; + +template +class MatMulBnb4 final : public CudaKernel { + public: + MatMulBnb4(const OpKernelInfo& info) : CudaKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("quant_type", &quant_type_)); + ORT_ENFORCE( + quant_type_ == FP4 || quant_type_ == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + + is_training_mode_ = static_cast(info.GetAttrOrDefault("training_mode", static_cast(0))); + transB_ = static_cast(info.GetAttrOrDefault("transB", static_cast(1))); + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t quant_type_; + bool is_training_mode_; + bool transB_; +}; + +template +Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* a = ctx->Input(0); + const Tensor* b_quant = ctx->Input(1); + const Tensor* absmax = ctx->Input(2); + + const auto* a_data = a->Data(); + const uint8_t* b_quant_data = b_quant->Data(); + const auto* absmax_data = absmax->Data(); + + typedef typename ToCudaType::MappedType CudaT; + + // TODO: find a better way to create the quant_map without using a buffer + // don't want to use malloc directly so asking from the caller + // can create a __device__ static array for float but doesn't work for half + IAllocatorUniquePtr quant_map_buffer = GetScratchBuffer(16, ctx->GetComputeStream()); + auto* quant_map_buffer_data = quant_map_buffer.get(); + ORT_RETURN_IF_ERROR(SetBnbQuantMap( + SafeInt(quant_type_), + reinterpret_cast(quant_map_buffer_data), + static_cast(ctx->GetComputeStream()->GetHandle()))); + + constexpr bool transa = false; + const bool transb = transB_; + MatMulComputeHelper helper; + TensorShape b_shape({N_, K_}); + ORT_RETURN_IF_ERROR( + helper.Compute(a->Shape(), b_shape, transa, transb)); + + Tensor* Y = ctx->Output(0, helper.OutputShape()); + // Bail out early if the output is going to be empty + if (Y->Shape().Size() == 0) return Status::OK(); + + bool is_4bit_done = !is_training_mode_ // skip inference specific handle if in training mode + && TryMatMulBnb4( + reinterpret_cast(quant_map_buffer_data), + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + b_quant_data, + reinterpret_cast(absmax_data), + SafeInt(helper.M()), + SafeInt(helper.N()), + SafeInt(helper.K()), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle())); + + if (!is_4bit_done) { + IAllocatorUniquePtr b_dequant_ptr = GetScratchBuffer(N_ * K_, ctx->GetComputeStream()); + auto* b_dequant_data = b_dequant_ptr.get(); + ORT_RETURN_IF_ERROR(DequantizeBnb4( + reinterpret_cast(quant_map_buffer_data), + reinterpret_cast(b_dequant_data), + b_quant_data, + reinterpret_cast(absmax_data), + SafeInt(block_size_), + SafeInt(N_ * K_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + + const CudaT alpha = ToCudaType::FromFloat(1.f); + const CudaT zero = ToCudaType::FromFloat(0.f); + + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( + GetCublasHandle(ctx), + transb ? CUBLAS_OP_T : CUBLAS_OP_N, // transB + CUBLAS_OP_N, // transA + SafeInt(helper.N()), + SafeInt(helper.M()), + SafeInt(helper.K()), + &alpha, + reinterpret_cast(b_dequant_data), + helper.Ldb(transb), // ldb + reinterpret_cast(a_data), + helper.Lda(transa), // lda + &zero, + reinterpret_cast(Y->MutableData()), + helper.Ldc(), + GetDeviceProp())); + } + + return Status::OK(); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulBnb4, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulBnb4); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulBnb4, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulBnb4); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu new file mode 100644 index 0000000000000..1d9aa75ff3701 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu @@ -0,0 +1,192 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include +#include +#include +#include "matmul_bnb4.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define num_values_4bit 32 +template +__global__ void kgemm_4bit_inference_naive( + int M, + int N, + int K, + const T* __restrict__ A, + const uint8_t* B, + const T* absmax, + const T* datatype, + T* out, + int lda, + int ldb, + int ldc, + int block_size) { + // per threadblock: + // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] + // 4 warps -> 4 loads per iter + // 1x32 * 32x4 -> 1x4 outputs per thread block + typedef cub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage[THREADS / 32]; + + const int warp_idx = threadIdx.x / 32; + const int warp_lane = threadIdx.x % 32; + const int row_B = (THREADS / 32) * blockIdx.x + warp_idx; + const int num_values_8bit = num_values_4bit / 2; + float local_C = 0.0f; + + uint8_t local_B_4bit[num_values_8bit]; + T local_B[num_values_4bit / 4]; + T local_A[num_values_4bit / 4]; + __shared__ T quant_map[16]; + T local_absmax = T(0.0f); + + for (int i = threadIdx.x; i < 16; i++) quant_map[i] = T(datatype[i]); + __syncthreads(); + + // A: [1, K] + // B: [N, K] + for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += 32 * num_values_4bit) { + int inner_idx_halved = inner_idx / 2; + int offset_B = ldb * row_B; + int absidx = ((2 * offset_B) + inner_idx) / block_size; + local_absmax = __ldg(&(absmax[absidx])); + + if (row_B < N) { + if ((inner_idx_halved + num_values_8bit) < (K / 2)) { + // this is the most important for performance considerations + reinterpret_cast(local_B_4bit)[0] = + reinterpret_cast(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)]; + } else { + #pragma unroll + for (int j = 0; j < (num_values_8bit); j++) + if ((inner_idx_halved) + j < (K / 2)) + local_B_4bit[j] = B[offset_B + inner_idx_halved + j]; + else + local_B_4bit[j] = 0b01110111; + } + } else { + #pragma unroll + for (int j = 0; j < (num_values_8bit); j++) local_B_4bit[j] = 0b01110111; + } + + for (int i = 0; i < 4; i++) { + #pragma unroll + for (int k = 0; k < num_values_8bit / 4; k++) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax; + local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax; + #else + // half multiplication not supported + local_B[k * 2] = + static_cast(static_cast(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4]) * + static_cast(local_absmax)); + local_B[k * 2 + 1] = + static_cast(static_cast(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F]) * + static_cast(local_absmax)); + #endif + } + + if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) { + // this is also relatively important for performance + if (BITS == 16) { + reinterpret_cast(local_A)[0] = + reinterpret_cast(A)[inner_idx / (num_values_4bit / 4) + i]; + } else { + reinterpret_cast(local_A)[0] = + reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0]; + reinterpret_cast(local_A)[1] = + reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1]; + } + } else { + #pragma unroll + for (int k = 0; k < num_values_4bit / 4; k++) { + if (inner_idx + (i * num_values_4bit / 4) + k < K) + local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)]; + else + local_A[k] = T(0.0f); + } + } + + // accumulate in float; small performance hit for Ampere, but lower error for outputs + #pragma unroll + for (int k = 0; k < num_values_4bit / 4; k++) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + local_C += static_cast(local_A[k] * local_B[k]); + #else + // half multiplication not supported + local_C += static_cast(local_A[k]) * static_cast(local_B[k]); + #endif + } + } + } + + local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); + + if (row_B < N && warp_lane == 0) out[row_B] = T(local_C); +} + +template +bool TryMatMulBnb4( + const T* quant_map, + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream) { + if (k % block_size != 0 || m > 1) { + return false; + } + // supported block_sizes are [4096, 2048, 1024, 512, 256, 128, 64, 32] + if (block_size % 32 != 0 || block_size > 4096) { + return false; + } + + int lda = k; + int ldb = (k + 1) / 2; + int ldc = n; + int num_blocks = (n + 3) / 4; + + constexpr int bits = std::is_same_v ? 16 : 32; + kgemm_4bit_inference_naive<<>>( + m, n, k, a_data, b_data_quant, absmax, quant_map, output, lda, ldb, ldc, block_size); + + return true; +} + +template bool TryMatMulBnb4( + const float* quant_map, + float* output, + const float* a_data, + const uint8_t* b_data_quant, + const float* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream); + +template bool TryMatMulBnb4( + const half* quant_map, + half* output, + const half* a_data, + const uint8_t* b_data_quant, + const half* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cuh b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cuh new file mode 100644 index 0000000000000..743234282fbf3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cuh @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +bool TryMatMulBnb4( + const T* quant_map, + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc new file mode 100644 index 0000000000000..5b0e61e197014 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -0,0 +1,169 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// +// This module define MatMulFp32Q4 operator, it is basically +// matmul float32 with right hand side being a 2-D matrix +// pre-packed and block-compacted into int4 +// + +#include "core/common/safeint.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "matmul_nbits.cuh" +#include "dequantize_blockwise.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +using namespace onnxruntime::cuda; + +template +class MatMulNBits final : public CudaKernel { + public: + MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + ORT_ENFORCE(nbits_ == 4, + "Only 4b quantization is supported for MatMulNBits op," + " additional bits support is planned."); + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t nbits_; + bool column_wise_quant_blk_{true}; +}; + +template +Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* a = ctx->Input(0); + const Tensor* b = ctx->Input(1); + const Tensor* scales = ctx->Input(2); + const Tensor* zero_points = ctx->Input(3); + + const auto* a_data = a->Data(); + const uint8_t* blob_data = b->Data(); + const auto* scales_data = scales->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + + typedef typename ToCudaType::MappedType CudaT; + + constexpr bool transa = false; + constexpr bool transb = true; + MatMulComputeHelper helper; + TensorShape b_shape({N_, K_}); + ORT_RETURN_IF_ERROR( + helper.Compute(a->Shape(), b_shape, transa, transb)); + + Tensor* Y = ctx->Output(0, helper.OutputShape()); + // Bail out early if the output is going to be empty + if (Y->Shape().Size() == 0) return Status::OK(); + + bool is_4bit_done = TryMatMul4Bits( + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + blob_data, + reinterpret_cast(scales_data), + zero_points_data, + SafeInt(helper.M()), + SafeInt(helper.N()), + SafeInt(helper.K()), + SafeInt(block_size_), + SafeInt(GetDeviceProp().sharedMemPerBlock), + static_cast(ctx->GetComputeStream()->GetHandle())); + if (!is_4bit_done) { + int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_; + IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream()); + auto* b_data = b_data_ptr.get(); + if (column_wise_quant_blk_) { + // column-wise block + ORT_RETURN_IF_ERROR(Dequantize4Bits( + reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + zero_points_data, + SafeInt(K_padded), + SafeInt(N_), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } else { + // row-wise block + K_padded = K_; + + ORT_RETURN_IF_ERROR(DequantizeBlockwise4b( + reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + zero_points_data, + SafeInt(block_size_), + column_wise_quant_blk_, + SafeInt(K_), + SafeInt(N_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } +#if 0 + cudaStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle())); + T* b_data_cpu = new T[K_ * N_]; + cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost); + delete[] b_data_cpu; +#endif + + const CudaT alpha = ToCudaType::FromFloat(1.f); + const CudaT zero = ToCudaType::FromFloat(0.f); + + if (helper.OutputOffsets().size() == 1) { + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( + GetCublasHandle(ctx), + CUBLAS_OP_T, + CUBLAS_OP_N, + SafeInt(helper.N()), + SafeInt(helper.M()), + SafeInt(helper.K()), + &alpha, + reinterpret_cast(b_data), + SafeInt(K_padded), + reinterpret_cast(a_data), + helper.Lda(transa), + &zero, + reinterpret_cast(Y->MutableData()), + helper.Ldc(), + GetDeviceProp())); + } + } + + return Status::OK(); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulNBits); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulNBits); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu new file mode 100644 index 0000000000000..f2600a506285d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu @@ -0,0 +1,229 @@ +// Modifications: scaling is moved from masked softmax to the gemm before that. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cuda_common.h" +#include "matmul_nbits.cuh" + +using namespace onnxruntime::cuda; +using namespace cub; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a) { + half2 scale_half2 = {scale, scale}; + half zp_adjust = -scale * __short2half_rn(zp); + half2 zp_adjust2 = {zp_adjust, zp_adjust}; + uint4 vec_a = *(reinterpret_cast(a)); + + half2 element01 = __halves2half2(__uint2half_rn(values_quant & 0xF), __uint2half_rn((values_quant >> 4) & 0xF)); + half2 v0 = element01 * scale_half2 + zp_adjust2; + + half2 element23 = __halves2half2(__uint2half_rn((values_quant >> 8) & 0xF), __uint2half_rn((values_quant >> 12) & 0xF)); + half2 v1 = element23 * scale_half2 + zp_adjust2; + + half2 element45 = __halves2half2(__uint2half_rn((values_quant >> 16) & 0xF), __uint2half_rn((values_quant >> 20) & 0xF)); + half2 v2 = element45 * scale_half2 + zp_adjust2; + + half2 element67 = __halves2half2(__uint2half_rn((values_quant >> 24) & 0xF), __uint2half_rn((values_quant >> 28) & 0xF)); + half2 v3 = element67 * scale_half2 + zp_adjust2; + + v0 = v0 * (*(reinterpret_cast(&(vec_a.x)))); + v1 = v1 * (*(reinterpret_cast(&(vec_a.y)))); + v2 = v2 * (*(reinterpret_cast(&(vec_a.z)))) + v0; + v3 = v3 * (*(reinterpret_cast(&(vec_a.w)))) + v1; + v3 = v2 + v3; + return float(v3.x) + float(v3.y); +} + +__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, float scale, uint8_t zp, const float* a) { + float4 a_vec_0 = *(reinterpret_cast(a)); + float4 a_vec_1 = *(reinterpret_cast(a + 4)); + + float zp_adjust = -scale * zp; + float v0 = float(values_quant & 0xF) * scale + zp_adjust; + float v1 = float((values_quant >> 4) & 0xF) * scale + zp_adjust; + float v2 = float((values_quant >> 8) & 0xF) * scale + zp_adjust; + float v3 = float((values_quant >> 12) & 0xF) * scale + zp_adjust; + float v4 = float((values_quant >> 16) & 0xF) * scale + zp_adjust; + float v5 = float((values_quant >> 20) & 0xF) * scale + zp_adjust; + float v6 = float((values_quant >> 24) & 0xF) * scale + zp_adjust; + float v7 = float((values_quant >> 28) & 0xF) * scale + zp_adjust; + + v0 = v0 * a_vec_0.x; + v1 = v1 * a_vec_0.y; + v2 = v2 * a_vec_0.z; + v3 = v3 * a_vec_0.w; + v4 = v4 * a_vec_1.x + v0; + v5 = v5 * a_vec_1.y + v1; + v6 = v6 * a_vec_1.z + v2; + v7 = v7 * a_vec_1.w + v3; + return v4 + v5 + v6 + v7; +} + +constexpr int kColsPerThreadBlock = 8; +constexpr int kWarpSize = 32; + +// kernel for 4bits quantized gemv, i.e., computing A(1,K) x B(K, N) +// B(K, N) is quantized blockwise with 4bits and stored as [N, (K + block_size - 1)/block_size, blob] +// The thread block size is (kWarpSize, kColsPerThreadBlock) and grid size is (N/kColsPerThreadBlock, 1) +// Each thread block computes [1, K] x [kColsPerThreadBlock, (K + block_size - 1)/block_size, blob], +// i.e., computing kColsPerThreadBlock per block and a warp reduce (1, K) x (K) +template +__global__ void MatMulFloatInt4Kernel( + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int blocks_per_K) { + int n_block_id = blockIdx.x; + int m_id = blockIdx.y; + int lane_id = threadIdx.x; + int warp_id = threadIdx.y; + int n_id = n_block_id * kColsPerThreadBlock + warp_id; + int thread_id = warp_id * kWarpSize + lane_id; + constexpr int k_per_iter = 256; + int k_iter = k / k_per_iter; + + // blocks_per_k is the number of scales and zero points on the k dim + const int b_zp_k = (blocks_per_K + 1)/ 2; + + extern __shared__ char shared_buffer[]; + + // load scale to shared buffer + T* b_scale_vec = (T*)shared_buffer; + uint8_t* b_zp_vec = reinterpret_cast(b_scale_vec + kColsPerThreadBlock * blocks_per_K); + int offset = n_block_id * kColsPerThreadBlock * blocks_per_K; + for (int i = thread_id; i < kColsPerThreadBlock * blocks_per_K; i += kColsPerThreadBlock * kWarpSize) { + b_scale_vec[i] = scales_data[offset + i]; + } + + int zp_offset = n_block_id * kColsPerThreadBlock * b_zp_k; + for (int i = thread_id; i < kColsPerThreadBlock * b_zp_k; i += kColsPerThreadBlock * kWarpSize) { + b_zp_vec[i] = zero_points != nullptr ? zero_points[zp_offset + i] : uint8_t(0x88); + } + __syncthreads(); + + a_data += m_id * k; + b_data_quant += n_id * blocks_per_K * (block_size / 2); + + const int scale_col_offset = warp_id * blocks_per_K; + const int zp_col_offset = warp_id * b_zp_k; + + float sum = 0.f; + int k_id = 0; + for (; k_id < (k & 0xffffff00); k_id += k_per_iter) { + const int t_k = k_id + (lane_id << 3); // k index for this thread + const int t_meta_k = t_k / block_size; // k index for this thread, points to the scale and zero point + uint32_t value = *(reinterpret_cast(b_data_quant + (t_k >> 1))); + T scale = b_scale_vec[scale_col_offset + t_meta_k]; + uint8_t zp = b_zp_vec[zp_col_offset + t_meta_k/2]; + zp = (t_meta_k & 0x01) ? (zp >> 4) : (zp & 0x0f); + sum += AccumulateEightElements(value, scale, zp, a_data + k_id + (lane_id << 3)); + } + + // handle reminder + if (k_id + lane_id * 8 < k) { + const int t_k = k_id + (lane_id << 3); // k index for this thread + const int t_meta_k = t_k / block_size; // k index for this thread, points to the scale and zero point + uint32_t value = *(reinterpret_cast(b_data_quant + k_iter * 128 + lane_id * 4)); + T scale = b_scale_vec[scale_col_offset + t_meta_k]; + uint8_t zp = b_zp_vec[zp_col_offset + t_meta_k/2]; + zp = (t_meta_k & 0x01) ? (zp >> 4) : (zp & 0x0f); + sum += AccumulateEightElements(value, scale, zp, a_data + k_id + (lane_id << 3)); + } + + // warp reduction + for (int i = 16; i > 0; i = i / 2) { + sum += __shfl_down_sync(0xffffffff, sum, i); + } + + if (lane_id == 0) { + output[m_id * n + n_id] = sum; + } +} + +template +bool TryMatMul4Bits( + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + int shared_mem_per_block, + cudaStream_t stream) { + if (n % kColsPerThreadBlock != 0 || k % 8 != 0 || m > 1) { + return false; + } + dim3 blocks((n + kColsPerThreadBlock - 1) / kColsPerThreadBlock, m); + dim3 threads(kWarpSize, kColsPerThreadBlock); + int blocks_per_K = (k + block_size - 1) / block_size; + int blocks_per_thread_block = blocks_per_K * kColsPerThreadBlock; + int shared_mem_size = sizeof(T) * blocks_per_thread_block + blocks_per_thread_block / 2; + if (shared_mem_size > shared_mem_per_block) { + return false; + } + + if (16 == block_size) { + MatMulFloatInt4Kernel<<>>( + output, a_data, b_data_quant, scales_data, zero_points, m, n, k, blocks_per_K); + } else if (32 == block_size) { + MatMulFloatInt4Kernel<<>>( + output, a_data, b_data_quant, scales_data, zero_points, m, n, k, blocks_per_K); + } else if (64 == block_size) { + MatMulFloatInt4Kernel<<>>( + output, a_data, b_data_quant, scales_data, zero_points, m, n, k, blocks_per_K); + } else if (128 == block_size) { + MatMulFloatInt4Kernel<<>>( + output, a_data, b_data_quant, scales_data, zero_points, m, n, k, blocks_per_K); + } else { + ORT_THROW("block size ", block_size, " is not supported"); + } + + return true; +} + +template bool TryMatMul4Bits( + float* output, + const float* a_data, + const uint8_t* b_data_quant, + const float* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + int shared_mem_per_block, + cudaStream_t stream); + +template bool TryMatMul4Bits( + half* output, + const half* a_data, + const uint8_t* b_data_quant, + const half* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + int shared_mem_per_block, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh new file mode 100644 index 0000000000000..9ccbe4c4d97a8 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +bool TryMatMul4Bits( + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + int shared_mem_per_block, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.cc b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.cc new file mode 100644 index 0000000000000..381316f605fc9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.cc @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/tensor/dynamic_time_warping.h" +#include "contrib_ops/cuda/tensor/dynamic_time_warping_impl.h" +#include "core/providers/cpu/tensor/utils.h" + +#include +#include + +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +ONNX_OPERATOR_KERNEL_EX( + DynamicTimeWarping, + kMSDomain, + 1, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("F", DataTypeImpl::GetTensorType()) + .TypeConstraint("I", DataTypeImpl::GetTensorType()), + DynamicTimeWarping); + +Status DynamicTimeWarping::ComputeInternal(OpKernelContext* ctx) const { + const Tensor& input_tensor = *ctx->Input(0); + const auto& input_dims = input_tensor.Shape().GetDims(); + int rank = SafeInt(input_dims.size()); + ORT_ENFORCE(rank == 2 || (rank == 3 && input_dims[0] == 1), "Currently input rank must be 2, or (3 with first dim equal to 1), but got:", rank); + + const size_t rows = SafeInt(input_dims[rank == 3 ? 1 : 0]); + const size_t cols = SafeInt(input_dims[rank == 3 ? 2 : 1]); + size_t max_index_len = 0; + + size_t buffer_size_in_bytes = GetDynamicTimeWarpingBufferSize(1, rows, cols, max_index_len); + IAllocatorUniquePtr buffer = GetScratchBuffer(buffer_size_in_bytes, ctx->GetComputeStream()); + + size_t result_len = 0; + ORT_RETURN_IF_ERROR(LaunchDynamicTimeWarping( + this->Stream(ctx), this->GetDeviceProp(), 1, rows, cols, + input_tensor.Data(), buffer.get(), result_len)); + + Tensor* output_tensor = ctx->Output(0, TensorShape{2LL, SafeInt(result_len)}); + + return CUDA_CALL(cudaMemcpy2DAsync( + output_tensor->MutableData(), result_len * sizeof(int32_t), + buffer.get() + ((max_index_len - result_len) * sizeof(int32_t)), max_index_len * sizeof(int32_t), + result_len * sizeof(int32_t), 2, + cudaMemcpyDeviceToDevice, this->Stream(ctx))); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.h b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.h new file mode 100644 index 0000000000000..3083e19aff6f2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/cuda_kernel.h" +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using onnxruntime::OpKernelContext; +using onnxruntime::OpKernelInfo; +using onnxruntime::cuda::CudaKernel; +class DynamicTimeWarping final : public CudaKernel { + public: + DynamicTimeWarping(const OpKernelInfo& info) : CudaKernel(info) {} + + ~DynamicTimeWarping() = default; + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.cu b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.cu new file mode 100644 index 0000000000000..7c3f2963207e6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.cu @@ -0,0 +1,142 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/tensor/dynamic_time_warping_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/common/common.h" +#include +#include + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +__global__ void DynamicTimeWarpingInitCost(float* cost_buffer, int8_t* trace_buffer, size_t cols_plus_1) { + int r = blockIdx.x; + cost_buffer += cols_plus_1 * r; + for (size_t i = threadIdx.x; i < cols_plus_1; i += blockDim.x) { + cost_buffer[i] = FLT_MAX; + } + if (r == 0) { + for (size_t i = threadIdx.x; i < cols_plus_1; i += blockDim.x) { + trace_buffer[i] = 2; + } + } + if (threadIdx.x == 0) trace_buffer[cols_plus_1 * r] = 1; + if (threadIdx.x == 0 && r == 0) *cost_buffer = 0.0f; +} + +__global__ void DynamicTimeWarpingKernel( + size_t rows, + size_t cols, + size_t max_index_len, + const float* input, + float* cost_buffer, + int8_t* trace_buffer, + int32_t* result_buffer, + size_t* result_len_device +) { + const int diag_max = static_cast(rows + cols); + for (int d = 1; d <= diag_max; d++) { + for (int c = threadIdx.x + 1; c <= cols; c += blockDim.x) { + int r = d - c; + if (r >= 1 && r <= rows) { + int cost_idx = ((r - 1) * (cols + 1) + (c - 1)); //[r - 1, c - 1] + const float c0 = cost_buffer[cost_idx]; + const float c1 = cost_buffer[cost_idx + 1]; // [r - 1, c] + const float c2 = cost_buffer[cost_idx + cols + 1]; // [r, c - 1] + + float cost; + int8_t t; + if (c0 < c1 && c0 < c2) { + cost = c0; + t = 0; + } else if (c1 < c0 && c1 < c2) { + cost = c1; + t = 1; + } else { + cost = c2; + t = 2; + } + cost_idx += ((cols + 1) + 1); + cost_buffer[cost_idx] = cost + input[(r - 1) * cols + (c - 1)]; + trace_buffer[cost_idx] = t; + } + } + __syncthreads(); + } + + //back tracing, reverse append to result buffer + if (threadIdx.x == 0) { + int r = rows - 1; + int c = cols - 1; + int pos = static_cast(max_index_len); // reverse put + while (r >= 0 && c >= 0) { + --pos; + result_buffer[pos] = r; + result_buffer[max_index_len + pos] = c; + const int trace_index = (r + 1) * (cols + 1) + (c + 1); + int8_t t = trace_buffer[trace_index]; + switch (t) { + case 0: r -= 1; c -= 1; break; + case 1: r -= 1; break; + default: c -= 1; break; + } + } + *result_len_device = max_index_len - static_cast(pos); + } +} + +size_t GetDynamicTimeWarpingBufferSize(size_t batch, size_t rows, size_t cols, size_t& max_index_len) { + max_index_len = rows + cols + 1; + size_t cost_buffer_size = ((rows + 1) * (cols + 1)); + return batch * max_index_len * 2 * sizeof(int32_t) + // two index arrays + sizeof(int64_t) + // final index array length + batch* cost_buffer_size * sizeof(float) + // cost buffer + batch* cost_buffer_size * sizeof(int8_t); // trace buffer +} + +Status LaunchDynamicTimeWarping( + cudaStream_t stream, + const cudaDeviceProp& device_prop, + size_t batch, + size_t rows, + size_t cols, + const float* input, + void* buffer, + size_t& result_len +) { + ORT_ENFORCE(batch == 1); + size_t max_index_len = rows + cols + 1; + int32_t* result_buffer = (int32_t*)buffer; + size_t* result_len_device_buf = (size_t*)(result_buffer + (batch * max_index_len * 2)); + float* cost_buffer = (float*)(result_len_device_buf + 1); + int8_t* trace_buffer = (int8_t*)(cost_buffer + ((rows + 1) * (cols + 1))); + + dim3 block(device_prop.maxThreadsPerBlock); + dim3 grid_init((unsigned)SafeInt(rows + 1), (unsigned)SafeInt(batch)); + DynamicTimeWarpingInitCost<<>>(cost_buffer, trace_buffer, cols+1); + ORT_RETURN_IF_ERROR(CUDA_CALL(cudaGetLastError())); + + dim3 grid(1, (unsigned)SafeInt(batch)); + DynamicTimeWarpingKernel<<>>( + rows, + cols, + max_index_len, + input, + cost_buffer, + trace_buffer, + result_buffer, + result_len_device_buf); + ORT_RETURN_IF_ERROR(CUDA_CALL(cudaGetLastError())); + + ORT_RETURN_IF_ERROR(CUDA_CALL(cudaMemcpyAsync(&result_len, result_len_device_buf, sizeof(size_t), cudaMemcpyDeviceToHost, stream))); + ORT_RETURN_IF_ERROR(CUDA_CALL(cudaGetLastError())); + return CUDA_CALL(cudaStreamSynchronize(stream)); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.h b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.h new file mode 100644 index 0000000000000..cb4a0dfb16807 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +size_t GetDynamicTimeWarpingBufferSize(size_t batch, size_t rows, size_t cols, size_t& max_index_len); + +Status LaunchDynamicTimeWarping( + cudaStream_t stream, + const cudaDeviceProp& device_prop, + size_t batch, + size_t rows, + size_t cols, + const float* input, + void* buffer, + size_t& result_len); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/unfold.cc b/onnxruntime/contrib_ops/cuda/tensor/unfold.cc new file mode 100644 index 0000000000000..c38c8c5317f0a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/unfold.cc @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/tensor/unfold.h" +#include "contrib_ops/cuda/tensor/unfold_impl.h" +#include "core/providers/cpu/tensor/utils.h" + +#include +#include + +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +ONNX_OPERATOR_KERNEL_EX( + UnfoldTensor, + kMSDomain, + 1, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()), + UnfoldTensor); + +Status UnfoldTensor::ComputeInternal(OpKernelContext* ctx) const { + const Tensor& input_tensor = *ctx->Input(0); + const auto& input_dims = input_tensor.Shape().GetDims(); + int rank = SafeInt(input_dims.size()); + + int dim = SafeInt(HandleNegativeAxis(dim_, rank)); + ORT_ENFORCE(dim < rank, "input rank:", rank, " is not bigger than attribut specified dim: ", dim); + ORT_ENFORCE(input_dims[dim] >= size_, "dimsize:", input_dims[dim], " is less than unfold size:", size_); + + int64_t leading_dims = std::accumulate(input_dims.begin(), input_dims.begin() + dim, 1LL, std::multiplies()); + int64_t tailing_dims = std::accumulate(input_dims.begin() + (dim + 1), input_dims.end(), 1LL, std::multiplies()); + + std::vector output_dims(rank + 1, 0); + std::copy(input_dims.begin(), input_dims.end(), output_dims.begin()); + output_dims[dim] = (input_dims[dim] - size_) / step_ + 1; + output_dims.back() = size_; + TensorShape output_shape(output_dims); + Tensor* output_tensor = ctx->Output(0, output_shape); + + cudaStream_t stream = this->Stream(ctx); + const cudaDeviceProp& device_prop = this->GetDeviceProp(); + size_t element_size = input_tensor.DataType()->Size(); + return LaunchUnfoldTensor( + stream, device_prop, element_size, input_tensor.DataRaw(), output_tensor->MutableDataRaw(), + leading_dims, input_dims[dim], tailing_dims, size_, step_); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/unfold.h b/onnxruntime/contrib_ops/cuda/tensor/unfold.h new file mode 100644 index 0000000000000..1717687593470 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/unfold.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/cuda_kernel.h" +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using onnxruntime::OpKernelContext; +using onnxruntime::OpKernelInfo; +using onnxruntime::cuda::CudaKernel; +class UnfoldTensor final : public CudaKernel { + public: + UnfoldTensor(const OpKernelInfo& info) : CudaKernel(info) { + dim_ = SafeInt(info.GetAttrOrDefault("dim", -1LL)); + step_ = SafeInt(info.GetAttrOrDefault("step", 1LL)); + ORT_ENFORCE(step_ > 0, "step must greater than zero!"); + + int64_t temp_size; + ORT_ENFORCE(info.GetAttr("size", &temp_size).IsOK()); + size_ = SafeInt(temp_size); + } + + ~UnfoldTensor() = default; + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int dim_; + int size_; + int step_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.cu b/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.cu new file mode 100644 index 0000000000000..a3c93ceb33c46 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.cu @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/tensor/unfold_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/common/common.h" +#include + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__global__ void UnfoldTensorKernel( + const T* input, + T* output, + int64_t N, + int64_t unfold_size, // stride_tailing_dim_dst + int64_t tailing_dims_size, // stride_fold_dim_dst = tailing_dims_size * unfold_size, stride_append_dim_src = tailing_dims_size + int64_t stride_leading_dst, + int64_t stride_fold_dim_src, + int64_t stride_leading_src +) { + int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) return; + + const int64_t idx_leading = idx / stride_leading_dst; + int64_t n = idx % stride_leading_dst; + const int64_t stride_fold_dim_dst = tailing_dims_size * unfold_size; + const int64_t idx_fold = n / stride_fold_dim_dst; + n %= stride_fold_dim_dst; + const int64_t idx_tailing = n / unfold_size; + const int64_t idx_append = n % unfold_size; + + int64_t idx_src = idx_leading * stride_leading_src + idx_fold * stride_fold_dim_src + idx_tailing + idx_append * tailing_dims_size; + output[idx] = input[idx_src]; +} + + +Status LaunchUnfoldTensor( + cudaStream_t stream, + const cudaDeviceProp& device_prop, + size_t element_size, + const void* input, + void* output, + int64_t leading_dims_size, + int64_t unfold_dim_size, + int64_t tailing_dims_size, + int64_t unfold_size, + int64_t step_size +) { + int64_t TPB = device_prop.maxThreadsPerBlock; + int64_t unfold_dim_size_dst = (unfold_dim_size - unfold_size) / step_size + 1; + int64_t N = leading_dims_size * unfold_dim_size_dst * tailing_dims_size * unfold_size; + int64_t num_blocks = (N + TPB - 1) / TPB; + + int64_t stride_leading_dst = unfold_size * tailing_dims_size * unfold_dim_size_dst; + + int64_t stride_fold_dim_src = tailing_dims_size * step_size; + int64_t stride_leading_src = tailing_dims_size * unfold_dim_size; + + dim3 block((unsigned)SafeInt(TPB)); + dim3 grid((unsigned)SafeInt(num_blocks)); + switch (element_size) { + case 1: + UnfoldTensorKernel<<>>( + (const int8_t*)input, (int8_t*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; + case 2: + UnfoldTensorKernel<<>>( + (const int16_t*)input, (int16_t*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; + case 4: + UnfoldTensorKernel<<>>( + (const int32_t*)input, (int32_t*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; + case 8: + UnfoldTensorKernel<<>>( + (const int64_t*)input, (int64_t*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; + case 16: + UnfoldTensorKernel<<>>( + (const float4*)input, (float4*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; + default: + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Unsupported element_size"); + } + + return CUDA_CALL(cudaGetLastError()); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.h b/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.h new file mode 100644 index 0000000000000..9e82dccdec23c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +Status LaunchUnfoldTensor( + cudaStream_t stream, + const cudaDeviceProp& device_prop, + size_t element_size, + const void* input, + void* output, + int64_t leading_dims_size, + int64_t tailing_dims_size, + int64_t dim_size, + int64_t unfold_size, + int64_t step_size); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc index d18460e016444..2a90e4911f286 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc @@ -33,6 +33,28 @@ ONNX_OPERATOR_KERNEL_EX( DataTypeImpl::GetTensorType()}), BeamSearch); +ONNX_OPERATOR_KERNEL_EX( + WhisperBeamSearch, + kMSDomain, + 1, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'input_ids' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 1) // 'max_length' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 2) // 'min_length' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 3) // 'num_beams' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 4) // 'num_return_sequences' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 5) // 'length_penalty' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 6) // 'repetition_penalty' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 9) // 'attention_mask' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 10) // 'decoder_input_ids' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 11) // 'logits_processor' needs to be on CPU + .OutputMemoryType(OrtMemTypeCPUOutput, 0) // 'sequences' output on CPU + .OutputMemoryType(OrtMemTypeCPUOutput, 1) // 'sequences_scores' output on CPU + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + WhisperBeamSearch); + transformers::CudaTensorConsoleDumper g_cuda_dumper; BeamSearch::BeamSearch(const OpKernelInfo& info) @@ -58,7 +80,9 @@ BeamSearch::BeamSearch(const OpKernelInfo& info) GenerationCudaDeviceHelper::UpdateDecoderFeeds, GenerationCudaDeviceHelper::ExpandBuffer, GenerationCudaDeviceHelper::ExpandBuffer, - GenerationCudaDeviceHelper::ExpandBuffer); + GenerationCudaDeviceHelper::ExpandBuffer, + GenerationCudaDeviceHelper::UpdateDecoderCrossQK, + GenerationCudaDeviceHelper::FinalizeDecoderCrossQK); SetConsoleDumper(&g_cuda_dumper); @@ -87,6 +111,60 @@ Status BeamSearch::Compute(OpKernelContext* context) const { return s; } +WhisperBeamSearch::WhisperBeamSearch(const OpKernelInfo& info) + : onnxruntime::contrib::transformers::WhisperBeamSearch(info) { + SetDeviceHelpers(GenerationCudaDeviceHelper::AddToFeeds, + GenerationCudaDeviceHelper::TopK, + GenerationCudaDeviceHelper::DeviceCopy, + GenerationCudaDeviceHelper::DeviceCopy, + GenerationCudaDeviceHelper::ProcessLogits, + GenerationCudaDeviceHelper::ProcessLogits, + GenerationCudaDeviceHelper::InitBeamState, + GenerationCudaDeviceHelper::InitBeamState, + GenerationCudaDeviceHelper::CreateBeamScorer); + +#ifndef USE_ROCM + SetDeviceHelpers_Cuda(GenerationCudaDeviceHelper::ReorderPastState, GenerationCudaDeviceHelper::InitCacheIndir); +#endif + + SetDeviceHelpers_Gpt(GenerationCudaDeviceHelper::UpdateGptFeeds, + GenerationCudaDeviceHelper::UpdateGptFeeds); + + SetDeviceHelpers_EncoderDecoder(GenerationCudaDeviceHelper::UpdateDecoderFeeds, + GenerationCudaDeviceHelper::UpdateDecoderFeeds, + GenerationCudaDeviceHelper::ExpandBuffer, + GenerationCudaDeviceHelper::ExpandBuffer, + GenerationCudaDeviceHelper::ExpandBuffer, + GenerationCudaDeviceHelper::UpdateDecoderCrossQK, + GenerationCudaDeviceHelper::FinalizeDecoderCrossQK); + + SetConsoleDumper(&g_cuda_dumper); + +#ifndef USE_ROCM + cuda_device_prop_ = &reinterpret_cast(info.GetExecutionProvider())->GetDeviceProp(); + + cuda_device_arch_ = static_cast(cuda_device_prop_)->major * 100 + + static_cast(cuda_device_prop_)->minor * 10; +#endif +} + +Status WhisperBeamSearch::ComputeInternal(OpKernelContext* context) const { + return onnxruntime::contrib::transformers::WhisperBeamSearch::Compute(context); +} + +Status WhisperBeamSearch::Compute(OpKernelContext* context) const { + auto s = ComputeInternal(context); + + if (s.IsOK()) { + auto err = cudaGetLastError(); + if (err != cudaSuccess) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA error ", cudaGetErrorName(err), ":", cudaGetErrorString(err)); + } + } + + return s; +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search.h b/onnxruntime/contrib_ops/cuda/transformers/beam_search.h index dda8271e3a6a0..a4370abd8af46 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search.h +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search.h @@ -21,6 +21,16 @@ class BeamSearch final : public onnxruntime::contrib::transformers::BeamSearch { Status ComputeInternal(OpKernelContext* context) const; }; +class WhisperBeamSearch final : public onnxruntime::contrib::transformers::WhisperBeamSearch { + public: + WhisperBeamSearch(const OpKernelInfo& info); + + Status Compute(OpKernelContext* context) const override; + + private: + Status ComputeInternal(OpKernelContext* context) const; +}; + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index 67b52b466f1c7..dbd7fb010462d 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -1376,6 +1376,235 @@ void ReorderPastStatesKernelLauncher(void* out_buffer, } } +template +__global__ void CopyCrossQKSingleDecodeStepKernel( + T* target, // shape [batchxbeam, layer_head_pair_count, max_length, frame] + T** qk_layer_pointers, + int token_index, + int num_layers, + int num_heads, + const int* cross_qk_layer_head_pairs, + int frames, + int max_length +) { + const int pair = blockIdx.x; + const int layer_head_pair_count = gridDim.x; + const int bbm = blockIdx.y; + cross_qk_layer_head_pairs += (pair * 2); + const int layer = *cross_qk_layer_head_pairs; + const int head = *(cross_qk_layer_head_pairs + 1); + + target += ((int64_t)bbm * layer_head_pair_count + pair) * max_length * frames + ((int64_t)token_index * frames); + T* src = qk_layer_pointers[layer] + ((int64_t)bbm * num_heads + head) * frames; + + for (int tid = threadIdx.x; tid < frames; tid += blockDim.x) { + target[tid] = src[tid]; // use vectorized read write in future if needed + } +} + +void LaunchCopyCrossQKSingleDecodeStep( + cudaStream_t stream, + float* cross_qk_buffer_data, + float** qk_layer_pointers, + int token_index, + int batchxbeam, + int num_layers, + int num_heads, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + int frames, + int max_length +) { + dim3 block(512); + dim3 grid(cross_qk_layer_head_pair_count, batchxbeam); + typedef typename ToCudaType::MappedType CudaT; + + CopyCrossQKSingleDecodeStepKernel<<>>( + (CudaT*)cross_qk_buffer_data, + (CudaT**)qk_layer_pointers, + token_index, + num_layers, + num_heads, + cross_qk_layer_head_pairs, + frames, + max_length + ); +} + + +template +__global__ void CopyDecoderCrossQKAllStepsKernel( + int context_decoding_len, + int num_beams, + int num_return_sequences, + int max_length, + int frames_of_k, + const T* cross_qk_buffer_data, // [batch, num_beams, layer_head_pair_count, max_length, frames] + T* cross_qk_output, // [batch, num_return_sequences, layer_head_pair_count, total_decoding_length, frames] + const int* cache_indir_data, // [batch, num_beams, max_length] + const int32_t* beam_indices +) { + const int pair = blockIdx.y; + const int layer_head_pair_count = gridDim.y; + const int total_decoding_length = gridDim.x; + const int token_decoding_index = blockIdx.x; + const int br = blockIdx.z; + const int batch = br / num_return_sequences; + const int ret_seq_id = br % num_return_sequences; + + // get the real beam index, as the cache_indir_data did not updated in last token + const int src_beam = beam_indices[batch * num_beams + ret_seq_id] % num_beams; + + const int64_t offset_in_cache = ((int64_t)batch * num_beams + src_beam) * max_length + token_decoding_index + context_decoding_len; + int bm_mapped = ((num_beams <= 1) ? 0: ((token_decoding_index == total_decoding_length - 1) ? ret_seq_id : cache_indir_data[offset_in_cache])); + int bi_src = batch * num_beams + bm_mapped; + + T* target = cross_qk_output + + (((int64_t)br * layer_head_pair_count + (int64_t)pair) * total_decoding_length + token_decoding_index) * frames_of_k; + const T* src = cross_qk_buffer_data + + ((int64_t)bi_src * layer_head_pair_count * max_length + (int64_t)pair * max_length + token_decoding_index) * frames_of_k; + for (int tid = threadIdx.x; tid < frames_of_k; tid += blockDim.x) { + target[tid] = src[tid]; // use vectorized read write in future if needed + } +} + +void LaunchFinalizeCrossQK( + cudaStream_t stream, + int iteration_number, + int context_decoding_len, + int batch_size, + int num_beams, + int max_length, + int cross_qk_layer_head_pair_count, + [[maybe_unused]] const int* cross_qk_layer_head_pairs, + int frames_of_k, + const float* cross_qk_buffer_data, + float* cross_qk_output, + int num_return_sequences, + const int* cache_indir_data, + const int32_t* beam_indices +) { + int64_t br = (int64_t)batch_size * num_return_sequences; + ORT_ENFORCE(br < 65536L && cross_qk_layer_head_pair_count < 65536); + const int total_decoding_length = iteration_number - 1; + dim3 block(512); + dim3 grid(total_decoding_length, cross_qk_layer_head_pair_count, (unsigned)br); + typedef typename ToCudaType::MappedType CudaT; + + CopyDecoderCrossQKAllStepsKernel<<>>( + context_decoding_len, + num_beams, + num_return_sequences, + max_length, + frames_of_k, + (const CudaT*)cross_qk_buffer_data, + (CudaT*)cross_qk_output, + cache_indir_data, + beam_indices); +} + +template +__global__ void ForceDecodingIdsKernel( + float* beam_scores, + const int vocab_size, + const int32_t* force_ids, + int id_len, + int step +) { + const int num_beams = gridDim.y; + const int beam = blockIdx.y; + const int batch = blockIdx.z; + beam_scores += (((int64_t)batch * num_beams + beam)* vocab_size); // move to (batch, beam) + const int32_t id_wanted = force_ids[((int64_t)batch * id_len) + step]; + if (id_wanted < 0 || id_wanted >= vocab_size) return; + + const int32_t elements_per_block = (int32_t)blockDim.x * ElementsPerThreads; + const int32_t block_start_id = blockIdx.x * elements_per_block; + + int32_t token_id = block_start_id + (int)threadIdx.x; + #pragma unroll + for (int elem = 0; elem < ElementsPerThreads; elem++) { + if (token_id < vocab_size) { + beam_scores[token_id] = ((token_id == id_wanted) ? 0.0f : cub::FpLimits::Lowest()); + } + token_id += (int)blockDim.x; + } +} + + +void LaunchForceDecodingIds( + float* beam_scores, + const int batch_size, + const int num_beams, + const int vocab_size, + const int32_t* force_ids, + int id_len, + int step, + cudaStream_t stream +) { + dim3 blocks(512); + constexpr int ElementsPerThreads = 4; + unsigned gridx = static_cast((vocab_size + 512 * ElementsPerThreads - 1) / (512 * ElementsPerThreads)); + dim3 grids(gridx, num_beams, batch_size); + ForceDecodingIdsKernel<<>>( + beam_scores, vocab_size, force_ids, id_len, step + ); +} + +template +__global__ void SaveNoSpeechProbsKernel( + T* result_no_speech_probs, + const float* probs, + const int batch_size, + const int num_beams, + const int vocab_size, + const int no_speech_token_id +) { + int b = blockIdx.x * blockDim.x + threadIdx.x; + if (b < batch_size) { + int64_t src_offset = b * num_beams * vocab_size + no_speech_token_id; + result_no_speech_probs[b] = (T)(probs[src_offset]); + } +} + +template +void LaunchSaveNoSpeechProbs( + T* result_no_speech_probs, /* [batch]*/ + const float* probs, /* [batch, num_beams, vocab_size]*/ + const int batch_size, + const int num_beams, + const int vocab_size, + const int no_speech_token_id, + cudaStream_t stream +) { + int tpb = 256; + int bpg = (batch_size + 255) / 256; + + typedef typename ToCudaType::MappedType CudaT; + SaveNoSpeechProbsKernel<<>>( + (CudaT*)result_no_speech_probs, probs, batch_size, num_beams, vocab_size, no_speech_token_id); +} + +template void LaunchSaveNoSpeechProbs( + float* result_no_speech_probs, + const float* probs, + const int batch_size, + const int num_beams, + const int vocab_size, + const int no_speech_token_id, + cudaStream_t stream +); + +template void LaunchSaveNoSpeechProbs( + MLFloat16* result_no_speech_probs, + const float* probs, + const int batch_size, + const int num_beams, + const int vocab_size, + const int no_speech_token_id, + cudaStream_t stream +); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h index 2c3662fb18edd..5ed5949196b29 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h @@ -221,6 +221,56 @@ void ReorderPastStatesKernelLauncher(void* out_buffer, int head_size, int chunk_size, cudaStream_t stream); + +void LaunchCopyCrossQKSingleDecodeStep( + cudaStream_t stream, + float* cross_qk_buffer_data, + float** qk_layer_pointers, + int token_index, + int batchxbeam, + int num_layers, + int num_heads, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + int frames, + int max_length); + +void LaunchFinalizeCrossQK( + cudaStream_t stream, + int iteration_number, + int context_decoding_len, + int batch_size, + int num_beams, + int max_length, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + int frames_of_k, + const float* cross_qk_buffer_data, + float* cross_qk_output, + int num_return_sequences, + const int* cache_indir_data, + const int32_t* beam_indices); + +void LaunchForceDecodingIds( + float* beam_scores, + const int batch_size, + const int num_beams, + const int vocab_size, + const int32_t* force_ids, + int id_len, + int step, + cudaStream_t stream); + +template +void LaunchSaveNoSpeechProbs( + T* result_no_speech_probs, /* [batch]*/ + const float* probs, /* [batch, num_beams, vocab_size]*/ + const int batch_size, + const int num_beams, + const int vocab_size, + const int no_speech_token_id, + cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 121cd05956d5c..380d561bbb23c 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -13,6 +13,8 @@ #include #include "contrib_ops/cuda/transformers/generation_cuda_impl.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +#include "contrib_ops/cpu/transformers/logits_processor.h" +#include "contrib_ops/cpu/transformers/generation_shared.h" #include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h" #include "contrib_ops/cpu/transformers/subgraph_gpt.h" #include "contrib_ops/cuda/transformers/beam_search_topk.h" @@ -203,7 +205,7 @@ Status AddToFeeds(Stream* ort_stream, ORT_ENFORCE(total_bytes > 0); cudaStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; - auto pinned_buffer = IAllocator::MakeUniquePtr(host_allocator, total_bytes); + auto pinned_buffer = IAllocator::MakeUniquePtr(host_allocator, total_bytes, false, ort_stream); char* pinned_data = static_cast(pinned_buffer.get()); // Copy tensors to one pinned memory buffer (so that we only need copy to GPU once) char* destination = pinned_data; @@ -419,11 +421,21 @@ Status ProcessLogits(const OrtValue& logits, // dumper->Print("next_token_scores after softmax", next_token_scores.data(), batch_size, num_beams, vocab_size); #endif + const bool is_whisper_model = (parameters->model_type == onnxruntime::contrib::transformers::IGenerationParameters::kModelTypeWhisper); + if (step == 1 && is_whisper_model && parameters->no_speech_probs) { + cuda::LaunchSaveNoSpeechProbs( + (T*)parameters->no_speech_probs, Y_data, batch_size, num_beams, vocab_size, parameters->no_speech_token, cuda_stream); + } + + // NOTE: currently we treat extra decoding ids are same + int extra_decoding_len = static_cast(parameters->extra_decoding_ids.size() / parameters->batch_size); + const bool need_handle_extra_decoding_ids = is_whisper_model && (!parameters->extra_decoding_ids.empty()) && (extra_decoding_len >= step); + cuda::LaunchLogitsProcessKernel( next_token_scores.data(), parameters->vocab_mask.data(), - step > 1 ? nullptr : parameters->prefix_vocab_mask.data(), // prefix vocab mask is applied to first step only. - nullptr, // parameters->presence_mask.data(), + (step > extra_decoding_len + 1) ? nullptr : parameters->prefix_vocab_mask.data(), // prefix vocab mask is applied to first step only. + nullptr, // parameters->presence_mask.data(), parameters->presence_penalty, parameters->temperature, parameters->batch_size, @@ -438,6 +450,50 @@ Status ProcessLogits(const OrtValue& logits, // parameters->no_repeat_ngram_size, cuda_stream); + // Whisper time stamp generation. + // TODO: implement it on GPU + bool gen_timestamp = is_whisper_model && + (parameters->logits_processor == onnxruntime::contrib::transformers::IGenerationParameters::kLogitsProcessorTypeWhisper); + if (gen_timestamp) { + // Copy next token scores to cpu memory, copy Sequences to cpu + std::vector cpu_next_token_scores(next_token_scores.size()); + gsl::span cpu_next_token_scores_span(cpu_next_token_scores.data(), cpu_next_token_scores.size()); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_next_token_scores.data(), + next_token_scores.data(), + next_token_scores.size_bytes(), + cudaMemcpyDeviceToHost, + cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(const_cast(sequences->GetSequence(0).data()), + sequences->GetCurrentDeviceSequences().data(), + sequences->GetSequence(0).size_bytes() * batch_beam_size, + cudaMemcpyDeviceToHost, + cuda_stream)); + constexpr int max_initial_timestamp_index = 50; + onnxruntime::contrib::transformers::TimestampLogitsProcessor time_logit_processor(parameters->eos_token_id, max_initial_timestamp_index); + onnxruntime::contrib::transformers::NextTokenScores next_token_scores_timestamp({cpu_next_token_scores_span, batch_beam_size, vocab_size}); + + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); + time_logit_processor.Process(sequences, next_token_scores_timestamp); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(next_token_scores.data(), + cpu_next_token_scores.data(), + next_token_scores.size_bytes(), + cudaMemcpyHostToDevice, + cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); + } + + if (need_handle_extra_decoding_ids && !parameters->extra_decoding_ids.empty()) { + cuda::LaunchForceDecodingIds( + next_token_scores.data(), + parameters->batch_size, + parameters->num_beams, + parameters->vocab_size, + parameters->extra_decoding_ids.data(), + static_cast(parameters->extra_decoding_ids.size() / parameters->batch_size), + step - 1, + cuda_stream); + } + #ifdef DEBUG_GENERATION dumper->Print("next_token_scores after logits process", next_token_scores.data(), batch_size, num_beams, vocab_size); #endif @@ -800,13 +856,11 @@ Status GreedySearchProcessLogits( // Sequences generated by beam scorer is currently stored in CPU. // Copy sequences to device only when repetition penalty or no repeat ngram is used in kernel - BufferUniquePtr sequences_buffer; + IAllocatorUniquePtr sequences_buffer; int current_sequence_length = sequences->GetSequenceLength(); if (parameters->repetition_penalty != 1.0f) { size_t bytes = SafeInt(sizeof(int32_t)) * batch_beam_size * parameters->max_length; - void* data = allocator->Alloc(bytes); - BufferUniquePtr temp_buffer(data, BufferDeleter(allocator)); - sequences_buffer = std::move(temp_buffer); + sequences_buffer = IAllocator::MakeUniquePtr(allocator, bytes, false, stream); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(sequences_buffer.get(), sequences->GetSequence(0).data(), bytes, cudaMemcpyHostToDevice, cuda_stream)); } @@ -1189,14 +1243,14 @@ Status UpdateDecoderFeeds( if (past_present_share_buffer) { // Update past sequence length input - const ptrdiff_t past_sequence_length_idx = 2 * (static_cast(last_outputs.size()) - t5_decoder_first_present_output_idx) + t5_decoder_first_past_input_idx; + const ptrdiff_t past_sequence_length_idx = 2 * num_present_tensors + t5_decoder_first_past_input_idx; *(next_inputs[past_sequence_length_idx].GetMutable()->MutableData()) = current_length - 1; // Update beam search specific input for DecoderMaskedSelfAttention (cache indirection) if present // If the last input is not `past_sequence_length`, then the beam search specific inputs // for `DecoderMaskedSelfAttention` is present - if (need_cache_indir) { + if (need_cache_indir && num_beams > 1) { ORT_ENFORCE(!beam_indices_gpu.empty(), "Beam indices must be present on CUDA while using DecoderMaskedMultiHeadAttention with BeamSearch"); // The cache indirection feed comes 2 feeds after the `past_sequence_length` feed @@ -1521,6 +1575,93 @@ template Status ExpandBuffer( OrtValue& expanded, bool only_copy_shape, int max_sequence_length); + +Status UpdateDecoderCrossQK( + int iteration_number, + Stream* stream, + OrtValue* cross_qks, + IAllocatorUniquePtr& qk_layer_pointers, + int num_layers, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + float* cross_qk_buffer_data, + int max_length, + AllocatorPtr allocator) { + cudaStream_t cuda_stream = stream ? static_cast(stream->GetHandle()) : nullptr; + + if (qk_layer_pointers.get() == nullptr) { + // Put all the qk pointers into gpu, as they did not change in following decoding steps + // also this help to use single kernel to process each step + qk_layer_pointers = IAllocator::MakeUniquePtr(allocator, static_cast(num_layers), false, stream); + std::vector qk_layer_data(num_layers, nullptr); + for (int layer = 0; layer < num_layers; layer++) { + qk_layer_data[layer] = cross_qks[layer].GetMutable()->MutableData(); + } + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync((void*)qk_layer_pointers.get(), qk_layer_data.data(), sizeof(qk_layer_data[0]) * num_layers, + cudaMemcpyHostToDevice, cuda_stream)); + } + + auto cross_qk_layer_shape = cross_qks[0].GetMutable()->Shape(); + int64_t batchxbeam = cross_qk_layer_shape[0]; + int64_t num_heads = cross_qk_layer_shape[1]; + int64_t frames = cross_qk_layer_shape[3]; + + cuda::LaunchCopyCrossQKSingleDecodeStep( + cuda_stream, + cross_qk_buffer_data, + qk_layer_pointers.get(), + iteration_number - 2, + static_cast(batchxbeam), + num_layers, + static_cast(num_heads), + cross_qk_layer_head_pair_count, + cross_qk_layer_head_pairs, + static_cast(frames), + max_length); + + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + + return Status::OK(); +} + +Status FinalizeDecoderCrossQK( + Stream* stream, + int iteration_number, + int context_decoding_len, + int batch_size, + int num_beams, + int max_length, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + int frames_of_k, + const float* cross_qk_buffer_data, + float* cross_qk_output, + int num_return_sequences, + const int* cache_indir_data, + gsl::span beam_indices_gpu) { + cudaStream_t cuda_stream = stream ? static_cast(stream->GetHandle()) : nullptr; + + cuda::LaunchFinalizeCrossQK( + cuda_stream, + iteration_number, + context_decoding_len, + batch_size, + num_beams, + max_length, + cross_qk_layer_head_pair_count, + cross_qk_layer_head_pairs, + frames_of_k, + cross_qk_buffer_data, + cross_qk_output, + num_return_sequences, + cache_indir_data, + beam_indices_gpu.data()); + + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + + return Status::OK(); +} + } // namespace GenerationCudaDeviceHelper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h index f5f062d7a101b..7a718eb9f66c1 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h @@ -150,6 +150,34 @@ Status ExpandBuffer( bool only_copy_shape, int max_sequence_length = 0); +Status UpdateDecoderCrossQK( + int iteration_number, + Stream* stream, + OrtValue* cross_qks, + IAllocatorUniquePtr& qk_layer_pointers, + int num_layers, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + float* cross_qk_buffer_data, + int max_length, + AllocatorPtr allocator); + +Status FinalizeDecoderCrossQK( + Stream* stream, + int iteration_number, + int context_decoding_len, + int batch_size, + int num_beams, + int max_length, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + int frames_of_k, + const float* cross_qk_buffer_data, + float* cross_qk_output, + int num_return_sequences, + const int* cache_indir_data, + gsl::span beam_indices); + } // namespace GenerationCudaDeviceHelper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/greedy_search.cc b/onnxruntime/contrib_ops/cuda/transformers/greedy_search.cc index d9014ca8f5c24..812ab0b1bcae6 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/greedy_search.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/greedy_search.cc @@ -48,10 +48,12 @@ GreedySearch::GreedySearch(const OpKernelInfo& info) SetConsoleDumper(&g_cuda_dumper_greedysearch); +#ifndef USE_ROCM cuda_device_prop_ = &reinterpret_cast(info.GetExecutionProvider())->GetDeviceProp(); cuda_device_arch_ = static_cast(cuda_device_prop_)->major * 100 + static_cast(cuda_device_prop_)->minor * 10; +#endif } Status GreedySearch::ComputeInternal(OpKernelContext* context) const { diff --git a/onnxruntime/contrib_ops/js/fused_conv.cc b/onnxruntime/contrib_ops/js/fused_conv.cc new file mode 100644 index 0000000000000..76402f0681976 --- /dev/null +++ b/onnxruntime/contrib_ops/js/fused_conv.cc @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/operators/conv.h" +namespace onnxruntime { +namespace contrib { +namespace js { + +ONNX_OPERATOR_KERNEL_EX( + FusedConv, + kMSDomain, + 1, + kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + onnxruntime::js::Conv); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index b0540d785badd..498a9f5679eb5 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -13,6 +13,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiH class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv); template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -27,7 +28,9 @@ Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo}; + BuildKernelCreateInfo, + BuildKernelCreateInfo}; for (auto& function_table_entry : function_table) { KernelCreateInfo info = function_table_entry(); diff --git a/onnxruntime/contrib_ops/js/skip_layer_norm.cc b/onnxruntime/contrib_ops/js/skip_layer_norm.cc index ee315f9b31e3b..f949326e1dc95 100644 --- a/onnxruntime/contrib_ops/js/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/js/skip_layer_norm.cc @@ -7,14 +7,16 @@ namespace onnxruntime { namespace contrib { namespace js { +using onnxruntime::js::JsepSupportedFloatTypes; + ONNX_OPERATOR_KERNEL_EX( SkipLayerNormalization, kMSDomain, 1, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) - .TypeConstraint("U", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", JsepSupportedFloatTypes()) + .TypeConstraint("U", JsepSupportedFloatTypes()), SkipLayerNorm); } // namespace js diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh index 246b66078537a..78983ac95e672 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh @@ -838,7 +838,7 @@ auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { Nop{}); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); if constexpr (USE_MASK) { ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp::LaunchConvertToFilledMaskValue(params)); diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh index cbf24ee2f5487..ea9040aa7875f 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh @@ -58,7 +58,7 @@ auto GetCKGemmAddFastGeluTypeStringAndOps() { auto zero = ToHipType::FromFloat(0.0f); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->alpha != one || params->beta != zero || params->bias == nullptr, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias != nullptr", params->Signature()); + impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias != nullptr"); auto nop = Nop{}; auto addfastgelu = AddFastGelu{}; @@ -67,7 +67,7 @@ auto GetCKGemmAddFastGeluTypeStringAndOps() { params->lda, params->ldb, std::array{0}, params->ldc, nop, nop, addfastgelu); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); return Status::OK(); }; @@ -95,7 +95,7 @@ auto GetCKGemmFastGeluTypeStringAndOps() { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->alpha != one || params->beta != zero || params->bias != nullptr, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias == nullptr", params->Signature()); + impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias == nullptr"); auto nop = Nop{}; auto fastgelu = FastGelu{}; @@ -108,7 +108,7 @@ auto GetCKGemmFastGeluTypeStringAndOps() { params->ldc, nop, nop, fastgelu); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); return Status::OK(); }; diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc index c665da89af36c..e82e15a304f4c 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc @@ -72,6 +72,12 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : RocmKernel(op_info) { channels_last_ = (op_info.GetAttrOrDefault("channels_last", static_cast(1)) != 0); } +Status GroupNorm::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/, + bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { + is_packed = false; + return Status::OK(); +} + Status GroupNorm::ComputeInternal(OpKernelContext* context) const { const Tensor* input = context->Input(0); const Tensor* gamma = context->Input(1); diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh index e87813fb19956..0146e81c6cf8c 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh @@ -79,7 +79,7 @@ auto GetCKGroupNormNHWCTypeStringAndOps() { nullptr, activation); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); return Status::OK(); }; diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index 7bc0f99081169..0f8fe68de717a 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -29,6 +29,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FusedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GatedRelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GatedRelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RemovePadding); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RestorePadding); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RestorePadding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Rfft); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Rfft); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Rfft); @@ -52,6 +60,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Affine); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Attention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Attention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, PackedAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, PackedAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, PackedMultiHeadAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, PackedMultiHeadAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BeamSearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ConvTransposeWithDynamicPads); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, Crop); @@ -61,12 +73,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GreedySearch); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GroupNorm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, NhwcConv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, NhwcConv); @@ -113,6 +124,17 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedMatMul); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedLayerNormalization); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedGelu); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QuantizeWithOrder); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, DequantizeWithOrder); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedAttention); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedLongformerAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedSelfAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GemmFastGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GemmFastGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GemmFastGelu); @@ -139,6 +161,7 @@ KernelCreateInfo BuildKernelCreateInfo() { return info; } +// clang-format off Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing @@ -162,70 +185,73 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // These ops were experimental ops in onnx domain which have been removed now. We add them here as // contrib ops to maintain backward compatibility - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -238,7 +264,6 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -249,16 +274,25 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo - + // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // TransposedMatMul is still here for backward compatibility BuildKernelCreateInfo, // backward compatibility BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -278,6 +312,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, #endif + }; for (auto& function_table_entry : function_table) { @@ -289,6 +324,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { return Status::OK(); } +// clang-format on } // namespace rocm } // namespace contrib diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index 6a82b3fcc734d..655d5014f3d60 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -22,6 +22,14 @@ #define HWCAP_ASIMDDP (1 << 20) #endif +#ifndef HWCAP2_I8MM +#define HWCAP2_I8MM (1 << 13) +#endif + +#ifndef HWCAP2_SVEI8MM +#define HWCAP2_SVEI8MM (1 << 9) +#endif + #endif // ARM #endif // Linux @@ -138,6 +146,9 @@ void CPUIDInfo::ArmLinuxInit() { is_hybrid_ = cpuinfo_get_uarchs_count() > 1; has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot(); has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); + has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); + has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); + const uint32_t core_cnt = cpuinfo_get_cores_count(); core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); is_armv8_narrow_ld_.resize(core_cnt, false); @@ -162,6 +173,10 @@ void CPUIDInfo::ArmLinuxInit() { pytorch_cpuinfo_init_ = false; has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0); has_fp16_ |= has_arm_neon_dot_; + + has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); + has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); + #endif } @@ -256,6 +271,9 @@ void CPUIDInfo::ArmWindowsInit() { has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); has_fp16_ |= has_arm_neon_dot_; + /* TODO: implement them when hw+sw is available for testing these features */ + has_arm_neon_i8mm_ = false; + has_arm_sve_i8mm_ = false; } #endif /* (arm or arm64) and windows */ diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index 386db347c669d..a15c75104b83a 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -28,6 +28,8 @@ class CPUIDInfo { // ARM bool HasArmNeonDot() const { return has_arm_neon_dot_; } + bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; } + bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } uint32_t GetCurrentCoreIdx() const; @@ -121,6 +123,8 @@ class CPUIDInfo { bool has_arm_neon_dot_{false}; bool has_fp16_{false}; + bool has_arm_neon_i8mm_{false}; + bool has_arm_sve_i8mm_{false}; #ifdef CPUIDINFO_ARCH_X86 diff --git a/onnxruntime/core/framework/fallback_cpu_capability.cc b/onnxruntime/core/framework/fallback_cpu_capability.cc index 3d971e6aa29a2..ef68b88187e08 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.cc +++ b/onnxruntime/core/framework/fallback_cpu_capability.cc @@ -9,6 +9,7 @@ #include "onnx/defs/data_type_utils.h" #include "core/framework/op_kernel.h" +#include "core/framework/utils.h" using namespace ONNX_NAMESPACE::Utils; @@ -77,7 +78,7 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe ORT_THROW_IF_ERROR(node->ForEachWithIndex( node->OutputDefs(), [&](const NodeArg& node_arg, size_t out_index) { - if (kernel_info->kernel_def->IsOutputOnCpu(out_index)) { + if (utils::IsOutputOnCpu(*node, kernel_info, out_index)) { cpu_output_args.insert(&node_arg); auto consumer_nodes = graph.GetConsumerNodes(node_arg.Name()); for (auto& consumer_node : consumer_nodes) { diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 1b492a3561396..e4fe0c7564548 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -13,7 +13,9 @@ #include "core/framework/kernel_registry_manager.h" #include "core/framework/kernel_registry.h" #include "core/graph/function.h" +#include "core/graph/function_utils.h" #include "core/graph/graph_viewer.h" +#include "core/graph/model.h" // uncomment this line to count non-CUDA ops in ONNX domain // #define COUNT_NON_CUDA_OPS @@ -129,6 +131,21 @@ struct GetCapabilityForEPParams { std::reference_wrapper debug_graph_fn; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) }; + +auto get_capabilities = [](const IExecutionProvider& ep, + const GraphViewer& graph_viewer, + const IExecutionProvider::IKernelLookup& kernel_lookup) { + auto capabilities = ep.GetCapability(graph_viewer, kernel_lookup); + + // In theory an EP could return an empty capability. Remove those. + capabilities.erase(std::remove_if(capabilities.begin(), capabilities.end(), + [](const std::unique_ptr& capability) { + return !capability || !capability->sub_graph; + }), + capabilities.end()); + + return capabilities; +}; } // namespace static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { @@ -143,21 +160,6 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - auto get_capabilities = [](const IExecutionProvider& ep, - const GraphViewer& graph_viewer, - const IExecutionProvider::IKernelLookup& kernel_lookup) { - auto capabilities = ep.GetCapability(graph_viewer, kernel_lookup); - - // In theory an EP could return an empty capability. Remove those. - capabilities.erase(std::remove_if(capabilities.begin(), capabilities.end(), - [](const std::unique_ptr& capability) { - return !capability || !capability->sub_graph; - }), - capabilities.end()); - - return capabilities; - }; - const auto& kernel_registry_mgr = params.kernel_registry_mgr.get(); const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type); const KernelLookup kernel_lookup{ep_type, @@ -239,6 +241,26 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { } #if !defined(ORT_MINIMAL_BUILD) + +// This function queries the capabilities for a given EP, but it does not assign the nodes. +// It also does not perform layout transformation. This will be done during normal partitioning. +static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer, + const KernelRegistryManager& kernel_registry_mgr, + const IExecutionProvider& current_ep, + std::vector>& capabilities) { + const auto& ep_type = current_ep.Type(); + + const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type); + const KernelLookup kernel_lookup{ep_type, + kernel_registries_for_ep, + kernel_registry_mgr.GetKernelTypeStrResolver()}; + + // TODO: Provide EP with a capability to look inside the functions. + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup); + + return Status::OK(); +} + /** * Check if a node can be placed on a specific provider. * Do nothing if the node is already assigned @@ -518,7 +540,7 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) { // successfully inlined, we re-run the partitioner on the modified graph. // NOTE: Inlining the function will change the nodes in the Graph instance, so we can't do that while iterating // using graph.Nodes(). - std::vector nodes_to_inline; + InlinedVector nodes_to_inline; for (auto& node : graph.Nodes()) { if (node.GetExecutionProviderType().empty() && node.CanBeInlined()) { nodes_to_inline.push_back(&node); @@ -533,6 +555,85 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) { return Status::OK(); } +static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_providers, + const KernelRegistryManager& kernel_registry_mgr, + Graph& graph, + InlinedHashSet& not_inlined, + size_t& inlined_count) { + // handle testing edge case where optimizers or constant lifting results in graph with no nodes. + // doing it here saves all providers checking for this in GetCapability + if (graph.NumberOfNodes() == 0) { + return Status::OK(); + } + + for (auto& node : graph.Nodes()) { + for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { + Graph* subgraph = entry.second; + // we pass through the FuncManager from the top level graph + ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, + kernel_registry_mgr, + *subgraph, + not_inlined, + inlined_count)); + } + } + + // Gather the candidates + InlinedVector inline_candidates; + for (auto& node : graph.Nodes()) { + if (node.CanBeInlined()) { + inline_candidates.push_back(node.Index()); + } + } + + if (inline_candidates.empty()) { + return Status::OK(); + } + + // Find out all the nodes that are already taken + const GraphViewer graph_viewer(graph); + + InlinedHashSet claimed_by_ep; + for (const auto& ep : execution_providers) { + std::vector> capabilities; + ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, capabilities)); + for (auto& capability : capabilities) { + const auto& nodes = capability->sub_graph->nodes; + if (nodes.size() == 1) { + // Single node capability. + ORT_IGNORE_RETURN_VALUE(claimed_by_ep.insert(nodes[0])); + } else { + // Make sure none is claimed by other EPs mirroring the logic in PartitionOnnxFormatModelImpl. + if (std::all_of(nodes.cbegin(), nodes.cend(), [&claimed_by_ep](NodeIndex node_index) { + return claimed_by_ep.count(node_index) == 0; + })) { + claimed_by_ep.insert(nodes.cbegin(), nodes.cend()); + } + } + } + } + + // TODO: Insert version check. We need to collect all the versions + // that imported by the model. If the version is not supported by + // the model, we can not inline it. + + for (auto node_index : inline_candidates) { + auto* node = graph.GetNode(node_index); + if (node != nullptr) { + if (claimed_by_ep.count(node_index) == 0) { + ORT_RETURN_IF_ERROR(graph.InlineFunction(*node)); + ++inlined_count; + } else { + // OpType is the same as function name. + auto function_id = function_utils::GetFunctionIdentifier(node->Domain(), node->OpType()); + ORT_IGNORE_RETURN_VALUE(not_inlined.insert(std::move(function_id))); + } + } + } + + return Status::OK(); +} + static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode, const ExecutionProviders& execution_providers, KernelRegistryManager& kernel_registry_manager) { @@ -693,6 +794,50 @@ static Status PartitionOrtFormatModel(const PartitionParams& partition_params, return Status::OK(); } +#ifndef ORT_MINIMAL_BUILD + +Status GraphPartitioner::InlineFunctionsAOT(Model& model, + const ExecutionProviders& execution_providers, + const KernelRegistryManager& kernel_registry_manager, + const logging::Logger& logger) const { + const auto local_functions_num = model.GetModelLocalFunctionTemplates().size(); + const bool is_there_local_functions = local_functions_num > 0; + + if (!is_there_local_functions) { + LOGS(logger, INFO) << "This model does not have any local functions defined. AOT Inlining is not performed"; + return Status::OK(); + } + + auto& graph = model.MainGraph(); + InlinedHashSet not_inlined; + do { + size_t inlined_count = 0; + ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, + kernel_registry_manager, + graph, + not_inlined, + inlined_count)); + + if (inlined_count == 0) { + break; + } + + ORT_RETURN_IF_ERROR(graph.Resolve()); + } while (true); + + model.RemoveLocalFunctionsProtos(not_inlined); + + LOGS(logger, INFO) + << "AOT inlining completed. (" << (local_functions_num - model.GetModelLocalFunctionTemplates().size()) + << ") functions of (" + << local_functions_num + << ") pruned."; + + return Status::OK(); +} + +#endif + Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, const layout_transformation::TransformLayoutFunction& transform_layout_function, Mode mode, diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index 36a27e906c651..4fc85c2588260 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -12,6 +12,7 @@ namespace onnxruntime { class ExecutionProviders; class KernelRegistryManager; +class Model; class GraphPartitioner { public: @@ -33,6 +34,28 @@ class GraphPartitioner { Mode mode = Mode::kNormal, const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const; +#ifndef ORT_MINIMAL_BUILD + ///

+ // Ahead of Time Function inlining. The main purpose of the function is to inline as many + // functions as possible and delete locally defined functions to reduce the size of the model. + // This would make other optimizations to be more effective. + // + // This function performs GetCapability on the graph and its subgraphs bottom up + // and inlines any functions that are not claimed by any of the execution providers. + // This function does not attempt to run layout transformation, and it does not assign EPs. + // The latter will be done by graph partitioning after Level1 optimizations are done. + /// + /// model instance + /// execution providers considered + /// registry manager + /// session logger + /// + Status InlineFunctionsAOT(Model& model, + const ExecutionProviders& execution_providers, + const KernelRegistryManager& kernel_registry_manager, + const logging::Logger& logger) const; +#endif + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphPartitioner); diff --git a/onnxruntime/core/framework/kernel_registry_manager.cc b/onnxruntime/core/framework/kernel_registry_manager.cc index c4eef5b27c1bb..b2ef853119588 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.cc +++ b/onnxruntime/core/framework/kernel_registry_manager.cc @@ -62,8 +62,13 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node, auto create_error_message = [&node, &status](const std::string& prefix) { std::ostringstream errormsg; - errormsg << prefix << node.OpType() << "(" << node.SinceVersion() << ")"; - errormsg << " (node:'" << node.Name() << "' ep:'" << node.GetExecutionProviderType() << "'). "; + errormsg << prefix; + const auto& domain = node.Domain(); + if (!domain.empty()) { + errormsg << domain << "."; + } + errormsg << node.OpType() << "(" << node.SinceVersion() << ")" + << " (node:'" << node.Name() << "' ep:'" << node.GetExecutionProviderType() << "'). "; if (!status.IsOK()) errormsg << status.ErrorMessage(); diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index f3e1acbbe523d..6e11bfe1ac8ea 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -85,6 +85,16 @@ ORT_API_STATUS_IMPL(OrtApis::GetSymbolicDimensions, return nullptr; } +ORT_API_STATUS_IMPL(OrtApis::SetSymbolicDimensions, + _In_ struct OrtTensorTypeAndShapeInfo* info, + _In_ const char** names, _In_ size_t dim_params_length) { + info->dim_params.clear(); + for (size_t idx = 0; idx < dim_params_length; ++idx) { + info->dim_params.push_back(names[idx]); + } + return nullptr; +} + ORT_API_STATUS_IMPL(OrtApis::GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* this_ptr, _Out_ size_t* out) { API_IMPL_BEGIN diff --git a/onnxruntime/core/framework/tunable.h b/onnxruntime/core/framework/tunable.h index 96b4cc53a022c..6d2dd641f6bc6 100644 --- a/onnxruntime/core/framework/tunable.h +++ b/onnxruntime/core/framework/tunable.h @@ -232,14 +232,15 @@ class TunableOp { return timer.Duration() / num_iter; } - static bool IsSupported(Op& op, const ParamsT* param) { - Status status = op.IsSupported(param); + // Filter all Status, only OK and TUNABLE_OP_UNSUPPORTED is left, other error status will be thrown, and to be + // processed by onnxruntime. We return Status to avoid the construction of op and params signature string. + static Status IsSupported(Op& op, const ParamsT* params) { + Status status = op.IsSupported(params); if (status.Category() == common::StatusCategory::NONE && status.Code() == common::StatusCode::INVALID_ARGUMENT) { - LOGS_DEFAULT(VERBOSE) << "unsupported reason: " << status.ErrorMessage(); - return false; + return status; } ORT_THROW_IF_ERROR(status); - return true; + return status; } protected: @@ -250,9 +251,9 @@ class TunableOp { int FindFastestImpl(const ParamsT* params, const std::vector>& candidates) { ITuningContext* ctx = params->TuningContext(); auto op_sig = Signature(); - auto param_sig = params->Signature(); - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl for " << op_sig << '(' << param_sig << ')'; - auto min_time = std::numeric_limits::infinity(); + auto params_sig = params->Signature(); + LOGS_DEFAULT(VERBOSE) << "finding fastest for " << op_sig << '(' << params_sig << ')'; + auto min_duration_ms = std::numeric_limits::infinity(); int id = -1; constexpr const int max_tuning_iter = 100; @@ -260,30 +261,32 @@ class TunableOp { for (size_t i = 0; i < candidates.size(); i++) { auto& candidate = const_cast&>(candidates[i]); - if (!IsSupported(candidate, params)) { - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl found unsupported " << op_sig << '(' << param_sig << ") id=" << i; + auto status = IsSupported(candidate, params); + if (!status.IsOK()) { + LOGS_DEFAULT(VERBOSE) << "├──unsupported id=" << i << ", " << op_sig << '(' << params_sig << ")"; + LOGS_DEFAULT(VERBOSE) << "│ reason: " << status.ErrorMessage(); continue; } WarmUp(candidate, params); auto approx_duration = Profile(candidate, params, approx_num_iter); - if (approx_duration > 2 * min_time) { - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl skip slow instance " << op_sig << '(' << param_sig << ") id=" << i; + if (approx_duration > 2 * min_duration_ms) { + LOGS_DEFAULT(VERBOSE) << "├──skip slow instance id=" << i; continue; } int tuning_iter = std::max(1, int(std::min(double(max_tuning_iter), ctx->GetMaxTuningDurationMs() / approx_duration))); - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl run instance " << op_sig << '(' << param_sig << ") id=" << i << " " << tuning_iter << " times."; - - auto time = Profile(candidate, params, tuning_iter); - if (time < min_time) { - min_time = time; + auto duration_ms = Profile(candidate, params, tuning_iter); + if (duration_ms < min_duration_ms) { + LOGS_DEFAULT(VERBOSE) << "├──found better instance, new best id=" << i << ", old id=" << id << ". " + << duration_ms << "ms, " << tuning_iter << " iters."; + min_duration_ms = duration_ms; id = static_cast(i); } } ORT_ENFORCE(id >= 0, "Could not find viable op"); - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl for " << op_sig << '(' << param_sig << ") found fastest with id=" << id; + LOGS_DEFAULT(VERBOSE) << "└──found fastest with id=" << id << " for " << op_sig << '(' << params_sig << ")"; std::this_thread::sleep_for(std::chrono::milliseconds(50)); return id; } diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index d63881ab4ff04..23fe5e1cd3d96 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -1025,7 +1025,32 @@ bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) overload_name = attrs.at("overload_name").s(); } - return !contrib::aten_ops::ATenOperatorExecutor::Instance().IsTensorArgument(op_name, overload_name, index); + return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, true); + } +#else + ORT_UNUSED_PARAMETER(node); +#endif + + return false; +} + +bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) { + if (p_kci && p_kci->kernel_def->IsOutputOnCpu(index)) { + return true; + } + +#ifdef ENABLE_ATEN + if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" && + node.Domain() == kPytorchAtenDomain) { + const auto& attrs = node.GetAttributes(); + ORT_ENFORCE(utils::HasString(attrs.at("operator"))); + std::string op_name = attrs.at("operator").s(); + std::string overload_name = ""; + if (attrs.find("overload_name") != attrs.end() && utils::HasString(attrs.at("overload_name"))) { + overload_name = attrs.at("overload_name").s(); + } + + return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, false); } #else ORT_UNUSED_PARAMETER(node); diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index ea6a629f87cb8..f0b1b9109d405 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -121,6 +121,7 @@ common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFet bool sync_subgraph_fetches = false); bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index); +bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index); template constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index e8d4785adf429..5bc18a4e69b47 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -171,10 +171,7 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c *output_shape.add_dim() = query_dims[1]; *output_shape.add_dim() = query_dims[2] * query_dims[4]; updateOutputShape(ctx, 0, output_shape); - return; - } - - if (hasInputShape(ctx, 2)) { + } else if (hasInputShape(ctx, 2)) { auto& value_shape = getInputShape(ctx, 2); auto& value_dims = value_shape.dim(); if (value_dims.size() != 3 && value_dims.size() != 4) { @@ -192,10 +189,7 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c ? (dmmha_packing ? value_dims[2] / 3 : value_dims[2]) : value_dims[1] * value_dims[3]; updateOutputShape(ctx, 0, output_shape); - return; - } - - if (hasInputShape(ctx, 1)) { + } else if (hasInputShape(ctx, 1)) { auto& key_shape = getInputShape(ctx, 1); if (key_shape.dim().size() == 5) { // packed KV ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput(ctx); @@ -217,7 +211,7 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c propagateElemTypeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); } else { if (sequence_length > 0 && past_dims[2].has_dim_value()) { - int64_t total_sequence_length = sequence_length + past_shape.dim(3).dim_value(); + int64_t total_sequence_length = sequence_length + past_dims[2].dim_value(); ONNX_NAMESPACE::TensorShapeProto present_shape; for (auto& dim : past_dims) { @@ -799,6 +793,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "The value to be filled in the attention mask. Default value is -10000.0f", AttributeProto::FLOAT, OPTIONAL_VALUE) + .Attr("output_qk", + "Need output the cross attention MatMul(Q, K)", + AttributeProto::INT, + OPTIONAL_VALUE) .Input(0, "query", "Query with shape (batch_size, 1, hidden_size) or packed QKV with shape " @@ -890,6 +888,12 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "while effective_seq_length = (past_sequence_length + kv_sequence_length).", "T", OpSchema::Optional) + .Output(3, + "qk", + "normalized Q * K, of shape (batch_size, num_heads, 1, head_size). ", + "V", + OpSchema::Optional) + .TypeConstraint("V", {"tensor(float)"}, "Constrain qk output types to float32 tensors.") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") @@ -942,7 +946,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(4, "key_padding_mask", - "Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)", + "Key padding mask with shape (batch_size), (3 * batch_size + 2), (batch_size, kv_sequence_length), (batch_size, total_sequence_length), " + "or (batch_size, sequence_length, total_sequence_length)", "M", OpSchema::Optional) .Input(5, @@ -1046,15 +1051,13 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key" "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", - "T", - OpSchema::Optional) + "T") .Output(2, "present_value", "present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value" "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", - "T", - OpSchema::Optional) + "T") .TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output to float tensors.") .TypeConstraint("M", {"tensor(int32)", "tensor(int64)"}, "Constrain past sequence length to int tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { @@ -1125,6 +1128,49 @@ ONNX_MS_OPERATOR_SET_SCHEMA( DecoderAttentionTypeAndShapeInference(ctx); })); +constexpr const char* RotaryEmbedding_ver1_doc = R"DOC( +RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices +that are multiplied to query and key before the inner product of query and key is taken. +)DOC"; +ONNX_MS_OPERATOR_SET_SCHEMA( + RotaryEmbedding, 1, + OpSchema() + .SetDoc(RotaryEmbedding_ver1_doc) + .Attr("scale", + "Custom scale will be used if specified. Default value is 1.0", + AttributeProto::FLOAT, + OPTIONAL_VALUE) + .Attr("interleaved", + "Rotate using interleaved pattern. Default value is 0 (False).", + AttributeProto::INT, + OPTIONAL_VALUE) + .Input(0, + "input", + "3D tensor with shape (batch_size, sequence_length, hidden_size)", + "T") + .Input(1, + "position_ids", + "1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)", + "M") + .Input(2, + "cos_cache", + "2D tensor with shape (max_sequence_length, head_size / 2).", + "T") + .Input(3, + "sin_cache", + "2D tensor with shape (max_sequence_length, head_size / 2).", + "T") + .Output(0, + "output", + "3D tensor with shape (batch_size, sequence_length, hidden_size)", + "T") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") + .TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + propagateShapeFromInputToOutput(ctx, 0, 0); + })); + constexpr const char* EmbedLayerNormalization_ver1_doc = R"DOC( EmbedLayerNormalization is the fusion of embedding layer in BERT model, with optional mask processing. The embedding layer takes input_ids (word IDs) and segment_ids (sentence IDs) to look up word_embedding, position_embedding, @@ -1496,4 +1542,4 @@ ONNX_MS_OPERATOR_SET_SCHEMA( })); } // namespace contrib -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc index 7cdd71014c02e..59adfc523c860 100644 --- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -83,17 +83,26 @@ void RegisterCollectiveOps() { ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedMatMul) .SetDomain(kMSDomain) .SinceVersion(1) - .Attr("device_mesh_elements", - "", - AttributeProto::INTS) - .Attr("device_mesh_shape", - "", - AttributeProto::INTS) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) .Attr("input_shard_specs", - "The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", AttributeProto::STRINGS) .Attr("output_shard_specs", - "The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.", + "Similar to input_shard_specs but for outputs.", AttributeProto::STRINGS) .Input(0, "A", "N-dimensional matrix A", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) .Input(1, "B", "N-dimensional matrix B", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) @@ -109,17 +118,26 @@ void RegisterCollectiveOps() { ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedSlice) .SetDomain(kMSDomain) .SinceVersion(1) - .Attr("device_mesh_elements", - "", - AttributeProto::INTS) - .Attr("device_mesh_shape", - "", - AttributeProto::INTS) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) .Attr("input_shard_specs", - "The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", AttributeProto::STRINGS) .Attr("output_shard_specs", - "The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.", + "Similar to input_shard_specs but for outputs.", AttributeProto::STRINGS) .Input( 0, @@ -173,6 +191,285 @@ void RegisterCollectiveOps() { .Output(0, "output", "Sliced data tensor.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensor types.") .TypeConstraint("Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types"); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReshape) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr( + "allowzero", + "(Optional) By default, when any value in the 'shape' input is equal to zero " + "the corresponding dimension value is copied from the input tensor dynamically. " + "allowzero=1 indicates that if any value in the 'shape' input is set to zero, " + "the zero value is honored, similar to NumPy.", + AttributeProto::INT, + static_cast(0)) + .Input(0, "data", "An input tensor.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "Specified shape for output.", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "reshaped", "Reshaped data.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensor types."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedExpand) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReduceSum) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr("keepdims", + "Keep the reduced dimension or not, default 1 mean keep reduced dimension.", + AttributeProto::INT, + static_cast(1)) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReduceMax) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr("keepdims", + "Keep the reduced dimension or not, default 1 mean keep reduced dimension.", + AttributeProto::INT, + static_cast(1)) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReduceMean) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr("keepdims", + "Keep the reduced dimension or not, default 1 mean keep reduced dimension.", + AttributeProto::INT, + static_cast(1)) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedUnsqueeze) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "axes", + "A 1-D tensor indicates the axes to add.", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedSqueeze) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "axes", + "A 1-D tensor indicates the axes to add.", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); } } // namespace contrib diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 21f9db7e486be..39449bea6303a 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -415,6 +415,7 @@ void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { // output 0 (sequences) shape: (batch_size, num_return_sequences, max_length) // output 1 (sequences_scores) shape: (batch_size, num_return_sequences) // output 2 (scores) shape: (max_length - sequence_length, batch_size, num_beams, vocab_size) + // output 3 (cross_attention): shape: (batch_size, num_return_sequences, Layers, Heads, max_length, Frames) if (!hasInputShape(ctx, 0)) { return; } @@ -1060,6 +1061,78 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GridSample, 1, updateOutputShape(ctx, 0, {N, C, H_out, W_out}); })); +ONNX_MS_OPERATOR_SET_SCHEMA( + UnfoldTensor, 1, + OpSchema() + .SetDoc("Returns a tensor which contains all slices of size size from input tensor in the dimension dim. " + "Step between two slices is given by step. " + "If sizedim is the size of dimension dim for input tensor, the size of dimension dim in " + "the returned tensor will be (sizedim - size) / step + 1. " + "An additional dimension of size size is appended in the returned tensor.") + .Attr("dim", "specify the dimension to unfold", AttributeProto::INT, static_cast(-1)) + .Attr("size", "specify the size", AttributeProto::INT) + .Attr("step", "specify the step.", AttributeProto::INT, static_cast(1)) + .Input(0, "input", "input tensor", "T") + .Output(0, "output", "Output tensor.", "T") + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Allow inputs and outputs to be any kind of tensor.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + + if (!hasInputShape(ctx, 0)) return; + auto& input_shape = getInputShape(ctx, 0); + const int rank = input_shape.dim_size(); + int64_t dim = getAttribute(ctx, "dim", -1); + dim = HandleNegativeAxis(dim, rank); + if (!input_shape.dim(static_cast(dim)).has_dim_value()) { + return; + } + int64_t dim_size = input_shape.dim(static_cast(dim)).dim_value(); + + const int64_t step = getAttribute(ctx, "step", -1); + if (step <= 0) { + fail_shape_inference("size attribute in UnfoldTensor must greater than 0.") + } + int64_t size = -1; + auto size_proto = ctx.getAttribute("size"); + if (!(size_proto)) { + fail_shape_inference("size attribute in UnfoldTensor not specified!") + } + size = size_proto->i(); + if (size > dim_size || size <= 0) { + fail_shape_inference("size attribute in UnfoldTensor not positive and less than the dim size!") + } + + ONNX_NAMESPACE::TensorShapeProto output_shape; + for (int d = 0; d < rank; d++) { + if (d == dim) { + output_shape.add_dim()->set_dim_value((dim_size - size) / step + 1); + } else { + *output_shape.add_dim() = input_shape.dim(d); + } + } + output_shape.add_dim()->set_dim_value(size); + updateOutputShape(ctx, 0, output_shape); + })); + +ONNX_MS_OPERATOR_SET_SCHEMA( + DynamicTimeWarping, 1, + OpSchema() + .SetDoc("Input is cost matrix where each value in input[r][c] is the cost for pass the point (r, c). From current point" + "(r, c), points (r+1, c), (r+1, c+1) or (r, c+1) could be arrived in next move. Given such cost matrix, return " + "dynamic time wrapping of shape [2, x], where the path made by all points (output[0][t], output[1][t])" + "have the lowest cost among all paths from (0, 0) to (M-1, N-1).") + .Input(0, "input", "Input cost tensor, it must be 2D tensor of shape M x N, or 1 x M x N", "F") + .Output(0, "output", "Output tensor. shape is [2, x], where max(M, N) <= x < M + N", "I") + .TypeConstraint("F", {"tensor(float)"}, "Constrain to float tensors.") + .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT32); + ONNX_NAMESPACE::TensorShapeProto resultShape; + resultShape.add_dim()->set_dim_value(2); + resultShape.add_dim(); + updateOutputShape(ctx, 0, resultShape); + })); + ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, OpSchema() .SetDoc("Beam Search for text generation. Supports GPT-2 decoder.") @@ -1110,6 +1183,122 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, BeamSearchShapeInference(ctx); })); +ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, + OpSchema() + .SetDoc("Beam Search for whisper model, especiall with cross_qk features etc.") + .Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT) + .Attr("pad_token_id", "The id of the padding token", AttributeProto::INT) + .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts.", AttributeProto::INT, static_cast(-1)) + .Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast(0)) + .Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast(0)) + .Attr("model_type", "Must be 2 for whisper", AttributeProto::INT, static_cast(2)) + .Attr("encoder", "The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.", AttributeProto::GRAPH, OPTIONAL_VALUE) + .Attr("init_decoder", + "The subgraph for the first decoding run. It will be called once before `decoder` subgraph. " + "This is relevant only for the GPT2 model. If this attribute is missing, the `decoder` subgraph will be used for all decoding runs", + AttributeProto::GRAPH, OPTIONAL_VALUE) + .Attr("decoder", "Decoder subgraph to execute in a loop.", AttributeProto::GRAPH) + .Attr("vocab_size", + "Size of the vocabulary. " + "If not provided, it will be inferred from the decoder subgraph's output shape", + AttributeProto::INT, static_cast(-1)) + .Attr("decoder_output_cross_qk", "If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("no_speech_token", + "The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.", + AttributeProto::INT, OPTIONAL_VALUE) + .Input(0, "input_ids", "The sequence used as a prompt for the generation in the encoder subgraph. Shape is (batch_size, sequence_length)", "F") + .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") + .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) + .Input(3, "num_beams", "Number of beams for beam search. 1 means no beam search. Shape is (1)", "I") + .Input(4, "num_return_sequences", "The number of returned sequences in the batch. Shape is (1)", "I") + .Input(5, "length_penalty", + "Exponential penalty to the length. Default value 1.0 means no penalty." + "Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences." + "Shape is (1,)", + "T", OpSchema::Optional) + .Input(6, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) + .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) + .Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional) + .Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) + .Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) + .Input(11, "logits_processor", "Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)", "I", OpSchema::Optional) + .Input(12, "cross_qk_layer_head", + "Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all" + "its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]", + "I", OpSchema::Optional) + .Input(13, "extra_decoding_ids", + "Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len)." + "In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) " + "are treated as stop of the extra_decoding_ids for corresponding batch.", + "I", OpSchema::Optional) + .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") + .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) + .Output(2, "scores", + "Processed beam scores for each vocabulary token at each generation step." + "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam." + "Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)", + "T", OpSchema::Optional) + .Output(3, "cross_qk", + "Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, " + "F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers," + "B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]." + "If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]", + "V", OpSchema::Optional) + .Output(4, "non_speech_probs", + "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token." + "Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph." + "The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]", + "T", OpSchema::Optional) + .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain to float tensors.") + .TypeConstraint("F", {"tensor(float)", "tensor(int32)", "tensor(float16)"}, "Constrain input type to float or int tensors.") + .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types") + .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to integer types") + .TypeConstraint("V", {"tensor(float)"}, "Constrain cross_qk to float32 tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + BeamSearchShapeInference(ctx); + if (ctx.getNumOutputs() > 3) { + ONNX_NAMESPACE::updateOutputElemType(ctx, 3, ONNX_NAMESPACE::TensorProto::FLOAT); + } + if (!hasInputShape(ctx, 0)) { + return; + } + auto& input_ids_shape = getInputShape(ctx, 0); + auto& input_ids_dims = input_ids_shape.dim(); + int64_t batch_size = input_ids_dims[0].dim_value(); + int64_t sequence_length = input_ids_dims[1].dim_value(); + + const auto max_length = ctx.getInputData(1); + const auto num_return_sequences = ctx.getInputData(4); + if (max_length == nullptr || num_return_sequences == nullptr) { // not initializer + return; + } + int max_length_value = 0; + if (!ParseScalar(max_length, max_length_value) || max_length_value <= 0) { + fail_shape_inference("Failed to parse max_length or it is not positive integer scalar"); + } + + int num_return_sequences_value = 0; + if (!ParseScalar(num_return_sequences, num_return_sequences_value) || num_return_sequences_value <= 0) { + fail_shape_inference("Failed to parse num_return_sequences or it is not positive integer scalar"); + } + + if (ctx.getNumOutputs() > 3) { + ONNX_NAMESPACE::TensorShapeProto cross_attn_shape; + cross_attn_shape.add_dim()->set_dim_value(batch_size); + cross_attn_shape.add_dim()->set_dim_value(num_return_sequences_value); + cross_attn_shape.add_dim(); // num of layer is unknown, no need to calc it from subgraph here + cross_attn_shape.add_dim(); // num of head is unknown, no need to calc it from subgraph here + cross_attn_shape.add_dim()->set_dim_value(max_length_value); + cross_attn_shape.add_dim()->set_dim_value(sequence_length); + updateOutputShape(ctx, 3, cross_attn_shape); + } + if (ctx.getNumOutputs() > 4) { + ONNX_NAMESPACE::TensorShapeProto non_speech_probs_shape; + non_speech_probs_shape.add_dim()->set_dim_value(batch_size); + updateOutputShape(ctx, 4, non_speech_probs_shape); + } + })); + ONNX_MS_OPERATOR_SET_SCHEMA(GreedySearch, 1, OpSchema() .SetDoc("Greedy Search for text generation.") @@ -2384,6 +2573,154 @@ ONNX_MS_OPERATOR_SET_SCHEMA(CropAndResize, 1, a fixed size = [crop_height, crop_width]. The result is a 4-D tensor [num_boxes, crop_height, crop_width, depth]. The resizing is corner aligned.)DOC")); +#if !defined(DISABLE_FLOAT8_TYPES) +#define GEMM_FLOAT8_TYPES \ + { "tensor(float8e4m3fn)", "tensor(float8e5m2)", "tensor(float16)", "tensor(bfloat16)", "tensor(float)" } +#else +#define GEMM_FLOAT8_TYPES \ + { "tensor(float16)", "tensor(bfloat16)", "tensor(float)" } +#endif + +ONNX_MS_OPERATOR_SET_SCHEMA(GemmFloat8, 1, + OpSchema() + .SetDoc(R"DOC(Generic Gemm for float and float 8.)DOC") + .Attr( + "transA", + "Whether A should be transposed. Float 8 only supprted transA=0.", + AttributeProto::INT, + static_cast(0)) + .Attr( + "transB", + "Whether B should be transposed. Float 8 only supprted transB=1.", + AttributeProto::INT, + static_cast(0)) + .Attr( + "alpha", + "Scalar multiplier for the product of input tensors A * B.", + AttributeProto::FLOAT, + 1.0f) + .Attr( + "beta", + "Scalar multiplier for the product of input bias C.", + AttributeProto::FLOAT, + 0.0f) + .Attr( + "dtype", + "Output Type. Same definition as attribute 'to' for operator Cast.", + AttributeProto::INT, + static_cast(1)) + .Attr( + "activation", + "Activation function, RELU or GELU or NONE (default).", + AttributeProto::STRING, + OPTIONAL_VALUE) + .Input( + 0, + "A", + "Input tensor A. " + "The shape of A should be (M, K) if transA is 0, " + "or (K, M) if transA is non-zero.", + "TA") + .Input( + 1, + "B", + "Input tensor B. " + "The shape of B should be (K, N) if transB is 0, " + "or (N, K) if transB is non-zero.", + "TB") + .Input( + 2, + "C", + "Input tensor C.", + "TC", + OpSchema::Optional) + .Input( + 3, + "scaleA", + "Scale of tensor A if A is float 8 tensor", + "TS", + OpSchema::Optional) + .Input( + 4, + "scaleB", + "Scale of tensor B if B is float 8 tensor", + "TS", + OpSchema::Optional) + .Input( + 5, + "scaleY", + "Scale of the output tensor if A or B is float 8.", + "TS", + OpSchema::Optional) + .Output(0, "Y", "Output tensor of shape (M, N).", "TR") + .TypeConstraint( + "TA", + GEMM_FLOAT8_TYPES, + "Constrain type to input A.") + .TypeConstraint( + "TB", + GEMM_FLOAT8_TYPES, + "Constrain type to input B.") + .TypeConstraint( + "TC", + {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, + "Constrain type to input C.") + .TypeConstraint( + "TR", + GEMM_FLOAT8_TYPES, + "Constrain type to result type.") + .TypeConstraint("TS", {"tensor(float)"}, + "Constrain type for all input scales (scaleA, scaleB, scaleY).") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0, TensorProto::FLOAT); + if (!hasNInputShapes(ctx, 2)) { + return; + } + auto transAAttr = ctx.getAttribute("transA"); + bool transA = transAAttr ? static_cast(transAAttr->i()) != 0 : false; + auto transBAttr = ctx.getAttribute("transB"); + bool transB = transBAttr ? static_cast(transBAttr->i()) != 0 : false; + auto& first_input_shape = getInputShape(ctx, 0); + auto& second_input_shape = getInputShape(ctx, 1); + if (first_input_shape.dim_size() != 2) { + fail_shape_inference("First input does not have rank 2"); + } + if (second_input_shape.dim_size() != 2) { + fail_shape_inference("Second input does not have rank 2"); + } + updateOutputShape(ctx, 0, {first_input_shape.dim(transA ? 1 : 0), second_input_shape.dim(transB ? 0 : 1)}); + })); + +static void MatmulWithQuantWeightShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, + int64_t K, + int64_t N, + bool transB) { + int input_a_idx = 0; + if (!hasInputShape(ctx, input_a_idx)) { + return; + } + + const auto& a_shape = ctx.getInputType(input_a_idx)->tensor_type().shape(); + if (a_shape.dim_size() == 0) { + fail_shape_inference("Input tensors of wrong rank (0)."); + } + + // TODO: check B shape + + const auto& dim_last = a_shape.dim(a_shape.dim_size() - 1); + ONNX_NAMESPACE::TensorShapeProto resultShape; + if (dim_last.has_dim_value() && dim_last.dim_value() != (transB ? K : N)) { + fail_shape_inference("Incompatible dimensions for matrix multiplication"); + } + + for (int i = 0; i < a_shape.dim_size() - 1; ++i) { + *resultShape.add_dim() = a_shape.dim(i); + } + resultShape.add_dim()->set_dim_value(transB ? N : K); + + *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape() = resultShape; +} + void RegisterContribSchemas() { ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(AttnLSTM, RegisterAttnLSTMContribOpSchema); ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(Range, RegisterRangeOpSchema); @@ -2972,6 +3309,119 @@ This op functions in much the same was as Dropout-11 and Dropout-13 do, execpt t } }); + static const char* MatMulNBits_ver1_doc = R"DOC( +MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7).It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: + 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. + 2. Input B is quantized with x bits which is specified by attribute 'bits'. It is quantized blockwisely along dimension 0 (e.g. column) with block size specified by attribute block_size. + And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. + 3. Input B's scale and zero point are specified by input scales and zero_points. + +Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: +- n_blocks_per_col = (K + block_size - 1) / block_size +- blob_size = block_size / 8 * bits + + For a block blob. It is stored in format: + struct Blob { + uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization + uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization + uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization + } + +Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] +Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is: + - [(N * n_blocks_per_col + 1) / 2] if bits <=4 + - [N * n_blocks_per_col] if bits > 4 + +)DOC"; + + ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBits) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(MatMulNBits_ver1_doc) + .Attr("K", "size of each input feature", AttributeProto::INT) + .Attr("N", "size of each output feature", AttributeProto::INT) + .Attr("bits", "number of bits used for weight quantization (default 4)", AttributeProto::INT) + .Attr("block_size", "number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT) + .Input(0, "A", "The input tensor, not quantized", "T1") + .Input(1, "B", "1-dimensional data blob", "T2") + .Input(2, "scales", "quantization scale", "T1") + .Input(3, "zero_points", "quantization zero points", "T2", OpSchema::Optional) + .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") + .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.") + .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + // Type inference + propagateElemTypeFromInputToOutput(ctx, 0, 0); + // Shape inference + int64_t in_features = getAttribute(ctx, "K", -1); + int64_t out_features = getAttribute(ctx, "N", -1); + MatmulWithQuantWeightShapeInference(ctx, in_features, out_features, true); + }); + + static const char* MatMulBnb4_ver1_doc = R"DOC( +MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: + 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. + 2. Input B is quantized with 4 bits with quantization data type specified by attribute 'quant_type'. It is transposed, flattened and quantized blockwisely with block size specified by attribute 'block_size'. + And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. + 3. Input B's quantization constants or scales are specified by input 'absmax'. + + Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. + Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + + + 1. (Default value) transB=True (Majorly used for forward pass) + Shape of A: [D0, D1, ..., Dn, K] + Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features]. + + The computation math: + dequant_B = dequant(B, absmax, quant_type, block_size) + transposed_dequant_B = dequant_B^T + output = A @ transposed_dequant_B + + Shape of output: [D0, D1, ..., Dn, N] + + 2. transB=False (Majorly used for backward pass) + Shape of A: [D0, D1, ..., Dn, N] + Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features]. + + The computation math: + dequant_B = dequant(B, absmax, quant_type, block_size) + output = A @ dequant_B + + Shape of output: [D0, D1, ..., Dn, K] + +)DOC"; + + ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulBnb4) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(MatMulBnb4_ver1_doc) + .Attr("K", "size of each input feature", AttributeProto::INT) + .Attr("N", "size of each output feature", AttributeProto::INT) + .Attr("block_size", "number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT) + .Attr("quant_type", "quantization data type. 0 for FP4, 1 for NF4.", AttributeProto::INT) + .Attr("training_mode", + "Indicate if the ops run in training_mode, by default, False.", + AttributeProto::INT, + static_cast(0)) + .Attr("transB", "Whether B should be transposed on the last two dimensions before doing multiplication. Default to be 1.", + AttributeProto::INT, static_cast(1)) + .Input(0, "A", "The input tensor, not quantized", "T1") + .Input(1, "B", "1-dimensional quantized data for weight", "T2") + .Input(2, "absmax", "quantization constants", "T1") + .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") + .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.") + .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + // Type inference + propagateElemTypeFromInputToOutput(ctx, 0, 0); + // Shape inference + int64_t in_features = getAttribute(ctx, "K", -1); + int64_t out_features = getAttribute(ctx, "N", -1); + bool transB = getAttribute(ctx, "transB", 1) != 0; + MatmulWithQuantWeightShapeInference(ctx, in_features, out_features, transB); + }); + #ifdef ENABLE_ATEN ONNX_CONTRIB_OPERATOR_SCHEMA(ATen) .SetDomain(kPytorchAtenDomain) diff --git a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc index c2f5edaa6149b..f81c3b8e0182c 100644 --- a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc @@ -42,7 +42,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "The number of groups of channels. It should be a divisor of the number of channels C", AttributeProto::INT) .Attr("activation", - "Activation after group normalization: 0 for None, 1 for Swish", + "Activation after group normalization: 0 for None, 1 for SiLU", AttributeProto::INT) .Attr("channels_last", "1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.", @@ -68,6 +68,85 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .TypeConstraint("M", {"tensor(float16)", "tensor(float)"}, "Constrain gamma and beta to float tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); +constexpr const char* SkipGroupNorm_ver1_doc = R"DOC( +This operator element-wise adds x, skip and bias, then apply group normalization and optional activation. + +This operator transforms input according to + s = x + skip + bias + y = gamma * (s - mean) / sqrt(variance + epsilon) + beta + +The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. +The num_channels must be divisible by num_groups. +The mean and standard-deviation of s are calculated separately over the each group. +The weight and bias are per-channel affine transform parameter vectors of size num_channels. + +The activation attribute can be used to enable activation after group normalization. +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + SkipGroupNorm, 1, + OpSchema() + .SetDoc(SkipGroupNorm_ver1_doc) + .Attr("epsilon", "The epsilon value to use to avoid division by zero", + AttributeProto::FLOAT, static_cast(1e-5)) + .Attr("groups", + "The number of groups of channels. It should be a divisor of the number of channels C", + AttributeProto::INT) + .Attr("activation", + "Activation after group normalization: 0 for None, 1 for SiLU", + AttributeProto::INT) + .Attr("channels_last", + "1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.", + AttributeProto::INT, + static_cast(1)) + .Input(0, + "X", + "Input data tensor. Dimensions are (N x H x W x C) when channels_last is 1 " + " or (N x C x H x W) otherwise, where N is the batch size, C is the number of channels," + " and H and W are the height and width of the data", + "T") + .Input(1, + "gamma", + "1D gamma tensor for normalization with shape (C), where C is number of channels", + "M") + .Input(2, + "beta", + "1D beta tensor for normalization with shape (C), where C is number of channels", + "M") + .Input(3, + "skip", + "4D or 2D skip tensor. The shape can be (N x H x W x C) or (N x 1 x 1 x C) or (N x C)", + "T") + .Input(4, + "bias", + "1D bias tensor. Dimensions are (C), where C is number of channels", + "T", + OpSchema::Optional) + .Output(0, + "Y", + "The output tensor of the same shape as X", + "T") + .Output(1, + "S", + "The element-wise sum of input x, skip and bias tensors. It has the same shape as X", + "T", + OpSchema::Optional) + .TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input X, skip, bias and output Y, S types to float tensors.") + .TypeConstraint("M", {"tensor(float16)", "tensor(float)"}, "Constrain gamma and beta to float tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (ctx.getNumOutputs() > 1) { + propagateElemTypeFromInputToOutput(ctx, 0, 1); + } + + if (hasInputShape(ctx, 0)) { + propagateShapeFromInputToOutput(ctx, 0, 0); + if (ctx.getNumOutputs() > 1) { + propagateShapeFromInputToOutput(ctx, 0, 1); + } + } + })); + constexpr const char* BiasSplitGelu_ver1_doc = R"DOC( A fusion used in diffusion model that after adding bias, hidden state is sliced into two tensors of same size, then left tensor multiplies the Gelu activation result of right tensor. diff --git a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc index d3fc5873cb274..03ad95260c0ad 100644 --- a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc +++ b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc @@ -90,12 +90,16 @@ void RegisterNHWCSchemaWithActivation(const RegistrationFunc& f, ::ONNX_NAMESPAC void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::function& fn) { // if the operator may be fused with an activation, use the WITH_ACTIVATION variant to add optional attributes // for the activation parameters. - // For now we only register operators from opset 11 on. Models can easily have their opset updated using ONNX tools + // We mainly register operators from opset 11 on . Models can easily have their opset updated using ONNX tools // so supporting older opsets is unnecessary. + // Older opsets are included on a per-operator basis as needed. // NOTE: This should be in sync with GetLayoutSensitiveOps in // /onnxruntime/core/optimizer/transpose_optimization/transpose_optimizer.cc + REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, AveragePool, 7); + REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, AveragePool, 10); REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, AveragePool, 11); + REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, AveragePool, 19); REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, BatchNormalization, 9); REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, BatchNormalization, 14); @@ -106,16 +110,18 @@ void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::function()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); @@ -196,8 +203,10 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); @@ -205,11 +214,14 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); } }; } // namespace contrib diff --git a/onnxruntime/core/graph/function_utils.cc b/onnxruntime/core/graph/function_utils.cc index 7477f48088a15..7b0a834a7ffc0 100644 --- a/onnxruntime/core/graph/function_utils.cc +++ b/onnxruntime/core/graph/function_utils.cc @@ -373,7 +373,8 @@ class Inliner { // Replace given name with a unique version of the name, and cache the // renaming-binding in current scope. void make_unique(std::string& name) { - auto new_name = prefix_ + name; + auto new_name{prefix_}; + new_name.append("_").append(name); auto& current_scope = rename_scopes_.back(); current_scope[name] = new_name; name = std::move(new_name); @@ -410,7 +411,7 @@ class Inliner { std::string rename_as = actuals.Get(i); if constexpr (isOutput) { if (rename_as.empty()) - rename_as.assign(prefix_).append(formal); + rename_as.assign(prefix_).append("_").append(formal); } current_scope[formal] = rename_as; if (!rename_as.empty()) @@ -420,7 +421,7 @@ class Inliner { std::string& formal = *formals.Mutable(i); std::string rename_as; if constexpr (isOutput) { - rename_as.assign(prefix_).append(formal); + rename_as.assign(prefix_).append("_").append(formal); } current_scope[formal] = rename_as; if (!rename_as.empty()) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 383c1d689d3c3..4b3cafcb39b78 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -860,18 +860,18 @@ Status Node::LoadEdgesFromOrtFormat(const onnxruntime::fbs::NodeEdge& fbs_node_e } #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) -void Node::Init(const std::string& name, - const std::string& op_type, - const std::string& description, - const std::vector& input_args, - const std::vector& output_args, +void Node::Init(std::string_view name, + std::string_view op_type, + std::string_view description, + gsl::span input_args, + gsl::span output_args, const NodeAttributes* attributes, - const std::string& domain) { + std::string_view domain) { name_ = name; op_type_ = op_type; description_ = description; - definitions_.input_defs = input_args; - definitions_.output_defs = output_args; + definitions_.input_defs.assign(input_args.begin(), input_args.end()); + definitions_.output_defs.assign(output_args.begin(), output_args.end()); domain_ = domain; can_be_saved_ = true; priority_ = 0; @@ -1145,7 +1145,8 @@ Graph::Graph(const Model& owning_model, IOnnxRuntimeOpSchemaCollectionPtr schema_registry, const logging::Logger& logger, bool strict_shape_type_inference) - : Graph(owning_model, graph_proto, domain_to_version, ir_version, schema_registry, nullptr, nullptr, logger, strict_shape_type_inference) {} + : Graph(owning_model, graph_proto, domain_to_version, ir_version, + schema_registry, nullptr, nullptr, logger, strict_shape_type_inference) {} Graph::Graph(const Model& owning_model, GraphProto* graph_proto, const std::unordered_map& domain_to_version, Version ir_version, @@ -3261,8 +3262,8 @@ Node& Graph::AddNode(const std::string& name, gsl::span output_args, const NodeAttributes* attributes, const std::string& domain) { - std::vector inputs; - std::vector outputs; + InlinedVector inputs; + InlinedVector outputs; inputs.resize(input_args.size()); outputs.resize(output_args.size()); int i = 0; @@ -3907,7 +3908,8 @@ Node& Graph::CreateFusedSubGraphNode(const IndexedSubGraph& sub_graph, const std // kernel lookup works as per usual, if not using an existing schema. if (sub_graph.schema_source == IndexedSubGraph::SourceOfSchema::EXISTING) { ORT_ENFORCE(SetOpSchemaFromRegistryForNode(fused_node), - "Schema was not found for fused node. Domain:", fused_node.Domain(), " OpType:", fused_node.OpType()); + "Schema was not found for fused node. Domain:", fused_node.Domain(), " OpType:", fused_node.OpType(), + " SinceVersion:", fused_node.SinceVersion()); } else if (IndexedSubGraph::SourceOfSchema::REUSE_OR_CREATE == sub_graph.schema_source) { auto schema_key = GenerateSchemaKey(sub_graph); if (reusable_fused_schema_map_.count(schema_key) == 0) { @@ -4019,69 +4021,100 @@ Node& Graph::FuseSubGraph(const IndexedSubGraph& sub_graph, return fused_node; } +Status Graph::AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& node_proto, + std::optional new_name) { + const gsl::not_null tensor{graph_proto_->add_initializer()}; + ORT_RETURN_IF_ERROR(utils::ConstantNodeProtoToTensorProto(node_proto, ModelPath(), *tensor, node_proto.output(0))); + + if (new_name.has_value()) { + tensor->set_name(std::string(new_name.value())); + } + + auto insert_result = name_to_initial_tensor_.emplace(tensor->name(), tensor); + ORT_ENFORCE(insert_result.second, "Constant node name: ", tensor->name(), + " conflicts with graph initializer. Check that the node names have been made unique."); + if (GetNodeArg(tensor->name()) == nullptr) { + TypeProto t{TypeProtoFromTensorProto(*tensor)}; + ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(tensor->name(), &t)); + } + +#if !defined(DISABLE_SPARSE_TENSORS) + if (node_proto.attribute(0).type() == AttributeProto_AttributeType_SPARSE_TENSOR) { + ORT_IGNORE_RETURN_VALUE(sparse_tensor_names_.emplace(tensor->name())); + } +#endif + + return Status::OK(); +} + +Status Graph::InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_inline) { + auto to_node_arg = [this](const std::string& name) { + return &this->GetOrCreateNodeArg(name, nullptr); + }; + + // Process constant nodes first and create NodeArg for these as they become initializers + // It is important for the initializers to have NodeArg created, first they are needed + // if the initializer is unused and removed, second if the node depends on the initializer, + // we can have Type attached to it. + InlinedVector non_constant_nodes; + non_constant_nodes.reserve(func_to_inline.node_size()); + for (const auto& inlined_node : func_to_inline.node()) { + if (inlined_node.op_type() == kConstant) { + // Copy constant nodes _value to name_to_initial_tensor_ + ORT_RETURN_IF_ERROR(AddConstantProtoAsInitializer(inlined_node, std::nullopt)); + } else { + non_constant_nodes.push_back(&inlined_node); + } + } + + for (const auto* inlined_node : non_constant_nodes) { + InlinedVector inputs; + InlinedVector outputs; + + for (const auto& tensor_name : inlined_node->input()) + inputs.push_back(to_node_arg(tensor_name)); + + for (const auto& tensor_name : inlined_node->output()) + outputs.push_back(to_node_arg(tensor_name)); + + onnxruntime::NodeAttributes new_attr_map; + new_attr_map.reserve(inlined_node->attribute_size()); + for (const auto& node_attr : inlined_node->attribute()) { + new_attr_map.insert_or_assign(node_attr.name(), node_attr); + } + ORT_IGNORE_RETURN_VALUE(AddNode(inlined_node->name(), inlined_node->op_type(), + inlined_node->doc_string(), inputs, outputs, + &new_attr_map, inlined_node->domain())); + } + + return Status::OK(); +} + Status Graph::InlineFunction(Node& callnode) { - const auto& model_path = ModelPath(); - auto output_edges = callnode.GetRelationships().output_edges; + // Remove output edges. Requirement for RemoveNode() below. + auto output_edges = callnode.GetRelationships().output_edges; // copy so RemoveEdge doesn't invalidate iterator for (const auto& output_edge : output_edges) { RemoveEdge(callnode.Index(), output_edge.GetNode().Index(), output_edge.GetSrcArgIndex(), output_edge.GetDstArgIndex()); } // create a uniq_identifier to append to every node name and intermediate input\outputs // to make sure there are no unintended duplicates - std::stringstream ss; - ss << "_inline_" << callnode.OpType(); - auto uniq_identifier = GenerateNodeName(ss.str()); + std::string base_uniq_identifier{"_inlfunc_"}; + base_uniq_identifier.append(callnode.OpType()); + const auto uniq_identifier = GenerateNodeName(base_uniq_identifier); + // Replace a (function-call) node by an inlined graph. if (!callnode.GetFunctionBody()) { // This is the normal use-case: inlining a FunctionProto (representing // a model-local function or a schema-defined function). - FunctionProto inlined_fp; + ONNX_NAMESPACE::FunctionProto inlined_fp; ORT_ENFORCE(callnode.TryGetFunctionProto(inlined_fp), "Node has no function body and cannot be inlined."); - function_utils::Specialize(inlined_fp, callnode, uniq_identifier); - auto to_node_arg = [this](const std::string& name) { - return &this->GetOrCreateNodeArg(name, nullptr); - }; - - // Process constant nodes first and create NodeArg for these as they become initializers - // It is important for the initializers to have NodeArg created, first they are needed - // if the initializer is unused and removed, second if the node depends on the initializer, - // we can have Type attached to it. - for (const auto& inlined_node : inlined_fp.node()) { - if (inlined_node.op_type() == kConstant) { - // Copy constant nodes _value to name_to_initial_tensor_ - const gsl::not_null tensor{graph_proto_->add_initializer()}; - ORT_RETURN_IF_ERROR(utils::ConstantNodeProtoToTensorProto(inlined_node, model_path, *tensor, inlined_node.output(0))); - auto insert_result = name_to_initial_tensor_.emplace(tensor->name(), tensor); - ORT_ENFORCE(insert_result.second, "Constant node name: ", tensor->name(), " in inlined function: ", - inlined_fp.name(), " conflicts with graph initializer. Check Specializing code."); - TypeProto t{TypeProtoFromTensorProto(*tensor)}; - ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(tensor->name(), &t)); - } - } - - for (const auto& inlined_node : inlined_fp.node()) { - if (inlined_node.op_type() != kConstant) { - InlinedVector inputs; - InlinedVector outputs; - - for (const auto& tensor_name : inlined_node.input()) - inputs.push_back(to_node_arg(tensor_name)); - - for (const auto& tensor_name : inlined_node.output()) - outputs.push_back(to_node_arg(tensor_name)); - - onnxruntime::NodeAttributes new_attr_map; - new_attr_map.reserve(inlined_node.attribute_size()); - for (const auto& node_attr : inlined_node.attribute()) { - onnx::AttributeProto attr_copy = node_attr; - new_attr_map[node_attr.name()] = std::move(attr_copy); - } - AddNode(inlined_node.name(), inlined_node.op_type(), - inlined_node.doc_string(), inputs, outputs, &new_attr_map, inlined_node.domain()); - } - } + // Make all the names unique and resolve nested graphs inputs to the outer scope. + function_utils::Specialize(inlined_fp, callnode, uniq_identifier); + // In this case, global Resolve() will take care of everything. + ORT_RETURN_IF_ERROR(InlineFunctionProto(inlined_fp)); } else { // Uncommon scenario. Inlining a node representing a fused sub-graph. // TODO: Unclear that this feature is needed. Can this be removed? @@ -4115,15 +4148,7 @@ Status Graph::InlineFunction(Node& callnode) { // Copy constant nodes _value to name_to_initial_tensor_ ONNX_NAMESPACE::NodeProto subgraph_node_proto{}; subgraph_node.ToProto(subgraph_node_proto); - const gsl::not_null tensor{graph_proto_->add_initializer()}; - ORT_RETURN_IF_ERROR(utils::ConstantNodeProtoToTensorProto(subgraph_node_proto, model_path, *tensor, subgraph_node_proto.output(0))); - auto insert_result = name_to_initial_tensor_.emplace(tensor->name(), tensor); - ORT_ENFORCE(insert_result.second, "Constant node name: ", tensor->name(), " in inlined subgraph: ", - subgraph.Name(), " conflicts with graph initializer. Check Specializing code."); - if (GetNodeArg(tensor->name()) == nullptr) { - TypeProto t{TypeProtoFromTensorProto(*tensor)}; - ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(tensor->name(), &t)); - } + ORT_RETURN_IF_ERROR(AddConstantProtoAsInitializer(subgraph_node_proto, std::nullopt)); } } diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index 05747a7e5124d..b3935e69ad7b1 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -41,6 +41,35 @@ namespace onnxruntime { #if !defined(ORT_MINIMAL_BUILD) +void Model::RemoveLocalFunctionsProtos(const InlinedHashSet& retained) { + auto* local_functions = model_proto_.mutable_functions(); + if (retained.empty()) { + model_local_function_templates_maps_.clear(); + model_local_functions_.clear(); + local_functions->erase(local_functions->begin(), local_functions->end()); + } else { + const auto retained_end = retained.cend(); + for (auto it = model_local_functions_.begin(); + it != model_local_functions_.end();) { + if (retained.find(it->first) == retained_end) { + model_local_function_templates_maps_.erase(it->first); + it = model_local_functions_.erase(it); + } else { + ++it; + } + } + + for (auto it = local_functions->begin(); it != local_functions->end();) { + const auto function_id = function_utils::GetFunctionIdentifier(it->domain(), it->name()); + if (retained.find(function_id) == retained_end) { + it = local_functions->erase(it); + } else { + ++it; + } + } + } +} + static constexpr int DEFAULT_PROTOBUF_BLOCK_SIZE = 4 * 1024 * 1024; Model::Model(const std::string& graph_name, @@ -95,10 +124,10 @@ Model::Model(const std::string& graph_name, for (auto& func : model_local_functions) { auto func_ptr = model_proto_.add_functions(); func_ptr->CopyFrom(func); - model_local_functions_.insert_or_assign(function_utils::GetFunctionIdentifier(func_ptr->domain(), func_ptr->name()), func_ptr); + model_local_functions_.insert_or_assign(function_utils::GetFunctionIdentifier(func_ptr->domain(), func_ptr->name()), + func_ptr); } - model_local_function_templates_.reserve(model_proto_.functions().size()); model_local_function_templates_maps_.reserve(model_proto_.functions().size()); for (auto& func : model_proto_.functions()) { auto func_schema_ptr = function_utils::CreateSchema(func.domain(), @@ -111,8 +140,8 @@ Model::Model(const std::string& graph_name, auto func_template_ptr = std::make_unique(); func_template_ptr->op_schema_ = std::move(func_schema_ptr); func_template_ptr->onnx_func_proto_ = &func; - model_local_function_templates_.push_back(std::move(func_template_ptr)); - model_local_function_templates_maps_[function_utils::GetFunctionIdentifier(func.domain(), func.name())] = model_local_function_templates_.back().get(); + model_local_function_templates_maps_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name()), + std::move(func_template_ptr)); } // need to call private ctor so can't use make_shared @@ -203,6 +232,14 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path, } } + // special-case the internal NHWC domain as it must match the ONNX opset if not explicitly imported + if (domain_to_version.find(kMSInternalNHWCDomain) == domain_to_version.end()) { + auto onnx_version = domain_to_version.find(kOnnxDomain); + if (onnx_version != domain_to_version.end()) { + domain_to_version[kMSInternalNHWCDomain] = onnx_version->second; + } + } + auto domain_map = allow_official_onnx_release_only_final ? schema_registry->GetLastReleasedOpsetVersions(false) : schema_registry->GetLatestOpsetVersions(false); @@ -220,7 +257,6 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path, model_local_functions_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name()), &func); } - model_local_function_templates_.reserve(model_proto_.functions().size()); model_local_function_templates_maps_.reserve(model_proto_.functions().size()); for (auto& func : model_proto_.functions()) { auto func_schema_ptr = function_utils::CreateSchema(func.domain(), @@ -233,9 +269,7 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path, auto func_template_ptr = std::make_unique(); func_template_ptr->op_schema_ = std::move(func_schema_ptr); func_template_ptr->onnx_func_proto_ = &func; - model_local_function_templates_.push_back(std::move(func_template_ptr)); - model_local_function_templates_maps_[function_utils::GetFunctionIdentifier(func.domain(), func.name())] = - model_local_function_templates_.back().get(); + model_local_function_templates_maps_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name()), std::move(func_template_ptr)); } // create instance. need to call private ctor so can't use make_unique @@ -244,7 +278,7 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path, logger, options.strict_shape_type_inference)); } -const InlinedHashMap& Model::GetModelLocalFunctionTemplates() const { +const NodeHashMap>& Model::GetModelLocalFunctionTemplates() const { return model_local_function_templates_maps_; } @@ -332,7 +366,7 @@ const Graph& Model::MainGraph() const noexcept { } #if !defined(ORT_MINIMAL_BUILD) -ModelProto Model::ToProto() { +ModelProto Model::ToProto() const { // We want to return back the original proto // To that end invoke const overload of ToGraphProto() // that returns by value and, therefore, allows us to filter @@ -346,7 +380,7 @@ ModelProto Model::ToProto() { ModelProto Model::ToGraphProtoWithExternalInitializers(const std::string& external_file_name, const PathString& file_path, - size_t initializer_size_threshold) { + size_t initializer_size_threshold) const { ModelProto result(model_proto_); const auto& graph = *graph_; *(result.mutable_graph()) = graph.ToGraphProtoWithExternalInitializers(external_file_name, diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 6bdb68dd734f0..4ce6660b794bc 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -139,7 +139,7 @@ class Model { // Returns empty string if not specified. const std::string GraphDocString() const; - const InlinedHashMap& GetModelLocalFunctionTemplates() const; + const NodeHashMap>& GetModelLocalFunctionTemplates() const; #else // Get model's IR version. @@ -182,14 +182,14 @@ class Model { #if !defined(ORT_MINIMAL_BUILD) // Get model's serialization proto data. - ONNX_NAMESPACE::ModelProto ToProto(); + ONNX_NAMESPACE::ModelProto ToProto() const; // Get model's serialization proto data. // Save initializer larger than the given threshold (in bytes) into an external binary file // with the given name. This function is useful to avoid hitting the size limit of protobuf files. ONNX_NAMESPACE::ModelProto ToGraphProtoWithExternalInitializers(const std::string& external_file_name, const PathString& file_path, - size_t initializer_size_threshold); + size_t initializer_size_threshold) const; #ifdef _WIN32 static common::Status Save(Model& model, const std::wstring& file_path); @@ -291,6 +291,13 @@ class Model { common::Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder, flatbuffers::Offset& model) const; + /// + /// Frees local function definitions in the model, excluding those in the `retained` set. + /// Called from GraphPartitioner::InlineFunctionsAOT. + /// + /// contains function IDs that should not be removed. + void RemoveLocalFunctionsProtos(const InlinedHashSet& retained); + #endif // !defined(ORT_MINIMAL_BUILD) static common::Status LoadFromOrtFormat(const onnxruntime::fbs::Model& fbs_model, @@ -312,14 +319,12 @@ class Model { // this map will be used for the local functions' schema's type/shape inference. // This container is used by ONNX code and must be an std::unordered_map. std::unordered_map model_local_functions_; - // this is the container that host the generated schemas for model local functions. - // the generated schemare will be used for graph resolving and type/shape inference. - // those schemas' type/shape inference will reference to the model_local_functions_ as context, - // so need to keep them with same lifetime. - InlinedVector> model_local_function_templates_; // this is the map from function id to the local function template. // this map will be used by graph to instantiate the function body. - InlinedHashMap model_local_function_templates_maps_; + // Defined as a node based map so the memory is released when not all of the functions + // are inlined and removed. + NodeHashMap> model_local_function_templates_maps_; + #else // properties that would normally come from ModelProto std::string producer_version_; diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h index 65b48a3009e72..7c7b729117e4a 100644 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ b/onnxruntime/core/mlas/inc/mlas_q4.h @@ -39,7 +39,7 @@ typedef enum { * @brief Computes the number of bytes required to pack and int4-quantize * a weight matrix * @param QType type of block quantization - * @param N the number of columns of matrix B. + * @param N the number of columns of matrix B. * @param K the number of rows of matrix B. * @return size of the packing buffer, 0 if the operation is not yet supported. */ @@ -53,11 +53,11 @@ MlasQ4GemmPackBSize( /** * @brief Prepack and Quantize fp32 weight tensor to int4 blocks - * + * * @param QType type of block quantization * @param PackedBuf destination buffer * @param FpData the pointer to fp32 matrix - * @param N the number of columns of matrix B. + * @param N the number of columns of matrix B. * @param K the number of rows of matrix B. * @param ldb leading dimension of B */ @@ -229,3 +229,120 @@ MlasQ8Q4GemmBatch( const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool ); + + +//////////////////////////////////////////////////////////// +// Blockwise quantization and dequantization where quantization +// parameters are packed into separate buffers. +// + +/** + * @brief For quantization type , and + * matrix shape [rows, columns], compute the shape of the + * quantization parameter matrix [meta_rows, meta_cols] +*/ +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); + +/** + * @brief For quantization type , and + * matrix shape [rows, columns], compute the shape of the + * quantized matrix [q_rows, q_cols]. The quantized matrix + * is in column major layout, with bits packed on the column. + * + * @tparam T + * @param block_size + * @param columnwise + * @param rows + * @param columns + * @param q_rows + * @param q_cols +*/ +template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ); + + +/** + * @brief Blockwise 4 bits quantization, resulting elements and quantization + * parameters (scales, zero points) are packed into separate matrices + * all in column major layout for faster access during subsequent matrix + * multiplication. + * + * @tparam ElementT type of the input matrix element, usually floating point + * @tparam qbits number of bits used for quantization, 4 for int4 + * + * @param dst points to the quantized matrix, shape [rows, columns] column major + * @param scales points to the scales matrix, column major + * @param zero_points points to the zero_points matrix, column major + * @param src points to the floating point matrix, to be quantized, row major shape [rows, columns] + * @param block_size size of the block to quantize, elements from the same block share the same scale and zero point + * @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row + * @param rows + * @param columns + * @param leading_dimension + * @param thread_pool +*/ +template +void +MlasQuantizeBlockwise( + uint8_t* dst, + ElementT* scales, + uint8_t* zero_points, + const ElementT* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ); + + +/** + * @brief Blockwise 4 bits dequantization, quantized elements and quantization + * parameters (scales, zero points) are from separate matrices packed + * in column major layout. Output is a floating point matrix in column + * major layout for faster access during subsequent matrix multiplication. + * + * @tparam ElementT type of the dequantized matrix element, usually floating point + * @tparam qbits number of bits used for quantization, 4 for int4 + * + * @param dst points to dequantized matrix shape [rows, columns] column major + * @param src points to quantized matrix, column major + * @param scales points to quantization scales, column major + * @param zero_points points to quantization zero points, column major + * @param block_size size of the block to quantize, elements from the same block share the same scale and zero point + * @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row + * @param rows + * @param columns + * @param thread_pool +*/ +template +void +MlasDequantizeBlockwise( + ElementT* dst, + const uint8_t* src, + const ElementT* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool + ); diff --git a/onnxruntime/core/mlas/lib/aarch64/QgemmS8S8KernelSmmla.S b/onnxruntime/core/mlas/lib/aarch64/QgemmS8S8KernelSmmla.S new file mode 100644 index 0000000000000..e18846c89030e --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/QgemmS8S8KernelSmmla.S @@ -0,0 +1,922 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmS8S8KernelSmmla.s + +Abstract: + + This module implements the kernels for the Int8 precision matrix/matrix + multiply operation (QGEMM). + +--*/ + +#include "asmmacro.h" + + .text + +// +// Stack frame layout for the smmla kernel. d8-d15, x19-x30 need save +// + .equ .LMlasQgemmKernel_backup_x19_x20, 0 + .equ .LMlasQgemmKernel_backup_x21_x22, 16 + .equ .LMlasQgemmKernel_backup_x23_x24, 32 + .equ .LMlasQgemmKernel_backup_x25_x26, 48 + .equ .LMlasQgemmKernel_backup_x27_x28, 64 + .equ .LMlasQgemmKernel_backup_d8_d9, 80 + .equ .LMlasQgemmKernel_backup_d10_d11, 96 + .equ .LMlasQgemmKernel_backup_d12_d13, 112 + .equ .LMlasQgemmKernel_backup_d14_d15, 128 + .equ .LMlasQgemmKernel_SavedRegisters, 144 + .equ .LMlasQgemmKernel_SavedRegisters_Neg, -144 + + +// +// Init Row Accumulators +// +// Generates the code to initialize the accumulators for a single row of the output +// block. +// +// +// Accumulators are initialized to ZeroPointB * RowSum + ColumnSum +// x7 for RowSumsBuffer pointer +// x10 for ColumnSumBuffer pointer +// x11 for ZeroPointB buffer pointer +// +// v12~v13 for RowSums values +// v14~v15 for ColumnSums values +// v0~v3 for ZeroPointB values +// + .macro InitRowAccumulators Columns, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, RowSumReg + + mul v7.4s, v\RowSumReg\().4s, v8.4s + mov v\Vec1Reg\().16b, v7.16b + add v\Vec1Reg\().4s, v\Vec1Reg\().4s, v0.4s +.if \Columns\() > 2 + mul v7.4s, v\RowSumReg\().4s, v9.4s + mov v\Vec2Reg\().16b, v7.16b + add v\Vec2Reg\().4s, v\Vec2Reg\().4s, v1.4s +.endif +.if \Columns\() > 4 + mul v7.4s, v\RowSumReg\().4s, v10.4s + mov v\Vec3Reg\().16b, v7.16b + add v\Vec3Reg\().4s, v\Vec3Reg\().4s, v2.4s +.endif +.if \Columns\() > 6 + mul v7.4s, v\RowSumReg\().4s, v11.4s + mov v\Vec4Reg\().16b, v7.16b + add v\Vec4Reg\().4s, v\Vec4Reg\().4s, v3.4s +.endif + + .endm + + +// +// InitBlockAccumulators +// +// Generates the code to initialize the accumulators for 8x8 output +// block. +// + .macro InitBlockAccumulators Mode, Columns, Rows + + ld1 {v14.4s},[x10],#16 // load ColumnSumBuffer[0] +.if \Columns\() > 4 + ld1 {v15.4s},[x10],#16 // load ColumnSumBuffer[4] +.endif + // v4~v7 will be set to matrixB after this, so, they can used now + dup v4.4s,v14.s[0] // broadcast column + dup v5.4s,v14.s[1] + dup v6.4s,v14.s[2] + dup v7.4s,v14.s[3] + + zip1 v0.4s, v4.4s, v5.4s + zip2 v1.4s, v6.4s, v7.4s +.if \Columns\() > 4 + dup v4.4s,v15.s[0] // broadcast column + dup v5.4s,v15.s[1] + dup v6.4s,v15.s[2] + dup v7.4s,v15.s[3] + + zip1 v2.4s, v4.4s, v5.4s + zip2 v3.4s, v6.4s, v7.4s +.endif + + // v8~v11 will anyway get set in MatrixA loading, so they are free to use now + movi v8.4s, #1 + movi v9.4s, #1 + movi v10.4s, #1 + movi v11.4s, #1 + + cbz x11,.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB + + ld1 {v4.4s},[x11],#16 // load ZeroPointB[0] + ld1 {v5.4s},[x11],#16 // load ZeroPointB[4] + + dup v6.4s, v4.s[0] + dup v7.4s, v4.s[1] + zip1 v8.4s, v6.4s, v7.4s + + dup v6.4s, v4.s[2] + dup v7.4s, v4.s[3] + zip1 v9.4s, v6.4s, v7.4s + + dup v6.4s, v5.s[0] + dup v7.4s, v5.s[1] + zip1 v10.4s, v6.4s, v7.4s + + dup v6.4s, v5.s[2] + dup v7.4s, v5.s[3] + zip1 v11.4s, v6.4s, v7.4s + +.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB: + dup v4.4s, v12.s[0] //boardcast RowSums + dup v5.4s, v12.s[1] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),16,17,18,19,6 +.if \Rows\() > 2 + dup v4.4s, v12.s[2] //boardcast RowSums + dup v5.4s, v12.s[3] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),20,21,22,23,6 +.endif +.if \Rows\() > 4 + dup v4.4s,v13.s[0] // broadcast row sums + dup v5.4s,v13.s[1] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),24,25,26,27,6 +.endif +.if \Rows\() > 6 + dup v4.4s,v13.s[2] // broadcast row sums + dup v5.4s,v13.s[3] + + uzp1 v6.2d, v4.2d, v5.2d + InitRowAccumulators \Columns\(),28,29,30,31,6 +.endif + + .endm + + +// LoadPackedMatrixABy16Elements +// +// Generates the code to load 16 elements from matrix A. +// + .macro LoadPackedMatrixABy16Elements Rows +.if \Rows\() == 1 + ldr q8,[x0],#8 +.else + ldr q8,[x0],#16 + +.if \Rows\() > 2 + ldr q9,[x0],#16 +.endif + +.if \Rows\() > 4 + ldr q10,[x0],#16 +.endif + +.if \Rows\() > 6 + ldr q11,[x0],#16 +.endif +.endif + .endm + + +// +// MultiplyAccumulateRow +// +// Generates the code to multiply and accumulate a single row of the output +// block. +// + + .macro MultiplyAccumulateRow Columns, MatrixAReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + smmla v\Vec1Reg\().4s, \MatrixAReg\().16b, v4.16b +.if \Columns\() > 2 + smmla v\Vec2Reg\().4s, \MatrixAReg\().16b, v5.16b +.endif +.if \Columns\() > 4 + smmla v\Vec3Reg\().4s, \MatrixAReg\().16b, v6.16b +.endif +.if \Columns\() > 6 + smmla v\Vec4Reg\().4s, \MatrixAReg\().16b, v7.16b +.endif + + .endm + +// +// MultiplyAccumulateBlock +// +// Generates the code to multiply and accumulate into the output block. +// + + .macro MultiplyAccumulateBlock Columns, Rows + + MultiplyAccumulateRow \Columns\(),v8,16,17,18,19 +.if \Rows\() > 2 + MultiplyAccumulateRow \Columns\(),v9,20,21,22,23 +.endif +.if \Rows\() > 4 + MultiplyAccumulateRow \Columns\(),v10,24,25,26,27 +.endif +.if \Rows\() > 6 + MultiplyAccumulateRow \Columns\(),v11,28,29,30,31 +.endif + + .endm + +// +// ComputeBlockLoop +// +// Generates the code to loop over K entries of the input matrices to produce +// the output block. +// + + .macro ComputeBlockLoop Mode, Columns, Rows + + InitBlockAccumulators \Mode\(), \Columns\(),\Rows\() + + sub x9,x3,#1 // block count to process + tbnz x9,#63,.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop: + + LoadPackedMatrixABy16Elements \Rows\() + ld1 {v4.16b - v7.16b}, [x1], #64 + MultiplyAccumulateBlock \Columns\(),\Rows\() + + sub x9,x9,#1 + tbz x9,#63,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop +.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks: + add x9,x9,#1 // correct for over-subtract above + cbz x9,.L\Mode\().Output\Columns\().x\Rows\().Block + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4PaddedLoop: + LoadPackedMatrixABy16Elements \Rows\() + ld1 {v4.16b - v7.16b}, [x1], #64 + MultiplyAccumulateBlock \Columns\(),\Rows\() + +.L\Mode\().Output\Columns\().x\Rows\().Block: + + .endm + + +// +// OutputRow2Element +// OutputRow4Element +// OutputRow6Element +// OutputRow8Element +// OutputRow10Element +// OutputRow12Element +// OutputRow14Element +// OutputRow16Element +// +// Generates the code to store elements to the output block. +// + + .macro OutputRow2Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr s8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr s9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + mov v8.S[2], v9.S[0] + add v8.4s,v8.4s,v\Vec1Reg\().4s + + mov w27, v8.S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v8.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + mov w27, v\Vec1Reg\().S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v\Vec1Reg\().S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow4Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + + mov v8.D[1], v9.D[0] + + add v8.4s,v8.4s,v\Vec1Reg\().4s + + mov x27, v8.D[0] + mov x28, v8.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + mov x27, v\Vec1Reg\().D[0] + mov x28, v\Vec1Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.endif + + .endm + + + .macro OutputRow6Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#8 + ldr w28,[\AddrReg1\()],#-8 + mov v8.S[2], w28 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-8 + mov v9.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + mov x27, v8.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v8.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v9.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v9.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + mov x27, v4.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v4.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v5.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v5.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow8Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + +.endif + + .endm + + + .macro OutputRow10Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr w28, [\AddrReg1\()],#-16 + +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr w27,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + mov v8.S[0], w28 + mov v8.S[2], w27 + + add v8.4s,v8.4s,v\Vec3Reg\().4s + + mov w27, v8.S[0] + mov w28, v8.S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov w27, v\Vec3Reg\().S[0] + mov w28, v\Vec3Reg\().S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif +.endif + +.endm + + + .macro OutputRow12Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#-16 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + mov v11.D[0],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + + mov v10.D[1], v11.D[0] + + add v10.4s,v10.4s,v\Vec3Reg\().4s + + mov x27, v10.D[0] + mov x28, v10.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov x27, v\Vec3Reg\().D[0] + mov x28, v\Vec3Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif +.endif + + .endm + + .macro OutputRow14Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#8 + ldr w28, [\AddrReg1\()],#-24 + mov v10.S[2], w28 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-24 + mov v11.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + add v10.4s,v10.4s,v6.4s + add v11.4s,v11.4s,v7.4s + + str q8,[\AddrReg1\()],#16 + + mov x27, v10.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v10.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 + mov x27, v11.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v11.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + str q4,[\AddrReg1\()],#16 + mov x27, v6.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v6.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 + mov x27, v7.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v7.S[2] + str w27, [\AddrReg2\()],#4 +.endif +.endif + + .endm + + + .macro OutputRow16Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldp q8,q10,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldp q9,q11,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + add v10.4s,v10.4s,v6.4s + add v11.4s,v11.4s,v7.4s + + stp q8,q10,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q9,q11,[\AddrReg2\()],#32 +.endif +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + stp q4,q6,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q5,q7,[\AddrReg2\()],#32 +.endif +.endif + + .endm + +// +// OutputBlock +// +// Generates the code to store the output block. +// + + .macro OutputBlock Mode, Columns, Rows + + OutputRow\Columns\()Element \Mode\(),x2,x13,16,17,18,19,(\Rows\() == 1) + +.if \Rows\() > 2 + OutputRow\Columns\()Element \Mode\(),x14,x15,20,21,22,23,(\Rows\() == 3) +.endif + +.if \Rows\() > 4 + OutputRow\Columns\()Element \Mode\(),x16,x17,24,25,26,27,(\Rows\() == 5) +.endif + +.if \Rows\() > 6 + OutputRow\Columns\()Element \Mode\(),x18,x19,28,29,30,31,(\Rows\() == 7) +.endif + + .endm +// +// ProcessRows +// +// Generates the code to process a compute and store the output block for a +// fixed number of rows. +// + + .macro ProcessRows Mode, Rows + mov x4,#\Rows\() // return number of rows handled + cmp x5,#6 + ble .L\Mode\().ProcessNextColumnLoop6x\Rows\() + +.L\Mode\().ProcessNextColumnLoop8x\Rows\(): + ComputeBlockLoop \Mode\(),8,\Rows\() + + sub x5,x5,#8 + cmp x5,#0 + blt .L\Mode\().Output14ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),16,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#6 + bgt .L\Mode\().ProcessNextColumnLoop8x\Rows\() + cbz x5,.L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop6x\Rows\(): + + cmp x5,#4 + ble .L\Mode\().ProcessNextColumnLoop4x\Rows\() + ComputeBlockLoop \Mode\(),6,\Rows\() + sub x5,x5,#6 + cmp x5,#0 + blt .L\Mode\().Output10ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),12,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#4 + bgt .L\Mode\().ProcessNextColumnLoop6x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop4x\Rows\(): + cmp x5,#2 + ble .L\Mode\().ProcessNextColumnLoop2x\Rows\() + ComputeBlockLoop \Mode\(),4,\Rows\() + sub x5,x5,#4 + cmp x5,#0 + blt .L\Mode\().Output6ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),8,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#2 + bgt .L\Mode\().ProcessNextColumnLoop4x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop2x\Rows\(): + ComputeBlockLoop \Mode\(),2,\Rows\() + sub x5,x5,#2 + cmp x5,#0 + blt .L\Mode\().Output2ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),4,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#2 + b .L\Mode\().ExitKernel + +.L\Mode\().Output14ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),14,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output10ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),10,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output6ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),6,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output2ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),2,\Rows\() + b .L\Mode\().ExitKernel + + .endm + + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (x0) - Supplies the address of matrix A. The matrix data has been packed + using MlasGemmQuantCopyPackA. + + B (x1) - Supplies the address of matrix B. The matrix data has been packed + using MlasGemmQuantCopyPackB. + + C (x2) - Supplies the address of matrix C. + + PackedCountK (x3) - Supplies the number of packed columns from matrix A and + the number of packed rows from matrix B to iterate over. + + CountM (x4) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (x5) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + ldc (x6) - Supplies the first dimension of matrix C. + + RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values + have been pre-scaled by the zero point offset of matrix B if the offset + is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be + scaled by the per-column zero point offsets of matrix B. These values are + accumulated into every row of matrix C. + + ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied + by the zero point offset of matrix A. These values are accumulated into + every column of matrix C. + + ZeroPointB - Optionally supplies the per-column zero point offsets of matrix + B, else nullptr if the matrix B is using per-tensor quantization. + +Return Value: + + Returns the number of rows handled. + +--*/ + + .macro QgemmS8S8KernelSmmlaFunction Mode + + FUNCTION_ENTRY MlasGemmS8S8KernelSmmla\Mode\() + + ldr x10,[sp, #0] + ldr x11,[sp,#8] + + stp x19, x20, [sp, #.LMlasQgemmKernel_SavedRegisters_Neg]! + stp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] + stp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] + stp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] + stp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] + stp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] + stp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] + stp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] + stp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] + + add x13,x2,x6,lsl #2 // compute matrix C plus 1 row + add x14,x13,x6,lsl #2 // compute matrix C plus 2 rows + add x15,x14,x6,lsl #2 // compute matrix C plus 3 rows + add x16,x15,x6,lsl #2 // compute matrix C plus 4 rows + add x17,x16,x6,lsl #2 // compute matrix C plus 5 rows + add x18,x17,x6,lsl #2 // compute matrix C plus 6 rows + add x19,x18,x6,lsl #2 // compute matrix C plus 7 rows + + mov x8,x0 // save matrix A + +// +// Process 8 rows of the matrices. +// + ld1 {v12.4s},[x7],#16 // load row sum 1 ~ 4 + cmp x4,#8 + blt .L\Mode\().ProcessCountMLessThan8 + ld1 {v13.4s},[x7],#16 // load row sum 5 ~ 8 + ProcessRows \Mode\(),8 + +// +// Restore non-volatile registers and return. +// + +.L\Mode\().ExitKernel: + mov x0,x4 + + ldp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] + ldp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] + ldp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] + ldp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] + ldp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] + ldp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] + ldp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] + ldp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] + ldp x19, x20, [sp], #.LMlasQgemmKernel_SavedRegisters + + ret + +// +// Process 4 rows of the matrix. +// + +.L\Mode\().ProcessCountMLessThan8: + cmp x4,#4 + blt .L\Mode\().ProcessCountMLessThan4 + ProcessRows \Mode\(),4 + b .L\Mode\().ExitKernel + +// +// Process 2 row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan4: + cmp x4,#2 + blt .L\Mode\().ProcessCountMLessThan2 + + ProcessRows \Mode\(),2 + b .L\Mode\().ExitKernel + + +// +// Process the last row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan2: + ProcessRows \Mode\(),1 + b .L\Mode\().ExitKernel + + + .endm + + QgemmS8S8KernelSmmlaFunction Zero + QgemmS8S8KernelSmmlaFunction Add + + .end diff --git a/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelUmmla.S b/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelUmmla.S new file mode 100644 index 0000000000000..baf6e21e6ff06 --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelUmmla.S @@ -0,0 +1,922 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmU8X8KernelUmmla.s + +Abstract: + + This module implements the kernels for the Int8 precision matrix/matrix + multiply operation (QGEMM). + +--*/ + +#include "asmmacro.h" + + .text + +// +// Stack frame layout for the ummla kernel. d8-d15, x19-x30 need save +// + .equ .LMlasQgemmKernel_backup_x19_x20, 0 + .equ .LMlasQgemmKernel_backup_x21_x22, 16 + .equ .LMlasQgemmKernel_backup_x23_x24, 32 + .equ .LMlasQgemmKernel_backup_x25_x26, 48 + .equ .LMlasQgemmKernel_backup_x27_x28, 64 + .equ .LMlasQgemmKernel_backup_d8_d9, 80 + .equ .LMlasQgemmKernel_backup_d10_d11, 96 + .equ .LMlasQgemmKernel_backup_d12_d13, 112 + .equ .LMlasQgemmKernel_backup_d14_d15, 128 + .equ .LMlasQgemmKernel_SavedRegisters, 144 + .equ .LMlasQgemmKernel_SavedRegisters_Neg, -144 + + +// +// Init Row Accumulators +// +// Generates the code to initialize the accumulators for a single row of the output +// block. +// +// +// Accumulators are initialized to ZeroPointB * RowSum + ColumnSum +// x7 for RowSumsBuffer pointer +// x10 for ColumnSumBuffer pointer +// x11 for ZeroPointB buffer pointer +// +// v12~v13 for RowSums values +// v14~v15 for ColumnSums values +// v0~v3 for ZeroPointB values +// + .macro InitRowAccumulators Columns, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, RowSumReg + + mul v7.4s, v\RowSumReg\().4s, v8.4s + mov v\Vec1Reg\().16b, v7.16b + add v\Vec1Reg\().4s, v\Vec1Reg\().4s, v0.4s +.if \Columns\() > 2 + mul v7.4s, v\RowSumReg\().4s, v9.4s + mov v\Vec2Reg\().16b, v7.16b + add v\Vec2Reg\().4s, v\Vec2Reg\().4s, v1.4s +.endif +.if \Columns\() > 4 + mul v7.4s, v\RowSumReg\().4s, v10.4s + mov v\Vec3Reg\().16b, v7.16b + add v\Vec3Reg\().4s, v\Vec3Reg\().4s, v2.4s +.endif +.if \Columns\() > 6 + mul v7.4s, v\RowSumReg\().4s, v11.4s + mov v\Vec4Reg\().16b, v7.16b + add v\Vec4Reg\().4s, v\Vec4Reg\().4s, v3.4s +.endif + + .endm + + +// +// InitBlockAccumulators +// +// Generates the code to initialize the accumulators for 8x8 output +// block. +// + .macro InitBlockAccumulators Mode, Columns, Rows + + ld1 {v14.4s},[x10],#16 // load ColumnSumBuffer[0] +.if \Columns\() > 4 + ld1 {v15.4s},[x10],#16 // load ColumnSumBuffer[4] +.endif + // v4~v7 will be set to matrixB after this, so, they can used now + dup v4.4s,v14.s[0] // broadcast column + dup v5.4s,v14.s[1] + dup v6.4s,v14.s[2] + dup v7.4s,v14.s[3] + + zip1 v0.4s, v4.4s, v5.4s + zip2 v1.4s, v6.4s, v7.4s +.if \Columns\() > 4 + dup v4.4s,v15.s[0] // broadcast column + dup v5.4s,v15.s[1] + dup v6.4s,v15.s[2] + dup v7.4s,v15.s[3] + + zip1 v2.4s, v4.4s, v5.4s + zip2 v3.4s, v6.4s, v7.4s +.endif + + // v8~v11 will anyway get set in MatrixA loading, so they are free to use now + movi v8.4s, #1 + movi v9.4s, #1 + movi v10.4s, #1 + movi v11.4s, #1 + + cbz x11,.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB + + ld1 {v4.4s},[x11],#16 // load ZeroPointB[0] + ld1 {v5.4s},[x11],#16 // load ZeroPointB[4] + + dup v6.4s, v4.s[0] + dup v7.4s, v4.s[1] + zip1 v8.4s, v6.4s, v7.4s + + dup v6.4s, v4.s[2] + dup v7.4s, v4.s[3] + zip1 v9.4s, v6.4s, v7.4s + + dup v6.4s, v5.s[0] + dup v7.4s, v5.s[1] + zip1 v10.4s, v6.4s, v7.4s + + dup v6.4s, v5.s[2] + dup v7.4s, v5.s[3] + zip1 v11.4s, v6.4s, v7.4s + +.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB: + dup v4.4s, v12.s[0] //boardcast RowSums + dup v5.4s, v12.s[1] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),16,17,18,19,6 +.if \Rows\() > 2 + dup v4.4s, v12.s[2] //boardcast RowSums + dup v5.4s, v12.s[3] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),20,21,22,23,6 +.endif +.if \Rows\() > 4 + dup v4.4s,v13.s[0] // broadcast row sums + dup v5.4s,v13.s[1] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),24,25,26,27,6 +.endif +.if \Rows\() > 6 + dup v4.4s,v13.s[2] // broadcast row sums + dup v5.4s,v13.s[3] + + uzp1 v6.2d, v4.2d, v5.2d + InitRowAccumulators \Columns\(),28,29,30,31,6 +.endif + + .endm + + +// LoadPackedMatrixABy16Elements +// +// Generates the code to load 16 elements from matrix A. +// + .macro LoadPackedMatrixABy16Elements Rows +.if \Rows\() == 1 + ldr q8,[x0],#8 +.else + ldr q8,[x0],#16 + +.if \Rows\() > 2 + ldr q9,[x0],#16 +.endif + +.if \Rows\() > 4 + ldr q10,[x0],#16 +.endif + +.if \Rows\() > 6 + ldr q11,[x0],#16 +.endif +.endif + .endm + + +// +// MultiplyAccumulateRow +// +// Generates the code to multiply and accumulate a single row of the output +// block. +// + + .macro MultiplyAccumulateRow Columns, MatrixAReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + ummla v\Vec1Reg\().4s, \MatrixAReg\().16b, v4.16b +.if \Columns\() > 2 + ummla v\Vec2Reg\().4s, \MatrixAReg\().16b, v5.16b +.endif +.if \Columns\() > 4 + ummla v\Vec3Reg\().4s, \MatrixAReg\().16b, v6.16b +.endif +.if \Columns\() > 6 + ummla v\Vec4Reg\().4s, \MatrixAReg\().16b, v7.16b +.endif + + .endm + +// +// MultiplyAccumulateBlock +// +// Generates the code to multiply and accumulate into the output block. +// + + .macro MultiplyAccumulateBlock Columns, Rows + + MultiplyAccumulateRow \Columns\(),v8,16,17,18,19 +.if \Rows\() > 2 + MultiplyAccumulateRow \Columns\(),v9,20,21,22,23 +.endif +.if \Rows\() > 4 + MultiplyAccumulateRow \Columns\(),v10,24,25,26,27 +.endif +.if \Rows\() > 6 + MultiplyAccumulateRow \Columns\(),v11,28,29,30,31 +.endif + + .endm + +// +// ComputeBlockLoop +// +// Generates the code to loop over K entries of the input matrices to produce +// the output block. +// + + .macro ComputeBlockLoop Mode, Columns, Rows + + InitBlockAccumulators \Mode\(), \Columns\(),\Rows\() + + sub x9,x3,#1 // block count to process + tbnz x9,#63,.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop: + + LoadPackedMatrixABy16Elements \Rows\() + ld1 {v4.16b - v7.16b}, [x1], #64 + MultiplyAccumulateBlock \Columns\(),\Rows\() + + sub x9,x9,#1 + tbz x9,#63,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop +.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks: + add x9,x9,#1 // correct for over-subtract above + cbz x9,.L\Mode\().Output\Columns\().x\Rows\().Block + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4PaddedLoop: + LoadPackedMatrixABy16Elements \Rows\() + ld1 {v4.16b - v7.16b}, [x1], #64 + MultiplyAccumulateBlock \Columns\(),\Rows\() + +.L\Mode\().Output\Columns\().x\Rows\().Block: + + .endm + + +// +// OutputRow2Element +// OutputRow4Element +// OutputRow6Element +// OutputRow8Element +// OutputRow10Element +// OutputRow12Element +// OutputRow14Element +// OutputRow16Element +// +// Generates the code to store elements to the output block. +// + + .macro OutputRow2Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr s8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr s9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + mov v8.S[2], v9.S[0] + add v8.4s,v8.4s,v\Vec1Reg\().4s + + mov w27, v8.S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v8.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + mov w27, v\Vec1Reg\().S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v\Vec1Reg\().S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow4Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + + mov v8.D[1], v9.D[0] + + add v8.4s,v8.4s,v\Vec1Reg\().4s + + mov x27, v8.D[0] + mov x28, v8.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + mov x27, v\Vec1Reg\().D[0] + mov x28, v\Vec1Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.endif + + .endm + + + .macro OutputRow6Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#8 + ldr w28,[\AddrReg1\()],#-8 + mov v8.S[2], w28 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-8 + mov v9.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + mov x27, v8.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v8.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v9.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v9.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + mov x27, v4.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v4.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v5.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v5.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow8Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + +.endif + + .endm + + + .macro OutputRow10Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr w28, [\AddrReg1\()],#-16 + +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr w27,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + mov v8.S[0], w28 + mov v8.S[2], w27 + + add v8.4s,v8.4s,v\Vec3Reg\().4s + + mov w27, v8.S[0] + mov w28, v8.S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov w27, v\Vec3Reg\().S[0] + mov w28, v\Vec3Reg\().S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif +.endif + +.endm + + + .macro OutputRow12Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#-16 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + mov v11.D[0],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + + mov v10.D[1], v11.D[0] + + add v10.4s,v10.4s,v\Vec3Reg\().4s + + mov x27, v10.D[0] + mov x28, v10.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov x27, v\Vec3Reg\().D[0] + mov x28, v\Vec3Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif +.endif + + .endm + + .macro OutputRow14Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#8 + ldr w28, [\AddrReg1\()],#-24 + mov v10.S[2], w28 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-24 + mov v11.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + add v10.4s,v10.4s,v6.4s + add v11.4s,v11.4s,v7.4s + + str q8,[\AddrReg1\()],#16 + + mov x27, v10.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v10.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 + mov x27, v11.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v11.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + str q4,[\AddrReg1\()],#16 + mov x27, v6.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v6.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 + mov x27, v7.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v7.S[2] + str w27, [\AddrReg2\()],#4 +.endif +.endif + + .endm + + + .macro OutputRow16Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldp q8,q10,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldp q9,q11,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + add v10.4s,v10.4s,v6.4s + add v11.4s,v11.4s,v7.4s + + stp q8,q10,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q9,q11,[\AddrReg2\()],#32 +.endif +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + stp q4,q6,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q5,q7,[\AddrReg2\()],#32 +.endif +.endif + + .endm + +// +// OutputBlock +// +// Generates the code to store the output block. +// + + .macro OutputBlock Mode, Columns, Rows + + OutputRow\Columns\()Element \Mode\(),x2,x13,16,17,18,19,(\Rows\() == 1) + +.if \Rows\() > 2 + OutputRow\Columns\()Element \Mode\(),x14,x15,20,21,22,23,(\Rows\() == 3) +.endif + +.if \Rows\() > 4 + OutputRow\Columns\()Element \Mode\(),x16,x17,24,25,26,27,(\Rows\() == 5) +.endif + +.if \Rows\() > 6 + OutputRow\Columns\()Element \Mode\(),x18,x19,28,29,30,31,(\Rows\() == 7) +.endif + + .endm +// +// ProcessRows +// +// Generates the code to process a compute and store the output block for a +// fixed number of rows. +// + + .macro ProcessRows Mode, Rows + mov x4,#\Rows\() // return number of rows handled + cmp x5,#6 + ble .L\Mode\().ProcessNextColumnLoop6x\Rows\() + +.L\Mode\().ProcessNextColumnLoop8x\Rows\(): + ComputeBlockLoop \Mode\(),8,\Rows\() + + sub x5,x5,#8 + cmp x5,#0 + blt .L\Mode\().Output14ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),16,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#6 + bgt .L\Mode\().ProcessNextColumnLoop8x\Rows\() + cbz x5,.L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop6x\Rows\(): + + cmp x5,#4 + ble .L\Mode\().ProcessNextColumnLoop4x\Rows\() + ComputeBlockLoop \Mode\(),6,\Rows\() + sub x5,x5,#6 + cmp x5,#0 + blt .L\Mode\().Output10ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),12,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#4 + bgt .L\Mode\().ProcessNextColumnLoop6x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop4x\Rows\(): + cmp x5,#2 + ble .L\Mode\().ProcessNextColumnLoop2x\Rows\() + ComputeBlockLoop \Mode\(),4,\Rows\() + sub x5,x5,#4 + cmp x5,#0 + blt .L\Mode\().Output6ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),8,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#2 + bgt .L\Mode\().ProcessNextColumnLoop4x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop2x\Rows\(): + ComputeBlockLoop \Mode\(),2,\Rows\() + sub x5,x5,#2 + cmp x5,#0 + blt .L\Mode\().Output2ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),4,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#2 + b .L\Mode\().ExitKernel + +.L\Mode\().Output14ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),14,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output10ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),10,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output6ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),6,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output2ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),2,\Rows\() + b .L\Mode\().ExitKernel + + .endm + + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (x0) - Supplies the address of matrix A. The matrix data has been packed + using MlasGemmQuantCopyPackA. + + B (x1) - Supplies the address of matrix B. The matrix data has been packed + using MlasGemmQuantCopyPackB. + + C (x2) - Supplies the address of matrix C. + + PackedCountK (x3) - Supplies the number of packed columns from matrix A and + the number of packed rows from matrix B to iterate over. + + CountM (x4) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (x5) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + ldc (x6) - Supplies the first dimension of matrix C. + + RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values + have been pre-scaled by the zero point offset of matrix B if the offset + is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be + scaled by the per-column zero point offsets of matrix B. These values are + accumulated into every row of matrix C. + + ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied + by the zero point offset of matrix A. These values are accumulated into + every column of matrix C. + + ZeroPointB - Optionally supplies the per-column zero point offsets of matrix + B, else nullptr if the matrix B is using per-tensor quantization. + +Return Value: + + Returns the number of rows handled. + +--*/ + + .macro QgemmU8X8KernelUmmlaFunction Mode + + FUNCTION_ENTRY MlasGemmU8X8KernelUmmla\Mode\() + + ldr x10,[sp, #0] + ldr x11,[sp,#8] + + stp x19, x20, [sp, #.LMlasQgemmKernel_SavedRegisters_Neg]! + stp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] + stp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] + stp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] + stp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] + stp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] + stp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] + stp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] + stp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] + + add x13,x2,x6,lsl #2 // compute matrix C plus 1 row + add x14,x13,x6,lsl #2 // compute matrix C plus 2 rows + add x15,x14,x6,lsl #2 // compute matrix C plus 3 rows + add x16,x15,x6,lsl #2 // compute matrix C plus 4 rows + add x17,x16,x6,lsl #2 // compute matrix C plus 5 rows + add x18,x17,x6,lsl #2 // compute matrix C plus 6 rows + add x19,x18,x6,lsl #2 // compute matrix C plus 7 rows + + mov x8,x0 // save matrix A + +// +// Process 8 rows of the matrices. +// + ld1 {v12.4s},[x7],#16 // load row sum 1 ~ 4 + cmp x4,#8 + blt .L\Mode\().ProcessCountMLessThan8 + ld1 {v13.4s},[x7],#16 // load row sum 5 ~ 8 + ProcessRows \Mode\(),8 + +// +// Restore non-volatile registers and return. +// + +.L\Mode\().ExitKernel: + mov x0,x4 + + ldp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] + ldp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] + ldp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] + ldp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] + ldp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] + ldp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] + ldp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] + ldp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] + ldp x19, x20, [sp], #.LMlasQgemmKernel_SavedRegisters + + ret + +// +// Process 4 rows of the matrix. +// + +.L\Mode\().ProcessCountMLessThan8: + cmp x4,#4 + blt .L\Mode\().ProcessCountMLessThan4 + ProcessRows \Mode\(),4 + b .L\Mode\().ExitKernel + +// +// Process 2 row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan4: + cmp x4,#2 + blt .L\Mode\().ProcessCountMLessThan2 + + ProcessRows \Mode\(),2 + b .L\Mode\().ExitKernel + + +// +// Process the last row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan2: + ProcessRows \Mode\(),1 + b .L\Mode\().ExitKernel + + + .endm + + QgemmU8X8KernelUmmlaFunction Zero + QgemmU8X8KernelUmmlaFunction Add + + .end diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index b6ac4a1ca1d6c..e0c2772cbb719 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -184,11 +184,17 @@ class MLASCPUIDInfo bool IsCurrentCoreArmv8NarrowLd() const { return false; } + bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; } + + bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } + private: MLASCPUIDInfo(); bool has_arm_neon_dot_{false}; bool has_fp16_{false}; + bool has_arm_neon_i8mm_{false}; + bool has_arm_sve_i8mm_{false}; }; using MLAS_CPUIDINFO = MLASCPUIDInfo; @@ -856,6 +862,8 @@ extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchNeon; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmX8S8DispatchNeon; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUdot; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSdot; +extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUmmla; +extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSmmla; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmSimd; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmQuantDispatchDefault; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemm8X8DispatchPOWER10; @@ -1069,6 +1077,23 @@ MlasTrySimpleParallel( const std::function& Work ); + +/** + * @brief Distribute many iterations of work over a thread pool if supported. + * This function is for small workloads in non-performance critical situation. + * + * @param ThreadPool [IN] Optional thread pool. Ignored when using OpenMP + * @param Iterations [IN] Total number of iterations + * @param Work [IN] Logic for computing a range of iterations [begin, end) + */ +void +MlasTryBatchParallel( + MLAS_THREADPOOL * ThreadPool, + const std::ptrdiff_t Iterations, + const std::function& Work + ); + + inline ptrdiff_t MlasGetMaximumThreadCount( diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 3c0f82408179b..39586282e00ad 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -52,6 +52,14 @@ MLASCPUIDInfo::MLASCPUIDInfo() #define HWCAP_ASIMDDP (1 << 20) #endif +#ifndef HWCAP2_I8MM +#define HWCAP2_I8MM (1 << 13) +#endif + +#ifndef HWCAP2_SVEI8MM +#define HWCAP2_SVEI8MM (1 << 9) +#endif + #if defined(BUILD_MLAS_NO_ONNXRUNTIME) MLASCPUIDInfo::MLASCPUIDInfo() { @@ -59,6 +67,9 @@ MLASCPUIDInfo::MLASCPUIDInfo() // raw hack! Need CPUIDInfo implementation for more precise detection has_fp16_ = has_arm_neon_dot_; + + has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); + has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); } #endif @@ -480,6 +491,17 @@ Return Value: this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; } +#if defined(__linux__) + // + // Check if the processor supports ASIMD I8MM instructions. + // + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) { + this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchUmmla; + this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchUmmla; + this->GemmS8S8Dispatch = &MlasGemmS8S8DispatchSmmla; + } +#endif + #endif // MLAS_TARGET_ARM64 #if defined(MLAS_TARGET_POWER) this->GemmFloatKernel = MlasSgemmKernel; diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index 85c0d13006126..fbd1030de8ab7 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -294,3 +294,664 @@ MlasQ4GemmUnPackB( return MlasQ4GemmUnPackBImpl(FpData, PackedBuf, N, K, ldb); } } + + + +/*************************************************************** + * The quantization format that pack data and quantization + * parameters into separate buffers. + */ + + +template < + int Row_, ///< rows of a matrix + int Column_ ///< columns of a matrix + > +struct Shape2D { + static int const kRow = Row_; ///< rows of a matrix + static int const kColumn = Column_; ///< columns of a matrix + static int const kCount = Row_ * Column_; ///< total number of elements in a matrix +}; + + +template +struct BitsTraits { + static_assert(qbits <= 8, "Only BitsTraits are for small number of bits!"); + + static constexpr int kBits = qbits; + static constexpr int kMax = (1 << qbits) - 1; + static constexpr int kMid = 1 << (qbits - 1); + static constexpr float kMaxFp = static_cast(kMax); + + // number of qbit elements to pack into whole bytes + static constexpr int kPackSize = (qbits == 8) ? 1 : (qbits == 4) ? 2 : (qbits == 2) ? 4 : 0; + static_assert(kPackSize != 0, "Packing to whole bytes not supported for this qbits!"); +}; + + +/** + * @brief Rectify min/max from a set of weights, and convert to scale and zero point + * for quantization + * @tparam ScaleT type of scale, usually floating point of various bits + * @tparam qbits number of int bits used for zero point value + * @param[in] min + * @param[in] max + * @param[out] scale + * @param[out] zp + */ +template +MLAS_FORCEINLINE +void +range2scalezp(float min, float max, ScaleT& scale, uint8_t& zp) +{ + constexpr int zp_max = BitsTraits::kMax; + constexpr float zp_max_fp = BitsTraits::kMaxFp; + + min = std::min(min, 0.0f); + max = std::max(max, 0.0f); + + float scale_f = (max - min) / zp_max; + + float zero_point_fp = min; + if (scale_f != 0.0f) { + zero_point_fp = 0.f - min / scale_f; + } + + if (zero_point_fp < 0.0f) { + zp = 0; + } else if (zero_point_fp > zp_max_fp) { + zp = zp_max; + } else { + zp = (uint8_t)roundf(zero_point_fp); + } + scale = ScaleT(scale_f); +} + +template +MLAS_FORCEINLINE +void +range2scale(float min, float max, ScaleT& scale) +{ + constexpr int mid_v = BitsTraits::kMid; + constexpr float mid_fp = static_cast(-mid_v); + + max = fabsf(max) > fabsf(min) ? max : min; + + scale = ScaleT(max / mid_fp); +}; + + +/** + * @brief Blockwise quantization methods + * @tparam ElementT source data type, e.g. fp32/fp16 + * @tparam block_size number of elemenets quantized together + * @tparam qbits number of bits in each quantized element + * @tparam Columnwise true: elements in a block come from one single column + * false: elements in a block come from one single row + */ +template < + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +struct BlockwiseQuantizer { + // To support other qbits, need to add bit packing code for + // storing to dst and zero points + static_assert(qbits == 4, "Only 4b block quantization is supported!"); + + using QuantBlk = std::conditional_t, Shape2D<1, block_size>>; + using ThreadBlk = Shape2D::kPackSize, QuantBlk::kColumn>; + + static + MLAS_FORCEINLINE + void quantizeMetaShape(int rows, int columns, int& meta_rows, int& meta_cols) + { + meta_rows = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + meta_cols = (columns + QuantBlk::kColumn - 1) / QuantBlk::kColumn; + } + + static + MLAS_FORCEINLINE + void quantizedShape(int rows, int columns, int& q_rows, int& q_cols) { + int meta_rows; + int meta_cols; + quantizeMetaShape(rows, columns, meta_rows, meta_cols); + + // quantized matrix is stored in column major, packed by column + q_rows = (meta_rows * QuantBlk::kRow * qbits + 7) / 8; + q_cols = meta_cols * QuantBlk::kColumn; + } + + /** + * @brief Quantized a Matrix shape [rows, columns], resulting quantized + * and packed data are stored in column major (transposed) + * @param[out] dst pointer to the quantized weights, column major: [columns, rows] + * @param[out] scale pointer to the scales, column major: [columns/QuantBlk::kColumn, rows/QuantBlk::kRow] + * @param[out] zero_points pointer to the zero points, same shape as scale + * @param[in] src pointer to the source matrix, row major: [rows, columns] + * @param rows + * @param columns + * @param leadingDimension stride of the source matrix, i.e. distance from one row to the next + */ + static void quantizeAndTranspose( + uint8_t* dst, + ElementT* scales, + uint8_t* zero_points, + const ElementT* src, + int32_t rows, + int32_t columns, + int32_t leadingDimension, + MLAS_THREADPOOL* thread_pool) + { + // Thread partitioning + const auto thrd_row_blks = (rows + ThreadBlk::kRow - 1) / ThreadBlk::kRow; + const auto thrd_col_blks = (columns + ThreadBlk::kColumn - 1) / ThreadBlk::kColumn; + const auto total_thrd_blks = thrd_row_blks * thrd_col_blks; + + const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + + int q_rows, q_cols; + quantizedShape(rows, columns, q_rows, q_cols); + + MlasTryBatchParallel( + thread_pool, total_thrd_blks, + [&](ptrdiff_t block_idx) { + uint8_t zp_bytes[BitsTraits::kPackSize]; + std::fill_n(zp_bytes, BitsTraits::kPackSize, (uint8_t)8); + + const int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); + const int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); + + const int32_t r = r_blk_idx * ThreadBlk::kRow; + const int32_t c = c_blk_idx * ThreadBlk::kColumn; + + const int32_t r_end = std::min(r + ThreadBlk::kRow, rows); + const int32_t c_end = std::min(c + ThreadBlk::kColumn, columns); + + const int meta_row = r / QuantBlk::kRow; + const int meta_col = c / QuantBlk::kColumn; + + // compute scale and zero point + for (int kpack = 0; kpack < BitsTraits::kPackSize; kpack++) { + + // scan a single block to extract range [min, max] + float min = std::numeric_limits::max(); + float max = -min; + const int row_start = r + kpack * QuantBlk::kRow; + const int row_end = std::min(row_start + QuantBlk::kRow, r_end); + for (int i = row_start; i < row_end; ++i) { + for (int j = c; j < c_end; ++j) { + const float v = static_cast(src[i * leadingDimension + j]); + if (v < min) min = v; + if (v > max) max = v; + } + } + + // store scale and zero point at quant parameter matrix position + if (row_start < row_end) { + const int32_t meta_idx = meta_col * row_blks + meta_row + kpack; + if (zero_points == nullptr) { + range2scale(min, max, scales[meta_idx]); + } else { + range2scalezp(min, max, scales[meta_idx], zp_bytes[kpack]); + } + } + } + + // !! 4b specific code as we need to pack 2 4b numbers into one byte + if (zero_points != nullptr) { + const int32_t meta_idx = meta_col * ((row_blks + 1) / 2) + meta_row / 2; + zero_points[meta_idx] = (zp_bytes[0] & 0xf) | (zp_bytes[1] << 4); + } + + for (int32_t j = c; j < c_end; ++j) { + const int32_t meta_c = j / QuantBlk::kColumn; + for (int32_t i = r; i < r_end; i += 2) { + const int32_t meta_r = i / QuantBlk::kRow; + const float scale = static_cast(scales[meta_c * row_blks + meta_r]); + const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; + const int8_t zp = zp_bytes[meta_r & 1]; + const int8_t zp1 = zp_bytes[((i + 1) / QuantBlk::kRow) & 1]; + + const float v0 = static_cast(src[i * leadingDimension + j]); + const uint8_t vi0 = (uint8_t)std::clamp(roundf(v0 * reciprocal_scale + zp), + 0.0f, BitsTraits::kMaxFp); + + uint8_t vi1 = (uint8_t)zp; + if (i + 1 < r_end) { + float reciprocal_scale1 = reciprocal_scale; + if constexpr (QuantBlk::kRow == 1) { + const float scale1 = + static_cast(scales[meta_c * row_blks + meta_r + 1]); + reciprocal_scale1 = scale1 ? 1.0f / scale1 : 0.0f; + } + const float v1 = static_cast(src[(i + 1) * leadingDimension + j]); + vi1 = (uint8_t)std::clamp(roundf(v1 * reciprocal_scale1 + zp1), 0.0f, + BitsTraits::kMaxFp); + } + + // !! 4b specific code + dst[j * q_rows + i / 2] = (vi0 & 0xf) | (vi1 << 4); + } + } + }); + } + + /** + * @brief Dequantize a column major quantized matrix, and store the result in a column major + * matrix for use in GEMM + * @param[out] dst pointer to the dequantized matrix, column major: [columns, rows] + * @param[in] weights pointer to the quantized weights, column major: [columns, rows] + * @param[in] scales pointer to the scales of quantized blocks, column major layout + * @param[in] zero_points pointer to the zero points of quantized blocks, packed column major + * scales + * @param[in] rows + * @param[in] columns + */ + static void dequantize( + ElementT* dst, + const uint8_t* weights, + const ElementT* scales, + const uint8_t* zero_points, + int32_t rows, + int32_t columns, + MLAS_THREADPOOL* thread_pool) + { + // Thread partitioning + const auto thrd_row_blks = (rows + ThreadBlk::kRow - 1) / ThreadBlk::kRow; + const auto thrd_col_blks = (columns + ThreadBlk::kColumn - 1) / ThreadBlk::kColumn; + const auto total_thrd_blks = thrd_row_blks * thrd_col_blks; + + const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + + int q_rows, q_cols; + quantizedShape(rows, columns, q_rows, q_cols); + + MlasTryBatchParallel( + thread_pool, total_thrd_blks, + [&](ptrdiff_t block_idx) { + int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); + int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); + + int32_t r = r_blk_idx * ThreadBlk::kRow; + int32_t c = c_blk_idx * ThreadBlk::kColumn; + + int32_t r_end = std::min(r + ThreadBlk::kRow, rows); + int32_t c_end = std::min(c + ThreadBlk::kColumn, columns); + + for (int32_t j = c; j < c_end; ++j) { + const int32_t meta_col = j / QuantBlk::kColumn; + + // !! 4b specific code + // the whole loop is 4b specific due to sub 8 bit packing + // and unpacking. We can potentially make this qbits generic + // by wraping the packing/unpacking code like cutlass::Array + for (int32_t i = r; i < r_end; i += 2) { + const int32_t meta_row = i / QuantBlk::kRow; + + const float scale0 = + static_cast(scales[meta_col * row_blks + meta_row]); + + const int zp_pair = + (zero_points == nullptr) + ? 0x88 + : zero_points[meta_col * ((row_blks + 1) / 2) + meta_row / 2]; + const int zp0 = (meta_row & 1) ? (zp_pair >> 4) : (zp_pair & 0xf); + + const uint8_t vi0 = weights[j * q_rows + i / 2] & 0xf; + const float v0 = (static_cast(vi0) - zp0) * scale0; + + dst[j * rows + i] = static_cast(v0); + if ((i + 1) < r_end) { + float scale1 = scale0; + int zp1 = zp0; + if constexpr (QuantBlk::kRow == 1) { + scale1 = + static_cast(scales[meta_col * row_blks + meta_row + 1]); + zp1 = (zp_pair >> 4) & 0xf; + } + const uint8_t vi1 = weights[j * q_rows + i / 2] >> 4; + const float v1 = (static_cast(vi1) - zp1) * scale1; + dst[j * rows + (i + 1)] = static_cast(v1); + } + } + } + }); + } +}; + + +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ) +{ + switch (block_size) { + case 16: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } + break; + } + case 32: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape( + rows, columns, meta_rows, meta_cols); + } + break; + } + case 64: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } + break; + } + case 128: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } + break; + } + case 256: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } + break; + } + default: + meta_rows = 0; + meta_cols = 0; + break; + } +} + + + +template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ) +{ + switch (block_size) { + case 16: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } + break; + } + case 32: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape( + rows, columns, q_rows, q_cols); + } + break; + } + case 64: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } + break; + } + case 128: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } + break; + } + case 256: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } + break; + } + default: + q_rows = 0; + q_cols = 0; + break; + } +} + + +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); + +template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ); + + +template +void +MlasQuantizeBlockwise( + uint8_t* dst, + T* scales, + uint8_t* zero_points, + const T* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ) +{ + switch (block_size) { + case 16: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + case 32: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + case 64: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + case 128: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + case 256: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + default: + // Only block size 16, 32, 64, 128, 256 are supported. + break; + } +} + +template +void +MlasQuantizeBlockwise( + uint8_t* dst, + float* scales, + uint8_t* zero_points, + const float* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ); + +template +void +MlasQuantizeBlockwise( + uint8_t* dst, + MLAS_FP16* scales, + uint8_t* zero_points, + const MLAS_FP16* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ); + + +template +void +MlasDequantizeBlockwise( + T* dst, + const uint8_t* src, + const T* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool + ) +{ + switch (block_size) { + case 16: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } + break; + case 32: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } + break; + case 64: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } + break; + case 128: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, + rows, columns, thread_pool); + } + break; + case 256: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, + rows, columns, thread_pool); + } + break; + default: + // Only block size 16, 32, 64, 128, 256 are supported. + break; + } +} + +template +void +MlasDequantizeBlockwise( + float* dst, + const uint8_t* src, + const float* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool + ); diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_smmla.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_smmla.cpp new file mode 100644 index 0000000000000..c41f43ca22d18 --- /dev/null +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_smmla.cpp @@ -0,0 +1,964 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + qgemm_kernel_smmla.cpp + +Abstract: + + This module implements smmla QGEMM kernel. + +--*/ + +#include "mlasi.h" +#include "qgemm.h" + +// +// Define the prototypes of the NEON SMMLA routines written in assembly. +// + +extern "C" { + +size_t MLASCALL +MlasGemmS8S8KernelSmmlaZero(const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB); + +size_t MLASCALL +MlasGemmS8S8KernelSmmlaAdd(const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB); +} + +struct MLAS_GEMM_S8S8_KERNEL_SMMLA { + typedef uint8_t PackedAType; + typedef uint8_t PackedBType; + typedef int8_t OffsetAType; + typedef int8_t OffsetBType; + + static constexpr size_t PackedK = 8; + static constexpr MLAS_GEMM_QUANT_STRIDES Strides{24, 128, 256}; + static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{24, 128, 384}; +}; + +constexpr size_t MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedK; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_S8S8_KERNEL_SMMLA::Strides; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedStrides; + +template <> +MLAS_FORCEINLINE int32_t +MlasGemmQuantFixupZeroPointB(int32_t ZeroPointB, bool BIsSigned) +{ + MLAS_UNREFERENCED_PARAMETER(BIsSigned); + return ZeroPointB; +} + +template <> +void +MlasGemmQuantCopyPackA( + MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedAType* D_uint8_t, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer, + bool AIsSigned) +{ + int8_t* D = reinterpret_cast(D_uint8_t); + MLAS_UNREFERENCED_PARAMETER(AIsSigned); + int8_t PaddedMatrixAData[64]; + + // + // Process 8 rows of matrix A. + // + // MMLA kernels load 8x8 block of A with four vector registers. So A is packed + // a series of 64 byte vectors where eight rows are interleaved with the + // following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // [ C0 C1 C2 C3 C4 C5 C6 C7 ] + // [ D0 D1 D2 D3 D4 D5 D6 D7 ] + // [ E0 E1 E2 E3 E4 E5 E6 E7 ] + // [ F0 F1 F2 F3 F4 F5 F6 F7 ] + // [ G0 G1 G2 G3 G4 G5 G6 G7 ] + // [ H0 H1 H2 H3 H4 H5 H6 H7 ] + // + // ... + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + while (CountM >= 8) { + const int8_t* a0 = reinterpret_cast(A); + const int8_t* a1 = a0 + lda; + const int8_t* a2 = a0 + lda * 2; + const int8_t* a3 = a0 + lda * 3; + const int8_t* a4 = a0 + lda * 4; + const int8_t* a5 = a0 + lda * 5; + const int8_t* a6 = a0 + lda * 6; + const int8_t* a7 = a0 + lda * 7; + + size_t k = CountK; + int32x4_t RowSums0 = vmovq_n_s32(0); + int32x4_t RowSums1 = vmovq_n_s32(0); + + while (k >= 16) { + int64x2_t v0 = vld1q_s64(reinterpret_cast(a0)); + a0 += 16; + int64x2_t v1 = vld1q_s64(reinterpret_cast(a1)); + a1 += 16; + int64x2_t v2 = vld1q_s64(reinterpret_cast(a2)); + a2 += 16; + int64x2_t v3 = vld1q_s64(reinterpret_cast(a3)); + a3 += 16; + int64x2_t v4 = vld1q_s64(reinterpret_cast(a4)); + a4 += 16; + int64x2_t v5 = vld1q_s64(reinterpret_cast(a5)); + a5 += 16; + int64x2_t v6 = vld1q_s64(reinterpret_cast(a6)); + a6 += 16; + int64x2_t v7 = vld1q_s64(reinterpret_cast(a7)); + a7 += 16; + + int64x2_t z0 = vzip1q_s64(v0, v1); + int64x2_t z1 = vzip2q_s64(v0, v1); + int64x2_t z2 = vzip1q_s64(v2, v3); + int64x2_t z3 = vzip2q_s64(v2, v3); + + int64x2_t z4 = vzip1q_s64(v4, v5); + int64x2_t z5 = vzip2q_s64(v4, v5); + int64x2_t z6 = vzip1q_s64(v6, v7); + int64x2_t z7 = vzip2q_s64(v6, v7); + + vst1q_s8(&D[0], vreinterpretq_s8_s64(z0)); + vst1q_s8(&D[16], vreinterpretq_s8_s64(z2)); + vst1q_s8(&D[32], vreinterpretq_s8_s64(z4)); + vst1q_s8(&D[48], vreinterpretq_s8_s64(z6)); + vst1q_s8(&D[64], vreinterpretq_s8_s64(z1)); + vst1q_s8(&D[80], vreinterpretq_s8_s64(z3)); + vst1q_s8(&D[96], vreinterpretq_s8_s64(z5)); + vst1q_s8(&D[112], vreinterpretq_s8_s64(z7)); + + int32x4_t RowSums0L_pada = vmovq_n_s32(0); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z0))); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z1))); + + int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); + int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); + int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), + vdups_laneq_s32(RowSums0L_add, 2)}; + + int32x4_t RowSums0H_pada = vmovq_n_s32(0); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z2))); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z3))); + + int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); + int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); + int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), + vdups_laneq_s32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_s32(RowSums0, vcombine_s32(RowSums0L, RowSums0H)); + + int32x4_t RowSums1L_pada = vmovq_n_s32(0); + RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z4))); + RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z5))); + + int32x4_t RowSums1L_ext = vextq_s32(RowSums1L_pada, RowSums1L_pada, 1); + int32x4_t RowSums1L_add = vaddq_s32(RowSums1L_pada, RowSums1L_ext); + int32x2_t RowSums1L = {vdups_laneq_s32(RowSums1L_add, 0), + vdups_laneq_s32(RowSums1L_add, 2)}; + + int32x4_t RowSums1H_pada = vmovq_n_s32(0); + RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z6))); + RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z7))); + + int32x4_t RowSums1H_ext = vextq_s32(RowSums1H_pada, RowSums1H_pada, 1); + int32x4_t RowSums1H_add = vaddq_s32(RowSums1H_pada, RowSums1H_ext); + int32x2_t RowSums1H = {vdups_laneq_s32(RowSums1H_add, 0), + vdups_laneq_s32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_s32(RowSums1, vcombine_s32(RowSums1L, RowSums1H)); + + D += 128; + k -= 16; + } + + while (k >= 8) { + int64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + int64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + int64x1_t v2 = *reinterpret_cast(a2); + a2 += 8; + int64x1_t v3 = *reinterpret_cast(a3); + a3 += 8; + int64x1_t v4 = *reinterpret_cast(a4); + a4 += 8; + int64x1_t v5 = *reinterpret_cast(a5); + a5 += 8; + int64x1_t v6 = *reinterpret_cast(a6); + a6 += 8; + int64x1_t v7 = *reinterpret_cast(a7); + a7 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + *reinterpret_cast(&D[16]) = v2; + *reinterpret_cast(&D[24]) = v3; + *reinterpret_cast(&D[32]) = v4; + *reinterpret_cast(&D[40]) = v5; + *reinterpret_cast(&D[48]) = v6; + *reinterpret_cast(&D[56]) = v7; + + int64x2_t z01 = vcombine_s64(v0, v1); + int64x2_t z23 = vcombine_s64(v2, v3); + int64x2_t z45 = vcombine_s64(v4, v5); + int64x2_t z67 = vcombine_s64(v6, v7); + + int32x4_t RowSums0L_pada = vmovq_n_s32(0); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); + int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); + int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), + vdups_laneq_s32(RowSums0L_add, 2)}; + + int32x4_t RowSums0H_pada = vmovq_n_s32(0); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); + + int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); + int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); + int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), + vdups_laneq_s32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_s32(RowSums0, vcombine_s32(RowSums0L, RowSums0H)); + + int32x4_t RowSums1L_pada = vmovq_n_s32(0); + RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z45))); + + int32x4_t RowSums1L_ext = vextq_s32(RowSums1L_pada, RowSums1L_pada, 1); + int32x4_t RowSums1L_add = vaddq_s32(RowSums1L_pada, RowSums1L_ext); + int32x2_t RowSums1L = {vdups_laneq_s32(RowSums1L_add, 0), + vdups_laneq_s32(RowSums1L_add, 2)}; + + int32x4_t RowSums1H_pada = vmovq_n_s32(0); + RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z67))); + + int32x4_t RowSums1H_ext = vextq_s32(RowSums1H_pada, RowSums1H_pada, 1); + int32x4_t RowSums1H_add = vaddq_s32(RowSums1H_pada, RowSums1H_ext); + int32x2_t RowSums1H = {vdups_laneq_s32(RowSums1H_add, 0), + vdups_laneq_s32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_s32(RowSums1, vcombine_s32(RowSums1L, RowSums1H)); + + D += 64; + k -= 8; + } + + if (k > 0) { + // + // zero pad the remaining columns to 8 + // + int8_t* d = D; + + vst1q_s8(d, vmovq_n_s8(0)); + vst1q_s8(&d[16], vmovq_n_s8(0)); + vst1q_s8(&d[32], vmovq_n_s8(0)); + vst1q_s8(&d[48], vmovq_n_s8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + d[16] = *a2++; + d[24] = *a3++; + d[32] = *a4++; + d[40] = *a5++; + d[48] = *a6++; + d[56] = *a7++; + d += 1; + k -= 1; + } + d = D; + int64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v2 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v3 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v4 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v5 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v6 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v7 = *reinterpret_cast(d); + d = d + 8; + + int64x2_t z01 = vcombine_s64(v0, v1); + int64x2_t z23 = vcombine_s64(v2, v3); + int64x2_t z45 = vcombine_s64(v4, v5); + int64x2_t z67 = vcombine_s64(v6, v7); + + int32x4_t RowSums0L_pada = vmovq_n_s32(0); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); + int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); + int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), + vdups_laneq_s32(RowSums0L_add, 2)}; + + int32x4_t RowSums0H_pada = vmovq_n_s32(0); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); + + int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); + int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); + int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), + vdups_laneq_s32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_s32(RowSums0, vcombine_s32(RowSums0L, RowSums0H)); + + int32x4_t RowSums1L_pada = vmovq_n_s32(0); + RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z45))); + + int32x4_t RowSums1L_ext = vextq_s32(RowSums1L_pada, RowSums1L_pada, 1); + int32x4_t RowSums1L_add = vaddq_s32(RowSums1L_pada, RowSums1L_ext); + int32x2_t RowSums1L = {vdups_laneq_s32(RowSums1L_add, 0), + vdups_laneq_s32(RowSums1L_add, 2)}; + + int32x4_t RowSums1H_pada = vmovq_n_s32(0); + RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z67))); + + int32x4_t RowSums1H_ext = vextq_s32(RowSums1H_pada, RowSums1H_pada, 1); + int32x4_t RowSums1H_add = vaddq_s32(RowSums1H_pada, RowSums1H_ext); + int32x2_t RowSums1H = {vdups_laneq_s32(RowSums1H_add, 0), + vdups_laneq_s32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_s32(RowSums1, vcombine_s32(RowSums1L, RowSums1H)); + + D += 64; + } + + vst1q_s32(RowSumBuffer, RowSums0); + vst1q_s32(&RowSumBuffer[4], RowSums1); + + RowSumBuffer += 8; + + A = A + lda * 8; + CountM -= 8; + } + + // + // Process four rows of matrix A. + // + // The buffer is packed as a series of 32 byte vectors where four rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // [ C0 C1 C2 C3 C4 C5 C6 C7 ] + // [ D0 D1 D2 D3 D4 D5 D6 D7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + if (CountM >= 4) { + const int8_t* a0 = reinterpret_cast(A); + const int8_t* a1 = a0 + lda; + const int8_t* a2 = a1 + lda; + const int8_t* a3 = a2 + lda; + + size_t k = CountK; + int32x4_t RowSums = vmovq_n_s32(0); + + while (k >= 16) { + int64x2_t v0 = vld1q_s64(reinterpret_cast(a0)); + a0 += 16; + int64x2_t v1 = vld1q_s64(reinterpret_cast(a1)); + a1 += 16; + int64x2_t v2 = vld1q_s64(reinterpret_cast(a2)); + a2 += 16; + int64x2_t v3 = vld1q_s64(reinterpret_cast(a3)); + a3 += 16; + + int64x2_t z0 = vzip1q_s64(v0, v1); + int64x2_t z1 = vzip2q_s64(v0, v1); + int64x2_t z2 = vzip1q_s64(v2, v3); + int64x2_t z3 = vzip2q_s64(v2, v3); + + vst1q_s8(&D[0], vreinterpretq_s8_s64(z0)); + vst1q_s8(&D[16], vreinterpretq_s8_s64(z2)); + vst1q_s8(&D[32], vreinterpretq_s8_s64(z1)); + vst1q_s8(&D[48], vreinterpretq_s8_s64(z3)); + + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z0))); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z1))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + int32x4_t RowSumsH_pada = vmovq_n_s32(0); + RowSumsH_pada = vpadalq_s16(RowSumsH_pada, vpaddlq_s8(vreinterpretq_s8_s64(z2))); + RowSumsH_pada = vpadalq_s16(RowSumsH_pada, vpaddlq_s8(vreinterpretq_s8_s64(z3))); + + int32x4_t RowSumsH_ext = vextq_s32(RowSumsH_pada, RowSumsH_pada, 1); + int32x4_t RowSumsH_add = vaddq_s32(RowSumsH_pada, RowSumsH_ext); + int32x2_t RowSumsH = {vdups_laneq_s32(RowSumsH_add, 0), + vdups_laneq_s32(RowSumsH_add, 2)}; + + RowSums = vaddq_s32(RowSums, vcombine_s32(RowSumsL, RowSumsH)); + + D += 64; + k -= 16; + } + + while (k >= 8) { + int64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + int64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + int64x1_t v2 = *reinterpret_cast(a2); + a2 += 8; + int64x1_t v3 = *reinterpret_cast(a3); + a3 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + *reinterpret_cast(&D[16]) = v2; + *reinterpret_cast(&D[24]) = v3; + + int64x2_t z01 = vcombine_s64(v0, v1); + int64x2_t z23 = vcombine_s64(v2, v3); + + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + int32x4_t RowSumsH_pada = vmovq_n_s32(0); + RowSumsH_pada = vpadalq_s16(RowSumsH_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); + + int32x4_t RowSumsH_ext = vextq_s32(RowSumsH_pada, RowSumsH_pada, 1); + int32x4_t RowSumsH_add = vaddq_s32(RowSumsH_pada, RowSumsH_ext); + int32x2_t RowSumsH = {vdups_laneq_s32(RowSumsH_add, 0), + vdups_laneq_s32(RowSumsH_add, 2)}; + + RowSums = vaddq_s32(RowSums, vcombine_s32(RowSumsL, RowSumsH)); + + D += 32; + k -= 8; + } + + if (k > 0) { + // + // Copy the remaining bytes with zero padding. + // + int8_t* d = D; + + vst1q_s8(d, vmovq_n_s8(0)); + vst1q_s8(&d[16], vmovq_n_s8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + d[16] = *a2++; + d[24] = *a3++; + d += 1; + k -= 1; + } + + d = D; + int64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v2 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v3 = *reinterpret_cast(d); + d = d + 8; + + int64x2_t z01 = vcombine_s64(v0, v1); + int64x2_t z23 = vcombine_s64(v2, v3); + + int32x4_t RowSums0L_pada = vmovq_n_s32(0); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); + int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); + int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), + vdups_laneq_s32(RowSums0L_add, 2)}; + + int32x4_t RowSums0H_pada = vmovq_n_s32(0); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); + + int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); + int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); + int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), + vdups_laneq_s32(RowSums0H_add, 2)}; + + RowSums = vaddq_s32(RowSums, vcombine_s32(RowSums0L, RowSums0H)); + + D += 32; + } + + vst1q_s32(RowSumBuffer, RowSums); + RowSumBuffer += 4; + + A = A + lda * 4; + CountM -= 4; + } + + // + // Process two rows of matrix A. + // + // The buffer is packed as a series of 16 byte vectors where two rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + if (CountM >= 2) { + const int8_t* a0 = reinterpret_cast(A); + const int8_t* a1 = a0 + lda; + + size_t k = CountK; + int32x2_t RowSums = vmov_n_s32(0); + + while (k >= 16) { + int64x2_t v0 = vld1q_s64(reinterpret_cast(a0)); + a0 += 16; + int64x2_t v1 = vld1q_s64(reinterpret_cast(a1)); + a1 += 16; + + int64x2_t z0 = vzip1q_s64(v0, v1); + int64x2_t z1 = vzip2q_s64(v0, v1); + + vst1q_s8(&D[0], vreinterpretq_s8_s64(z0)); + vst1q_s8(&D[16], vreinterpretq_s8_s64(z1)); + + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z0))); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z1))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + RowSums = vadd_s32(RowSums, RowSumsL); + + D += 32; + k -= 16; + } + + while (k >= 8) { + int64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + int64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + + int64x2_t z01 = vcombine_s64(v0, v1); + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + RowSums = vadd_s32(RowSums, RowSumsL); + + D += 16; + k -= 8; + } + + if (k > 0) { + // + // Zero pad the remaining elements to make 8 columns. + // + + int8_t* d = PaddedMatrixAData; + vst1q_s8(PaddedMatrixAData, vmovq_n_s8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + + d += 1; + k -= 1; + } + + d = PaddedMatrixAData; + int64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + + int64x2_t z01 = vcombine_s64(v0, v1); + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + RowSums = vadd_s32(RowSums, RowSumsL); + + int8x16_t PackedVector = vld1q_s8(PaddedMatrixAData); + vst1q_s8(D, PackedVector); + + D += 16; + } + + vst1_s32(RowSumBuffer, RowSums); + RowSumBuffer += 2; + + A = A + lda * 2; + CountM -= 2; + } + + // + // Process one row of matrix A. + // + // The buffer is packed as a series of 8 byte with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of 8, then the vector is padded + // with zeroes. + // + + if (CountM > 0) { + // No need to pad the rows to 2, the .S takes care of zero pdding + const int8_t* a = reinterpret_cast(A); + size_t k = CountK; + int32x4_t RowSums = vmovq_n_s32(0); + + while (k >= 16) { + int8x16_t v = vld1q_s8(a); + a += 16; + + vst1q_s8(D, v); + + RowSums = vpadalq_s16(RowSums, vpaddlq_s8(v)); + + D += 16; + k -= 16; + } + + if (k > 0) { + // + // Copy the remaining bytes to the zero padded stack buffer. + // + + vst1q_s8(PaddedMatrixAData, vmovq_n_s8(0)); + + for (size_t kk = 0; kk < k; kk++) { + PaddedMatrixAData[kk] = a[kk]; + } + + int8x16_t v = vld1q_s8(PaddedMatrixAData); + vst1q_s8(D, v); + + RowSums = vpadalq_s16(RowSums, vpaddlq_s8(v)); + } + + *RowSumBuffer = int32_t(vaddvq_s32(RowSums)); + } +} + +MLAS_FORCEINLINE +void +MlasGemmS8S8CopyPackBProcessSmmla(int8_t* D, int8x8_t BytesRow[8], int32x4_t ColumnSums[2]) +{ + int8x16_t v02 = vcombine_s8(BytesRow[0], BytesRow[2]); + int8x16_t v13 = vcombine_s8(BytesRow[1], BytesRow[3]); + + int8x16_t v46 = vcombine_s8(BytesRow[4], BytesRow[6]); + int8x16_t v57 = vcombine_s8(BytesRow[5], BytesRow[7]); + + int8x16x2_t zw1 = vzipq_s8(v02, v13); + int16x8x2_t zd1 = vzipq_s16(vreinterpretq_s16_s8(zw1.val[0]), vreinterpretq_s16_s8(zw1.val[1])); + + int8x16x2_t zw2 = vzipq_s8(v46, v57); + int16x8x2_t zd2 = vzipq_s16(vreinterpretq_s16_s8(zw2.val[0]), vreinterpretq_s16_s8(zw2.val[1])); + + int32x4x2_t zd3 = + vzipq_s32(vreinterpretq_s32_s16(zd1.val[0]), vreinterpretq_s32_s16(zd2.val[0])); + int32x4x2_t zd4 = + vzipq_s32(vreinterpretq_s32_s16(zd1.val[1]), vreinterpretq_s32_s16(zd2.val[1])); + + vst1q_s8(&D[0], vreinterpretq_s8_s32(zd3.val[0])); + vst1q_s8(&D[16], vreinterpretq_s8_s32(zd3.val[1])); + vst1q_s8(&D[32], vreinterpretq_s8_s32(zd4.val[0])); + vst1q_s8(&D[48], vreinterpretq_s8_s32(zd4.val[1])); + + int32x4_t ColSums0L_pada = vmovq_n_s32(0); + ColSums0L_pada = vpadalq_s16(ColSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd3.val[0]))); + int32x4_t ColSums0L_ext = vextq_s32(ColSums0L_pada, ColSums0L_pada, 1); + int32x4_t ColSums0L_add = vaddq_s32(ColSums0L_pada, ColSums0L_ext); + int32x2_t ColSums0L = {vdups_laneq_s32(ColSums0L_add, 0), vdups_laneq_s32(ColSums0L_add, 2)}; + + int32x4_t ColSums0H_pada = vmovq_n_s32(0); + ColSums0H_pada = vpadalq_s16(ColSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd3.val[1]))); + int32x4_t ColSums0H_ext = vextq_s32(ColSums0H_pada, ColSums0H_pada, 1); + int32x4_t ColSums0H_add = vaddq_s32(ColSums0H_pada, ColSums0H_ext); + int32x2_t ColSums0H = {vdups_laneq_s32(ColSums0H_add, 0), vdups_laneq_s32(ColSums0H_add, 2)}; + + ColumnSums[0] = vaddq_s32(ColumnSums[0], vcombine_s32(ColSums0L, ColSums0H)); + + int32x4_t ColSums1L_pada = vmovq_n_s32(0); + ColSums1L_pada = vpadalq_s16(ColSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd4.val[0]))); + int32x4_t ColSums1L_ext = vextq_s32(ColSums1L_pada, ColSums1L_pada, 1); + int32x4_t ColSums1L_add = vaddq_s32(ColSums1L_pada, ColSums1L_ext); + int32x2_t ColSums1L = {vdups_laneq_s32(ColSums1L_add, 0), vdups_laneq_s32(ColSums1L_add, 2)}; + + int32x4_t ColSums1H_pada = vmovq_n_s32(0); + ColSums1H_pada = vpadalq_s16(ColSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd4.val[1]))); + int32x4_t ColSums1H_ext = vextq_s32(ColSums1H_pada, ColSums1H_pada, 1); + int32x4_t ColSums1H_add = vaddq_s32(ColSums1H_pada, ColSums1H_ext); + int32x2_t ColSums1H = {vdups_laneq_s32(ColSums1H_add, 0), vdups_laneq_s32(ColSums1H_add, 2)}; + + ColumnSums[1] = vaddq_s32(ColumnSums[1], vcombine_s32(ColSums1L, ColSums1H)); +} + +template <> +void +MlasGemmQuantCopyPackB(MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedBType* Dst, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned) +{ + MLAS_UNREFERENCED_PARAMETER(BIsSigned); + int8_t* D = reinterpret_cast(Dst); + const int8x16_t ZeroVector = vmovq_n_s8(0); + int8x8_t BytesRow[8]; + + // + // Copy data from matrix B into the destination buffer 8x2 blocks at a + // time. + // + // + while (CountN >= 8) { + const int8_t* b = reinterpret_cast(B); + size_t k = CountK; + int32x4_t ColumnSums[2]; + + ColumnSums[0] = vmovq_n_s32(0); + ColumnSums[1] = vmovq_n_s32(0); + + while (k >= 8) { + BytesRow[0] = vld1_s8(&b[ldb * 0]); + BytesRow[1] = vld1_s8(&b[ldb * 1]); + BytesRow[2] = vld1_s8(&b[ldb * 2]); + BytesRow[3] = vld1_s8(&b[ldb * 3]); + BytesRow[4] = vld1_s8(&b[ldb * 4]); + BytesRow[5] = vld1_s8(&b[ldb * 5]); + BytesRow[6] = vld1_s8(&b[ldb * 6]); + BytesRow[7] = vld1_s8(&b[ldb * 7]); + + MlasGemmS8S8CopyPackBProcessSmmla(D, BytesRow, ColumnSums); + + D += 64; + b += ldb * 8; + k -= 8; + } + + if (k > 0) { + // Pad k to 8 + + BytesRow[0] = vld1_s8(&b[ldb * 0]); + BytesRow[1] = (k >= 2) ? vld1_s8(&b[ldb * 1]) : vget_low_s8(ZeroVector); + BytesRow[2] = (k >= 3) ? vld1_s8(&b[ldb * 2]) : vget_low_s8(ZeroVector); + BytesRow[3] = (k >= 4) ? vld1_s8(&b[ldb * 3]) : vget_low_s8(ZeroVector); + BytesRow[4] = (k >= 5) ? vld1_s8(&b[ldb * 4]) : vget_low_s8(ZeroVector); + BytesRow[5] = (k >= 6) ? vld1_s8(&b[ldb * 5]) : vget_low_s8(ZeroVector); + BytesRow[6] = (k >= 7) ? vld1_s8(&b[ldb * 6]) : vget_low_s8(ZeroVector); + BytesRow[7] = vget_low_s8(ZeroVector); + + MlasGemmS8S8CopyPackBProcessSmmla(D, BytesRow, ColumnSums); + + D += 64; + } + + // Zero pad the output buffer to a multiple of PackedK if the above + // processed an odd number of four row bundles. + // + vst1q_s32(&ColumnSumBuffer[0], ColumnSums[0]); + vst1q_s32(&ColumnSumBuffer[4], ColumnSums[1]); + + ColumnSumBuffer += 8; + + B += 8; + CountN -= 8; + } + + // + // Process the remaining columns of matrix B. + // + + if (CountN > 0) { + const int8_t* b = reinterpret_cast(B); + size_t k = CountK; + int8_t PaddedMatrixBData[64]; + int32x4_t ColumnSums[2]; + + vst1q_s8(&PaddedMatrixBData[0], ZeroVector); + vst1q_s8(&PaddedMatrixBData[16], ZeroVector); + vst1q_s8(&PaddedMatrixBData[32], ZeroVector); + vst1q_s8(&PaddedMatrixBData[48], ZeroVector); + + ColumnSums[0] = vmovq_n_s32(0); + ColumnSums[1] = vmovq_n_s32(0); + + // + // Interleave rows of matrix B using an intermediate zero padded stack + // buffer and write to the packed buffer. + // + + while (k > 0) { + const int8_t* bcopy0 = &b[ldb * 0]; + const int8_t* bcopy1 = &b[ldb * 1]; + const int8_t* bcopy2 = &b[ldb * 2]; + const int8_t* bcopy3 = &b[ldb * 3]; + const int8_t* bcopy4 = &b[ldb * 4]; + const int8_t* bcopy5 = &b[ldb * 5]; + const int8_t* bcopy6 = &b[ldb * 6]; + const int8_t* bcopy7 = &b[ldb * 7]; + + if (k >= 8) { + b += ldb * 8; + k -= 8; + + } else { + vst1q_s8(&PaddedMatrixBData[0], ZeroVector); + vst1q_s8(&PaddedMatrixBData[16], ZeroVector); + vst1q_s8(&PaddedMatrixBData[32], ZeroVector); + vst1q_s8(&PaddedMatrixBData[48], ZeroVector); + + bcopy1 = (k >= 2) ? bcopy1 : &PaddedMatrixBData[56]; + bcopy2 = (k >= 3) ? bcopy2 : &PaddedMatrixBData[56]; + bcopy3 = (k >= 4) ? bcopy3 : &PaddedMatrixBData[56]; + bcopy4 = (k >= 5) ? bcopy4 : &PaddedMatrixBData[56]; + bcopy5 = (k >= 6) ? bcopy5 : &PaddedMatrixBData[56]; + bcopy6 = (k >= 7) ? bcopy6 : &PaddedMatrixBData[56]; + bcopy7 = &PaddedMatrixBData[56]; + + k = 0; + } + + int8_t* padded = PaddedMatrixBData; + int8_t* padded_end = padded + CountN; + do { + padded[0] = *bcopy0++; + padded[8] = *bcopy1++; + padded[16] = *bcopy2++; + padded[24] = *bcopy3++; + padded[32] = *bcopy4++; + padded[40] = *bcopy5++; + padded[48] = *bcopy6++; + padded[56] = *bcopy7++; + + } while (++padded < padded_end); + + BytesRow[0] = vld1_s8(&PaddedMatrixBData[0]); + BytesRow[1] = vld1_s8(&PaddedMatrixBData[8]); + BytesRow[2] = vld1_s8(&PaddedMatrixBData[16]); + BytesRow[3] = vld1_s8(&PaddedMatrixBData[24]); + BytesRow[4] = vld1_s8(&PaddedMatrixBData[32]); + BytesRow[5] = vld1_s8(&PaddedMatrixBData[40]); + BytesRow[6] = vld1_s8(&PaddedMatrixBData[48]); + BytesRow[7] = vld1_s8(&PaddedMatrixBData[56]); + + MlasGemmS8S8CopyPackBProcessSmmla(D, BytesRow, ColumnSums); + + D += 64; + } + + vst1q_s32(&ColumnSumBuffer[0], ColumnSums[0]); + vst1q_s32(&ColumnSumBuffer[4], ColumnSums[1]); + } +} + +template <> +MLAS_FORCEINLINE size_t +MlasGemmQuantKernel(const MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedAType* A, + const MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode) +{ + size_t RowsHandled; + + if (ZeroMode) { + RowsHandled = MlasGemmS8S8KernelSmmlaZero(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB); + } else { + RowsHandled = MlasGemmS8S8KernelSmmlaAdd(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB); + } + + return RowsHandled; +} + +const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSmmla = { + MlasGemmQuantOperation, + MlasGemmQuantPackedOperation, + MlasGemmQuantCopyPackB, + MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedK, + MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedStrides.K, + 8}; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_ummla.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_ummla.cpp new file mode 100644 index 0000000000000..3936154432ac7 --- /dev/null +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_ummla.cpp @@ -0,0 +1,967 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + qgemm_kernel_ummla.cpp + +Abstract: + + This module implements ummla QGEMM kernel. + +--*/ + +#include "mlasi.h" +#include "qgemm.h" + +// +// Define the prototypes of the NEON UMMLA routines written in assembly. +// + +extern "C" { + +size_t MLASCALL +MlasGemmU8X8KernelUmmlaZero(const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB); + +size_t MLASCALL +MlasGemmU8X8KernelUmmlaAdd(const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB); +} + +struct MLAS_GEMM_U8X8_KERNEL_UMMLA { + typedef uint8_t PackedAType; + typedef uint8_t PackedBType; + typedef uint8_t OffsetAType; + typedef uint8_t OffsetBType; + + static constexpr size_t PackedK = 8; + static constexpr MLAS_GEMM_QUANT_STRIDES Strides{24, 128, 256}; + static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{24, 128, 384}; +}; + +constexpr size_t MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedK; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8X8_KERNEL_UMMLA::Strides; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedStrides; + +template <> +MLAS_FORCEINLINE int32_t +MlasGemmQuantFixupZeroPointB(int32_t ZeroPointB, bool BIsSigned) +{ + if (BIsSigned) { + ZeroPointB = MLAS_GEMM_U8X8_KERNEL_UMMLA::OffsetBType(ZeroPointB ^ 0x80); + } + + return ZeroPointB; +} + +template <> +void +MlasGemmQuantCopyPackA(MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer, + bool AIsSigned) +{ + MLAS_UNREFERENCED_PARAMETER(AIsSigned); + uint8_t PaddedMatrixAData[64]; + + // + // Process 8 rows of matrix A. + // + // MMLA kernels load 8x8 block of A with four vector registers. So A is packed + // a series of 64 byte vectors where eight rows are interleaved with the + // following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // [ C0 C1 C2 C3 C4 C5 C6 C7 ] + // [ D0 D1 D2 D3 D4 D5 D6 D7 ] + // [ E0 E1 E2 E3 E4 E5 E6 E7 ] + // [ F0 F1 F2 F3 F4 F5 F6 F7 ] + // [ G0 G1 G2 G3 G4 G5 G6 G7 ] + // [ H0 H1 H2 H3 H4 H5 H6 H7 ] + // + // ... + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + while (CountM >= 8) { + const uint8_t* a0 = A; + const uint8_t* a1 = a0 + lda; + const uint8_t* a2 = a0 + lda * 2; + const uint8_t* a3 = a0 + lda * 3; + const uint8_t* a4 = a0 + lda * 4; + const uint8_t* a5 = a0 + lda * 5; + const uint8_t* a6 = a0 + lda * 6; + const uint8_t* a7 = a0 + lda * 7; + + size_t k = CountK; + uint32x4_t RowSums0 = vmovq_n_u32(0); + uint32x4_t RowSums1 = vmovq_n_u32(0); + + while (k >= 16) { + uint64x2_t v0 = vld1q_u64(reinterpret_cast(a0)); + a0 += 16; + uint64x2_t v1 = vld1q_u64(reinterpret_cast(a1)); + a1 += 16; + uint64x2_t v2 = vld1q_u64(reinterpret_cast(a2)); + a2 += 16; + uint64x2_t v3 = vld1q_u64(reinterpret_cast(a3)); + a3 += 16; + uint64x2_t v4 = vld1q_u64(reinterpret_cast(a4)); + a4 += 16; + uint64x2_t v5 = vld1q_u64(reinterpret_cast(a5)); + a5 += 16; + uint64x2_t v6 = vld1q_u64(reinterpret_cast(a6)); + a6 += 16; + uint64x2_t v7 = vld1q_u64(reinterpret_cast(a7)); + a7 += 16; + + uint64x2_t z0 = vzip1q_u64(v0, v1); + uint64x2_t z1 = vzip2q_u64(v0, v1); + uint64x2_t z2 = vzip1q_u64(v2, v3); + uint64x2_t z3 = vzip2q_u64(v2, v3); + + uint64x2_t z4 = vzip1q_u64(v4, v5); + uint64x2_t z5 = vzip2q_u64(v4, v5); + uint64x2_t z6 = vzip1q_u64(v6, v7); + uint64x2_t z7 = vzip2q_u64(v6, v7); + + vst1q_u8(&D[0], vreinterpretq_u8_u64(z0)); + vst1q_u8(&D[16], vreinterpretq_u8_u64(z2)); + vst1q_u8(&D[32], vreinterpretq_u8_u64(z4)); + vst1q_u8(&D[48], vreinterpretq_u8_u64(z6)); + vst1q_u8(&D[64], vreinterpretq_u8_u64(z1)); + vst1q_u8(&D[80], vreinterpretq_u8_u64(z3)); + vst1q_u8(&D[96], vreinterpretq_u8_u64(z5)); + vst1q_u8(&D[112], vreinterpretq_u8_u64(z7)); + + uint32x4_t RowSums0L_pada = vmovq_n_u32(0); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z0))); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z1))); + + uint32x4_t RowSums0L_ext = vextq_u32(RowSums0L_pada, RowSums0L_pada, 1); + uint32x4_t RowSums0L_add = vaddq_u32(RowSums0L_pada, RowSums0L_ext); + uint32x2_t RowSums0L = {vdups_laneq_u32(RowSums0L_add, 0), + vdups_laneq_u32(RowSums0L_add, 2)}; + + uint32x4_t RowSums0H_pada = vmovq_n_u32(0); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z2))); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z3))); + + uint32x4_t RowSums0H_ext = vextq_u32(RowSums0H_pada, RowSums0H_pada, 1); + uint32x4_t RowSums0H_add = vaddq_u32(RowSums0H_pada, RowSums0H_ext); + uint32x2_t RowSums0H = {vdups_laneq_u32(RowSums0H_add, 0), + vdups_laneq_u32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_u32(RowSums0, vcombine_u32(RowSums0L, RowSums0H)); + + uint32x4_t RowSums1L_pada = vmovq_n_u32(0); + RowSums1L_pada = vpadalq_u16(RowSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z4))); + RowSums1L_pada = vpadalq_u16(RowSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z5))); + + uint32x4_t RowSums1L_ext = vextq_u32(RowSums1L_pada, RowSums1L_pada, 1); + uint32x4_t RowSums1L_add = vaddq_u32(RowSums1L_pada, RowSums1L_ext); + uint32x2_t RowSums1L = {vdups_laneq_u32(RowSums1L_add, 0), + vdups_laneq_u32(RowSums1L_add, 2)}; + + uint32x4_t RowSums1H_pada = vmovq_n_u32(0); + RowSums1H_pada = vpadalq_u16(RowSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z6))); + RowSums1H_pada = vpadalq_u16(RowSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z7))); + + uint32x4_t RowSums1H_ext = vextq_u32(RowSums1H_pada, RowSums1H_pada, 1); + uint32x4_t RowSums1H_add = vaddq_u32(RowSums1H_pada, RowSums1H_ext); + uint32x2_t RowSums1H = {vdups_laneq_u32(RowSums1H_add, 0), + vdups_laneq_u32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_u32(RowSums1, vcombine_u32(RowSums1L, RowSums1H)); + + D += 128; + k -= 16; + } + + while (k >= 8) { + uint64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + uint64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + uint64x1_t v2 = *reinterpret_cast(a2); + a2 += 8; + uint64x1_t v3 = *reinterpret_cast(a3); + a3 += 8; + uint64x1_t v4 = *reinterpret_cast(a4); + a4 += 8; + uint64x1_t v5 = *reinterpret_cast(a5); + a5 += 8; + uint64x1_t v6 = *reinterpret_cast(a6); + a6 += 8; + uint64x1_t v7 = *reinterpret_cast(a7); + a7 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + *reinterpret_cast(&D[16]) = v2; + *reinterpret_cast(&D[24]) = v3; + *reinterpret_cast(&D[32]) = v4; + *reinterpret_cast(&D[40]) = v5; + *reinterpret_cast(&D[48]) = v6; + *reinterpret_cast(&D[56]) = v7; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint64x2_t z23 = vcombine_u64(v2, v3); + uint64x2_t z45 = vcombine_u64(v4, v5); + uint64x2_t z67 = vcombine_u64(v6, v7); + + uint32x4_t RowSums0L_pada = vmovq_n_u32(0); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSums0L_ext = vextq_u32(RowSums0L_pada, RowSums0L_pada, 1); + uint32x4_t RowSums0L_add = vaddq_u32(RowSums0L_pada, RowSums0L_ext); + uint32x2_t RowSums0L = {vdups_laneq_u32(RowSums0L_add, 0), + vdups_laneq_u32(RowSums0L_add, 2)}; + + uint32x4_t RowSums0H_pada = vmovq_n_u32(0); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z23))); + + uint32x4_t RowSums0H_ext = vextq_u32(RowSums0H_pada, RowSums0H_pada, 1); + uint32x4_t RowSums0H_add = vaddq_u32(RowSums0H_pada, RowSums0H_ext); + uint32x2_t RowSums0H = {vdups_laneq_u32(RowSums0H_add, 0), + vdups_laneq_u32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_u32(RowSums0, vcombine_u32(RowSums0L, RowSums0H)); + + uint32x4_t RowSums1L_pada = vmovq_n_u32(0); + RowSums1L_pada = vpadalq_u16(RowSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z45))); + + uint32x4_t RowSums1L_ext = vextq_u32(RowSums1L_pada, RowSums1L_pada, 1); + uint32x4_t RowSums1L_add = vaddq_u32(RowSums1L_pada, RowSums1L_ext); + uint32x2_t RowSums1L = {vdups_laneq_u32(RowSums1L_add, 0), + vdups_laneq_u32(RowSums1L_add, 2)}; + + uint32x4_t RowSums1H_pada = vmovq_n_u32(0); + RowSums1H_pada = vpadalq_u16(RowSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z67))); + + uint32x4_t RowSums1H_ext = vextq_u32(RowSums1H_pada, RowSums1H_pada, 1); + uint32x4_t RowSums1H_add = vaddq_u32(RowSums1H_pada, RowSums1H_ext); + uint32x2_t RowSums1H = {vdups_laneq_u32(RowSums1H_add, 0), + vdups_laneq_u32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_u32(RowSums1, vcombine_u32(RowSums1L, RowSums1H)); + + D += 64; + k -= 8; + } + + if (k > 0) { + // + // zero pad the remaining columns to 8 + // + uint8_t* d = D; + + vst1q_u8(d, vmovq_n_u8(0)); + vst1q_u8(&d[16], vmovq_n_u8(0)); + vst1q_u8(&d[32], vmovq_n_u8(0)); + vst1q_u8(&d[48], vmovq_n_u8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + d[16] = *a2++; + d[24] = *a3++; + d[32] = *a4++; + d[40] = *a5++; + d[48] = *a6++; + d[56] = *a7++; + d += 1; + k -= 1; + } + d = D; + uint64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v2 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v3 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v4 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v5 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v6 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v7 = *reinterpret_cast(d); + d = d + 8; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint64x2_t z23 = vcombine_u64(v2, v3); + uint64x2_t z45 = vcombine_u64(v4, v5); + uint64x2_t z67 = vcombine_u64(v6, v7); + + uint32x4_t RowSums0L_pada = vmovq_n_u32(0); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSums0L_ext = vextq_u32(RowSums0L_pada, RowSums0L_pada, 1); + uint32x4_t RowSums0L_add = vaddq_u32(RowSums0L_pada, RowSums0L_ext); + uint32x2_t RowSums0L = {vdups_laneq_u32(RowSums0L_add, 0), + vdups_laneq_u32(RowSums0L_add, 2)}; + + uint32x4_t RowSums0H_pada = vmovq_n_u32(0); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z23))); + + uint32x4_t RowSums0H_ext = vextq_u32(RowSums0H_pada, RowSums0H_pada, 1); + uint32x4_t RowSums0H_add = vaddq_u32(RowSums0H_pada, RowSums0H_ext); + uint32x2_t RowSums0H = {vdups_laneq_u32(RowSums0H_add, 0), + vdups_laneq_u32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_u32(RowSums0, vcombine_u32(RowSums0L, RowSums0H)); + + uint32x4_t RowSums1L_pada = vmovq_n_u32(0); + RowSums1L_pada = vpadalq_u16(RowSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z45))); + + uint32x4_t RowSums1L_ext = vextq_u32(RowSums1L_pada, RowSums1L_pada, 1); + uint32x4_t RowSums1L_add = vaddq_u32(RowSums1L_pada, RowSums1L_ext); + uint32x2_t RowSums1L = {vdups_laneq_u32(RowSums1L_add, 0), + vdups_laneq_u32(RowSums1L_add, 2)}; + + uint32x4_t RowSums1H_pada = vmovq_n_u32(0); + RowSums1H_pada = vpadalq_u16(RowSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z67))); + + uint32x4_t RowSums1H_ext = vextq_u32(RowSums1H_pada, RowSums1H_pada, 1); + uint32x4_t RowSums1H_add = vaddq_u32(RowSums1H_pada, RowSums1H_ext); + uint32x2_t RowSums1H = {vdups_laneq_u32(RowSums1H_add, 0), + vdups_laneq_u32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_u32(RowSums1, vcombine_u32(RowSums1L, RowSums1H)); + + D += 64; + } + + vst1q_s32(RowSumBuffer, vreinterpretq_s32_u32(RowSums0)); + vst1q_s32(&RowSumBuffer[4], vreinterpretq_s32_u32(RowSums1)); + + RowSumBuffer += 8; + + A = A + lda * 8; + CountM -= 8; + } + + // + // Process four rows of matrix A. + // + // The buffer is packed as a series of 32 byte vectors where four rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // [ C0 C1 C2 C3 C4 C5 C6 C7 ] + // [ D0 D1 D2 D3 D4 D5 D6 D7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + if (CountM >= 4) { + const uint8_t* a0 = A; + const uint8_t* a1 = a0 + lda; + const uint8_t* a2 = a1 + lda; + const uint8_t* a3 = a2 + lda; + + size_t k = CountK; + uint32x4_t RowSums = vmovq_n_u32(0); + + while (k >= 16) { + uint64x2_t v0 = vld1q_u64(reinterpret_cast(a0)); + a0 += 16; + uint64x2_t v1 = vld1q_u64(reinterpret_cast(a1)); + a1 += 16; + uint64x2_t v2 = vld1q_u64(reinterpret_cast(a2)); + a2 += 16; + uint64x2_t v3 = vld1q_u64(reinterpret_cast(a3)); + a3 += 16; + + uint64x2_t z0 = vzip1q_u64(v0, v1); + uint64x2_t z1 = vzip2q_u64(v0, v1); + uint64x2_t z2 = vzip1q_u64(v2, v3); + uint64x2_t z3 = vzip2q_u64(v2, v3); + + vst1q_u8(&D[0], vreinterpretq_u8_u64(z0)); + vst1q_u8(&D[16], vreinterpretq_u8_u64(z2)); + vst1q_u8(&D[32], vreinterpretq_u8_u64(z1)); + vst1q_u8(&D[48], vreinterpretq_u8_u64(z3)); + + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z0))); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z1))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + uint32x4_t RowSumsH_pada = vmovq_n_u32(0); + RowSumsH_pada = vpadalq_u16(RowSumsH_pada, vpaddlq_u8(vreinterpretq_u8_u64(z2))); + RowSumsH_pada = vpadalq_u16(RowSumsH_pada, vpaddlq_u8(vreinterpretq_u8_u64(z3))); + + uint32x4_t RowSumsH_ext = vextq_u32(RowSumsH_pada, RowSumsH_pada, 1); + uint32x4_t RowSumsH_add = vaddq_u32(RowSumsH_pada, RowSumsH_ext); + uint32x2_t RowSumsH = {vdups_laneq_u32(RowSumsH_add, 0), + vdups_laneq_u32(RowSumsH_add, 2)}; + + RowSums = vaddq_u32(RowSums, vcombine_u32(RowSumsL, RowSumsH)); + + D += 64; + k -= 16; + } + + while (k >= 8) { + uint64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + uint64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + uint64x1_t v2 = *reinterpret_cast(a2); + a2 += 8; + uint64x1_t v3 = *reinterpret_cast(a3); + a3 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + *reinterpret_cast(&D[16]) = v2; + *reinterpret_cast(&D[24]) = v3; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint64x2_t z23 = vcombine_u64(v2, v3); + + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + uint32x4_t RowSumsH_pada = vmovq_n_u32(0); + RowSumsH_pada = vpadalq_u16(RowSumsH_pada, vpaddlq_u8(vreinterpretq_u8_u64(z23))); + + uint32x4_t RowSumsH_ext = vextq_u32(RowSumsH_pada, RowSumsH_pada, 1); + uint32x4_t RowSumsH_add = vaddq_u32(RowSumsH_pada, RowSumsH_ext); + uint32x2_t RowSumsH = {vdups_laneq_u32(RowSumsH_add, 0), + vdups_laneq_u32(RowSumsH_add, 2)}; + + RowSums = vaddq_u32(RowSums, vcombine_u32(RowSumsL, RowSumsH)); + + D += 32; + k -= 8; + } + + if (k > 0) { + // + // Copy the remaining bytes with zero padding. + // + uint8_t* d = D; + + vst1q_u8(d, vmovq_n_u8(0)); + vst1q_u8(&d[16], vmovq_n_u8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + d[16] = *a2++; + d[24] = *a3++; + d += 1; + k -= 1; + } + + d = D; + uint64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v2 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v3 = *reinterpret_cast(d); + d = d + 8; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint64x2_t z23 = vcombine_u64(v2, v3); + + uint32x4_t RowSums0L_pada = vmovq_n_u32(0); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSums0L_ext = vextq_u32(RowSums0L_pada, RowSums0L_pada, 1); + uint32x4_t RowSums0L_add = vaddq_u32(RowSums0L_pada, RowSums0L_ext); + uint32x2_t RowSums0L = {vdups_laneq_u32(RowSums0L_add, 0), + vdups_laneq_u32(RowSums0L_add, 2)}; + + uint32x4_t RowSums0H_pada = vmovq_n_u32(0); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z23))); + + uint32x4_t RowSums0H_ext = vextq_u32(RowSums0H_pada, RowSums0H_pada, 1); + uint32x4_t RowSums0H_add = vaddq_u32(RowSums0H_pada, RowSums0H_ext); + uint32x2_t RowSums0H = {vdups_laneq_u32(RowSums0H_add, 0), + vdups_laneq_u32(RowSums0H_add, 2)}; + + RowSums = vaddq_u32(RowSums, vcombine_u32(RowSums0L, RowSums0H)); + + D += 32; + } + + vst1q_s32(RowSumBuffer, vreinterpretq_s32_u32(RowSums)); + RowSumBuffer += 4; + + A = A + lda * 4; + CountM -= 4; + } + + // + // Process two rows of matrix A. + // + // The buffer is packed as a series of 16 byte vectors where two rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + if (CountM >= 2) { + const uint8_t* a0 = A; + const uint8_t* a1 = a0 + lda; + + size_t k = CountK; + uint32x2_t RowSums = vmov_n_u32(0); + + while (k >= 16) { + uint64x2_t v0 = vld1q_u64(reinterpret_cast(a0)); + a0 += 16; + uint64x2_t v1 = vld1q_u64(reinterpret_cast(a1)); + a1 += 16; + + uint64x2_t z0 = vzip1q_u64(v0, v1); + uint64x2_t z1 = vzip2q_u64(v0, v1); + + vst1q_u8(&D[0], vreinterpretq_u8_u64(z0)); + vst1q_u8(&D[16], vreinterpretq_u8_u64(z1)); + + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z0))); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z1))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + RowSums = vadd_u32(RowSums, RowSumsL); + + D += 32; + k -= 16; + } + + while (k >= 8) { + uint64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + uint64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + RowSums = vadd_u32(RowSums, RowSumsL); + + D += 16; + k -= 8; + } + + if (k > 0) { + // + // Zero pad the remaining elements to make 8 columns. + // + + uint8_t* d = PaddedMatrixAData; + vst1q_u8(PaddedMatrixAData, vmovq_n_u8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + + d += 1; + k -= 1; + } + + d = PaddedMatrixAData; + uint64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + RowSums = vadd_u32(RowSums, RowSumsL); + + uint8x16_t PackedVector = vld1q_u8(PaddedMatrixAData); + vst1q_u8(D, PackedVector); + + D += 16; + } + + vst1_s32(RowSumBuffer, vreinterpret_s32_u32(RowSums)); + RowSumBuffer += 2; + + A = A + lda * 2; + CountM -= 2; + } + + // + // Process one row of matrix A. + // + // The buffer is packed as a series of 8 byte with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of 8, then the vector is padded + // with zeroes. + // + + if (CountM > 0) { + // No need to pad the rows to 2, the .S takes care of zero pdding + const uint8_t* a = A; + size_t k = CountK; + uint32x4_t RowSums = vmovq_n_u32(0); + + while (k >= 16) { + uint8x16_t v = vld1q_u8(a); + a += 16; + + vst1q_u8(D, v); + + RowSums = vpadalq_u16(RowSums, vpaddlq_u8(v)); + + D += 16; + k -= 16; + } + + if (k > 0) { + // + // Copy the remaining bytes to the zero padded stack buffer. + // + + vst1q_u8(PaddedMatrixAData, vmovq_n_u8(0)); + + for (size_t kk = 0; kk < k; kk++) { + PaddedMatrixAData[kk] = a[kk]; + } + + uint8x16_t v = vld1q_u8(PaddedMatrixAData); + vst1q_u8(D, v); + + RowSums = vpadalq_u16(RowSums, vpaddlq_u8(v)); + } + + *RowSumBuffer = int32_t(vaddvq_u32(RowSums)); + } +} + +MLAS_FORCEINLINE +void +MlasGemmU8X8CopyPackBProcessUmmla(MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedBType* D, + uint8x8_t BytesRow[8], + uint8x16_t BitFlipVector, + uint32x4_t ColumnSums[2]) +{ + uint8x16_t v02 = veorq_u8(vcombine_u8(BytesRow[0], BytesRow[2]), BitFlipVector); + uint8x16_t v13 = veorq_u8(vcombine_u8(BytesRow[1], BytesRow[3]), BitFlipVector); + + uint8x16_t v46 = veorq_u8(vcombine_u8(BytesRow[4], BytesRow[6]), BitFlipVector); + uint8x16_t v57 = veorq_u8(vcombine_u8(BytesRow[5], BytesRow[7]), BitFlipVector); + + uint8x16x2_t zw1 = vzipq_u8(v02, v13); + uint16x8x2_t zd1 = + vzipq_u16(vreinterpretq_u16_u8(zw1.val[0]), vreinterpretq_u16_u8(zw1.val[1])); + + uint8x16x2_t zw2 = vzipq_u8(v46, v57); + uint16x8x2_t zd2 = + vzipq_u16(vreinterpretq_u16_u8(zw2.val[0]), vreinterpretq_u16_u8(zw2.val[1])); + + uint32x4x2_t zd3 = + vzipq_u32(vreinterpretq_u32_u16(zd1.val[0]), vreinterpretq_u32_u16(zd2.val[0])); + uint32x4x2_t zd4 = + vzipq_u32(vreinterpretq_u32_u16(zd1.val[1]), vreinterpretq_u32_u16(zd2.val[1])); + + vst1q_u8(&D[0], vreinterpretq_u8_u32(zd3.val[0])); + vst1q_u8(&D[16], vreinterpretq_u8_u32(zd3.val[1])); + vst1q_u8(&D[32], vreinterpretq_u8_u32(zd4.val[0])); + vst1q_u8(&D[48], vreinterpretq_u8_u32(zd4.val[1])); + + uint32x4_t ColSums0L_pada = vmovq_n_u32(0); + ColSums0L_pada = vpadalq_u16(ColSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u32(zd3.val[0]))); + uint32x4_t ColSums0L_ext = vextq_u32(ColSums0L_pada, ColSums0L_pada, 1); + uint32x4_t ColSums0L_add = vaddq_u32(ColSums0L_pada, ColSums0L_ext); + uint32x2_t ColSums0L = {vdups_laneq_u32(ColSums0L_add, 0), vdups_laneq_u32(ColSums0L_add, 2)}; + + uint32x4_t ColSums0H_pada = vmovq_n_u32(0); + ColSums0H_pada = vpadalq_u16(ColSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u32(zd3.val[1]))); + uint32x4_t ColSums0H_ext = vextq_u32(ColSums0H_pada, ColSums0H_pada, 1); + uint32x4_t ColSums0H_add = vaddq_u32(ColSums0H_pada, ColSums0H_ext); + uint32x2_t ColSums0H = {vdups_laneq_u32(ColSums0H_add, 0), vdups_laneq_u32(ColSums0H_add, 2)}; + + ColumnSums[0] = vaddq_u32(ColumnSums[0], vcombine_u32(ColSums0L, ColSums0H)); + + uint32x4_t ColSums1L_pada = vmovq_n_u32(0); + ColSums1L_pada = vpadalq_u16(ColSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u32(zd4.val[0]))); + uint32x4_t ColSums1L_ext = vextq_u32(ColSums1L_pada, ColSums1L_pada, 1); + uint32x4_t ColSums1L_add = vaddq_u32(ColSums1L_pada, ColSums1L_ext); + uint32x2_t ColSums1L = {vdups_laneq_u32(ColSums1L_add, 0), vdups_laneq_u32(ColSums1L_add, 2)}; + + uint32x4_t ColSums1H_pada = vmovq_n_u32(0); + ColSums1H_pada = vpadalq_u16(ColSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u32(zd4.val[1]))); + uint32x4_t ColSums1H_ext = vextq_u32(ColSums1H_pada, ColSums1H_pada, 1); + uint32x4_t ColSums1H_add = vaddq_u32(ColSums1H_pada, ColSums1H_ext); + uint32x2_t ColSums1H = {vdups_laneq_u32(ColSums1H_add, 0), vdups_laneq_u32(ColSums1H_add, 2)}; + + ColumnSums[1] = vaddq_u32(ColumnSums[1], vcombine_u32(ColSums1L, ColSums1H)); +} + +template <> +void +MlasGemmQuantCopyPackB(MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned) +{ + const uint8x16_t BitFlipVector = vdupq_n_u8(BIsSigned ? 0x80 : 0); + uint8x8_t BytesRow[8]; + + // + // Copy data from matrix B into the destination buffer 8x2 blocks at a + // time. + // + // + while (CountN >= 8) { + const uint8_t* b = B; + size_t k = CountK; + uint32x4_t ColumnSums[2]; + ColumnSums[0] = vmovq_n_u32(0); + ColumnSums[1] = vmovq_n_u32(0); + + while (k >= 8) { + BytesRow[0] = vld1_u8(&b[ldb * 0]); + BytesRow[1] = vld1_u8(&b[ldb * 1]); + BytesRow[2] = vld1_u8(&b[ldb * 2]); + BytesRow[3] = vld1_u8(&b[ldb * 3]); + BytesRow[4] = vld1_u8(&b[ldb * 4]); + BytesRow[5] = vld1_u8(&b[ldb * 5]); + BytesRow[6] = vld1_u8(&b[ldb * 6]); + BytesRow[7] = vld1_u8(&b[ldb * 7]); + + MlasGemmU8X8CopyPackBProcessUmmla(D, BytesRow, BitFlipVector, ColumnSums); + + D += 64; + b += ldb * 8; + k -= 8; + } + + if (k > 0) { + // Pad k to 8 + + BytesRow[0] = vld1_u8(&b[ldb * 0]); + BytesRow[1] = (k >= 2) ? vld1_u8(&b[ldb * 1]) : vget_low_u8(BitFlipVector); + BytesRow[2] = (k >= 3) ? vld1_u8(&b[ldb * 2]) : vget_low_u8(BitFlipVector); + BytesRow[3] = (k >= 4) ? vld1_u8(&b[ldb * 3]) : vget_low_u8(BitFlipVector); + BytesRow[4] = (k >= 5) ? vld1_u8(&b[ldb * 4]) : vget_low_u8(BitFlipVector); + BytesRow[5] = (k >= 6) ? vld1_u8(&b[ldb * 5]) : vget_low_u8(BitFlipVector); + BytesRow[6] = (k >= 7) ? vld1_u8(&b[ldb * 6]) : vget_low_u8(BitFlipVector); + BytesRow[7] = vget_low_u8(BitFlipVector); + + MlasGemmU8X8CopyPackBProcessUmmla(D, BytesRow, BitFlipVector, ColumnSums); + + D += 64; + } + + // Zero pad the output buffer to a multiple of PackedK if the above + // processed an odd number of four row bundles. + // + vst1q_s32(&ColumnSumBuffer[0], vreinterpretq_s32_u32(ColumnSums[0])); + vst1q_s32(&ColumnSumBuffer[4], vreinterpretq_s32_u32(ColumnSums[1])); + + ColumnSumBuffer += 8; + + B += 8; + CountN -= 8; + } + + // + // Process the remaining columns of matrix B. + // + + if (CountN > 0) { + const uint8_t* b = B; + size_t k = CountK; + uint8_t PaddedMatrixBData[64]; + uint32x4_t ColumnSums[2]; + + vst1q_u8(&PaddedMatrixBData[0], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[16], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[32], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[48], BitFlipVector); + + ColumnSums[0] = vmovq_n_u32(0); + ColumnSums[1] = vmovq_n_u32(0); + + // + // Interleave rows of matrix B using an intermediate zero padded stack + // buffer and write to the packed buffer. + // + + while (k > 0) { + const uint8_t* bcopy0 = &b[ldb * 0]; + const uint8_t* bcopy1 = &b[ldb * 1]; + const uint8_t* bcopy2 = &b[ldb * 2]; + const uint8_t* bcopy3 = &b[ldb * 3]; + const uint8_t* bcopy4 = &b[ldb * 4]; + const uint8_t* bcopy5 = &b[ldb * 5]; + const uint8_t* bcopy6 = &b[ldb * 6]; + const uint8_t* bcopy7 = &b[ldb * 7]; + + if (k >= 8) { + b += ldb * 8; + k -= 8; + + } else { + vst1q_u8(&PaddedMatrixBData[0], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[16], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[32], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[48], BitFlipVector); + + bcopy1 = (k >= 2) ? bcopy1 : &PaddedMatrixBData[56]; + bcopy2 = (k >= 3) ? bcopy2 : &PaddedMatrixBData[56]; + bcopy3 = (k >= 4) ? bcopy3 : &PaddedMatrixBData[56]; + bcopy4 = (k >= 5) ? bcopy4 : &PaddedMatrixBData[56]; + bcopy5 = (k >= 6) ? bcopy5 : &PaddedMatrixBData[56]; + bcopy6 = (k >= 7) ? bcopy6 : &PaddedMatrixBData[56]; + bcopy7 = &PaddedMatrixBData[56]; + + k = 0; + } + + uint8_t* padded = PaddedMatrixBData; + uint8_t* padded_end = padded + CountN; + do { + padded[0] = *bcopy0++; + padded[8] = *bcopy1++; + padded[16] = *bcopy2++; + padded[24] = *bcopy3++; + padded[32] = *bcopy4++; + padded[40] = *bcopy5++; + padded[48] = *bcopy6++; + padded[56] = *bcopy7++; + + } while (++padded < padded_end); + + BytesRow[0] = vld1_u8(&PaddedMatrixBData[0]); + BytesRow[1] = vld1_u8(&PaddedMatrixBData[8]); + BytesRow[2] = vld1_u8(&PaddedMatrixBData[16]); + BytesRow[3] = vld1_u8(&PaddedMatrixBData[24]); + BytesRow[4] = vld1_u8(&PaddedMatrixBData[32]); + BytesRow[5] = vld1_u8(&PaddedMatrixBData[40]); + BytesRow[6] = vld1_u8(&PaddedMatrixBData[48]); + BytesRow[7] = vld1_u8(&PaddedMatrixBData[56]); + + MlasGemmU8X8CopyPackBProcessUmmla(D, BytesRow, BitFlipVector, ColumnSums); + + D += 64; + } + + vst1q_s32(&ColumnSumBuffer[0], vreinterpretq_s32_u32(ColumnSums[0])); + vst1q_s32(&ColumnSumBuffer[4], vreinterpretq_s32_u32(ColumnSums[1])); + } +} + +template <> +MLAS_FORCEINLINE size_t +MlasGemmQuantKernel(const MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode) +{ + size_t RowsHandled; + + if (ZeroMode) { + RowsHandled = MlasGemmU8X8KernelUmmlaZero(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB); + } else { + RowsHandled = MlasGemmU8X8KernelUmmlaAdd(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB); + } + + return RowsHandled; +} + +const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUmmla = { + MlasGemmQuantOperation, + MlasGemmQuantPackedOperation, + MlasGemmQuantCopyPackB, + MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedK, + MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedStrides.K, + 8}; diff --git a/onnxruntime/core/mlas/lib/threading.cpp b/onnxruntime/core/mlas/lib/threading.cpp index ecdc5250ebf0e..dc5daf998d3be 100644 --- a/onnxruntime/core/mlas/lib/threading.cpp +++ b/onnxruntime/core/mlas/lib/threading.cpp @@ -93,3 +93,41 @@ MlasTrySimpleParallel( MLAS_THREADPOOL::TrySimpleParallelFor(ThreadPool, Iterations, Work); #endif } + + +void +MlasTryBatchParallel( + MLAS_THREADPOOL * ThreadPool, + const std::ptrdiff_t Iterations, + const std::function& Work) +{ + // + // Execute the routine directly if only one iteration is specified. + // + if (Iterations == 1) { + Work(0); + return; + } + +#if defined(BUILD_MLAS_NO_ONNXRUNTIME) + MLAS_UNREFERENCED_PARAMETER(ThreadPool); + + // + // Fallback to OpenMP or a serialized implementation. + // + + // + // Execute the routine for the specified number of iterations. + // + for (ptrdiff_t tid = 0; tid < Iterations; tid++) { + Work(tid); + } +#else + // + // Schedule the threaded iterations using the thread pool object. + // + + MLAS_THREADPOOL::TryBatchParallelFor(ThreadPool, Iterations, Work, 0); +#endif + +} \ No newline at end of file diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index c090ab2a6cc9b..d27603e4ab3a1 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -4,7 +4,7 @@ #include "core/optimizer/conv_activation_fusion.h" #include - +#include #include "core/common/inlined_containers.h" #include "core/framework/tensorprotoutils.h" #include "core/mlas/inc/mlas.h" @@ -174,9 +174,29 @@ using NTO = NodesToOptimize; class FuseConvActivationAction : public ReplaceWithNew { private: - std::string OpType(const RuntimeState&) const override { return "FusedConv"; } + std::string OpType(const RuntimeState& runtime_state) const override { + const auto& domain = runtime_state.selected_nodes.Target().Domain(); + const auto& op_type = runtime_state.selected_nodes.Target().OpType(); + if (domain == kOnnxDomain) { + if (op_type == "Conv") { + return "FusedConv"; + } + } else if (domain == kMSDomain) { + if (op_type == "NhwcConv") { + return "NhwcFusedConv"; + } + } else if (domain == kMSInternalNHWCDomain) { + if (op_type == "Conv") { + return "Conv"; + } + } + ORT_THROW("Unsupported operator: ", op_type, " and domain: ", domain); + } - std::string Domain(const RuntimeState&) const override { return kMSDomain; } + std::string Domain(const RuntimeState& runtime_state) const override { + auto domain = runtime_state.selected_nodes.Target().Domain(); + return domain == kOnnxDomain ? kMSDomain : domain; + } NodeAttributes ExtraAttributes(const RuntimeState& state) const override { NodeAttributes extra_fused_conv_attributes; @@ -260,8 +280,11 @@ void RegisterConvActivationFusionRules(SelectorActionRegistry& registry) { const auto name = "ConvAct"; auto action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) + const std::string msInternalNHWCDomainConv = SelectorActionRegistry::OpVersionsMapKey("Conv", kMSInternalNHWCDomain); + const std::string msDomainConv = SelectorActionRegistry::OpVersionsMapKey("NhwcConv", kMSDomain); auto selector = std::make_unique(); - registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11}}}, + + registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11}}, {msInternalNHWCDomainConv, {11}}, {msDomainConv, {1}}}, std::move(selector), std::move(action)); #else registry.RegisterAction(name, std::move(action)); diff --git a/onnxruntime/core/optimizer/conv_add_act_fusion.cc b/onnxruntime/core/optimizer/conv_add_act_fusion.cc index 7c8bfeaec5f0f..6f90eaf07ef4d 100644 --- a/onnxruntime/core/optimizer/conv_add_act_fusion.cc +++ b/onnxruntime/core/optimizer/conv_add_act_fusion.cc @@ -287,12 +287,9 @@ class FuseConvAddActivationAction : public ReplaceWithNew { void RegisterConvAddActivationFusionRules(SelectorActionRegistry& registry) { auto action = std::make_unique(); auto selector = std::make_unique(); - registry.RegisterSelectorAndAction("ConvAddAct", {{"Conv", {1, 11}}}, + std::string msDomainNhwcFusedConv = SelectorActionRegistry::OpVersionsMapKey("NhwcFusedConv", kMSDomain); + registry.RegisterSelectorAndAction("ConvAddAct", {{"Conv", {1, 11}}, {msDomainNhwcFusedConv, {1, 11}}}, std::move(selector), std::move(action)); - auto action_nhwc = std::make_unique(); - auto selector_nhwc = std::make_unique(); - registry.RegisterSelectorAndAction("NhwcFusedConvAct", {{"NhwcFusedConv", {1, 11}}}, - std::move(selector_nhwc), std::move(action_nhwc)); } SelectorActionRegistry CreateSelectorActionRegistry() { diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index c4416068e2457..c1397e92d9d26 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -50,6 +50,8 @@ #include "core/optimizer/matmul_integer_to_float.h" #include "core/optimizer/matmul_scale_fusion.h" #include "core/optimizer/matmul_transpose_fusion.h" +#include "core/optimizer/matmul_bn_fusion.h" +#include "core/optimizer/pad_fusion.h" #include "core/optimizer/nchwc_transformer.h" #include "core/optimizer/noop_elimination.h" #include "core/optimizer/not_where_fusion.h" @@ -127,6 +129,8 @@ InlinedVector> GenerateRewriteRules( rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); + rules.push_back(std::make_unique()); + rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); break; @@ -268,11 +272,12 @@ InlinedVector> GenerateTransformers( onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider, onnxruntime::kDmlExecutionProvider}; - const InlinedHashSet cpu_cuda_rocm_acl_armnn_eps = {onnxruntime::kCpuExecutionProvider, - onnxruntime::kCudaExecutionProvider, - onnxruntime::kRocmExecutionProvider, - onnxruntime::kAclExecutionProvider, - onnxruntime::kArmNNExecutionProvider}; + const InlinedHashSet cpu_cuda_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kRocmExecutionProvider, + onnxruntime::kAclExecutionProvider, + onnxruntime::kArmNNExecutionProvider, + onnxruntime::kJsExecutionProvider}; #ifdef MLAS_TARGET_AMD64_IX86 const bool avx2_precision_mode = @@ -294,7 +299,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_ep)); transformers.emplace_back(std::make_unique(cpu_ep)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_acl_armnn_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_acl_armnn_js_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); diff --git a/onnxruntime/core/optimizer/initializer.cc b/onnxruntime/core/optimizer/initializer.cc index c8da15f65a6d7..9e807ddc7be59 100644 --- a/onnxruntime/core/optimizer/initializer.cc +++ b/onnxruntime/core/optimizer/initializer.cc @@ -291,7 +291,11 @@ Initializer& Initializer::sqrt() { namespace { template struct ScaleByAxis { - void operator()(Tensor& data, const Tensor& scalers, const size_t block_size, const size_t num_blocks) const { + void operator()(Tensor& data, + const Tensor& scalers, + const size_t block_size, + const size_t num_blocks, + const bool column_major) const { ToNumeric to_numeric; const auto scaler_size = scalers.Shape().Size(); T* dst = data.MutableData(); @@ -303,24 +307,32 @@ struct ScaleByAxis { } } else { for (size_t block_offset = 0, i = 0; i < num_blocks; i++) { - const auto numeric_scaler = to_numeric(scalers_data[i]); - for (size_t j = 0; j < block_size; ++j, ++block_offset) { - dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + if (column_major) { + for (size_t j = 0; j < block_size; ++j, ++block_offset) { + const auto numeric_scaler = to_numeric(scalers_data[j]); + dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + } + } else { + const auto numeric_scaler = to_numeric(scalers_data[i]); + for (size_t j = 0; j < block_size; ++j, ++block_offset) { + dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + } } } } } }; - } // namespace -void Initializer::scale_by_axis(const Initializer& scalers, int axis) { +void Initializer::scale_by_axis(const Initializer& scalers, int axis, bool column_major) { ORT_ENFORCE(axis >= 0, "Axis must be non-negative"); const size_t block_size = narrow(data_.Shape().SizeFromDimension(gsl::narrow_cast(axis))); const size_t num_blocks = size() / block_size; - ORT_ENFORCE(scalers.size() == 1 || scalers.size() == num_blocks, "Invalid other(scalers) size"); + ORT_ENFORCE(scalers.size() == 1 || + (column_major ? scalers.size() == block_size : scalers.size() == num_blocks), + "Invalid other(scalers) size"); utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, scalers.data_, block_size, num_blocks); + t_disp.Invoke(data_, scalers.data_, block_size, num_blocks, column_major); } #endif // ORT_EXTENDED_MINIMAL_BUILD } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/initializer.h b/onnxruntime/core/optimizer/initializer.h index dfe054ba1aced..78e3fd6a3d24e 100644 --- a/onnxruntime/core/optimizer/initializer.h +++ b/onnxruntime/core/optimizer/initializer.h @@ -86,7 +86,7 @@ class Initializer final { Initializer& sqrt(); - void scale_by_axis(const Initializer& other, int axis); + void scale_by_axis(const Initializer& other, int axis, bool column_major = false); #endif // ORT_EXTENDED_MINIMAL_BUILD private: std::string name_; diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 7c087ec77d9fe..959fcd6efdc3c 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -32,7 +32,7 @@ onnxruntime::NodeArg* AddCastNode(onnxruntime::Graph& graph, int64_t to_type, onnxruntime::ProviderType providerType) { // insert cast op to cast input - std::string node_name = graph.GenerateNodeName("InsertedCast_" + old_arg->Name()); + std::string node_name = graph.GenerateNodeName("InsertedPrecisionFreeCast_" + old_arg->Name()); auto* new_arg = &graph.GetOrCreateNodeArg(node_name, new_type); @@ -235,7 +235,8 @@ enum TypeGroup { Unknown = -1, Bool = 0, Integer = 1, - Float = 2, + Unsigned = 2, + Float = 3, }; TypeGroup GetTypeGroup(DataType type) { @@ -243,11 +244,14 @@ TypeGroup GetTypeGroup(DataType type) { return Bool; } - if (*type == "tensor(int16)" || *type == "tensor(int32)" || *type == "tensor(int64)" || *type == "tensor(int8)" || - *type == "tensor(uint16)" || *type == "tensor(uint32)" || *type == "tensor(uint64)" || *type == "tensor(uint8)") { + if (*type == "tensor(int16)" || *type == "tensor(int32)" || *type == "tensor(int64)" || *type == "tensor(int8)") { return Integer; } + if (*type == "tensor(uint16)" || *type == "tensor(uint32)" || *type == "tensor(uint64)" || *type == "tensor(uint8)") { + return Unsigned; + } + if (*type == "tensor(bfloat16)" || *type == "tensor(double)" || *type == "tensor(float)" || *type == "tensor(float16)") { return Float; } @@ -255,6 +259,22 @@ TypeGroup GetTypeGroup(DataType type) { return Unknown; } +int BitLength(DataType type) { + if (*type == "tensor(bool)") { + return 1; + } else if (*type == "tensor(uint8)" || *type == "tensor(int8)") { + return 8; + } else if (*type == "tensor(int16)" || *type == "tensor(uint16)" || *type == "tensor(bfloat16)" || *type == "tensor(float16)") { + return 16; + } else if (*type == "tensor(int32)" || *type == "tensor(uint32)" || *type == "tensor(float)") { + return 32; + } else if (*type == "tensor(int64)" || *type == "tensor(uint64)" || *type == "tensor(double)") { + return 64; + } else { + return -1; + } +} + /** Transformer to remove duplicate Cast nodes. */ class RemoveDuplicateCastTransformer : public GraphTransformer { public: @@ -262,6 +282,48 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { } private: + static bool UnsafeCast(DataType src_type, DataType dst_type, const Node& node) { + // This is not a complete cast optimisation pass, and is more conservative than it could be. + // For instance, certain integral -> floating point casts could be optimised but this is left to an explicit cast optimisation pass. + + // The comparison with "InsertedPrecisionFreeCast_" reflects cast nodes that are inserted by InsertCastTransformer. + // Such casts should not be considered as loss of precision - the inserted upcasts (f16 -> f32) and downcasts (f32 -> f16) are inserted to support kernels when on a CPU EP without F16 support. + auto src_type_group = GetTypeGroup(src_type); + auto dst_type_group = GetTypeGroup(dst_type); + if (Unknown == src_type_group || Unknown == dst_type_group) { + return true; + } + + // Do not remove any signed -> unsigned cast. + if ((src_type_group != Bool && src_type_group != Unsigned) && Unsigned == dst_type_group) { + return true; + } + + // Do not remove any floating point -> non floating point cast. + if (Float == src_type_group && Float != dst_type_group) { + return true; + } + + auto src_bit_length = BitLength(src_type); + auto dst_bit_length = BitLength(dst_type); + + // unsigned integer -> integer cast may overflow if the destination integer is smaller or equal to the source integer. + if (Unsigned == src_type_group && Integer == dst_type_group) { + return dst_bit_length <= src_bit_length; + } + + // integral -> floating cast may overflow if integer cannot be encoded in the mantissa. This check could be more precise. + if ((Integer == src_type_group || Unsigned == src_type_group) && Float == dst_type_group) { + return dst_bit_length <= src_bit_length; + } + + if ((*src_type == "tensor(float16)" && *dst_type == "tensor(bfloat16)") || (*src_type == "tensor(bfloat16)" && *dst_type == "tensor(float16)")) { + return true; + } + + return src_bit_length > dst_bit_length && (node.Name().compare(0, 26, "InsertedPrecisionFreeCast_")); + } + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override { auto output_args = graph.GetOutputs(); InlinedHashSet graph_outputs; @@ -293,17 +355,8 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { // - for each consumer cast node, it meets above condition for this optimization. auto src_type = node.InputDefs()[0]->Type(); auto dst_type = node.OutputDefs()[0]->Type(); - TypeGroup src_type_group = GetTypeGroup(src_type); - TypeGroup dst_type_group = GetTypeGroup(dst_type); - if (src_type_group == Unknown || dst_type_group == Unknown) { - continue; - } - - bool loss_precision_cast = false; - if (src_type_group > dst_type_group) { - loss_precision_cast = true; - } + bool loss_precision_cast = UnsafeCast(src_type, dst_type, node); size_t num_children = node.GetOutputEdgesCount(); bool inconsistent_casts = false; @@ -312,10 +365,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { if (output_node.OpType() == "Cast") { auto src_type1 = output_node.InputDefs()[0]->Type(); auto dst_type1 = output_node.OutputDefs()[0]->Type(); - TypeGroup src_type_group1 = GetTypeGroup(src_type1); - TypeGroup dst_type_group1 = GetTypeGroup(dst_type1); - if (src_type_group1 == Unknown || dst_type_group1 == Unknown || - (loss_precision_cast && dst_type_group1 > src_type_group1)) { + if (loss_precision_cast && UnsafeCast(dst_type1, src_type1, output_node)) { inconsistent_casts = true; break; } diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc index 6c91949e467ae..4505d4afdf1e0 100644 --- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc +++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc @@ -30,6 +30,23 @@ CostCheckResult PostLayoutTransformCostCheck(const api::GraphRef& graph, const a return OrtEPCostCheck(graph, node, perm, outputs_leading_to_transpose); } +#if defined(USE_CUDA) && ENABLE_CUDA_NHWC_OPS +const std::unordered_set& GetCUDALayoutSensitiveOps() { + static std::unordered_set cuda_nhwc_ops = []() { + return std::unordered_set{ + "BatchNormalization", + "Conv", + "ConvTranspose", + "GlobalMaxPool", + "MaxPool", + "GlobalAveragePool", + "AveragePool", + }; + }(); + return cuda_nhwc_ops; +} +#endif + /// /// Default function for checking if a node should have its layout changed. Allows EP specific adjustments to the /// default set of layout sensitive operators if required. @@ -49,17 +66,6 @@ bool ConvertNodeLayout(const api::NodeRef& node) { const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps(); // handle special cases -#if defined(USE_XNNPACK) - if (node.GetExecutionProviderType() == kXnnpackExecutionProvider) { - if (node.OpType() == "Resize") { - // XNNPACK supports NCHW and NHWC for Resize so we don't need to use the internal NHWC domain and wrap the Resize - // with Transpose nodes. EPAwareHandleResize will allow an NCHW <-> NHWC Transpose to be pushed through - // the Resize during transpose optimization. - return false; - } - } -#endif - #if defined(USE_JSEP) // TODO(fs-eire): Remove special case handing of JSEP once NHWC Resize implementation is fixed if (node.GetExecutionProviderType() == kJsExecutionProvider) { @@ -71,11 +77,16 @@ bool ConvertNodeLayout(const api::NodeRef& node) { } #endif - // #if defined(USE_CUDA) - // if (node.GetExecutionProviderType() == kCudaExecutionProvider) { - // Update as per https://github.com/microsoft/onnxruntime/pull/17200 with CUDA ops that support NHWC - // } - // #endif +#if defined(USE_CUDA) && ENABLE_CUDA_NHWC_OPS + if (node.GetExecutionProviderType() == kCudaExecutionProvider) { + if (layout_sensitive_ops.count(node.OpType())) { + const auto& cuda_nhwc_ops = GetCUDALayoutSensitiveOps(); + if (!cuda_nhwc_ops.count(node.OpType())) { + return false; + } + } + } +#endif return layout_sensitive_ops.count(node.OpType()) != 0; } diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc new file mode 100644 index 0000000000000..e944522c9c338 --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -0,0 +1,230 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/matmul_bn_fusion.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" + +namespace onnxruntime { + +namespace { +const std::vector>> ignorable_nodes{ + {"Reshape", {1, 5, 13, 14, 19}}, + {"Transpose", {1, 13}}}; +const std::pair> dest = {"BatchNormalization", {1, 6, 7, 9, 14, 15}}; +} // namespace + +bool NodeIsIgnorable(const Graph& graph, const Node& root_node, NodeIndex curr_node_index) { + const Node* curr_node = graph.GetNode(curr_node_index); + + // curr_node has different execution provider then it's parent or + // has output edge != 1 (this condition will handle the case when ignorable node + // is graph output i.e. a graph like this "MatMul->Transpose") + if (curr_node->GetExecutionProviderType() != root_node.GetExecutionProviderType() || + curr_node->GetOutputEdgesCount() != 1) { + return false; + } + + // curr_node can be any of the ignorable_nodes. + for (size_t index = 0; index < ignorable_nodes.size(); index++) { + if (graph_utils::IsSupportedOptypeVersionAndDomain(*curr_node, ignorable_nodes[index].first, ignorable_nodes[index].second)) { + return true; + } + } + + return false; +} + +std::optional MatchPath(const Graph& graph, const Node& root_node, NodeIndex curr_node_index) { + while (NodeIsIgnorable(graph, root_node, curr_node_index)) { + curr_node_index = graph.GetNode(curr_node_index)->OutputNodesBegin()->Index(); + } + + // curr_node is neither ignorable nor dest + const Node* curr_node = graph.GetNode(curr_node_index); + if (curr_node->OpType() != dest.first) { + return std::nullopt; + } + + if (curr_node->GetExecutionProviderType() == root_node.GetExecutionProviderType() && + graph_utils::IsSupportedOptypeVersionAndDomain(*curr_node, dest.first, dest.second)) { + return curr_node_index; + } + + // either curr_node has different execution provider or + // has invalid opset. + return std::nullopt; +} + +/* + * Given a MatMul node, it will verify the following pattern. + * MatMul GEMM + * | | + * Reshape ^ ---> Reshape ^ + * | | + * Transpose ^ Transpose ^ + * | + * BatchNormalization + * Note: ^ means there can be 0 or any occurrences of that node. + * Few example fusable pattern: + * - MatMul -> Reshape -> Transpose -> BatchNormalization ---> GEMM -> Reshape -> Transpose + * - MatMul -> Reshape -> BatchNormalization ---> GEMM -> Reshape + * - MatMul -> Transpose -> BatchNormalization ---> GEMM -> Transpose + * - MatMul -> Reshape -> Reshape -> BatchNormalization ---> GEMM -> Reshape -> Reshape + * - MatMul -> Reshape -> Transpose -> Reshape -> BatchNormalization ---> GEMM -> Reshape -> Transpose -> Reshape + * - MatMul -> BatchNormalization ---> GEMM + * Other Conditions: + * - B tensor of MatMul should be constant. + * - scale, B, mean, var tensors of BatchNormalization should be constant. + * - Every node in the path, except the BatchNormalization, should have only 1 output edge. + */ +bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {1, 9, 13}) || + node.GetOutputEdgesCount() != 1) { + return false; + } + + if (graph.NodeProducesGraphOutput(node)) { + return false; + } + + // because is not producing graph output, it means it will have a child node + NodeIndex child_node_index = node.OutputNodesBegin()->Index(); + std::optional batch_norm_index = MatchPath(graph, node, child_node_index); + if (!batch_norm_index.has_value()) { + return false; + } + + const Node* batch_norm_node = graph.GetNode(*batch_norm_index); + + // Check that the appropriate inputs to the Matmul and BN nodes are constants. + if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[1]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[1]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[2]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[3]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[4])) { + return false; + } + + // First output from BN is required. Others are optional. If any optional outputs exist we can't fuse. + const auto& output_defs = batch_norm_node->OutputDefs(); + if (output_defs.size() > 1) { + for (size_t i = 1, end = output_defs.size(); i < end; ++i) { + if (output_defs[i] != nullptr && output_defs[i]->Exists()) { + return false; + } + } + } + + return true; +} + +/* + * BatchNormalization: [https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_batch_normalization_operator_desc] + * Scale * ((Input - Mean) / sqrt(Variance + Epsilon)) + Bias // ignore the FusedActivation in the above definition, that's very specific to DML + * Expanding out the terms: + * Output = (Scale / sqrt(Variance + Epsilon)) * Input + (Scale / sqrt(Variance + Epsilon)) * -Mean + Bias + * Here, + * [Scale/sqrt(Variance + Epsilon)] is constant, and let's call it `alpha` + * [(Scale / sqrt(Variance + Epsilon)) * -Mean + Bias] is also constant, and let's call it `beta` + * Output = alpha * Input + beta, Input = B tensor of MatMul. + * + */ +Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { + NodeIndex child_node_index = matmul_node.OutputNodesBegin()->Index(); + NodeIndex batch_norm_node_index = MatchPath(graph, matmul_node, child_node_index).value(); + + Node& batch_norm_node = *graph.GetNode(batch_norm_node_index); // need mutable node, that's why extracting node from graph + + // only perform fusion if epsilon is present and is of float_32 type + auto epsilon_attribute = batch_norm_node.GetAttributes().find("epsilon"); + if (epsilon_attribute == batch_norm_node.GetAttributes().end() || + epsilon_attribute->second.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT) { + return Status::OK(); + } + const float epsilon = epsilon_attribute->second.f(); + + const onnx::TensorProto* scale_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[1]->Name()); + ORT_ENFORCE(scale_tensor); + const onnx::TensorProto* bias_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[2]->Name()); + ORT_ENFORCE(bias_tensor); + const onnx::TensorProto* mean_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[3]->Name()); + ORT_ENFORCE(mean_tensor); + const onnx::TensorProto* var_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[4]->Name()); + ORT_ENFORCE(var_tensor); + const onnx::TensorProto* matmul_b_tensor = graph_utils::GetConstantInitializer(graph, matmul_node.InputDefs()[1]->Name()); + ORT_ENFORCE(matmul_b_tensor); + + if (!optimizer_utils::IsFloatingPointDataType(*matmul_b_tensor) || + !optimizer_utils::IsFloatingPointDataType(*scale_tensor) || + !optimizer_utils::IsFloatingPointDataType(*bias_tensor) || + !optimizer_utils::IsFloatingPointDataType(*mean_tensor) || + !optimizer_utils::IsFloatingPointDataType(*var_tensor) || + scale_tensor->dims_size() != 1 || + bias_tensor->dims_size() != 1 || + mean_tensor->dims_size() != 1 || + var_tensor->dims_size() != 1 || + scale_tensor->dims(0) != matmul_b_tensor->dims(1) || + bias_tensor->dims(0) != matmul_b_tensor->dims(1) || + mean_tensor->dims(0) != matmul_b_tensor->dims(1) || + var_tensor->dims(0) != matmul_b_tensor->dims(1)) { + return Status::OK(); + } + + /* + * temp = scale / sqrt(var + epsilon) + * output = (temp * Input) - ((temp * mean) + bias) + */ + Initializer scale(*scale_tensor, graph.ModelPath()); + Initializer bias(*bias_tensor, graph.ModelPath()); + Initializer mean(*mean_tensor, graph.ModelPath()); + Initializer var(*var_tensor, graph.ModelPath()); + Initializer matmul_b(*matmul_b_tensor, graph.ModelPath()); + + var.add(epsilon); + var.sqrt(); + scale.div(var); // this is the temp + matmul_b.scale_by_axis(scale, 1, true); + + mean.mul(scale); + bias.sub(mean); + + // create B tensorProto for new Gemm node from initializer. + ONNX_NAMESPACE::TensorProto new_gemm_b_tensor(*matmul_b_tensor); + matmul_b.ToProto(new_gemm_b_tensor); + const std::string new_gemm_b_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmB_" + matmul_b_tensor->name()); + new_gemm_b_tensor.set_name(new_gemm_b_name); + NodeArg& new_gemm_b_node_arg = graph_utils::AddInitializer(graph, new_gemm_b_tensor); + + // create bias tensorProto for new Gemm node from initializer. + ONNX_NAMESPACE::TensorProto new_gemm_bias_tensor(*bias_tensor); + bias.ToProto(new_gemm_bias_tensor); + const std::string new_gemm_bias_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmBias"); + new_gemm_bias_tensor.set_name(new_gemm_bias_name); + NodeArg& new_gemm_bias_node_arg = graph_utils::AddInitializer(graph, new_gemm_bias_tensor); + + Node& gemm_node = graph.AddNode( + graph.GenerateNodeArgName("MatMulBnFusion_Gemm"), + "Gemm", + "Generated from Matmul BatchNormalization fusion", + {matmul_node.MutableInputDefs()[0], &new_gemm_b_node_arg, &new_gemm_bias_node_arg}, + matmul_node.MutableOutputDefs(), + nullptr, + kOnnxDomain); + + // Remove MatMul node. + Node* node = graph.GetNode(matmul_node.Index()); + graph_utils::RemoveNodeOutputEdges(graph, *node); + graph.RemoveNode(matmul_node.Index()); + + // Delete optional empty output defs. + // Delete BatchNormalization node and update the input of the child of BatchNormalization + batch_norm_node.MutableOutputDefs().resize(1); + NodeIndex batch_norm_parent_index = graph.GetNode(child_node_index)->OpType() == "BatchNormalization" ? gemm_node.Index() : batch_norm_node.InputNodesBegin()->Index(); + graph_utils::FinalizeNodeFusion(graph, *graph.GetNode(batch_norm_parent_index), batch_norm_node); + + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + return Status::OK(); +} +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.h b/onnxruntime/core/optimizer/matmul_bn_fusion.h new file mode 100644 index 0000000000000..7a43483cf37d4 --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { +/* + * This fusion submerges a BatchNormalization operator to it's super + * precedding MatMul operator, if and only if MatmulBNFusion::SatisfyCondition() + * is true. + */ +class MatmulBNFusion : public RewriteRule { + public: + MatmulBNFusion() : RewriteRule("MatMul_BatchNormalization_Fusion") {} + + std::vector TargetOpTypes() const noexcept override { + return {"MatMul"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/pad_fusion.cc b/onnxruntime/core/optimizer/pad_fusion.cc new file mode 100644 index 0000000000000..b25e7618802dd --- /dev/null +++ b/onnxruntime/core/optimizer/pad_fusion.cc @@ -0,0 +1,128 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/pad_fusion.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" + +namespace onnxruntime { + +/* + * It matches following pattern: + * Pad + * | + * Conv/MaxPool + */ +bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { + // if Pad has input axis, don't fuse it. + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Pad", {1, 2, 11, 13, 18, 19}) || + node.GetOutputEdgesCount() != 1 || + node.InputDefs().size() > 3) { + return false; + } + + if (graph.NodeProducesGraphOutput(node)) { + return false; + } + + const Node& child_node = *node.OutputNodesBegin(); + if (!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "Conv", {1, 11}) && + !graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "MaxPool", {1, 8, 10, 11, 12})) { + return false; + } + + // Don't fuse if MaxPool has optional output indices tensor because output indices tensor + // does not incorporate pad values. Basically if we allow the fusion, then dimension values + // of input tensor < dimension values of input tensor without fusion. + // This will cause the range of values for output indices tensor to be less than what it + // should have been. + + if (child_node.OutputDefs().size() > 1) { + return false; + } + + // conv or maxpool node must use explicit padding to perform this fusion. + if (child_node.GetAttributes().find("auto_pad") != child_node.GetAttributes().end() && + child_node.GetAttributes().at("auto_pad").s() != "NOTSET") { + return false; + } + + const NodeAttributes& pad_attributes = node.GetAttributes(); + if (pad_attributes.find("mode") != pad_attributes.end() && + pad_attributes.at("mode").s() != "constant") { + return false; + } + + // Since opset 11, and moved to inputs. + // Both of these should be initializer because we have to verify the values. + if (node.SinceVersion() >= 11) { + if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[1]) || + (node.InputDefs().size() > 2 && !graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[2]))) { + return false; + } + + // constant_value should be zero because Conv and MaxPool allow only 0 as padding value. + if (node.InputDefs().size() > 2) { + const auto* pad_constant_value_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[2]->Name()); + Initializer pad_constant_value{*pad_constant_value_proto, graph.ModelPath()}; + if (std::any_of(pad_constant_value.DataAsByteSpan().begin(), pad_constant_value.DataAsByteSpan().end(), [](const uint8_t byte) { return byte != 0; })) { + return false; + } + } + } else { + if (pad_attributes.find("value") != pad_attributes.end() && + pad_attributes.at("value").f() != 0.0) { + return false; + } + } + + return true; +} + +/* + * - For 1st two dimension Pads array's value should be zero and for rest of them values should >= 0 + */ +Status PadFusion::Apply(Graph& graph, Node& pad_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { + std::vector pads_values; + + if (pad_node.SinceVersion() >= 11) { + const auto* pads_proto = graph_utils::GetConstantInitializer(graph, pad_node.InputDefs()[1]->Name()); + Initializer pads{*pads_proto, graph.ModelPath()}; + pads_values.assign(pads.DataAsSpan().begin(), pads.DataAsSpan().end()); + } else { + pads_values.assign(pad_node.GetAttributes().at("pads").ints().begin(), pad_node.GetAttributes().at("pads").ints().end()); + } + + assert(static_cast(pads_values.size()) == (2 * static_cast(pad_node.InputDefs()[0]->Shape()->dim_size()))); + + uint32_t pads_size = static_cast(pads_values.size()); + // check if padding is applied only on feature dims + if (pads_values[0] != 0 || pads_values[1] != 0 || pads_values[pads_size / 2] != 0 || + pads_values[pads_size / 2 + 1] != 0) { + return Status::OK(); + } + + // check if padding is only positive + if (std::any_of(pads_values.begin(), pads_values.end(), [](int64_t value) { return value < 0; })) { + return Status::OK(); + } + + Node& child_node = *graph.GetNode(pad_node.OutputNodesBegin()->Index()); + auto child_pads = child_node.GetMutableAttributes()["pads"].mutable_ints(); + uint32_t child_pads_size = static_cast(child_pads->size()); + + for (uint32_t pads_index = 2, child_index = 0; pads_index < pads_size / 2; pads_index++, child_index++) { + child_pads->Set(child_index, child_pads->Get(child_index) + pads_values[pads_index]); + uint32_t mirrored_child_index = child_index + (child_pads_size / 2); + uint32_t mirrored_pad_index = pads_index + (pads_size / 2); + child_pads->Set(mirrored_child_index, child_pads->Get(mirrored_child_index) + pads_values[mirrored_pad_index]); + } + + graph_utils::RemoveNodeOutputEdges(graph, pad_node); + graph_utils::ReplaceNodeInput(child_node, 0, *pad_node.MutableInputDefs()[0]); + graph.RemoveNode(pad_node.Index()); + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + return Status::OK(); +} +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/pad_fusion.h b/onnxruntime/core/optimizer/pad_fusion.h new file mode 100644 index 0000000000000..a1b6978a83d1e --- /dev/null +++ b/onnxruntime/core/optimizer/pad_fusion.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { +/* + * This fusion submerges a Pad operator to it's child + * Conv or MaxPool operator, if and only if PadFusion::SatisfyCondition() + * is true. + */ +class PadFusion : public RewriteRule { + public: + PadFusion() : RewriteRule("Pad_Fusion") {} + + std::vector TargetOpTypes() const noexcept override { + return {"Pad"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index 3f1b2f0458bc0..1a4d3a0c18151 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -30,6 +30,7 @@ void Selectors::RegisterSelector(const OpVersionsAndSelector::OpVersionsMap& ops static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() { return {{"Gather", {}}, {"Reshape", {}}, + {"Expand", {}}, {"Flatten", {}}, {"Transpose", {}}, {"MaxPool", {12}}, diff --git a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc index e182b6c695d2f..546d52b6f1682 100644 --- a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc @@ -3,9 +3,10 @@ #include "core/optimizer/selectors_actions/selector_action_transformer.h" -#include #include +#include #include +#include #include #include "core/graph/op_identifier_utils.h" @@ -56,9 +57,9 @@ const SelectorActionRegistry::Entry* SelectorActionRegistry::LookUp(const std::s } #if !defined(ORT_MINIMAL_BUILD) -auto SelectorActionRegistry::LookUpByOpType(const std::string& op_type) const +auto SelectorActionRegistry::LookUpByOpTypeAndDomain(const std::string& op_type, const std::string& domain) const -> std::vector> { - const auto [range_begin, range_end] = op_type_to_entry_.equal_range(op_type); + const auto [range_begin, range_end] = op_type_to_entry_.equal_range(OpVersionsMapKey(op_type, domain)); std::vector> result{}; result.reserve(std::distance(range_begin, range_end)); std::transform(range_begin, range_end, std::back_inserter(result), @@ -93,20 +94,15 @@ static Status MatchAndProcess( Status status = Status::OK(); do { - // TODO: for now this just needs to support ONNX and Micrsoft Domain ops. - // If we ever had a transformer that was going to target non-ONNX ops, - // we'd need to rework a few things to include the op domain in the matches - if (node.Domain() != kOnnxDomain && node.Domain() != kMSDomain) { - break; - } - std::optional node_selection_opt{}; const SelectorActionRegistry::Entry* selector_action_entry_ptr = nullptr; - const auto selector_action_entries = selector_action_registry.LookUpByOpType(node.OpType()); + const auto selector_action_entries = + selector_action_registry.LookUpByOpTypeAndDomain(node.OpType(), node.Domain()); + std::string key = SelectorActionRegistry::OpVersionsMapKey(node.OpType(), node.Domain()); for (const auto& entry : selector_action_entries) { // check the supported versions if specified - const auto& versions = entry->ops_and_versions.find(node.OpType())->second; + const auto& versions = entry->ops_and_versions.find(key)->second; if (!versions.empty()) { if (std::find(versions.cbegin(), versions.cend(), node.SinceVersion()) == versions.cend()) { continue; diff --git a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.h b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.h index 7eb162cc693f1..5caa949ebbe93 100644 --- a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.h +++ b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.h @@ -38,8 +38,20 @@ struct NodeSelector { // class to manage a set of selector and associated actions class SelectorActionRegistry { public: + // The key is a string representing the op, optionally specifying the domain using ':' as the + // separator with domain as the first part and operator as the second part, ":" or "". + // For ops in kOnnxDomain, the domain should be left unspecified (""). + // For ops in other domains, the domain should be specified (":"). + // Ex: "Conv", "com.microsoft:Conv", "com.ms.internal.nhwc:Conv" using OpVersionsMap = std::unordered_map>; + // Helper function to create a key to OpVersionsMap using domain and op_type. + static std::string OpVersionsMapKey(std::string_view op_type, std::string_view domain = kOnnxDomain) { + return (domain == kOnnxDomain) + ? std::string{op_type} + : std::string{domain} + ":" + std::string{op_type}; + } + struct Entry { Entry(const std::string& name_in, #if !defined(ORT_MINIMAL_BUILD) @@ -95,14 +107,15 @@ class SelectorActionRegistry { #if !defined(ORT_MINIMAL_BUILD) // return registered Entry or nullptr if not found - auto LookUpByOpType(const std::string& op_type) const -> std::vector>; + auto LookUpByOpTypeAndDomain(const std::string& op_type, + const std::string& domain) const -> std::vector>; #endif // !defined(ORT_MINIMAL_BUILD) private: std::unordered_map name_to_entry_; #if !defined(ORT_MINIMAL_BUILD) - // auxiliary mapping to enable lookup by op type + // auxiliary mapping to enable lookup by op type or "domain:op type" std::unordered_multimap op_type_to_entry_; #endif // !defined(ORT_MINIMAL_BUILD) }; diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc index 07f391f2ae430..0d7ab70eba613 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.cc +++ b/onnxruntime/core/optimizer/transformer_memcpy.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "transformer_memcpy.h" +#include "core/common/logging/logging.h" #include "core/framework/kernel_registry_manager.h" #include "core/framework/execution_providers.h" #include "core/framework/utils.h" @@ -16,12 +17,12 @@ class TransformerMemcpyImpl { TransformerMemcpyImpl(onnxruntime::Graph& graph, const std::string& provider) : graph_(graph), provider_(provider) {} - bool ModifyGraph(const KernelRegistryManager& schema_registries); + bool ModifyGraph(const KernelRegistryManager& schema_registries, const logging::Logger& logger, int& copy_node_counter); private: void ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, InitializedTensorSet& initializers_consumed); void BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries); - void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input); + void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger); bool ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed); private: @@ -61,11 +62,21 @@ static const onnx::TensorProto* GetInitializer(const Graph& graph, const std::st // very simple GraphTransformer that uses TransformerMemcpyImpl for each graph // and mainly provides the subgraph recursion functionality -common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { +common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { for (auto& provider : provider_types_) { if (!utils::ProviderIsCpuBased(provider)) { TransformerMemcpyImpl copy_impl(graph, provider); - auto current_modified = copy_impl.ModifyGraph(registry_manager_); + + int copy_node_counter = 0; + auto current_modified = copy_impl.ModifyGraph(registry_manager_, logger, copy_node_counter); + if (copy_node_counter > 0 && provider == kCudaExecutionProvider) { + LOGS(logger, WARNING) << copy_node_counter << " Memcpy nodes are added to the graph " << graph.Name() + << " for " << provider + << ". It might have negative impact on performance (including unable to run CUDA graph). " + << "Set session_options.log_severity_level=1 to see the detail logs before this message."; + } + modified = modified || current_modified; break; } @@ -111,7 +122,9 @@ This transformer does not currently optimize copies between, e.g., two different */ -bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_registries) { +bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_registries, + const logging::Logger& logger, + int& copy_node_counter) { bool modified = false; InitializedTensorSet initializers_consumed; // find defs that require copy @@ -137,19 +150,22 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi // For inputs we need to create a copy node only when the input is connected to both provider // and non-provider nodes. Otherwise utils::CopyInputsAcrossDevices() will do the job. if (provider_input_defs_.count(arg) && non_provider_input_defs_.count(arg)) { - AddCopyNode(const_cast(arg), true); + AddCopyNode(const_cast(arg), true, logger); + copy_node_counter++; modified = true; } for (auto arg : non_provider_output_defs_) if (provider_input_defs_.count(arg)) { - AddCopyNode(arg, true); + AddCopyNode(arg, true, logger); + copy_node_counter++; modified = true; } for (auto arg : provider_output_defs_) if (non_provider_input_defs_.count(arg)) { - AddCopyNode(arg, false); + AddCopyNode(arg, false, logger); + copy_node_counter++; modified = true; } @@ -176,7 +192,8 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi // (the name will be the same as the parent node's implicit input) const auto* node_arg_in_current_graph_level = *provider_input_defs_.find(arg); - AddCopyNode(const_cast(node_arg_in_current_graph_level), true); + AddCopyNode(const_cast(node_arg_in_current_graph_level), true, logger); + copy_node_counter++; modified = true; } } @@ -232,7 +249,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg if (!arg->Exists()) continue; - if (kci && kci->kernel_def->IsOutputOnCpu(i)) + if (utils::IsOutputOnCpu(node, kci, i)) non_provider_output_defs_.insert(arg); else provider_output_defs_.insert(arg); @@ -291,13 +308,13 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, co if (!kci || !utils::IsInputOnCpu(it, kci, arg_input_index)) provider_input_nodes_[arg].insert(&it); } if (arg_output_index != -1) { - if (!kci || !kci->kernel_def->IsOutputOnCpu(arg_output_index)) provider_output_nodes_[arg].insert(&it); + if (!kci || !utils::IsOutputOnCpu(it, kci, arg_output_index)) provider_output_nodes_[arg].insert(&it); } } } } -void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input) { +void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger) { // create unique name for new def std::string new_def_name = graph_.GenerateNodeArgName(arg->Name() + "_" + provider_); @@ -309,6 +326,9 @@ void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input std::string new_node_name = graph_.GenerateNodeName("Memcpy"); const auto op_name = is_input ? "MemcpyFromHost" : "MemcpyToHost"; + LOGS(logger, INFO) << "Add " << op_name << (is_input ? " after " : " before ") << arg->Name() + << " for " << provider_; + auto& new_node = graph_.AddNode(new_node_name, op_name, "Copy from/to host memory", std::vector{src_arg}, std::vector{dst_arg}); @@ -384,8 +404,8 @@ bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& ker // normally initializers are only inputs, but things may change with ops like assign ORT_THROW_IF_ERROR(Node::ForEachWithIndex( p_node->OutputDefs(), - [kci, &dup_replacements](const onnxruntime::NodeArg& arg, size_t index) { - if (kci->kernel_def->IsOutputOnCpu(index)) { + [kci, &p_node, &dup_replacements](const onnxruntime::NodeArg& arg, size_t index) { + if (utils::IsOutputOnCpu(*p_node, kci, index)) { ORT_ENFORCE(dup_replacements.find(&arg) == dup_replacements.end()); } return Status::OK(); diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc index f4f3505128737..8eaac3d34c3af 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc @@ -17,7 +17,7 @@ static bool EPAwareHandleResize(HandlerArgs& args) { // layout. Due to that, only push a Transpose through a Resize once it is assigned and we know it's being handled // by an EP that supports multiple layouts. Currently that's the CPU and XNNPACK EPs. const auto ep_type = args.node.GetExecutionProviderType(); - if (ep_type == kCpuExecutionProvider || ep_type == kXnnpackExecutionProvider) { + if (ep_type == kCpuExecutionProvider) { // allow NCHW <-> NHWC for now. not clear any other sort of transpose has a valid usage in a real model int64_t rank_int = gsl::narrow_cast(args.perm.size()); if (rank_int == 4) { diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 3d03abf5b7ebc..4553e7ee18913 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -365,7 +365,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, Slice); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 11, Dropout); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, NonMaxSuppression); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, IsInf); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 19, IsInf); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 15, float, RoiAlign); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 15, double, RoiAlign); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ReverseSequence); @@ -682,9 +682,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Ga class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 15, ScatterND); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 15, ScatterElements); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13, Identity); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, IsNaN); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, IsNaN); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, IsNaN); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, float, IsNaN); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, double, IsNaN); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, MLFloat16, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool, NonZero); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, NonZero); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, NonZero); @@ -798,7 +798,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 18, If); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, RoiAlign); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, double, RoiAlign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, GridSample); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 19, float, GridSample); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 17, ScatterND); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, string, Where); @@ -960,6 +960,20 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Sh // Opset 20 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, GridSample); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, GridSample); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, AffineGrid); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, AffineGrid); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN); +#if !defined(DISABLE_FLOAT8_TYPES) +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FNUZ, IsNaN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2, IsNaN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2FNUZ, IsNaN); +#endif +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf); // !!PLEASE READ BELOW!! Following that, add new entries above this comment @@ -1492,7 +1506,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { Dropout)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2389,6 +2403,20 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { // Opset 20 BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +#if !defined(DISABLE_FLOAT8_TYPES) + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +#endif + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index b40f0edb2a8e1..9c55d37f550f4 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -245,6 +245,26 @@ struct ProviderHostCPUImpl : ProviderHostCPU { subgraph_session_state); } + Status WhisperBeamSearch__Compute(const contrib::transformers::WhisperBeamSearch* p, OpKernelContext* ctx) override { + return p->contrib::transformers::WhisperBeamSearch::Compute(ctx); + } + + void BeamSearchParameters__ParseFromAttributes(contrib::transformers::BeamSearchParameters* p, const OpKernelInfo& info) override { + p->contrib::transformers::BeamSearchParameters::ParseFromAttributes(info); + } + + void GreedySearchParameters__ParseFromAttributes(contrib::transformers::GreedySearchParameters* p, const OpKernelInfo& info) override { + p->contrib::transformers::GreedySearchParameters::ParseFromAttributes(info); + } + + void SamplingParameters__ParseFromAttributes(contrib::transformers::SamplingParameters* p, const OpKernelInfo& info) override { + p->contrib::transformers::SamplingParameters::ParseFromAttributes(info); + } + + void WhisperBeamSearchParameters__ParseFromAttributes(contrib::transformers::WhisperBeamSearchParameters* p, const OpKernelInfo& info) override { + p->contrib::transformers::WhisperBeamSearchParameters::ParseFromAttributes(info); + } + void GreedySearch__Init(contrib::transformers::GreedySearch* p, const OpKernelInfo& info) override { p->contrib::transformers::GreedySearch::Init(info); } diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index 7dc80f44b53cd..8dee1cd620282 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -8,8 +8,13 @@ class LongformerAttentionBase; class AttentionBase; namespace transformers { class BeamSearch; +class WhisperBeamSearch; class GreedySearch; class Sampling; +struct BeamSearchParameters; +struct GreedySearchParameters; +struct SamplingParameters; +struct WhisperBeamSearchParameters; } // namespace transformers } // namespace contrib @@ -169,6 +174,15 @@ struct ProviderHostCPU { const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0; + virtual Status WhisperBeamSearch__Compute(const contrib::transformers::WhisperBeamSearch* p, OpKernelContext* ctx) = 0; + + virtual void BeamSearchParameters__ParseFromAttributes(contrib::transformers::BeamSearchParameters* p, const OpKernelInfo& info) = 0; + + virtual void GreedySearchParameters__ParseFromAttributes(contrib::transformers::GreedySearchParameters* p, const OpKernelInfo& info) = 0; + + virtual void SamplingParameters__ParseFromAttributes(contrib::transformers::SamplingParameters* p, const OpKernelInfo& info) = 0; + + virtual void WhisperBeamSearchParameters__ParseFromAttributes(contrib::transformers::WhisperBeamSearchParameters* p, const OpKernelInfo& info) = 0; // GreedySearch virtual void GreedySearch__Init(contrib::transformers::GreedySearch* p, const OpKernelInfo& info) = 0; diff --git a/onnxruntime/core/providers/cpu/generator/random.cc b/onnxruntime/core/providers/cpu/generator/random.cc index b63c0d2161ad5..dfa27f1f44d5a 100644 --- a/onnxruntime/core/providers/cpu/generator/random.cc +++ b/onnxruntime/core/providers/cpu/generator/random.cc @@ -428,4 +428,14 @@ template Status MultinomialComputeShared(AllocatorPtr& alloc, std::default_random_engine& generator, Tensor& Y); +#if !defined(DISABLE_CONTRIB_OPS) +// used by onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h +template Status MultinomialComputeShared(AllocatorPtr& alloc, + const Tensor& X, + const int64_t batch_size, + const int64_t num_classes, + const int64_t num_samples, + std::default_random_engine& generator, + Tensor& Y); +#endif } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 3192c8573c5c0..1d524a90302e7 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -967,7 +967,7 @@ Status Xor::Compute(OpKernelContext* context) const { }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array() ^ per_iter_bh.EigenInput1().array(); + per_iter_bh.EigenInput0().array() != per_iter_bh.EigenInput1().array(); }}; UntypedBroadcastTwo(*context, funcs, 1.0); diff --git a/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h b/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h index 8507d87fd2442..a5d46aff83b50 100644 --- a/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h +++ b/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #pragma once @@ -22,11 +23,17 @@ class BatchNormHelper { const Tensor* B, const Tensor* mean, const Tensor* var, - bool is_spatial = true) { + bool is_spatial = true, + bool is_nhwc = false) { const auto& x_dims = X->Shape().GetDims(); // If x_dims size < 2, num_channels defaults to 1. - int64_t num_channels = x_dims.size() > 1 ? x_dims[1] : 1; + int64_t num_channels; + if (is_nhwc) { + num_channels = x_dims.size() > 1 ? x_dims[x_dims.size() - 1] : 1; + } else { + num_channels = x_dims.size() > 1 ? x_dims[1] : 1; + } // the first 2 are respectively - N and C. int num_feature_dims = x_dims.size() > 1 ? static_cast(x_dims.size() - 2) : 0; @@ -109,7 +116,7 @@ class BatchNormHelper { return common::Status::OK(); } - static void NormalizeDims(const TensorShape& x_shape, std::vector& new_dims) { + static void NormalizeDims(const TensorShape& x_shape, std::vector& new_dims, bool is_nhwc = false) { new_dims.clear(); auto orig_dims = x_shape.GetDims(); ORT_ENFORCE(orig_dims.size() < 6, @@ -122,13 +129,19 @@ class BatchNormHelper { auto rank = x_shape.NumDimensions(); auto num_samples = rank > 0 ? orig_dims[0] : 1; // NCHW - auto num_channels = rank > 1 ? orig_dims[1] : 1; - auto height = rank > 2 ? orig_dims[2] : 1; + const size_t channel_dim = is_nhwc ? rank - 1 : 1; + const size_t height_dim = is_nhwc ? 1 : 2; + auto num_channels = rank > 1 ? orig_dims[channel_dim] : 1; + auto height = rank > 2 ? orig_dims[height_dim] : 1; int64_t width = 1; - new_dims = {num_samples, num_channels, height, width}; + if (is_nhwc) { + new_dims = {num_samples, height, width, num_channels}; + } else { + new_dims = {num_samples, num_channels, height, width}; + } } }; } // namespace onnxruntime #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) -#endif \ No newline at end of file +#endif diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h b/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h index a4d67ec63f0c2..4b3b934834ac8 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h @@ -14,6 +14,7 @@ * limitations under the License. */ /* Modifications Copyright (c) Microsoft. */ +// Copyright (c) 2023 NVIDIA Corporation. #pragma once @@ -44,17 +45,19 @@ struct ConvTransposeAttributes : public ConvAttributes { }; Status PrepareForCompute(OpKernelContext* context, bool has_bias, Prepare& p, - bool dynamic_padding = false, const TensorShape* filter_shape = nullptr) const { + bool dynamic_padding = false, const TensorShape* filter_shape = nullptr, + bool is_nhwc = false) const { const Tensor* X = context->Input(0); const Tensor* F = (filter_shape != nullptr) ? nullptr : context->Input(1); const TensorShape& F_Shape = (filter_shape != nullptr) ? *filter_shape : F->Shape(); const Tensor* Pads = dynamic_padding ? context->Input(2) : nullptr; const Tensor* B = has_bias ? (dynamic_padding ? context->Input(3) : context->Input(2)) : nullptr; - TensorShape input_shape = X->Shape().Slice(2); - const int64_t num_input_channels = X->Shape()[1]; + const int rank = static_cast(X->Shape().NumDimensions()); + TensorShape input_shape = X->Shape().Slice(is_nhwc ? 1 : 2, is_nhwc ? rank - 1 : rank); + const int64_t num_input_channels = is_nhwc ? X->Shape()[rank - 1] : X->Shape()[1]; const int64_t N = X->Shape()[0]; - const int64_t num_output_channels_multiplier = F_Shape[1]; + const int64_t num_output_channels_multiplier = is_nhwc ? F_Shape[3] : F_Shape[1]; const int64_t num_output_channels = num_output_channels_multiplier * group; // input validations @@ -85,7 +88,7 @@ struct ConvTransposeAttributes : public ConvAttributes { } TensorShapeVector kernel_shape; - ORT_RETURN_IF_ERROR(ComputeKernelShape(F_Shape, kernel_shape)); + ORT_RETURN_IF_ERROR(ComputeKernelShape(F_Shape, kernel_shape, is_nhwc)); TensorShapeVector local_output_padding(output_padding); if (local_output_padding.empty()) { @@ -115,7 +118,7 @@ struct ConvTransposeAttributes : public ConvAttributes { TensorShapeVector Y_dims; ComputePadsAndOutputShape(input_shape, num_output_channels, kernel_shape, - local_strides, local_dilations, local_output_padding, N, &local_pads, &Y_dims); + local_strides, local_dilations, local_output_padding, N, &local_pads, &Y_dims, is_nhwc); TensorShape Yshape(Y_dims); Tensor* Y = context->Output(0, Yshape); @@ -137,9 +140,14 @@ struct ConvTransposeAttributes : public ConvAttributes { void ComputePadsAndOutputShape(TensorShape input_shape, int64_t output_channel, const TensorShapeVector& kernel_shape, const TensorShapeVector& p_strides, const TensorShapeVector& p_dilations, const TensorShapeVector& p_output_padding, const int64_t N, - ConvPadVector* p_pads, TensorShapeVector* output_shape_p) const { + ConvPadVector* p_pads, TensorShapeVector* output_shape_p, + bool is_nhwc = false) const { size_t output_shape_size = output_shape.size(); - output_shape_p->insert(output_shape_p->begin(), {N, output_channel}); + if (is_nhwc) { + output_shape_p->insert(output_shape_p->begin(), {N}); + } else { + output_shape_p->insert(output_shape_p->begin(), {N, output_channel}); + } size_t rank = input_shape.NumDimensions(); for (size_t dim = 0; dim < rank; ++dim) { @@ -163,6 +171,9 @@ struct ConvTransposeAttributes : public ConvAttributes { ORT_ENFORCE(dim_size > 0, "Invalid input shape: ", input_shape.ToString()); output_shape_p->push_back(dim_size); } + if (is_nhwc) { + output_shape_p->push_back(output_channel); + } } TensorShapeVector output_padding; diff --git a/onnxruntime/core/providers/cpu/nn/instance_norm_helper.h b/onnxruntime/core/providers/cpu/nn/instance_norm_helper.h index 48e54ac7eeefb..9a2a710fd291a 100644 --- a/onnxruntime/core/providers/cpu/nn/instance_norm_helper.h +++ b/onnxruntime/core/providers/cpu/nn/instance_norm_helper.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #pragma once @@ -8,13 +9,16 @@ #include "core/framework/tensor.h" #endif #include +#include namespace onnxruntime { class InstanceNormHelper { public: - static common::Status ValidateInputs(const Tensor* input, const Tensor* scale, const Tensor* B) { - if (input->Shape().NumDimensions() < 3) { + static common::Status ValidateInputs(const Tensor* input, const Tensor* scale, const Tensor* B, + bool is_nhwc = false) { + const auto rank = input->Shape().NumDimensions(); + if (rank < 3) { std::ostringstream ostr; ostr << "Invalid input data: number of dimensions is less than 3: " << input->Shape().NumDimensions(); return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostr.str()); @@ -24,10 +28,13 @@ class InstanceNormHelper { ostr << "Invalid input scale: number of dimensions is not 1: " << scale->Shape().NumDimensions(); return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostr.str()); } - if (scale->Shape().Size() != input->Shape().GetDims()[1]) { + auto in_dims = input->Shape().GetDims(); + auto in_channels = is_nhwc ? in_dims[rank - 1] : in_dims[1]; + + if (scale->Shape().Size() != in_channels) { std::ostringstream ostr; - ostr << "Mismatch between input data and scale: size of scale != input channel count " - << scale->Shape().Size() << " vs. " << input->Shape().GetDims()[1]; + ostr << "Mismatch between input data and scale: size of scale != input channel count " << scale->Shape().Size() + << " vs. " << in_channels; return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostr.str()); } @@ -37,10 +44,10 @@ class InstanceNormHelper { return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostr.str()); } - if (B->Shape().Size() != input->Shape().GetDims()[1]) { + if (B->Shape().Size() != in_channels) { std::ostringstream ostr; - ostr << "Mismatch between input data and B: size of B != input channel count " - << B->Shape().Size() << " vs. " << input->Shape().GetDims()[1]; + ostr << "Mismatch between input data and B: size of B != input channel count " << B->Shape().Size() << " vs. " + << in_channels; return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostr.str()); } diff --git a/onnxruntime/core/providers/cpu/nn/pool_attributes.h b/onnxruntime/core/providers/cpu/nn/pool_attributes.h index 54f41f09f4b24..118cb4a3ba4bd 100644 --- a/onnxruntime/core/providers/cpu/nn/pool_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/pool_attributes.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #pragma once @@ -98,28 +99,34 @@ struct PoolAttributes { TensorShapeVector SetOutputSize(const TensorShape& input_shape, int64_t output_channel, - TensorShapeVector* actual_pads) const { + TensorShapeVector* actual_pads, + bool is_nhwc = false) const { ORT_ENFORCE(input_shape.Size() > 0 || input_shape[0] == 0, "Invalid input shape. Only N can be zero. Got:", input_shape); TensorShapeVector output_dims; int64_t N = input_shape[0]; - InferOutputSize(input_shape.GetDims(), &output_dims, actual_pads); - - output_dims.insert(output_dims.begin(), {N, output_channel}); - + InferOutputSize(input_shape.GetDims(), &output_dims, actual_pads, is_nhwc); + if (is_nhwc) { + output_dims.insert(output_dims.begin(), N); + output_dims.push_back(output_channel); + } else { + output_dims.insert(output_dims.begin(), {N, output_channel}); + } return output_dims; } void InferOutputSize(gsl::span input_dims, TensorShapeVector* output_dims, - TensorShapeVector* actual_pads) const { + TensorShapeVector* actual_pads, + bool is_nhwc = false) const { ORT_ENFORCE(input_dims.size() >= 2); if (global_pooling) { output_dims->assign(input_dims.size() - 2, 1); } else { for (size_t dim = 0; dim < input_dims.size() - 2; ++dim) { int64_t dim_size = 0; - ComputeSizePadDilations(static_cast(input_dims[dim + 2]), + auto spatial_dim = is_nhwc ? input_dims[dim + 1] : input_dims[dim + 2]; + ComputeSizePadDilations(static_cast(spatial_dim), strides[dim], kernel_shape[dim], &actual_pads->at(dim), diff --git a/onnxruntime/core/providers/cpu/tensor/affine_grid.cc b/onnxruntime/core/providers/cpu/tensor/affine_grid.cc new file mode 100644 index 0000000000000..15900ba553983 --- /dev/null +++ b/onnxruntime/core/providers/cpu/tensor/affine_grid.cc @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cpu/tensor/affine_grid.h" + +#include "core/common/common.h" +#include "core/providers/op_kernel_type_control.h" +#include "core/util/math_cpuonly.h" +#include +#include "Eigen/src/Core/Map.h" +#include +#include "core/common/eigen_common_wrapper.h" + +namespace onnxruntime { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + AffineGrid, \ + 20, \ + T, \ + KernelDefBuilder() \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + AffineGrid); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(double) + +template +void generate_base_grid_2d(int64_t H, int64_t W, bool align_corners, Eigen::Matrix& base_grid) { + Eigen::VectorXf row_vec = Eigen::VectorXf::LinSpaced(static_cast(W), -1, 1); + if (!align_corners) { + row_vec = row_vec * (W - 1) / W; + } + Eigen::VectorXf col_vec = Eigen::VectorXf::LinSpaced(static_cast(H), -1, 1); + if (!align_corners) { + col_vec = col_vec * (H - 1) / H; + } + + base_grid.resize(static_cast(H * W), 2); + for (Eigen::Index j = 0; j < H; j++) { + for (Eigen::Index i = 0; i < W; i++) { + base_grid.row(j * static_cast(W) + i) << row_vec(i), col_vec(j); + } + } +} + +template +void generate_base_grid_3d(int64_t D, int64_t H, int64_t W, bool align_corners, Eigen::Matrix& base_grid) { + Eigen::VectorXf row_vec = Eigen::VectorXf::LinSpaced(static_cast(W), -1, 1); + if (!align_corners) { + row_vec = row_vec * (W - 1) / W; + } + Eigen::VectorXf col_vec = Eigen::VectorXf::LinSpaced(static_cast(H), -1, 1); + if (!align_corners) { + col_vec = col_vec * (H - 1) / H; + } + Eigen::VectorXf slice_vec = Eigen::VectorXf::LinSpaced(static_cast(D), -1, 1); + if (!align_corners) { + slice_vec = slice_vec * (D - 1) / D; + } + + base_grid.resize(static_cast(D * H * W), 3); + for (Eigen::Index k = 0; k < D; k++) { + for (Eigen::Index j = 0; j < H; j++) { + for (Eigen::Index i = 0; i < W; i++) { + base_grid.row(k * static_cast(H * W) + j * static_cast(W) + i) << row_vec(i), col_vec(j), slice_vec(k); + } + } + } +} + +template +void affine_grid_generator_2d(const Tensor* theta, const Eigen::Matrix& base_grid_transposed, int64_t batch_num, int64_t H, int64_t W, Tensor* grid) { + const Eigen::StorageOptions option = Eigen::RowMajor; + auto theta_batch_offset = batch_num * 2 * 3; + const T* theta_data = theta->Data() + theta_batch_offset; + const Eigen::Matrix theta_R{{theta_data[0], theta_data[1]}, {theta_data[3], theta_data[4]}}; + const Eigen::Array theta_T(theta_data[2], theta_data[5]); + + auto grid_batch_offset = batch_num * H * W * 2; + T* grid_data = grid->MutableData() + grid_batch_offset; + Eigen::Map> grid_matrix(grid_data, narrow(H * W), 2); + grid_matrix = ((theta_R * base_grid_transposed).array().colwise() + theta_T).matrix().transpose(); +} + +template +void affine_grid_generator_3d(const Tensor* theta, const Eigen::Matrix& base_grid_transposed, int64_t batch_num, int64_t D, int64_t H, int64_t W, Tensor* grid) { + const Eigen::StorageOptions option = Eigen::RowMajor; + auto theta_batch_offset = batch_num * 3 * 4; + const T* theta_data = theta->Data() + theta_batch_offset; + const Eigen::Matrix theta_R{ + {theta_data[0], theta_data[1], theta_data[2]}, + {theta_data[4], theta_data[5], theta_data[6]}, + {theta_data[8], theta_data[9], theta_data[10]}}; + const Eigen::Array theta_T(theta_data[3], theta_data[7], theta_data[11]); + + auto grid_batch_offset = batch_num * D * H * W * 3; + T* grid_data = grid->MutableData() + grid_batch_offset; + Eigen::Map> grid_matrix(grid_data, narrow(D * H * W), 3); + grid_matrix = ((theta_R * base_grid_transposed).array().colwise() + theta_T).matrix().transpose(); +} + +template +Status AffineGrid::Compute(OpKernelContext* context) const { + const Tensor* theta = context->Input(0); + const TensorShape& theta_shape = theta->Shape(); + if (theta_shape.NumDimensions() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "AffineGrid : Input theta tensor dimension is not 3"); + } + + const Tensor* size = context->Input(1); + const TensorShape& size_shape = size->Shape(); + const int64_t* size_data = size->Data(); + + if (size_shape.GetDims()[0] == 4 /*&& get_check_2d_grid_sample_consistency(theta_shape, size_shape, N, C, H, W)*/) { + int64_t N = size_data[0], H = size_data[2], W = size_data[3]; + + TensorShape grid_shape{N, H, W, 2}; + auto grid = context->Output(0, grid_shape); + + Eigen::Matrix base_grid; + generate_base_grid_2d(H, W, align_corners_, base_grid); + Eigen::Matrix base_grid_transposed = base_grid.transpose(); + + std::function fn = [theta, base_grid_transposed, H, W, grid](ptrdiff_t batch_num) { + affine_grid_generator_2d(theta, base_grid_transposed, batch_num, H, W, grid); + }; + + concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), narrow(N), std::move(fn), 0); + } else if (size_shape.GetDims()[0] == 5 /*&& get_check_2d_grid_sample_consistency(theta_shape, size_shape, N, C, H, W)*/) { + int64_t N = size_data[0], D = size_data[2], H = size_data[3], W = size_data[4]; + + TensorShape grid_shape{N, D, H, W, 3}; + auto grid = context->Output(0, grid_shape); + + Eigen::Matrix base_grid; + generate_base_grid_3d(D, H, W, align_corners_, base_grid); + Eigen::Matrix base_grid_transposed = base_grid.transpose(); + + std::function fn = [theta, base_grid_transposed, D, H, W, grid](ptrdiff_t batch_num) { + affine_grid_generator_3d(theta, base_grid_transposed, batch_num, D, H, W, grid); + }; + + concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), narrow(N), std::move(fn), 0); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "AffineGrid : Invalidate size - length of size should be 4 or 5."); + } + return Status::OK(); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/affine_grid.h b/onnxruntime/core/providers/cpu/tensor/affine_grid.h new file mode 100644 index 0000000000000..5ffe660e986f2 --- /dev/null +++ b/onnxruntime/core/providers/cpu/tensor/affine_grid.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { + +template +class AffineGrid final : public OpKernel { + public: + AffineGrid(const OpKernelInfo& info) : OpKernel(info) { + int64_t align_corners = info.GetAttrOrDefault("align_corners", 0); + align_corners_ = (align_corners != 0); + } + + Status Compute(OpKernelContext* context) const override; + + private: + bool align_corners_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/grid_sample.cc b/onnxruntime/core/providers/cpu/tensor/grid_sample.cc index c58a7d8337114..a83ba378d7f1e 100644 --- a/onnxruntime/core/providers/cpu/tensor/grid_sample.cc +++ b/onnxruntime/core/providers/cpu/tensor/grid_sample.cc @@ -11,17 +11,23 @@ namespace onnxruntime { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_CPU_OPERATOR_TYPED_KERNEL( \ - GridSample, \ - 16, \ - T, \ - KernelDefBuilder() \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ - GridSample); +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(GridSample, kOnnxDomain, 16, 19, T, kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + GridSample); + +#define REGISTER_KERNEL_TYPED_20(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX(GridSample, kOnnxDomain, 20, T, kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + GridSample); REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED_20(float) +REGISTER_KERNEL_TYPED_20(double) // Restore normalized location to actual image location // When align_corners is true: @@ -44,16 +50,15 @@ T GsDenormalize(T n, int64_t length, bool align_corners) { } // Reflect by the near border till within the borders -// Use float for borders to avoid potential issues with integer T template -T GsReflect(T x, float x_min, float x_max) { - float dx = {}; - float fx = static_cast(x); - float range = x_max - x_min; +T GsReflect(T x, T x_min, T x_max) { + T dx = {}; + T fx = static_cast(x); + T range = x_max - x_min; if (fx < x_min) { dx = x_min - fx; int n = static_cast(dx / range); - float r = dx - n * range; + T r = dx - n * range; if (n % 2 == 0) { fx = x_min + r; } else { @@ -62,7 +67,7 @@ T GsReflect(T x, float x_min, float x_max) { } else if (fx > x_max) { dx = fx - x_max; int n = static_cast(dx / range); - float r = dx - n * range; + T r = dx - n * range; if (n % 2 == 0) { fx = x_max - r; } else { @@ -75,9 +80,9 @@ T GsReflect(T x, float x_min, float x_max) { // Calculate cubic convolution interpolation coefficients // ROBERT G. KEYS https://ieeexplore.ieee.org/document/1163711 -// Use float to avoid potential issues with integer T -void GsGetCubicCoeffs(float x, float coeffs[4]) { - constexpr float cubic_alpha = -0.75f; +template +void GsGetCubicCoeffs(T x, T coeffs[4]) { + constexpr T cubic_alpha = -0.75f; x = std::abs(x); coeffs[0] = ((cubic_alpha * (x + 1) - 5 * cubic_alpha) * (x + 1) + 8 * cubic_alpha) * (x + 1) - 4 * cubic_alpha; coeffs[1] = ((cubic_alpha + 2) * x - (cubic_alpha + 3)) * x * x + 1; @@ -86,9 +91,9 @@ void GsGetCubicCoeffs(float x, float coeffs[4]) { } template -T GsBicubicInterpolate(T p[4][4], float x, float y) { - float v[4] = {}; - float coeffs[4] = {}; +T GsBicubicInterpolate(T p[4][4], T x, T y) { + T v[4] = {}; + T coeffs[4] = {}; GsGetCubicCoeffs(x, coeffs); for (int64_t i = 0; i < 4; i++) { v[i] = coeffs[0] * p[i][0] + coeffs[1] * p[i][1] + coeffs[2] * p[i][2] + coeffs[3] * p[i][3]; @@ -98,7 +103,7 @@ T GsBicubicInterpolate(T p[4][4], float x, float y) { } template -T GridSample::PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, int64_t W, float border[/* 4 */]) const { +T GridSample::PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, int64_t W, T border[/* 4 */]) const { T pixel = {}; // default 0 if (padding_mode_ == Zeros) { if (c >= 0 && c < W && r >= 0 && r < H) { @@ -116,6 +121,27 @@ T GridSample::PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, in return pixel; } +template +T GridSample::PixelAtGrid3D(const T* image, int64_t d, int64_t h, int64_t w, int64_t D, int64_t H, int64_t W, T border[/* 6 */]) const { + T pixel = {}; // default 0 + if (padding_mode_ == Zeros) { + if (w >= 0 && w < W && h >= 0 && h < H && d >= 0 && d < D) { + pixel = image[d * H * W + h * W + w]; + } + } else if (padding_mode_ == Border) { + w = std::clamp(w, 0, W - 1); + h = std::clamp(h, 0, H - 1); + d = std::clamp(d, 0, D - 1); + pixel = image[d * H * W + h * W + w]; + } else { // (padding_mode_ == Reflection) + w = static_cast(GsReflect(static_cast(w), border[0], border[3])); + h = static_cast(GsReflect(static_cast(h), border[1], border[4])); + d = static_cast(GsReflect(static_cast(d), border[2], border[5])); + pixel = image[d * H * W + h * W + w]; + } + return pixel; +} + // When grid sampling, padding is applied before interpolation. // For instance, in bilinear mode and zeros padding-mode, pixel p at actual // image location (-0.5, -0.5) @@ -134,113 +160,203 @@ Status GridSample::Compute(OpKernelContext* context) const { const auto& input_dims = input->Shape(); const auto& grid_dims = grid->Shape(); - if (input_dims.NumDimensions() != 4 || grid_dims.NumDimensions() != 4) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Only 4-D tensor is supported"); - } + int64_t data_dims = input_dims.NumDimensions() - 2; + ORT_ENFORCE(static_cast(grid_dims.NumDimensions()) == data_dims + 2, + "grid dimensions must be ", data_dims + 2, "for input dimension of ", data_dims); + + ORT_ENFORCE(grid_dims[grid_dims.NumDimensions() - 1] == data_dims, + "Last dimension of grid: ", grid_dims[grid_dims.NumDimensions() - 1], ", expect ", data_dims); + + ORT_ENFORCE(input_dims.NumDimensions() == 4 || input_dims.NumDimensions() == 5, "Only 4-D or 5-D tensor is supported"); auto N = input_dims[0]; auto C = input_dims[1]; - auto H_in = input_dims[2]; - auto W_in = input_dims[3]; - auto H_out = grid_dims[1]; - auto W_out = grid_dims[2]; ORT_ENFORCE(grid_dims[0] == N, "Grid batch size ", grid_dims[0], " does not match input batch size ", N); - ORT_ENFORCE(grid_dims[3] == 2, "Last dimension of grid: ", grid_dims[3], ", expect 2"); - TensorShape Y_shape = {N, C, H_out, W_out}; - auto& Y = *context->Output(0, Y_shape); - // Return early if the output tensor is going to be of size 0 - if (Y.Shape().Size() == 0) { - return Status::OK(); + if (input_dims.NumDimensions() == 5) { + ORT_ENFORCE(mode_ != Cubic, "Only support GridSample Cubic mode in 4-D cases."); } - // Force float here to avoid possible issue in integer T case - float x_min = -0.5f; - float x_max = W_in - 0.5f; - float y_min = -0.5f; - float y_max = H_in - 0.5f; - - if (align_corners_) { - x_min = 0.f; - x_max = W_in - 1.f; - y_min = 0.f; - y_max = H_in - 1.f; - } - float border[] = {x_min, y_min, x_max, y_max}; // l-t-r-b - - concurrency::ThreadPool* tp = H_out * W_out > 64 ? context->GetOperatorThreadPool() : nullptr; - for (int64_t n = 0; n < N; n++) { - const T* grid_data = grid->Data() + n * (H_out * W_out) * 2; - concurrency::ThreadPool::TrySimpleParallelFor( - tp, onnxruntime::narrow(C), - [&](std::ptrdiff_t c) { - const T* X_data = input->Data() + (n * C + c) * (H_in * W_in); - T* Y_data = Y.MutableData() + (n * C + c) * (H_out * W_out); - - for (int64_t oy = 0; oy < H_out; oy++) { - for (int64_t ox = 0; ox < W_out; ox++) { - const T* gridpoint = grid_data + (oy * W_out + ox) * 2; - T* Y_gridpoint = Y_data + oy * W_out + ox; - auto nx = gridpoint[0]; // normalized location - auto ny = gridpoint[1]; - auto x = GsDenormalize(nx, W_in, align_corners_); // actual location - auto y = GsDenormalize(ny, H_in, align_corners_); - - if (mode_ == Nearest) { - x = static_cast(std::nearbyintf(static_cast(x))); - y = static_cast(std::nearbyintf(static_cast(y))); - } + if (data_dims == 2) { + // sample 2d; + auto H_in = input_dims[2]; + auto W_in = input_dims[3]; + auto H_out = grid_dims[1]; + auto W_out = grid_dims[2]; + TensorShape Y_shape = {N, C, H_out, W_out}; + auto& Y = *context->Output(0, Y_shape); + // Return early if the output tensor is going to be of size 0 + if (Y.Shape().Size() == 0) { + return Status::OK(); + } - if (x < x_min || x > x_max || y < y_min || y > y_max) { // out of bound - if (padding_mode_ == Border) { - // use original border in both align_corner cases - x = std::clamp(x, static_cast(0), static_cast(W_in - 1)); - y = std::clamp(y, static_cast(0), static_cast(H_in - 1)); - } else if (padding_mode_ == Reflection) { - x = GsReflect(x, x_min, x_max); - y = GsReflect(y, y_min, y_max); - } - } // out of bound + T x_min = -0.5f; + T x_max = W_in - 0.5f; + T y_min = -0.5f; + T y_max = H_in - 0.5f; - if (mode_ == Nearest) { - // x, y are integers in all padding modes - *Y_gridpoint = PixelAtGrid(X_data, static_cast(y), static_cast(x), H_in, W_in, border); - continue; - } + if (align_corners_) { + x_min = 0.f; + x_max = W_in - 1.f; + y_min = 0.f; + y_max = H_in - 1.f; + } + T border[] = {x_min, y_min, x_max, y_max}; // l-t-r-b - if (mode_ == Bilinear) { - int64_t x1 = static_cast(std::floor(x)); - int64_t y1 = static_cast(std::floor(y)); - int64_t x2 = x1 + 1; - int64_t y2 = y1 + 1; - - T p11 = PixelAtGrid(X_data, y1, x1, H_in, W_in, border); - T p12 = PixelAtGrid(X_data, y1, x2, H_in, W_in, border); - T p21 = PixelAtGrid(X_data, y2, x1, H_in, W_in, border); - T p22 = PixelAtGrid(X_data, y2, x2, H_in, W_in, border); - - T dx2 = static_cast(x2) - x; - T dx1 = x - static_cast(x1); - T dy2 = static_cast(y2) - y; - T dy1 = y - static_cast(y1); - *Y_gridpoint = dy2 * (dx2 * p11 + dx1 * p12) + dy1 * (dx2 * p21 + dx1 * p22); + concurrency::ThreadPool* tp = H_out * W_out > 64 ? context->GetOperatorThreadPool() : nullptr; + for (int64_t n = 0; n < N; n++) { + const T* grid_data = grid->Data() + n * (H_out * W_out) * 2; + concurrency::ThreadPool::TrySimpleParallelFor( + tp, onnxruntime::narrow(C), + [&](std::ptrdiff_t c) { + const T* X_data = input->Data() + (n * C + c) * (H_in * W_in); + T* Y_data = Y.MutableData() + (n * C + c) * (H_out * W_out); + + for (int64_t oy = 0; oy < H_out; oy++) { + for (int64_t ox = 0; ox < W_out; ox++) { + const T* gridpoint = grid_data + (oy * W_out + ox) * 2; + T* Y_gridpoint = Y_data + oy * W_out + ox; + auto nx = gridpoint[0]; // normalized location + auto ny = gridpoint[1]; + auto x = GsDenormalize(nx, W_in, align_corners_); // actual location + auto y = GsDenormalize(ny, H_in, align_corners_); + + if (mode_ == Nearest) { + x = static_cast(std::nearbyint(static_cast(x))); + y = static_cast(std::nearbyint(static_cast(y))); + // x, y are integers in all padding modes + *Y_gridpoint = PixelAtGrid(X_data, static_cast(y), static_cast(x), H_in, W_in, border); + } else if (mode_ == Linear) { + int64_t x1 = static_cast(std::floor(x)); + int64_t y1 = static_cast(std::floor(y)); + int64_t x2 = x1 + 1; + int64_t y2 = y1 + 1; + + T p11 = PixelAtGrid(X_data, y1, x1, H_in, W_in, border); + T p12 = PixelAtGrid(X_data, y1, x2, H_in, W_in, border); + T p21 = PixelAtGrid(X_data, y2, x1, H_in, W_in, border); + T p22 = PixelAtGrid(X_data, y2, x2, H_in, W_in, border); + + T dx2 = static_cast(x2) - x; + T dx1 = x - static_cast(x1); + T dy2 = static_cast(y2) - y; + T dy1 = y - static_cast(y1); + *Y_gridpoint = dy2 * (dx2 * p11 + dx1 * p12) + dy1 * (dx2 * p21 + dx1 * p22); + } else if (mode_ == Cubic) { + int64_t x0 = static_cast(std::floor(x)) - 1; // top-left corner of the bbox + int64_t y0 = static_cast(std::floor(y)) - 1; + + T p[4][4] = {}; // [H][W] + for (int64_t h = 0; h < 4; h++) { + for (int64_t w = 0; w < 4; w++) { + p[h][w] = PixelAtGrid(X_data, h + y0, w + x0, H_in, W_in, border); + } + } + T dx = static_cast(x - x0 - 1); + T dy = static_cast(y - y0 - 1); + *Y_gridpoint = GsBicubicInterpolate(p, dx, dy); + } } - if (mode_ == Bicubic) { - int64_t x0 = static_cast(std::floor(x)) - 1; // top-left corner of the bbox - int64_t y0 = static_cast(std::floor(y)) - 1; - T p[4][4] = {}; // [H][W] - for (int64_t h = 0; h < 4; h++) { - for (int64_t w = 0; w < 4; w++) { - p[h][w] = PixelAtGrid(X_data, h + y0, w + x0, H_in, W_in, border); + } + }); + } + } else if (data_dims == 3) { + // sample 3d; + auto D_in = input_dims[2]; + auto H_in = input_dims[3]; + auto W_in = input_dims[4]; + auto D_out = grid_dims[1]; + auto H_out = grid_dims[2]; + auto W_out = grid_dims[3]; + TensorShape Y_shape = {N, C, D_out, H_out, W_out}; + auto& Y = *context->Output(0, Y_shape); + // Return early if the output tensor is going to be of size 0 + if (Y.Shape().Size() == 0) { + return Status::OK(); + } + + T x_min = -0.5f; + T x_max = W_in - 0.5f; + T y_min = -0.5f; + T y_max = H_in - 0.5f; + T z_min = -0.5f; + T z_max = D_in - 0.5f; + + if (align_corners_) { + x_min = 0.f; + x_max = W_in - 1.f; + y_min = 0.f; + y_max = H_in - 1.f; + z_min = 0.f; + z_max = D_in - 1.f; + } + T border[] = {x_min, y_min, z_min, x_max, y_max, z_max}; + + concurrency::ThreadPool* tp = D_out * H_out * W_out > 64 ? context->GetOperatorThreadPool() : nullptr; + for (int64_t n = 0; n < N; n++) { + const T* grid_data = grid->Data() + n * (D_out * H_out * W_out) * 3; + concurrency::ThreadPool::TrySimpleParallelFor( + tp, onnxruntime::narrow(C), + [&](std::ptrdiff_t c) { + const T* X_data = input->Data() + (n * C + c) * (D_in * H_in * W_in); + T* Y_data = Y.MutableData() + (n * C + c) * (D_out * H_out * W_out); + + for (int64_t oz = 0; oz < D_out; oz++) { + for (int64_t oy = 0; oy < H_out; oy++) { + for (int64_t ox = 0; ox < W_out; ox++) { + const T* gridpoint = grid_data + (oz * H_out * W_out + oy * W_out + ox) * 3; + T* Y_gridpoint = Y_data + oz * H_out * W_out + oy * W_out + ox; + auto nx = gridpoint[0]; // normalized location + auto ny = gridpoint[1]; + auto nz = gridpoint[2]; + auto x = GsDenormalize(nx, W_in, align_corners_); // actual location + auto y = GsDenormalize(ny, H_in, align_corners_); + auto z = GsDenormalize(nz, D_in, align_corners_); + + if (mode_ == Nearest) { + x = static_cast(std::nearbyint(static_cast(x))); + y = static_cast(std::nearbyint(static_cast(y))); + z = static_cast(std::nearbyint(static_cast(z))); + + // x, y are integers in all padding modes + *Y_gridpoint = PixelAtGrid3D(X_data, static_cast(z), static_cast(y), static_cast(x), + D_in, H_in, W_in, border); + } else if (mode_ == Linear) { + int64_t x1 = static_cast(std::floor(x)); + int64_t y1 = static_cast(std::floor(y)); + int64_t z1 = static_cast(std::floor(z)); + int64_t x2 = x1 + 1; + int64_t y2 = y1 + 1; + int64_t z2 = z1 + 1; + + T dx2 = static_cast(x2) - x; + T dx1 = x - static_cast(x1); + T dy2 = static_cast(y2) - y; + T dy1 = y - static_cast(y1); + T dz2 = static_cast(z2) - z; + T dz1 = z - static_cast(z1); + + T p111 = PixelAtGrid3D(X_data, z1, y1, x1, D_in, H_in, W_in, border); + T p112 = PixelAtGrid3D(X_data, z1, y1, x2, D_in, H_in, W_in, border); + T p121 = PixelAtGrid3D(X_data, z1, y2, x1, D_in, H_in, W_in, border); + T p122 = PixelAtGrid3D(X_data, z1, y2, x2, D_in, H_in, W_in, border); + T Y_gridpoint_z1 = dy2 * (dx2 * p111 + dx1 * p112) + dy1 * (dx2 * p121 + dx1 * p122); + + T p211 = PixelAtGrid3D(X_data, z2, y1, x1, D_in, H_in, W_in, border); + T p212 = PixelAtGrid3D(X_data, z2, y1, x2, D_in, H_in, W_in, border); + T p221 = PixelAtGrid3D(X_data, z2, y2, x1, D_in, H_in, W_in, border); + T p222 = PixelAtGrid3D(X_data, z2, y2, x2, D_in, H_in, W_in, border); + T Y_gridpoint_z2 = dy2 * (dx2 * p211 + dx1 * p212) + dy1 * (dx2 * p221 + dx1 * p222); + *Y_gridpoint = dz2 * Y_gridpoint_z1 + dz1 * Y_gridpoint_z2; } } - T dx = static_cast(x - x0 - 1); - T dy = static_cast(y - y0 - 1); - *Y_gridpoint = GsBicubicInterpolate(p, static_cast(dx), static_cast(dy)); } } - } - }); + }); + } + } else { + // shall not reach here due to above checks + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Only support GirdSample in 4-D or 5-D cases."); } return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/tensor/grid_sample.h b/onnxruntime/core/providers/cpu/tensor/grid_sample.h index 2dd828b3ae3f1..dee0c4701ee21 100644 --- a/onnxruntime/core/providers/cpu/tensor/grid_sample.h +++ b/onnxruntime/core/providers/cpu/tensor/grid_sample.h @@ -15,37 +15,52 @@ template class GridSample final : public OpKernel { public: explicit GridSample(const OpKernelInfo& info) : OpKernel(info) { - std::string mode_str = info.GetAttrOrDefault("mode", "bilinear"); - std::string padding_mode_str = info.GetAttrOrDefault("padding_mode", "zeros"); - align_corners_ = static_cast(info.GetAttrOrDefault("align_corners", 0)); - ORT_ENFORCE(mode_str == "bilinear" || mode_str == "nearest" || mode_str == "bicubic", - "mode \"", mode_str, "\" not supported, expect bilinear, nearest or bicubic"); - ORT_ENFORCE(padding_mode_str == "zeros" || padding_mode_str == "border" || padding_mode_str == "reflection", - "padding_mode \"", padding_mode_str, "\" not supported, expect zeros, border or reflection"); - if (mode_str == "bicubic") { - mode_ = Bicubic; - } else if (mode_str == "nearest") { - mode_ = Nearest; + int start_version = info.node().SinceVersion(); + if (start_version >= 20) { + std::string mode_str = info.GetAttrOrDefault("mode", "linear"); + if (mode_str == "cubic") { + mode_ = Cubic; + } else if (mode_str == "nearest") { + mode_ = Nearest; + } else if (mode_str == "linear") { + mode_ = Linear; + } else { + ORT_THROW("mode \"", mode_str, "\" not supported, expect linear, nearest or cubic"); + } } else { - mode_ = Bilinear; + std::string mode_str = info.GetAttrOrDefault("mode", "bilinear"); + if (mode_str == "bicubic") { + mode_ = Cubic; + } else if (mode_str == "nearest") { + mode_ = Nearest; + } else if (mode_str == "bilinear") { + mode_ = Linear; + } else { + ORT_THROW("mode \"", mode_str, "\" not supported, expect bilinear, nearest or bicubic"); + } } + + std::string padding_mode_str = info.GetAttrOrDefault("padding_mode", "zeros"); + align_corners_ = static_cast(info.GetAttrOrDefault("align_corners", 0)); if (padding_mode_str == "reflection") { padding_mode_ = Reflection; } else if (padding_mode_str == "border") { padding_mode_ = Border; - } else { + } else if (padding_mode_str == "zeros") { padding_mode_ = Zeros; + } else { + ORT_THROW("padding_mode \"", padding_mode_str, "\" not supported, expect zeros, border or reflection"); } } Status Compute(OpKernelContext* context) const override; private: - enum GridSampleInterpolationMode { - Bilinear, + typedef enum { + Linear, + Cubic, Nearest, - Bicubic - }; + } GridSampleInterpolationMode; enum GridSamplePaddingMode { Zeros, @@ -53,9 +68,10 @@ class GridSample final : public OpKernel { Reflection }; - T PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, int64_t W, float border[/* 4 */]) const; + T PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, int64_t W, T border[/* 4 */]) const; + T PixelAtGrid3D(const T* image, int64_t d, int64_t h, int64_t w, int64_t D, int64_t H, int64_t W, T border[/* 6 */]) const; - GridSampleInterpolationMode mode_{Bilinear}; + GridSampleInterpolationMode mode_{Linear}; GridSamplePaddingMode padding_mode_{Zeros}; bool align_corners_{0}; }; diff --git a/onnxruntime/core/providers/cpu/tensor/isinf.cc b/onnxruntime/core/providers/cpu/tensor/isinf.cc index bc99caa8036cf..1b449f46927a2 100644 --- a/onnxruntime/core/providers/cpu/tensor/isinf.cc +++ b/onnxruntime/core/providers/cpu/tensor/isinf.cc @@ -14,15 +14,38 @@ namespace onnxruntime { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#IsInf namespace op_kernel_type_control { -ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS( - kCpuExecutionProvider, kOnnxDomain, IsInf, Input, 0, - float, double); +using IsInfTypesOpset10 = TypeList; + +ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST( + kCpuExecutionProvider, kOnnxDomain, IsInf, 10, Input, 0, + IsInfTypesOpset10); + +using IsInfTypesOpset20 = + TypeList< + float, + double +#if !defined(DISABLE_FLOAT8_TYPES) + , + Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ +#endif + >; + +ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST( + kCpuExecutionProvider, + kOnnxDomain, + IsInf, + 20, + Input, + 0, + IsInfTypesOpset20); } // namespace op_kernel_type_control class IsInf final : public OpKernel { public: - using EnabledDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain, - IsInf, Input, 0); + using EnabledDataTypes10 = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, + IsInf, 10, Input, 0); + using EnabledDataTypes20 = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, + IsInf, 20, Input, 0); explicit IsInf(const OpKernelInfo& info); Status Compute(OpKernelContext* context) const override; @@ -30,14 +53,25 @@ class IsInf final : public OpKernel { private: int64_t detect_positive_{1}; int64_t detect_negative_{1}; + int opset_; }; -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( IsInf, 10, + 19, KernelDefBuilder() .TypeConstraint("T1", - BuildKernelDefConstraintsFromTypeList()) + BuildKernelDefConstraintsFromTypeList()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + IsInf); + +ONNX_CPU_OPERATOR_KERNEL( + IsInf, + 20, + KernelDefBuilder() + .TypeConstraint("T1", + BuildKernelDefConstraintsFromTypeList()) .TypeConstraint("T2", DataTypeImpl::GetTensorType()), IsInf); @@ -46,6 +80,7 @@ IsInf::IsInf(const OpKernelInfo& info) : OpKernel(info) { ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_positive"); status = info.GetAttr("detect_negative", &detect_negative_); ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_negative"); + opset_ = info.node().SinceVersion(); } namespace isinf_internal { @@ -78,6 +113,49 @@ struct ComputeDispatchTarget { } } }; + +#if !defined(DISABLE_FLOAT8_TYPES) +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor&, Tensor& Y, bool, bool) const { + EigenMap(Y).array() = false; + } +}; + +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor&, Tensor& Y, bool, bool) const { + EigenMap(Y).array() = false; + } +}; + +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const { + auto& dims = X.Shape(); + auto input = ConstEigenVectorMap(static_cast(static_cast(X.Data())), onnxruntime::narrow(dims.Size())); + auto output = EigenMap(Y); + + // S.11111.00 + if (detect_positive && detect_negative) { + output.array() = input.array() == 0b01111100 || input.array() == 0b11111100; + } else if (detect_positive) { + output.array() = input.array() == 0b01111100; + } else if (detect_negative) { + output.array() = input.array() == 0b11111100; + } else { + output.array() = false; + } + } +}; + +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor&, Tensor& Y, bool, bool) const { + EigenMap(Y).array() = false; + } +}; +#endif } // namespace isinf_internal Status IsInf::Compute(OpKernelContext* context) const { @@ -88,8 +166,13 @@ Status IsInf::Compute(OpKernelContext* context) const { using namespace isinf_internal; - utils::MLTypeCallDispatcherFromTypeList dispatcher{X.GetElementType()}; - dispatcher.Invoke(X, Y, detect_positive_ != 0, detect_negative_ != 0); + if (opset_ < 20) { + utils::MLTypeCallDispatcherFromTypeList dispatcher{X.GetElementType()}; + dispatcher.Invoke(X, Y, detect_positive_ != 0, detect_negative_ != 0); + } else { + utils::MLTypeCallDispatcherFromTypeList dispatcher{X.GetElementType()}; + dispatcher.Invoke(X, Y, detect_positive_ != 0, detect_negative_ != 0); + } return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/tensor/isnan.cc b/onnxruntime/core/providers/cpu/tensor/isnan.cc index 33d0f8eb6c1ae..34495e382278a 100644 --- a/onnxruntime/core/providers/cpu/tensor/isnan.cc +++ b/onnxruntime/core/providers/cpu/tensor/isnan.cc @@ -20,10 +20,20 @@ namespace onnxruntime { .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ IsNaN); +#define ADD_TYPED_ISNAN_OP_13(data_type) \ + ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \ + IsNaN, \ + 13, 19, \ + data_type, \ + KernelDefBuilder() \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + IsNaN); + #define ADD_TYPED_ISNAN_OP(data_type) \ ONNX_CPU_OPERATOR_TYPED_KERNEL( \ IsNaN, \ - 13, \ + 20, \ data_type, \ KernelDefBuilder() \ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ @@ -33,10 +43,20 @@ namespace onnxruntime { ADD_TYPED_ISNAN_OP_9(float); ADD_TYPED_ISNAN_OP_9(double); ADD_TYPED_ISNAN_OP_9(MLFloat16); +ADD_TYPED_ISNAN_OP_13(float); +ADD_TYPED_ISNAN_OP_13(double); +ADD_TYPED_ISNAN_OP_13(MLFloat16); ADD_TYPED_ISNAN_OP(float); ADD_TYPED_ISNAN_OP(double); ADD_TYPED_ISNAN_OP(MLFloat16); +#if !defined(DISABLE_FLOAT8_TYPES) +ADD_TYPED_ISNAN_OP(Float8E4M3FN); +ADD_TYPED_ISNAN_OP(Float8E4M3FNUZ); +ADD_TYPED_ISNAN_OP(Float8E5M2); +ADD_TYPED_ISNAN_OP(Float8E5M2FNUZ); +#endif + template Status IsNaN::Compute(OpKernelContext* context) const { const auto* X_ptr = context->Input(0); @@ -70,4 +90,63 @@ Status IsNaN::Compute(OpKernelContext* context) const { return Status::OK(); } + +#if !defined(DISABLE_FLOAT8_TYPES) +template <> +Status IsNaN::Compute(OpKernelContext* context) const { + const auto* X = context->Input(0); + auto& dims = X->Shape(); + auto& Y = *context->Output(0, dims); + + auto input = ConstEigenVectorMap(static_cast(static_cast(X->Data())), onnxruntime::narrow(dims.Size())); + auto output = EigenMap(Y); + + // S.1111.111 + std::transform(input.begin(), input.end(), output.begin(), [](uint8_t c) { return (c & 0x7f) == 0x7f; }); + return Status::OK(); +} + +template <> +Status IsNaN::Compute(OpKernelContext* context) const { + const auto* X = context->Input(0); + auto X_data = X->Data(); + auto& dims = X->Shape(); + auto shape_size = dims.Size(); + auto& Y = *context->Output(0, dims); + + // 1.0000.000 + EigenMap(Y) = + ConstEigenVectorMap(static_cast(static_cast(X_data)), onnxruntime::narrow(shape_size)).array() == 0x80; + + return Status::OK(); +} + +template <> +Status IsNaN::Compute(OpKernelContext* context) const { + const auto* X = context->Input(0); + auto& dims = X->Shape(); + auto& Y = *context->Output(0, dims); + + auto input = ConstEigenVectorMap(static_cast(static_cast(X->Data())), onnxruntime::narrow(dims.Size())); + auto output = EigenMap(Y); + + // S.11111.{01, 10, 11} + std::transform(input.begin(), input.end(), output.begin(), [](uint8_t c) { return ((c & 0x7c) == 0x7c) && ((c & 0x03) != 0x00); }); + return Status::OK(); +} + +template <> +Status IsNaN::Compute(OpKernelContext* context) const { + const auto* X = context->Input(0); + auto X_data = X->Data(); + auto& dims = X->Shape(); + auto shape_size = dims.Size(); + auto& Y = *context->Output(0, dims); + + // 1.0000.000 + EigenMap(Y) = ConstEigenVectorMap(static_cast(static_cast(X_data)), onnxruntime::narrow(shape_size)).array() == 0x80; + + return Status::OK(); +} +#endif } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h index 99522cdf0759a..0b3ce6f477843 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h +++ b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h @@ -352,7 +352,7 @@ class UpsampleBase { (scales.size() == 4 && scales[0] == 1 && scales[3] == 1) || scales.size() == 3 || (scales.size() == 5 && scales[0] == 1 && scales[1] == 1), - "'Linear' mode only support:\n" + "'Linear' mode only supports:\n" " * 2-D inputs or\n" " * 3-D inputs ('Bilinear', 'Trilinear') or\n" " * 4-D inputs with the corresponding outermost 2 scale values being 1" diff --git a/onnxruntime/core/providers/cuda/cuda_common.cc b/onnxruntime/core/providers/cuda/cuda_common.cc index 57477f167c555..288ca8e97e34d 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.cc +++ b/onnxruntime/core/providers/cuda/cuda_common.cc @@ -27,5 +27,90 @@ const HalfGemmOptions* HalfGemmOptions::GetInstance() { return &instance; } +const char* cublasGetErrorEnum(cublasStatus_t error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; + default: + return ""; + } +} + +const char* CudaDataTypeToString(cudaDataType_t dt) { + switch (dt) { + case CUDA_R_16F: + return "CUDA_R_16F"; + case CUDA_R_16BF: + return "CUDA_R_16BF"; + case CUDA_R_32F: + return "CUDA_R_32F"; +#if (CUDA_VERSION >= 11080) + case CUDA_R_8F_E4M3: + return "CUDA_R_8F_E4M3"; + case CUDA_R_8F_E5M2: + return "CUDA_R_8F_E5M2"; +#endif + default: + return ""; + } +} + +const char* CublasComputeTypeToString(cublasComputeType_t ct) { + switch (ct) { + case CUBLAS_COMPUTE_16F: + return "CUBLAS_COMPUTE_16F"; + case CUBLAS_COMPUTE_32F: + return "CUBLAS_COMPUTE_32F"; + case CUBLAS_COMPUTE_32F_FAST_16F: + return "CUBLAS_COMPUTE_32F_FAST_16F"; + case CUBLAS_COMPUTE_32F_FAST_16BF: + return "CUBLAS_COMPUTE_32F_FAST_16BF"; + case CUBLAS_COMPUTE_32F_FAST_TF32: + return "CUBLAS_COMPUTE_32F_FAST_TF32"; + case CUBLAS_COMPUTE_64F: + return "CUBLAS_COMPUTE_64F"; + default: + return ""; + } +} + +// It must exist somewhere already. +cudaDataType_t ToCudaDataType(int32_t element_type) { + switch (element_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + return CUDA_R_32F; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + return CUDA_R_16F; + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: + return CUDA_R_16BF; +#if (!defined(DISABLE_FLOAT8_TYPES) && (CUDA_VERSION >= 11080)) + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: + return CUDA_R_8F_E4M3; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: + return CUDA_R_8F_E5M2; +#endif + default: + ORT_THROW("Unexpected element_type=", element_type, "."); + } +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index fa258961f1155..9cd4e721ccab8 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -11,6 +11,7 @@ #include "core/providers/shared_library/provider_api.h" #include "core/common/status.h" +#include "core/framework/float8.h" #include "core/framework/float16.h" #include "core/providers/cuda/cuda_pch.h" #include "core/providers/cuda/shared_inc/cuda_call.h" @@ -48,6 +49,33 @@ class ToCudaType { } }; +template <> +class ToCudaType { + public: + typedef BFloat16 MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + +template <> +class ToCudaType { + public: + typedef Float8E4M3FN MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + +template <> +class ToCudaType { + public: + typedef Float8E5M2 MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + inline bool CalculateFdmStrides(gsl::span p, const std::vector& dims) { int stride = 1; if (dims.empty() || p.size() < dims.size()) @@ -152,5 +180,13 @@ class HalfGemmOptions { static HalfGemmOptions instance; }; +const char* cublasGetErrorEnum(cublasStatus_t error); + +const char* CudaDataTypeToString(cudaDataType_t dt); + +const char* CublasComputeTypeToString(cublasComputeType_t ct); + +cudaDataType_t ToCudaDataType(int32_t element_type); + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index de01e240a06c7..2d242d7d6fb12 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #include "core/common/inlined_containers.h" @@ -15,6 +16,10 @@ #include "contrib_ops/cuda/cuda_contrib_kernels.h" #endif +#ifdef ENABLE_CUDA_NHWC_OPS +#include "core/providers/cuda/cuda_nhwc_kernels.h" +#endif + #ifdef ENABLE_TRAINING_OPS #include "orttraining/training_ops/cuda/cuda_training_kernels.h" #endif @@ -233,6 +238,10 @@ CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& in : IExecutionProvider{onnxruntime::kCudaExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_{info}, tuning_context_(this, &info_.tunable_op) { +#ifndef ENABLE_CUDA_NHWC_OPS + ORT_ENFORCE(info_.prefer_nhwc == 0, "This build does not support NHWC layout"); +#endif + CUDA_CALL_THROW(cudaSetDevice(info_.device_id)); // must wait GPU idle, otherwise cudaGetDeviceProperties might fail @@ -250,7 +259,7 @@ CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& in if (info.external_allocator_info.UseExternalAllocator()) { use_ep_level_unified_stream_ = true; stream_ = nullptr; - } else if (info.enable_cuda_graph) { + } else if (info.enable_cuda_graph || info.use_ep_level_unified_stream) { // current cuda graph implementation only works with single stream // use EP level unified stream for all the reqeust CUDA_CALL_THROW(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); @@ -271,6 +280,10 @@ CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& in #endif } +DataLayout CUDAExecutionProvider::GetPreferredLayout() const { + return this->IsNHWCPreferred() ? DataLayout::NHWC : DataLayout::NCHW; +} + CUDAExecutionProvider::~CUDAExecutionProvider() { // clean up thread local context caches { @@ -632,51 +645,54 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, GlobalMaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, GlobalMaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, GlobalMaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, double, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, double, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int64_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int64_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int8_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, uint8_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 12, float, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 12, double, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 12, MLFloat16, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 12, int32_t, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 12, int64_t, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceLogSumExp); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 8, float, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 8, double, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 8, MLFloat16, Cast); @@ -811,12 +827,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, Mod); // opset 11 -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, float, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, double, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, float, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, double, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ArgMin); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Compress); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Concat); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Flatten); @@ -830,45 +840,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Loop); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, NonMaxSuppression); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Range); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, double, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, int32_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, int64_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int64_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceSumSquare); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 15, Scan); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, ScatterElements); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, Slice); @@ -946,22 +917,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Pow); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int32_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int64_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int8_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, uint8_t, ReduceMax); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int64_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int8_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, uint8_t, ReduceMin); - class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int64_t, GatherND); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Dropout); @@ -1115,50 +1070,36 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Gemm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceLogSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceLogSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceLogSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int64_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int8_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, uint8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceLogSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceLogSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceLogSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceLogSumExp); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceLogSumExp); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceLogSumExp); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceProd); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, ReduceSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceSumSquare); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceSumSquare); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceSumSquare); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceSumSquare); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceSumSquare); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceSumSquare); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Dropout); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Resize); @@ -1257,13 +1198,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int8_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, uint8_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMin); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, Trilu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Add); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Sub); @@ -1287,6 +1228,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, PRelu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 18, Scan); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, Where); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, BFloat16, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double_t, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, int32_t, Where); @@ -1316,6 +1258,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, // Opset 18 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, Split); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax); + // Opset 19 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Cast); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, double, Cast); @@ -1581,51 +1529,51 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1764,12 +1712,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // opset 11 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1783,45 +1728,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1895,22 +1801,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2064,50 +1954,36 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2206,13 +2082,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2236,6 +2112,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2264,6 +2141,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { // Opset 18 BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 19 BuildKernelCreateInfo, @@ -2330,6 +2212,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { ORT_RETURN_IF_ERROR(::onnxruntime::contrib::cuda::RegisterCudaContribKernels(kernel_registry)); #endif +#ifdef ENABLE_CUDA_NHWC_OPS + ORT_RETURN_IF_ERROR(::onnxruntime::cuda::RegisterCudaNhwcKernels(kernel_registry)); +#endif + #ifdef ENABLE_TRAINING_OPS ORT_RETURN_IF_ERROR(::onnxruntime::cuda::RegisterCudaTrainingKernels(kernel_registry)); #endif diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index c9e510b7f472b..d0bb2321edf0a 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #pragma once @@ -32,6 +33,8 @@ class CUDAExecutionProvider : public IExecutionProvider { Status OnRunEnd(bool sync_stream) override; + DataLayout GetPreferredLayout() const override; + const void* GetExecutionHandle() const noexcept override { // The CUDA interface does not return anything interesting. return nullptr; @@ -49,6 +52,12 @@ class CUDAExecutionProvider : public IExecutionProvider { return GetPerThreadContext().CudnnHandle(); } + cudaStream_t ComputeStream() { + // this will return the CUDA EP level stream which can differ from the actual compute tasks stream + // the compute task stream is supplied within OpKernelContext during inference + return stream_; + } + template const T* GetConstOnes(size_t count, cudaStream_t stream) { return GetPerThreadContext().template GetConstOnes(count, stream); @@ -68,6 +77,7 @@ class CUDAExecutionProvider : public IExecutionProvider { bool GetCudnnConvUseMaxWorkspace() const { return info_.cudnn_conv_use_max_workspace; } bool GetCudnnConv1dPadToNc1d() const { return info_.cudnn_conv1d_pad_to_nc1d; } bool IsSkipLayerNormInStrictMode() const { return info_.enable_skip_layer_norm_strict_mode; } + bool IsNHWCPreferred() const { return info_.prefer_nhwc; } ProviderOptions GetProviderOptions() const override { return CUDAExecutionProviderInfo::ToProviderOptions(info_); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc index ca88b3474b758..daa3b5ff3d72f 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" @@ -29,6 +30,8 @@ constexpr const char* kTunableOpEnable = "tunable_op_enable"; constexpr const char* kTunableOpTuningEnable = "tunable_op_tuning_enable"; constexpr const char* kTunableOpMaxTuningDurationMs = "tunable_op_max_tuning_duration_ms"; constexpr const char* kEnableSkipLayerNormStrictMode = "enable_skip_layer_norm_strict_mode"; +constexpr const char* kPreferNCHWMode = "prefer_nhwc"; +constexpr const char* KUseEPLevelUnifiedStream = "use_ep_level_unified_stream"; } // namespace provider_option_names } // namespace cuda @@ -99,6 +102,8 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P .AddAssignmentToReference(cuda::provider_option_names::kEnableCudaGraph, info.enable_cuda_graph) .AddAssignmentToReference(cuda::provider_option_names::kCudnnConv1dPadToNc1d, info.cudnn_conv1d_pad_to_nc1d) .AddAssignmentToReference(cuda::provider_option_names::kEnableSkipLayerNormStrictMode, info.enable_skip_layer_norm_strict_mode) + .AddAssignmentToReference(cuda::provider_option_names::kPreferNCHWMode, info.prefer_nhwc) + .AddAssignmentToReference(cuda::provider_option_names::KUseEPLevelUnifiedStream, info.use_ep_level_unified_stream) .AddValueParser( cuda::provider_option_names::kTunableOpEnable, [&info](const std::string& value_str) -> Status { @@ -144,6 +149,8 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution {cuda::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op.tuning_enable)}, {cuda::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op.max_tuning_duration_ms)}, {cuda::provider_option_names::kEnableSkipLayerNormStrictMode, MakeStringWithClassicLocale(info.enable_skip_layer_norm_strict_mode)}, + {cuda::provider_option_names::kPreferNCHWMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, + {cuda::provider_option_names::KUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, }; return options; @@ -162,6 +169,8 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid {cuda::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op_enable)}, {cuda::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op_tuning_enable)}, {cuda::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op_max_tuning_duration_ms)}, + {cuda::provider_option_names::kPreferNCHWMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, + {cuda::provider_option_names::KUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, }; return options; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h index 789b02b0e1d8c..b286f5a9161b0 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #pragma once @@ -71,6 +72,9 @@ struct CUDAExecutionProviderInfo { cuda::TunableOpInfo tunable_op{}; bool enable_skip_layer_norm_strict_mode{false}; + bool prefer_nhwc{false}; + + bool use_ep_level_unified_stream{false}; static CUDAExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const CUDAExecutionProviderInfo& info); diff --git a/onnxruntime/core/providers/cuda/cuda_kernel.h b/onnxruntime/core/providers/cuda/cuda_kernel.h index 58517c2850baf..e3106e41e77c8 100644 --- a/onnxruntime/core/providers/cuda/cuda_kernel.h +++ b/onnxruntime/core/providers/cuda/cuda_kernel.h @@ -170,10 +170,10 @@ class CudaKernel : public OpKernel { return provider_->PerThreadDefaultCudnnHandle(); } - protected: - template - inline const T* GetConstOnes(size_t count, cudaStream_t stream) const { - return provider_->template GetConstOnes(count, stream); + inline cudaStream_t DefaultCudaStream() const { + // this will return the CUDA EP level stream which can differ from the actual compute tasks stream + // the compute task stream is supplied within OpKernelContext during inference + return provider_->ComputeStream(); } inline Status CopyTensor(const Tensor& src, Tensor& dst, onnxruntime::Stream& stream) const { @@ -181,6 +181,12 @@ class CudaKernel : public OpKernel { return gpu_data_transfer->CopyTensorAsync(src, dst, stream); } + protected: + template + inline const T* GetConstOnes(size_t count, cudaStream_t stream) const { + return provider_->template GetConstOnes(count, stream); + } + inline int GetDeviceId() const { return provider_->GetDeviceId(); } private: diff --git a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc new file mode 100644 index 0000000000000..f416caecd115f --- /dev/null +++ b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc @@ -0,0 +1,169 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. +// Licensed under the MIT License. + +#ifdef ENABLE_CUDA_NHWC_OPS + +#include + +#include "core/providers/shared_library/provider_api.h" +#include "core/providers/cuda/cuda_fwd.h" + +#include "core/providers/cuda/cuda_nhwc_kernels.h" + +namespace onnxruntime::cuda { + +// When adding new supported NHWC operations make sure to also integrate them into: ConvertNodeLayout +// in onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc + +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, float, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, MLFloat16, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, float, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, MLFloat16, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, float, + Conv); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, MLFloat16, + Conv); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, float, + ConvTranspose); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, MLFloat16, + ConvTranspose); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 9, float, + AveragePool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 9, MLFloat16, + AveragePool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, float, GlobalAveragePool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, MLFloat16, + GlobalAveragePool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 7, float, + MaxPool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 7, MLFloat16, + MaxPool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 8, 9, float, + MaxPool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 8, 9, MLFloat16, + MaxPool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, float, GlobalMaxPool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, MLFloat16, GlobalMaxPool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, float, + AveragePool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, MLFloat16, + AveragePool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, float, + MaxPool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, MLFloat16, + MaxPool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, float, Conv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, MLFloat16, Conv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, float, ConvTranspose); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, MLFloat16, + ConvTranspose); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, float, AveragePool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, MLFloat16, AveragePool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, 11, float, + MaxPool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, 11, MLFloat16, + MaxPool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, float, MaxPool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, MLFloat16, MaxPool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, float, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, MLFloat16, + BatchNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, float, + BatchNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, MLFloat16, + BatchNormalization); + +Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) { + static const BuildKernelCreateInfoFn nhwc_function_table[] = { + BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + }; + + for (auto& function_table_entry : nhwc_function_table) { + KernelCreateInfo info = function_table_entry(); + if (info.kernel_def != nullptr) { // filter disabled entries where type is void + ORT_RETURN_IF_ERROR(kernel_registry.Register(std::move(info))); + } + } + return Status::OK(); +} +} // namespace onnxruntime::cuda +#endif diff --git a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.h b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.h new file mode 100644 index 0000000000000..0b3a6d5cff0c7 --- /dev/null +++ b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.h @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/status.h" + +namespace onnxruntime::cuda { + +onnxruntime::common::Status RegisterCudaNhwcKernels(onnxruntime::KernelRegistry& kernel_registry); + +} // namespace onnxruntime::cuda diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 5a11f2529f38e..892e8d5329eba 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" @@ -217,11 +218,13 @@ struct CUDA_Provider : Provider { info.default_memory_arena_cfg = params->default_memory_arena_cfg; info.cudnn_conv_use_max_workspace = params->cudnn_conv_use_max_workspace != 0; info.enable_cuda_graph = params->enable_cuda_graph != 0; + info.prefer_nhwc = params->prefer_nhwc; info.cudnn_conv1d_pad_to_nc1d = params->cudnn_conv1d_pad_to_nc1d != 0; info.tunable_op.enable = params->tunable_op_enable; info.tunable_op.tuning_enable = params->tunable_op_tuning_enable; info.tunable_op.max_tuning_duration_ms = params->tunable_op_max_tuning_duration_ms; info.enable_skip_layer_norm_strict_mode = params->enable_skip_layer_norm_strict_mode != 0; + info.use_ep_level_unified_stream = params->use_ep_level_unified_stream != 0; return std::make_shared(info); } @@ -243,7 +246,7 @@ struct CUDA_Provider : Provider { cuda_options.arena_extend_strategy = internal_options.arena_extend_strategy; cuda_options.do_copy_in_default_stream = internal_options.do_copy_in_default_stream; cuda_options.has_user_compute_stream = internal_options.has_user_compute_stream; - // The 'has_user_compute_stream' of the OrtCUDAProviderOptionsV2 instance can be set byC API UpdateCUDAProviderOptionsWithValue() as well. + // The 'has_user_compute_stream' of the OrtCUDAProviderOptionsV2 instance can be set by C API UpdateCUDAProviderOptionsWithValue() as well. // We only set the 'has_user_compute_stream' of the OrtCUDAProviderOptionsV2 instance if it is provided in options if (options.find("has_user_compute_stream") != options.end()) { cuda_options.user_compute_stream = internal_options.user_compute_stream; @@ -253,6 +256,8 @@ struct CUDA_Provider : Provider { cuda_options.enable_cuda_graph = internal_options.enable_cuda_graph; cuda_options.cudnn_conv1d_pad_to_nc1d = internal_options.cudnn_conv1d_pad_to_nc1d; cuda_options.enable_skip_layer_norm_strict_mode = internal_options.enable_skip_layer_norm_strict_mode; + cuda_options.prefer_nhwc = internal_options.prefer_nhwc; + cuda_options.use_ep_level_unified_stream = internal_options.use_ep_level_unified_stream; } ProviderOptions GetProviderOptions(const void* provider_options) override { diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc index e855a515f445a..5f1dbd30f6a3e 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc @@ -7,6 +7,25 @@ namespace onnxruntime { +DeferredCpuAllocator::DeferredCpuAllocator(CudaStream& cuda_stream) : cuda_stream_(cuda_stream) { + OrtAllocator::version = ORT_API_VERSION; + OrtAllocator::Alloc = + [](OrtAllocator* this_, size_t size) { + auto self = reinterpret_cast(this_); + return self->cuda_stream_.GetCpuAllocator()->Alloc(size); + }; + OrtAllocator::Free = + [](OrtAllocator* this_, void* p) { + auto self = reinterpret_cast(this_); + self->cuda_stream_.EnqueDeferredCPUBuffer(p); + }; + OrtAllocator::Info = + [](const OrtAllocator* this_) { + auto self = reinterpret_cast(this_); + return &self->cuda_stream_.GetCpuAllocator()->Info(); + }; +} + struct CudaNotification : public synchronize::Notification { CudaNotification(Stream& s) : Notification(s) { CUDA_CALL_THROW(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); @@ -46,7 +65,8 @@ CudaStream::CudaStream(cudaStream_t stream, cublasHandle_t external_cublas_handle) : Stream(stream, device), own_stream_(own_flag), cpu_allocator_(cpu_allocator), - release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream) { + release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream), + deferred_cpu_allocator_(*this) { if (own_flag) { CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_)); CUBLAS_CALL_THROW(cublasSetStream(cublas_handle_, stream)); @@ -162,6 +182,9 @@ void* CudaStream::GetResource(int version, int id) const { case CudaResource::cublas_handle_t: return reinterpret_cast(cublas_handle_); break; + case CudaResource::deferred_cpu_allocator_t: + return const_cast(&deferred_cpu_allocator_); + break; default: break; } diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.h b/onnxruntime/core/providers/cuda/cuda_stream_handle.h index 9c62b029b7a36..917702fae08f1 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.h +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.h @@ -9,6 +9,13 @@ namespace onnxruntime { +struct CudaStream; + +struct DeferredCpuAllocator : public OrtAllocator { + DeferredCpuAllocator(CudaStream&); + CudaStream& cuda_stream_; +}; + struct CudaStream : Stream { CudaStream(cudaStream_t stream, const OrtDevice& device, @@ -36,10 +43,13 @@ struct CudaStream : Stream { void* GetResource(int version, int id) const override; + onnxruntime::IAllocator* GetCpuAllocator() const { return cpu_allocator_.get(); } + private: std::vector deferred_cpu_buffers_; AllocatorPtr cpu_allocator_; bool release_cpu_buffer_on_cuda_stream_{true}; + DeferredCpuAllocator deferred_cpu_allocator_; }; void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc index fc02a6509bf24..4df59a98b12e5 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.cc +++ b/onnxruntime/core/providers/cuda/cudnn_common.cc @@ -1,7 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. -#include "cudnn_common.h" +#include + +#include "core/providers/cuda/cudnn_common.h" #include "core/common/inlined_containers.h" #include "core/common/gsl.h" #include "shared_inc/cuda_call.h" @@ -27,7 +30,7 @@ Status CudnnTensor::CreateTensorIfNeeded() { return Status::OK(); } -Status CudnnTensor::Set(gsl::span input_dims, cudnnDataType_t dataType) { +Status CudnnTensor::Set(gsl::span input_dims, cudnnDataType_t dataType, bool is_nhwc) { ORT_RETURN_IF_ERROR(CreateTensorIfNeeded()); int rank = gsl::narrow_cast(input_dims.size()); @@ -38,6 +41,10 @@ Status CudnnTensor::Set(gsl::span input_dims, cudnnDataType_t dat dims[i] = gsl::narrow_cast(input_dims[i]); strides[i] = gsl::narrow_cast(pitches[i]); } + if (is_nhwc) { + std::swap(dims[1], dims[rank - 1]); + std::swap(strides[1], strides[rank - 1]); + } CUDNN_RETURN_IF_ERROR(cudnnSetTensorNdDescriptor(tensor_, dataType, static_cast(rank), dims.data(), strides.data())); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/cudnn_common.h b/onnxruntime/core/providers/cuda/cudnn_common.h index ba75ab4f2c029..8a94a334ee688 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.h +++ b/onnxruntime/core/providers/cuda/cudnn_common.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #pragma once @@ -16,7 +17,7 @@ class CudnnTensor final { ~CudnnTensor(); ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CudnnTensor); - Status Set(gsl::span input_dims, cudnnDataType_t dataType); + Status Set(gsl::span input_dims, cudnnDataType_t dataType, bool is_nhwc = false); Status Set(const CudnnTensor& x_desc, cudnnBatchNormMode_t mode); // Set 4D tensor format (for NHWC) Status Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int n, int c, int h, int w); diff --git a/onnxruntime/core/providers/cuda/nn/batch_norm.cc b/onnxruntime/core/providers/cuda/nn/batch_norm.cc index 4f22b5298a30a..c468971e1e426 100644 --- a/onnxruntime/core/providers/cuda/nn/batch_norm.cc +++ b/onnxruntime/core/providers/cuda/nn/batch_norm.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #include "batch_norm.h" @@ -11,38 +12,38 @@ using namespace std; namespace onnxruntime { namespace cuda { -#define REGISTER_KERNEL_TYPED(T) \ +#define REGISTER_KERNEL_TYPED(T, DOMAIN, NHWC) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ BatchNormalization, \ - kOnnxDomain, \ + DOMAIN, \ 7, 8, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - BatchNorm); \ + BatchNorm); \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ BatchNormalization, \ - kOnnxDomain, \ + DOMAIN, \ 9, 13, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - BatchNorm); \ + BatchNorm); \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ BatchNormalization, \ - kOnnxDomain, \ + DOMAIN, \ 14, 14, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ .TypeConstraint("U", DataTypeImpl::GetTensorType()), \ - BatchNorm); \ + BatchNorm); \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ BatchNormalization, \ - kOnnxDomain, \ + DOMAIN, \ 15, \ T, \ kCudaExecutionProvider, \ @@ -50,10 +51,10 @@ namespace cuda { .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ - BatchNorm); + BatchNorm); -template -Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) const { +template +Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) const { typedef typename ToCudaType::MappedType CudaT; const Tensor* X = p_op_kernel_context->Input(0); @@ -62,7 +63,7 @@ Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) const const Tensor* mean = p_op_kernel_context->Input(3); const Tensor* var = p_op_kernel_context->Input(4); - ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(X, scale, B, mean, var, spatial_ == 1)); + ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(X, scale, B, mean, var, spatial_ == 1, NHWC)); const TensorShape& x_shape = X->Shape(); const TensorShape& channel_shape = mean->Shape(); @@ -87,7 +88,7 @@ Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) const CudnnTensor data_desc; vector new_dims; BatchNormHelper::NormalizeDims(x_shape, new_dims); - ORT_RETURN_IF_ERROR(data_desc.Set(new_dims, CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(data_desc.Set(new_dims, CudnnTensor::GetDataType(), NHWC)); // For half data type, the alpha, beta, scale, B, mean, var need to be float type if (X->IsDataType()) { @@ -97,7 +98,7 @@ Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) const ORT_RETURN_IF_ERROR(bn_tensor_desc.Set(data_desc, cudnn_batch_norm_mode_)); // Convert the scale, B, mean, var to float - const int64_t C = x_shape.GetDims()[1]; + const int64_t C = x_shape.GetDims()[NHWC ? 3 : 1]; auto f_scale = GetScratchBuffer(C, p_op_kernel_context->GetComputeStream()); auto f_B = GetScratchBuffer(C, p_op_kernel_context->GetComputeStream()); auto f_mean = GetScratchBuffer(C, p_op_kernel_context->GetComputeStream()); @@ -175,13 +176,17 @@ Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) const return Status::OK(); } -#define SPECIALIZED_COMPUTE(T) \ - REGISTER_KERNEL_TYPED(T) \ - template Status BatchNorm::ComputeInternal(OpKernelContext* ctx) const; +#define SPECIALIZED_COMPUTE(T, DOMAIN, NHWC) \ + REGISTER_KERNEL_TYPED(T, DOMAIN, NHWC) \ + template Status BatchNorm::ComputeInternal(OpKernelContext* ctx) const; -SPECIALIZED_COMPUTE(float) -SPECIALIZED_COMPUTE(double) -SPECIALIZED_COMPUTE(MLFloat16) +SPECIALIZED_COMPUTE(float, kOnnxDomain, false) +SPECIALIZED_COMPUTE(double, kOnnxDomain, false) +SPECIALIZED_COMPUTE(MLFloat16, kOnnxDomain, false) +#ifdef ENABLE_CUDA_NHWC_OPS +SPECIALIZED_COMPUTE(float, kMSInternalNHWCDomain, true) +SPECIALIZED_COMPUTE(MLFloat16, kMSInternalNHWCDomain, true) +#endif } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/batch_norm.h b/onnxruntime/core/providers/cuda/nn/batch_norm.h index 99da7652a1d24..4eb9fb74d3761 100644 --- a/onnxruntime/core/providers/cuda/nn/batch_norm.h +++ b/onnxruntime/core/providers/cuda/nn/batch_norm.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #pragma once @@ -9,7 +10,7 @@ namespace onnxruntime { namespace cuda { -template +template class BatchNorm final : public CudaKernel { public: BatchNorm(const OpKernelInfo& op_kernel_info) diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 81db3c4186282..82f3503919237 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -1,38 +1,47 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. +#include + #include "core/providers/cuda/nn/conv.h" #include "core/common/span_utils.h" #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" #include "core/providers/cuda/tensor/slice.h" +#include "core/providers/cuda/tensor/transpose.h" namespace onnxruntime { namespace cuda { // Op Set 11 for Conv only update document to clearify default dilations and strides value. // which are already convered by op set 11 cpu versoin, so simply add declaration. -#define REGISTER_KERNEL_TYPED(T) \ +#define REGISTER_KERNEL_TYPED(T, DOMAIN, NHWC) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ Conv, \ - kOnnxDomain, \ + DOMAIN, \ 1, 10, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); \ + Conv); \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ Conv, \ - kOnnxDomain, \ + DOMAIN, \ 11, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); + Conv); + +REGISTER_KERNEL_TYPED(float, kOnnxDomain, false) +REGISTER_KERNEL_TYPED(double, kOnnxDomain, false) +REGISTER_KERNEL_TYPED(MLFloat16, kOnnxDomain, false) -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(double) -REGISTER_KERNEL_TYPED(MLFloat16) +#ifdef ENABLE_CUDA_NHWC_OPS +REGISTER_KERNEL_TYPED(float, kMSInternalNHWCDomain, true) +REGISTER_KERNEL_TYPED(MLFloat16, kMSInternalNHWCDomain, true) +#endif template const cudnnConvolutionFwdAlgo_t Conv::kAllAlgos[] = { @@ -86,6 +95,39 @@ Status SliceOutUnwantedOutputSection(cudaStream_t stream, return SliceCuda::Impl(stream, input_data, input_dims, output_data, compute_metadata, element_size); } +template +Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) { + is_packed = false; + // only layout of weight input is adjusted via PrePack + if (NHWC && is_nhwc_domain_) { // InputTensors::IN_W + if (input_idx == 1) { + // Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} + auto orig_shape = tensor.Shape(); + + InlinedVector perm{0, 2, 3, 1}; + gsl::span permutation(perm.data(), 4); + TensorShapeVector new_dims{orig_shape[0], + orig_shape[2], + orig_shape[3], + orig_shape[1]}; + W_ = Tensor::Create(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); + + auto status = cuda::Transpose::DoTranspose(GetDeviceProp(), + DefaultCudaStream(), + DefaultCublasHandle(), + permutation, tensor, *W_); + if (!status.IsOK()) { + return status; + } + CUDA_CALL_THROW(cudaStreamSynchronize(DefaultCudaStream())); + is_packed = true; + } + } + + return Status::OK(); +} + template Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const { // set X @@ -95,7 +137,12 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) s_.x_data = reinterpret_cast(X->Data()); s_.element_size = X->DataType()->Size(); // set W - const Tensor* W = context->Input(1); + const Tensor* W; + if (!W_) { + W = context->Input(1); + } else { + W = W_.get(); + } const TensorShape& w_shape = W->Shape(); auto w_dims = w_shape.AsShapeVector(); s_.w_data = reinterpret_cast(W->Data()); diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index 07825b93204ca..bcaa4d855b81e 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -1,13 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #pragma once +#include +#include + #include "core/platform/ort_mutex.h" #include "core/providers/cuda/cuda_kernel.h" #include "core/providers/cuda/cudnn_common.h" #include "core/providers/cpu/nn/conv_attributes.h" -#include namespace onnxruntime { @@ -187,8 +190,12 @@ class Conv : public CudaKernel { Conv(const OpKernelInfo& info) : CudaKernel(info), conv_attrs_(info) { auto pads_size = conv_attrs_.pads.size(); ORT_ENFORCE(pads_size % 2 == 0); + is_nhwc_domain_ = info.node().Domain() == kMSInternalNHWCDomain; } + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) override; + Status ComputeInternal(OpKernelContext* context) const override; protected: @@ -201,6 +208,8 @@ class Conv : public CudaKernel { mutable CudnnConvState s_; constexpr static auto kDefaultConvAlgo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; static const cudnnConvolutionFwdAlgo_t kAllAlgos[]; + std::unique_ptr W_; + bool is_nhwc_domain_; // prepack is only needed for the Conv in kMSInternalNHWCDomain }; Status SliceOutUnwantedOutputSection(cudaStream_t stream, diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index 2b3326c528659..55dceaa2698e8 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -1,7 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. +#include + #include "conv_transpose.h" +#include "core/providers/cuda/tensor/transpose.h" // To suppress FP static analyzer warnings: // https://msdata.visualstudio.com/Vienna/_workitems/edit/1944928 and @@ -17,35 +21,59 @@ namespace cuda { // Op Set 11 for ConvTranspose only update document to clarify default dilations and strides value. // which are already covered by op set 11 cpu version, so simply add declaration. -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - ConvTranspose, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - ConvTranspose); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - ConvTranspose, \ - kOnnxDomain, \ - 11, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - ConvTranspose); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(double) -REGISTER_KERNEL_TYPED(MLFloat16) - -template -Status ConvTranspose::ComputeInternal(OpKernelContext* context) const { +#define REGISTER_KERNEL_TYPED(T, DOMAIN, NHWC) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + ConvTranspose, DOMAIN, 1, 10, T, kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), ConvTranspose); \ + ONNX_OPERATOR_TYPED_KERNEL_EX(ConvTranspose, DOMAIN, 11, T, kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + ConvTranspose); + +REGISTER_KERNEL_TYPED(float, kOnnxDomain, false) +REGISTER_KERNEL_TYPED(double, kOnnxDomain, false) +REGISTER_KERNEL_TYPED(MLFloat16, kOnnxDomain, false) + +#ifdef ENABLE_CUDA_NHWC_OPS +REGISTER_KERNEL_TYPED(float, kMSInternalNHWCDomain, true) +REGISTER_KERNEL_TYPED(MLFloat16, kMSInternalNHWCDomain, true) +#endif + +template +Status ConvTranspose::ComputeInternal(OpKernelContext* context) const { return DoConvTranspose(context, false); } -template -Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const { +template +Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, + [[maybe_unused]] PrePackedWeights* prepacked_weights) { + is_packed = false; + // only layout of weight input is adjusted via PrePack + if (NHWC) { // InputTensors::IN_W + if (input_idx == 1) { + // Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} + auto orig_shape = tensor.Shape(); + + InlinedVector perm{0, 2, 3, 1}; + gsl::span permutation(perm.data(), 4); + TensorShapeVector new_dims{orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[1]}; + W_ = Tensor::Create(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); + + auto status = cuda::Transpose::DoTranspose(GetDeviceProp(), DefaultCudaStream(), DefaultCublasHandle(), + permutation, tensor, *W_); + + if (!status.IsOK()) { + return status; + } + CUDA_CALL_THROW(cudaStreamSynchronize(DefaultCudaStream())); + is_packed = true; + } + } + + return Status::OK(); +} + +template +Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const { typedef typename ToCudaType::MappedType CudaT; const Tensor* X = context->Input(0); @@ -59,7 +87,12 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input X must be 3-, 4- or 5-dimensional.", " X: ", X->Shape().ToString().c_str()); } - const Tensor* W = context->Input(1); + const Tensor* W; + if (!W_) { + W = context->Input(1); + } else { + W = W_.get(); + } const TensorShape& w_shape = W->Shape(); TensorShapeVector w_dims = w_shape.AsShapeVector(); auto w_data = reinterpret_cast(W->Data()); @@ -80,8 +113,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ bool input_dims_changed = (s_.last_x_dims.AsShapeVector() != x_dims); bool w_dims_changed = (s_.last_w_dims.AsShapeVector() != w_dims); if (input_dims_changed || w_dims_changed) { - if (input_dims_changed) - s_.last_x_dims = gsl::make_span(x_dims); + if (input_dims_changed) s_.last_x_dims = gsl::make_span(x_dims); if (w_dims_changed) { s_.last_w_dims = gsl::make_span(w_dims); @@ -89,7 +121,8 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ } ConvTransposeAttributes::Prepare p; - ORT_RETURN_IF_ERROR(conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, dynamic_padding)); + ORT_RETURN_IF_ERROR( + conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, dynamic_padding, &w_shape, NHWC)); auto y_dims = p.Y->Shape().AsShapeVector(); if (x_dimensions == 3) { @@ -102,8 +135,15 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ } s_.y_dims = gsl::make_span(y_dims); - if (w_dims_changed) - ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType())); + if (w_dims_changed) { + if (NHWC) { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), + static_cast(w_dims[0]), static_cast(w_dims[3]), + static_cast(w_dims[1]), static_cast(w_dims[2]))); + } else { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType())); + } + } // Special case when there is a dim value of 0 in the shape. // Return only after we have cached the following for subsequent runs : @@ -112,31 +152,39 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ if (p.Y->Shape().Size() == 0) { return Status::OK(); } - - ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, CudnnTensor::GetDataType())); + if (NHWC) { + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), + static_cast(x_dims[0]), static_cast(x_dims[3]), + static_cast(x_dims[1]), static_cast(x_dims[2]))); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), + static_cast(y_dims[0]), static_cast(y_dims[3]), + static_cast(y_dims[1]), static_cast(y_dims[2]))); + } else { + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims, CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, CudnnTensor::GetDataType())); + } cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations, - gsl::narrow_cast(conv_transpose_attrs_.group), - mode, CudnnTensor::GetDataType())); + gsl::narrow_cast(conv_transpose_attrs_.group), mode, + CudnnTensor::GetDataType())); if (has_bias) { const auto& b_shape = p.B->Shape(); ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); TensorShapeVector b_dims(2 + p.kernel_shape.size()); - b_dims[0] = 1; // N - b_dims[1] = b_shape[0]; // C - for (size_t i = 0; i < p.kernel_shape.size(); i++) - b_dims[2 + i] = 1; + b_dims[0] = 1; // N + b_dims[NHWC ? 3 : 1] = b_shape[0]; // C + for (size_t i = 0; i < p.kernel_shape.size(); i++) b_dims[(NHWC ? 1 : 2) + i] = 1; - ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType(), NHWC)); } y_data = reinterpret_cast(p.Y->MutableData()); if (!s_.cached_benchmark_results.contains(x_dims)) { - IAllocatorUniquePtr algo_search_workspace = GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); + IAllocatorUniquePtr algo_search_workspace = + GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); // set math type to tensor core before algorithm search if constexpr (std::is_same::value) @@ -145,19 +193,8 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ cudnnConvolutionBwdDataAlgoPerf_t perf; int algo_count = 1; CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardDataAlgorithmEx( - GetCudnnHandle(context), - s_.w_desc, - w_data, - s_.x_tensor, - x_data, - s_.conv_desc, - s_.y_tensor, - y_data, - 1, - &algo_count, - &perf, - algo_search_workspace.get(), - AlgoSearchWorkspaceSize)); + GetCudnnHandle(context), s_.w_desc, w_data, s_.x_tensor, x_data, s_.conv_desc, s_.y_tensor, y_data, 1, + &algo_count, &perf, algo_search_workspace.get(), AlgoSearchWorkspaceSize)); s_.cached_benchmark_results.insert(x_dims, {perf.algo, perf.memory, perf.mathType}); } @@ -188,26 +225,15 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ IAllocatorUniquePtr workspace = GetScratchBuffer(s_.workspace_bytes, context->GetComputeStream()); - CUDNN_RETURN_IF_ERROR( - cudnnConvolutionBackwardData( - GetCudnnHandle(context), - &alpha, - s_.w_desc, - w_data, - s_.x_tensor, - x_data, - s_.conv_desc, - s_.algo, - workspace.get(), - s_.workspace_bytes, - &beta, - s_.y_tensor, - y_data)); + CUDNN_RETURN_IF_ERROR(cudnnConvolutionBackwardData(GetCudnnHandle(context), &alpha, s_.w_desc, w_data, s_.x_tensor, + x_data, s_.conv_desc, s_.algo, workspace.get(), + s_.workspace_bytes, &beta, s_.y_tensor, y_data)); if (has_bias) { const Tensor* B = dynamic_padding ? context->Input(3) : context->Input(2); auto b_data = reinterpret_cast(B->Data()); - CUDNN_RETURN_IF_ERROR(cudnnAddTensor(GetCudnnHandle(context), &alpha, s_.b_tensor, b_data, &alpha, s_.y_tensor, y_data)); + CUDNN_RETURN_IF_ERROR( + cudnnAddTensor(GetCudnnHandle(context), &alpha, s_.b_tensor, b_data, &alpha, s_.y_tensor, y_data)); } } diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.h b/onnxruntime/core/providers/cuda/nn/conv_transpose.h index 165d548d27fa2..77c9d94162b6b 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.h @@ -1,8 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #pragma once +#include + #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cuda_kernel.h" #include "core/providers/cuda/cudnn_common.h" @@ -12,10 +15,12 @@ namespace onnxruntime { namespace cuda { -template +template class ConvTranspose : public CudaKernel { public: ConvTranspose(const OpKernelInfo& info) : CudaKernel(info), conv_transpose_attrs_(info){}; + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) override; Status ComputeInternal(OpKernelContext* context) const override; Status DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const; @@ -23,6 +28,7 @@ class ConvTranspose : public CudaKernel { ConvTransposeAttributes conv_transpose_attrs_; mutable CudnnConvState s_; + std::unique_ptr W_; }; } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu index 4cc560a1178ef..679b8b6b78886 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu @@ -104,17 +104,17 @@ __device__ void cuWelfordMuSigma2( const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; const T* lvals = vals + i1 * n2; - const T* skip_vals = (skip != NULL) ? skip + i1 * n2 : NULL; + const T* skip_vals = (skip != nullptr) ? skip + i1 * n2 : nullptr; int l = 4 * thrx; for (; l + 3 < n2; l += 4 * numx) { for (int k = 0; k < 4; ++k) { U curr = static_cast(lvals[l + k]); - if (bias != NULL) { + if (bias != nullptr) { curr += static_cast(bias[l + k]); } - if (skip_vals != NULL) { + if (skip_vals != nullptr) { curr += static_cast(skip_vals[l + k]); } @@ -124,11 +124,11 @@ __device__ void cuWelfordMuSigma2( for (; l < n2; ++l) { U curr = static_cast(lvals[l]); - if (bias != NULL) { + if (bias != nullptr) { curr += static_cast(bias[l]); } - if (skip_vals != NULL) { + if (skip_vals != nullptr) { curr += static_cast(skip_vals[l]); } @@ -301,7 +301,7 @@ namespace { // { // extern __device__ void error(void); // error(); -// return NULL; +// return nullptr; // } // }; // https://github.com/NVIDIA/apex/issues/246 @@ -338,9 +338,7 @@ __global__ void cuApplyLayerNorm( const V* __restrict__ beta, const T* __restrict__ skip, const T* __restrict__ bias, - T* __restrict__ skip_input_bias_add_output, - const bool skip_broadcasted, - const int skip_size) { + T* __restrict__ skip_input_bias_add_output) { // Assumptions: // 1) blockDim.x == GPU_WARP_SIZE // 2) Tensors are contiguous @@ -350,38 +348,35 @@ __global__ void cuApplyLayerNorm( U* buf = shared.getPointer(); U mu, sigma2; cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf, skip, bias); - const T* lvals = vals + i1 * n2; - const T* skip_vals = (skip != NULL) ? skip + i1 * n2 : NULL; - V* ovals = output_vals + i1 * n2; - T* skip_input_bias_add_ovals = (skip_input_bias_add_output != NULL) ? skip_input_bias_add_output + i1 * n2 : NULL; + const int offset = i1 * n2; + const T* lvals = vals + offset; + const T* skip_vals = (skip != nullptr) ? skip + offset : nullptr; + + V* ovals = output_vals + offset; + T* skip_input_bias_add_ovals = (skip_input_bias_add_output != nullptr) ? skip_input_bias_add_output + offset : nullptr; U c_inv_std_dev = rsqrt(sigma2 + epsilon); const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; for (int i = thrx; i < n2; i += numx) { U curr = static_cast(lvals[i]); - - - if (bias != NULL) { + if (bias != nullptr) { curr += static_cast(bias[i]); } - if (skip_vals != NULL && skip_broadcasted) { - int skip_i = i % skip_size; - curr += static_cast(skip_vals[skip_i]); //Calculates index for the second dimension of the skip tensor - }else if (skip_vals != NULL){ + if (skip_vals != nullptr) { curr += static_cast(skip_vals[i]); } - U gamma_i = (gamma != NULL) ? (U)gamma[i] : (U)1; - U beta_i = (beta != NULL) ? (U)beta[i] : (U)0; + U gamma_i = (gamma != nullptr) ? (U)gamma[i] : (U)1; + U beta_i = (beta != nullptr) ? (U)beta[i] : (U)0; if (simplified) { ovals[i] = static_cast(gamma_i * c_inv_std_dev * curr); } else { ovals[i] = static_cast(gamma_i * c_inv_std_dev * (curr - mu) + beta_i); } - if (skip_input_bias_add_ovals != NULL) { + if (skip_input_bias_add_ovals != nullptr) { skip_input_bias_add_ovals[i] = static_cast(curr); } } @@ -418,9 +413,7 @@ void HostApplyLayerNorm( const V* beta, const T* skip, const T* bias, - T* skip_input_bias_add_output, - const bool skip_broadcasted, - const int skip_size) { + T* skip_input_bias_add_output) { const int maxGridY = prop.maxGridSize[1]; const int warp_size = prop.warpSize; ORT_ENFORCE(warp_size == GPU_WARP_SIZE_HOST); @@ -452,17 +445,14 @@ void HostApplyLayerNorm( n1, n2, U(epsilon), gamma, beta, - skip, bias, skip_input_bias_add_output, - skip_broadcasted, - skip_size); + skip, bias, skip_input_bias_add_output); } #define LAYERNORM_LINEAR_IMPL(T, U, V, simplified) \ template void HostApplyLayerNorm(const cudaDeviceProp& prop, cudaStream_t stream, V* output, \ U* mean, U* inv_std_dev, const T* input, int n1, int n2, \ double epsilon, const V* gamma, const V* beta, const T* skip, \ - const T* bias, T* skip_input_bias_add_output, const bool skip_broadcasted, \ - const int skip_size); + const T* bias, T* skip_input_bias_add_output); LAYERNORM_LINEAR_IMPL(float, float, float, true) LAYERNORM_LINEAR_IMPL(half, float, half, true) diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h index d0d5db8ba3587..e3952eefae35d 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h @@ -43,9 +43,7 @@ void HostApplyLayerNorm( const V* beta, const T* skip = nullptr, const T* bias = nullptr, - T* skip_input_bias_add_output = nullptr, - const bool skip_broadcasted = false, - const int skip_size = 0); + T* skip_input_bias_add_output = nullptr); } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/pool.cc b/onnxruntime/core/providers/cuda/nn/pool.cc index e632ef20bce43..8bc96958693bc 100644 --- a/onnxruntime/core/providers/cuda/nn/pool.cc +++ b/onnxruntime/core/providers/cuda/nn/pool.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" @@ -11,92 +12,99 @@ using namespace onnxruntime::common; namespace onnxruntime { namespace cuda { -#define POOLING_KERNEL(op_name, data_type, pool_type, since_version) \ +#define POOLING_KERNEL(op_name, data_type, pool_type, since_version, op_domain, nhwc) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ - op_name, \ - kOnnxDomain, \ - since_version, \ - data_type, \ - kCudaExecutionProvider, \ + op_name, op_domain, since_version, data_type, kCudaExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Pool); - -#define POOLING_KERNEL_VERSIONED(op_name, data_type, pool_type, since_version, end_version) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - op_name, \ - kOnnxDomain, \ - since_version, \ - end_version, \ - data_type, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Pool); - -#define POOLING_KERNEL_WITH_INDICES(op_name, data_type, pool_type, since_version) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - op_name, \ - kOnnxDomain, \ - since_version, \ - data_type, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ - Pool); - -#define POOLING_KERNEL_VERSIONED_WITH_INDICES(op_name, data_type, pool_type, since_version, end_version) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - op_name, \ - kOnnxDomain, \ - since_version, \ - end_version, \ - data_type, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ - Pool); - -POOLING_KERNEL_VERSIONED(AveragePool, float, AveragePool, 7, 9) -POOLING_KERNEL_VERSIONED(AveragePool, double, AveragePool, 7, 9) -POOLING_KERNEL_VERSIONED(AveragePool, MLFloat16, AveragePool, 7, 9) -POOLING_KERNEL_VERSIONED(AveragePool, float, AveragePool, 10, 10) -POOLING_KERNEL_VERSIONED(AveragePool, double, AveragePool, 10, 10) -POOLING_KERNEL_VERSIONED(AveragePool, MLFloat16, AveragePool, 10, 10) + Pool); + +#define POOLING_KERNEL_VERSIONED(op_name, data_type, pool_type, since_version, end_version, op_domain, nhwc) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + op_name, op_domain, since_version, end_version, data_type, kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pool); + +#define POOLING_KERNEL_WITH_INDICES(op_name, data_type, pool_type, since_version, op_domain, nhwc) \ + ONNX_OPERATOR_TYPED_KERNEL_EX(op_name, op_domain, since_version, data_type, kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ + Pool); + +#define POOLING_KERNEL_VERSIONED_WITH_INDICES(op_name, data_type, pool_type, since_version, end_version, op_domain, \ + nhwc) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(op_name, op_domain, since_version, end_version, data_type, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ + Pool); + +POOLING_KERNEL_VERSIONED(AveragePool, float, AveragePool, 7, 9, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED(AveragePool, double, AveragePool, 7, 9, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED(AveragePool, MLFloat16, AveragePool, 7, 9, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED(AveragePool, float, AveragePool, 10, 10, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED(AveragePool, double, AveragePool, 10, 10, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED(AveragePool, MLFloat16, AveragePool, 10, 10, kOnnxDomain, false) // AveragePool and MaxPool op set 11 only update spec document on default value for dilations and strides. -POOLING_KERNEL(AveragePool, float, AveragePool, 11) -POOLING_KERNEL(AveragePool, double, AveragePool, 11) -POOLING_KERNEL(AveragePool, MLFloat16, AveragePool, 11) -POOLING_KERNEL(GlobalAveragePool, float, AveragePool, 1) -POOLING_KERNEL(GlobalAveragePool, double, AveragePool, 1) -POOLING_KERNEL(GlobalAveragePool, MLFloat16, AveragePool, 1) -POOLING_KERNEL_VERSIONED(MaxPool, float, MaxPool<1>, 1, 7) -POOLING_KERNEL_VERSIONED(MaxPool, double, MaxPool<1>, 1, 7) -POOLING_KERNEL_VERSIONED(MaxPool, MLFloat16, MaxPool<1>, 1, 7) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 8, 9) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, double, MaxPool<8>, 8, 9) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 8, 9) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 10, 10) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, double, MaxPool<8>, 10, 10) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 10, 10) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 11, 11) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, double, MaxPool<8>, 11, 11) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 11, 11) -POOLING_KERNEL_WITH_INDICES(MaxPool, float, MaxPool<8>, 12) -POOLING_KERNEL_WITH_INDICES(MaxPool, double, MaxPool<8>, 12) -POOLING_KERNEL_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 12) -POOLING_KERNEL_WITH_INDICES(MaxPool, int8_t, MaxPool<8>, 12) -POOLING_KERNEL_WITH_INDICES(MaxPool, uint8_t, MaxPool<8>, 12) - -POOLING_KERNEL(GlobalMaxPool, float, MaxPool<1>, 1) -POOLING_KERNEL(GlobalMaxPool, double, MaxPool<1>, 1) -POOLING_KERNEL(GlobalMaxPool, MLFloat16, MaxPool<1>, 1) +POOLING_KERNEL(AveragePool, float, AveragePool, 11, kOnnxDomain, false) +POOLING_KERNEL(AveragePool, double, AveragePool, 11, kOnnxDomain, false) +POOLING_KERNEL(AveragePool, MLFloat16, AveragePool, 11, kOnnxDomain, false) +POOLING_KERNEL(GlobalAveragePool, float, AveragePool, 1, kOnnxDomain, false) +POOLING_KERNEL(GlobalAveragePool, double, AveragePool, 1, kOnnxDomain, false) +POOLING_KERNEL(GlobalAveragePool, MLFloat16, AveragePool, 1, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED(MaxPool, float, MaxPool<1>, 1, 7, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED(MaxPool, double, MaxPool<1>, 1, 7, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED(MaxPool, MLFloat16, MaxPool<1>, 1, 7, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 8, 9, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, double, MaxPool<8>, 8, 9, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 8, 9, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 10, 10, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, double, MaxPool<8>, 10, 10, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 10, 10, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 11, 11, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, double, MaxPool<8>, 11, 11, kOnnxDomain, false) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 11, 11, kOnnxDomain, false) +POOLING_KERNEL_WITH_INDICES(MaxPool, float, MaxPool<8>, 12, kOnnxDomain, false) +POOLING_KERNEL_WITH_INDICES(MaxPool, double, MaxPool<8>, 12, kOnnxDomain, false) +POOLING_KERNEL_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 12, kOnnxDomain, false) +POOLING_KERNEL_WITH_INDICES(MaxPool, int8_t, MaxPool<8>, 12, kOnnxDomain, false) +POOLING_KERNEL_WITH_INDICES(MaxPool, uint8_t, MaxPool<8>, 12, kOnnxDomain, false) + +POOLING_KERNEL(GlobalMaxPool, float, MaxPool<1>, 1, kOnnxDomain, false) +POOLING_KERNEL(GlobalMaxPool, double, MaxPool<1>, 1, kOnnxDomain, false) +POOLING_KERNEL(GlobalMaxPool, MLFloat16, MaxPool<1>, 1, kOnnxDomain, false) + +// NHWC variants +#ifdef ENABLE_CUDA_NHWC_OPS +POOLING_KERNEL_VERSIONED(MaxPool, float, MaxPool<1>, 1, 7, kMSInternalNHWCDomain, true) +POOLING_KERNEL_VERSIONED(MaxPool, MLFloat16, MaxPool<1>, 1, 7, kMSInternalNHWCDomain, true) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 8, 9, kMSInternalNHWCDomain, true) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 8, 9, kMSInternalNHWCDomain, true) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 10, 10, kMSInternalNHWCDomain, true) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 10, 10, kMSInternalNHWCDomain, true) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 11, 11, kMSInternalNHWCDomain, true) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 11, 11, kMSInternalNHWCDomain, true) +POOLING_KERNEL_WITH_INDICES(MaxPool, float, MaxPool<8>, 12, kMSInternalNHWCDomain, true) +POOLING_KERNEL_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 12, kMSInternalNHWCDomain, true) + +POOLING_KERNEL(GlobalMaxPool, float, MaxPool<1>, 1, kMSInternalNHWCDomain, true) +POOLING_KERNEL(GlobalMaxPool, MLFloat16, MaxPool<1>, 1, kMSInternalNHWCDomain, true) + +POOLING_KERNEL_VERSIONED(AveragePool, float, AveragePool, 7, 9, kMSInternalNHWCDomain, true) +POOLING_KERNEL_VERSIONED(AveragePool, MLFloat16, AveragePool, 7, 9, kMSInternalNHWCDomain, true) +POOLING_KERNEL_VERSIONED(AveragePool, float, AveragePool, 10, 10, kMSInternalNHWCDomain, true) +POOLING_KERNEL_VERSIONED(AveragePool, MLFloat16, AveragePool, 10, 10, kMSInternalNHWCDomain, true) +// AveragePool and MaxPool op set 11 only update spec document on default value for dilations +POOLING_KERNEL(AveragePool, float, AveragePool, 11, kMSInternalNHWCDomain, true) +POOLING_KERNEL(AveragePool, MLFloat16, AveragePool, 11, kMSInternalNHWCDomain, true) +POOLING_KERNEL(GlobalAveragePool, float, AveragePool, 1, kMSInternalNHWCDomain, true) +POOLING_KERNEL(GlobalAveragePool, MLFloat16, AveragePool, 1, kMSInternalNHWCDomain, true) +#endif class CudnnPoolingDescriptor final { public: - CudnnPoolingDescriptor() : desc_(nullptr) { - } + CudnnPoolingDescriptor() : desc_(nullptr) {} ~CudnnPoolingDescriptor() { if (desc_ != nullptr) { @@ -108,12 +116,9 @@ class CudnnPoolingDescriptor final { CudnnPoolingDescriptor(const CudnnPoolingDescriptor&) = delete; CudnnPoolingDescriptor& operator=(const CudnnPoolingDescriptor&) = delete; - Status Set(cudnnPoolingMode_t mode, - const gsl::span& kernel_shape, - const gsl::span& pads, - const gsl::span& strides) { - if (!desc_) - CUDNN_RETURN_IF_ERROR(cudnnCreatePoolingDescriptor(&desc_)); + Status Set(cudnnPoolingMode_t mode, const gsl::span& kernel_shape, + const gsl::span& pads, const gsl::span& strides) { + if (!desc_) CUDNN_RETURN_IF_ERROR(cudnnCreatePoolingDescriptor(&desc_)); int rank = gsl::narrow_cast(kernel_shape.size()); InlinedVector window(rank); @@ -128,14 +133,8 @@ class CudnnPoolingDescriptor final { for (int i = 0; i < rank; i++) { stride[i] = gsl::narrow_cast(strides[i]); } - CUDNN_RETURN_IF_ERROR(SetPoolingNdDescriptorHelper( - desc_, - mode, - CUDNN_PROPAGATE_NAN, - rank, - window.data(), - padding.data(), - stride.data())); + CUDNN_RETURN_IF_ERROR(SetPoolingNdDescriptorHelper(desc_, mode, CUDNN_PROPAGATE_NAN, rank, window.data(), + padding.data(), stride.data())); return Status::OK(); } @@ -146,8 +145,8 @@ class CudnnPoolingDescriptor final { cudnnPoolingDescriptor_t desc_; }; -template -Status Pool::ComputeInternal(OpKernelContext* context) const { +template +Status Pool::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; const Tensor* X = context->Input(0); const TensorShape& x_shape = X->Shape(); @@ -166,13 +165,12 @@ Status Pool::ComputeInternal(OpKernelContext* context) const { pads.assign(kernel_shape.size(), 0); strides.assign(kernel_shape.size(), 1); } - - auto y_dims = pool_attrs_.SetOutputSize(x_shape, x_shape[1], &pads); + auto out_channel = NHWC ? x_shape[3] : x_shape[1]; + auto y_dims = pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, NHWC); TensorShape y_shape(y_dims); Tensor* Y = context->Output(0, y_shape); // special case when there is a dim value of 0 in the shape. - if (y_shape.Size() == 0) - return Status::OK(); + if (y_shape.Size() == 0) return Status::OK(); auto x_data = reinterpret_cast(X->Data()); auto y_data = reinterpret_cast(Y->MutableData()); @@ -181,12 +179,19 @@ Status Pool::ComputeInternal(OpKernelContext* context) const { TensorShapeVector y_dims_cudnn(y_dims); if (kernel_shape.size() < 2) { // cudnn only takes 4D or 5D input, so pad dimensions if needed - x_dims_cudnn.push_back(1); - y_dims_cudnn.push_back(1); + if (NHWC) { + x_dims_cudnn.insert(x_dims_cudnn.begin() + 1, 1); + y_dims_cudnn.insert(y_dims_cudnn.begin() + 1, 1); + kernel_shape.insert(kernel_shape.begin() + 1, 1); + strides.insert(strides.begin() + 1, 1); + } else { + x_dims_cudnn.push_back(1); + y_dims_cudnn.push_back(1); + kernel_shape.push_back(1); + strides.push_back(1); + } pads.insert(pads.begin() + kernel_shape.size(), 0); pads.insert(pads.end(), 0); - kernel_shape.push_back(1); - strides.push_back(1); } cudnnPoolingMode_t mode = CUDNN_POOLING_MAX; @@ -203,8 +208,8 @@ Status Pool::ComputeInternal(OpKernelContext* context) const { const auto beta = Consts::Zero; CudnnTensor x_tensor; CudnnTensor y_tensor; - ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType(), NHWC)); + ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType(), NHWC)); const auto input_count = x_shape.Size(); const auto output_count = y_shape.Size(); @@ -212,24 +217,26 @@ Status Pool::ComputeInternal(OpKernelContext* context) const { IAllocatorUniquePtr temp_X = GetScratchBuffer(input_count, context->GetComputeStream()); auto temp_Y = GetScratchBuffer(output_count, context->GetComputeStream()); Impl_Cast(Stream(context), reinterpret_cast(x_data), temp_X.get(), input_count); - CUDNN_RETURN_IF_ERROR(PoolingForwardHelper(GetCudnnHandle(context), pooling_desc, &alpha, x_tensor, temp_X.get(), &beta, y_tensor, temp_Y.get())); + CUDNN_RETURN_IF_ERROR(PoolingForwardHelper(GetCudnnHandle(context), pooling_desc, &alpha, x_tensor, temp_X.get(), + &beta, y_tensor, temp_Y.get())); Impl_Cast(Stream(context), temp_Y.get(), y_data, output_count); } else { const auto alpha = Consts::One; const auto beta = Consts::Zero; CudnnTensor x_tensor; CudnnTensor y_tensor; - ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType(), NHWC)); + ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType(), NHWC)); - CUDNN_RETURN_IF_ERROR(PoolingForwardHelper(GetCudnnHandle(context), pooling_desc, &alpha, x_tensor, x_data, &beta, y_tensor, y_data)); + CUDNN_RETURN_IF_ERROR( + PoolingForwardHelper(GetCudnnHandle(context), pooling_desc, &alpha, x_tensor, x_data, &beta, y_tensor, y_data)); } return Status::OK(); } -template -Status Pool>::ComputeInternal(OpKernelContext* context) const { +template +Status Pool, NHWC>::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; const Tensor* X = context->Input(0); const TensorShape& x_shape = X->Shape(); @@ -248,13 +255,12 @@ Status Pool>::ComputeInternal(OpKernelContext* context) const { pads.assign(kernel_shape.size(), 0); strides.assign(kernel_shape.size(), 1); } - - auto y_dims = this->pool_attrs_.SetOutputSize(x_shape, x_shape[1], &pads); + auto out_channel = NHWC ? x_shape[3] : x_shape[1]; + auto y_dims = this->pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, NHWC); Tensor* Y = context->Output(0, TensorShape(y_dims)); // special case when there is a dim value of 0 in the shape. - if (Y->Shape().Size() == 0) - return Status::OK(); + if (Y->Shape().Size() == 0) return Status::OK(); auto x_data = reinterpret_cast(X->Data()); auto y_data = reinterpret_cast(Y->MutableData()); @@ -262,20 +268,10 @@ Status Pool>::ComputeInternal(OpKernelContext* context) const { Tensor* I = context->Output(1, TensorShape(y_dims)); if (nullptr != I || !this->pool_attrs_.default_dilations) { auto i_data = nullptr == I ? nullptr : I->MutableData(); - MaxPoolWithIndex( - this->Stream(context), - x_shape, - TensorShape(y_dims), - kernel_shape, - strides, - pads, - this->pool_attrs_.dilations, - this->pool_attrs_.storage_order, - x_data, - y_data, - i_data); + MaxPoolWithIndex(this->Stream(context), x_shape, TensorShape(y_dims), kernel_shape, strides, pads, + this->pool_attrs_.dilations, this->pool_attrs_.storage_order, x_data, y_data, i_data); } else { - ORT_RETURN_IF_ERROR((Pool>::ComputeInternal(context))); + ORT_RETURN_IF_ERROR((Pool, NHWC>::ComputeInternal(context))); } return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/nn/pool.h b/onnxruntime/core/providers/cuda/nn/pool.h index fb223c18d2625..8b5152a1565a9 100644 --- a/onnxruntime/core/providers/cuda/nn/pool.h +++ b/onnxruntime/core/providers/cuda/nn/pool.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #pragma once @@ -10,7 +11,7 @@ namespace onnxruntime { namespace cuda { -template +template class Pool : public CudaKernel, public PoolBase { public: Pool(const OpKernelInfo& info) : CudaKernel(info), PoolBase(info) {} @@ -18,10 +19,10 @@ class Pool : public CudaKernel, public PoolBase { Status ComputeInternal(OpKernelContext* context) const override; }; -template -class Pool> final : public Pool> { +template +class Pool, NHWC> final : public Pool, NHWC> { public: - Pool(const OpKernelInfo& info) : Pool>(info) {} + explicit Pool(const OpKernelInfo& info) : Pool, NHWC>(info) {} Status ComputeInternal(OpKernelContext* context) const override; }; diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index 2f057d53d5607..bc78e577c5052 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -16,140 +16,29 @@ using namespace onnxruntime::common; namespace onnxruntime { namespace cuda { -// opset 11 explicitly added support for negative axis. implementation already allowed it. -#define REGISTER_KERNEL_TYPED(name, T) \ +#define REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, end) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ name, \ kOnnxDomain, \ - 1, 10, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 11, 12, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 13, \ + 1, end, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ name); -#define REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 11, 11, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 12, 12, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ +#define REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, version) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + name, \ + kOnnxDomain, \ + version, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()).InputMemoryType(OrtMemTypeCPUInput, 1), \ name); -// Register those with changes in OpSet12. -#define REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(name, T) \ - REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 13, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -#define REGISTER_KERNEL_VERSIONED_TYPED_13(name, T) \ - REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 13, 13, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -// Register ReduceMin int64_t support in OpSet14. -#define REGISTER_KERNEL_TYPED_14(name, T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 14, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -// CUDA ArgMax/ArgMin doesn't have OpSet12+ implementation (with select_last_index attr) yet -#define REGISTER_KERNEL_VERSIONED_TYPED_11(name, T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 11, 11, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -// Register with the latest version 13 -#define REGISTER_KERNEL_TYPED_13(name, T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 11, 12, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 13, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .InputMemoryType(OrtMemTypeCPUInput, 1) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); +#define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \ + REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, last) \ + REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, cur) // TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored. template @@ -725,6 +614,30 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, return Status::OK(); } +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + template template Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, cudnnReduceTensorOp_t cudnn_reduce_op) const { @@ -917,69 +830,76 @@ template std::unique_ptr ReduceComputeShape(); + +#ifdef ENABLE_STRIDED_TENSORS + // Strided output. + if (input_data_tensor->DataRaw() == output_tensor->DataRaw()) { + gsl::span input_strides = input_data_tensor->Strides(); + TensorShapeVector output_strides = + ComputeOutputStrides(input_data_tensor->Shape(), input_strides, output_shape); + output_tensor->SetShapeAndStrides(output_shape, output_strides); + return Status::OK(); + } +#endif + + auto output_dims = output_shape.AsShapeVector(); + auto input_dims = input_data_tensor->Shape().AsShapeVector(); + + CalcEffectiveDims(input_dims, output_dims); + int rank = gsl::narrow_cast(output_dims.size()); + + TensorPitches original_input_strides(input_dims); + TensorPitches original_output_strides(output_dims); + + TArray input_strides(rank); + for (auto i = 0; i < rank; i++) { + input_strides[i] = input_dims[i] == 1 ? 0 : original_input_strides[i]; + } + + TArray output_strides(rank); + for (auto i = 0; i < rank; i++) { + output_strides[i] = fast_divmod(static_cast(original_output_strides[i])); + } + + return ExpandImpl( + cuda_kernel->Stream(ctx), + input_data_tensor->DataType()->Size(), + gsl::narrow_cast(output_shape.Size()), + gsl::narrow_cast(input_data_tensor->Shape().Size()), + input_data_tensor->DataRaw(), + output_tensor->MutableDataRaw(), + output_strides, + input_strides); +} + +std::unique_ptr FuncExpand( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* input_data_tensor, + const Tensor* input_shape_tensor) { + // new shape to be expanded to + const auto* p_shape = input_shape_tensor->Data(); + TensorShapeVector output_dims{p_shape, p_shape + input_shape_tensor->Shape().Size()}; + TensorShape output_shape(output_dims); + + ORT_ENFORCE( + ComputeOutputShape( + cuda_kernel->Node().Name(), + input_data_tensor->Shape(), + output_dims, output_shape) + .IsOK()); + + // Pre-allocate output. + AllocatorPtr alloc; + ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc).IsOK()); + auto output_tensor = Tensor::Create(input_data_tensor->DataType(), output_shape, alloc); + + // Only assign output values when output tensor is non-empty + // because empty tensor doesn't own any data. + if (output_shape.Size() > 0) { + ORT_ENFORCE(FuncExpand(cuda_kernel, ctx, input_data_tensor, input_shape_tensor, output_tensor.get()).IsOK()); + } + + return output_tensor; +} + #ifdef ENABLE_STRIDED_TENSORS #define CREATE_EXPAND_KERNEL_DEF (*KernelDefBuilder::Create()).MayStridedOutput(0, 0) #else diff --git a/onnxruntime/core/providers/cuda/tensor/expand.h b/onnxruntime/core/providers/cuda/tensor/expand.h index 4cf4c14e61058..a0b12790017f6 100644 --- a/onnxruntime/core/providers/cuda/tensor/expand.h +++ b/onnxruntime/core/providers/cuda/tensor/expand.h @@ -20,5 +20,18 @@ Status ComputeOutputShape( const TensorShape& rhs_shape, TensorShape& out_shape); +Status FuncExpand( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* input_data_tensor, + const Tensor* /*input_shape_tensor*/, + Tensor* output_tensor); + +std::unique_ptr FuncExpand( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* input_data_tensor, + const Tensor* input_shape_tensor); + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/reshape.cc b/onnxruntime/core/providers/cuda/tensor/reshape.cc index 3c6d900cee9a4..ab364c274a32d 100644 --- a/onnxruntime/core/providers/cuda/tensor/reshape.cc +++ b/onnxruntime/core/providers/cuda/tensor/reshape.cc @@ -6,6 +6,81 @@ namespace onnxruntime { namespace cuda { +TensorShape InferReshapeOutputShape( + const TensorShape& data_tensor_shape, // Data tensor's shape. + const gsl::span& shape_span, // Shape that data tensor reshape to. + bool allow_zero) { + TensorShapeVector shape_vector(shape_span.begin(), shape_span.end()); + ReshapeHelper helper(data_tensor_shape, shape_vector, allow_zero); + return TensorShape(shape_vector); +} + +TensorShape InferReshapeOutputShape(const Tensor* src, const Tensor* shape, bool allow_zero) { + ORT_ENFORCE(shape != nullptr, "Cannot reshape to a null shape."); + ORT_ENFORCE(shape->Shape().NumDimensions() == 1, "Shape must be an 1-D tensor."); + ORT_ENFORCE(shape->Location().device.Type() == OrtDevice::CPU, "Shape must be on CPU."); + + return InferReshapeOutputShape( + src->Shape(), + shape->template DataAsSpan(), + allow_zero); +} + +Status FuncReshape( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* X, + const Tensor* shape, + const bool /*allow_zero*/, + Tensor* Y) { + if (!X) return Status(common::ONNXRUNTIME, common::FAIL, "Missing data tensor to be reshaped."); + if (!shape) return Status(common::ONNXRUNTIME, common::FAIL, "Missing shape tensor for reshaping."); + if (shape->Shape().NumDimensions() != 1) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, FAIL, "The shape tensor for reshaping must be a vector, but got ", shape->Shape(), "."); + } + if (shape->Location().device.Type() != OrtDevice::CPU) { + return Status(common::ONNXRUNTIME, common::FAIL, "Shape tensor must be on CPU."); + } + + const void* src_data = X->DataRaw(); + void* dst_data = Y->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (src_data != dst_data) { + ORT_ENFORCE(ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(cuda_kernel->CopyTensor(*X, *Y, *ctx->GetComputeStream())); + } + + return Status::OK(); +} + +std::unique_ptr FuncReshape( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* X, + const Tensor* shape, + const bool allow_zero) { + // TODO(wechi): Study if Tensor can be created as view to existing tensor. + // This feature can refine code for re-sharding and shape broadcasting. + + ORT_ENFORCE(X != nullptr, "Missing data tensor to be reshaped."); + ORT_ENFORCE(shape != nullptr, "Missing shape tensor for reshaping."); + ORT_ENFORCE(shape->Shape().NumDimensions() == 1, "The shape tensor for reshaping must be a vector, but got ", shape->Shape(), "."); + ORT_ENFORCE(shape->Location().device.Type() == OrtDevice::CPU, "Shape tensor must be on CPU."); + + // Calculate output's shape. + auto dst_shape = InferReshapeOutputShape(X, shape, allow_zero); + + // Pre-allocate output. + AllocatorPtr alloc; + ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc).IsOK()); + auto Y = Tensor::Create(X->DataType(), dst_shape, alloc); + + // Do reshape. It's equivalent to memcpy. + ORT_ENFORCE(FuncReshape(cuda_kernel, ctx, X, shape, allow_zero, Y.get()).IsOK()); + return Y; +} + ONNX_OPERATOR_KERNEL_EX( Reshape, kOnnxDomain, diff --git a/onnxruntime/core/providers/cuda/tensor/reshape.h b/onnxruntime/core/providers/cuda/tensor/reshape.h index 01e933e65888f..8f33265071ed3 100644 --- a/onnxruntime/core/providers/cuda/tensor/reshape.h +++ b/onnxruntime/core/providers/cuda/tensor/reshape.h @@ -10,6 +10,39 @@ namespace onnxruntime { namespace cuda { +// Deduce output shape from ONNX Reshape's inputs. +// +// Arguments: +// data_tensor_shape: The shape of the data tensor (i.e., 1st input). +// shape_span: Elements in the shape tensor (i.e., 2nd input). +// +// Returns: +// The output shape of this Reshape. No symbolic values such as "-1" or "0". +TensorShape InferReshapeOutputShape( + const TensorShape& data_tensor_shape, + const gsl::span& shape_span, + bool allow_zero); + +TensorShape InferReshapeOutputShape( + const Tensor* src, + const Tensor* shape, + bool allow_zero); + +Status FuncReshape( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* X, + const Tensor* shape, + const bool /*allow_zero*/, + Tensor* Y); + +std::unique_ptr FuncReshape( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* X, + const Tensor* shape, + const bool allow_zero); + class Reshape final : public CudaKernel { public: Reshape(const OpKernelInfo& info) : CudaKernel(info), @@ -18,27 +51,11 @@ class Reshape final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override { // Copy the second input tensor into the shape vector - const Tensor* shapeTensor = context->Input(1); - if (shapeTensor == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); - if (shapeTensor->Shape().NumDimensions() != 1) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "A shape tensor must be a vector tensor, got ", shapeTensor->Shape().NumDimensions(), " dimensions"); - auto data_span = shapeTensor->template DataAsSpan(); - TensorShapeVector shape(data_span.begin(), data_span.end()); - const Tensor* X = context->Input(0); - if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); - const TensorShape& X_shape = X->Shape(); - - ReshapeHelper helper(X_shape, shape, allow_zero_); - - Tensor* Y = context->Output(0, TensorShape(shape)); - const void* source = X->DataRaw(); - void* target = Y->MutableDataRaw(); - // If source and target pointers are not equal (non-inplace operation), we need to copy the data. - if (target != source) { - ORT_ENFORCE(context->GetComputeStream()); - ORT_RETURN_IF_ERROR(CopyTensor(*X, *Y, *context->GetComputeStream())); - } - - return Status::OK(); + const Tensor* data_tensor = context->Input(0); + const Tensor* shape_tensor = context->Input(1); + const auto target_shape = InferReshapeOutputShape(data_tensor, shape_tensor, allow_zero_); + Tensor* output_tensor = context->Output(0, target_shape); + return FuncReshape(this, context, data_tensor, shape_tensor, allow_zero_, output_tensor); } private: diff --git a/onnxruntime/core/providers/cuda/tensor/where.cc b/onnxruntime/core/providers/cuda/tensor/where.cc index b3f92c913a84b..4d98b3a8a145e 100644 --- a/onnxruntime/core/providers/cuda/tensor/where.cc +++ b/onnxruntime/core/providers/cuda/tensor/where.cc @@ -216,5 +216,6 @@ SPECIALIZED_COMPUTE(int64_t) SPECIALIZED_COMPUTE(float) SPECIALIZED_COMPUTE(double_t) SPECIALIZED_COMPUTE(MLFloat16) +SPECIALIZED_COMPUTE(BFloat16) } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/where_impl.cu b/onnxruntime/core/providers/cuda/tensor/where_impl.cu index 0fbd2062ca1b2..d7909454e922c 100644 --- a/onnxruntime/core/providers/cuda/tensor/where_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/where_impl.cu @@ -238,6 +238,7 @@ SPECIALIZED_IMPL(int64_t) SPECIALIZED_IMPL(float) SPECIALIZED_IMPL(double_t) SPECIALIZED_IMPL(half) +SPECIALIZED_IMPL(BFloat16) } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h index 52018500b134c..cdb0338157561 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h @@ -3,6 +3,9 @@ #pragma once interface IMLOperatorRegistry; +interface IDMLDevice; +interface ID3D12CommandQueue; +interface ID3D12Resource; #include "core/common/status.h" #include "core/framework/data_transfer.h" @@ -28,7 +31,8 @@ namespace Dml std::unique_ptr CreateExecutionProvider( IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, - bool enableMetacommands = true); + bool enableMetacommands, + bool enableDynamicGraphFusion); ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr); void FlushContext(onnxruntime::IExecutionProvider* provider); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index 04381b6ce355c..074f13b309181 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -7,11 +7,14 @@ #include #include #include +#include #include "core/framework/op_kernel.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h" struct AbstractOperatorDesc; interface IMLOperatorTensor; +interface IDMLOperator; struct DML_INPUT_GRAPH_EDGE_DESC; struct DML_OUTPUT_GRAPH_EDGE_DESC; struct DML_INTERMEDIATE_GRAPH_EDGE_DESC; @@ -92,6 +95,8 @@ namespace Windows::AI::MachineLearning::Adapter const onnxruntime::Node& node, MLOperatorTensorGetter& constantInputGetter, const void* executionHandle, + const EdgeShapes* inputShapesOverrides, + /*out*/ EdgeShapes* outputShapes, /*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo )>; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp index ede3e7f2c2257..eb068087de4ad 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp @@ -491,6 +491,8 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( const onnxruntime::Node& node, MLOperatorTensorGetter& constantInputGetter, const void* executionHandle, + const EdgeShapes* inputShapesOverrides, + /*out*/ EdgeShapes* outputShapes, /*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo ) { @@ -498,15 +500,15 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( onnxruntime::OpNodeProtoHelper protoHelper(&nodeContext); // Use the same list of required constant inputs for the shape inferrer and the kernel. - EdgeShapes outputShapes; - InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, nullptr, outputShapes); + InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, inputShapesOverrides, *outputShapes); // Create the kernel while allowing input shape and output shape queries according to options ComPtr kernelInfoWrapper = wil::MakeOrThrow( &protoHelper, executionHandle, true, - &outputShapes, + inputShapesOverrides, + outputShapes, &defaultAttributesCapture, graphNodeCreateInfo, constantCpuInputCapture, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h new file mode 100644 index 0000000000000..5ff70493252bd --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace Windows::AI::MachineLearning::Adapter +{ + // edges and unused edges have an empty array of dimensions. + class EdgeShapes + { + public: + EdgeShapes() = default; + + EdgeShapes(size_t count) : m_shapes(count) {} + + const std::vector& GetShape(size_t edgeIndex) const + { + return m_shapes[edgeIndex]; + } + + std::vector& GetMutableShape(size_t edgeIndex) + { + return m_shapes[edgeIndex]; + } + + size_t EdgeCount() const { return m_shapes.size(); } + + void Reset(size_t edge_count) + { + m_shapes.clear(); + m_shapes.resize(edge_count); + } + + bool operator!=(const EdgeShapes& other) const noexcept + { + return (m_shapes != other.m_shapes); + } + + private: + std::vector> m_shapes; + }; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index 51b93efb3a646..4f7ec188140b5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -1,7 +1,7 @@ #pragma once #include "DmlGraphFusionHelper.h" - +#include "DmlRuntimeFusedGraphKernel.h" namespace Dml { @@ -103,6 +103,36 @@ namespace DmlGraphFusionHelper ORT_THROW_IF_FAILED(resourceUnk->QueryInterface(resource)); } + std::tuple, std::vector, std::byte*, size_t> UnpackInitializer( + const onnxruntime::Graph& graph, + const ONNX_NAMESPACE::TensorProto* initializer) + { + std::unique_ptr unpackedTensor; + std::vector unpackedExternalTensor; + std::byte* tensorPtr = nullptr; + size_t tensorByteSize = 0; + + // The tensor may be stored as raw data or in typed fields. + if (initializer->data_location() == onnx::TensorProto_DataLocation_EXTERNAL) + { + THROW_IF_NOT_OK(onnxruntime::utils::UnpackInitializerData(*initializer, graph.ModelPath(), unpackedExternalTensor)); + tensorPtr = reinterpret_cast(unpackedExternalTensor.data()); + tensorByteSize = unpackedExternalTensor.size(); + } + else if (initializer->has_raw_data()) + { + tensorPtr = (std::byte*)(initializer->raw_data().c_str()); + tensorByteSize = initializer->raw_data().size(); + } + else + { + std::tie(unpackedTensor, tensorByteSize) = Windows::AI::MachineLearning::Adapter::UnpackTensor(*initializer, graph.ModelPath()); + tensorPtr = unpackedTensor.get(); + } + + return std::make_tuple(std::move(unpackedTensor), std::move(unpackedExternalTensor), tensorPtr, tensorByteSize); + } + void ProcessInputData( const ExecutionProviderImpl* providerImpl, const std::vector& isInputsUploadedByDmlEP, @@ -161,32 +191,11 @@ namespace DmlGraphFusionHelper auto iter = initializerNameToInitializerMap.find(subGraphInputArgNames[i]); if (iter != initializerNameToInitializerMap.end()) { - std::byte* tensorPtr = nullptr; - size_t tensorByteSize = 0; - std::vector unpackedExternalTensor; - - std::unique_ptr unpackedTensor; - - //auto& initializer = iter->second; auto* initializer = iter->second.first; + auto [unpackedTensor, unpackedExternalTensor, tensorPtr, tensorByteSize] = UnpackInitializer(graph, initializer); - // The tensor may be stored as raw data or in typed fields. - if (initializer->data_location() == onnx::TensorProto_DataLocation_EXTERNAL) - { - THROW_IF_NOT_OK(onnxruntime::utils::UnpackInitializerData(*initializer, graph.ModelPath(), unpackedExternalTensor)); - tensorPtr = reinterpret_cast(unpackedExternalTensor.data()); - tensorByteSize = unpackedExternalTensor.size(); - } - else if (initializer->has_raw_data()) + if (initializer->data_location() != onnx::TensorProto_DataLocation_EXTERNAL && !initializer->has_raw_data()) { - tensorPtr = (std::byte*)(initializer->raw_data().c_str()); - tensorByteSize = initializer->raw_data().size(); - } - else - { - std::tie(unpackedTensor, tensorByteSize) = Windows::AI::MachineLearning::Adapter::UnpackTensor(*initializer, graph.ModelPath()); - tensorPtr = unpackedTensor.get(); - // Free the initializer if this is the last usage of it. if (initializerToLastInputIndexMap[initializer] == i) { @@ -501,5 +510,173 @@ namespace DmlGraphFusionHelper graph.FinalizeFuseSubGraph(indexedSubGraph, fusedNode); } + + void RegisterDynamicKernel( + onnxruntime::Graph& graph, + onnxruntime::KernelRegistry* registryForPartitionKernels, + const ExecutionProviderImpl* providerImpl, + std::unordered_map graphNodePropertyMap, + const std::unordered_set& dynamicCpuInputMap, + std::shared_ptr indexedSubGraph, + std::unordered_map>&& isInitializerTransferable) + { + struct NodeInfo + { + std::string name; + std::string opType; + std::string description; + std::string domain; + onnxruntime::NodeAttributes attributes; + std::vector inputDefPointers; + std::vector outputDefPointers; + }; + + auto partitionNodePropsMap = DmlGraphFusionHelper::CreatePartitionNodePropsMap( + graph, + *indexedSubGraph, + std::move(graphNodePropertyMap)); + + auto modelPath = graph.ModelPath(); + + const gsl::span subGraphInputArgNames = indexedSubGraph->GetMetaDef()->inputs; + const gsl::span subGraphOutputArgNames = indexedSubGraph->GetMetaDef()->outputs; + + std::vector nodesInfo; + nodesInfo.reserve(indexedSubGraph->nodes.size()); + + std::vector subgraphInputs; + subgraphInputs.reserve(subGraphInputArgNames.size()); + + std::vector subgraphOutputs; + subgraphOutputs.reserve(subGraphOutputArgNames.size()); + + std::vector nodeAttributes; + nodeAttributes.reserve(indexedSubGraph->nodes.size()); + + std::vector> intermediateNodeArgs; + + for (size_t sortedNodeIndex : indexedSubGraph->nodes) + { + auto node = graph.GetNode(sortedNodeIndex); + + nodeAttributes.push_back(node->GetAttributes()); + + NodeInfo nodeInfo{}; + nodeInfo.name = node->Name(); + nodeInfo.opType = node->OpType(); + nodeInfo.description = node->Description(); + nodeInfo.domain = node->Domain(); + nodeInfo.attributes = node->GetAttributes(); + nodeInfo.inputDefPointers.reserve(node->InputDefs().size()); + nodeInfo.outputDefPointers.reserve(node->OutputDefs().size()); + + for (const onnxruntime::NodeArg* inputDef : node->InputDefs()) + { + intermediateNodeArgs.emplace_back(std::make_shared(inputDef->Name(), inputDef->TypeAsProto())); + nodeInfo.inputDefPointers.push_back(intermediateNodeArgs.back().get()); + } + + for (const onnxruntime::NodeArg* outputDef : node->OutputDefs()) + { + intermediateNodeArgs.emplace_back(std::make_shared(outputDef->Name(), outputDef->TypeAsProto())); + nodeInfo.outputDefPointers.push_back(intermediateNodeArgs.back().get()); + } + + nodesInfo.push_back(std::move(nodeInfo)); + } + + for (const std::string& graphInputName : subGraphInputArgNames) + { + subgraphInputs.push_back(graph.GetNodeArg(graphInputName)); + } + + for (const std::string& graphOutputName : subGraphOutputArgNames) + { + subgraphOutputs.push_back(graph.GetNodeArg(graphOutputName)); + } + + // We need to keep the initializers alive since they will be freed once the nodes are removed from the graph + std::vector ownedInitializers; + ownedInitializers.reserve(isInitializerTransferable.size()); + + for (auto& kvp : isInitializerTransferable) + { + auto [unpackedTensor, unpackedExternalTensor, tensorPtr, tensorByteSize] = UnpackInitializer(graph, kvp.second.first); + + ONNX_NAMESPACE::TensorProto tensorProto; + tensorProto.set_data_type(kvp.second.first->data_type()); + tensorProto.set_raw_data(tensorPtr, tensorByteSize); + tensorProto.set_name(kvp.second.first->name()); + + for (int i = 0; i < kvp.second.first->dims_size(); ++i) + { + tensorProto.add_dims(kvp.second.first->dims(i)); + } + ownedInitializers.push_back(std::move(tensorProto)); + kvp.second.first = &ownedInitializers.back(); + } + + // lamda captures for the kernel registration + auto fused_kernel_func = [ + indexedSubGraph, + &modelPath, + nodesInfo = std::move(nodesInfo), + intermediateNodeArgs = std::move(intermediateNodeArgs), + subgraphInputs = std::move(subgraphInputs), + subgraphOutputs = std::move(subgraphOutputs), + partitionNodePropsMap = std::move(partitionNodePropsMap), + ownedInitializers = std::move(ownedInitializers)] (onnxruntime::FuncManager& func_mgr, const onnxruntime::OpKernelInfo& info, std::unique_ptr& out) mutable ->onnxruntime::Status + { + std::vector> subgraphNodes; + subgraphNodes.reserve(nodesInfo.size()); + + for (const NodeInfo& nodeInfo : nodesInfo) + { + subgraphNodes.emplace_back(std::make_shared( + nodeInfo.name, + nodeInfo.opType, + nodeInfo.description, + nodeInfo.inputDefPointers, + nodeInfo.outputDefPointers, + &nodeInfo.attributes, + nodeInfo.domain)); + } + + out.reset(CreateRuntimeFusedGraphKernel( + info, + indexedSubGraph, + modelPath, + std::move(subgraphNodes), + std::move(subgraphInputs), + std::move(subgraphOutputs), + std::move(intermediateNodeArgs), + std::move(partitionNodePropsMap), + std::move(ownedInitializers))); + return Status::OK(); + }; + + // build the kernel definition on the fly, and register it to the fused_kernel_regisitry. + onnxruntime::KernelDefBuilder builder; + builder.SetName(indexedSubGraph->GetMetaDef()->name) + .SetDomain(indexedSubGraph->GetMetaDef()->domain) + .SinceVersion(indexedSubGraph->GetMetaDef()->since_version) + .Provider(onnxruntime::kDmlExecutionProvider); + + // Force the CPU inputs to be allocated on the CPU + for (int i = 0; i < subGraphInputArgNames.size(); ++i) + { + if (dynamicCpuInputMap.find(subGraphInputArgNames[i]) != dynamicCpuInputMap.end()) + { + builder.InputMemoryType(OrtMemTypeCPUInput, i); + } + } + + ORT_THROW_IF_ERROR(registryForPartitionKernels->Register(builder, fused_kernel_func)); + + auto& fusedNode = graph.BeginFuseSubGraph(*indexedSubGraph, indexedSubGraph->GetMetaDef()->name); + fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider); + + graph.FinalizeFuseSubGraph(*indexedSubGraph, fusedNode); + } } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h index 030cffc2a8794..f8f6162aaa1e0 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h @@ -80,5 +80,14 @@ namespace DmlGraphFusionHelper std::vector&& isInputsUploadedByDmlEP, const GraphDescBuilder::GraphDesc& graphDesc, Microsoft::WRL::ComPtr compiledExecutionPlanOperator); + + void RegisterDynamicKernel( + onnxruntime::Graph& graph, + onnxruntime::KernelRegistry* registryForPartitionKernels, + const ExecutionProviderImpl* providerImpl, + std::unordered_map graphNodePropertyMap, + const std::unordered_set& dynamicCpuInputMap, + std::shared_ptr indexedSubGraph, + std::unordered_map>&& isInitializerTransferable); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp index 4813707cdf50c..679738b639ec9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp @@ -15,6 +15,18 @@ namespace Dml { + namespace + { + struct CompiledPartitionInfo + { + Microsoft::WRL::ComPtr compiledOperator; + onnxruntime::IndexedSubGraph indexedSubGraph; + std::vector isInputsUploadedByDmlEP; + GraphDescBuilder::GraphDesc graphDesc; + std::unordered_map> isInitializerTransferable; + }; + } + DmlGraphFusionTransformer::DmlGraphFusionTransformer( const std::string& name, const onnxruntime::IExecutionProvider* provider @@ -24,15 +36,6 @@ namespace Dml { } - struct CompiledPartitionInfo - { - Microsoft::WRL::ComPtr compiledOperator; - onnxruntime::IndexedSubGraph indexedSubGraph; - std::vector isInputsUploadedByDmlEP; - GraphDescBuilder::GraphDesc graphDesc; - std::unordered_map> isInitializerTransferable; - }; - onnxruntime::common::Status DmlGraphFusionTransformer::ApplyImpl( onnxruntime::Graph& graph, bool& modified, @@ -87,6 +90,7 @@ namespace Dml { // Initializers needed by any graph partition std::unordered_set requiredInitializerMap; + std::unordered_set dynamicCpuInputMap; std::unordered_map graphNodePropertyMap; onnxruntime::GraphViewer graphViewer(graph); std::vector> partitions = BuildPartitions( @@ -96,8 +100,10 @@ namespace Dml m_providerImpl->GetSupportedDeviceDataTypeMask(), graphNodePropertyMap, requiredInitializerMap, + dynamicCpuInputMap, additionalSplittingNodes, - implicitInputDefs); + implicitInputDefs, + false); // Reset the splitting nodes for the current iteration additionalSplittingNodes.clear(); @@ -190,17 +196,48 @@ namespace Dml std::move(graphNodePropertyMap)); // Convert partitionONNXGraph into DML EP GraphDesc + auto modelPath = graph.ModelPath(); + + const gsl::span subGraphInputArgNames = indexedSubGraph.GetMetaDef()->inputs; + const gsl::span subGraphOutputArgNames = indexedSubGraph.GetMetaDef()->outputs; + + std::vector subgraphNodes; + subgraphNodes.reserve(indexedSubGraph.nodes.size()); + + std::vector subgraphInputs; + subgraphInputs.reserve(subGraphInputArgNames.size()); + + std::vector subgraphOutputs; + subgraphOutputs.reserve(subGraphOutputArgNames.size()); + + for (size_t sortedNodeIndex : indexedSubGraph.nodes) + { + subgraphNodes.push_back(graph.GetNode(sortedNodeIndex)); + } + + for (const std::string& graphInputName : subGraphInputArgNames) + { + subgraphInputs.push_back(graph.GetNodeArg(graphInputName)); + } + + for (const std::string& graphOutputName : subGraphOutputArgNames) + { + subgraphOutputs.push_back(graph.GetNodeArg(graphOutputName)); + } + ComPtr device; ORT_THROW_IF_FAILED(m_providerImpl->GetDmlDevice(device.GetAddressOf())); GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( isInputsUploadedByDmlEP.data(), isInputsUploadedByDmlEP.size(), isInitializerTransferable, - graph, - indexedSubGraph, partitionNodePropsMap, device.Get(), - m_providerImpl); + m_providerImpl, + modelPath, + subgraphNodes, + subgraphInputs, + subgraphOutputs); // Compile the operator auto compiledPartition = DmlGraphFusionHelper::TryCreateCompiledOperator( diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp new file mode 100644 index 0000000000000..5c7b7bff1e370 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -0,0 +1,369 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +#include "core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h" + +using namespace Windows::AI::MachineLearning::Adapter; + +namespace Dml +{ + class DmlRuntimeFusedGraphKernel : public onnxruntime::OpKernel + { + public: + DmlRuntimeFusedGraphKernel() = delete; + + DmlRuntimeFusedGraphKernel( + const onnxruntime::OpKernelInfo& kernelInfo, + std::shared_ptr indexedSubGraph, + const onnxruntime::Path& modelPath, + std::vector>&& subgraphNodes, + std::vector&& subgraphInputs, + std::vector&& subgraphOutputs, + std::vector>&& intermediateNodeArgs, + std::unordered_map&& partitionNodePropsMap, + std::vector&& ownedInitializers) + : OpKernel(kernelInfo), + m_indexedSubGraph(std::move(indexedSubGraph)), + m_modelPath(modelPath), + m_subgraphNodes(std::move(subgraphNodes)), + m_subgraphInputs(std::move(subgraphInputs)), + m_subgraphOutputs(std::move(subgraphOutputs)), + m_intermediateNodeArgs(std::move(intermediateNodeArgs)), + m_partitionNodePropsMap(std::move(partitionNodePropsMap)), + m_ownedInitializers(std::move(ownedInitializers)) + { + for (const auto& initializer : m_ownedInitializers) + { + m_isInitializerTransferable[initializer.name()] = std::make_pair(&initializer, false); + } + + // Get the execution provider interfaces + auto executionHandle = kernelInfo.GetExecutionProvider()->GetExecutionHandle(); + if (executionHandle) + { + // We assume the execution object inherits IUnknown as its first base + ComPtr providerExecutionObject = const_cast(static_cast(executionHandle)); + + // Get the WinML-specific execution provider interface from the execution object. + ORT_THROW_IF_FAILED(providerExecutionObject.As(&m_provider)); + ORT_THROW_IF_FAILED(providerExecutionObject.As(&m_winmlProvider)); + } + + m_subgraphNodePointers.reserve(m_subgraphNodes.size()); + + for (auto& subgraphNode : m_subgraphNodes) + { + m_subgraphNodePointers.push_back(subgraphNode.get()); + } + } + + void TranslateAndCompileGraph( + const onnxruntime::OpKernelInfo& kernelInfo, + std::vector>& initializeResourceRefs, + std::vector initInputBindings) const + { + // Allocate a persistent resource and initialize the operator + UINT64 persistentResourceSize = m_compiledExecutionPlanOperator->GetBindingProperties().PersistentResourceSize; + if (persistentResourceSize > 0) + { + ORT_THROW_IF_FAILED(m_provider->AllocatePooledResource( + static_cast(persistentResourceSize), + AllocatorRoundingMode::Disabled, + m_persistentResource.ReleaseAndGetAddressOf(), + m_persistentResourceAllocatorUnk.ReleaseAndGetAddressOf())); + + m_persistentResourceBinding = DML_BUFFER_BINDING { m_persistentResource.Get(), 0, persistentResourceSize }; + } + + ORT_THROW_IF_FAILED(m_provider->InitializeOperator( + m_compiledExecutionPlanOperator.Get(), + m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, + gsl::make_span(initInputBindings))); + + std::for_each( + initializeResourceRefs.begin(), + initializeResourceRefs.end(), + [&](ComPtr& resource){ m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(resource).Get()); } + ); + } + + onnxruntime::Status Compute(onnxruntime::OpKernelContext* kernelContext) const override + { + ORT_THROW_HR_IF(E_UNEXPECTED, static_cast(m_subgraphInputs.size()) != kernelContext->InputCount()); + + bool recompileNeeded = m_compiledExecutionPlanOperator == nullptr; + + for (int inputIndex = 0; inputIndex < kernelContext->InputCount(); ++inputIndex) + { + const auto& input = kernelContext->RequiredInput(inputIndex); + const std::string& inputName = m_subgraphInputs[inputIndex]->Name(); + auto shapeIter = m_inferredInputShapes.find(inputName); + + if (shapeIter == m_inferredInputShapes.end()) + { + m_inferredInputShapes[inputName] = input.Shape(); + recompileNeeded = true; + } + else if (shapeIter->second != input.Shape()) + { + shapeIter->second = input.Shape(); + recompileNeeded = true; + } + + // If we have CPU inputs that are not initializers (i.e. they were computed at runtime), add them to the initializer list + if (input.Location().device.Type() == OrtDevice::CPU) + { + auto inputProto = onnxruntime::utils::TensorToTensorProto(input, inputName); + + // We can only avoid recompiling the graph when all CPU inputs are identical + auto initializerIter = m_isInitializerTransferable.find(inputName); + + if (initializerIter != m_isInitializerTransferable.end()) + { + if (initializerIter->second.first->raw_data().length() == inputProto.raw_data().length()) + { + for (int i = 0; i < inputProto.raw_data().length(); ++i) + { + if (initializerIter->second.first->raw_data()[i] != inputProto.raw_data()[i]) + { + recompileNeeded = true; + break; + } + } + } + else + { + recompileNeeded = true; + } + } + else + { + recompileNeeded = true; + } + + m_ownedCpuInputs.push_back(std::make_unique(std::move(inputProto))); + m_isInitializerTransferable[inputName] = std::make_pair(m_ownedCpuInputs.back().get(), false); + } + } + + if (recompileNeeded) + { + // Go through all the node args and replace their shapes with the real ones + for (auto& nodeArg : m_intermediateNodeArgs) + { + auto iter = m_inferredInputShapes.find(nodeArg->Name()); + if (iter != m_inferredInputShapes.end()) + { + auto tensorShape = *nodeArg->Shape(); + ORT_THROW_HR_IF(E_UNEXPECTED, tensorShape.dim_size() != static_cast(iter->second.NumDimensions())); + + for (int i = 0; i < tensorShape.dim_size(); ++i) + { + tensorShape.mutable_dim(i)->set_dim_value(iter->second.GetDims()[i]); + } + + nodeArg->SetShape(tensorShape); + } + } + + // Populate input bindings for operator initialization + const uint32_t fusedNodeInputCount = gsl::narrow_cast(m_indexedSubGraph->GetMetaDef()->inputs.size()); + std::vector> initializeResourceRefs; // For lifetime control + std::vector initInputBindings(fusedNodeInputCount); + std::vector isInputsUploadedByDmlEP(fusedNodeInputCount); + auto providerImpl = static_cast(Info().GetExecutionProvider())->GetImpl(); + + // Convert partitionONNXGraph into DML EP GraphDesc + ComPtr device; + ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); + GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( + isInputsUploadedByDmlEP.data(), + isInputsUploadedByDmlEP.size(), + m_isInitializerTransferable, + m_partitionNodePropsMap, + device.Get(), + providerImpl, + m_modelPath, + m_subgraphNodePointers, + m_subgraphInputs, + m_subgraphOutputs); + + m_outputShapes = graphDesc.outputShapes; + + // Walk through each graph edge and mark used inputs + m_inputsUsed.resize(fusedNodeInputCount, false); + for (const DML_INPUT_GRAPH_EDGE_DESC& edge : graphDesc.inputEdges) + { + m_inputsUsed[edge.GraphInputIndex] = true; + } + + // Compile the operator + m_compiledExecutionPlanOperator = DmlGraphFusionHelper::TryCreateCompiledOperator( + graphDesc, + *m_indexedSubGraph, + providerImpl); + + // Queue references to objects which must be kept alive until resulting GPU work completes + m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get()); + + TranslateAndCompileGraph( + Info(), + initializeResourceRefs, + initInputBindings); + } + + // Wrap tensors as required by Dml::IExecutionProvider::ExecuteOperator + OpKernelContextWrapper contextWrapper( + kernelContext, + Info().GetExecutionProvider(), + true, + nullptr); + + ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier()); + + // Get input resources for execution, excluding those which were specified as owned by DML and provided + // at initialization instead. + std::vector> inputTensors(kernelContext->InputCount()); + std::vector inputPtrs(kernelContext->InputCount()); + + for (int i = 0; i < kernelContext->InputCount(); ++i) + { + if (!m_inputsUsed[i]) + { + continue; + } + + ORT_THROW_IF_FAILED(contextWrapper.GetInputTensor(i, inputTensors[i].GetAddressOf())); + inputPtrs[i] = m_provider->DecodeResource(MLOperatorTensor(inputTensors[i].Get()).GetDataInterface().Get()); + } + + auto outputTensors = contextWrapper.GetOutputTensors(m_outputShapes); + ExecuteOperator( + m_compiledExecutionPlanOperator.Get(), + m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, + inputPtrs, + outputTensors); + + ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier()); + + return onnxruntime::Status::OK(); + } + + void ExecuteOperator( + IDMLCompiledOperator* op, + _In_opt_ const DML_BUFFER_BINDING* persistentResourceBinding, + gsl::span inputTensors, + gsl::span outputTensors) const + { + auto FillBindingsFromTensors = [this](auto& bufferBindings, auto& bindingDescs, gsl::span& tensors) + { + for (IMLOperatorTensor* tensor : tensors) + { + if (tensor) + { + assert(tensor->IsDataInterface()); + ID3D12Resource* resource = m_provider->DecodeResource(MLOperatorTensor(tensor).GetDataInterface().Get()); + D3D12_RESOURCE_DESC resourceDesc = resource->GetDesc(); + bufferBindings.push_back({ resource, 0, resourceDesc.Width }); + bindingDescs.push_back({ DML_BINDING_TYPE_BUFFER, &bufferBindings.back() }); + } + else + { + bufferBindings.push_back({ nullptr, 0, 0 }); + bindingDescs.push_back({ DML_BINDING_TYPE_NONE, nullptr }); + } + } + }; + + auto FillBindingsFromBuffers = [](auto& bufferBindings, auto& bindingDescs, gsl::span& resources) + { + for (ID3D12Resource* resource : resources) + { + if (resource) + { + D3D12_RESOURCE_DESC resourceDesc = resource->GetDesc(); + bufferBindings.push_back({ resource, 0, resourceDesc.Width }); + bindingDescs.push_back({ DML_BINDING_TYPE_BUFFER, &bufferBindings.back() }); + } + else + { + bufferBindings.push_back({ nullptr, 0, 0 }); + bindingDescs.push_back({ DML_BINDING_TYPE_NONE, nullptr }); + } + } + }; + + std::vector inputBufferBindings; + inputBufferBindings.reserve(inputTensors.size()); + std::vector inputBindings; + inputBindings.reserve(inputTensors.size()); + FillBindingsFromBuffers(inputBufferBindings, inputBindings, inputTensors); + + std::vector outputBufferBindings; + outputBufferBindings.reserve(outputTensors.size()); + std::vector outputBindings; + outputBindings.reserve(outputTensors.size()); + FillBindingsFromTensors(outputBufferBindings, outputBindings, outputTensors); + + ORT_THROW_IF_FAILED(m_provider->ExecuteOperator( + op, + persistentResourceBinding, + inputBindings, + outputBindings)); + } + + private: + ComPtr m_winmlProvider; + ComPtr m_provider; + + mutable std::optional m_persistentResourceBinding; + std::shared_ptr m_indexedSubGraph; + const onnxruntime::Path& m_modelPath; + + std::vector> m_subgraphNodes; + std::vector m_subgraphInputs; + std::vector m_subgraphOutputs; + mutable std::vector> m_intermediateNodeArgs; + std::unordered_map m_partitionNodePropsMap; + std::vector m_ownedInitializers; + mutable std::unordered_map> m_isInitializerTransferable; + std::vector m_subgraphNodePointers; + + // Bindings from previous executions of a re-used command list + mutable std::vector> m_ownedCpuInputs; + mutable ComPtr m_compiledExecutionPlanOperator; + mutable std::vector m_inputsUsed; + mutable ComPtr m_persistentResource; + mutable ComPtr m_persistentResourceAllocatorUnk; // Controls when the persistent resource is returned to the allocator + mutable Windows::AI::MachineLearning::Adapter::EdgeShapes m_outputShapes; + mutable std::unordered_map m_inferredInputShapes; + }; + + onnxruntime::OpKernel* CreateRuntimeFusedGraphKernel( + const onnxruntime::OpKernelInfo& info, + std::shared_ptr indexedSubGraph, + const onnxruntime::Path& modelPath, + std::vector>&& subgraphNodes, + std::vector&& subgraphInputs, + std::vector&& subgraphOutputs, + std::vector>&& intermediateNodeArgs, + std::unordered_map&& partitionNodePropsMap, + std::vector&& ownedInitializers) + { + return new DmlRuntimeFusedGraphKernel( + info, + std::move(indexedSubGraph), + modelPath, + std::move(subgraphNodes), + std::move(subgraphInputs), + std::move(subgraphOutputs), + std::move(intermediateNodeArgs), + std::move(partitionNodePropsMap), + std::move(ownedInitializers) + ); + } +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h new file mode 100644 index 0000000000000..d679c5aa5667c --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/op_kernel.h" +#include "GraphDescBuilder.h" +#include "DmlRuntimeGraphFusionTransformer.h" + +namespace Dml +{ + onnxruntime::OpKernel* CreateRuntimeFusedGraphKernel( + const onnxruntime::OpKernelInfo& info, + std::shared_ptr indexedSubGraph, + const onnxruntime::Path& modelPath, + std::vector>&& subgraphNodes, + std::vector&& subgraphInputs, + std::vector&& subgraphOutputs, + std::vector>&& intermediateNodeArgs, + std::unordered_map&& partitionNodePropsMap, + std::vector&& ownedInitializers + ); +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp new file mode 100644 index 0000000000000..6318b0d5e2865 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp @@ -0,0 +1,161 @@ +#pragma once + +#include "precomp.h" +#include "GraphDescBuilder.h" +#include "ExecutionProvider.h" +#include "DmlRuntimeGraphFusionTransformer.h" +#include "GraphPartitioner.h" +#include "core/framework/kernel_type_str_resolver.h" +#include "core/framework/kernel_lookup.h" +#include "core/optimizer/constant_sharing.h" +#include "DmlRuntimeFusedGraphKernel.h" +#include "MLOperatorAuthorImpl.h" +#include "DmlGraphFusionHelper.h" + +namespace Dml +{ + namespace + { + struct CompiledPartitionInfo + { + std::shared_ptr indexedSubGraph; + std::unordered_map> isInitializerTransferable; + }; + } + + DmlRuntimeGraphFusionTransformer::DmlRuntimeGraphFusionTransformer( + const std::string& name, + const onnxruntime::IExecutionProvider* provider + ) + :onnxruntime::GraphTransformer(name), + m_providerImpl(static_cast(provider)->GetImpl()) + { + } + + onnxruntime::common::Status DmlRuntimeGraphFusionTransformer::ApplyImpl( + onnxruntime::Graph& graph, + bool& modified, + int graphLevel, + const onnxruntime::logging::Logger& logger) const + { + return ApplyImplHelper(graph, modified, graphLevel, logger, {}); + } + + onnxruntime::common::Status DmlRuntimeGraphFusionTransformer::ApplyImplHelper( + onnxruntime::Graph& graph, + bool& modified, + int graphLevel, + const onnxruntime::logging::Logger& logger, + const std::unordered_map& implicitInputDefs) const + { + onnxruntime::ProviderType providerType = onnxruntime::kDmlExecutionProvider; + const gsl::not_null registry = m_providerImpl->GetKernelRegistry().get(); + const auto kernelTypeStrResolver = onnxruntime::OpSchemaKernelTypeStrResolver{}; + const auto kernelLookup = onnxruntime::KernelLookup( + providerType, + gsl::make_span(®istry, 1), + kernelTypeStrResolver); + + onnxruntime::GraphViewer graphViewer(graph); + const auto& nodeTopologyList = graphViewer.GetNodesInTopologicalOrder(); + + for (auto nodeIndex : nodeTopologyList) + { + auto* node = graph.GetNode(nodeIndex); + if (!node) + { + continue; // node was removed + } + + std::unordered_map subgraphImplicitInputDefs; + for (const onnxruntime::NodeArg* inputDef : node->ImplicitInputDefs()) + { + subgraphImplicitInputDefs[inputDef->Name()] = inputDef; + } + + for (auto& entry : node->GetAttributeNameToMutableSubgraphMap()) + { + auto& subgraph = *entry.second; + ORT_RETURN_IF_ERROR(ApplyImplHelper(subgraph, modified, graphLevel + 1, logger, subgraphImplicitInputDefs)); + } + } + + // Initializers needed by any graph partition + std::vector additionalSplittingNodes; + std::unordered_map graphNodePropertyMap; + std::unordered_set requiredInitializerMap; + std::unordered_set dynamicCpuInputMap; + std::vector> partitions = BuildPartitions( + graphViewer, + *m_providerImpl->GetInternalRegistrationInfoMap(), + kernelLookup, + m_providerImpl->GetSupportedDeviceDataTypeMask(), + graphNodePropertyMap, + requiredInitializerMap, + dynamicCpuInputMap, + additionalSplittingNodes, + implicitInputDefs, + true); + + // Reset the splitting nodes for the current iteration + additionalSplittingNodes.clear(); + + // Reset the compiled operators for the current iteration + std::vector> compiledPartitionInfos(partitions.size()); + + // Create a map between each initialized tensor and the partition(s) it is part of. + auto initializerPartitionMap = DmlGraphFusionHelper::GetInitializerToPartitionMap(graphViewer, partitions); + + for (uint32_t partitionIndex = 0; partitionIndex < partitions.size(); ++partitionIndex) + { + auto& partition = partitions[partitionIndex]; + + if (partition->GetRootMergedPartition() != partition.get() || + !partition->IsDmlPartition()) + { + continue; + } + + if (partition->IsDmlGraphPartition()) + { + std::unordered_map> isInitializerTransferable; + + std::string partitionKernelPrefix = std::to_string(m_providerImpl->GetPartitionKernelPrefixVal()) + "_"; + m_providerImpl->IncreasePartitionKernelPrefixVal(); + + // populate isInitializerTransferable + for (const auto& input : partition->GetInputs()) + { + const onnx::TensorProto* tensor = nullptr; + if (graph.GetInitializedTensor(input, tensor) && requiredInitializerMap.find(input) != requiredInitializerMap.end()) + { + isInitializerTransferable[input] = {tensor, false}; + } + } + + compiledPartitionInfos[partitionIndex] = std::make_shared(); + compiledPartitionInfos[partitionIndex]->indexedSubGraph = std::make_shared( + DmlGraphFusionHelper::CreateIndexedSubGraph(partition.get(), partitionIndex, partitionKernelPrefix)); + compiledPartitionInfos[partitionIndex]->isInitializerTransferable = std::move(isInitializerTransferable); + } + } + + for (auto&& compiledPartitionInfo : compiledPartitionInfos) + { + // Null compiled operators were not DML partitions + if (compiledPartitionInfo) + { + DmlGraphFusionHelper::RegisterDynamicKernel( + graph, + m_providerImpl->GetKernelRegistry().get(), + m_providerImpl, + graphNodePropertyMap, + dynamicCpuInputMap, + std::move(compiledPartitionInfo->indexedSubGraph), + std::move(compiledPartitionInfo->isInitializerTransferable)); + } + } + + return onnxruntime::common::Status::OK(); + } +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h new file mode 100644 index 0000000000000..cfa743e1f2b85 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include "core/optimizer/graph_transformer.h" +#include "core/framework/execution_providers.h" + +namespace Dml +{ +class ExecutionProviderImpl; + +class DmlRuntimeGraphFusionTransformer : public onnxruntime::GraphTransformer +{ +public: + DmlRuntimeGraphFusionTransformer( + const std::string& name, + const onnxruntime::IExecutionProvider* provider + ); + +public: + static inline const char* const DML_GRAPH_FUSION_NODE_NAME_PREFIX = "DmlRuntimeFusedNode_"; + static inline const char* const DML_GRAPH_FUSION_NODE_DOMAIN = "DmlRuntimeFusedNodeDomain"; + +private: + onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, + bool& modified, + int graphLevel, + const onnxruntime::logging::Logger& logger) const final; + + onnxruntime::common::Status ApplyImplHelper( + onnxruntime::Graph& graph, + bool& modified, + int graphLevel, + const onnxruntime::logging::Logger& logger, + const std::unordered_map& implicitInputDefs) const; + +private: + const ExecutionProviderImpl* m_providerImpl = nullptr; +}; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 5f6bd178aaa15..8644b8d56a426 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -67,7 +67,8 @@ namespace Dml ExecutionProvider::ExecutionProvider( IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, - bool enableMetacommands) : + bool enableMetacommands, + bool enableDynamicGraphFusion) : IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)) { D3D12_COMMAND_LIST_TYPE queueType = commandQueue->GetDesc().Type; @@ -80,7 +81,7 @@ namespace Dml ComPtr device; GRAPHICS_THROW_IF_FAILED(commandQueue->GetDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf()))); - m_impl = wil::MakeOrThrow(dmlDevice, device.Get(), commandQueue, enableMetacommands); + m_impl = wil::MakeOrThrow(dmlDevice, device.Get(), commandQueue, enableMetacommands, enableDynamicGraphFusion); } std::vector> @@ -147,12 +148,12 @@ namespace Dml // Task 24384515: Update ORT AIInfra release agent pool to install 19H1 SDK on VM bootstrap #define D3D_FEATURE_LEVEL_1_0_CORE_PRIVATE ((D3D_FEATURE_LEVEL)0x1000) - ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ID3D12CommandQueue* queue, bool enableMetacommands) + ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ID3D12CommandQueue* queue, bool enableMetacommands, bool enableDynamicGraphFusion) : m_d3d12Device(d3d12Device), m_dmlDevice(dmlDevice), - m_areMetacommandsEnabled(enableMetacommands) + m_areMetacommandsEnabled(enableMetacommands), + m_dynamicGraphFusionEnabled(enableDynamicGraphFusion) { - D3D12_FEATURE_DATA_FEATURE_LEVELS featureLevels = {}; D3D_FEATURE_LEVEL featureLevelsList[] = { @@ -1093,6 +1094,11 @@ namespace Dml return m_areMetacommandsEnabled; } + bool ExecutionProviderImpl::DynamicGraphFusionEnabled() const noexcept + { + return m_dynamicGraphFusionEnabled; + } + std::shared_ptr ExecutionProviderImpl::GetInternalRegistrationInfoMap() const { @@ -1129,9 +1135,10 @@ namespace Dml std::unique_ptr CreateExecutionProvider( IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, - bool enableMetacommands) + bool enableMetacommands, + bool enableDynamicGraphFusion) { - return std::make_unique(dmlDevice, commandQueue, enableMetacommands); + return std::make_unique(dmlDevice, commandQueue, enableMetacommands, enableDynamicGraphFusion); } ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 31b893a2f25d7..3aaa11cdee479 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -5,6 +5,7 @@ #include "GraphTransformer.h" #include "core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h" +#include "core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h" #include #include @@ -34,7 +35,8 @@ namespace Dml IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ID3D12CommandQueue* queue, - bool enableMetacommands = true); + bool enableMetacommands, + bool enableDynamicGraphFusion); void ReleaseCompletedReferences(); @@ -150,6 +152,7 @@ namespace Dml STDMETHOD_(bool, IsMcdmDevice)() const noexcept final; STDMETHOD_(bool, MetacommandsEnabled)() const noexcept final; + bool DynamicGraphFusionEnabled() const noexcept; std::shared_ptr GetGpuAllocator(); std::shared_ptr GetCpuInputAllocator(); @@ -184,6 +187,7 @@ namespace Dml ComPtr m_dmlDevice; bool m_isMcdmDevice = false; bool m_areMetacommandsEnabled = true; + bool m_dynamicGraphFusionEnabled = false; bool m_native16BitShaderOpsSupported = false; std::shared_ptr m_context; std::unique_ptr m_uploadHeap; @@ -236,7 +240,8 @@ namespace Dml explicit ExecutionProvider( IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, - bool enableMetacommands = true + bool enableMetacommands, + bool enableDynamicGraphFusion ); std::unique_ptr GetDataTransfer() const final override @@ -299,9 +304,9 @@ namespace Dml return m_impl.Get(); } - void MetacommandsEnabled() + bool DynamicGraphFusionEnabled() const { - m_impl->MetacommandsEnabled(); + return m_impl->DynamicGraphFusionEnabled(); } virtual std::vector CreatePreferredAllocators() override diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 636f46428ce99..3fc8f415e5a58 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -147,14 +147,14 @@ namespace Dml::GraphDescBuilder const uint8_t* isConstGpuGraphInput, const size_t isConstGpuGraphInputCount, const std::unordered_map>& isInitializerTransferable, - const onnxruntime::Graph& graph, - const onnxruntime::IndexedSubGraph& indexedSubGraph, const std::unordered_map& graphNodePropertyMap, IDMLDevice* device, - const void* executionHandle) + const void* executionHandle, + const onnxruntime::Path& modelPath, + gsl::span subgraphNodes, + gsl::span subgraphInputs, + gsl::span subgraphOutputs) { - const gsl::span subGraphInputArgNames = indexedSubGraph.GetMetaDef()->inputs; - const gsl::span subGraphOutputArgNames = indexedSubGraph.GetMetaDef()->outputs; struct NodeAndIndex { uint32_t nodeIndex; // The index of the node itself @@ -164,12 +164,14 @@ namespace Dml::GraphDescBuilder // Map from Lotus node argument names to the new node and index where it will be produced std::unordered_map nameToNodeAndIndexMap; + std::unordered_map nodeOutputShapes; + // Map from Lotus node argument names to input indices of the fused kernel node. std::unordered_map nameToDmlFusedNodeInputIndex; - for (size_t inputIndex = 0; inputIndex < subGraphInputArgNames.size(); ++inputIndex) + for (size_t inputIndex = 0; inputIndex < subgraphInputs.size(); ++inputIndex) { - const onnxruntime::NodeArg* graphInput = graph.GetNodeArg(subGraphInputArgNames[inputIndex]); + const onnxruntime::NodeArg* graphInput = subgraphInputs[inputIndex]; if (!graphInput) { @@ -196,13 +198,11 @@ namespace Dml::GraphDescBuilder const uint32_t minNodeCountToReuseCommandList = 5; bool reuseCommandList = false; - if (indexedSubGraph.nodes.size() >= minNodeCountToReuseCommandList) + if (subgraphNodes.size() >= minNodeCountToReuseCommandList) { reuseCommandList = true; } - auto modelPath = graph.ModelPath(); - auto constantCpuGraphInputGetter = [&isInitializerTransferable, &modelPath](const std::string& argName) { ComPtr tensorWrapper; @@ -219,9 +219,11 @@ namespace Dml::GraphDescBuilder // Iterate through each node and create a corresponding node in the new graph // We can iterate the nodes in any order because the edge connectivity will take care of the topological order - for (size_t sortedNodeIndex : indexedSubGraph.nodes) + std::unordered_map> inferredOutputShapes; + + for (const onnxruntime::Node* subgraphNode : subgraphNodes) { - const onnxruntime::Node& node = *graph.GetNode(sortedNodeIndex); + const onnxruntime::Node& node = *subgraphNode; const GraphNodeProperties& graphNodeProps = graphNodePropertyMap.find(GetUniqueNodeName(node))->second; const auto& requiredConstantCpuInputs = graphNodeProps.internalRegInfo->requiredConstantCpuInputs; @@ -244,14 +246,45 @@ namespace Dml::GraphDescBuilder return tensor; }; + EdgeShapes inputShapesOverrides(node.InputDefs().size()); + + // Override the input shapes with shapes that were previously inferred + for (int inputIndex = 0; inputIndex < node.InputDefs().size(); ++inputIndex) + { + auto inputDef = node.InputDefs()[inputIndex]; + + auto outputShapesIter = inferredOutputShapes.find(inputDef->Name()); + if (outputShapesIter != inferredOutputShapes.end()) + { + inputShapesOverrides.GetMutableShape(inputIndex) = outputShapesIter->second; + } + else if (inputDef->HasTensorOrScalarShape()) + { + for (int i = 0; i < inputDef->Shape()->dim_size(); ++i) + { + ORT_THROW_HR_IF(E_INVALIDARG, !inputDef->Shape()->dim(i).has_dim_value()); + inputShapesOverrides.GetMutableShape(inputIndex).push_back(gsl::narrow_cast(inputDef->Shape()->dim(i).dim_value())); + } + } + } + + EdgeShapes outputShapes; DmlGraphNodeCreateInfo graphNodeCreateInfo; graphNodeProps.internalRegInfo->graphNodeFactoryRegistration->factory( node, constantCpuNodeInputGetter, executionHandle, + &inputShapesOverrides, + /*out*/ &outputShapes, /*out*/ &graphNodeCreateInfo ); + ORT_THROW_HR_IF(E_UNEXPECTED, outputShapes.EdgeCount() != node.OutputDefs().size()); + for (int i = 0; i < node.OutputDefs().size(); ++i) + { + inferredOutputShapes[node.OutputDefs()[i]->Name()] = outputShapes.GetShape(i); + } + // Create a map between operatorGraphNodeIndex to mainGraphNodeIndex. std::unordered_map operatorGraphNodeIndexToMainGraphNodeIndexMap; uint32_t graphNodeCount = gsl::narrow_cast(graphNodes.size()); @@ -347,6 +380,8 @@ namespace Dml::GraphDescBuilder operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphOutputEdge.FromNodeIndex], operatorGraphOutputEdge.FromNodeOutputIndex }; + + nodeOutputShapes[arg->Name()] = outputShapes; } } @@ -367,10 +402,12 @@ namespace Dml::GraphDescBuilder } } + EdgeShapes graphOutputShapes(subgraphOutputs.size()); + // Add graph output nodes, which might be in a different order from the encapsulating node - for (size_t outputIndex = 0; outputIndex < subGraphOutputArgNames.size(); ++outputIndex) + for (size_t outputIndex = 0; outputIndex < subgraphOutputs.size(); ++outputIndex) { - const onnxruntime::NodeArg* graphOutput = graph.GetNodeArg(subGraphOutputArgNames[outputIndex]); + const onnxruntime::NodeArg* graphOutput = subgraphOutputs[outputIndex]; ORT_THROW_HR_IF_NULL_MSG(E_POINTER, graphOutput, "FusedNode's nodeArgList does not contain one of the nodeArg"); const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name()); @@ -380,6 +417,7 @@ namespace Dml::GraphDescBuilder edge.FromNodeOutputIndex = outputNodeAndIndex.targetIndex; edge.GraphOutputIndex = gsl::narrow_cast(outputIndex); graphOutputEdges.push_back(edge); + graphOutputShapes.GetMutableShape(outputIndex) = nodeOutputShapes[graphOutput->Name()].GetShape(outputNodeAndIndex.targetIndex); } RemoveUnconnectedNodes(graphNodes, graphInputEdges, graphIntermediateEdges, graphOutputEdges); @@ -390,6 +428,7 @@ namespace Dml::GraphDescBuilder graphDesc.outputEdges = std::move(graphOutputEdges); graphDesc.intermediateEdges = std::move(graphIntermediateEdges); graphDesc.reuseCommandList = reuseCommandList; + graphDesc.outputShapes = std::move(graphOutputShapes); return graphDesc; } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h index 5c04962e55557..0039678c00e59 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h @@ -9,10 +9,10 @@ namespace Dml { struct GraphNodeProperties { - std::shared_ptr + std::shared_ptr internalRegInfo; - // These are currently passed from the partitioning step since the only DML operators current + // These are currently passed from the partitioning step since the only DML operators current // supporting graph nodes don't customize the order of edges or shapes, other than coercing // dimension count. This will change as the supported set of operators as graph nodes increases. Windows::AI::MachineLearning::Adapter::EdgeShapes inputShapes; @@ -38,16 +38,19 @@ namespace Dml std::vector outputEdges; std::vector intermediateEdges; bool reuseCommandList; + Windows::AI::MachineLearning::Adapter::EdgeShapes outputShapes; }; GraphDesc BuildGraphDesc( const uint8_t* isConstGpuGraphInput, const size_t isConstGpuGraphInputCount, const std::unordered_map>& isInitializerTransferable, - const onnxruntime::Graph& graph, - const onnxruntime::IndexedSubGraph& indexedSubGraph, const std::unordered_map& graphNodePropertyMap, IDMLDevice* device, - const void* executionHandle); + const void* executionHandle, + const onnxruntime::Path& modelPath, + gsl::span subgraphNodes, + gsl::span subgraphInputs, + gsl::span subgraphOutputs); } -} \ No newline at end of file +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp index 18943878ccedc..f7a4743801d81 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp @@ -151,6 +151,8 @@ namespace Dml _In_opt_ const std::unordered_map* nodeNameToPartitionMap, _Inout_ std::unordered_map& dmlNodePropertyMap, _Inout_ std::unordered_set& requiredInitializerMap, + _Inout_ std::unordered_set& dynamicCpuInputMap, + bool allowDmlGraphDynamicShapes, _Out_ bool* isDmlGraphNode ) { @@ -172,36 +174,68 @@ namespace Dml if (internalRegInfo && internalRegInfo->graphNodeFactoryRegistration) { - bool requiredCpuInputsConstant = true; - for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs) + if (allowDmlGraphDynamicShapes) { - if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists()) + for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs) { - continue; - } + if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists()) + { + continue; + } - const onnx::TensorProto* tensor = nullptr; - const std::string& inputName = node.InputDefs()[inputIndex]->Name(); + const onnx::TensorProto* tensor = nullptr; + const std::string& inputName = node.InputDefs()[inputIndex]->Name(); - if (!graph.GetInitializedTensor(inputName, tensor)) - { - requiredCpuInputsConstant = false; - break; + if (graph.GetInitializedTensor(inputName, tensor)) + { + requiredInitializerMap.insert(inputName); + } + else + { + dynamicCpuInputMap.insert(inputName); + } } - requiredInitializerMap.insert(inputName); + std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; + if (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size()) + { + *isDmlGraphNode = true; + graphNodeProperty.first->second.internalRegInfo = internalRegInfo; + } } - - std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; - if (requiredCpuInputsConstant && - TryGetStaticInputShapes( node, graphNodeProperty.first->second.inputShapes) && - !ContainsEmptyDimensions(graphNodeProperty.first->second.inputShapes, internalRegInfo->requiredConstantCpuInputs) && - TryGetStaticOutputShapes(node, graphNodeProperty.first->second.outputShapes) && - !ContainsEmptyDimensions(graphNodeProperty.first->second.outputShapes, internalRegInfo->requiredConstantCpuInputs) && - (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size())) + else { - *isDmlGraphNode = true; - graphNodeProperty.first->second.internalRegInfo = internalRegInfo; + bool requiredCpuInputsConstant = true; + for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs) + { + if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists()) + { + continue; + } + + const onnx::TensorProto* tensor = nullptr; + const std::string& inputName = node.InputDefs()[inputIndex]->Name(); + + if (!graph.GetInitializedTensor(inputName, tensor)) + { + requiredCpuInputsConstant = false; + break; + } + + requiredInitializerMap.insert(inputName); + } + + std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; + if (requiredCpuInputsConstant && + TryGetStaticInputShapes( node, graphNodeProperty.first->second.inputShapes) && + !ContainsEmptyDimensions(graphNodeProperty.first->second.inputShapes, internalRegInfo->requiredConstantCpuInputs) && + TryGetStaticOutputShapes(node, graphNodeProperty.first->second.outputShapes) && + !ContainsEmptyDimensions(graphNodeProperty.first->second.outputShapes, internalRegInfo->requiredConstantCpuInputs) && + (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size())) + { + *isDmlGraphNode = true; + graphNodeProperty.first->second.internalRegInfo = internalRegInfo; + } } } } @@ -379,8 +413,10 @@ namespace Dml uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::unordered_map& graphNodePropertyMap, std::unordered_set& requiredInitializerMap, + std::unordered_set& dynamicCpuInputMap, gsl::span additionalSplittingNodes, - const std::unordered_map& implicitInputs) + const std::unordered_map& implicitInputs, + bool allowDmlGraphDynamicShapes) { // Nodes are uniquely identified by the name of their first output argument std::vector> partitions; @@ -443,6 +479,8 @@ namespace Dml &nodeNameToPartitionMap, graphNodePropertyMap, requiredInitializerMap, + dynamicCpuInputMap, + allowDmlGraphDynamicShapes, /*out*/ &isDmlGraphNode ); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h index 37d577f647fb5..3bddb5ae16086 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h @@ -50,6 +50,8 @@ namespace Dml uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::unordered_map& graphNodePropertyMap, std::unordered_set& requiredInitializerMap, + std::unordered_set& dynamicCpuInputMap, gsl::span additionalSplittingNodes, - const std::unordered_map& implicitInputs); + const std::unordered_map& implicitInputs, + bool allowDmlGraphDynamicShapes); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h index d7a0a607cdec9..a8a6d6745e908 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h @@ -2,8 +2,15 @@ // Licensed under the MIT License. #pragma once + +#include + #include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" +interface IDMLCompiledOperator; +struct DML_BUFFER_BINDING; +struct DML_BINDING_DESC; + namespace Dml { struct Binding diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index 6cd10e14e08d2..4deec620fe5fb 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1356,13 +1356,14 @@ namespace Windows::AI::MachineLearning::Adapter const onnxruntime::OpNodeProtoHelper* protoHelper, const void* executionHandle, bool isInternalOperator, + const EdgeShapes* inputShapesOverrides, const EdgeShapes* inferredOutputShapes, const AttributeMap* defaultAttributes, DmlGraphNodeCreateInfo* graphNodeCreateInfo, gsl::span requiredConstantCpuInputs, MLOperatorTensorGetter& constantInputGetter ) - : OpNodeInfoWrapper(protoHelper, nullptr, defaultAttributes, requiredConstantCpuInputs, constantInputGetter, nullptr), + : OpNodeInfoWrapper(protoHelper, inputShapesOverrides, defaultAttributes, requiredConstantCpuInputs, constantInputGetter, nullptr), m_inferredOutputShapes(inferredOutputShapes), m_internalOperator(isInternalOperator), m_graphNodeCreateInfo(graphNodeCreateInfo) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index a7f8bebb2de78..913997ff4ad49 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -4,6 +4,7 @@ #pragma once #include "core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h" #include "core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h" #include "core/framework/op_kernel.h" #include "core/framework/customregistry.h" #include "core/framework/tensorprotoutils.h" @@ -93,42 +94,6 @@ struct AttributeValue using AttributeMap = std::map; -// Encapsulation of shapes across different edges of an operator. Non-tensor -// edges and unused edges have an empty array of dimensions. -class EdgeShapes -{ -public: - EdgeShapes() = default; - - EdgeShapes(size_t count) : m_shapes(count) {} - - const std::vector& GetShape(size_t edgeIndex) const - { - return m_shapes[edgeIndex]; - } - - std::vector& GetMutableShape(size_t edgeIndex) - { - return m_shapes[edgeIndex]; - } - - size_t EdgeCount() const { return m_shapes.size(); } - - void Reset(size_t edge_count) - { - m_shapes.clear(); - m_shapes.resize(edge_count); - } - - bool operator!=(const EdgeShapes& other) const noexcept - { - return (m_shapes != other.m_shapes); - } - - private: - std::vector> m_shapes; -}; - // Base class for ABI objects which may be "Closed", at which point calls will predictably // fail or return a dummy value. This is used for transient ABI context objects which // are passed to methods on kernel or inferencers, and which wrap Lotus objects whose lifetimes @@ -434,6 +399,7 @@ class DmlGraphOpKernelInfoWrapper : public OpNodeInfoWrapper< const onnxruntime::OpNodeProtoHelper * protoHelper, const void* executionHandle, bool isInternalOperator, + const EdgeShapes* inputShapesOverrides, const EdgeShapes* inferredOutputShapes, const AttributeMap* defaultAttributes, DmlGraphNodeCreateInfo* graphNodeCreateInfo, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h index c63863853fb4e..4bbc8a4b718da 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h @@ -51,16 +51,6 @@ namespace GridSample_float_float #include "GeneratedShaders/grid_sample_float_float.h" } -namespace GridSample_double_float -{ - #include "GeneratedShaders/grid_sample_double_float.h" -} - -namespace GridSample_bool_float -{ - #include "GeneratedShaders/grid_sample_bool_float.h" -} - namespace GridSample_uint16_fp16 { #include "GeneratedShaders/grid_sample_uint16_fp16.h" @@ -101,66 +91,6 @@ namespace GridSample_float_fp16 #include "GeneratedShaders/grid_sample_float_fp16.h" } -namespace GridSample_double_fp16 -{ - #include "GeneratedShaders/grid_sample_double_fp16.h" -} - -namespace GridSample_bool_fp16 -{ - #include "GeneratedShaders/grid_sample_bool_fp16.h" -} - -namespace GridSample_uint16_double -{ - #include "GeneratedShaders/grid_sample_uint16_double.h" -} - -namespace GridSample_uint_double -{ - #include "GeneratedShaders/grid_sample_uint_double.h" -} - -namespace GridSample_uint64_double -{ - #include "GeneratedShaders/grid_sample_uint64_double.h" -} - -namespace GridSample_int16_double -{ - #include "GeneratedShaders/grid_sample_int16_double.h" -} - -namespace GridSample_int_double -{ - #include "GeneratedShaders/grid_sample_int_double.h" -} - -namespace GridSample_int64_double -{ - #include "GeneratedShaders/grid_sample_int64_double.h" -} - -namespace GridSample_fp16_double -{ - #include "GeneratedShaders/grid_sample_fp16_double.h" -} - -namespace GridSample_float_double -{ - #include "GeneratedShaders/grid_sample_float_double.h" -} - -namespace GridSample_double_double -{ - #include "GeneratedShaders/grid_sample_double_double.h" -} - -namespace GridSample_bool_double -{ - #include "GeneratedShaders/grid_sample_bool_double.h" -} - #include #include @@ -471,14 +401,6 @@ class DmlGridSampleOperator : public WRL::Base computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_float_float::g_GridSample, sizeof(GridSample_float_float::g_GridSample)); break; - case MLOperatorTensorDataType::Double: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_double_float::g_GridSample, sizeof(GridSample_double_float::g_GridSample)); - break; - - case MLOperatorTensorDataType::Bool: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_bool_float::g_GridSample, sizeof(GridSample_bool_float::g_GridSample)); - break; - default: ORT_THROW_HR(E_INVALIDARG); } @@ -520,63 +442,6 @@ class DmlGridSampleOperator : public WRL::Base computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_float_fp16::g_GridSample, sizeof(GridSample_float_fp16::g_GridSample)); break; - case MLOperatorTensorDataType::Double: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_double_fp16::g_GridSample, sizeof(GridSample_double_fp16::g_GridSample)); - break; - - case MLOperatorTensorDataType::Bool: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_bool_fp16::g_GridSample, sizeof(GridSample_bool_fp16::g_GridSample)); - break; - - default: - ORT_THROW_HR(E_INVALIDARG); - } - break; - } - case MLOperatorTensorDataType::Double: - { - switch (inputDataType) - { - case MLOperatorTensorDataType::UInt16: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint16_double::g_GridSample, sizeof(GridSample_uint16_double::g_GridSample)); - break; - - case MLOperatorTensorDataType::UInt32: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint_double::g_GridSample, sizeof(GridSample_uint_double::g_GridSample)); - break; - - case MLOperatorTensorDataType::UInt64: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint64_double::g_GridSample, sizeof(GridSample_uint64_double::g_GridSample)); - break; - - case MLOperatorTensorDataType::Int16: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int16_double::g_GridSample, sizeof(GridSample_int16_double::g_GridSample)); - break; - - case MLOperatorTensorDataType::Int32: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int_double::g_GridSample, sizeof(GridSample_int_double::g_GridSample)); - break; - - case MLOperatorTensorDataType::Int64: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int64_double::g_GridSample, sizeof(GridSample_int64_double::g_GridSample)); - break; - - case MLOperatorTensorDataType::Float16: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_fp16_double::g_GridSample, sizeof(GridSample_fp16_double::g_GridSample)); - break; - - case MLOperatorTensorDataType::Float: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_float_double::g_GridSample, sizeof(GridSample_float_double::g_GridSample)); - break; - - case MLOperatorTensorDataType::Double: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_double_double::g_GridSample, sizeof(GridSample_double_double::g_GridSample)); - break; - - case MLOperatorTensorDataType::Bool: - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_bool_double::g_GridSample, sizeof(GridSample_bool_double::g_GridSample)); - break; - default: ORT_THROW_HR(E_INVALIDARG); } @@ -901,6 +766,7 @@ class DmlGridSampleOperatorFactory : public WRL::Base kernelDescription.executionType = MLOperatorExecutionType::D3D12; // T1: tensor(float16), tensor(float), tensor(double), tensor(bfloat16) + // tensor(double) is not supported for GPU MLOperatorEdgeTypeConstrant t1Constraint; t1Constraint.typeLabel = "T1"; std::vector t1AllowedEdges diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp index 9c1a7baeaa8df..03500d0ee86a9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp @@ -205,12 +205,34 @@ class DmlOperatorMultiHeadAttention : public DmlOperator else { const auto keyPaddingMaskTensorShape = m_inputTensorDescs[dmlMaskIndex].GetSizes(); - ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape.size() == 2); + size_t maskDimCount = keyPaddingMaskTensorShape.size(); + ML_CHECK_VALID_ARGUMENT(maskDimCount >= 2 || maskDimCount <= 4); ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[0] == batchSize); - ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == kvSequenceLength); - const uint32_t actualShape[4] = {batchSize, 1, 1, kvSequenceLength}; - const uint32_t desiredShape[4] = {batchSize, numHeads, sequenceLength, kvSequenceLength}; + std::array actualShape{}; + std::array desiredShape{}; + + if (maskDimCount == 2) + { + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == kvSequenceLength); + actualShape = {batchSize, 1, 1, kvSequenceLength}; + desiredShape = {batchSize, numHeads, sequenceLength, kvSequenceLength}; + } + else if (maskDimCount == 3) + { + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == sequenceLength); + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[2] == totalSequenceLength); + actualShape = {batchSize, 1, sequenceLength, totalSequenceLength}; + desiredShape = {batchSize, numHeads, sequenceLength, totalSequenceLength}; + } + else if (maskDimCount == 4) + { + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == numHeads); + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[2] == sequenceLength); + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[3] == totalSequenceLength); + actualShape = {batchSize, numHeads, sequenceLength, totalSequenceLength}; + desiredShape = {batchSize, numHeads, sequenceLength, totalSequenceLength}; + } m_inputTensorDescs[dmlMaskIndex] = TensorDesc::ConstructBroadcastedTensorDesc( m_inputTensorDescs[dmlMaskIndex].GetMlOperatorDataType(), diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp new file mode 100644 index 0000000000000..30c339b845b36 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp @@ -0,0 +1,436 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +// This operator is easier to understand by looking at a python implementation of the non-interleaved version: +// +// def rotate_half(x): +// """Rotates half the hidden dims of the input.""" +// half_dim = x.shape[-1] // 2 +// x1 = x[..., :half_dim] +// x2 = x[..., half_dim:] +// return np.concatenate((-x2, x1), dim=-1) +// +// +// def apply_rope(x, cos, sin, position_ids): +// cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] +// sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] +// x_embed = (x * cos) + (rotate_half(x) * sin) +// return x_embed +// +// For the non-interleaved version, we multiply the cos cache by the non-rotated input tensor while we multiply the sin cache +// by the rotated input tensor. Rotating the tensor means slicing it in half on the head dimension and swapping the 2 halves. +// +// The interleaved version is very similar but instead of swapping 2 halves, we swap every pair of adjacent elements and we swap +// the sign of every adjacent element. + +namespace Dml +{ +class DmlOperatorRotaryEmbedding : public DmlOperator +{ +public: + DmlOperatorRotaryEmbedding(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo) + { + enum InputIndex : uint32_t + { + inputDataIndex, + positionIdsIndex, + cosCacheIndex, + sinCacheIndex, + }; + + ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 4); + ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); + + // When positionIds is a scalar, it represents the start offset for each sequence + const bool positionIdsIsOffset = kernelInfo.GetInputTensorDimensionCount(positionIdsIndex) == 1; + + Initialize(kernelInfo); + + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[inputDataIndex].GetDimensionCount() == 4); + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[positionIdsIndex].GetDimensionCount() == 4); + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[cosCacheIndex].GetDimensionCount() == 4); + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[sinCacheIndex].GetDimensionCount() == 4); + + ML_CHECK_VALID_ARGUMENT(m_outputTensorDescs[0].GetDimensionCount() == 4); + + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[cosCacheIndex].GetSizes() == m_inputTensorDescs[sinCacheIndex].GetSizes()); + const uint32_t headSize = m_inputTensorDescs[cosCacheIndex].GetSizes().back() * 2; + + // The last dimension of the data is the hidden size, so it must be divisible by the head size + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[inputDataIndex].GetSizes().back() % headSize == 0); + + // We resize the data to be of shape [batchSize, sequenceLength, numHeads, headSize] + const auto inputDataSizes = m_inputTensorDescs[inputDataIndex].GetSizes(); + const uint32_t batchSize = inputDataSizes[1]; + const uint32_t sequenceLength = inputDataSizes[2]; + const uint32_t numHeads = inputDataSizes[3] / headSize; + + const auto cosCacheSizes = m_inputTensorDescs[cosCacheIndex].GetSizes(); + const uint32_t maxSequenceLength = cosCacheSizes[cosCacheSizes.size() - 2]; + + if (sequenceLength > maxSequenceLength) + { + ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); + } + + const bool interleaved = gsl::narrow_cast(kernelInfo.GetOptionalAttribute(AttrName::Interleaved, 0)); + + std::vector inputDescs = GetDmlInputDescs(); + const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType; + + // Splitting the hiddenSize into numHeads and headSize dimensions makes it easier for DML to handle + const std::array inputOutputShape = {batchSize, sequenceLength, numHeads, headSize}; + TensorDesc inputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape); + const DML_TENSOR_DESC inputOutputDmlTensorDesc = inputOutputTensorDesc.GetDmlDesc(); + + // Copy the input to preserve its real input shape in the graph without reshaping it. This will disappear during DML's graph compilation phase. + DML_SCALE_BIAS scaleBias = {1.0f, 0.0f}; + + DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC copyInputDesc{}; + copyInputDesc.InputTensor = &inputOutputDmlTensorDesc; + copyInputDesc.OutputTensor = &inputOutputDmlTensorDesc; + copyInputDesc.ScaleBias = &scaleBias; + const DML_OPERATOR_DESC copyInputDmlDesc = {DML_OPERATOR_ELEMENT_WISE_IDENTITY, ©InputDesc}; + + // Split the input data into 2 equal parts + const std::vector inputDataTensorShape = interleaved + ? std::vector({batchSize, sequenceLength, numHeads, headSize / 2, 2}) + : std::vector({batchSize, sequenceLength, numHeads, 2, headSize / 2}); + + const std::vector splitInputDataTensorShape = interleaved + ? std::vector({batchSize, sequenceLength, numHeads, headSize / 2, 1}) + : std::vector({batchSize, sequenceLength, numHeads, 1, headSize / 2}); + + TensorDesc inputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); + const DML_TENSOR_DESC inputDataDmlTensorDesc = inputDataTensorDesc.GetDmlDesc(); + + TensorDesc splitInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputDataTensorShape); + const std::array splitInputDataDmlTensorDescs = {splitInputDataTensorDesc.GetDmlDesc(), splitInputDataTensorDesc.GetDmlDesc()}; + + DML_SPLIT_OPERATOR_DESC splitInputDesc{}; + splitInputDesc.InputTensor = &inputDataDmlTensorDesc; + splitInputDesc.OutputTensors = splitInputDataDmlTensorDescs.data(); + splitInputDesc.OutputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); + splitInputDesc.Axis = interleaved + ? gsl::narrow_cast(splitInputDataTensorShape.size()) - 1 + : gsl::narrow_cast(splitInputDataTensorShape.size()) - 2; + + const DML_OPERATOR_DESC splitInputDmlDesc = {DML_OPERATOR_SPLIT, &splitInputDesc}; + + // Swap the 2 halves and join them together + DML_JOIN_OPERATOR_DESC joinInputDesc{}; + joinInputDesc.InputTensors = splitInputDataDmlTensorDescs.data(); + joinInputDesc.OutputTensor = &inputDataDmlTensorDesc; + joinInputDesc.Axis = splitInputDesc.Axis; + joinInputDesc.InputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); + const DML_OPERATOR_DESC joinInputDmlDesc = {DML_OPERATOR_JOIN, &joinInputDesc}; + + // We generate a sequence from 0 to sequenceLength and add the offset to it + const std::array positionIdsRangeShape = {1, 1, 1, sequenceLength}; + auto positionIdsDataType = kernelInfo.GetInputEdgeDescription(positionIdsIndex).tensorDataType; + TensorDesc positionIdsRangeTensorDesc = TensorDesc::ConstructDefaultTensorDesc(positionIdsDataType, positionIdsRangeShape); + const DML_TENSOR_DESC positionIdsRangeDmlTensorDesc = positionIdsRangeTensorDesc.GetDmlDesc(); + + const std::array broadcastedPositionIdsRangeShape = {1, 1, batchSize, sequenceLength}; + TensorDesc broadcastedPositionIdsRangeTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(positionIdsDataType, broadcastedPositionIdsRangeShape, positionIdsRangeShape); + const DML_TENSOR_DESC broadcastedPositionIdsRangeDmlTensorDesc = broadcastedPositionIdsRangeTensorDesc.GetDmlDesc(); + + const std::array broadcastedOffsetShape = {1, 1, batchSize, sequenceLength}; + TensorDesc broadcastedOffsetTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(positionIdsDataType, broadcastedOffsetShape, m_inputTensorDescs[positionIdsIndex].GetSizes()); + const DML_TENSOR_DESC broadcastedOffsetDmlTensorDesc = broadcastedOffsetTensorDesc.GetDmlDesc(); + + TensorDesc offsetPositionIdsTensorDesc = TensorDesc::ConstructDefaultTensorDesc(positionIdsDataType, broadcastedOffsetShape); + const DML_TENSOR_DESC offsetPositionIdsRangeDmlTensorDesc = offsetPositionIdsTensorDesc.GetDmlDesc(); + + DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC positionIdsRange{}; + DML_ELEMENT_WISE_ADD_OPERATOR_DESC positionIdsAddOffset{}; + if (positionIdsIsOffset) + { + ML_CHECK_VALID_ARGUMENT(positionIdsDataType == MLOperatorTensorDataType::Int64); + positionIdsRange.ValueDataType = DML_TENSOR_DATA_TYPE_INT64; + positionIdsRange.ValueDelta.Int64 = 1; + positionIdsRange.OutputTensor = &positionIdsRangeDmlTensorDesc; + + positionIdsAddOffset.ATensor = &broadcastedPositionIdsRangeDmlTensorDesc; + positionIdsAddOffset.BTensor = &broadcastedOffsetDmlTensorDesc; + positionIdsAddOffset.OutputTensor = &offsetPositionIdsRangeDmlTensorDesc; + } + const DML_OPERATOR_DESC positionIdsRangeDmlDesc = {DML_OPERATOR_FILL_VALUE_SEQUENCE, &positionIdsRange}; + const DML_OPERATOR_DESC positionIdsAddOffsetDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &positionIdsAddOffset}; + + // Gather the cos/sin values based on the position ids + const std::array gatheredCosSinShape = {1, batchSize, sequenceLength, headSize / 2}; + TensorDesc gatheredCosSinTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, gatheredCosSinShape); + const DML_TENSOR_DESC gatheredCosSinDmlTensorDesc = gatheredCosSinTensorDesc.GetDmlDesc(); + + DML_GATHER_OPERATOR_DESC gatherCosSinDesc{}; + gatherCosSinDesc.InputTensor = &inputDescs[cosCacheIndex]; + gatherCosSinDesc.IndicesTensor = positionIdsIsOffset ? &offsetPositionIdsRangeDmlTensorDesc : &inputDescs[positionIdsIndex]; + gatherCosSinDesc.OutputTensor = &gatheredCosSinDmlTensorDesc; + gatherCosSinDesc.Axis = 2; + gatherCosSinDesc.IndexDimensions = 2; + const DML_OPERATOR_DESC gatherCosSinDmlDesc {DML_OPERATOR_GATHER, &gatherCosSinDesc}; + + // After gathering cos/sin, reshape and broadcast them to match the number of heads of the input data + const std::vector reshapedCosSinShape = interleaved + ? std::vector({batchSize, sequenceLength, 1, headSize / 2, 1}) + : std::vector({batchSize, sequenceLength, 1, 1, headSize / 2}); + TensorDesc broadcastedCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, inputDataTensorShape, reshapedCosSinShape); + const DML_TENSOR_DESC broadcastedCosSinDmlTensorDesc = broadcastedCosSinTensorDesc.GetDmlDesc(); + + // Create a vector that contains the sign values {-1, 1} + const std::array signTensorShape = {2}; + TensorDesc signTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, signTensorShape); + const DML_TENSOR_DESC signDmlTensorDesc = signTensorDesc.GetDmlDesc(); + + DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC signRange{}; + signRange.OutputTensor = &signDmlTensorDesc; + if (dataType == MLOperatorTensorDataType::Float16) + { + const auto valueStart = static_cast(-1.0f); + const auto valueDelta = static_cast(2.0f); + memcpy(signRange.ValueStart.Bytes, reinterpret_cast(&valueStart), sizeof(valueStart)); + memcpy(signRange.ValueDelta.Bytes, reinterpret_cast(&valueDelta), sizeof(valueDelta)); + signRange.ValueDataType = DML_TENSOR_DATA_TYPE_FLOAT16; + } + else + { + ML_CHECK_VALID_ARGUMENT(dataType == MLOperatorTensorDataType::Float); + signRange.ValueStart.Float32 = -1.0f; + signRange.ValueDelta.Float32 = 2.0f; + signRange.ValueDataType = DML_TENSOR_DATA_TYPE_FLOAT32; + } + const DML_OPERATOR_DESC signRangeDmlDesc = {DML_OPERATOR_FILL_VALUE_SEQUENCE, &signRange}; + + // Multiply the broadcasted sign values with the rotated input + const std::vector reshapedSignShape = interleaved + ? std::vector({1, 1, 1, 1, 2}) + : std::vector({1, 1, 1, 2, 1}); + TensorDesc broadcastedSignCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, inputDataTensorShape, reshapedSignShape); + const DML_TENSOR_DESC broadcastedSignDmlTensorDesc = broadcastedSignCosSinTensorDesc.GetDmlDesc(); + + DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulSignDesc{}; + mulSignDesc.ATensor = &inputDataDmlTensorDesc; + mulSignDesc.BTensor = &broadcastedSignDmlTensorDesc; + mulSignDesc.OutputTensor = &inputDataDmlTensorDesc; + const DML_OPERATOR_DESC mulSignDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulSignDesc}; + + // Multiply the non-rotated data with the cos and the rotated data with the sin + DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulCosSinDesc{}; + mulCosSinDesc.ATensor = &inputDataDmlTensorDesc; + mulCosSinDesc.BTensor = &broadcastedCosSinDmlTensorDesc; + mulCosSinDesc.OutputTensor = &inputDataDmlTensorDesc; + const DML_OPERATOR_DESC mulCosSinDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulCosSinDesc}; + + // Add the multiplied cos and sin values together + DML_ELEMENT_WISE_ADD_OPERATOR_DESC addDesc{}; + addDesc.ATensor = &inputOutputDmlTensorDesc; + addDesc.BTensor = &inputOutputDmlTensorDesc; + addDesc.OutputTensor = &inputOutputDmlTensorDesc; + const DML_OPERATOR_DESC addDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &addDesc}; + + // Construct the graph + std::vector inputEdges; + std::vector intermediateEdges; + std::vector outputEdges; + + std::vector opDescs = { + ©InputDmlDesc, // Copy the input data to preseve the real input shape + &splitInputDmlDesc, // Split the input data + &gatherCosSinDmlDesc, // Gather cos + &gatherCosSinDmlDesc, // Gather sin + &signRangeDmlDesc, // Generate the signs + + &joinInputDmlDesc, // Join the split data + &mulCosSinDmlDesc, // Multiply cos with the non-rotated data + &mulCosSinDmlDesc, // Multiply sin with the rotated data + &mulSignDmlDesc, // Multiply the sign with the rotated data + &addDmlDesc, // Add the rotated cos and non-rotated sin parts together + }; + + enum NodeIndex : uint32_t + { + copyInputOpIndex, + splitInputOpIndex, + gatherCosOpIndex, + gatherSinOpIndex, + signRangeOpIndex, + + joinInputOpIndex, + mulCosOpIndex, + mulSinOpIndex, + mulSignOpIndex, + addOpIndex, + + // The following indices are optional + positionIdsRangeOpIndex, + positionIdsAddOffsetOpIndex, + }; + + if (positionIdsIsOffset) + { + opDescs.push_back(&positionIdsRangeDmlDesc); + opDescs.push_back(&positionIdsAddOffsetDmlDesc); + + DML_INPUT_GRAPH_EDGE_DESC positionIdsToAddOffsetEdge = {}; + positionIdsToAddOffsetEdge.GraphInputIndex = positionIdsIndex; + positionIdsToAddOffsetEdge.ToNodeIndex = positionIdsAddOffsetOpIndex; + positionIdsToAddOffsetEdge.ToNodeInputIndex = 1; + inputEdges.push_back(positionIdsToAddOffsetEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC positionIdsOffsetToAddOffsetEdge = {}; + positionIdsOffsetToAddOffsetEdge.FromNodeIndex = positionIdsRangeOpIndex; + positionIdsOffsetToAddOffsetEdge.FromNodeOutputIndex = 0; + positionIdsOffsetToAddOffsetEdge.ToNodeIndex = positionIdsAddOffsetOpIndex; + positionIdsOffsetToAddOffsetEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(positionIdsOffsetToAddOffsetEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC positionIdsAddOffsetToGatherCosEdge = {}; + positionIdsAddOffsetToGatherCosEdge.FromNodeIndex = positionIdsAddOffsetOpIndex; + positionIdsAddOffsetToGatherCosEdge.FromNodeOutputIndex = 0; + positionIdsAddOffsetToGatherCosEdge.ToNodeIndex = gatherCosOpIndex; + positionIdsAddOffsetToGatherCosEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(positionIdsAddOffsetToGatherCosEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC positionIdsAddOffsetToGatherSinEdge = {}; + positionIdsAddOffsetToGatherSinEdge.FromNodeIndex = positionIdsAddOffsetOpIndex; + positionIdsAddOffsetToGatherSinEdge.FromNodeOutputIndex = 0; + positionIdsAddOffsetToGatherSinEdge.ToNodeIndex = gatherSinOpIndex; + positionIdsAddOffsetToGatherSinEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(positionIdsAddOffsetToGatherSinEdge); + } + else + { + DML_INPUT_GRAPH_EDGE_DESC positionIdsToGatherCosEdge = {}; + positionIdsToGatherCosEdge.GraphInputIndex = positionIdsIndex; + positionIdsToGatherCosEdge.ToNodeIndex = gatherCosOpIndex; + positionIdsToGatherCosEdge.ToNodeInputIndex = 1; + inputEdges.push_back(positionIdsToGatherCosEdge); + + DML_INPUT_GRAPH_EDGE_DESC positionIdsToGatherSinEdge = {}; + positionIdsToGatherSinEdge.GraphInputIndex = positionIdsIndex; + positionIdsToGatherSinEdge.ToNodeIndex = gatherSinOpIndex; + positionIdsToGatherSinEdge.ToNodeInputIndex = 1; + inputEdges.push_back(positionIdsToGatherSinEdge); + } + + DML_INPUT_GRAPH_EDGE_DESC inputToCopyInputEdge = {}; + inputToCopyInputEdge.GraphInputIndex = inputDataIndex; + inputToCopyInputEdge.ToNodeIndex = copyInputOpIndex; + inputToCopyInputEdge.ToNodeInputIndex = 0; + inputEdges.push_back(inputToCopyInputEdge); + + DML_INPUT_GRAPH_EDGE_DESC cosToGatherEdge = {}; + cosToGatherEdge.GraphInputIndex = cosCacheIndex; + cosToGatherEdge.ToNodeIndex = gatherCosOpIndex; + cosToGatherEdge.ToNodeInputIndex = 0; + inputEdges.push_back(cosToGatherEdge); + + DML_INPUT_GRAPH_EDGE_DESC sinToGatherEdge = {}; + sinToGatherEdge.GraphInputIndex = sinCacheIndex; + sinToGatherEdge.ToNodeIndex = gatherSinOpIndex; + sinToGatherEdge.ToNodeInputIndex = 0; + inputEdges.push_back(sinToGatherEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC inputToSplitEdge = {}; + inputToSplitEdge.FromNodeIndex = copyInputOpIndex; + inputToSplitEdge.FromNodeOutputIndex = 0; + inputToSplitEdge.ToNodeIndex = splitInputOpIndex; + inputToSplitEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(inputToSplitEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC nonRotatedDataToMulEdge = {}; + nonRotatedDataToMulEdge.FromNodeIndex = copyInputOpIndex; + nonRotatedDataToMulEdge.FromNodeOutputIndex = 0; + nonRotatedDataToMulEdge.ToNodeIndex = mulCosOpIndex; + nonRotatedDataToMulEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(nonRotatedDataToMulEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC secondHalfDataToJoinEdge = {}; + secondHalfDataToJoinEdge.FromNodeIndex = splitInputOpIndex; + secondHalfDataToJoinEdge.FromNodeOutputIndex = 1; + secondHalfDataToJoinEdge.ToNodeIndex = joinInputOpIndex; + secondHalfDataToJoinEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(secondHalfDataToJoinEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC firstHalfDataToJoinEdge = {}; + firstHalfDataToJoinEdge.FromNodeIndex = splitInputOpIndex; + firstHalfDataToJoinEdge.FromNodeOutputIndex = 0; + firstHalfDataToJoinEdge.ToNodeIndex = joinInputOpIndex; + firstHalfDataToJoinEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(firstHalfDataToJoinEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC cosToMulEdge = {}; + cosToMulEdge.FromNodeIndex = gatherCosOpIndex; + cosToMulEdge.FromNodeOutputIndex = 0; + cosToMulEdge.ToNodeIndex = mulCosOpIndex; + cosToMulEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(cosToMulEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC rotatedDataToMulEdge = {}; + rotatedDataToMulEdge.FromNodeIndex = joinInputOpIndex; + rotatedDataToMulEdge.FromNodeOutputIndex = 0; + rotatedDataToMulEdge.ToNodeIndex = mulSinOpIndex; + rotatedDataToMulEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(rotatedDataToMulEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC sinToMulEdge = {}; + sinToMulEdge.FromNodeIndex = gatherSinOpIndex; + sinToMulEdge.FromNodeOutputIndex = 0; + sinToMulEdge.ToNodeIndex = mulSinOpIndex; + sinToMulEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(sinToMulEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC rotatedSinToMulEdge = {}; + rotatedSinToMulEdge.FromNodeIndex = mulSinOpIndex; + rotatedSinToMulEdge.FromNodeOutputIndex = 0; + rotatedSinToMulEdge.ToNodeIndex = mulSignOpIndex; + rotatedSinToMulEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(rotatedSinToMulEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC signToMulEdge = {}; + signToMulEdge.FromNodeIndex = signRangeOpIndex; + signToMulEdge.FromNodeOutputIndex = 0; + signToMulEdge.ToNodeIndex = mulSignOpIndex; + signToMulEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(signToMulEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC nonRotatedCosToAddEdge = {}; + nonRotatedCosToAddEdge.FromNodeIndex = mulCosOpIndex; + nonRotatedCosToAddEdge.FromNodeOutputIndex = 0; + nonRotatedCosToAddEdge.ToNodeIndex = addOpIndex; + nonRotatedCosToAddEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(nonRotatedCosToAddEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC rotatedSinToAddEdge = {}; + rotatedSinToAddEdge.FromNodeIndex = mulSignOpIndex; + rotatedSinToAddEdge.FromNodeOutputIndex = 0; + rotatedSinToAddEdge.ToNodeIndex = addOpIndex; + rotatedSinToAddEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(rotatedSinToAddEdge); + + DML_OUTPUT_GRAPH_EDGE_DESC addToOutputEdge = {}; + addToOutputEdge.FromNodeIndex = addOpIndex; + addToOutputEdge.FromNodeOutputIndex = 0; + addToOutputEdge.GraphOutputIndex = 0; + outputEdges.push_back(addToOutputEdge); + + MLOperatorGraphDesc operatorGraphDesc = {}; + operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); + operatorGraphDesc.inputEdges = inputEdges.data(); + operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size()); + operatorGraphDesc.intermediateEdges = intermediateEdges.data(); + operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); + operatorGraphDesc.outputEdges = outputEdges.data(); + operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); + operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + + SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo); + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(RotaryEmbedding, DmlOperatorRotaryEmbedding); + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/GenerateShaders.bat b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/GenerateShaders.bat index fb087bd800ff0..c5580ee103595 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/GenerateShaders.bat +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/GenerateShaders.bat @@ -16,8 +16,6 @@ if "%1" == "DEBUG" ( dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=float -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int64_float.h dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=float -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_fp16_float.h fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=float /DTBUFFER2=float /Zi /Od /Fh grid_sample_float_float.h - fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=double /DTBUFFER2=float /Zi /Od /Fh grid_sample_double_float.h - fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=bool /DTBUFFER2=float /Zi /Od /Fh grid_sample_bool_float.h dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint16_fp16.h dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint_fp16.h @@ -27,20 +25,6 @@ if "%1" == "DEBUG" ( dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int64_fp16.h dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_fp16_fp16.h dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_float_fp16.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=double -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_double_fp16.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=bool -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_bool_fp16.h - - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint16_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint64_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint64_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int16_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int16_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int64_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_fp16_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_float_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=double -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_double_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=bool -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_bool_double.h - ) else ( fxc.exe ..\Shaders\stockham.hlsl -E DFT -T cs_5_0 /DTBUFFER=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh stockham.h dxc.exe ..\Shaders\stockham.hlsl -E DFT -T cs_6_2 -DTBUFFER=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh stockham_fp16.h @@ -56,8 +40,6 @@ if "%1" == "DEBUG" ( dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=float -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int64_float.h dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=float -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_fp16_float.h fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=float /DTBUFFER2=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh grid_sample_float_float.h - fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=double /DTBUFFER2=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh grid_sample_double_float.h - fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=bool /DTBUFFER2=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh grid_sample_bool_float.h dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint16_fp16.h dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint_fp16.h @@ -67,18 +49,5 @@ if "%1" == "DEBUG" ( dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int64_fp16.h dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_fp16_fp16.h dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_float_fp16.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=double -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_double_fp16.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=bool -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_bool_fp16.h - - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint16_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint64_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint64_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int16_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int16_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int64_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_fp16_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_float_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=double -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_double_double.h - dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=bool -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_bool_double.h ) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/grid_sample_bool_double.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/grid_sample_bool_double.h deleted file mode 100644 index 6d83865ecb7a2..0000000000000 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/grid_sample_bool_double.h +++ /dev/null @@ -1,6398 +0,0 @@ -#if 0 -; -; Note: shader requires additional functionality: -; Double-precision floating point -; -; -; Input signature: -; -; Name Index Mask Register SysValue Format Used -; -------------------- ----- ------ -------- -------- ------- ------ -; no parameters -; -; Output signature: -; -; Name Index Mask Register SysValue Format Used -; -------------------- ----- ------ -------- -------- ------- ------ -; no parameters -; shader hash: eff86a5dd3f8ca652b3700c52570e535 -; -; Pipeline Runtime Information: -; -; -; -; Buffer Definitions: -; -; cbuffer -; { -; -; [116 x i8] (type annotation not present) -; -; } -; -; Resource bind info for -; { -; -; [4 x i8] (type annotation not present) -; -; } -; -; Resource bind info for -; { -; -; [8 x i8] (type annotation not present) -; -; } -; -; Resource bind info for -; { -; -; [4 x i8] (type annotation not present) -; -; } -; -; -; Resource Bindings: -; -; Name Type Format Dim ID HLSL Bind Count -; ------------------------------ ---------- ------- ----------- ------- -------------- ------ -; cbuffer NA NA CB0 cb0 1 -; UAV struct r/w U0 u0 1 -; UAV struct r/w U1 u1 1 -; UAV struct r/w U2 u2 1 -; -target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64" -target triple = "dxil-ms-dx" - -%dx.types.Handle = type { i8* } -%dx.types.CBufRet.i32 = type { i32, i32, i32, i32 } -%dx.types.ResRet.i32 = type { i32, i32, i32, i32, i32 } -%"class.RWStructuredBuffer" = type { i32 } -%"class.RWStructuredBuffer" = type { double } -%Constants = type { i32, i32, i32, i32, <4 x i32>, <4 x i32>, <4 x i32>, <4 x i32>, <4 x i32>, <4 x i32>, i32 } - -define void @GridSample() { - %1 = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 2, i32 2, i1 false) ; CreateHandle(resourceClass,rangeId,index,nonUniformIndex) - %2 = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 1, i32 1, i1 false) ; CreateHandle(resourceClass,rangeId,index,nonUniformIndex) - %3 = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 0, i32 0, i1 false) ; CreateHandle(resourceClass,rangeId,index,nonUniformIndex) - %4 = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 2, i32 0, i32 0, i1 false) ; CreateHandle(resourceClass,rangeId,index,nonUniformIndex) - %5 = call i32 @dx.op.threadId.i32(i32 93, i32 0) ; ThreadId(component) - %6 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %4, i32 0) ; CBufferLoadLegacy(handle,regIndex) - %7 = extractvalue %dx.types.CBufRet.i32 %6, 0 - %8 = add i32 %7, %5 - %9 = extractvalue %dx.types.CBufRet.i32 %6, 1 - %10 = icmp ult i32 %8, %9 - br i1 %10, label %11, label %3389 - -;