diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github/workflows/gradle-wrapper-validation.yml index 03ea773a25130..bc2d8117930bc 100644 --- a/.github/workflows/gradle-wrapper-validation.yml +++ b/.github/workflows/gradle-wrapper-validation.yml @@ -11,4 +11,4 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: gradle/wrapper-validation-action@v1 + - uses: gradle/wrapper-validation-action@v2 diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index ce8fb3160954e..a196226a4b836 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -3,11 +3,14 @@ on: issues: types: [opened, edited] +permissions: + issues: write + jobs: triage: runs-on: ubuntu-latest steps: - - uses: github/issue-labeler@v3.3 + - uses: github/issue-labeler@v3.4 with: repo-token: "${{ secrets.GITHUB_TOKEN }}" configuration-path: .github/labeler.yml diff --git a/.github/workflows/publish-csharp-apidocs.yml b/.github/workflows/publish-csharp-apidocs.yml index c03399f4693be..5bc21595bf882 100644 --- a/.github/workflows/publish-csharp-apidocs.yml +++ b/.github/workflows/publish-csharp-apidocs.yml @@ -37,7 +37,7 @@ jobs: wget https://github.com/dotnet/docfx/releases/download/v${DOCFXVERSION}/docfx-linux-x64-v${DOCFXVERSION}.zip -O build/docfx/docfx.zip unzip build/docfx/docfx.zip -d build/docfx - name: Install NuGet - uses: nuget/setup-nuget@v1 + uses: nuget/setup-nuget@v2 - name: Build Documentation run: | build/docfx/docfx metadata csharp/ApiDocs/docfx.json diff --git a/.github/workflows/publish-java-apidocs.yml b/.github/workflows/publish-java-apidocs.yml index 708842e59f9f2..3e553049a186e 100644 --- a/.github/workflows/publish-java-apidocs.yml +++ b/.github/workflows/publish-java-apidocs.yml @@ -30,7 +30,7 @@ jobs: java-version: '11' distribution: 'adopt' - name: Build with Gradle - uses: gradle/gradle-build-action@v2 + uses: gradle/gradle-build-action@v3 with: build-root-directory: java gradle-executable: java/gradlew diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index c94e3fa5bcb8c..181f3fb17d332 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -13,7 +13,7 @@ jobs: issues: write pull-requests: write steps: - - uses: actions/stale@v9.0.0 + - uses: actions/stale@v8 with: # Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale exempt-issue-labels: contributions welcome, feature request, regression diff --git a/.lintrunner.toml b/.lintrunner.toml index 4e5d077b08ff4..be95e03479cf9 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -132,6 +132,7 @@ exclude_patterns = [ 'onnxruntime/core/flatbuffers/schema/*.fbs.h', # Generated code 'onnxruntime/core/graph/contrib_ops/quantization_defs.cc', 'onnxruntime/core/mlas/**', # Contains assembly code + 'onnxruntime/core/mickey/cutlass_ext/**', # CUTLASS lib recommends NO automatic code formatting 'winml/lib/Api.Image/shaders/**', # Contains data chunks ] command = [ diff --git a/.pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml b/.pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml index 67f9d8b0ce392..fd3b7266d30f7 100644 --- a/.pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml +++ b/.pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml @@ -29,6 +29,8 @@ extends: git: submodules: false globalSdl: # https://aka.ms/obpipelines/sdl + asyncSdl: + enabled: false tsa: enabled: true prefast: diff --git a/.pipelines/windowsai-steps.yml b/.pipelines/windowsai-steps.yml index 6e551d8187171..855573de753b0 100644 --- a/.pipelines/windowsai-steps.yml +++ b/.pipelines/windowsai-steps.yml @@ -80,11 +80,11 @@ jobs: # must call vsdevcmd first to add cmake to PATH - script: | - curl -O -L https://github.com/Kitware/CMake/releases/download/v3.26.3/cmake-3.26.3-windows-x86_64.zip - 7z x cmake-3.26.3-windows-x86_64.zip + curl -O -L https://github.com/Kitware/CMake/releases/download/v3.28.3/cmake-3.28.3-windows-x86_64.zip + 7z x cmake-3.28.3-windows-x86_64.zip set PYTHONHOME=$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools set PYTHONPATH=$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools - $(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools\python.exe "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_qspectre --enable_lto --use_telemetry --disable_rtti --enable_wcos $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" CMAKE_SYSTEM_VERSION=10.0.19041.0 --cmake_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\ctest.exe + $(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools\python.exe "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --parallel --use_binskim_compliant_compile_flags --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos --windows_sdk_version "10.0.22621.0" $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" --cmake_path $(Build.BinariesDirectory)\cmake-3.28.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.28.3-windows-x86_64\bin\ctest.exe workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Generate cmake config' diff --git a/.vscode/settings.json b/.vscode/settings.json index 3e2b1f31dd6cf..98d23090fd474 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -21,5 +21,8 @@ "cpplint.filters": [ "-build/include_subdir", "-runtime/references" - ] + ], + "files.associations": { + "span": "cpp" + } } diff --git a/CITATION.cff b/CITATION.cff index 82bcac5a7b750..10b7290022aef 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -3,8 +3,7 @@ title: ONNX Runtime message: "Please use this information to cite ONNX Runtime in research or other publications." authors: - - affiliation: Microsoft Corporation - given-names: ONNX Runtime developers + - name: ONNX Runtime developers date-released: 2018-11-29 url: "https://onnxruntime.ai" repository-code: "https://github.com/microsoft/onnxruntime" diff --git a/ThirdPartyNotices.txt b/ThirdPartyNotices.txt index 700206180decd..30894903ec8d2 100644 --- a/ThirdPartyNotices.txt +++ b/ThirdPartyNotices.txt @@ -6299,3 +6299,210 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +_____ + +neural-speed + +https://github.com/intel/neural-speed + + Apache License + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + ============================================================================ + + Copyright 2016-2019 Intel Corporation + Copyright 2018 YANDEX LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + This distribution includes third party software ("third party programs"). + This third party software, even if included with the distribution of + the Intel software, may be governed by separate license terms, including + without limitation, third party license terms, other Intel software license + terms, and open source software license terms. These separate license terms + govern your use of the third party programs as set forth in the + "THIRD-PARTY-PROGRAMS" file. diff --git a/cgmanifests/generate_cgmanifest.py b/cgmanifests/generate_cgmanifest.py index 81181d3ccfb20..3cecbb0cc977f 100644 --- a/cgmanifests/generate_cgmanifest.py +++ b/cgmanifests/generate_cgmanifest.py @@ -115,8 +115,8 @@ def normalize_path_separators(path): submodule_lines = proc.stdout.splitlines() for submodule_line in submodule_lines: (absolute_path, url, commit) = submodule_line.split(" ") - git_deps[GitDep(commit, url)] = "git submodule at {}".format( - normalize_path_separators(os.path.relpath(absolute_path, REPO_DIR)) + git_deps[GitDep(commit, url)] = ( + f"git submodule at {normalize_path_separators(os.path.relpath(absolute_path, REPO_DIR))}" ) with open(os.path.join(SCRIPT_DIR, "..", "cmake", "deps.txt")) as f: diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 03e3f84547a68..dc7e9c3fddb2f 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -42,6 +42,16 @@ "comments": "abseil_cpp" } }, + { + "component": { + "type": "git", + "git": { + "commitHash": "dbb0094fd0cb936469e35320bf37e866ef7a1da4", + "repositoryUrl": "https://github.com/apple/coremltools.git" + }, + "comments": "coremltools" + } + }, { "component": { "type": "git", @@ -76,7 +86,7 @@ "component": { "type": "git", "git": { - "commitHash": "6df40a2471737b27271bdd9b900ab5f3aec746c7", + "commitHash": "0100f6a5779831fa7a651e4b67ef389a8752bd9b", "repositoryUrl": "https://github.com/google/flatbuffers.git" }, "comments": "flatbuffers" @@ -106,7 +116,7 @@ "component": { "type": "git", "git": { - "commitHash": "361e8d1cfe0c6c36d30b39f1b61302ece5507320", + "commitHash": "344117638c8ff7e239044fd0fa7085839fc03021", "repositoryUrl": "https://github.com/google/benchmark.git" }, "comments": "google_benchmark" diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 0eb224623f678..02b568abdf8da 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -88,7 +88,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF) option(onnxruntime_USE_SNPE "Build with SNPE support" OFF) option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF) option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) -option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" ON) +option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" OFF) option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF) option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON) option(onnxruntime_BUILD_CSHARP "Build C# library" OFF) @@ -117,9 +117,7 @@ option(onnxruntime_CROSS_COMPILING "Cross compiling onnx runtime" OFF) option(onnxruntime_GCOV_COVERAGE "Compile with options necessary to run code coverage" OFF) option(onnxruntime_DONT_VECTORIZE "Do not vectorize operations in Eigen" OFF) -#It's preferred to turn it OFF when onnxruntime is dynamically linked to PROTOBUF. But Tensort always required the full version of protobuf. -cmake_dependent_option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF "NOT onnxruntime_USE_TENSORRT" ON) -option(tensorflow_C_PACKAGE_PATH "Path to tensorflow C package installation dir") +option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF) option(onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS "Enable operator implemented in language other than cpp" OFF) option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump debug information about node inputs and outputs when executing the model." OFF) cmake_dependent_option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS_ENABLE_DUMP_TO_SQLDB "Build dump debug information about node inputs and outputs with support for sql database." OFF "onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS" OFF) @@ -325,15 +323,27 @@ if (onnxruntime_USE_ROCM) endif() # replicate strategy used by pytorch to get ROCM_VERSION - # https://github.com/pytorch/pytorch/blob/8eb21488fdcdb8b0e6fa2e46179b5fa6c42e75af/cmake/public/LoadHIP.cmake#L153-L173 - file(READ "${onnxruntime_ROCM_HOME}/.info/version-dev" ROCM_VERSION_DEV_RAW) - string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_DEV_MATCH ${ROCM_VERSION_DEV_RAW}) - if (ROCM_VERSION_DEV_MATCH) + # https://github.com/pytorch/pytorch/blob/5c5b71b6eebae76d744261715231093e62f0d090/cmake/public/LoadHIP.cmake + # with modification + if (EXISTS "${onnxruntime_ROCM_HOME}/.info/version-dev") + file(READ "${onnxruntime_ROCM_HOME}/.info/version-dev" ROCM_VERSION_DEV_RAW) + string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_MATCH ${ROCM_VERSION_DEV_RAW}) + elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm_version.h") + file(READ "${onnxruntime_ROCM_HOME}/include/rocm_version.h" ROCM_VERSION_H_RAW) + string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW}) + elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h") + file(READ "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h" ROCM_VERSION_H_RAW) + string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW}) + endif() + + if (ROCM_VERSION_MATCH) set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1}) set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2}) set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3}) set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}") math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}") + else() + message(FATAL_ERROR "Cannot determine ROCm version string") endif() message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/.info/version-dev ****\n") message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}") @@ -641,6 +651,7 @@ else() check_cxx_compiler_flag(-Wunused-but-set-variable HAS_UNUSED_BUT_SET_VARIABLE) check_cxx_compiler_flag(-Wunused-variable HAS_UNUSED_VARIABLE) check_cxx_compiler_flag(-Wuseless-cast HAS_USELESS_CAST) + check_cxx_compiler_flag(-Wstringop-overflow HAS_STRINGOP_OVERFLOW) check_function_exists(reallocarray HAS_REALLOCARRAY) if (NOT APPLE AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND onnxruntime_target_platform STREQUAL "aarch64") check_cxx_compiler_flag(-march=armv8.2-a+bf16 HAS_ARM64_BFLOAT16) @@ -715,6 +726,9 @@ if (onnxruntime_USE_CUDA) set(onnxruntime_USE_FLASH_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() + if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4) + message( FATAL_ERROR "Failed build due to CUDA compiler version < 11.4") + endif() else() set(onnxruntime_USE_FLASH_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) @@ -735,8 +749,8 @@ if (onnxruntime_USE_CUDA) list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_MEMORY_EFFICIENT_ATTENTION=1) endif() - endif() + if (onnxruntime_USE_VITISAI) list(APPEND ORT_PROVIDER_FLAGS -DUSE_VITISAI=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_VITISAI=1) @@ -984,9 +998,12 @@ function(onnxruntime_set_compile_flags target_name) foreach(FLAG ${ORT_WARNING_FLAGS}) target_compile_options(${target_name} PRIVATE "$<$:SHELL:--compiler-options ${FLAG}>") endforeach() - if ((NVCC_HAS_STRICT_ALIASING AND "${target_name}" MATCHES "cuda") OR (HAS_STRICT_ALIASING AND NOT "${target_name}" MATCHES "cuda")) + if (NVCC_HAS_STRICT_ALIASING AND "${target_name}" MATCHES "cuda") target_compile_options(${target_name} PRIVATE "$<$:-Wno-strict-aliasing>") endif() + if (HAS_STRICT_ALIASING AND NOT "${target_name}" MATCHES "cuda") + target_compile_options(${target_name} PRIVATE "$<$:-Wno-strict-aliasing>") + endif() endif() if (onnxruntime_USE_ROCM) # flags are detected with CXX language mode, some flags are not supported with hipclang @@ -1236,17 +1253,15 @@ if (onnxruntime_USE_TVM) $) set(onnxruntime_tvm_libs onnxruntime_providers_tvm) - - # needs to link with stdc++fs in Linux - if (UNIX) - if (NOT APPLE) - set(FS_STDLIB stdc++fs) - endif() - endif() - list(APPEND onnxruntime_EXTERNAL_LIBRARIES tvm ${FS_STDLIB}) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES tvm) list(APPEND onnxruntime_EXTERNAL_DEPENDENCIES tvm) endif() +# needs to link with stdc++fs in Linux +if (UNIX AND "${CMAKE_C_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_LESS 9) + set(FS_STDLIB stdc++fs) +endif() +list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${FS_STDLIB}) # onnxruntime-extensions if (onnxruntime_USE_EXTENSIONS) @@ -1256,11 +1271,7 @@ endif() #Dependencies end. In the next we'll enable "treat warning as error" #Adjust warning flags -if (onnxruntime_USE_CUDA) - set_msvc_c_cpp_compiler_warning_level(3) -else() - set_msvc_c_cpp_compiler_warning_level(4) -endif() +set_msvc_c_cpp_compiler_warning_level(4) set(onnxruntime_DELAYLOAD_FLAGS "") @@ -1397,6 +1408,10 @@ endif() if (onnxruntime_USE_CUDA) set(CMAKE_CUDA_RUNTIME_LIBRARY Shared) set(CMAKE_CUDA_STANDARD 17) + if(onnxruntime_CUDA_HOME) + file(TO_CMAKE_PATH CUDAToolkit_ROOT ${onnxruntime_CUDA_HOME}) + endif() + find_package(CUDAToolkit REQUIRED) if(onnxruntime_CUDNN_HOME) file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME) endif() @@ -1597,7 +1612,7 @@ if (UNIX AND onnxruntime_USE_NCCL) else() set(onnxruntime_USE_NCCL OFF) set(onnxruntime_USE_MPI OFF) -message( WARNING "MPI and NCCL disabled on Win build." ) + message( WARNING "MPI and NCCL are disabled because build is on Windows or USE_NCCL is set to OFF." ) endif() if (onnxruntime_USE_MPI) @@ -1726,14 +1741,12 @@ if(onnxruntime_BUILD_KERNEL_EXPLORER) endif() # When GDK_PLATFORM is set then WINAPI_FAMILY is defined in gdk_toolchain.cmake (along with other relevant flags/definitions). -if (WIN32 AND NOT GDK_PLATFORM) +if (WIN32 AND NOT GDK_PLATFORM AND NOT CMAKE_CROSSCOMPILING) if (NOT CMAKE_CXX_STANDARD_LIBRARIES MATCHES kernel32.lib) # On onecore, link to the onecore build of the MSVC runtime get_filename_component(msvc_path "${CMAKE_C_COMPILER}/../../../.." ABSOLUTE) link_directories(BEFORE "${msvc_path}/lib/onecore/${onnxruntime_target_platform}") - # The .lib files in the MSVC runtime have a DEFAULITLIB entry for onecore.lib, which in turn links to reverse forwarders. - # We ignore that entry and use onecore_apiset.lib instead, since system components must not rely on reverse forwarders. - add_link_options("/NODEFAULTLIB:onecore.lib") + # The .lib files in the MSVC runtime have a DEFAULITLIB entry for onecore.lib, but it shold not cause any conflict with onecoreuap.lib endif() endif() diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 2c7bf9f1c2f5c..d3f9256105127 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -92,7 +92,7 @@ if (onnxruntime_MINIMAL_BUILD) endif() endif() -# enable stream for all the non-minimal build +# Enable stream for all the non-minimal build if (NOT onnxruntime_MINIMAL_BUILD) add_compile_definitions(ORT_ENABLE_STREAM) endif() @@ -205,7 +205,7 @@ endif() macro(check_nvcc_compiler_flag _FLAG _RESULT) - execute_process(COMMAND ${onnxruntime_CUDA_HOME}/bin/nvcc "${_FLAG}" RESULT_VARIABLE NVCC_OUT ERROR_VARIABLE NVCC_ERROR) + execute_process(COMMAND ${CUDAToolkit_BIN_DIR}/nvcc "${_FLAG}" RESULT_VARIABLE NVCC_OUT ERROR_VARIABLE NVCC_ERROR) message("NVCC_ERROR = ${NVCC_ERROR}") message("NVCC_OUT = ${NVCC_OUT}") if ("${NVCC_OUT}" MATCHES "0") diff --git a/cmake/deps.txt b/cmake/deps.txt index ba9c2bb73cf7a..4111689c5def9 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -13,6 +13,7 @@ # See https://microsoft.sharepoint.com/teams/ONNX2/_layouts/OneNote.aspx?id=%2Fteams%2FONNX2%2FShared%20Documents%2FNotebooks%2FONNX%20Ecosystem%20Team%20Notebook&wd=target%28Development.one%7C63D3AB47-51D1-4A62-9965-66882234BD44%2FAdd%20or%20update%20a%20dependency%20in%20deps.txt%7C0E9ED71D-89D5-40FA-B05F-C0123289C591%2F%29 # abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20240116.0.zip;bc2cec6baaad67fcb6c0c38972b687d4797927e9 +coremltools;https://github.com/apple/coremltools/archive/refs/tags/7.1.zip;f1bab0f30966f2e217d8e01207d518f230a1641a cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0 date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159 dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b31321e5549591d78aa7f377173445 @@ -22,10 +23,10 @@ dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b3132 # Until the 3.4.1 release this is the best option we have. # Issue link: https://gitlab.com/libeigen/eigen/-/issues/2744 eigen;https://gitlab.com/libeigen/eigen/-/archive/e7248b26a1ed53fa030c5c459f7ea095dfd276ac/eigen-e7248b26a1ed53fa030c5c459f7ea095dfd276ac.zip;be8be39fdbc6e60e94fa7870b280707069b5b81a -flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v1.12.0.zip;ba0a75fd12dbef8f6557a74e611b7a3d0c5fe7bf +flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v23.5.26.zip;59422c3b5e573dd192fead2834d25951f1c1670c fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip;b985f6985a05a1c03ff1bb71190f66d8f98a1494 fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1 -google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.7.0.zip;e97c368b176e8614e3f1bf13dd9abcf6a7ad9908 +google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.8.3.zip;bf9870756ee3f8d2d3b346b24ee3600a41c74d3d google_nsync;https://github.com/google/nsync/archive/refs/tags/1.26.0.zip;5e7c00ef6bf5b787386fc040067903ec774e2752 googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034 googlexnnpack;https://github.com/google/XNNPACK/archive/0da379fc4808f9601faef392352018c741c0f297.zip;663883491e380b628e0a5b162b5f2658032fae73 @@ -36,8 +37,8 @@ mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 neural_speed;https://github.com/intel/neural-speed/archive/refs/tags/bestlav0.1.1.zip;65b0f7a0d04f72f0d5a8d48af70f0366f2ab3939 onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.15.0.zip;54c3f960a0541c5d8d3e60c2933e11f5d3688a11 -#use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) -onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/a43ce67187bab219520fd80f21af8bbd4354bc8c.zip;572535aefef477050f86744dfab1fef840198035 +#use the commit of Final DDS removal. DDS output is now supported by ORT TRT. +onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/bacfaaa951653cd4e72efe727a543567cb38f7de.zip;26434329612e804164ab7baa6ae629ada56c1b26 protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa protoc_win64;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip;b4521f7ada5b260380f94c4bd7f1b7684c76969a protoc_win32;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win32.zip;3688010318192c46ce73213cdfb6b3e5656da874 @@ -55,4 +56,4 @@ tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2 cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.1.0.zip;757f90a795034a89d4f48a79d1f009f7a04c8dee utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156 extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c -composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299 \ No newline at end of file +composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299 diff --git a/cmake/deps_update_and_upload.py b/cmake/deps_update_and_upload.py index d357284d91225..63df3f6f03869 100644 --- a/cmake/deps_update_and_upload.py +++ b/cmake/deps_update_and_upload.py @@ -1,56 +1,109 @@ -# in case deps.txt is updated, run this file to update and upload the dependencies so that CI can use them. -# Before running the script, increase the version number found at: +# If deps.txt is updated, run this file to update and upload the dependencies so that CI can use them. +# +# Before running the script, find the latest version number at: # https://aiinfra.visualstudio.com/Lotus/_artifacts/feed/Lotus/UPack/onnxruntime_build_dependencies/versions +# Increment it to obtain a new version number to use. +# # Run without --do-upload once to verify downloading. Use --do-upload when you are ready to publish. -# python cmake/deps_update_and_upload.py --root-path C:/temp/onnxruntime_deps --version 1.0.82 --do-upload -# update version number in tools\ci_build\github\azure-pipelines\templates\download-deps.yml +# E.g.: +# python cmake/deps_update_and_upload.py --root-path C:/temp/onnxruntime_deps --version 1.0.82 +# # check contents of C:/temp/onnxruntime_deps +# python cmake/deps_update_and_upload.py --root-path C:/temp/onnxruntime_deps --version 1.0.82 --no-download --do-upload +# +# Next, update the version number in tools/ci_build/github/azure-pipelines/templates/download-deps.yml. + +import argparse +import contextlib +import pathlib import re import subprocess -import os -import argparse import tempfile +script_dir = pathlib.Path(__file__).parent + parser = argparse.ArgumentParser(description="Update dependencies and publish to Azure Artifacts") parser.add_argument( - "--root-path", type=str, default=tempfile.gettempdir(), help="Target root path for downloaded files" + "--root-path", + type=pathlib.Path, + help="Target root path for downloaded files. If not provided, a temporary directory is used.", +) +parser.add_argument( + "--version", + type=str, + help="Package version to publish", +) +parser.add_argument( + "--do-upload", + action="store_true", + dest="upload", + help="Upload the package to Azure Artifacts", +) +parser.add_argument( + "--no-download", + action="store_false", + dest="download", + help="Skip downloading the dependency files. " + "Use with '--do-upload' and '--root-path' to upload the package from existing dependency files.", ) -parser.add_argument("--version", type=str, default="1.0.82", help="Package version to publish") -parser.add_argument("--do-upload", action="store_true", help="Upload the package to Azure Artifacts") args = parser.parse_args() -with open("cmake/deps.txt") as file: +if args.upload: + assert args.version is not None, "'--version' must be specified if uploading." + +if args.upload != args.download: + assert args.root_path is not None, "'--root-path' must be specified if only downloading or uploading." + +deps_path = script_dir / "deps.txt" +with open(deps_path) as file: text = file.read() lines = [line for line in text.split("\n") if not line.startswith("#") and ";" in line] -root_path = args.root_path - -for line in lines: - url = re.sub("^[^;]+?;https://([^;]+?);.*", r"https://\1", line) - filename = re.sub("^[^;]+?;https://([^;]+?);.*", r"\1", line) - full_path = os.path.join(root_path, filename) - subprocess.run(["curl", "-sSL", "--create-dirs", "-o", full_path, url]) # noqa: PLW1510 - -package_name = "onnxruntime_build_dependencies" -version = args.version - -# Check if the user is logged in to Azure -result = subprocess.run("az account show", shell=True, capture_output=True, text=True) # noqa: PLW1510 -if "No subscriptions found" in result.stderr: - # Prompt the user to log in to Azure - print("You are not logged in to Azure. Please log in to continue.") - subprocess.run("az login", shell=True) # noqa: PLW1510 - -# Publish the package to Azure Artifacts if --no-upload is not specified - -cmd = f'az artifacts universal publish --organization https://dev.azure.com/onnxruntime --feed onnxruntime --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}' -if args.do_upload: - subprocess.run(cmd, shell=True) # noqa: PLW1510 -else: - print("would have run: " + cmd) - -cmd = f'az artifacts universal publish --organization https://dev.azure.com/aiinfra --feed Lotus --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}' -if args.do_upload: - subprocess.run(cmd, shell=True) # noqa: PLW1510 -else: - print("would have run: " + cmd) +with contextlib.ExitStack() as context_stack: + if args.root_path is not None: + root_path = args.root_path.resolve() + root_path.mkdir(parents=True, exist_ok=True) + else: + temp_dir_name = context_stack.enter_context(tempfile.TemporaryDirectory()) + root_path = pathlib.Path(temp_dir_name) + + if args.download: + print(f"Downloading dependencies to directory: {root_path}") + + dep_pattern = re.compile(r"^[^;]+;https://([^;]+);.*$") + + for line in lines: + match = dep_pattern.fullmatch(line) + if match is None: + continue + + dep_path = match[1] + url = f"https://{dep_path}" + full_path = root_path / dep_path + + subprocess.run(["curl", "-sSL", "--create-dirs", "-o", str(full_path), url], check=True) + + package_name = "onnxruntime_build_dependencies" + version = args.version if args.version is not None else "VERSION_PLACEHOLDER" + + if args.upload: + # Check if the user is logged in to Azure + result = subprocess.run("az account show", shell=True, capture_output=True, text=True, check=False) + if "No subscriptions found" in result.stderr: + # Prompt the user to log in to Azure + print("You are not logged in to Azure. Please log in to continue.") + subprocess.run("az login", shell=True, check=True) + + # Publish the package to Azure Artifacts if --do-upload is specified + + cmd = f'az artifacts universal publish --organization https://dev.azure.com/onnxruntime --feed onnxruntime --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}' + if args.upload: + subprocess.run(cmd, shell=True, check=True) + else: + print("would have run: " + cmd) + + cmd = f'az artifacts universal publish --organization https://dev.azure.com/aiinfra --feed Lotus --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}' + if args.upload: + subprocess.run(cmd, shell=True, check=True) + else: + print("would have run: " + cmd) diff --git a/cmake/external/dnnl.cmake b/cmake/external/dnnl.cmake index d7b70640781d0..9eb5fed7a1af6 100644 --- a/cmake/external/dnnl.cmake +++ b/cmake/external/dnnl.cmake @@ -2,7 +2,7 @@ include (ExternalProject) set(DNNL_URL https://github.com/oneapi-src/onednn.git) # If DNNL_TAG is updated, check if MKLML_VERSION and platform.cmake.patch need to be updated. -set(DNNL_TAG v3.0) +set(DNNL_TAG v3.0.1) if(WIN32) set(DNNL_SHARED_LIB dnnl.dll) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 403b4b2c4107a..ac1e187f357aa 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -37,8 +37,13 @@ if (onnxruntime_BUILD_UNIT_TESTS) set(gtest_disable_pthreads ON) endif() set(INSTALL_GTEST OFF CACHE BOOL "" FORCE) - if (CMAKE_SYSTEM_NAME STREQUAL "iOS") - # Needs to update onnxruntime/test/xctest/xcgtest.mm + if (IOS OR ANDROID) + # on mobile platforms the absl flags class dumps the flag names (assumably for binary size), which breaks passing + # any args to gtest executables, such as using --gtest_filter to debug a specific test. + # Processing of compile definitions: + # https://github.com/abseil/abseil-cpp/blob/8dc90ff07402cd027daec520bb77f46e51855889/absl/flags/config.h#L21 + # If set, this code throws away the flag and does nothing on registration, which results in no flags being known: + # https://github.com/abseil/abseil-cpp/blob/8dc90ff07402cd027daec520bb77f46e51855889/absl/flags/flag.h#L205-L217 set(GTEST_HAS_ABSL OFF CACHE BOOL "" FORCE) else() set(GTEST_HAS_ABSL ON CACHE BOOL "" FORCE) @@ -104,7 +109,7 @@ FetchContent_Declare( URL ${DEP_URL_flatbuffers} URL_HASH SHA1=${DEP_SHA1_flatbuffers} PATCH_COMMAND ${ONNXRUNTIME_FLATBUFFERS_PATCH_COMMAND} - FIND_PACKAGE_ARGS 1.12.0...<2.0.0 NAMES Flatbuffers + FIND_PACKAGE_ARGS 23.5.9 NAMES Flatbuffers ) # Download a protoc binary from Internet if needed @@ -224,8 +229,6 @@ FetchContent_Declare( URL_HASH SHA1=${DEP_SHA1_mp11} ) -set(JSON_BuildTests OFF CACHE INTERNAL "") -set(JSON_Install OFF CACHE INTERNAL "") set(JSON_BuildTests OFF CACHE INTERNAL "") set(JSON_Install OFF CACHE INTERNAL "") @@ -258,14 +261,7 @@ if (onnxruntime_ENABLE_CPUINFO) set(CPUINFO_SUPPORTED TRUE) endif() if (WIN32) - # Exclude Windows ARM build and Windows Store - if (${onnxruntime_target_platform} MATCHES "^(ARM.*|arm.*)$" ) - message(WARNING "Cpuinfo not included for compilation problems with Windows ARM.") - set(CPUINFO_SUPPORTED FALSE) - elseif (WIN32 AND NOT CMAKE_CXX_STANDARD_LIBRARIES MATCHES kernel32.lib) - message(WARNING "Cpuinfo not included non-Desktop builds") - set(CPUINFO_SUPPORTED FALSE) - endif() + set(CPUINFO_SUPPORTED TRUE) elseif (NOT ${onnxruntime_target_platform} MATCHES "^(i[3-6]86|AMD64|x86(_64)?|armv[5-8].*|aarch64|arm64)$") message(WARNING "Target processor architecture \"${onnxruntime_target_platform}\" is not supported in cpuinfo. " @@ -309,13 +305,23 @@ if (CPUINFO_SUPPORTED) set(CPUINFO_BUILD_UNIT_TESTS OFF CACHE INTERNAL "") set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE INTERNAL "") set(CPUINFO_BUILD_BENCHMARKS OFF CACHE INTERNAL "") - - FetchContent_Declare( - pytorch_cpuinfo - URL ${DEP_URL_pytorch_cpuinfo} - URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo} - FIND_PACKAGE_ARGS NAMES cpuinfo - ) + if(onnxruntime_target_platform STREQUAL "ARM64EC") + message("Applying a patch for Windows ARM64EC in cpuinfo") + FetchContent_Declare( + pytorch_cpuinfo + URL ${DEP_URL_pytorch_cpuinfo} + URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo} + PATCH_COMMAND ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/9bb12d342fd9479679d505d93a478a6f9cd50a47.patch + FIND_PACKAGE_ARGS NAMES cpuinfo + ) + else() + FetchContent_Declare( + pytorch_cpuinfo + URL ${DEP_URL_pytorch_cpuinfo} + URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo} + FIND_PACKAGE_ARGS NAMES cpuinfo + ) + endif() set(ONNXRUNTIME_CPUINFO_PROJ pytorch_cpuinfo) endif() @@ -541,22 +547,32 @@ if(onnxruntime_ENABLE_TRAINING OR (onnxruntime_ENABLE_TRAINING_APIS AND onnxrunt onnxruntime_fetchcontent_makeavailable(cxxopts) endif() +if (onnxruntime_USE_COREML) + FetchContent_Declare( + coremltools + URL ${DEP_URL_coremltools} + URL_HASH SHA1=${DEP_SHA1_coremltools} + PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/coremltools/crossplatformbuild.patch + ) + # we don't build directly so use Populate. selected files are built from onnxruntime_providers_coreml.cmake + FetchContent_Populate(coremltools) +endif() + message("Finished fetching external dependencies") set(onnxruntime_LINK_DIRS ) if (onnxruntime_USE_CUDA) #TODO: combine onnxruntime_CUDNN_HOME and onnxruntime_CUDA_HOME, assume they are the same + find_package(CUDAToolkit REQUIRED) if (WIN32) if(onnxruntime_CUDNN_HOME) list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib/x64) endif() - list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDA_HOME}/x64/lib64) else() if(onnxruntime_CUDNN_HOME) list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib64) endif() - list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDA_HOME}/lib64) endif() endif() diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index c900f4d4b09a5..2ead13e554197 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -189,7 +189,6 @@ set(onnxruntime_INTERNAL_LIBRARIES ${PROVIDERS_SNPE} ${PROVIDERS_TVM} ${PROVIDERS_RKNPU} - ${PROVIDERS_VITISAI} ${PROVIDERS_XNNPACK} ${PROVIDERS_WEBNN} ${PROVIDERS_AZURE} diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake index 6b8c2560b1714..fb56e3f3445d4 100644 --- a/cmake/onnxruntime_common.cmake +++ b/cmake/onnxruntime_common.cmake @@ -201,10 +201,6 @@ endif() if (RISCV64 OR ARM64 OR ARM OR X86 OR X64 OR X86_64) - if((WIN32 AND NOT CMAKE_CXX_STANDARD_LIBRARIES MATCHES kernel32.lib) OR ((ARM64 OR ARM) AND MSVC)) - # msvc compiler report syntax error with cpuinfo arm source files - # and cpuinfo does not have code for getting arm uarch info under windows - else() # Link cpuinfo if supported # Using it mainly in ARM with Android. # Its functionality in detecting x86 cpu features are lacking, so is support for Windows. @@ -212,7 +208,6 @@ if (RISCV64 OR ARM64 OR ARM OR X86 OR X64 OR X86_64) onnxruntime_add_include_to_target(onnxruntime_common cpuinfo::cpuinfo) list(APPEND onnxruntime_EXTERNAL_LIBRARIES cpuinfo::cpuinfo ${ONNXRUNTIME_CLOG_TARGET_NAME}) endif() - endif() endif() if (NOT onnxruntime_BUILD_SHARED_LIB) diff --git a/cmake/onnxruntime_graph.cmake b/cmake/onnxruntime_graph.cmake index 3f532ec2c3261..4d51325b8414e 100644 --- a/cmake/onnxruntime_graph.cmake +++ b/cmake/onnxruntime_graph.cmake @@ -7,8 +7,26 @@ file(GLOB_RECURSE onnxruntime_graph_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/graph/*.cc" ) -# create empty list for any excludes +# start with empty training srcs list +set(orttraining_graph_src) + +if (onnxruntime_ENABLE_TRAINING_OPS AND NOT onnxruntime_ENABLE_TRAINING) + set(orttraining_graph_src + "${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.cc" + "${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.h" + ) +endif() + +if (onnxruntime_ENABLE_TRAINING) + file(GLOB_RECURSE orttraining_graph_src CONFIGURE_DEPENDS + "${ORTTRAINING_SOURCE_DIR}/core/graph/*.h" + "${ORTTRAINING_SOURCE_DIR}/core/graph/*.cc" + ) +endif() + +# create empty lists for any excludes set(onnxruntime_graph_src_exclude_patterns) +set(orttraining_graph_src_exclude_patterns) if (onnxruntime_MINIMAL_BUILD) # remove schema registration support @@ -22,11 +40,18 @@ if (onnxruntime_MINIMAL_BUILD) "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/onnx_function_util.cc" "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/shape_inference_functions.h" "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/shape_inference_functions.cc" + "${ONNXRUNTIME_ROOT}/core/graph/dml_ops/dml_defs.h" + "${ONNXRUNTIME_ROOT}/core/graph/dml_ops/dml_defs.cc" "${ONNXRUNTIME_ROOT}/core/graph/function_template.h" "${ONNXRUNTIME_ROOT}/core/graph/function_utils.h" "${ONNXRUNTIME_ROOT}/core/graph/function_utils.cc" ) + list(APPEND orttraining_graph_src_exclude_patterns + "${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.h" + "${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.cc" + ) + # no Function support initially list(APPEND onnxruntime_graph_src_exclude_patterns "${ONNXRUNTIME_ROOT}/core/graph/function*" @@ -64,30 +89,12 @@ endif() file(GLOB onnxruntime_graph_src_exclude ${onnxruntime_graph_src_exclude_patterns}) list(REMOVE_ITEM onnxruntime_graph_src ${onnxruntime_graph_src_exclude}) -file(GLOB_RECURSE onnxruntime_ir_defs_src CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/defs/*.cc" -) - -if (onnxruntime_ENABLE_TRAINING_OPS AND NOT onnxruntime_ENABLE_TRAINING) - set(orttraining_graph_src - "${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.cc" - "${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.h" - ) -endif() - -if (onnxruntime_ENABLE_TRAINING) - file(GLOB_RECURSE orttraining_graph_src CONFIGURE_DEPENDS - "${ORTTRAINING_SOURCE_DIR}/core/graph/*.h" - "${ORTTRAINING_SOURCE_DIR}/core/graph/*.cc" - ) -endif() - -set(onnxruntime_graph_lib_src ${onnxruntime_graph_src} ${onnxruntime_ir_defs_src}) if (onnxruntime_ENABLE_TRAINING_OPS) - list(APPEND onnxruntime_graph_lib_src ${orttraining_graph_src}) + file(GLOB orttraining_graph_src_exclude ${orttraining_graph_src_exclude_patterns}) + list(REMOVE_ITEM orttraining_graph_src ${orttraining_graph_src_exclude}) endif() -onnxruntime_add_static_library(onnxruntime_graph ${onnxruntime_graph_lib_src}) +onnxruntime_add_static_library(onnxruntime_graph ${onnxruntime_graph_src} ${orttraining_graph_src}) add_dependencies(onnxruntime_graph onnx_proto flatbuffers::flatbuffers) onnxruntime_add_include_to_target(onnxruntime_graph onnxruntime_common ${WIL_TARGET} onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers safeint_interface Boost::mp11) @@ -120,7 +127,7 @@ endif() set_target_properties(onnxruntime_graph PROPERTIES FOLDER "ONNXRuntime") set_target_properties(onnxruntime_graph PROPERTIES LINKER_LANGUAGE CXX) -source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_graph_src} ${onnxruntime_ir_defs_src}) +source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_graph_src}) if (onnxruntime_ENABLE_TRAINING_OPS) source_group(TREE ${ORTTRAINING_ROOT} FILES ${orttraining_graph_src}) endif() diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 8d3ea403fb74b..7e7819ac31a19 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -66,11 +66,7 @@ if(onnxruntime_USE_CUDA) set(PROVIDERS_CUDA onnxruntime_providers_cuda) endif() if(onnxruntime_USE_COREML) - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - set(PROVIDERS_COREML onnxruntime_providers_coreml onnxruntime_coreml_proto) - else() - set(PROVIDERS_COREML onnxruntime_providers_coreml) - endif() + set(PROVIDERS_COREML onnxruntime_providers_coreml coreml_proto) endif() if(onnxruntime_USE_NNAPI_BUILTIN) set(PROVIDERS_NNAPI onnxruntime_providers_nnapi) diff --git a/cmake/onnxruntime_providers_coreml.cmake b/cmake/onnxruntime_providers_coreml.cmake index aa8c35526b274..b8ebc4ca53239 100644 --- a/cmake/onnxruntime_providers_coreml.cmake +++ b/cmake/onnxruntime_providers_coreml.cmake @@ -1,107 +1,220 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. - if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD) - message(FATAL_ERROR "CoreML EP can not be used in a basic minimal build. Please build with '--minimal_build extended'") - endif() +if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD) + message(FATAL_ERROR "CoreML EP can not be used in a basic minimal build. Please build with '--minimal_build extended'") +endif() + +add_compile_definitions(USE_COREML=1) - add_compile_definitions(USE_COREML=1) - - # Compile CoreML proto definition to ${CMAKE_CURRENT_BINARY_DIR}/coreml - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - set(COREML_PROTO_ROOT ${PROJECT_SOURCE_DIR}/../onnxruntime/core/providers/coreml/mlmodel_format) - file(GLOB coreml_proto_srcs - "${COREML_PROTO_ROOT}/*.proto" - ) - onnxruntime_add_static_library(onnxruntime_coreml_proto ${coreml_proto_srcs}) - target_include_directories(onnxruntime_coreml_proto PUBLIC $ "${CMAKE_CURRENT_BINARY_DIR}") - target_compile_definitions(onnxruntime_coreml_proto PUBLIC $) - set_target_properties(onnxruntime_coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility=hidden") - set_target_properties(onnxruntime_coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility-inlines-hidden") - set(_src_sub_dir "coreml/") - onnxruntime_protobuf_generate( - APPEND_PATH - GEN_SRC_SUB_DIR ${_src_sub_dir} - IMPORT_DIRS ${COREML_PROTO_ROOT} - TARGET onnxruntime_coreml_proto - ) - - if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_coreml_proto - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR} - ) - endif() +# Check if we can build the coremltools code for creating an mlpackage with an mlprogram. +# The coremltools source requires std::filesystem::path which is only available from iOS 13 on. +set(_enable_ML_PROGRAM ON) +if (IOS AND CMAKE_OSX_DEPLOYMENT_TARGET VERSION_LESS 13.0) + message(WARNING "CoreML ML Program is not supported on iOS < 13.0. Excluding ML Program support from build.") + set(_enable_ML_PROGRAM OFF) +elseif(LINUX) + # uuid-dev is required. we don't bother installing on CIs as it's really for manual developer testing. + find_library(LibUUID_LIBRARY NAMES uuid) + find_path(LibUUID_INCLUDE_DIR NAMES uuid/uuid.h) + if (NOT LibUUID_INCLUDE_DIR) + message(STATUS "uuid/uuid.h was not found as is required for ML Program support. " + "Run `sudo apt install uuid-dev` if you need to test ML Program related CoreML EP code. ") + set(_enable_ML_PROGRAM OFF) endif() +endif() + +if (_enable_ML_PROGRAM) + add_compile_definitions(COREML_ENABLE_MLPROGRAM=1) +endif() + +# Compile CoreML proto definition to ${CMAKE_CURRENT_BINARY_DIR}/coreml_proto +set(COREML_PROTO_ROOT ${coremltools_SOURCE_DIR}/mlmodel/format) +file(GLOB coreml_proto_srcs "${COREML_PROTO_ROOT}/*.proto") + +onnxruntime_add_static_library(coreml_proto ${coreml_proto_srcs}) +target_include_directories(coreml_proto + PUBLIC $ + "${CMAKE_CURRENT_BINARY_DIR}") +target_compile_definitions(coreml_proto + PUBLIC $) +set_target_properties(coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility=hidden") +set_target_properties(coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility-inlines-hidden") - # These are shared utils, - # TODO, move this to a separated lib when used by EPs other than NNAPI and CoreML - file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" +set(_src_sub_dir "coreml_proto/") +onnxruntime_protobuf_generate( + APPEND_PATH + GEN_SRC_SUB_DIR ${_src_sub_dir} + IMPORT_DIRS ${COREML_PROTO_ROOT} + TARGET coreml_proto +) + +if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS coreml_proto + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR} ) +endif() + +# Add the .proto and generated .cc/.h files to the External/coreml_proto folder in Visual Studio. +# Separate source_group for each as the .proto files are in the repo and the .cc/.h files are generated in the build +# output directory. +set_target_properties(coreml_proto PROPERTIES FOLDER "External") +source_group(TREE ${COREML_PROTO_ROOT} PREFIX coreml_proto FILES ${coreml_proto_srcs}) + +# filter to the generated .cc/.h files +get_target_property(coreml_proto_generated_srcs coreml_proto SOURCES) +list(FILTER coreml_proto_generated_srcs INCLUDE REGEX "\.pb\.(h|cc)$") +source_group(TREE ${CMAKE_CURRENT_BINARY_DIR} PREFIX coreml_proto_generated FILES ${coreml_proto_generated_srcs}) + +# These are shared utils, +# TODO, move this to a separate lib when used by EPs other than NNAPI and CoreML +file(GLOB onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" +) +file(GLOB onnxruntime_providers_coreml_public_headers CONFIGURE_DEPENDS + "${ONNXRUNTIME_INCLUDE_DIR}/core/providers/coreml/*.h" +) + +file(GLOB + onnxruntime_providers_coreml_cc_srcs_top CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/coreml/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/*.cc" +) + +# Add builder source code +file(GLOB_RECURSE + onnxruntime_providers_coreml_cc_srcs_nested CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.cc" +) + +if(_enable_ML_PROGRAM) + # Add helpers to create mlpackage weights. limit to just the files we need to minimize the changes to make them + # build on Windows and Linux. file(GLOB - onnxruntime_providers_coreml_cc_srcs_top CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/coreml/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/*.cc" + onnxruntime_providers_coreml_milblob_cc_srcs CONFIGURE_DEPENDS + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/*.hpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/*.cpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Util/*.hpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/BlobDataType.hpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/StorageFormat.hpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/FileWriter.?pp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/StorageWriter.?pp" ) - # Add builder source code - file(GLOB_RECURSE - onnxruntime_providers_coreml_cc_srcs_nested CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.cc" + # Add helpers to create mlpackage + file(GLOB + onnxruntime_providers_coreml_modelpackage_cc_srcs CONFIGURE_DEPENDS + "${coremltools_SOURCE_DIR}/modelpackage/src/ModelPackage.?pp" + "${coremltools_SOURCE_DIR}/modelpackage/src/utils/JsonMap.?pp" ) - if (NOT CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND NOT CMAKE_SYSTEM_NAME STREQUAL "iOS") - list(REMOVE_ITEM onnxruntime_providers_coreml_cc_srcs_nested - "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/model_builder.h" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/model_builder.cc" - ) - endif() - - # Add CoreML objective c++ source code - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - file(GLOB - onnxruntime_providers_coreml_objcc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.h" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.mm" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.h" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.mm" - ) - endif() - set(onnxruntime_providers_coreml_cc_srcs - ${onnxruntime_providers_coreml_cc_srcs_top} - ${onnxruntime_providers_coreml_cc_srcs_nested} - ${onnxruntime_providers_shared_utils_cc_srcs} + set(coremltools_srcs + ${onnxruntime_providers_coreml_milblob_cc_srcs} + ${onnxruntime_providers_coreml_modelpackage_cc_srcs} ) - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_coreml_cc_srcs}) - onnxruntime_add_static_library(onnxruntime_providers_coreml - ${onnxruntime_providers_coreml_cc_srcs} ${onnxruntime_providers_coreml_objcc_srcs} + source_group(TREE ${coremltools_SOURCE_DIR} PREFIX coremltools FILES ${coremltools_srcs}) +endif() + +# Add CoreML objective c++ source code +if (APPLE) + file(GLOB + onnxruntime_providers_coreml_objcc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.mm" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.mm" ) - onnxruntime_add_include_to_target(onnxruntime_providers_coreml - onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface +else() + # add the Model implementation that uses the protobuf types but excludes any actual CoreML dependencies + # by using stub implementations on non-Apple platforms. + file(GLOB + onnxruntime_providers_coreml_objcc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils_stub.cc" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model_stub.cc" ) - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - onnxruntime_add_include_to_target(onnxruntime_providers_coreml onnxruntime_coreml_proto) - target_link_libraries(onnxruntime_providers_coreml PRIVATE onnxruntime_coreml_proto "-framework Foundation" "-framework CoreML") - add_dependencies(onnxruntime_providers_coreml onnxruntime_coreml_proto) +endif() + +set(onnxruntime_providers_coreml_cc_srcs + ${onnxruntime_providers_coreml_cc_srcs_top} + ${onnxruntime_providers_coreml_cc_srcs_nested} + ${onnxruntime_providers_shared_utils_cc_srcs} + ${onnxruntime_providers_coreml_objcc_srcs} +) + +source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_providers_coreml_cc_srcs}) +source_group(TREE ${ONNXRUNTIME_INCLUDE_DIR} FILES ${onnxruntime_providers_coreml_public_headers}) + +onnxruntime_add_static_library(onnxruntime_providers_coreml + ${onnxruntime_providers_coreml_public_headers} + ${onnxruntime_providers_coreml_cc_srcs} + ${coremltools_srcs} +) + +onnxruntime_add_include_to_target(onnxruntime_providers_coreml + onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 + safeint_interface +) + +onnxruntime_add_include_to_target(onnxruntime_providers_coreml coreml_proto) +target_link_libraries(onnxruntime_providers_coreml PRIVATE coreml_proto) +add_dependencies(onnxruntime_providers_coreml coreml_proto) + +if (APPLE) + target_compile_definitions(onnxruntime_providers_coreml PRIVATE __APPLE__) +endif() + +if (_enable_ML_PROGRAM) + # Setup coremltools fp16 and json dependencies for creating an mlpackage. + # + # These are also used by external/xnnpack.cmake. fp16 depends on psimd + FetchContent_Declare(psimd URL ${DEP_URL_psimd} URL_HASH SHA1=${DEP_SHA1_psimd}) + onnxruntime_fetchcontent_makeavailable(psimd) + set(PSIMD_SOURCE_DIR ${psimd_SOURCE_DIR}) + FetchContent_Declare(fp16 URL ${DEP_URL_fp16} URL_HASH SHA1=${DEP_SHA1_fp16}) + set(FP16_BUILD_TESTS OFF CACHE INTERNAL "") + set(FP16_BUILD_BENCHMARKS OFF CACHE INTERNAL "") + onnxruntime_fetchcontent_makeavailable(fp16) + + # need to tweak the include paths to match what the coreml source code expects + target_include_directories(onnxruntime_providers_coreml PRIVATE + ${fp16_SOURCE_DIR}/include + ${nlohmann_json_SOURCE_DIR}/single_include/nlohmann + ${coremltools_SOURCE_DIR} + ${coremltools_SOURCE_DIR}/mlmodel/src/ + ${coremltools_SOURCE_DIR}/modelpackage/src/ + ) + + add_dependencies(onnxruntime_providers_coreml nlohmann_json::nlohmann_json fp16) + + if (LINUX) + target_link_libraries(onnxruntime_providers_coreml PRIVATE uuid) endif() - add_dependencies(onnxruntime_providers_coreml ${onnxruntime_EXTERNAL_DEPENDENCIES}) - - set_target_properties(onnxruntime_providers_coreml PROPERTIES CXX_STANDARD_REQUIRED ON) - set_target_properties(onnxruntime_providers_coreml PROPERTIES FOLDER "ONNXRuntime") - target_include_directories(onnxruntime_providers_coreml PRIVATE ${ONNXRUNTIME_ROOT} ${coreml_INCLUDE_DIRS}) - set_target_properties(onnxruntime_providers_coreml PROPERTIES LINKER_LANGUAGE CXX) - - if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_coreml - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() \ No newline at end of file +endif() + +if (APPLE) + target_link_libraries(onnxruntime_providers_coreml PRIVATE "-framework Foundation" "-framework CoreML") +endif() + +add_dependencies(onnxruntime_providers_coreml ${onnxruntime_EXTERNAL_DEPENDENCIES}) + +set_target_properties(onnxruntime_providers_coreml PROPERTIES CXX_STANDARD_REQUIRED ON) +set_target_properties(onnxruntime_providers_coreml PROPERTIES FOLDER "ONNXRuntime") +target_include_directories(onnxruntime_providers_coreml PRIVATE ${ONNXRUNTIME_ROOT} ${coreml_INCLUDE_DIRS}) +set_target_properties(onnxruntime_providers_coreml PROPERTIES LINKER_LANGUAGE CXX) + +if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS onnxruntime_providers_coreml + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) +endif() diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 9887d615c92d7..aeeac10ead27d 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -141,18 +141,22 @@ if (HAS_GUARD_CF) target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /guard:cf>") endif() + if (HAS_QSPECTRE) target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /Qspectre>") endif() + foreach(ORT_FLAG ${ORT_WARNING_FLAGS}) target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler \"${ORT_FLAG}\">") endforeach() + # CUDA 11.3+ supports parallel compilation # https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-guiding-compiler-driver-threads if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.3) option(onnxruntime_NVCC_THREADS "Number of threads that NVCC can use for compilation." 1) target_compile_options(${target} PRIVATE "$<$:SHELL:--threads \"${onnxruntime_NVCC_THREADS}\">") endif() + if (UNIX) target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler -Wno-reorder>" "$<$>:-Wno-reorder>") @@ -162,6 +166,13 @@ #mutex.cuh(91): warning C4834: discarding return value of function with 'nodiscard' attribute target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /wd4834>") target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /wd4127>") + if (MSVC) + # the VS warnings for 'Conditional Expression is Constant' are spurious as they don't handle multiple conditions + # e.g. `if (std::is_same_v && not_a_const)` will generate the warning even though constexpr cannot + # be used due to `&& not_a_const`. This affects too many places for it to be reasonable to disable at a finer + # granularity. + target_compile_options(${target} PRIVATE "$<$:/wd4127>") + endif() endif() onnxruntime_add_include_to_target(${target} onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers) @@ -178,9 +189,10 @@ add_dependencies(${target} onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) if(onnxruntime_CUDA_MINIMAL) target_compile_definitions(${target} PRIVATE USE_CUDA_MINIMAL) - target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) + target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface CUDA::cudart) else() - target_link_libraries(${target} PRIVATE cublasLt cublas cudnn curand cufft ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) + target_link_libraries(${target} PRIVATE CUDA::cublasLt CUDA::cublas cudnn CUDA::curand CUDA::cufft CUDA::cudart + ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) if(onnxruntime_CUDNN_HOME) target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include) target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib) @@ -196,25 +208,24 @@ target_include_directories(${target} PRIVATE ${triton_kernel_header_dir}) target_link_libraries(${target} PUBLIC -Wl,--whole-archive ${triton_kernel_obj_file} -Wl,--no-whole-archive) # lib cuda needed by cuLaunchKernel - target_link_libraries(${target} PRIVATE cuda) + target_link_libraries(${target} PRIVATE CUDA::cuda_driver) endif() include(cutlass) - target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples) + target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include) - target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} + PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found set_target_properties(${target} PROPERTIES LINKER_LANGUAGE CUDA) set_target_properties(${target} PROPERTIES FOLDER "ONNXRuntime") if (onnxruntime_ENABLE_CUDA_PROFILING) # configure cupti for cuda profiling - target_include_directories(${target} PRIVATE ${onnxruntime_CUDA_HOME}/extras/CUPTI/include) - target_link_directories(${target} PRIVATE ${onnxruntime_CUDA_HOME}/extras/CUPTI/lib64) - target_link_libraries(${target} PRIVATE cupti) + target_link_libraries(${target} PRIVATE CUDA::cupti) endif() - if (onnxruntime_ENABLE_NVTX_PROFILE AND NOT WIN32) - target_link_libraries(${target} PRIVATE nvToolsExt) + if (onnxruntime_ENABLE_NVTX_PROFILE) + target_link_libraries(${target} PRIVATE CUDA::nvtx3) endif() if (onnxruntime_ENABLE_TRAINING_OPS) diff --git a/cmake/onnxruntime_providers_nnapi.cmake b/cmake/onnxruntime_providers_nnapi.cmake index 5ac25a3b76efb..b718a976eb26f 100644 --- a/cmake/onnxruntime_providers_nnapi.cmake +++ b/cmake/onnxruntime_providers_nnapi.cmake @@ -49,12 +49,10 @@ endif() # These are shared utils, - # TODO, move this to a separated lib when used by EPs other than NNAPI and CoreML + # TODO, move this to a separate lib when used by EPs other than NNAPI and CoreML list(APPEND onnxruntime_provider_nnapi_cc_src_patterns "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc" ) file(GLOB onnxruntime_providers_nnapi_cc_srcs CONFIGURE_DEPENDS ${onnxruntime_provider_nnapi_cc_src_patterns}) @@ -81,4 +79,4 @@ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() \ No newline at end of file + endif() diff --git a/cmake/onnxruntime_providers_qnn.cmake b/cmake/onnxruntime_providers_qnn.cmake index a93a06e960c81..b68d84c23bb32 100644 --- a/cmake/onnxruntime_providers_qnn.cmake +++ b/cmake/onnxruntime_providers_qnn.cmake @@ -4,12 +4,10 @@ add_compile_definitions(USE_QNN=1) # These are shared utils, - # TODO, move this to a separated lib when used by EPs other than QNN, NNAPI and CoreML - file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS + # TODO, move to a separate lib when used by EPs other than QNN, NNAPI and CoreML + file(GLOB onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc" ) file(GLOB_RECURSE @@ -42,4 +40,4 @@ # ignore the warning unknown-pragmas on "pragma region" if(NOT MSVC) target_compile_options(onnxruntime_providers_qnn PRIVATE "-Wno-unknown-pragmas") - endif() \ No newline at end of file + endif() diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake index 686a993de3a4a..15ffc29e79ff4 100644 --- a/cmake/onnxruntime_providers_tensorrt.cmake +++ b/cmake/onnxruntime_providers_tensorrt.cmake @@ -8,7 +8,7 @@ set(BUILD_LIBRARY_ONLY 1) add_definitions("-DONNX_ML=1") add_definitions("-DONNX_NAMESPACE=onnx") - set(CUDA_INCLUDE_DIRS ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + set(CUDA_INCLUDE_DIRS ${CUDAToolkit_INCLUDE_DIRS}) set(TENSORRT_ROOT ${onnxruntime_TENSORRT_HOME}) set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) set(PROTOBUF_LIBRARY ${PROTOBUF_LIB}) @@ -58,7 +58,7 @@ URL_HASH SHA1=${DEP_SHA1_onnx_tensorrt} ) if (NOT CUDA_INCLUDE_DIR) - set(CUDA_INCLUDE_DIR ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) # onnx-tensorrt repo needs this variable to build + set(CUDA_INCLUDE_DIR ${CUDAToolkit_INCLUDE_DIRS}) # onnx-tensorrt repo needs this variable to build endif() # The onnx_tensorrt repo contains a test program, getSupportedAPITest, which doesn't support Windows. It uses # unistd.h. So we must exclude it from our build. onnxruntime_fetchcontent_makeavailable is for the purpose. @@ -102,11 +102,12 @@ onnxruntime_add_include_to_target(onnxruntime_providers_tensorrt onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) add_dependencies(onnxruntime_providers_tensorrt onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER) - target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${trt_link_libs} cudart ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS}) + target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart) else() - target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${onnxparser_link_libs} ${trt_link_libs} cudart ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS}) + target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart) endif() - target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} + PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) if(onnxruntime_CUDNN_HOME) target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${onnxruntime_CUDNN_HOME}/include) endif() diff --git a/cmake/onnxruntime_providers_vitisai.cmake b/cmake/onnxruntime_providers_vitisai.cmake index 0951c2d02664d..183a3e196af42 100644 --- a/cmake/onnxruntime_providers_vitisai.cmake +++ b/cmake/onnxruntime_providers_vitisai.cmake @@ -14,14 +14,19 @@ "${ONNXRUNTIME_ROOT}/core/providers/vitisai/*.h" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.cc" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" ) source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_vitisai_cc_srcs}) - onnxruntime_add_static_library(onnxruntime_providers_vitisai ${onnxruntime_providers_vitisai_cc_srcs}) - onnxruntime_add_include_to_target(onnxruntime_providers_vitisai onnxruntime_common onnxruntime_framework onnx onnx_proto) - target_link_libraries(onnxruntime_providers_vitisai PRIVATE onnx protobuf::libprotobuf nlohmann_json::nlohmann_json) - if(NOT MSVC) - target_compile_options(onnxruntime_providers_vitisai PUBLIC $<$:-U_FORTIFY_SOURCE -D_FORTIFY_SOURCE=0>) - endif(NOT MSVC) + onnxruntime_add_shared_library(onnxruntime_providers_vitisai ${onnxruntime_providers_vitisai_cc_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_vitisai ${ONNXRUNTIME_PROVIDERS_SHARED} nlohmann_json::nlohmann_json safeint_interface flatbuffers::flatbuffers) + target_link_libraries(onnxruntime_providers_vitisai PRIVATE ${ONNXRUNTIME_PROVIDERS_SHARED}) + if(MSVC) + onnxruntime_add_include_to_target(onnxruntime_providers_vitisai dbghelp) + set_property(TARGET onnxruntime_providers_vitisai APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/vitisai/symbols.def") + else(MSVC) + set_property(TARGET onnxruntime_providers_vitisai APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/vitisai/version_script.lds -Xlinker --gc-sections") + endif(MSVC) target_include_directories(onnxruntime_providers_vitisai PRIVATE "${ONNXRUNTIME_ROOT}/core/providers/vitisai/include" ${XRT_INCLUDE_DIRS} ${CMAKE_CURRENT_BINARY_DIR}/VitisAI) if(MSVC) @@ -30,17 +35,18 @@ target_compile_options(onnxruntime_providers_vitisai PRIVATE "/wd4251") # for unused formal parameter target_compile_options(onnxruntime_providers_vitisai PRIVATE "/wd4100") + # for type name first seen using 'class' now seen using 'struct' + target_compile_options(onnxruntime_providers_vitisai PRIVATE "/wd4099") else(MSVC) + target_compile_options(onnxruntime_providers_vitisai PUBLIC $<$:-U_FORTIFY_SOURCE -D_FORTIFY_SOURCE=0>) target_compile_options(onnxruntime_providers_vitisai PRIVATE -Wno-unused-parameter) endif(MSVC) set_target_properties(onnxruntime_providers_vitisai PROPERTIES FOLDER "ONNXRuntime") set_target_properties(onnxruntime_providers_vitisai PROPERTIES LINKER_LANGUAGE CXX) - if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_vitisai - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() + install(TARGETS onnxruntime_providers_vitisai + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/cmake/onnxruntime_providers_xnnpack.cmake b/cmake/onnxruntime_providers_xnnpack.cmake index 9c00703ca0846..796536ac9d12b 100644 --- a/cmake/onnxruntime_providers_xnnpack.cmake +++ b/cmake/onnxruntime_providers_xnnpack.cmake @@ -7,9 +7,6 @@ "${ONNXRUNTIME_INCLUDE_DIR}/core/providers/xnnpack/*.h" "${ONNXRUNTIME_ROOT}/core/providers/xnnpack/*.h" "${ONNXRUNTIME_ROOT}/core/providers/xnnpack/*.cc" - # utils for handling QDQ models - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc" ) source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_xnnpack_cc_srcs}) @@ -19,6 +16,12 @@ flatbuffers::flatbuffers Boost::mp11 safeint_interface ) + # TODO fix stringop-overflow warnings + # Add compile option to suppress stringop-overflow error in Flatbuffers. + if (HAS_STRINGOP_OVERFLOW) + target_compile_options(onnxruntime_providers_xnnpack PRIVATE -Wno-error=stringop-overflow) + endif() + add_dependencies(onnxruntime_providers_xnnpack onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) set_target_properties(onnxruntime_providers_xnnpack PROPERTIES FOLDER "ONNXRuntime") diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 2e3594f256f65..23c6e5e430875 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -170,7 +170,6 @@ target_link_libraries(onnxruntime_pybind11_state PRIVATE onnxruntime_session ${onnxruntime_libs} ${PROVIDERS_TVM} - ${PROVIDERS_VITISAI} ${PROVIDERS_NNAPI} ${PROVIDERS_XNNPACK} ${PROVIDERS_COREML} @@ -283,10 +282,7 @@ if (WIN32) get_filename_component(CUDNN_DLL_NAME ${CUDNN_DLL_PATH} NAME_WE) string(REPLACE "cudnn64_" "" CUDNN_VERSION "${CUDNN_DLL_NAME}") if(NOT onnxruntime_CUDA_VERSION) - message("Reading json file ${onnxruntime_CUDA_HOME}/version.json") - set(CUDA_SDK_JSON_FILE_PATH "${onnxruntime_CUDA_HOME}/version.json") - file(READ ${CUDA_SDK_JSON_FILE_PATH} CUDA_SDK_JSON_CONTENT) - string(JSON onnxruntime_CUDA_VERSION GET ${CUDA_SDK_JSON_CONTENT} "cuda" "version") + set(onnxruntime_CUDA_VERSION ${CUDAToolkit_VERSION}) message("onnxruntime_CUDA_VERSION=${onnxruntime_CUDA_VERSION}") endif() file(APPEND "${VERSION_INFO_FILE}" @@ -474,6 +470,9 @@ file(GLOB onnxruntime_python_transformers_models_llama_src CONFIGURE_DEPENDS file(GLOB onnxruntime_python_transformers_models_longformer_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/longformer/*.py" ) +file(GLOB onnxruntime_python_transformers_models_phi2_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/phi2/*.py" +) file(GLOB onnxruntime_python_transformers_models_stable_diffusion_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/stable_diffusion/*.py" ) @@ -544,6 +543,7 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/gpt2 COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/llama COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/longformer + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/phi2 COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/stable_diffusion COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/t5 COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/whisper @@ -647,6 +647,9 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_models_longformer_src} $/onnxruntime/transformers/models/longformer/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_transformers_models_phi2_src} + $/onnxruntime/transformers/models/phi2/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_models_stable_diffusion_src} $/onnxruntime/transformers/models/stable_diffusion/ @@ -852,6 +855,16 @@ if (onnxruntime_USE_DNNL) ) endif() +if (onnxruntime_USE_VITISAI) + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + ${DNNL_DLL_PATH} $ + $ + $/onnxruntime/capi/ + ) +endif() + if (onnxruntime_USE_TENSORRT) add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index d485abe6bb1a6..6f54943f09afe 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -20,10 +20,6 @@ set(contrib_ops_excluded_files "bert/fastertransformer_decoder_attention/*" "bert/multihead_attention.cc" "bert/multihead_attention.h" - "bert/fast_gelu_impl.cu" - "bert/fast_gelu_impl.h" - "bert/fast_gelu.cc" - "bert/fast_gelu.h" "bert/relative_attn_bias.cc" "bert/relative_attn_bias.h" "bert/relative_attn_bias_impl.cu" @@ -44,12 +40,7 @@ set(contrib_ops_excluded_files "bert/packed_multihead_attention.cc" "bert/packed_multihead_attention_impl.h" "bert/packed_multihead_attention_impl.cu" - "diffusion/group_norm.cc" "diffusion/group_norm_impl.cu" - "diffusion/group_norm_impl.h" - "diffusion/group_norm_impl_kernel.cuh" - "diffusion/group_norm_common_base.h" - "diffusion/group_norm_common_base.cc" "diffusion/nhwc_conv.cc" "math/gemm_float8.cc" "math/gemm_float8.cu" @@ -123,6 +114,10 @@ if (NOT onnxruntime_USE_NCCL) list(APPEND contrib_ops_excluded_files "collective/distributed_reduce.cc") list(APPEND contrib_ops_excluded_files "collective/distributed_unsqueeze.cc") list(APPEND contrib_ops_excluded_files "collective/distributed_squeeze.cc") +else() + # moe not supported for ROCm EP + list(APPEND contrib_ops_excluded_files "collective/sharded_moe.h") + list(APPEND contrib_ops_excluded_files "collective/sharded_moe.cc") endif() set(provider_excluded_files diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 0987d6d164dbd..1ffb838328643 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -if (${CMAKE_SYSTEM_NAME} STREQUAL "iOS") +if (IOS) find_package(XCTest REQUIRED) endif() @@ -18,7 +18,7 @@ function(AddTest) cmake_parse_arguments(_UT "DYN" "TARGET" "LIBS;SOURCES;DEPENDS;TEST_ARGS" ${ARGN}) list(REMOVE_DUPLICATES _UT_SOURCES) - if (${CMAKE_SYSTEM_NAME} STREQUAL "iOS") + if (IOS) onnxruntime_add_executable(${_UT_TARGET} ${TEST_SRC_DIR}/xctest/orttestmain.m) else() onnxruntime_add_executable(${_UT_TARGET} ${_UT_SOURCES}) @@ -67,7 +67,7 @@ function(AddTest) if(onnxruntime_USE_CUDA) #XXX: we should not need to do this. onnxruntime_test_all.exe should not have direct dependency on CUDA DLLs, # otherwise it will impact when CUDA DLLs can be unloaded. - target_link_libraries(${_UT_TARGET} PRIVATE cudart) + target_link_libraries(${_UT_TARGET} PRIVATE CUDA::cudart) endif() target_link_libraries(${_UT_TARGET} PRIVATE ${_UT_LIBS} GTest::gtest GTest::gmock ${onnxruntime_EXTERNAL_LIBRARIES}) endif() @@ -111,7 +111,9 @@ function(AddTest) target_compile_options(${_UT_TARGET} PRIVATE ${DISABLED_WARNINGS_FOR_TVM}) target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options -Wno-error=sign-compare>" "$<$>:-Wno-error=sign-compare>") - target_compile_options(${_UT_TARGET} PRIVATE "-Wno-error=uninitialized") + if (${HAS_NOERROR}) + target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Wno-error=uninitialized>") + endif() endif() set(TEST_ARGS ${_UT_TEST_ARGS}) @@ -127,7 +129,7 @@ function(AddTest) endif() endif(onnxruntime_GENERATE_TEST_REPORTS) - if (${CMAKE_SYSTEM_NAME} STREQUAL "iOS") + if (IOS) # target_sources(${_UT_TARGET} PRIVATE ${TEST_SRC_DIR}/xctest/orttestmain.m) set_target_properties(${_UT_TARGET} PROPERTIES FOLDER "ONNXRuntimeTest" MACOSX_BUNDLE_BUNDLE_NAME ${_UT_TARGET} @@ -565,11 +567,7 @@ if(onnxruntime_USE_ROCM) endif() if(onnxruntime_USE_COREML) - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml onnxruntime_coreml_proto) - else() - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml) - endif() + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto) endif() if(onnxruntime_USE_ACL) @@ -591,7 +589,6 @@ set(ONNXRUNTIME_TEST_LIBS # CUDA, ROCM, TENSORRT, MIGRAPHX, DNNL, and OpenVINO are dynamically loaded at runtime ${PROVIDERS_NNAPI} ${PROVIDERS_JS} - ${PROVIDERS_VITISAI} ${PROVIDERS_QNN} ${PROVIDERS_SNPE} ${PROVIDERS_RKNPU} @@ -675,15 +672,9 @@ endif() if(onnxruntime_USE_COREML) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/coreml/*) - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml onnxruntime_coreml_proto) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml onnxruntime_coreml_proto) - list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_coreml onnxruntime_coreml_proto) - else() - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml) - list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_coreml) - endif() + list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml coreml_proto) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto) + list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_coreml coreml_proto) endif() if(onnxruntime_USE_XNNPACK) @@ -743,34 +734,37 @@ target_include_directories(onnxruntime_test_utils PUBLIC "${TEST_SRC_DIR}/util/i set_target_properties(onnxruntime_test_utils PROPERTIES FOLDER "ONNXRuntimeTest") source_group(TREE ${TEST_SRC_DIR} FILES ${onnxruntime_test_utils_src}) -set(onnx_test_runner_src_dir ${TEST_SRC_DIR}/onnx) -file(GLOB onnx_test_runner_common_srcs CONFIGURE_DEPENDS - ${onnx_test_runner_src_dir}/*.h - ${onnx_test_runner_src_dir}/*.cc) +if(NOT IOS) + set(onnx_test_runner_src_dir ${TEST_SRC_DIR}/onnx) + file(GLOB onnx_test_runner_common_srcs CONFIGURE_DEPENDS + ${onnx_test_runner_src_dir}/*.h + ${onnx_test_runner_src_dir}/*.cc) -list(REMOVE_ITEM onnx_test_runner_common_srcs ${onnx_test_runner_src_dir}/main.cc) + list(REMOVE_ITEM onnx_test_runner_common_srcs ${onnx_test_runner_src_dir}/main.cc) -onnxruntime_add_static_library(onnx_test_runner_common ${onnx_test_runner_common_srcs}) -if(MSVC) - target_compile_options(onnx_test_runner_common PRIVATE "$<$:SHELL:--compiler-options /utf-8>" - "$<$>:/utf-8>") -else() - target_compile_definitions(onnx_test_runner_common PUBLIC -DNSYNC_ATOMIC_CPP11) - target_include_directories(onnx_test_runner_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT}) - onnxruntime_add_include_to_target(onnx_test_runner_common nsync::nsync_cpp) -endif() -if (MSVC AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8) - #TODO: fix the warnings, they are dangerous - target_compile_options(onnx_test_runner_common PRIVATE "/wd4244") -endif() -onnxruntime_add_include_to_target(onnx_test_runner_common onnxruntime_common onnxruntime_framework - onnxruntime_test_utils onnx onnx_proto re2::re2 flatbuffers::flatbuffers Boost::mp11 safeint_interface) + onnxruntime_add_static_library(onnx_test_runner_common ${onnx_test_runner_common_srcs}) + if(MSVC) + target_compile_options(onnx_test_runner_common PRIVATE "$<$:SHELL:--compiler-options /utf-8>" + "$<$>:/utf-8>") + else() + target_compile_definitions(onnx_test_runner_common PUBLIC -DNSYNC_ATOMIC_CPP11) + target_include_directories(onnx_test_runner_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT}) + onnxruntime_add_include_to_target(onnx_test_runner_common nsync::nsync_cpp) + endif() + if (MSVC AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8) + #TODO: fix the warnings, they are dangerous + target_compile_options(onnx_test_runner_common PRIVATE "/wd4244") + endif() + onnxruntime_add_include_to_target(onnx_test_runner_common onnxruntime_common onnxruntime_framework + onnxruntime_test_utils onnx onnx_proto re2::re2 flatbuffers::flatbuffers Boost::mp11 safeint_interface) -add_dependencies(onnx_test_runner_common onnx_test_data_proto ${onnxruntime_EXTERNAL_DEPENDENCIES}) -target_include_directories(onnx_test_runner_common PRIVATE ${eigen_INCLUDE_DIRS} - ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT}) + add_dependencies(onnx_test_runner_common onnx_test_data_proto ${onnxruntime_EXTERNAL_DEPENDENCIES}) + target_include_directories(onnx_test_runner_common PRIVATE ${eigen_INCLUDE_DIRS} + ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT}) -set_target_properties(onnx_test_runner_common PROPERTIES FOLDER "ONNXRuntimeTest") + set_target_properties(onnx_test_runner_common PROPERTIES FOLDER "ONNXRuntimeTest") + set(onnx_test_runner_common_lib onnx_test_runner_common) +endif() set(all_tests ${onnxruntime_test_common_src} ${onnxruntime_test_ir_src} ${onnxruntime_test_optimizer_src} ${onnxruntime_test_framework_src} ${onnxruntime_test_providers_src} ${onnxruntime_test_quantiztion_src}) @@ -783,6 +777,7 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $) config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut) onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock) + target_include_directories(onnxruntime_providers_cuda_ut PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey) target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_cuda_ut) endif() @@ -824,6 +819,17 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") "${TEST_SRC_DIR}/providers/memcpy_test.cc" ) endif() + list(REMOVE_ITEM all_tests "${TEST_SRC_DIR}/providers/cpu/reduction/reduction_ops_test.cc" + "${TEST_SRC_DIR}/providers/cpu/tensor/grid_sample_test.cc") +endif() + +if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten" OR IOS) + # Because we do not run these model tests in our web or iOS CI build pipelines, and some test code uses C++17 + # filesystem functions that are not available in the iOS version we target. + message("Disable model tests in onnxruntime_test_all") + list(REMOVE_ITEM all_tests + "${TEST_SRC_DIR}/providers/cpu/model_tests.cc" + ) endif() set(test_all_args) @@ -843,7 +849,7 @@ AddTest( TARGET onnxruntime_test_all SOURCES ${all_tests} ${onnxruntime_unittest_main_src} LIBS - onnx_test_runner_common ${onnxruntime_test_providers_libs} ${onnxruntime_test_common_libs} + ${onnx_test_runner_common_lib} ${onnxruntime_test_providers_libs} ${onnxruntime_test_common_libs} onnx_test_data_proto DEPENDS ${all_dependencies} TEST_ARGS ${test_all_args} @@ -881,7 +887,7 @@ endif() # the default logger tests conflict with the need to have an overall default logger # so skip in this type of target_compile_definitions(onnxruntime_test_all PUBLIC -DSKIP_DEFAULT_LOGGER_TESTS) -if (CMAKE_SYSTEM_NAME STREQUAL "iOS") +if (IOS) target_compile_definitions(onnxruntime_test_all_xc PUBLIC -DSKIP_DEFAULT_LOGGER_TESTS) endif() if(onnxruntime_RUN_MODELTEST_IN_DEBUG_MODE) @@ -906,7 +912,7 @@ if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") set_target_properties(onnxruntime_test_all PROPERTIES LINK_DEPENDS ${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js) - set_target_properties(onnxruntime_test_all PROPERTIES LINK_FLAGS "-s STACK_SIZE=5242880 -s ALLOW_MEMORY_GROWTH=1 -s MAXIMUM_MEMORY=4294967296 -s INCOMING_MODULE_JS_API=[preRun,locateFile,arguments,onExit,wasmMemory,buffer,instantiateWasm] --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1 -s DEMANGLE_SUPPORT=1") + set_target_properties(onnxruntime_test_all PROPERTIES LINK_FLAGS "-s STACK_SIZE=5242880 -s INITIAL_MEMORY=536870912 -s ALLOW_MEMORY_GROWTH=1 -s MAXIMUM_MEMORY=4294967296 -s INCOMING_MODULE_JS_API=[preRun,locateFile,arguments,onExit,wasmMemory,buffer,instantiateWasm] --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1 -s DEMANGLE_SUPPORT=1") if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) set_property(TARGET onnxruntime_test_all APPEND_STRING PROPERTY LINK_FLAGS " -s DEFAULT_PTHREAD_STACK_SIZE=131072 -s PROXY_TO_PTHREAD=1") endif() @@ -994,7 +1000,9 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) if (MSVC OR ${CMAKE_SYSTEM_NAME} STREQUAL "Linux") file(GLOB QNN_LIB_FILES LIST_DIRECTORIES false "${onnxruntime_QNN_HOME}/lib/${QNN_ARCH_ABI}/*.so" "${onnxruntime_QNN_HOME}/lib/${QNN_ARCH_ABI}/*.dll") if (${QNN_ARCH_ABI} STREQUAL "aarch64-windows-msvc") - file(GLOB EXTRA_HTP_LIB LIST_DIRECTORIES false "${onnxruntime_QNN_HOME}/lib/hexagon-v68/unsigned/libQnnHtpV68Skel.so" "${onnxruntime_QNN_HOME}/lib/hexagon-v73/unsigned/libQnnHtpV73Skel.so") + file(GLOB EXTRA_HTP_LIB LIST_DIRECTORIES false "${onnxruntime_QNN_HOME}/lib/hexagon-v68/unsigned/libQnnHtpV68Skel.so" + "${onnxruntime_QNN_HOME}/lib/hexagon-v73/unsigned/libQnnHtpV73Skel.so" + "${onnxruntime_QNN_HOME}/lib/hexagon-v73/unsigned/libqnnhtpv73.cat") list(APPEND QNN_LIB_FILES ${EXTRA_HTP_LIB}) endif() message(STATUS "QNN lib files: " ${QNN_LIB_FILES}) @@ -1052,45 +1060,42 @@ if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS) list(APPEND onnx_test_libs onnxruntime_language_interop onnxruntime_pyop) endif() -onnxruntime_add_executable(onnx_test_runner ${onnx_test_runner_src_dir}/main.cc) -if(MSVC) - target_compile_options(onnx_test_runner PRIVATE "$<$:SHELL:--compiler-options /utf-8>" - "$<$>:/utf-8>") -endif() -if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") - set_target_properties(onnx_test_runner PROPERTIES - XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO" - ) -endif() -if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) - set_target_properties(onnx_test_runner PROPERTIES LINK_FLAGS "-s NODERAWFS=1 -s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1") - else() - set_target_properties(onnx_test_runner PROPERTIES LINK_FLAGS "-s NODERAWFS=1 -s ALLOW_MEMORY_GROWTH=1") - endif() -endif() +if (NOT IOS) + onnxruntime_add_executable(onnx_test_runner ${onnx_test_runner_src_dir}/main.cc) + if(MSVC) + target_compile_options(onnx_test_runner PRIVATE "$<$:SHELL:--compiler-options /utf-8>" + "$<$>:/utf-8>") + endif() + if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) + set_target_properties(onnx_test_runner PROPERTIES LINK_FLAGS "-s NODERAWFS=1 -s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1") + else() + set_target_properties(onnx_test_runner PROPERTIES LINK_FLAGS "-s NODERAWFS=1 -s ALLOW_MEMORY_GROWTH=1") + endif() + endif() -target_link_libraries(onnx_test_runner PRIVATE onnx_test_runner_common ${GETOPT_LIB_WIDE} ${onnx_test_libs} nlohmann_json::nlohmann_json) -target_include_directories(onnx_test_runner PRIVATE ${ONNXRUNTIME_ROOT}) -if (onnxruntime_USE_ROCM) - target_include_directories(onnx_test_runner PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining) -endif() -if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) - target_link_libraries(onnx_test_runner PRIVATE Python::Python) -endif() -set_target_properties(onnx_test_runner PROPERTIES FOLDER "ONNXRuntimeTest") + target_link_libraries(onnx_test_runner PRIVATE onnx_test_runner_common ${GETOPT_LIB_WIDE} ${onnx_test_libs} nlohmann_json::nlohmann_json) + target_include_directories(onnx_test_runner PRIVATE ${ONNXRUNTIME_ROOT}) + if (onnxruntime_USE_ROCM) + target_include_directories(onnx_test_runner PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining) + endif() + if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) + target_link_libraries(onnx_test_runner PRIVATE Python::Python) + endif() + set_target_properties(onnx_test_runner PROPERTIES FOLDER "ONNXRuntimeTest") -if (onnxruntime_USE_TVM) - if (WIN32) - target_link_options(onnx_test_runner PRIVATE "/STACK:4000000") - endif() -endif() + if (onnxruntime_USE_TVM) + if (WIN32) + target_link_options(onnx_test_runner PRIVATE "/STACK:4000000") + endif() + endif() -install(TARGETS onnx_test_runner - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - BUNDLE DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) + install(TARGETS onnx_test_runner + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + BUNDLE DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) +endif() if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) if(onnxruntime_BUILD_BENCHMARKS) @@ -1171,90 +1176,80 @@ endif() if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) - #perf test runner - set(onnxruntime_perf_test_src_dir ${TEST_SRC_DIR}/perftest) - set(onnxruntime_perf_test_src_patterns - "${onnxruntime_perf_test_src_dir}/*.cc" - "${onnxruntime_perf_test_src_dir}/*.h") + if(NOT IOS) + #perf test runner + set(onnxruntime_perf_test_src_dir ${TEST_SRC_DIR}/perftest) + set(onnxruntime_perf_test_src_patterns + "${onnxruntime_perf_test_src_dir}/*.cc" + "${onnxruntime_perf_test_src_dir}/*.h") - if(WIN32) - list(APPEND onnxruntime_perf_test_src_patterns - "${onnxruntime_perf_test_src_dir}/windows/*.cc" - "${onnxruntime_perf_test_src_dir}/windows/*.h" ) - else () - list(APPEND onnxruntime_perf_test_src_patterns - "${onnxruntime_perf_test_src_dir}/posix/*.cc" - "${onnxruntime_perf_test_src_dir}/posix/*.h" ) - endif() + if(WIN32) + list(APPEND onnxruntime_perf_test_src_patterns + "${onnxruntime_perf_test_src_dir}/windows/*.cc" + "${onnxruntime_perf_test_src_dir}/windows/*.h" ) + else () + list(APPEND onnxruntime_perf_test_src_patterns + "${onnxruntime_perf_test_src_dir}/posix/*.cc" + "${onnxruntime_perf_test_src_dir}/posix/*.h" ) + endif() - file(GLOB onnxruntime_perf_test_src CONFIGURE_DEPENDS - ${onnxruntime_perf_test_src_patterns} - ) - onnxruntime_add_executable(onnxruntime_perf_test ${onnxruntime_perf_test_src} ${ONNXRUNTIME_ROOT}/core/platform/path_lib.cc) - if(MSVC) - target_compile_options(onnxruntime_perf_test PRIVATE "$<$:SHELL:--compiler-options /utf-8>" + file(GLOB onnxruntime_perf_test_src CONFIGURE_DEPENDS + ${onnxruntime_perf_test_src_patterns} + ) + onnxruntime_add_executable(onnxruntime_perf_test ${onnxruntime_perf_test_src} ${ONNXRUNTIME_ROOT}/core/platform/path_lib.cc) + if(MSVC) + target_compile_options(onnxruntime_perf_test PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") - endif() - target_include_directories(onnxruntime_perf_test PRIVATE ${onnx_test_runner_src_dir} ${ONNXRUNTIME_ROOT} + endif() + target_include_directories(onnxruntime_perf_test PRIVATE ${onnx_test_runner_src_dir} ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} ${onnxruntime_graph_header} ${onnxruntime_exec_src_dir} ${CMAKE_CURRENT_BINARY_DIR}) - if (onnxruntime_USE_ROCM) - target_include_directories(onnxruntime_perf_test PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining) - endif() - if (WIN32) - target_compile_options(onnxruntime_perf_test PRIVATE ${disabled_warnings}) - if (NOT DEFINED SYS_PATH_LIB) - set(SYS_PATH_LIB shlwapi) + if (onnxruntime_USE_ROCM) + target_include_directories(onnxruntime_perf_test PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining) + endif() + if (WIN32) + target_compile_options(onnxruntime_perf_test PRIVATE ${disabled_warnings}) + if (NOT DEFINED SYS_PATH_LIB) + set(SYS_PATH_LIB shlwapi) + endif() endif() - endif() - if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") - set_target_properties(onnxruntime_perf_test PROPERTIES - XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO" - ) - endif() - if (onnxruntime_BUILD_SHARED_LIB) - #It will dynamically link to onnxruntime. So please don't add onxruntime_graph/onxruntime_framework/... here. - #onnxruntime_common is kind of ok because it is thin, tiny and totally stateless. - set(onnxruntime_perf_test_libs + if (onnxruntime_BUILD_SHARED_LIB) + #It will dynamically link to onnxruntime. So please don't add onxruntime_graph/onxruntime_framework/... here. + #onnxruntime_common is kind of ok because it is thin, tiny and totally stateless. + set(onnxruntime_perf_test_libs onnx_test_runner_common onnxruntime_test_utils onnxruntime_common onnxruntime onnxruntime_flatbuffers onnx_test_data_proto ${onnxruntime_EXTERNAL_LIBRARIES} ${GETOPT_LIB_WIDE} ${SYS_PATH_LIB} ${CMAKE_DL_LIBS}) - if(NOT WIN32) - list(APPEND onnxruntime_perf_test_libs nsync::nsync_cpp) - if(onnxruntime_USE_SNPE) - list(APPEND onnxruntime_perf_test_libs onnxruntime_providers_snpe) + if(NOT WIN32) + list(APPEND onnxruntime_perf_test_libs nsync::nsync_cpp) + if(onnxruntime_USE_SNPE) + list(APPEND onnxruntime_perf_test_libs onnxruntime_providers_snpe) + endif() endif() + if (CMAKE_SYSTEM_NAME STREQUAL "Android") + list(APPEND onnxruntime_perf_test_libs ${android_shared_libs}) + endif() + target_link_libraries(onnxruntime_perf_test PRIVATE ${onnxruntime_perf_test_libs} Threads::Threads) + if(WIN32) + target_link_libraries(onnxruntime_perf_test PRIVATE debug dbghelp advapi32) + endif() + else() + target_link_libraries(onnxruntime_perf_test PRIVATE onnx_test_runner_common ${GETOPT_LIB_WIDE} ${onnx_test_libs}) endif() - if (CMAKE_SYSTEM_NAME STREQUAL "Android") - list(APPEND onnxruntime_perf_test_libs ${android_shared_libs}) - endif() - target_link_libraries(onnxruntime_perf_test PRIVATE ${onnxruntime_perf_test_libs} Threads::Threads) - if(WIN32) - target_link_libraries(onnxruntime_perf_test PRIVATE debug dbghelp advapi32) - endif() - if(tensorflow_C_PACKAGE_PATH) - target_include_directories(onnxruntime_perf_test PRIVATE ${tensorflow_C_PACKAGE_PATH}/include) - target_link_directories(onnxruntime_perf_test PRIVATE ${tensorflow_C_PACKAGE_PATH}/lib) - target_link_libraries(onnxruntime_perf_test PRIVATE tensorflow) - target_compile_definitions(onnxruntime_perf_test PRIVATE HAVE_TENSORFLOW) - endif() - else() - target_link_libraries(onnxruntime_perf_test PRIVATE onnx_test_runner_common ${GETOPT_LIB_WIDE} ${onnx_test_libs}) - endif() - set_target_properties(onnxruntime_perf_test PROPERTIES FOLDER "ONNXRuntimeTest") + set_target_properties(onnxruntime_perf_test PROPERTIES FOLDER "ONNXRuntimeTest") - if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS AND NOT onnxruntime_BUILD_SHARED_LIB) - target_link_libraries(onnxruntime_perf_test PRIVATE onnxruntime_language_interop onnxruntime_pyop) - endif() + if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS AND NOT onnxruntime_BUILD_SHARED_LIB) + target_link_libraries(onnxruntime_perf_test PRIVATE onnxruntime_language_interop onnxruntime_pyop) + endif() - if (onnxruntime_USE_TVM) - if (WIN32) - target_link_options(onnxruntime_perf_test PRIVATE "/STACK:4000000") + if (onnxruntime_USE_TVM) + if (WIN32) + target_link_options(onnxruntime_perf_test PRIVATE "/STACK:4000000") + endif() endif() endif() - # shared lib if (onnxruntime_BUILD_SHARED_LIB) onnxruntime_add_static_library(onnxruntime_mocked_allocator ${TEST_SRC_DIR}/util/test_allocator.cc) @@ -1275,7 +1270,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) list(APPEND onnxruntime_shared_lib_test_LIBS cpuinfo) endif() if (onnxruntime_USE_CUDA) - list(APPEND onnxruntime_shared_lib_test_LIBS cudart) + list(APPEND onnxruntime_shared_lib_test_LIBS CUDA::cudart) endif() if (onnxruntime_USE_ROCM) list(APPEND onnxruntime_shared_lib_test_LIBS hip::host) @@ -1309,7 +1304,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_compile_definitions(onnxruntime_shared_lib_test PRIVATE USE_DUMMY_EXA_DEMANGLE=1) endif() - if (CMAKE_SYSTEM_NAME STREQUAL "iOS") + if (IOS) add_custom_command( TARGET onnxruntime_shared_lib_test POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_directory @@ -1396,7 +1391,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26426>" "$<$>:/wd26426>") endif() - if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") + if(IOS) set_target_properties(onnxruntime_mlas_test PROPERTIES XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO" ) @@ -1597,7 +1592,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") DEPENDS ${all_dependencies} ) - if (CMAKE_SYSTEM_NAME STREQUAL "iOS") + if (IOS) add_custom_command( TARGET onnxruntime_customopregistration_test POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_directory diff --git a/cmake/patches/coremltools/crossplatformbuild.patch b/cmake/patches/coremltools/crossplatformbuild.patch new file mode 100644 index 0000000000000..7f2268f50c82e --- /dev/null +++ b/cmake/patches/coremltools/crossplatformbuild.patch @@ -0,0 +1,155 @@ +diff --git a/mlmodel/src/MILBlob/Blob/FileWriter.cpp b/mlmodel/src/MILBlob/Blob/FileWriter.cpp +index adc7bfcf..7b2bf9cc 100644 +--- a/mlmodel/src/MILBlob/Blob/FileWriter.cpp ++++ b/mlmodel/src/MILBlob/Blob/FileWriter.cpp +@@ -8,8 +8,12 @@ + + #include + #include ++ ++// ORT_EDIT: Exclude mmap on Windows. Not used in this file anyway. ++#if !defined(_WIN32) + #include + #include ++#endif + + using namespace MILBlob; + using namespace MILBlob::Blob; +diff --git a/mlmodel/src/MILBlob/Fp16.cpp b/mlmodel/src/MILBlob/Fp16.cpp +index ae1e71a1..77a7161f 100644 +--- a/mlmodel/src/MILBlob/Fp16.cpp ++++ b/mlmodel/src/MILBlob/Fp16.cpp +@@ -5,6 +5,8 @@ + + #include "MILBlob/Fp16.hpp" + ++// ORT_EDIT: Exclude clang specific pragmas from other builds ++#if defined(__clang__) + // fp16 lib code has some conversion warnings we don't want to globally ignore + #pragma clang diagnostic push + #pragma clang diagnostic ignored "-Wincompatible-pointer-types" +@@ -12,6 +14,9 @@ + #pragma clang diagnostic ignored "-Wconversion" + #include "fp16/fp16.h" + #pragma clang diagnostic pop ++#else ++#include "fp16/fp16.h" ++#endif + + using namespace MILBlob; + +diff --git a/modelpackage/src/ModelPackage.cpp b/modelpackage/src/ModelPackage.cpp +index 8fee56b9..99e0d8d6 100644 +--- a/modelpackage/src/ModelPackage.cpp ++++ b/modelpackage/src/ModelPackage.cpp +@@ -26,7 +26,14 @@ namespace std { + #else + #error "missing required header " + #endif ++ ++// ORT_EDIT: Use UuidCreate on Windows. ++#if defined(_WIN32) ++#pragma comment(lib, "rpcrt4.lib") // UuidCreate ++#include ++#else + #include ++#endif + #include + + #if defined(__cplusplus) +@@ -187,7 +194,10 @@ public: + ModelPackageItemInfo createFile(const std::string& name, const std::string& author, const std::string& description); + }; + ++// ORT_EDIT: pragma only available on APPLE platforms ++#if defined(__APPLE__) + #pragma mark ModelPackageImpl ++#endif + + ModelPackageImpl::ModelPackageImpl(const std::filesystem::path& path, bool createIfNecessary, bool readOnly) + : m_packagePath(path), +@@ -372,6 +382,20 @@ std::filesystem::path ModelPackageImpl::getItemPath(const std::string& name, con + } + + std::string ModelPackageImpl::generateIdentifier() const { ++// ORT_EDIT: Use built-in UUID generation on Windows ++#if defined(_WIN32) ++ UUID uuid; ++ UuidCreate(&uuid); ++ ++ RPC_CSTR uuidStr; ++ UuidToStringA(&uuid, &uuidStr); ++ ++ std::string uuidStrCpp(reinterpret_cast(uuidStr)); ++ ++ RpcStringFreeA(&uuidStr); ++ ++ return uuidStrCpp; ++#else + uuid_t uuid; + + // uuid_unparse generates a 36-character null-terminated string (37 bytes). +@@ -383,6 +407,7 @@ std::string ModelPackageImpl::generateIdentifier() const { + uuid_unparse(uuid, buf); + + return std::string(buf); ++#endif + } + + ModelPackageItemInfo ModelPackageImpl::createFile(const std::string& name, const std::string& author, const std::string& description) { +@@ -468,7 +493,13 @@ std::shared_ptr ModelPackageImpl::findItem(const std::stri + auto author = itemInfoEntry->getString(kModelPackageItemInfoAuthorKey); + auto description = itemInfoEntry->getString(kModelPackageItemInfoDescriptionKey); + ++// ORT_EDIT: need to use path.string() on Windows ++#if defined(_WIN32) ++ return std::make_shared(std::make_shared(identifier, path.string(), name, author, description)); ++ ++#else + return std::make_shared(std::make_shared(identifier, path, name, author, description)); ++#endif + } + + std::shared_ptr ModelPackageImpl::findItem(const std::string& name, const std::string& author) const +@@ -514,7 +545,9 @@ void ModelPackageImpl::removeItem(const std::string& identifier) + } + + auto path = m_packageDataDirPath / itemInfoEntry->getString(kModelPackageItemInfoPathKey); +- if (0 != std::remove(path.c_str())) { ++ // ORT_EDIT: std::remove doesn't work on Windows. Use std::filesystem::remove instead. ++ // if (0 != std::remove(path.c_str())) { ++ if (!std::filesystem::remove(path)) { + throw std::runtime_error("Failed to remove file at path: " + path.string()); + } + +@@ -525,13 +558,16 @@ bool ModelPackageImpl::isValid(const std::filesystem::path& path) + { + try { + ModelPackageImpl(path, false, true); +- } catch (std::runtime_error& e) { ++ } catch (std::runtime_error& /*e*/) { // ORT_EDIT: comment out unused variable + return false; + } + return true; + } + ++// ORT_EDIT: pragma only available on APPLE platforms ++#if defined(__APPLE__) + #pragma mark ModelPackage ++#endif + + ModelPackage::ModelPackage(const std::string& packagePath, bool createIfNecessary, bool readOnly) + : m_modelPackageImpl(std::make_shared(packagePath, createIfNecessary, readOnly)) +@@ -544,7 +580,12 @@ ModelPackage::~ModelPackage() + + std::string ModelPackage::path() const + { ++// ORT_EDIT: Windows doesn't automatically convert to std::string as the native format could be char or wchar. ++#if defined(_WIN32) ++ return m_modelPackageImpl->path().string(); ++#else + return m_modelPackageImpl->path(); ++#endif + } + + std::string ModelPackage::setRootModel(const std::string& path, const std::string& name, const std::string& author, const std::string& description) diff --git a/cmake/patches/cpuinfo/9bb12d342fd9479679d505d93a478a6f9cd50a47.patch b/cmake/patches/cpuinfo/9bb12d342fd9479679d505d93a478a6f9cd50a47.patch new file mode 100644 index 0000000000000..afb19a45ce0f4 --- /dev/null +++ b/cmake/patches/cpuinfo/9bb12d342fd9479679d505d93a478a6f9cd50a47.patch @@ -0,0 +1,22 @@ +diff --git a/include/cpuinfo.h b/include/cpuinfo.h +index c46b65e..8b83a64 100644 +--- a/include/cpuinfo.h ++++ b/include/cpuinfo.h +@@ -18,7 +18,7 @@ + #define CPUINFO_ARCH_X86 1 + #endif + +-#if defined(__x86_64__) || defined(__x86_64) || defined(_M_X64) || defined(_M_AMD64) ++#if defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)) || (defined(_M_AMD64) && !defined(_M_ARM64EC)) + #define CPUINFO_ARCH_X86_64 1 + #endif + +@@ -26,7 +26,7 @@ + #define CPUINFO_ARCH_ARM 1 + #endif + +-#if defined(__aarch64__) || defined(_M_ARM64) ++#if defined(__aarch64__) || defined(_M_ARM64) || defined(_M_ARM64EC) + #define CPUINFO_ARCH_ARM64 1 + #endif + diff --git a/cmake/patches/flatbuffers/flatbuffers.patch b/cmake/patches/flatbuffers/flatbuffers.patch index fb2678ef1bdce..fbe8db37ecb0e 100644 --- a/cmake/patches/flatbuffers/flatbuffers.patch +++ b/cmake/patches/flatbuffers/flatbuffers.patch @@ -2,35 +2,11 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt index 3987eac9..5e5462f1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -223,7 +223,7 @@ elseif(CMAKE_COMPILER_IS_GNUCXX) - "${CMAKE_CXX_FLAGS} -std=c++0x") - endif(CYGWIN) - set(CMAKE_CXX_FLAGS -- "${CMAKE_CXX_FLAGS} -Wall -pedantic -Werror -Wextra -Werror=shadow") -+ "${CMAKE_CXX_FLAGS} -Wall -pedantic -Werror -Wextra -Werror=shadow -Wno-error=stringop-overflow") - set(FLATBUFFERS_PRIVATE_CXX_FLAGS "-Wold-style-cast") - if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.4) - if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) -diff --git a/src/idl_gen_rust.cpp b/src/idl_gen_rust.cpp -index 55b8439b..dc03e8a8 100644 ---- a/src/idl_gen_rust.cpp -+++ b/src/idl_gen_rust.cpp -@@ -406,7 +406,8 @@ class RustGenerator : public BaseGenerator { - // example: f(A, D::E) -> super::D::E - // does not include leaf object (typically a struct type). - -- size_t i = 0; -+ // fix unused but set variable warning -+ //size_t i = 0; - std::stringstream stream; - - auto s = src->components.begin(); -@@ -417,7 +418,7 @@ class RustGenerator : public BaseGenerator { - if (*s != *d) { break; } - ++s; - ++d; -- ++i; -+ //++i; - } - - for (; s != src->components.end(); ++s) { stream << "super::"; } +@@ -279,5 +279,5 @@ + # Append FLATBUFFERS_CXX_FLAGS to CMAKE_CXX_FLAGS. + if(DEFINED FLATBUFFERS_CXX_FLAGS) + message(STATUS "extend CXX_FLAGS with ${FLATBUFFERS_CXX_FLAGS}") +- set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${FLATBUFFERS_CXX_FLAGS}") ++ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${FLATBUFFERS_CXX_FLAGS} -Wno-error=stringop-overflow") + endif() + message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") diff --git a/cmake/wcos_rules_override.cmake b/cmake/wcos_rules_override.cmake index f3d8093629a42..ec2303b073d5e 100644 --- a/cmake/wcos_rules_override.cmake +++ b/cmake/wcos_rules_override.cmake @@ -1,2 +1,2 @@ -set(CMAKE_C_STANDARD_LIBRARIES_INIT onecoreuap_apiset.lib) -set(CMAKE_CXX_STANDARD_LIBRARIES_INIT onecoreuap_apiset.lib) +set(CMAKE_C_STANDARD_LIBRARIES_INIT onecoreuap.lib) +set(CMAKE_CXX_STANDARD_LIBRARIES_INIT onecoreuap.lib) diff --git a/cmake/winml.cmake b/cmake/winml.cmake index 268ee3960e75a..d74250b962628 100644 --- a/cmake/winml.cmake +++ b/cmake/winml.cmake @@ -836,6 +836,13 @@ if (winml_is_inbox) target_include_directories(${new_target} PRIVATE ${include_directories}) target_link_libraries(${new_target} PRIVATE ${link_libraries}) target_link_options(${new_target} PRIVATE ${link_options}) + + # Attempt to copy linker flags + get_target_property(link_flags ${target} LINK_FLAGS) + + if (NOT link_flags MATCHES ".*NOTFOUND") + set_property(TARGET ${new_target} PROPERTY LINK_FLAGS "${link_flags}") + endif() endfunction() if (WAI_ARCH STREQUAL x64 OR WAI_ARCH STREQUAL arm64) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 4128524b30483..8a8426a0b3054 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -362,6 +362,7 @@ static NativeMethods() OrtDisableMemPattern = (DOrtDisableMemPattern)Marshal.GetDelegateForFunctionPointer(api_.DisableMemPattern, typeof(DOrtDisableMemPattern)); OrtEnableCpuMemArena = (DOrtEnableCpuMemArena)Marshal.GetDelegateForFunctionPointer(api_.EnableCpuMemArena, typeof(DOrtEnableCpuMemArena)); OrtDisableCpuMemArena = (DOrtDisableCpuMemArena)Marshal.GetDelegateForFunctionPointer(api_.DisableCpuMemArena, typeof(DOrtDisableCpuMemArena)); + OrtDisablePerSessionThreads = (DOrtDisablePerSessionThreads)Marshal.GetDelegateForFunctionPointer(api_.DisablePerSessionThreads, typeof(DOrtDisablePerSessionThreads)); OrtSetSessionLogId = (DOrtSetSessionLogId)Marshal.GetDelegateForFunctionPointer(api_.SetSessionLogId, typeof(DOrtSetSessionLogId)); OrtSetSessionLogVerbosityLevel = (DOrtSetSessionLogVerbosityLevel)Marshal.GetDelegateForFunctionPointer(api_.SetSessionLogVerbosityLevel, typeof(DOrtSetSessionLogVerbosityLevel)); OrtSetSessionLogSeverityLevel = (DOrtSetSessionLogSeverityLevel)Marshal.GetDelegateForFunctionPointer(api_.SetSessionLogSeverityLevel, typeof(DOrtSetSessionLogSeverityLevel)); @@ -992,6 +993,10 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public delegate IntPtr /*(OrtStatus*)*/ DOrtDisableCpuMemArena(IntPtr /* OrtSessionOptions* */ options); public static DOrtDisableCpuMemArena OrtDisableCpuMemArena; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtDisablePerSessionThreads(IntPtr /* OrtSessionOptions* */ options); + public static DOrtDisablePerSessionThreads OrtDisablePerSessionThreads; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtSetSessionLogId(IntPtr /* OrtSessionOptions* */ options, byte[] /* const char* */ logId); public static DOrtSetSessionLogId OrtSetSessionLogId; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index 7a68246c9b67a..30d005b3c4236 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -696,6 +696,15 @@ public bool EnableCpuMemArena } private bool _enableCpuMemArena = true; + /// + /// Disables the per session threads. Default is true. + /// This makes all sessions in the process use a global TP. + /// + public void DisablePerSessionThreads() + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtDisablePerSessionThreads(handle)); + } + /// /// Log Id to be used for the session. Default is empty string. /// diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs index fd8feda359f90..d6a6b9627f418 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs @@ -55,6 +55,9 @@ public void TestSessionOptions() Assert.Equal(0, opt.InterOpNumThreads); Assert.Equal(GraphOptimizationLevel.ORT_ENABLE_ALL, opt.GraphOptimizationLevel); + // No get, so no verify + opt.DisablePerSessionThreads(); + // try setting options opt.ExecutionMode = ExecutionMode.ORT_PARALLEL; Assert.Equal(ExecutionMode.ORT_PARALLEL, opt.ExecutionMode); @@ -98,7 +101,7 @@ public void TestSessionOptions() Assert.Contains("[ErrorCode:InvalidArgument] Config key is empty", ex.Message); // SessionOptions.RegisterOrtExtensions can be manually tested by referencing the - // Microsoft.ML.OnnxRuntime.Extensions nuget package. After that is done, this should not throw. + // Microsoft.ML.OnnxRuntime.Extensions nuget package. After that is done, this should not throw. ex = Assert.Throws(() => { opt.RegisterOrtExtensions(); }); Assert.Contains("Microsoft.ML.OnnxRuntime.Extensions NuGet package must be referenced", ex.Message); diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs index 715aed7e1d64f..7f3d5d6624b07 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs @@ -145,7 +145,7 @@ private void TestCUDAProviderOptions() private void CanRunInferenceOnAModelWithTensorRT() { string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx"); - + int deviceId = 0; string deviceIdStr = System.Environment.GetEnvironmentVariable("ONNXRUNTIME_TEST_GPU_DEVICE_ID"); if (!string.IsNullOrEmpty(deviceIdStr) && int.TryParse(deviceIdStr, out int parsedValue) && parsedValue >= 0) diff --git a/csharp/tools/MauiModelTester/MauiModelTester.csproj b/csharp/tools/MauiModelTester/MauiModelTester.csproj index a374c2933ce8f..39e688ce6c1b8 100644 --- a/csharp/tools/MauiModelTester/MauiModelTester.csproj +++ b/csharp/tools/MauiModelTester/MauiModelTester.csproj @@ -1,8 +1,8 @@  - net6.0-android;net6.0-ios - $(TargetFrameworks);net6.0-windows10.0.19041.0 + net8.0-ios;net8.0-android34.0 + $(TargetFrameworks);net8.0-windows10.0.19041.0 Exe MauiModelTester true @@ -21,7 +21,7 @@ 1 12.0 - 21.0 + 29.0 10.0.17763.0 10.0.17763.0 true diff --git a/csharp/tools/MauiModelTester/Platforms/Android/AndroidManifest.xml b/csharp/tools/MauiModelTester/Platforms/Android/AndroidManifest.xml index cc320dab474a0..2ef2296d7441f 100644 --- a/csharp/tools/MauiModelTester/Platforms/Android/AndroidManifest.xml +++ b/csharp/tools/MauiModelTester/Platforms/Android/AndroidManifest.xml @@ -4,5 +4,5 @@ - + \ No newline at end of file diff --git a/dockerfiles/Dockerfile.migraphx b/dockerfiles/Dockerfile.migraphx index bc513a8e8ba6d..c3541a8bd3425 100644 --- a/dockerfiles/Dockerfile.migraphx +++ b/dockerfiles/Dockerfile.migraphx @@ -5,57 +5,22 @@ # Dockerfile to run ONNXRuntime with MIGraphX integration #-------------------------------------------------------------------------- -FROM ubuntu:20.04 +FROM rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1 ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime ARG ONNXRUNTIME_BRANCH=main -ARG ROCM_VERSION=5.4 -# MIGraphX version should be the same as ROCm version -ARG MIGRAPHX_VERSION=rocm-5.4.0 -ENV DEBIAN_FRONTEND noninteractive -ENV MIGRAPHX_DISABLE_FAST_GELU=1 -RUN apt-get clean && apt-get update && apt-get install -y locales -RUN locale-gen en_US.UTF-8 -RUN update-locale LANG=en_US.UTF-8 -ENV LC_ALL C.UTF-8 -ENV LANG C.UTF-8 +ENV PATH /code/cmake-3.27.3-linux-x86_64/bin:${PATH} -# Install rocm -RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl && \ - curl -sL http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ - sh -c 'echo deb [arch=amd64] http://repo.radeon.com/rocm/apt/${ROCM_VERSION}/ ubuntu main > /etc/apt/sources.list.d/rocm.list' - -RUN apt-get update &&\ - apt-get install -y sudo git bash build-essential rocm-dev python3-dev python3-pip miopen-hip \ - rocblas half aria2 libnuma-dev pkg-config - -RUN aria2c -q -d /tmp -o cmake-3.27.3-linux-x86_64.tar.gz \ -https://github.com/Kitware/CMake/releases/download/v3.27.3/cmake-3.27.3-linux-x86_64.tar.gz &&\ -tar -zxf /tmp/cmake-3.27.3-linux-x86_64.tar.gz --strip=1 -C /usr - -# Install rbuild -RUN pip3 install https://github.com/RadeonOpenCompute/rbuild/archive/master.tar.gz numpy yapf==0.28.0 - -ENV PATH /opt/miniconda/bin:/code/cmake-3.27.3-linux-x86_64/bin:${PATH} - -# Install MIGraphX from source -RUN mkdir -p /migraphx -RUN cd /migraphx && git clone --depth=1 --branch ${MIGRAPHX_VERSION} https://github.com/ROCmSoftwarePlatform/AMDMIGraphX src -RUN cd /migraphx && rbuild package --cxx /opt/rocm/llvm/bin/clang++ -d /migraphx/deps -B /migraphx/build -S /migraphx/src/ -DPYTHON_EXECUTABLE=/usr/bin/python3 -RUN dpkg -i /migraphx/build/*.deb -RUN rm -rf /migraphx - -# Install rocm ep dependencies RUN apt-get update &&\ - apt-get install -y rocrand rccl hipsparse hipfft hipcub hipblas rocthrust + apt-get install -y migraphx WORKDIR /code # Prepare onnxruntime repository & build onnxruntime RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh &&\ - cd onnxruntime &&\ + cd onnxruntime && pip install --upgrade pip &&\ /bin/sh ./build.sh --allow_running_as_root --cmake_extra_defines ONNXRUNTIME_VERSION=`cat ./VERSION_NUMBER` --config Release --parallel \ --skip_tests --build_wheel --use_rocm --rocm_version=${ROCM_VERSION} --rocm_home /opt/rocm --use_migraphx &&\ pip install /code/onnxruntime/build/Linux/Release/dist/*.whl diff --git a/dockerfiles/Dockerfile.rocm b/dockerfiles/Dockerfile.rocm index 35a676383337b..c242933f677f0 100644 --- a/dockerfiles/Dockerfile.rocm +++ b/dockerfiles/Dockerfile.rocm @@ -5,14 +5,14 @@ # Dockerfile to run ONNXRuntime with ROCm integration #-------------------------------------------------------------------------- -FROM rocm/pytorch:rocm5.4_ubuntu20.04_py3.7_pytorch_1.12.1 +FROM rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1 ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime ARG ONNXRUNTIME_BRANCH=main WORKDIR /code -ENV PATH /opt/miniconda/bin:/code/cmake-3.27.3-linux-x86_64/bin:${PATH} +ENV PATH /code/cmake-3.27.3-linux-x86_64/bin:${PATH} # Prepare onnxruntime repository & build onnxruntime RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ diff --git a/dockerfiles/README.md b/dockerfiles/README.md index f226ebfe8b193..a2e99d66d4654 100644 --- a/dockerfiles/README.md +++ b/dockerfiles/README.md @@ -277,7 +277,7 @@ Nothing else from ONNX Runtime source tree will be copied/installed to the image Note: When running the container you built in Docker, please either use 'nvidia-docker' command instead of 'docker', or use Docker command-line options to make sure NVIDIA runtime will be used and appropiate files mounted from host. Otherwise, CUDA libraries won't be found. You can also [set NVIDIA runtime as default in Docker](https://github.com/dusty-nv/jetson-containers#docker-default-runtime). ## MIGraphX -**Ubuntu 20.04, ROCm5.4, AMDMIGraphX v1.2** +**Ubuntu 20.04, ROCm6.0, MIGraphX** 1. Build the docker image from the Dockerfile in this repository. ``` @@ -291,7 +291,7 @@ Note: When running the container you built in Docker, please either use 'nvidia- ``` ## ROCm -**Ubuntu 20.04, ROCm5.4** +**Ubuntu 20.04, ROCm6.0** 1. Build the docker image from the Dockerfile in this repository. ``` diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index e7b537d6894c8..5f0100fad95a2 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -461,7 +461,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : M
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : M
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -2252,7 +2252,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : I
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : I
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -2795,7 +2795,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Constrain input A data type to 8-bit integer tensor.
T2 : tensor(int8), tensor(uint8)
Constrain input B data type to 8-bit integer tensor.
-
T3 : tensor(float)
+
T3 : tensor(float), tensor(float16)
Constrain input a_scale, b_scale and output Y data type as float tensor.
@@ -2808,22 +2808,23 @@ This version of the operator has been available since version 1 of the 'com.micr And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. 3. Input B's scale and zero point are specified by input scales and zero_points. - Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: - - n_blocks_per_col = (K + block_size - 1) / block_size - - blob_size = block_size / 8 * bits + Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: + - n_blocks_per_col = (K + block_size - 1) / block_size + - blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>) + For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t. + - for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t. + 4bit example: + |.|.|.|.| .|.|.|.| =uint8_t (2x4bit) + - for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted. + 3bit example: + |.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used. + The last uint_8 may have some bits unused. - For a block blob. It is stored in format: - struct Blob { - uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization - uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization - uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization - } Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] - Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is: - - [(N * n_blocks_per_col + 1) / 2] if bits <=4 - - [N * n_blocks_per_col] if bits > 4 - + Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B. + - [CeilDiv((N * n_blocks_per_col + 1) *bits, 8)] + If zero_points has same type as A, it's not packed and has the same shape as Scales. #### Version @@ -2844,17 +2845,19 @@ This version of the operator has been available since version 1 of the 'com.micr
number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.
-#### Inputs (3 - 4) +#### Inputs (3 - 5)
A : T1
The input tensor, not quantized
B : T2
-
1-dimensional data blob
+
1 or 2 dimensional data blob
scales : T1
quantization scale
-
zero_points (optional) : T2
+
zero_points (optional) : T3
quantization zero points
+
g_idx (optional) : T4
+
group_idx
#### Outputs @@ -2869,8 +2872,12 @@ This version of the operator has been available since version 1 of the 'com.micr
T1 : tensor(float), tensor(float16)
Constrain input and output types to float/half_float tensors.
-
T2 : tensor(uint8)
-
Constrain quantized weight types to uint8.
+
T2 : tensor(uint8), tensor(int32)
+
Constrain quantized weight types to uint8/int32.
+
T3 : tensor(uint8), tensor(int32), tensor(float16), tensor(float)
+
Constrain quantized zero point types to uint8/int32/float16/float.
+
T4 : tensor(int32)
+
the index tensor.
@@ -5154,7 +5161,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : I
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : I
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5743,12 +5750,14 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
+
beginning_timestamp_token_id : int
+
The id of the first timestamp
decoder : graph (required)
Decoder subgraph to execute in a loop.
decoder_output_cross_qk : int
If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.
decoder_start_token_id : int
-
The id of the token that indicates decoding starts.
+
The id of the token that indicates decoding starts (i.e. the start of transcription token id)
early_stopping : int
early stop or not
encoder : graph
@@ -5761,10 +5770,18 @@ This version of the operator has been available since version 1 of the 'com.micr
Must be 2 for whisper
no_repeat_ngram_size : int
no repeat ngrams size
-
no_speech_token : int
+
no_speech_token_id : int
The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.
+
no_timestamps_token_id : int
+
The id of the token that indicates no timestamps
pad_token_id : int (required)
The id of the padding token
+
start_of_lm_token_id : int
+
The id of the token that indicates LM starts
+
transcribe_token_id : int
+
The id of the transcribe task
+
translate_token_id : int
+
The id of the translate task
vocab_size : int
Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
@@ -5783,11 +5800,11 @@ This version of the operator has been available since version 1 of the 'com.micr
num_return_sequences : I
The number of returned sequences in the batch. Shape is (1)
length_penalty (optional) : T
-
Exponential penalty to the length. Default value 1.0 means no penalty.Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences.Shape is (1,)
+
Exponential penalty to the length. Default value 1.0 means no penalty. Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences. Shape is (1,)
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : M
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : M
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5797,7 +5814,7 @@ This version of the operator has been available since version 1 of the 'com.micr
logits_processor (optional) : I
Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)
cross_qk_layer_head (optional) : I
-
Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect allits shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
+
Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
extra_decoding_ids (optional) : I
Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len).In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) are treated as stop of the extra_decoding_ids for corresponding batch.
temperature (optional) : T
@@ -5812,11 +5829,11 @@ This version of the operator has been available since version 1 of the 'com.micr
sequences_scores (optional) : T
Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)
scores (optional) : T
-
Processed beam scores for each vocabulary token at each generation step.Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam.Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)
+
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam. Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)
cross_qk (optional) : V
-
Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers,B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F].If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]
+
Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers, B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]. If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]
non_speech_probs (optional) : T
-
For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token.Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph.The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]
+
For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id. The shape of non_speech_probs is [B]
#### Type Constraints diff --git a/docs/Memory_Optimizer.md b/docs/Memory_Optimizer.md index 97f7e7ff2c14b..eaa48c9da0609 100644 --- a/docs/Memory_Optimizer.md +++ b/docs/Memory_Optimizer.md @@ -51,6 +51,7 @@ There are two modes to enable the memory optimizations: - Plan 8 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1 ``` 3. As shown above, `Config` is a string representative for a re-computable subgraph. All are enabled for recompute in this case. +4. By `export ORTMODULE_MEMORY_OPT_LEVEL=2`, all plans including compromised recomptable subgraphs will also be enabled. ### Mode 2 - Advanced Usage (User Selected Subgraph Recompute) diff --git a/docs/ORTModule_Convergence_Notes.md b/docs/ORTModule_Convergence_Notes.md index 791b6c32c9b48..2374e7b7c538a 100644 --- a/docs/ORTModule_Convergence_Notes.md +++ b/docs/ORTModule_Convergence_Notes.md @@ -89,7 +89,7 @@ The limitation of `GlobalSubscriberManager` is, only 'nn.Module's forward output dump the intermediate tensors in a `nn.Module`'s forward function, refer to the following example: ```diff -+ from onnxruntime.training.utils import inspect_activation ++ from onnxruntime.training.utils.hooks import inspect_activation class BloomForCausalLM(BloomPreTrainedModel): def __init__(self, config: BloomConfig): ... diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index 91057d3dfb120..54137937ad56d 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -246,7 +246,7 @@ to standard outputs. #### ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER - **Feature Area**: *ORTMODULE/Optimizations* -- **Description**: By default, this is disabled. This env var can be used for enabling or disabling the embedding input +- **Description**: By default, this is enabled. This env var can be used for enabling or disabling the embedding input data sparsity based performance optimizations. ```bash @@ -287,13 +287,16 @@ A classical usage of disabling the deep copy: when the deep copy before module e #### ORTMODULE_MEMORY_OPT_LEVEL - **Feature Area**: *ORTMODULE/Optimizations* -- **Description**: By default, the level is 0. This env var can be used for enabling recomputation for reducing memory peak requirement. Setting the level to be 0 means all detected subgraphs with each transformer-based model layer generating stashed activations will be recomputed. This is conceptually equivalent to PyTorch's gradient checkpoint. When level is not 0, check Check [Memory Optimizer for ONNX Runtime Training](Memory_Optimizer.md) for more details. +- **Description**: By default, the level is 0. This env var can be used for enabling recomputation for reducing memory peak requirement. + - Setting the level to be 1 means all detected recomputable subgraphs (NOT including compromised recomputable graphs) with each transformer-based model layer generating stashed activations will be recomputed. This is conceptually equivalent to PyTorch's gradient checkpoint. + - Setting the level to be 2 means all detected recomputable subgraphs (including compromised recomputable graphs) with each transformer-based model layer generating stashed activations will be recomputed. This is conceptually equivalent to PyTorch's gradient checkpoint. + - When the level is 0, check Check [Memory Optimizer for ONNX Runtime Training](Memory_Optimizer.md) for more details. ```bash export ORTMODULE_MEMORY_OPT_LEVEL=0 ``` -### ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT +#### ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT - **Feature Area**: *ORTMODULE/Optimizations* - **Description**: By default, the memory-efficient gradient management is turned off. The gradient after it is computed in ONNX Runtime, will trigger the corresponding parameter's backward function through `PythonOpGrad` operator. This would help release the gradient buffer managed in ONNX Runtime, which originally is released once all backward computation finishes. diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 9d9b266355335..eddc3b7873d80 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -127,6 +127,7 @@ Do not modify directly.* |GatherND|*in* data:**T**
*in* indices:**tensor(int64)**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**indices** = tensor(int64)| |||12|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**indices** = tensor(int64)| |||11|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**indices** = tensor(int64)| +|Gelu|*in* X:**T**
*out* Y:**T**|20+|**T** = tensor(float)| |Gemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)| |||[11, 12]|**T** = tensor(double), tensor(float)| |||[9, 10]|**T** = tensor(double), tensor(float)| @@ -159,9 +160,9 @@ Do not modify directly.* |||[1, 10]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ImageScaler|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| |InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(float)| -|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| |||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| -|IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| |||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float)| @@ -469,7 +470,7 @@ Do not modify directly.* |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| |MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)| -|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(uint8)
**T4** = tensor(int32)| |MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float)| |MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)| @@ -606,6 +607,7 @@ Do not modify directly.* |GatherND|*in* data:**T**
*in* indices:**tensor(int64)**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)
**indices** = tensor(int64)| |||12|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)
**indices** = tensor(int64)| |||11|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)
**indices** = tensor(int64)| +|Gelu|*in* X:**T**
*out* Y:**T**|20+|**T** = tensor(double), tensor(float), tensor(float16)| |Gemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |||[9, 10]|**T** = tensor(double), tensor(float), tensor(float16)| @@ -617,6 +619,7 @@ Do not modify directly.* |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| |GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| |||[12, 15]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| +|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float)
**T2** = tensor(float)| |HardSigmoid|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| |Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|19+|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[14, 18]|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -628,6 +631,11 @@ Do not modify directly.* |||[1, 10]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ImageScaler|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| +|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| +|IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|||[13, 19]|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| +|||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| |||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |LSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| @@ -731,7 +739,8 @@ Do not modify directly.* |||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| |||[5, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| |||[1, 4]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|13+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| +|Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|18+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| +|||[13, 17]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |||[11, 12]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |ReverseSequence|*in* input:**T**
*in* sequence_lens:**tensor(int64)**
*out* Y:**T**|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -744,7 +753,9 @@ Do not modify directly.* |||[9, 10]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||8|**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Scatter|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| -|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -763,7 +774,7 @@ Do not modify directly.* |Sigmoid|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |Sign|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)
**V** = tensor(double), tensor(float), tensor(float16)| +|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)
**V** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(double), tensor(float), tensor(float16)| |Size|*in* data:**T**
*out* size:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| @@ -782,7 +793,7 @@ Do not modify directly.* |||[13, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[2, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |Squeeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* squeezed:**T**

or

*in* data:**T**
*out* squeezed:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -849,7 +860,7 @@ Do not modify directly.* |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)| -|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc2_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| @@ -1257,13 +1268,16 @@ Do not modify directly.* |BiasSplitGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +|DynamicQuantizeMatMul|*in* A:**T1**
*in* B:**T2**
*in* b_scale:**T1**
*in* b_zero_point:**T2**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |EmbedLayerNormalization|*in* input_ids:**T1**
*in* segment_ids:**T1**
*in* word_embedding:**T**
*in* position_embedding:**T**
*in* segment_embedding:**T**
*in* gamma:**T**
*in* beta:**T**
*in* mask:**T1**
*in* position_ids:**T1**
*out* output:**T**
*out* mask_index:**T1**
*out* embedding_sum:**T**|1+|**T** = tensor(float), tensor(float16)| |FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| +|MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| +|QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| |QLinearAdd|*in* A:**T**
*in* A_scale:**tensor(float)**
*in* A_zero_point:**T**
*in* B:**T**
*in* B_scale:**tensor(float)**
*in* B_zero_point:**T**
*in* C_scale:**tensor(float)**
*in* C_zero_point:**T**
*out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearAveragePool|*in* X:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearConcat|*in* Y_scale:**TF**
*in* Y_zero_point:**T8**
*in* inputs:**TV**
*out* Y:**T8**|1+|**T8** = tensor(int8), tensor(uint8)
**TF** = tensor(float)
**TV** = tensor(float), tensor(int8), tensor(uint8)| diff --git a/docs/python/conf.py b/docs/python/conf.py index 7ab2d42aa15e1..438c21570eaac 100644 --- a/docs/python/conf.py +++ b/docs/python/conf.py @@ -2,12 +2,10 @@ # Licensed under the MIT License. # pylint: disable=C0103 -# -*- coding: utf-8 -*- -# -# Configuration file for the Sphinx documentation builder. +"""Configuration file for the Sphinx documentation builder.""" import os -import shutil # noqa: F401 +import shutil import sys sys.path.append(os.path.join(os.path.dirname(__file__), "..", "_common")) @@ -127,7 +125,5 @@ def setup(app): urllib.request.urlretrieve(url, dest) loc = os.path.split(dest)[-1] if not os.path.exists(loc): - import shutil # noqa: F811 - shutil.copy(dest, loc) return app diff --git a/docs/python/examples/plot_train_convert_predict.py b/docs/python/examples/plot_train_convert_predict.py index dcbc84b20767a..44b6bb74c29df 100644 --- a/docs/python/examples/plot_train_convert_predict.py +++ b/docs/python/examples/plot_train_convert_predict.py @@ -134,7 +134,7 @@ def loop(X_test, fct, n=None): nrow = X_test.shape[0] if n is None: n = nrow - for i in range(0, n): + for i in range(n): im = i % nrow fct(X_test[im : im + 1]) diff --git a/docs/python/on_device_training/training_api.rst b/docs/python/on_device_training/training_api.rst index 64f81f3f18142..f4856b085b7fc 100644 --- a/docs/python/on_device_training/training_api.rst +++ b/docs/python/on_device_training/training_api.rst @@ -42,12 +42,32 @@ Sample usage: CheckpointState.save_checkpoint(state, path_to_the_checkpoint_artifact) +.. autoclass:: onnxruntime.training.api.checkpoint_state.Parameter + :members: + :show-inheritance: + :member-order: bysource + :inherited-members: + :special-members: __repr__ + +.. autoclass:: onnxruntime.training.api.checkpoint_state.Parameters + :members: + :show-inheritance: + :member-order: bysource + :inherited-members: + :special-members: __getitem__, __setitem__, __contains__, __iter__, __repr__, __len__ + +.. autoclass:: onnxruntime.training.api.checkpoint_state.Properties + :members: + :show-inheritance: + :member-order: bysource + :inherited-members: + :special-members: __getitem__, __setitem__, __contains__, __iter__, __repr__, __len__ + .. autoclass:: onnxruntime.training.api.CheckpointState :members: :show-inheritance: :member-order: bysource :inherited-members: - :special-members: __getitem__, __setitem__, __contains__ .. autoclass:: onnxruntime.training.api.Module :members: diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h index fbeee8a2aedc5..3a3b5cb6888f2 100644 --- a/include/onnxruntime/core/framework/data_types_internal.h +++ b/include/onnxruntime/core/framework/data_types_internal.h @@ -305,7 +305,7 @@ class CallableDispatchableHelper { return 0; } - void CheckCalledOnce() { + void CheckCalledOnce() const { ORT_ENFORCE(called_ == 1, "Unsupported data type: ", dt_type_); } }; diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 31c988f500779..40ca96a19aef1 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -33,6 +33,8 @@ class Node; #include "core/framework/stream_handles.h" #include "core/framework/tuning_context.h" +struct OrtRunOptions; + namespace onnxruntime { /** @@ -51,6 +53,8 @@ struct NodeComputeInfo { DestroyFunctionStateFunc release_state_func; }; +using RunOptions = OrtRunOptions; + enum class DataLayout { NCHW, NHWC, @@ -184,7 +188,7 @@ class IExecutionProvider { Run may not be finished on device This function should be regarded as the point after which a new Run would start to submit commands from CPU */ - virtual common::Status OnRunStart() { return Status::OK(); } + virtual common::Status OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { return Status::OK(); } /** Called when InferenceSession::Run ended @@ -192,25 +196,27 @@ class IExecutionProvider { may not be finished on device This function should be regarded as the point that all commands of current Run has been submmited by CPU */ - virtual common::Status OnRunEnd(bool /*sync_stream*/) { return Status::OK(); } + virtual common::Status OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) { + return Status::OK(); + } /** Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for - the provider. Currently only CUDA execution provider supports it. + the provider. */ virtual bool IsGraphCaptureEnabled() const { return false; } /** - Indicate whether the graph has been captured and instantiated. Currently - only CUDA execution provider supports it. + Indicate whether the graph has been captured and instantiated. */ - virtual bool IsGraphCaptured() const { return false; } + virtual bool IsGraphCaptured(int /*graph_annotation_id*/) const { return false; } /** - Run the instantiated graph. Currently only CUDA execution provider supports - it. + Run the instantiated graph. */ - virtual common::Status ReplayGraph() { return Status::OK(); } + virtual common::Status ReplayGraph(int /*graph_annotation_id*/) { + return Status::OK(); + } /** Called when session creation is complete diff --git a/include/onnxruntime/core/framework/stream_handles.h b/include/onnxruntime/core/framework/stream_handles.h index c235ee904762e..26d78133b52fc 100644 --- a/include/onnxruntime/core/framework/stream_handles.h +++ b/include/onnxruntime/core/framework/stream_handles.h @@ -100,6 +100,8 @@ class Stream { return nullptr; } + virtual WaitNotificationFn GetWaitNotificationFn() const { return nullptr; } + private: StreamHandle handle_; const OrtDevice& device_; diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index 9b26ba914c7dd..8e04050d089a0 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -31,6 +31,7 @@ constexpr size_t kMaxExecutionProviderNameLen = 30; constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider"; constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider"; +constexpr const char* kCudaNHWCExecutionProvider = "CUDANHWCExecutionProvider"; constexpr const char* kDnnlExecutionProvider = "DnnlExecutionProvider"; constexpr const char* kOpenVINOExecutionProvider = "OpenVINOExecutionProvider"; constexpr const char* kVitisAIExecutionProvider = "VitisAIExecutionProvider"; diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 22827d43b200f..b16d52dbdab68 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -21,7 +21,7 @@ #pragma warning(pop) #endif -#include "flatbuffers/flatbuffers.h" +#include "core/common/flatbuffers.h" #include "core/common/gsl.h" @@ -753,7 +753,6 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi cannot be overridden at runtime. If the initializer is not found or is not constant, a nullptr is returned. @param check_outer_scope If true and the graph is a subgraph, check ancestor graph/s for 'name' if not found in 'graph'. - @remarks check_outer_scope of true is not supported in a minimal build */ const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name, bool check_outer_scope) const; diff --git a/include/onnxruntime/core/graph/graph_viewer.h b/include/onnxruntime/core/graph/graph_viewer.h index 3cdbb07099cab..1023d50310181 100644 --- a/include/onnxruntime/core/graph/graph_viewer.h +++ b/include/onnxruntime/core/graph/graph_viewer.h @@ -165,7 +165,8 @@ class GraphViewer { if a const initializer is part of the underlying Graph but not part of this GraphViewer, it will still be returned instead of nullptr */ - const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name, bool check_outer_scope) const; + const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name, + bool check_outer_scope = true) const; /** Get the Node containing this Graph if IsSubgraph is true. Returns nullptr otherwise. */ const Node* ParentNode() const noexcept { return graph_->ParentNode(); } diff --git a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h index 03715eb5b78b2..55abb90b981f5 100644 --- a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h +++ b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h @@ -28,9 +28,12 @@ enum COREMLFlags { // dynamic shapes. However, the performance may be negatively impacted if inputs have dynamic shapes. COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES = 0x008, + // Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or later. + COREML_FLAG_CREATE_MLPROGRAM = 0x010, + // Keep COREML_FLAG_LAST at the end of the enum definition // And assign the last COREMLFlag to it - COREML_FLAG_LAST = COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES, + COREML_FLAG_LAST = COREML_FLAG_CREATE_MLPROGRAM, }; #ifdef __cplusplus diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h index 1370f5c4c5e10..7104e70c3a8a9 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_context.h +++ b/include/onnxruntime/core/providers/cuda/cuda_context.h @@ -37,6 +37,7 @@ struct CudaContext : public CustomOpContext { bool cudnn_conv1d_pad_to_nc1d = false; bool enable_skip_layer_norm_strict_mode = false; bool prefer_nhwc = false; + bool use_tf32 = true; void Init(const OrtKernelContext& kernel_ctx) { cuda_stream = FetchResource(kernel_ctx, CudaResource::cuda_stream_t); @@ -52,11 +53,12 @@ struct CudaContext : public CustomOpContext { cudnn_conv1d_pad_to_nc1d = FetchResource(kernel_ctx, CudaResource::cudnn_conv1d_pad_to_nc1d_t); enable_skip_layer_norm_strict_mode = FetchResource(kernel_ctx, CudaResource::enable_skip_layer_norm_strict_mode_t); prefer_nhwc = FetchResource(kernel_ctx, CudaResource::prefer_nhwc_t); + use_tf32 = FetchResource(kernel_ctx, CudaResource::use_tf32_t); } template T FetchResource(const OrtKernelContext& kernel_ctx, CudaResource resource_type) { - if (sizeof(T) > sizeof(void*)) { + if constexpr (sizeof(T) > sizeof(void*)) { ORT_CXX_API_THROW("void* is not large enough to hold resource type: " + std::to_string(resource_type), OrtErrorCode::ORT_INVALID_ARGUMENT); } const auto& ort_api = Ort::GetApi(); diff --git a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h index 82bb8ba83be4a..6d53760ab60b5 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h +++ b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h @@ -37,4 +37,5 @@ struct OrtCUDAProviderOptionsV2 { // The strict mode has better accuracy but lower performance. int prefer_nhwc = 0; // make the CUDA EP NHWC preferred int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not + int use_tf32 = 1; // use TF32 }; diff --git a/include/onnxruntime/core/providers/cuda/cuda_resource.h b/include/onnxruntime/core/providers/cuda/cuda_resource.h index c0e6328f27122..00e7dec5727d1 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_resource.h +++ b/include/onnxruntime/core/providers/cuda/cuda_resource.h @@ -18,4 +18,5 @@ enum CudaResource : int { cudnn_conv1d_pad_to_nc1d_t, enable_skip_layer_norm_strict_mode_t, prefer_nhwc_t, -}; \ No newline at end of file + use_tf32_t, +}; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 2ce9d361e8e56..41b034e9c1dcc 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -1837,14 +1837,28 @@ struct OrtApi { /** \brief Used for custom operators, get an input of a kernel * - * \see ::OrtCustomOp + * The function attempts fetches the input of the kernel. If the input is optional + * and not present, the function returns success and out is set to nullptr. + * + * \param[in] context ::OrtKernelContext instance + * \param[in] input index. See KernelContext_GetInputCount for boundaries check. + * \param[in, out] returns a ptr to OrtValue if the input is present + * + * \snippet{doc} snippets.dox OrtStatus Return Value */ ORT_API2_STATUS(KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index, _Out_ const OrtValue** out); /** \brief Used for custom operators, get an output of a kernel * - * \see ::OrtCustomOp + * The function attempts fetches the output of the kernel. If the output is optional + * and not present, the function returns success and out is set to nullptr. + * + * \param[in] context ::OrtKernelContext instance + * \param[in] output index. See KernelContext_GetOutputCount for boundaries check. + * \param[in, out] returns a ptr to OrtValue if the output is present + * + * \snippet{doc} snippets.dox OrtStatus Return Value */ ORT_API2_STATUS(KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count, _Outptr_ OrtValue** out); @@ -3619,6 +3633,10 @@ struct OrtApi { * - "73" * - "75" * "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device). + "enable_htp_fp16_precision": Only used for float32 model. + Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision. + - "0": Default. With fp32 precision. + - "1": With fp16 precision. * * SNPE supported keys: * "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16", @@ -4569,6 +4587,43 @@ struct OrtApi { _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); + + /** \brief Append VitisAI provider to session options + * + * If VitisAI is not available (due to a non VitisAI enabled build, or if VitisAI is not installed on the system), this function will return failure. + * + * \param[in] options + * \param[in] provider_options_keys + * \param[in] provider_options_values + * \param[in] num_keys + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_VitisAI, + _In_ OrtSessionOptions* options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); + + /** \brief Get scratch buffer from the corresponding allocator under the sepcific OrtMemoryInfo object. + * NOTE: callers are responsible to release this scratch buffer from the corresponding allocator + * \param[in] context OrtKernelContext instance + * \param[in] mem_info OrtMemoryInfo instance + * \param[in] count_or_bytes How many bytes is this scratch buffer + * \param[out] out A pointer to the scrach buffer + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out); + + /** \brief Get allocator from KernelInfo for a specific memory type. Please use C API ReleaseAllocator to release out object + * + * \param[in] info OrtKernelInfo instance + * \param[in] mem_type OrtMemType object + * \param[out] out A pointer to OrtAllocator + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out); }; /* @@ -4666,6 +4721,13 @@ struct OrtCustomOp { // Get start range int(ORT_API_CALL* GetStartVersion)(_In_ const struct OrtCustomOp* op); int(ORT_API_CALL* GetEndVersion)(_In_ const struct OrtCustomOp* op); + + // Get the inplace_map that defines which output can reuse which input + // Callers will provide 2 raw int* and pass in their address, this function will fill these 2 arrays + // when return, output (*output_index)[i] may reuse the input (*input_index[i]). + // The return value is the size of these 2 arrays. + // Callers are responsible to delete these 2 arrays after use. + size_t(ORT_API_CALL* GetMayInplace)(_Out_ int** input_index, _Out_ int** output_index); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 7a553f9f94006..60540514fbfa6 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -901,6 +901,9 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {}); SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name); ///< Wraps OrtApi::RegisterCustomOpsUsingFunction + + ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_VitisAI + SessionOptionsImpl& AppendExecutionProvider_VitisAI(const std::unordered_map& provider_options = {}); }; } // namespace detail @@ -2052,7 +2055,11 @@ struct KernelContext { explicit KernelContext(OrtKernelContext* context); size_t GetInputCount() const; size_t GetOutputCount() const; + // If input is optional and is not present, the method returns en empty ConstValue + // which can be compared to nullptr. ConstValue GetInput(size_t index) const; + // If outout is optional and is not present, the method returns en empty UnownedValue + // which can be compared to nullptr. UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const; UnownedValue GetOutput(size_t index, const std::vector& dims) const; void* GetGPUComputeStream() const; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 957e849cf5d4d..23246adff254a 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -885,6 +885,25 @@ inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_Ope return *this; } +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_VitisAI(const std::unordered_map& provider_options) { + auto num_entries = provider_options.size(); + std::vector keys, values; + if (num_entries > 0) { + keys.reserve(num_entries); + values.reserve(num_entries); + + for (const auto& entry : provider_options) { + keys.push_back(entry.first.c_str()); + values.push_back(entry.second.c_str()); + } + } + + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_VitisAI(this->p_, keys.data(), values.data(), num_entries)); + + return *this; +} + template inline SessionOptionsImpl& SessionOptionsImpl::RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs) { diff --git a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h index 1f5fcd50e185c..c80b8c0c164b6 100644 --- a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h @@ -30,3 +30,22 @@ static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memor // Per default it will be set to '0' // Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream. static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers"; + +// Set HTP performance mode for QNN HTP backend before session run. +// options for HTP performance mode: "burst", "balanced", "default", "high_performance", +// "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver", +// "sustained_high_performance". Default to "default". +static const char* const kOrtRunOptionsConfigQnnPerfMode = "qnn.htp_perf_mode"; + +// Set HTP performance mode for QNN HTP backend post session run. +static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_mode_post_run"; + +// Set RPC control latency for QNN HTP backend +static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency"; + +// Set graph annotation id for CUDA EP. Use with enable_cuda_graph=true. +// The value should be an integer. If the value is not set, the default value is 0 and +// ORT session only captures one cuda graph before another capture is requested. +// If the value is set to -1, cuda graph capture/replay is disabled in that run. +// User are not expected to set the value to 0 as it is reserved for internal use. +static const char* const kOrtRunOptionsConfigCudaGraphAnnotation = "gpu_graph_id"; diff --git a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java index eb124decf75f3..cec3fadf446ca 100644 --- a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java +++ b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; @@ -14,7 +14,18 @@ public enum CoreMLFlags implements OrtFlags { /** Enables CoreML on subgraphs. */ ENABLE_ON_SUBGRAPH(2), // COREML_FLAG_ENABLE_ON_SUBGRAPH(0x002) /** Only enable usage of CoreML if the device has an Apple Neural Engine. */ - ONLY_ENABLE_DEVICE_WITH_ANE(4); // COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE(0x004), + ONLY_ENABLE_DEVICE_WITH_ANE(4), // COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE(0x004) + /** + * Only allow CoreML EP to take nodes with inputs with static shapes. By default it will also + * allow inputs with dynamic shapes. However, the performance may be negatively impacted if inputs + * have dynamic shapes. + */ + ONLY_ALLOW_STATIC_INPUT_SHAPES(8), // COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES(0x008) + /** + * Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or + * later. + */ + CREATE_MLPROGRAM(16); // COREML_FLAG_CREATE_MLPROGRAM(0x010) /** The native value of the enum. */ public final int value; diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 7fef2dc784b7b..9925197e4507c 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -673,7 +673,7 @@ private void runProvider(OrtProvider provider) throws OrtException { // CoreML gives slightly different answers on a 2020 13" M1 MBP assertArrayEquals(expectedOutput, resultArray, 1e-2f); } else { - assertArrayEquals(expectedOutput, resultArray, 1e-6f); + assertArrayEquals(expectedOutput, resultArray, 1e-5f); } } catch (OrtException e) { throw new IllegalStateException("Failed to execute a scoring operation", e); diff --git a/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java b/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java index 1ed883ace36e5..0e3bc15ba9c70 100644 --- a/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java +++ b/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java @@ -96,7 +96,7 @@ private static void runProvider(OrtProvider provider, OrtSession.SessionOptions OnnxValue resultTensor = result.get(0); float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue()); assertEquals(expectedOutput.length, resultArray.length); - assertArrayEquals(expectedOutput, resultArray, 1e-6f); + assertArrayEquals(expectedOutput, resultArray, 1e-5f); } catch (OrtException e) { throw new IllegalStateException("Failed to execute a scoring operation", e); } diff --git a/js/common/lib/backend-impl.ts b/js/common/lib/backend-impl.ts index 3e1e833addb91..e90efd7b97c29 100644 --- a/js/common/lib/backend-impl.ts +++ b/js/common/lib/backend-impl.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. import {Backend} from './backend.js'; +import {InferenceSession} from './inference-session.js'; interface BackendInfo { backend: Backend; @@ -10,6 +11,7 @@ interface BackendInfo { initPromise?: Promise; initialized?: boolean; aborted?: boolean; + error?: string; } const backends: Map = new Map(); @@ -60,43 +62,100 @@ export const registerBackend = (name: string, backend: Backend, priority: number }; /** - * Resolve backend by specified hints. + * Try to resolve and initialize a backend. * - * @param backendHints - a list of execution provider names to lookup. If omitted use registered backends as list. - * @returns a promise that resolves to the backend. + * @param backendName - the name of the backend. + * @returns the backend instance if resolved and initialized successfully, or an error message if failed. + */ +const tryResolveAndInitializeBackend = async(backendName: string): Promise => { + const backendInfo = backends.get(backendName); + if (!backendInfo) { + return 'backend not found.'; + } + + if (backendInfo.initialized) { + return backendInfo.backend; + } else if (backendInfo.aborted) { + return backendInfo.error!; + } else { + const isInitializing = !!backendInfo.initPromise; + try { + if (!isInitializing) { + backendInfo.initPromise = backendInfo.backend.init(backendName); + } + await backendInfo.initPromise; + backendInfo.initialized = true; + return backendInfo.backend; + } catch (e) { + if (!isInitializing) { + backendInfo.error = `${e}`; + backendInfo.aborted = true; + } + return backendInfo.error!; + } finally { + delete backendInfo.initPromise; + } + } +}; + +/** + * Resolve execution providers from the specific session options. + * + * @param options - the session options object. + * @returns a promise that resolves to a tuple of an initialized backend instance and a session options object with + * filtered EP list. * * @ignore */ -export const resolveBackend = async(backendHints: readonly string[]): Promise => { - const backendNames = backendHints.length === 0 ? backendsSortedByPriority : backendHints; - const errors = []; - for (const backendName of backendNames) { - const backendInfo = backends.get(backendName); - if (backendInfo) { - if (backendInfo.initialized) { - return backendInfo.backend; - } else if (backendInfo.aborted) { - continue; // current backend is unavailable; try next - } +export const resolveBackendAndExecutionProviders = async(options: InferenceSession.SessionOptions): + Promise<[backend: Backend, options: InferenceSession.SessionOptions]> => { + // extract backend hints from session options + const eps = options.executionProviders || []; + const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); + const backendNames = backendHints.length === 0 ? backendsSortedByPriority : backendHints; - const isInitializing = !!backendInfo.initPromise; - try { - if (!isInitializing) { - backendInfo.initPromise = backendInfo.backend.init(backendName); + // try to resolve and initialize all requested backends + let backend: Backend|undefined; + const errors = []; + const availableBackendNames = new Set(); + for (const backendName of backendNames) { + const resolveResult = await tryResolveAndInitializeBackend(backendName); + if (typeof resolveResult === 'string') { + errors.push({name: backendName, err: resolveResult}); + } else { + if (!backend) { + backend = resolveResult; + } + if (backend === resolveResult) { + availableBackendNames.add(backendName); + } } - await backendInfo.initPromise; - backendInfo.initialized = true; - return backendInfo.backend; - } catch (e) { - if (!isInitializing) { - errors.push({name: backendName, err: e}); + } + + // if no backend is available, throw error. + if (!backend) { + throw new Error(`no available backend found. ERR: ${errors.map(e => `[${e.name}] ${e.err}`).join(', ')}`); + } + + // for each explicitly requested backend, if it's not available, output warning message. + for (const {name, err} of errors) { + if (backendHints.includes(name)) { + // eslint-disable-next-line no-console + console.warn(`removing requested execution provider "${ + name}" from session options because it is not available: ${err}`); } - backendInfo.aborted = true; - } finally { - delete backendInfo.initPromise; } - } - } - throw new Error(`no available backend found. ERR: ${errors.map(e => `[${e.name}] ${e.err}`).join(', ')}`); -}; + const filteredEps = eps.filter(i => availableBackendNames.has(typeof i === 'string' ? i : i.name)); + + return [ + backend, new Proxy(options, { + get: (target, prop) => { + if (prop === 'executionProviders') { + return filteredEps; + } + return Reflect.get(target, prop); + } + }) + ]; + }; diff --git a/js/common/lib/backend.ts b/js/common/lib/backend.ts index 9bfcb12206057..8c07bdd5c5c4a 100644 --- a/js/common/lib/backend.ts +++ b/js/common/lib/backend.ts @@ -58,7 +58,7 @@ export interface TrainingSessionHandler extends SessionHandler { options: InferenceSession.RunOptions): Promise; getParametersSize(trainableOnly: boolean): Promise; - loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; + loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise; getContiguousParameters(trainableOnly: boolean): Promise; } @@ -77,8 +77,8 @@ export interface Backend { Promise; createTrainingSessionHandler? - (checkpointStateUriOrBuffer: TrainingSession.URIorBuffer, trainModelUriOrBuffer: TrainingSession.URIorBuffer, - evalModelUriOrBuffer: TrainingSession.URIorBuffer, optimizerModelUriOrBuffer: TrainingSession.URIorBuffer, + (checkpointStateUriOrBuffer: TrainingSession.UriOrBuffer, trainModelUriOrBuffer: TrainingSession.UriOrBuffer, + evalModelUriOrBuffer: TrainingSession.UriOrBuffer, optimizerModelUriOrBuffer: TrainingSession.UriOrBuffer, options: InferenceSession.SessionOptions): Promise; } diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index 6299c26159400..b139c719e863f 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -36,6 +36,7 @@ export declare namespace Env { /** * set or get a boolean value indicating whether to enable trace. * + * @deprecated Use `env.trace` instead. If `env.trace` is set, this property will be ignored. * @defaultValue `false` */ trace?: boolean; @@ -142,13 +143,48 @@ export declare namespace Env { */ ondata?: (data: WebGpuProfilingData) => void; }; + /** + * Set or get the power preference. + * + * Setting this property only has effect before the first WebGPU inference session is created. The value will be + * used as options for `navigator.gpu.requestAdapter()`. + * + * See {@link https://gpuweb.github.io/gpuweb/#dictdef-gpurequestadapteroptions} for more details. + * + * @defaultValue `undefined` + */ + powerPreference?: 'low-power'|'high-performance'; + /** + * Set or get the force fallback adapter flag. + * + * Setting this property only has effect before the first WebGPU inference session is created. The value will be + * used as options for `navigator.gpu.requestAdapter()`. + * + * See {@link https://gpuweb.github.io/gpuweb/#dictdef-gpurequestadapteroptions} for more details. + * + * @defaultValue `undefined` + */ + forceFallbackAdapter?: boolean; + /** + * Get the adapter for WebGPU. + * + * This property is only available after the first WebGPU inference session is created. + * + * When use with TypeScript, the type of this property is `GPUAdapter` defined in "@webgpu/types". + * Use `const adapter = env.webgpu.adapter as GPUAdapter;` in TypeScript to access this property with correct type. + * + * see comments on {@link Tensor.GpuBufferType} + */ + readonly adapter: unknown; /** * Get the device for WebGPU. * + * This property is only available after the first WebGPU inference session is created. + * * When use with TypeScript, the type of this property is `GPUDevice` defined in "@webgpu/types". * Use `const device = env.webgpu.device as GPUDevice;` in TypeScript to access this property with correct type. * - * see comments on {@link GpuBufferType} for more details about why not use types defined in "@webgpu/types". + * see comments on {@link Tensor.GpuBufferType} for more details about why not use types defined in "@webgpu/types". */ readonly device: unknown; /** @@ -167,6 +203,7 @@ export interface Env { * @defaultValue `'warning'` */ logLevel?: 'verbose'|'info'|'warning'|'error'|'fatal'; + /** * Indicate whether run in debug mode. * @@ -174,6 +211,13 @@ export interface Env { */ debug?: boolean; + /** + * set or get a boolean value indicating whether to enable trace. + * + * @defaultValue `false` + */ + trace?: boolean; + /** * Get version of the current package. */ diff --git a/js/common/lib/index.ts b/js/common/lib/index.ts index d7c98380f3fa4..3ed56b3c2e812 100644 --- a/js/common/lib/index.ts +++ b/js/common/lib/index.ts @@ -11,7 +11,7 @@ * - [onnxruntime-react-native](https://www.npmjs.com/package/onnxruntime-react-native) * * See also: - * - [Get Started](https://onnxruntime.ai/docs/get-started/with-javascript.html) + * - [Get Started](https://onnxruntime.ai/docs/get-started/with-javascript/) * - [Inference examples](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/js) * * @packageDocumentation @@ -21,6 +21,9 @@ export * from './backend.js'; export * from './env.js'; export * from './inference-session.js'; export * from './tensor.js'; +export * from './tensor-conversion.js'; +export * from './tensor-factory.js'; export * from './trace.js'; +export * from './onnx-model.js'; export * from './onnx-value.js'; export * from './training-session.js'; diff --git a/js/common/lib/inference-session-impl.ts b/js/common/lib/inference-session-impl.ts index 55f40c8907a89..ab4c6a3e0c46b 100644 --- a/js/common/lib/inference-session-impl.ts +++ b/js/common/lib/inference-session-impl.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {resolveBackend} from './backend-impl.js'; +import {resolveBackendAndExecutionProviders} from './backend-impl.js'; import {InferenceSessionHandler} from './backend.js'; import {InferenceSession as InferenceSessionInterface} from './inference-session.js'; import {OnnxValue} from './onnx-value.js'; @@ -195,11 +195,9 @@ export class InferenceSession implements InferenceSessionInterface { throw new TypeError('Unexpected argument[0]: must be \'path\' or \'buffer\'.'); } - // get backend hints - const eps = options.executionProviders || []; - const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); - const backend = await resolveBackend(backendHints); - const handler = await backend.createInferenceSessionHandler(filePathOrUint8Array, options); + // resolve backend, update session options with validated EPs, and create session handler + const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options); + const handler = await backend.createInferenceSessionHandler(filePathOrUint8Array, optionsWithValidatedEPs); TRACE_FUNC_END(); return new InferenceSession(handler); } diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 1221b52cd4985..4f7fbdcdcf0ca 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -111,7 +111,7 @@ export declare namespace InferenceSession { optimizedModelFilePath?: string; /** - * Wether enable profiling. + * Whether enable profiling. * * This setting is a placeholder for a future use. */ @@ -154,6 +154,12 @@ export declare namespace InferenceSession { */ preferredOutputLocation?: OnnxValueDataLocation|{readonly [outputName: string]: OnnxValueDataLocation}; + /** + * Whether enable graph capture. + * This setting is available only in ONNXRuntime Web for WebGPU EP. + */ + enableGraphCapture?: boolean; + /** * Store configurations for a session. See * https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/ @@ -180,22 +186,22 @@ export declare namespace InferenceSession { // #region execution providers // Currently, we have the following backends to support execution providers: - // Backend Node.js binding: supports 'cpu' and 'cuda'. + // Backend Node.js binding: supports 'cpu', 'dml' (win32), 'coreml' (macOS) and 'cuda' (linux). // Backend WebAssembly: supports 'cpu', 'wasm', 'webgpu' and 'webnn'. // Backend ONNX.js: supports 'webgl'. // Backend React Native: supports 'cpu', 'xnnpack', 'coreml' (iOS), 'nnapi' (Android). interface ExecutionProviderOptionMap { + coreml: CoreMLExecutionProviderOption; cpu: CpuExecutionProviderOption; - coreml: CoreMlExecutionProviderOption; cuda: CudaExecutionProviderOption; dml: DmlExecutionProviderOption; + nnapi: NnapiExecutionProviderOption; tensorrt: TensorRtExecutionProviderOption; wasm: WebAssemblyExecutionProviderOption; webgl: WebGLExecutionProviderOption; - xnnpack: XnnpackExecutionProviderOption; webgpu: WebGpuExecutionProviderOption; webnn: WebNNExecutionProviderOption; - nnapi: NnapiExecutionProviderOption; + xnnpack: XnnpackExecutionProviderOption; } type ExecutionProviderName = keyof ExecutionProviderOptionMap; @@ -213,10 +219,6 @@ export declare namespace InferenceSession { readonly name: 'cuda'; deviceId?: number; } - export interface CoreMlExecutionProviderOption extends ExecutionProviderOption { - readonly name: 'coreml'; - coreMlFlags?: number; - } export interface DmlExecutionProviderOption extends ExecutionProviderOption { readonly name: 'dml'; deviceId?: number; @@ -247,8 +249,39 @@ export declare namespace InferenceSession { } export interface CoreMLExecutionProviderOption extends ExecutionProviderOption { readonly name: 'coreml'; + /** + * The bit flags for CoreML execution provider. + * + * ``` + * COREML_FLAG_USE_CPU_ONLY = 0x001 + * COREML_FLAG_ENABLE_ON_SUBGRAPH = 0x002 + * COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE = 0x004 + * COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES = 0x008 + * COREML_FLAG_CREATE_MLPROGRAM = 0x010 + * ``` + * + * See include/onnxruntime/core/providers/coreml/coreml_provider_factory.h for more details. + * + * This flag is available only in ONNXRuntime (Node.js binding). + */ + coreMlFlags?: number; + /** + * Specify whether to use CPU only in CoreML EP. + * + * This setting is available only in ONNXRuntime (react-native). + */ useCPUOnly?: boolean; + /** + * Specify whether to enable CoreML EP on subgraph. + * + * This setting is available only in ONNXRuntime (react-native). + */ enableOnSubgraph?: boolean; + /** + * Specify whether to only enable CoreML EP for Apple devices with ANE (Apple Neural Engine). + * + * This setting is available only in ONNXRuntime (react-native). + */ onlyEnableDeviceWithANE?: boolean; } export interface NnapiExecutionProviderOption extends ExecutionProviderOption { diff --git a/js/common/lib/onnx-value.ts b/js/common/lib/onnx-value.ts index a16a30d25d839..72369ce8b4209 100644 --- a/js/common/lib/onnx-value.ts +++ b/js/common/lib/onnx-value.ts @@ -3,7 +3,7 @@ import {Tensor} from './tensor.js'; -type NonTensorType = never; +export type NonTensorType = never; /** * Type OnnxValue Represents both tensors and non-tensors value for model's inputs/outputs. diff --git a/js/common/lib/tensor-factory.ts b/js/common/lib/tensor-factory.ts index 6e19d7fb898a3..431de4c3635c2 100644 --- a/js/common/lib/tensor-factory.ts +++ b/js/common/lib/tensor-factory.ts @@ -253,7 +253,7 @@ export interface TensorFactory { /** * create a tensor from an ImageBitmap object * - * @param bitMap - the ImageBitmap object to create tensor from + * @param bitmap - the ImageBitmap object to create tensor from * @param options - An optional object representing options for creating tensor from URL. * * The following default settings will be applied: diff --git a/js/common/lib/tensor-impl-type-mapping.ts b/js/common/lib/tensor-impl-type-mapping.ts index c4a43ea27fea1..b29cb8cbd6d35 100644 --- a/js/common/lib/tensor-impl-type-mapping.ts +++ b/js/common/lib/tensor-impl-type-mapping.ts @@ -14,7 +14,6 @@ export const NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP = new Map { - if (!isBigIntChecked) { - isBigIntChecked = true; - const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && typeof BigInt64Array.from === 'function'; - const isBigUint64ArrayAvailable = - typeof BigUint64Array !== 'undefined' && typeof BigUint64Array.from === 'function'; +// a dummy type declaration for Float16Array in case any polyfill is available. +declare global { + // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any + const Float16Array: any; +} + +// the following code allows delaying execution of BigInt/Float16Array checking. This allows lazy initialization for +// NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP and NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, which allows BigInt/Float16Array +// polyfill if available. +let isTypedArrayChecked = false; +export const checkTypedArray = () => { + if (!isTypedArrayChecked) { + isTypedArrayChecked = true; + const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && BigInt64Array.from; + const isBigUint64ArrayAvailable = typeof BigUint64Array !== 'undefined' && BigUint64Array.from; + const isFloat16ArrayAvailable = typeof Float16Array !== 'undefined' && Float16Array.from; if (isBigInt64ArrayAvailable) { NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('int64', BigInt64Array); @@ -53,5 +58,12 @@ export const checkBigInt = () => { NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('uint64', BigUint64Array); NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(BigUint64Array, 'uint64'); } + if (isFloat16ArrayAvailable) { + NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Float16Array); + NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(Float16Array, 'float16'); + } else { + // if Float16Array is not available, use 'Uint16Array' to store the data. + NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Uint16Array); + } } }; diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index e3e2b9c728556..56682ef98e117 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -5,7 +5,7 @@ import {tensorToDataURL, tensorToImageData} from './tensor-conversion-impl.js'; import {TensorToDataUrlOptions, TensorToImageDataOptions} from './tensor-conversion.js'; import {tensorFromGpuBuffer, tensorFromImage, tensorFromPinnedBuffer, tensorFromTexture} from './tensor-factory-impl.js'; import {CpuPinnedConstructorParameters, GpuBufferConstructorParameters, TensorFromGpuBufferOptions, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, TensorFromTextureOptions, TensorFromUrlOptions, TextureConstructorParameters} from './tensor-factory.js'; -import {checkBigInt, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js'; +import {checkTypedArray, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js'; import {calculateSize, tensorReshape} from './tensor-utils-impl.js'; import {Tensor as TensorInterface} from './tensor.js'; @@ -67,8 +67,8 @@ export class Tensor implements TensorInterface { arg0: TensorType|TensorDataType|readonly string[]|readonly boolean[]|CpuPinnedConstructorParameters| TextureConstructorParameters|GpuBufferConstructorParameters, arg1?: TensorDataType|readonly number[]|readonly string[]|readonly boolean[], arg2?: readonly number[]) { - // perform one-time check for BigInt support - checkBigInt(); + // perform one-time check for BigInt/Float16Array support + checkTypedArray(); let type: TensorType; let dims: readonly number[]; @@ -103,7 +103,7 @@ export class Tensor implements TensorInterface { } case 'gpu-buffer': { if ((type !== 'float32' && type !== 'float16' && type !== 'int32' && type !== 'int64' && type !== 'uint32' && - type !== 'bool')) { + type !== 'uint8' && type !== 'bool')) { throw new TypeError(`unsupported type "${type}" to create tensor from gpu buffer`); } this.gpuBufferData = arg0.gpuBuffer; @@ -142,7 +142,9 @@ export class Tensor implements TensorInterface { throw new TypeError(`Unsupported tensor type: ${arg0}.`); } if (Array.isArray(arg1)) { - if (arg0 === 'float16') { + if (arg0 === 'float16' && typedArrayConstructor === Uint16Array) { + // When no Float16Array polyfill is used, we cannot create 'float16' tensor from number array. + // // Throw error here because when user try to use number array as data, // e.g. new Tensor('float16', [1, 2, 3, 4], dims)), it will actually call // Uint16Array.from(arg1) which generates wrong data. diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts index 6c08d1fe8e057..20319ebb800c2 100644 --- a/js/common/lib/tensor.ts +++ b/js/common/lib/tensor.ts @@ -135,7 +135,7 @@ export declare namespace Tensor { /** * supported data types for constructing a tensor from a WebGPU buffer */ - export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'bool'; + export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'uint8'|'bool'; /** * represent where the tensor data is stored @@ -160,7 +160,7 @@ export interface Tensor extends TypedTensorBase, TypedTensorUtils { - if (!env.wasm.trace) { + if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { return; } // eslint-disable-next-line no-console @@ -29,15 +32,21 @@ const TRACE_FUNC = (msg: string, extraMsg?: string) => { } }; +/** + * @ignore + */ export const TRACE_FUNC_BEGIN = (extraMsg?: string) => { - if (!env.wasm.trace) { + if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { return; } TRACE_FUNC('BEGIN', extraMsg); }; +/** + * @ignore + */ export const TRACE_FUNC_END = (extraMsg?: string) => { - if (!env.wasm.trace) { + if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { return; } TRACE_FUNC('END', extraMsg); diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index 23bd4421ae672..bae38b0dfda5a 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {resolveBackend} from './backend-impl.js'; +import {resolveBackendAndExecutionProviders} from './backend-impl.js'; import {SessionHandler, TrainingSessionHandler} from './backend.js'; import {InferenceSession as InferenceSession} from './inference-session.js'; import {OnnxValue} from './onnx-value.js'; @@ -55,13 +55,12 @@ export class TrainingSession implements TrainingSessionInterface { const optimizerModel: string|Uint8Array = trainingOptions.optimizerModel || ''; const options: SessionOptions = sessionOptions || {}; - // get backend hints - const eps = options.executionProviders || []; - const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); - const backend = await resolveBackend(backendHints); + // resolve backend, update session options with validated EPs, and create session handler + const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options); if (backend.createTrainingSessionHandler) { const handler = await backend.createTrainingSessionHandler( - trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options); + trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, + optionsWithValidatedEPs); return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel); } else { throw new Error(noBackendErrMsg); diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts index e54aed90e702c..f9de77e3ac7d0 100644 --- a/js/common/lib/training-session.ts +++ b/js/common/lib/training-session.ts @@ -11,7 +11,7 @@ export declare namespace TrainingSession { /** * Either URI file path (string) or Uint8Array containing model or checkpoint information. */ - type URIorBuffer = string|Uint8Array; + type UriOrBuffer = string|Uint8Array; } /** @@ -98,13 +98,13 @@ export interface TrainingSession { getParametersSize(trainableOnly: boolean): Promise; /** - * Copies parameter values from the given array to the training state. Currently, only supporting models with + * Copies parameter values from the given buffer to the training state. Currently, only supporting models with * parameters of type Float32. * - * @param buffer - Float32 buffer containing parameters converted to a Uint8Array. + * @param buffer - A Uint8Array representation of Float32 parameters. * @param trainableOnly - True if trainable parameters only to be modified, false otherwise. Default value is true. */ - loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; + loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise; /** * Copies the model parameters to a contiguous buffer. Usually used in the context of Federated Learning. @@ -157,19 +157,19 @@ export interface TrainingSessionCreateOptions { /** * URI or buffer for a .ckpt file that contains the checkpoint for the training model. */ - checkpointState: TrainingSession.URIorBuffer; + checkpointState: TrainingSession.UriOrBuffer; /** * URI or buffer for the .onnx training file. */ - trainModel: TrainingSession.URIorBuffer; + trainModel: TrainingSession.UriOrBuffer; /** * Optional. URI or buffer for the .onnx optimizer model file. */ - optimizerModel?: TrainingSession.URIorBuffer; + optimizerModel?: TrainingSession.UriOrBuffer; /** * Optional. URI or buffer for the .onnx eval model file. */ - evalModel?: TrainingSession.URIorBuffer; + evalModel?: TrainingSession.UriOrBuffer; } /** diff --git a/js/common/package-lock.json b/js/common/package-lock.json index a5ada877b916a..3988ac80707e0 100644 --- a/js/common/package-lock.json +++ b/js/common/package-lock.json @@ -9,13 +9,13 @@ "version": "1.18.0", "license": "MIT", "devDependencies": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" } }, "node_modules/ansi-sequence-parser": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.0.tgz", - "integrity": "sha512-lEm8mt52to2fT8GhciPCGeCXACSz2UwIN4X2e2LJSnZ5uAbn2/dsYdOmUXq0AtWS5cpAupysIneExOgH0Vd2TQ==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.1.tgz", + "integrity": "sha512-vJXt3yiaUL4UU546s3rPXlsry/RnM730G1+HkpKE012AN0sx1eOrxSu95oKDIonskeLTijMgqWZ3uDEe3NFvyg==", "dev": true }, "node_modules/balanced-match": { @@ -34,9 +34,9 @@ } }, "node_modules/jsonc-parser": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.0.tgz", - "integrity": "sha512-gfFQZrcTc8CnKXp6Y4/CBT3fTc0OVuDofpre4aEeEpSBPV5X5v4+Vmx+8snU7RLPrNHPKSgLxGo9YuQzz20o+w==", + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.1.tgz", + "integrity": "sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==", "dev": true }, "node_modules/lunr": { @@ -46,9 +46,9 @@ "dev": true }, "node_modules/marked": { - "version": "4.2.12", - "resolved": "https://registry.npmjs.org/marked/-/marked-4.2.12.tgz", - "integrity": "sha512-yr8hSKa3Fv4D3jdZmtMMPghgVt6TWbk86WQaWhDloQjRSQhMMYCAro7jP7VDJrjjdV8pxVxMssXS8B8Y5DZ5aw==", + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/marked/-/marked-4.3.0.tgz", + "integrity": "sha512-PRsaiG84bK+AMvxziE/lCFss8juXjNaWzVbN5tXAm4XjeaS9NAHhop+PjQxz2A9h8Q4M/xGmzP8vqNwy6JeK0A==", "dev": true, "bin": { "marked": "bin/marked.js" @@ -58,24 +58,24 @@ } }, "node_modules/minimatch": { - "version": "7.4.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-7.4.2.tgz", - "integrity": "sha512-xy4q7wou3vUoC9k1xGTXc+awNdGaGVHtFUaey8tiX4H1QRc04DZ/rmDFwNm2EBsuYEhAZ6SgMmYf3InGY6OauA==", + "version": "9.0.3", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.3.tgz", + "integrity": "sha512-RHiac9mvaRw0x3AYRgDC1CxAP7HTcNrrECeA8YYJeWnpo+2Q5CegtZjaotWTWxDG3UeGA1coE05iH1mPjT/2mg==", "dev": true, "dependencies": { "brace-expansion": "^2.0.1" }, "engines": { - "node": ">=10" + "node": ">=16 || 14 >=14.17" }, "funding": { "url": "https://github.com/sponsors/isaacs" } }, "node_modules/shiki": { - "version": "0.14.1", - "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.1.tgz", - "integrity": "sha512-+Jz4nBkCBe0mEDqo1eKRcCdjRtrCjozmcbTUjbPTX7OOJfEbTZzlUWlZtGe3Gb5oV1/jnojhG//YZc3rs9zSEw==", + "version": "0.14.7", + "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.7.tgz", + "integrity": "sha512-dNPAPrxSc87ua2sKJ3H5dQ/6ZaY8RNnaAqK+t0eG7p0Soi2ydiqbGOTaZCqaYvA/uZYfS1LJnemt3Q+mSfcPCg==", "dev": true, "dependencies": { "ansi-sequence-parser": "^1.1.0", @@ -85,30 +85,30 @@ } }, "node_modules/typedoc": { - "version": "0.23.26", - "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.23.26.tgz", - "integrity": "sha512-5m4KwR5tOLnk0OtMaRn9IdbeRM32uPemN9kur7YK9wFqx8U0CYrvO9aVq6ysdZSV1c824BTm+BuQl2Ze/k1HtA==", + "version": "0.25.7", + "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.25.7.tgz", + "integrity": "sha512-m6A6JjQRg39p2ZVRIN3NKXgrN8vzlHhOS+r9ymUYtcUP/TIQPvWSq7YgE5ZjASfv5Vd5BW5xrir6Gm2XNNcOow==", "dev": true, "dependencies": { "lunr": "^2.3.9", - "marked": "^4.2.12", - "minimatch": "^7.1.3", - "shiki": "^0.14.1" + "marked": "^4.3.0", + "minimatch": "^9.0.3", + "shiki": "^0.14.7" }, "bin": { "typedoc": "bin/typedoc" }, "engines": { - "node": ">= 14.14" + "node": ">= 16" }, "peerDependencies": { - "typescript": "4.6.x || 4.7.x || 4.8.x || 4.9.x" + "typescript": "4.6.x || 4.7.x || 4.8.x || 4.9.x || 5.0.x || 5.1.x || 5.2.x || 5.3.x" } }, "node_modules/typescript": { - "version": "4.9.5", - "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz", - "integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==", + "version": "5.2.2", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.2.2.tgz", + "integrity": "sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==", "dev": true, "peer": true, "bin": { @@ -116,7 +116,7 @@ "tsserver": "bin/tsserver" }, "engines": { - "node": ">=4.2.0" + "node": ">=14.17" } }, "node_modules/vscode-oniguruma": { @@ -134,9 +134,9 @@ }, "dependencies": { "ansi-sequence-parser": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.0.tgz", - "integrity": "sha512-lEm8mt52to2fT8GhciPCGeCXACSz2UwIN4X2e2LJSnZ5uAbn2/dsYdOmUXq0AtWS5cpAupysIneExOgH0Vd2TQ==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.1.tgz", + "integrity": "sha512-vJXt3yiaUL4UU546s3rPXlsry/RnM730G1+HkpKE012AN0sx1eOrxSu95oKDIonskeLTijMgqWZ3uDEe3NFvyg==", "dev": true }, "balanced-match": { @@ -155,9 +155,9 @@ } }, "jsonc-parser": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.0.tgz", - "integrity": "sha512-gfFQZrcTc8CnKXp6Y4/CBT3fTc0OVuDofpre4aEeEpSBPV5X5v4+Vmx+8snU7RLPrNHPKSgLxGo9YuQzz20o+w==", + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.1.tgz", + "integrity": "sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==", "dev": true }, "lunr": { @@ -167,24 +167,24 @@ "dev": true }, "marked": { - "version": "4.2.12", - "resolved": "https://registry.npmjs.org/marked/-/marked-4.2.12.tgz", - "integrity": "sha512-yr8hSKa3Fv4D3jdZmtMMPghgVt6TWbk86WQaWhDloQjRSQhMMYCAro7jP7VDJrjjdV8pxVxMssXS8B8Y5DZ5aw==", + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/marked/-/marked-4.3.0.tgz", + "integrity": "sha512-PRsaiG84bK+AMvxziE/lCFss8juXjNaWzVbN5tXAm4XjeaS9NAHhop+PjQxz2A9h8Q4M/xGmzP8vqNwy6JeK0A==", "dev": true }, "minimatch": { - "version": "7.4.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-7.4.2.tgz", - "integrity": "sha512-xy4q7wou3vUoC9k1xGTXc+awNdGaGVHtFUaey8tiX4H1QRc04DZ/rmDFwNm2EBsuYEhAZ6SgMmYf3InGY6OauA==", + "version": "9.0.3", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.3.tgz", + "integrity": "sha512-RHiac9mvaRw0x3AYRgDC1CxAP7HTcNrrECeA8YYJeWnpo+2Q5CegtZjaotWTWxDG3UeGA1coE05iH1mPjT/2mg==", "dev": true, "requires": { "brace-expansion": "^2.0.1" } }, "shiki": { - "version": "0.14.1", - "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.1.tgz", - "integrity": "sha512-+Jz4nBkCBe0mEDqo1eKRcCdjRtrCjozmcbTUjbPTX7OOJfEbTZzlUWlZtGe3Gb5oV1/jnojhG//YZc3rs9zSEw==", + "version": "0.14.7", + "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.7.tgz", + "integrity": "sha512-dNPAPrxSc87ua2sKJ3H5dQ/6ZaY8RNnaAqK+t0eG7p0Soi2ydiqbGOTaZCqaYvA/uZYfS1LJnemt3Q+mSfcPCg==", "dev": true, "requires": { "ansi-sequence-parser": "^1.1.0", @@ -194,21 +194,21 @@ } }, "typedoc": { - "version": "0.23.26", - "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.23.26.tgz", - "integrity": "sha512-5m4KwR5tOLnk0OtMaRn9IdbeRM32uPemN9kur7YK9wFqx8U0CYrvO9aVq6ysdZSV1c824BTm+BuQl2Ze/k1HtA==", + "version": "0.25.7", + "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.25.7.tgz", + "integrity": "sha512-m6A6JjQRg39p2ZVRIN3NKXgrN8vzlHhOS+r9ymUYtcUP/TIQPvWSq7YgE5ZjASfv5Vd5BW5xrir6Gm2XNNcOow==", "dev": true, "requires": { "lunr": "^2.3.9", - "marked": "^4.2.12", - "minimatch": "^7.1.3", - "shiki": "^0.14.1" + "marked": "^4.3.0", + "minimatch": "^9.0.3", + "shiki": "^0.14.7" } }, "typescript": { - "version": "4.9.5", - "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz", - "integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==", + "version": "5.2.2", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.2.2.tgz", + "integrity": "sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==", "dev": true, "peer": true }, diff --git a/js/common/package.json b/js/common/package.json index 64ab2736adbe3..cd2612aab4984 100644 --- a/js/common/package.json +++ b/js/common/package.json @@ -9,7 +9,7 @@ }, "author": "fs-eire", "scripts": { - "build:cjs": "tsc --module commonjs --outDir ./dist/cjs", + "build:cjs": "tsc --module commonjs --moduleResolution node10 --outDir ./dist/cjs", "build:esm": "tsc", "build:bundles": "webpack", "build": "node ./build.js", @@ -18,7 +18,7 @@ "test": "mocha ./test/**/*.js --timeout 30000" }, "devDependencies": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" }, "main": "dist/cjs/index.js", "exports": { diff --git a/js/common/test/tsconfig.json b/js/common/test/tsconfig.json index 2e4927ac3b325..e9068ad837a81 100644 --- a/js/common/test/tsconfig.json +++ b/js/common/test/tsconfig.json @@ -2,7 +2,7 @@ "extends": "../../tsconfig.tools.json", "exclude": ["type-tests/**/*.ts"], "compilerOptions": { - "module": "ES2022", + "module": "Node16", "sourceMap": true } } diff --git a/js/node/lib/backend.ts b/js/node/lib/backend.ts index e8eb0e9babf5a..927953b4f1dd6 100644 --- a/js/node/lib/backend.ts +++ b/js/node/lib/backend.ts @@ -36,7 +36,7 @@ class OnnxruntimeSessionHandler implements InferenceSessionHandler { async run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions): Promise { return new Promise((resolve, reject) => { - process.nextTick(() => { + setImmediate(() => { try { resolve(this.#inferenceSession.run(feeds, fetches, options)); } catch (e) { @@ -56,7 +56,7 @@ class OnnxruntimeBackend implements Backend { async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise { return new Promise((resolve, reject) => { - process.nextTick(() => { + setImmediate(() => { try { resolve(new OnnxruntimeSessionHandler(pathOrBuffer, options || {})); } catch (e) { diff --git a/js/node/package-lock.json b/js/node/package-lock.json index 2d7c39c86097f..62b47698a1438 100644 --- a/js/node/package-lock.json +++ b/js/node/package-lock.json @@ -30,7 +30,7 @@ "version": "1.18.0", "license": "MIT", "devDependencies": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" } }, "node_modules/@protobufjs/aspromise": { @@ -336,9 +336,9 @@ "dev": true }, "node_modules/follow-redirects": { - "version": "1.15.4", - "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.4.tgz", - "integrity": "sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==", + "version": "1.15.6", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz", + "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==", "dev": true, "funding": [ { @@ -1242,9 +1242,9 @@ "dev": true }, "follow-redirects": { - "version": "1.15.4", - "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.4.tgz", - "integrity": "sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==", + "version": "1.15.6", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz", + "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==", "dev": true }, "form-data": { @@ -1503,7 +1503,7 @@ "onnxruntime-common": { "version": "file:../common", "requires": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" } }, "parse-json": { diff --git a/js/react_native/e2e/yarn.lock b/js/react_native/e2e/yarn.lock index 9e20a286c4e27..6f05faf046098 100644 --- a/js/react_native/e2e/yarn.lock +++ b/js/react_native/e2e/yarn.lock @@ -3351,9 +3351,9 @@ invariant@^2.2.4: loose-envify "^1.0.0" ip@^1.1.5: - version "1.1.8" - resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.8.tgz#ae05948f6b075435ed3307acce04629da8cdbf48" - integrity sha512-PuExPYUiu6qMBQb4l06ecm6T6ujzhmh+MeJcW9wa89PoAz5pvd4zPgN5WJV104mb6S2T1AwNIAaB70JNrLQWhg== + version "1.1.9" + resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.9.tgz#8dfbcc99a754d07f425310b86a99546b1151e396" + integrity sha512-cyRxvOEpNHNtchU3Ln9KC/auJgup87llfQpQ+t5ghoC/UhL16SWzbueiCsdTnWmqAWl7LadfuwhlqmtOaqMHdQ== is-accessor-descriptor@^0.1.6: version "0.1.6" diff --git a/js/react_native/yarn.lock b/js/react_native/yarn.lock index 4dca90d7415cf..bbb0c4f3d1e22 100644 --- a/js/react_native/yarn.lock +++ b/js/react_native/yarn.lock @@ -3701,9 +3701,9 @@ invariant@^2.2.4: loose-envify "^1.0.0" ip@^1.1.5: - version "1.1.8" - resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.8.tgz#ae05948f6b075435ed3307acce04629da8cdbf48" - integrity sha512-PuExPYUiu6qMBQb4l06ecm6T6ujzhmh+MeJcW9wa89PoAz5pvd4zPgN5WJV104mb6S2T1AwNIAaB70JNrLQWhg== + version "1.1.9" + resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.9.tgz#8dfbcc99a754d07f425310b86a99546b1151e396" + integrity sha512-cyRxvOEpNHNtchU3Ln9KC/auJgup87llfQpQ+t5ghoC/UhL16SWzbueiCsdTnWmqAWl7LadfuwhlqmtOaqMHdQ== is-absolute@^1.0.0: version "1.0.0" diff --git a/js/web/README.md b/js/web/README.md index c75a40ad6da28..906c78a1b7ec4 100644 --- a/js/web/README.md +++ b/js/web/README.md @@ -12,7 +12,7 @@ The [Open Neural Network Exchange](http://onnx.ai/) (ONNX) is an open standard f With ONNX Runtime Web, web developers can score models directly on browsers with various benefits including reducing server-client communication and protecting user privacy, as well as offering install-free and cross-platform in-browser ML experience. -ONNX Runtime Web can run on both CPU and GPU. On CPU side, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. ONNX Runtime Web complies the native ONNX Runtime CPU engine into WebAssembly backend by using Emscripten, so it supports most functionalities native ONNX Runtime offers, including full ONNX operator coverage, multi-threading, [ONNX Runtime Quantization](https://www.onnxruntime.ai/docs/how-to/quantization.html) as well as [ONNX Runtime Mobile](https://onnxruntime.ai/docs/tutorials/mobile/). For performance acceleration with GPUs, ONNX Runtime Web leverages WebGL, a popular standard for accessing GPU capabilities. We are keeping improving op coverage and optimizing performance in WebGL backend. +ONNX Runtime Web can run on both CPU and GPU. On CPU side, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. ONNX Runtime Web compiles the native ONNX Runtime CPU engine into WebAssembly backend by using Emscripten, so it supports most functionalities native ONNX Runtime offers, including full ONNX operator coverage, multi-threading, [ONNX Runtime Quantization](https://www.onnxruntime.ai/docs/how-to/quantization.html) as well as [ONNX Runtime Mobile](https://onnxruntime.ai/docs/tutorials/mobile/). For performance acceleration with GPUs, ONNX Runtime Web leverages WebGL, a popular standard for accessing GPU capabilities. We are keeping improving op coverage and optimizing performance in WebGL backend. See [Compatibility](#Compatibility) and [Operators Supported](#Operators) for a list of platforms and operators ONNX Runtime Web currently supports. @@ -22,7 +22,7 @@ Refer to [ONNX Runtime JavaScript examples](https://github.com/microsoft/onnxrun ## Documents -### Developement +### Development Refer to the following links for development information: diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 2557971eb4ded..4a8c92bb97bfd 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -41,6 +41,7 @@ Do not modify directly.* | Erf | ai.onnx(9-12,13+) | | | Exp | ai.onnx(6-12,13+) | | | Expand | ai.onnx(8-12,13+) | | +| FastGelu | com.microsoft(1+) | | | Flatten | ai.onnx(1-8,9-10,11-12,13+) | | | Floor | ai.onnx(6-12,13+) | | | FusedConv | com.microsoft(1+) | | @@ -61,6 +62,7 @@ Do not modify directly.* | LessOrEqual | ai.onnx(12-15,16+) | | | Log | ai.onnx(6-12,13+) | | | MatMul | ai.onnx(1-12,13+) | | +| MatMulNBits | com.microsoft(1+) | | | MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(1-7,8-9,10,11,12+) | need perf optimization; need implementing activation | | MemcpyFromHost | ai.onnx(1+) | | | MemcpyToHost | ai.onnx(1+) | | diff --git a/js/web/karma.conf.js b/js/web/karma.conf.js index 8fce79843f617..9e44d9c0d9652 100644 --- a/js/web/karma.conf.js +++ b/js/web/karma.conf.js @@ -86,11 +86,11 @@ module.exports = function(config) { hostname, listenAddress, customLaunchers: { - // the following flags are used to make sure Edge on CI agents to initialize WebGPU correctly. + // Chromium-based browsers EdgeTest: {base: 'Edge', flags: chromiumFlags}, ChromeTest: {base: 'Chrome', flags: chromiumFlags}, - ChromeTestHeadless: {base: 'ChromeHeadless', flags: chromiumFlags}, ChromeCanaryTest: {base: 'ChromeCanary', flags: chromiumFlags}, + // // ==== BrowserStack browsers ==== // diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 24d7062c85fcb..56925b728e9a3 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -13,20 +13,100 @@ export declare namespace JSEP { type ReleaseKernelFunction = (kernel: number) => void; type RunFunction = (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => number; + type CaptureBeginFunction = () => void; + type CaptureEndFunction = () => void; + type ReplayFunction = () => void; + + export interface Module extends WebGpuModule { + /** + * Mount the external data file to an internal map, which will be used during session initialization. + * + * @param externalDataFilePath - specify the relative path of the external data file. + * @param externalDataFileData - specify the content data. + */ + mountExternalData(externalDataFilePath: string, externalDataFileData: Uint8Array): void; + /** + * Unmount all external data files from the internal map. + */ + unmountExternalData(): void; + + /** + * This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime per + * backend. This function initializes Asyncify support. If name is 'webgpu', also initializes WebGPU backend and + * registers a few callbacks that will be called in C++ code. + */ + jsepInit(name: 'webgpu', initParams: [ + backend: BackendType, alloc: AllocFunction, free: FreeFunction, upload: UploadFunction, + download: DownloadFunction, createKernel: CreateKernelFunction, releaseKernel: ReleaseKernelFunction, + run: RunFunction, captureBegin: CaptureBeginFunction, captureEnd: CaptureEndFunction, replay: ReplayFunction + ]): void; + jsepInit(name: 'webnn', initParams?: never): void; + } + + export interface WebGpuModule { + /** + * [exported from wasm] Specify a kernel's output when running OpKernel::Compute(). + * + * @param context - specify the kernel context pointer. + * @param index - specify the index of the output. + * @param data - specify the pointer to encoded data of type and dims. + */ + _JsepOutput(context: number, index: number, data: number): number; + /** + * [exported from wasm] Get name of an operator node. + * + * @param kernel - specify the kernel pointer. + * @returns the pointer to a C-style UTF8 encoded string representing the node name. + */ + _JsepGetNodeName(kernel: number): number; + + /** + * [exported from js_internal_api.js] Register a user GPU buffer for usage of a session's input or output. + * + * @param sessionId - specify the session ID. + * @param index - specify an integer to represent which input/output it is registering for. For input, it is the + * input_index corresponding to the session's inputNames. For output, it is the inputCount + output_index + * corresponding to the session's ouputNames. + * @param buffer - specify the GPU buffer to register. + * @param size - specify the original data size in byte. + * @returns the GPU data ID for the registered GPU buffer. + */ + jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number; + /** + * [exported from js_internal_api.js] Get the GPU buffer by GPU data ID. + * + * @param dataId - specify the GPU data ID + * @returns the GPU buffer. + */ + jsepGetBuffer: (dataId: number) => GPUBuffer; + /** + * [exported from js_internal_api.js] Create a function to be used to create a GPU Tensor. + * + * @param gpuBuffer - specify the GPU buffer + * @param size - specify the original data size in byte. + * @param type - specify the tensor type. + * @returns the generated downloader function. + */ + jsepCreateDownloader: + (gpuBuffer: GPUBuffer, size: number, + type: Tensor.GpuBufferDataTypes) => () => Promise; + /** + * [exported from js_internal_api.js] Called when InferenceSession.run started. This function will be called before + * _OrtRun[WithBinding]() is called. + * @param sessionId - specify the session ID. + */ + jsepOnRunStart: (sessionId: number) => void; + /** + * [exported from js_internal_api.js] Release a session. This function will be called before _OrtReleaseSession() is + * called. + * @param sessionId - specify the session ID. + * @returns + */ + jsepOnReleaseSession: (sessionId: number) => void; + } } -export interface OrtWasmModule extends EmscriptenModule { - // #region emscripten functions - stackSave(): number; - stackRestore(stack: number): void; - stackAlloc(size: number): number; - - UTF8ToString(offset: number, maxBytesToRead?: number): string; - lengthBytesUTF8(str: string): number; - stringToUTF8(str: string, offset: number, maxBytes: number): void; - // #endregion - - // #region ORT APIs +export interface OrtInferenceAPIs { _OrtInit(numThreads: number, loggingLevel: number): number; _OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): void; @@ -71,122 +151,61 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtReleaseRunOptions(runOptionsHandle: number): void; _OrtEndProfiling(sessionHandle: number): number; - // #endregion +} - // #region ORT Training APIs - _OrtTrainingLoadCheckpoint?(dataOffset: number, dataLength: number): number; +export interface OrtTrainingAPIs { + _OrtTrainingLoadCheckpoint(dataOffset: number, dataLength: number): number; - _OrtTrainingReleaseCheckpoint?(checkpointHandle: number): void; + _OrtTrainingReleaseCheckpoint(checkpointHandle: number): void; - _OrtTrainingCreateSession? - (sessionOptionsHandle: number, checkpointHandle: number, trainOffset: number, trainLength: number, - evalOffset: number, evalLength: number, optimizerOffset: number, optimizerLength: number): number; + _OrtTrainingCreateSession( + sessionOptionsHandle: number, checkpointHandle: number, trainOffset: number, trainLength: number, + evalOffset: number, evalLength: number, optimizerOffset: number, optimizerLength: number): number; - _OrtTrainingLazyResetGrad?(trainingHandle: number): number; + _OrtTrainingLazyResetGrad(trainingHandle: number): number; - _OrtTrainingRunTrainStep? - (trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number, - runOptionsHandle: number): number; + _OrtTrainingRunTrainStep( + trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number, + runOptionsHandle: number): number; - _OrtTrainingOptimizerStep?(trainingHandle: number, runOptionsHandle: number): number; + _OrtTrainingOptimizerStep(trainingHandle: number, runOptionsHandle: number): number; - _OrtTrainingEvalStep? - (trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number, - runOptionsHandle: number): number; + _OrtTrainingEvalStep( + trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number, + runOptionsHandle: number): number; - _OrtTrainingGetParametersSize?(trainingHandle: number, paramSizeT: number, trainableOnly: boolean): number; + _OrtTrainingGetParametersSize(trainingHandle: number, paramSizeT: number, trainableOnly: boolean): number; - _OrtTrainingCopyParametersToBuffer? - (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; + _OrtTrainingCopyParametersToBuffer( + trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; - _OrtTrainingCopyParametersFromBuffer? - (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; + _OrtTrainingCopyParametersFromBuffer( + trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; - _OrtTrainingGetModelInputOutputCount? - (trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number; - _OrtTrainingGetModelInputOutputName? - (trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number; + _OrtTrainingGetModelInputOutputCount( + trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number; + _OrtTrainingGetModelInputOutputName(trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): + number; - _OrtTrainingReleaseSession?(trainingHandle: number): void; + _OrtTrainingReleaseSession(trainingHandle: number): void; +} + +export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Partial, + Partial { + // #region emscripten functions + stackSave(): number; + stackRestore(stack: number): void; + stackAlloc(size: number): number; + + UTF8ToString(offset: number, maxBytesToRead?: number): string; + lengthBytesUTF8(str: string): number; + stringToUTF8(str: string, offset: number, maxBytes: number): void; // #endregion // #region config numThreads?: number; mainScriptUrlOrBlob?: string|Blob; // #endregion - - // #region external data API - mountExternalData?(externalDataFilePath: string, externalDataFileData: Uint8Array): void; - unmountExternalData?(): void; - // #endregion - - // #region JSEP - /** - * This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime. - * This function initializes WebGPU backend and registers a few callbacks that will be called in C++ code. - */ - jsepInit? - (backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction, - download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction, - releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void; - - /** - * [exported from wasm] Specify a kernel's output when running OpKernel::Compute(). - * - * @param context - specify the kernel context pointer. - * @param index - specify the index of the output. - * @param data - specify the pointer to encoded data of type and dims. - */ - _JsepOutput(context: number, index: number, data: number): number; - /** - * [exported from wasm] Get name of an operator node. - * - * @param kernel - specify the kernel pointer. - * @returns the pointer to a C-style UTF8 encoded string representing the node name. - */ - _JsepGetNodeName(kernel: number): number; - - /** - * [exported from js_internal_api.js] Register a user GPU buffer for usage of a session's input or output. - * - * @param sessionId - specify the session ID. - * @param index - specify an integer to represent which input/output it is registering for. For input, it is the - * input_index corresponding to the session's inputNames. For output, it is the inputCount + output_index - * corresponding to the session's ouputNames. - * @param buffer - specify the GPU buffer to register. - * @param size - specify the original data size in byte. - * @returns the GPU data ID for the registered GPU buffer. - */ - jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number; - /** - * [exported from js_internal_api.js] Unregister all user GPU buffers for a session. - * - * @param sessionId - specify the session ID. - */ - jsepUnregisterBuffers?: (sessionId: number) => void; - /** - * [exported from js_internal_api.js] Get the GPU buffer by GPU data ID. - * - * @param dataId - specify the GPU data ID - * @returns the GPU buffer. - */ - jsepGetBuffer: (dataId: number) => GPUBuffer; - /** - * [exported from js_internal_api.js] Create a function to be used to create a GPU Tensor. - * - * @param gpuBuffer - specify the GPU buffer - * @param size - specify the original data size in byte. - * @param type - specify the tensor type. - * @returns the generated downloader function. - */ - jsepCreateDownloader: - (gpuBuffer: GPUBuffer, size: number, - type: Tensor.GpuBufferDataTypes) => () => Promise; - /** - * [exported from js_internal_api.js] Called when InferenceSession.run started. - */ - jsepOnRunStart: () => void; - // #endregion } declare const moduleFactory: EmscriptenModuleFactory; diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index a48fe99570abf..d92b8ac68dbe7 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -3,14 +3,21 @@ import {Env, Tensor, TRACE, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common'; -import {tensorDataTypeEnumToString} from '../wasm-common'; +import {DataType, tensorDataTypeEnumToString} from '../wasm-common'; import {configureLogger, LOG_DEBUG} from './log'; import {createView, TensorView} from './tensor-view'; import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager'; import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules'; import {ProgramManager} from './webgpu/program-manager'; -import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, TimestampQuery} from './webgpu/types'; +import {AdapterInfo, ComputeContext, GpuArchitecture, GpuData, GpuVendor, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types'; + +interface CommandInfo { + readonly kernelId: number; + readonly computePipeline: GPUComputePipeline; + readonly bindGroup: GPUBindGroup; + readonly dispatchGroup: [number, number, number]; +} interface KernelInfo { readonly kernelType: string; @@ -87,11 +94,32 @@ const getProgramInfoUniqueKey = return key; }; +class AdapterInfoImpl implements AdapterInfo { + readonly architecture?: string; + readonly vendor?: string; + + constructor(adapterInfo: GPUAdapterInfo) { + if (adapterInfo) { + this.architecture = adapterInfo.architecture; + this.vendor = adapterInfo.vendor; + } + } + + isArchitecture(architecture: GpuArchitecture): boolean { + return this.architecture === architecture; + } + + isVendor(vendor: GpuVendor): boolean { + return this.vendor === vendor; + } +} + /** * this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as * the first parameter so that it is stored for future use. */ export class WebGpuBackend { + adapterInfo: AdapterInfoImpl; device: GPUDevice; /** * an instance of GpuDataManager to manage a GpuDataId -> GpuBuffer mapping @@ -103,6 +131,13 @@ export class WebGpuBackend { */ programManager: ProgramManager; + /** + * representing the session ID of which is currently being run. + * `null` means no session is being run. + * only valid when session.run is executed. + */ + currentSessionId: number|null = null; + /** * representing the kernel ID of which is currently being computed (CPU code perspective). * `null` means no kernel is being computed. @@ -155,6 +190,16 @@ export class WebGpuBackend { queryType: TimestampQuery; env: Env; + sessionStatus: SessionState = 'default'; + /** + * a SessionID -> CommandInfo[] mapping. It's used to record all GPU commands for corresponding session. + */ + capturedCommandList: Map = new Map(); + + /** + * a SessionID -> PendingKernelInfo[] mapping for profiling. + */ + private capturedPendingKernels: Map = new Map(); /** * a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping. @@ -188,6 +233,7 @@ export class WebGpuBackend { } this.device = await adapter.requestDevice(deviceDescriptor); + this.adapterInfo = new AdapterInfoImpl(await adapter.requestAdapterInfo()); this.gpuDataManager = createGpuDataManager(this); this.programManager = new ProgramManager(this); this.kernels = new Map(); @@ -207,6 +253,7 @@ export class WebGpuBackend { }; Object.defineProperty(this.env.webgpu, 'device', {value: this.device}); + Object.defineProperty(this.env.webgpu, 'adapter', {value: adapter}); // init queryType, which is necessary for InferenceSession.create this.setQueryType(); @@ -228,6 +275,7 @@ export class WebGpuBackend { getComputePassEncoder(): GPUComputePassEncoder { if (!this.computePassEncoder) { + const commandEncoder = this.getCommandEncoder(); const computePassDescriptor: GPUComputePassDescriptor = {}; if (this.queryType === 'at-passes') { @@ -238,7 +286,7 @@ export class WebGpuBackend { }; } - this.computePassEncoder = this.getCommandEncoder().beginComputePass(computePassDescriptor); + this.computePassEncoder = commandEncoder.beginComputePass(computePassDescriptor); } return this.computePassEncoder; } @@ -360,11 +408,16 @@ export class WebGpuBackend { // create info for inputs const inputDatas: GpuData[] = []; for (let i = 0; i < inputTensorViews.length; ++i) { - const gpuData = this.gpuDataManager.get(inputTensorViews[i].data); + const data = inputTensorViews[i].data; + // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it. + if (data === 0) { + continue; + } + const gpuData = this.gpuDataManager.get(data); if (!gpuData) { - throw new Error(`no GPU data for input: ${inputTensorViews[i].data}`); + throw new Error(`no GPU data for input: ${data}`); } - inputDatas[i] = gpuData; + inputDatas.push(gpuData); } const {outputs, dispatchGroup, programUniforms} = program.getRunData(inputTensorViews); @@ -394,6 +447,11 @@ export class WebGpuBackend { const tensorView = (isTemporary || isPersistent) ? createIntermediateOutput(outputs[i].dataType, outputs[i].dims) : createKernelOutput(validatedOutputIndices[i], outputs[i].dataType, outputs[i].dims); + outputTensorViews.push(tensorView); + // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it. + if (tensorView.data === 0) { + continue; + } const gpuData = this.gpuDataManager.get(tensorView.data); if (!gpuData) { throw new Error(`no GPU data for output: ${tensorView.data}`); @@ -409,10 +467,24 @@ export class WebGpuBackend { } persistentData.push(gpuData); } - outputTensorViews.push(tensorView); outputDatas.push(gpuData); } + // when there are any zero-sized tensor in the inputs or outputs, we should report error unless all outputs are + // zero-sized tensors. + if (inputDatas.length !== inputTensorViews.length || outputDatas.length !== outputTensorViews.length) { + // if all outputs are zero-sized tensors, there is no need to run the program. + if (outputDatas.length === 0) { + TRACE_FUNC_END(program.name); + return outputTensorViews; + } + // if some outputs are zero-sized tensors, report an error. + // + // TODO: so far we don't see any use case that outputs include both zero-sized tensors and non-zero-sized tensors. + // If we see such use case, we need to make a change here to support it. + throw new Error( + `Program ${program.name} has zero-sized tensor(s) in inputs or outputs. This is not supported now.`); + } // load uniforms // TODO: add cache for uniform (is it necessary?) @@ -428,10 +500,10 @@ export class WebGpuBackend { return; } // https://www.w3.org/TR/WGSL/#alignof - const sizeOfElement = v.type === 'float16' ? 2 : 4; + const sizeOfElement = v.type === DataType.float16 ? 2 : 4; let sizeOfVecOrMat; let baseAlignment; - if (v.type === 'float16') { + if (v.type === DataType.float16) { baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement); sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length; } else { @@ -445,7 +517,7 @@ export class WebGpuBackend { // SizeOf(vec4). For float16 type, when data.length > 4, the uniform variable is of type // array,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte // length is N * SizeOf(mat2x4). - const elementPerVecOrMat = v.type === 'float16' ? 8 : 4; + const elementPerVecOrMat = v.type === DataType.float16 ? 8 : 4; currentOffset += data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat : data.length * sizeOfElement; }); @@ -458,15 +530,17 @@ export class WebGpuBackend { programUniforms.forEach((v, i) => { const offset = offsets[i]; const data = typeof v.data === 'number' ? [v.data] : v.data; - if (v.type === 'int32') { + if (v.type === DataType.int32) { new Int32Array(arrayBuffer, offset, data.length).set(data); - } else if (v.type === 'uint32') { + } else if (v.type === DataType.uint32) { new Uint32Array(arrayBuffer, offset, data.length).set(data); - } else if (v.type === 'float16') { + } else if (v.type === DataType.float16) { // TODO: use Float16Array. new Uint16Array(arrayBuffer, offset, data.length).set(data); - } else { + } else if (v.type === DataType.float) { new Float32Array(arrayBuffer, offset, data.length).set(data); + } else { + throw new Error(`Unsupported uniform type: ${tensorDataTypeEnumToString(v.type)}`); } }); @@ -494,7 +568,7 @@ export class WebGpuBackend { () => `[ProgramManager] run "${program.name}" (key=${key}) with ${normalizedDispatchGroup[0]}x${ normalizedDispatchGroup[1]}x${normalizedDispatchGroup[2]}`); - if (this.queryType !== 'none') { + if (this.queryType !== 'none' || this.sessionStatus === 'capturing') { const pendingKernelInfo: PendingKernelInfo = { kernelId: this.currentKernelId!, programName: artifact.programInfo.name, @@ -502,6 +576,11 @@ export class WebGpuBackend { outputTensorViews, }; this.pendingKernels.push(pendingKernelInfo); + + if (this.sessionStatus === 'capturing') { + const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); + sessionPendingKernels!.push(pendingKernelInfo); + } } this.programManager.run(artifact, inputDatas, outputDatas, normalizedDispatchGroup, uniformBufferBinding); @@ -654,7 +733,8 @@ export class WebGpuBackend { } setQueryType(): void { this.queryType = 'none'; - if (this.env.webgpu.profiling?.mode === 'default' || this.env.wasm.trace) { + if (this.env.webgpu.profiling?.mode === 'default' || + (typeof this.env.trace === 'undefined' ? this.env.wasm.trace : this.env.trace)) { if (this.device.features.has('chromium-experimental-timestamp-query-inside-passes')) { this.queryType = 'inside-passes'; } else if (this.device.features.has('timestamp-query')) { @@ -672,7 +752,69 @@ export class WebGpuBackend { } } } - onRunStart(): void { + + captureBegin(): void { + LOG_DEBUG('info', 'captureBegin'); + if (!this.capturedCommandList.get(this.currentSessionId!)) { + this.capturedCommandList.set(this.currentSessionId!, []); + } + if (!this.capturedPendingKernels.get(this.currentSessionId!)) { + this.capturedPendingKernels.set(this.currentSessionId!, []); + } + // flush the left commands before we change the status. + this.flush(); + this.sessionStatus = 'capturing'; + } + captureEnd(): void { + LOG_DEBUG('info', 'captureEnd'); + // flush the left commands before we change the status. + this.flush(); + this.sessionStatus = 'default'; + } + replay(): void { + LOG_DEBUG('info', 'replay'); + this.sessionStatus = 'replaying'; + const sessionCommandList = this.capturedCommandList.get(this.currentSessionId!); + const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); + const length = sessionCommandList!.length; + this.pendingKernels = []; + for (let i = 0; i < length; i++) { + const computePassEncoder = this.getComputePassEncoder(); + const command = sessionCommandList![i]; + this.writeTimestamp(this.pendingDispatchNumber * 2); + computePassEncoder.setPipeline(command.computePipeline); + computePassEncoder.setBindGroup(0, command.bindGroup); + computePassEncoder.dispatchWorkgroups(...command.dispatchGroup); + this.writeTimestamp(this.pendingDispatchNumber * 2 + 1); + this.pendingDispatchNumber++; + if (this.queryType !== 'none') { + this.pendingKernels.push(sessionPendingKernels![i]); + } + if (this.pendingDispatchNumber >= this.maxDispatchNumber || this.queryType === 'at-passes') { + this.endComputePass(); + } + if (this.pendingDispatchNumber >= this.maxDispatchNumber) { + this.flush(); + } + } + // flush the left commands before we change the status. + this.flush(); + this.sessionStatus = 'default'; + } + + onReleaseSession(sessionId: number): void { + this.unregisterBuffers(sessionId); + if (this.capturedCommandList.has(sessionId)) { + this.capturedCommandList.delete(sessionId); + } + if (this.capturedPendingKernels.has(sessionId)) { + this.capturedPendingKernels.delete(sessionId); + } + this.gpuDataManager.onReleaseSession(sessionId); + } + + onRunStart(sessionId: number): void { + this.currentSessionId = sessionId; this.setQueryType(); } } diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index f1794d71579bf..adcaa145cdca8 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -10,7 +10,7 @@ import {WebGpuBackend} from './backend-webgpu'; import {LOG_DEBUG} from './log'; import {TensorView} from './tensor-view'; import {ShapeUtil} from './util'; -import {ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo} from './webgpu/types'; +import {AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo} from './webgpu/types'; /* eslint-disable no-bitwise */ @@ -54,6 +54,7 @@ class TensorViewImpl implements TensorView { } class ComputeContextImpl implements ComputeContext { + readonly adapterInfo: AdapterInfo; readonly opKernelContext: number; readonly inputs: readonly TensorView[]; readonly outputCount: number; @@ -66,6 +67,7 @@ class ComputeContextImpl implements ComputeContext { private customDataOffset = 0; private customDataSize = 0; constructor(private module: OrtWasmModule, private backend: WebGpuBackend, contextDataOffset: number) { + this.adapterInfo = backend.adapterInfo; const heapU32 = module.HEAPU32; // extract context data @@ -104,7 +106,8 @@ class ComputeContextImpl implements ComputeContext { throw new Error(`Unsupported data type: ${dataType}`); } const bufferSize = elementSize * ShapeUtil.size(dims); - return new TensorViewImpl(this.module, dataType, this.backend.gpuDataManager.create(bufferSize).id, dims); + const gpuDataId = bufferSize > 0 ? this.backend.gpuDataManager.create(bufferSize).id : 0; + return new TensorViewImpl(this.module, dataType, gpuDataId, dims); }; return this.backend.run(program, mappedInputs, outputIndices, createKernelOutput, createTemporaryOutput); } @@ -118,7 +121,7 @@ class ComputeContextImpl implements ComputeContext { for (let i = 0; i < dims.length; i++) { this.module.HEAPU32[offset++] = dims[i]; } - return this.module._JsepOutput(this.opKernelContext, index, data); + return this.module._JsepOutput!(this.opKernelContext, index, data); } catch (e) { throw new Error( `Failed to generate kernel's output[${index}] with dims [${dims}]. ` + @@ -133,27 +136,39 @@ class ComputeContextImpl implements ComputeContext { /** * Initialize JSEP with WebGPU backend. * - * This function will be called only once after the WebAssembly module is loaded and initialized ("_OrtInit" is called). - * This function expects: + * This function will be called after the WebAssembly module is loaded and initialized ("_OrtInit" is called), once for + * each of the following EPs if they are specified: + * - "webgpu" + * - "webnn" + * + * For WebGPU, this function expects: * - WebGPU is enabled in build (BUILD_DEFS.DISABLE_WEBGPU === false). * - WebGPU is available in current environment. (a valid GPUAdapter is passed in) + * + * For WebNN, this function expects: + * - WebNN is enabled in build (BUILD_DEFS.DISABLE_WEBGPU === false). + * - WebNN is available in current environment. (navigator.ml is not undefined) + * * If the WebAssembly module is not built with JSEP support, this function will throw an error. This will invalidate - * 'webgpu' backend. + * 'webgpu'/'webnn' backend. * + * @param name - the name of the EP, either "webgpu" or "webnn" * @param module - the ORT WebAssembly module * @param env - the ORT environment variable (ort.env) * @param gpuAdapter - the pre-created GPU adapter */ -export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapter): Promise => { +export const init = + async(name: 'webgpu'|'webnn', module: OrtWasmModule, env: Env, gpuAdapter?: GPUAdapter): Promise => { const jsepInit = module.jsepInit; if (!jsepInit) { throw new Error('Failed to initialize JSEP. The WebAssembly module is not built with JSEP support.'); } - const backend = new WebGpuBackend(); - await backend.initialize(env, gpuAdapter); + if (name === 'webgpu') { + const backend = new WebGpuBackend(); + await backend.initialize(env, gpuAdapter!); - jsepInit( + jsepInit('webgpu', [ // backend backend, @@ -187,8 +202,8 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte }, // jsepCreateKernel - (kernelType: string, kernelId: number, attribute: unknown) => - backend.createKernel(kernelType, kernelId, attribute, module.UTF8ToString(module._JsepGetNodeName(kernelId))), + (kernelType: string, kernelId: number, attribute: unknown) => backend.createKernel( + kernelType, kernelId, attribute, module.UTF8ToString(module._JsepGetNodeName!(kernelId))), // jsepReleaseKernel (kernel: number) => backend.releaseKernel(kernel), @@ -201,5 +216,15 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte contextDataOffset}`); const context = new ComputeContextImpl(module, backend, contextDataOffset); return backend.computeKernel(kernel, context, errors); - }); + }, + // jsepCaptureBegin + () => backend.captureBegin(), + // jsepCaptureEnd + () => backend.captureEnd(), + // jsepReplay + () => backend.replay() + ]); + } else { + jsepInit('webnn'); + } }; diff --git a/js/web/lib/wasm/jsep/util.ts b/js/web/lib/wasm/jsep/util.ts index 6922d7ff5df6e..9a1d5463f7843 100644 --- a/js/web/lib/wasm/jsep/util.ts +++ b/js/web/lib/wasm/jsep/util.ts @@ -56,7 +56,16 @@ export class BroadcastUtil { if (aLen !== bLen && aLen > 1 && bLen > 1) { return undefined; } - cdims[crank - i] = Math.max(aLen, bLen); + const max = Math.max(aLen, bLen); + if (aLen && bLen) { + cdims[crank - i] = Math.max(aLen, bLen); + } else { + // when either aLen or bLen is 0, the other should be either 0 or 1, otherwise it is not broadcastable. + if (max > 1) { + return undefined; + } + cdims[crank - i] = 0; + } } return cdims; @@ -92,6 +101,34 @@ export class ShapeUtil { return ShapeUtil.getSizeFromDimensionRange(dims, 0, dims.length); } + /** + * convert dims corresponding to type change to pack. ex. uint8 data to uint32 + */ + static convertShape(dims: readonly number[], size = 4): readonly number[] { + const rank = dims.length; + if (rank === 0) { + return []; + } + const newDims = new Array(rank); + let i = rank - 1; + while (i >= 0) { + if (dims[i] % size === 0) { + newDims[i] = dims[i] / size; + break; + } + if (size % dims[i] !== 0) { + throw new Error('cannot convert shape'); + } + newDims[i] = 1; + size /= dims[i]; + i--; + } + for (i--; i >= 0; i--) { + newDims[i] = dims[i]; + } + return newDims; + } + /** * calculate the size (number of elements) from the given axis (inclusive) */ diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 6f3d9a52d9f5d..c17bd1e1477ec 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -60,9 +60,15 @@ export interface GpuDataManager { unregisterExternalBuffer(buffer: GPUBuffer): void; /** - * destroy all gpu buffers. Call this when the session.release is called. + * destroy all gpu buffers. */ dispose(): void; + + /** + * release session related data. + * @param sessionId - specify the session ID. + */ + onReleaseSession(sessionId: number): void; } interface StorageCacheValue { @@ -139,6 +145,10 @@ class GpuDataManagerImpl implements GpuDataManager { // The external buffers registered users for IO Binding. private externalBuffers: Map; + // The pendingBuffers for capture graph. + // a SessionID -> GPUBuffer[] mapping. + private capturedPendingBuffers: Map; + constructor(private backend: WebGpuBackend) { this.storageCache = new Map(); this.freeBuffers = new Map(); @@ -146,6 +156,7 @@ class GpuDataManagerImpl implements GpuDataManager { this.buffersForUploadingPending = []; this.buffersPending = []; this.externalBuffers = new Map(); + this.capturedPendingBuffers = new Map(); } upload(id: GpuDataId, data: Uint8Array): void { @@ -220,6 +231,9 @@ class GpuDataManagerImpl implements GpuDataManager { () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${ id}, buffer is the same, skip.`); return id; + } else if (this.backend.capturedCommandList.has(this.backend.currentSessionId!)) { + throw new Error(`Registering a different external buffer under graph capture mode is not supported yet. + Please use the previous external buffer!`); } this.externalBuffers.delete(previousBuffer); } else { @@ -312,20 +326,39 @@ class GpuDataManagerImpl implements GpuDataManager { buffer.destroy(); } this.buffersForUploadingPending = []; - for (const buffer of this.buffersPending) { - // eslint-disable-next-line no-bitwise - if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) { - // Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing. - this.freeBuffers.get(buffer.size)!.push(buffer); + + if (this.buffersPending.length === 0) { + return; + } + + if (this.backend.sessionStatus === 'default') { + for (const buffer of this.buffersPending) { // eslint-disable-next-line no-bitwise - } else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) { - // Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing. - this.freeUniformBuffers.get(buffer.size)!.push(buffer); - } else { - buffer.destroy(); + if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) { + // Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing. + this.freeBuffers.get(buffer.size)!.push(buffer); + // eslint-disable-next-line no-bitwise + } else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) { + // Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing. + this.freeUniformBuffers.get(buffer.size)!.push(buffer); + } else { + buffer.destroy(); + } + } + this.buffersPending = []; + } else { + // Don't release intermediate tensors in non-default mode. + // TODO: reuse the storage buffers in non-default mode. + let capturedBuffers = this.capturedPendingBuffers.get(this.backend.currentSessionId!); + if (!capturedBuffers) { + capturedBuffers = []; + this.capturedPendingBuffers.set(this.backend.currentSessionId!, capturedBuffers); } + for (const buffer of this.buffersPending) { + capturedBuffers.push(buffer); + } + this.buffersPending = []; } - this.buffersPending = []; } dispose() { @@ -344,9 +377,26 @@ class GpuDataManagerImpl implements GpuDataManager { storage.gpuData.buffer.destroy(); }); + this.capturedPendingBuffers.forEach((buffers) => { + buffers.forEach(buffer => { + buffer.destroy(); + }); + }); this.storageCache = new Map(); this.freeBuffers = new Map(); this.freeUniformBuffers = new Map(); + this.capturedPendingBuffers = new Map(); + } + + onReleaseSession(sessionId: number) { + // release the captured pending buffers. + const pendingBuffers = this.capturedPendingBuffers.get(sessionId); + if (pendingBuffers) { + pendingBuffers.forEach(buffer => { + buffer.destroy(); + }); + this.capturedPendingBuffers.delete(sessionId); + } } } diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index d737a28654220..ba874c8dd0f80 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -13,12 +13,14 @@ import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose' import {cumsum, parseCumSumAttributes} from './ops/cumsum'; import {einsum, parseEinsumAttributes} from './ops/einsum'; import {expand} from './ops/expand'; +import {fastGelu} from './ops/fast-gelu'; import {gather, parseGatherAttributes} from './ops/gather'; import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements'; import {gemm, parseGemmAttributes} from './ops/gemm'; import {instanceNorm} from './ops/instance-norm'; import {layerNorm} from './ops/layer-norm'; import {matMul} from './ops/matmul'; +import {matMulNBits, parseMatMulNBitsAttributes} from './ops/matmulnbits'; import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion'; import {pad} from './ops/pad'; import * as pool from './ops/pool'; @@ -72,6 +74,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Erf', [unaryOps.erf]], ['Exp', [unaryOps.exp]], ['Expand', [expand]], + ['FastGelu', [fastGelu]], ['Floor', [unaryOps.floor]], ['FusedConv', [conv, parseConvAttributes]], ['Gather', [gather, parseGatherAttributes]], @@ -90,6 +93,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['LessOrEqual', [binaryOps.lessOrEqual]], ['Log', [unaryOps.log]], ['MatMul', [matMul]], + ['MatMulNBits', [matMulNBits, parseMatMulNBitsAttributes]], // TODO: support new attributes for MaxPool-8 and MaxPool-10 ['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]], ['Mul', [binaryOps.mul]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 1a03621512888..24006d393592a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -19,12 +19,13 @@ // // modified to fit the needs of the project +import {DataType} from '../../../../wasm-common'; import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvAttributes} from '../conv'; -import {getActivationSnippet} from '../fuse-utils'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils'; import {biasSnippet, typeSnippet} from './activation_util'; import {utilFunctions} from './conv_util'; @@ -129,7 +130,7 @@ const conv2dCommonSnippet = isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); const bType = isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType); - const applyActivation = getActivationSnippet(attributes, resType); + const applyActivation = getActivationSnippet(attributes, resType, dataType); const userCode = ` fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} { ${isChannelsLast ? sampleX : sampleW} @@ -189,16 +190,12 @@ export const createConv2DMatMulProgramInfo = const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1]; const programUniforms: ProgramUniform[] = [ - {type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}, - {type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'int32', data: attributes.strides}, - {type: 'int32', data: attributes.dilations} + {type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter}, + {type: DataType.int32, data: dimInner}, {type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]}, + {type: DataType.int32, data: attributes.strides}, {type: DataType.int32, data: attributes.dilations} ]; - if (attributes.activation === 'Clip') { - programUniforms.push( - {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); - } - programUniforms.push( - ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)); + appendActivationUniformsData(attributes, programUniforms); + programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; if (hasBias) { programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); @@ -212,9 +209,7 @@ export const createConv2DMatMulProgramInfo = {name: 'pad', type: 'i32', length: 2}, {name: 'stride', type: 'i32', length: 2}, {name: 'dilation', type: 'i32', length: 2} ]; - if (attributes.activation === 'Clip') { - uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); - } + appendActivationUniforms(attributes, uniforms); // TODO: support component 2, 3. const components = isVec4 ? 4 : 1; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index 33e50a9a39cb9..11c8778b72335 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -19,20 +19,21 @@ // // modified to fit the needs of the project +import {DataType} from '../../../../wasm-common'; import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; -import {getActivationSnippet} from '../fuse-utils'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils'; -import {biasSnippet, typeSnippet} from './activation_util'; +import {biasSnippet} from './activation_util'; import {utilFunctions} from './conv_util'; import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; const conv2dTransposeCommonSnippet = - (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, innerElementSize = 4): string => { - const type = typeSnippet(innerElementSize, 'f32'); + (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, type: string, + innerElementSize = 4): string => { const getWSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: @@ -46,7 +47,7 @@ const conv2dTransposeCommonSnippet = let v1 = w[getIndexFromCoords4D(coord1, vec4(uniforms.w_shape))]; let v2 = w[getIndexFromCoords4D(coord2, vec4(uniforms.w_shape))]; let v3 = w[getIndexFromCoords4D(coord3, vec4(uniforms.w_shape))]; - return vec4(v0, v1, v2, v3); + return ${type}(v0, v1, v2, v3); `; default: throw new Error(`innerElementSize ${innerElementSize} is not supported.`); @@ -197,16 +198,13 @@ export const createConv2DTransposeMatMulProgramInfo = ]; const programUniforms: ProgramUniform[] = [ - {type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}, - {type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations}, - {type: 'int32', data: filterDims}, {type: 'int32', data: pads} + {type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter}, + {type: DataType.int32, data: dimInner}, {type: DataType.int32, data: attributes.strides}, + {type: DataType.int32, data: attributes.dilations}, {type: DataType.int32, data: filterDims}, + {type: DataType.int32, data: pads} ]; - if (attributes.activation === 'Clip') { - programUniforms.push( - {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); - } - programUniforms.push( - ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)); + appendActivationUniformsData(attributes, programUniforms); + programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; if (hasBias) { @@ -226,7 +224,7 @@ export const createConv2DTransposeMatMulProgramInfo = const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); inputVariables.push(bias); declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + fn getBiasByOutputCoords(coords : vec4) -> ${bias.type.value} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } @@ -237,19 +235,21 @@ export const createConv2DTransposeMatMulProgramInfo = {name: 'filter_dims', type: 'i32', length: filterDims.length}, {name: 'pads', type: 'i32', length: pads.length} ]; - if (attributes.activation === 'Clip') { - uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + appendActivationUniforms(attributes, uniforms); + const elemType = tensorTypeToWsglStorageType(inputs[0].dataType, 1); + if (elemType !== 'f16' && elemType !== 'f32') { + throw new Error(`elemType ${elemType} is not supported.`); } return ` ${utilFunctions('uniforms.result_strides')} ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}; ${declareFunctions} - ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)} + ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, x.type.value, innerElementSize)} ${ isVec4 ? makeMatMulPackedVec4Source( - elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : + elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner) : makeMatMulPackedSource( - elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false, + elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner, false, undefined, sequentialAccessByThreads)}`; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index 380efc8bc577a..846ad49c5222b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -17,6 +17,7 @@ // sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv_backprop_webgpu.ts +import {DataType} from '../../../../wasm-common'; import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; @@ -264,10 +265,11 @@ export const createConvTranspose2DProgramInfo = const outputChannelsPerGroup = wShape[1]; const programUniforms: ProgramUniform[] = [ - {type: 'int32', data: outputSize}, {type: 'uint32', data: strides}, {type: 'uint32', data: filterDims}, - {type: 'uint32', data: dilations}, {type: 'uint32', data: effectiveFilterDims}, {type: 'int32', data: pads}, - {type: 'uint32', data: inputChannelsPerGroup}, {type: 'uint32', data: outputChannelsPerGroup}, - ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims) + {type: DataType.int32, data: outputSize}, {type: DataType.uint32, data: strides}, + {type: DataType.uint32, data: filterDims}, {type: DataType.uint32, data: dilations}, + {type: DataType.uint32, data: effectiveFilterDims}, {type: DataType.int32, data: pads}, + {type: DataType.uint32, data: inputChannelsPerGroup}, {type: DataType.uint32, data: outputChannelsPerGroup}, + ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims) ]; if (hasBias) { programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index ee71110245252..29c7941e6bd30 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -19,11 +19,12 @@ // // modified to fit the needs of the project +import {DataType} from '../../../../wasm-common'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; import {createTensorShapeVariables, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; -import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -443,20 +444,16 @@ export const createMatmulProgramInfo = const components = isVec4 ? 4 : 1; const aShapeTemp = [...outerDimsA, dimAOuter, dimInner / components]; - const aShapeOrRank = aShapeTemp.length; + const aRank = aShapeTemp.length; const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components]; - const bShapeOrRank = bShapeTemp.length; + const bRank = bShapeTemp.length; const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; - const programUniforms: ProgramUniform[] = - [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; - if (activationAttributes.activation === 'Clip') { - programUniforms.push( - {type: 'float32', data: activationAttributes.clipMax!}, - {type: 'float32', data: activationAttributes.clipMin!}); - } - programUniforms.push( - ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShapeTemp), - ...createTensorShapeVariables(bShapeTemp)); + const programUniforms: ProgramUniform[] = [ + {type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter}, + {type: DataType.int32, data: dimInner} + ]; + appendActivationUniformsData(activationAttributes, programUniforms); + programUniforms.push(...createTensorShapeVariables(outerDims, aShapeTemp, bShapeTemp)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; const hasBias = inputs.length > 2; @@ -467,12 +464,12 @@ export const createMatmulProgramInfo = programUniforms.push(...createTensorShapeVariables(outputShapeTemp)); const getShaderSource = (shaderHelper: ShaderHelper) => { - const batchShapeOrRank = outerDims.length; - const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1); + const batchRank = outerDims.length; + const batchDims = internalVariable('batchDims', inputs[0].dataType, batchRank, 1); const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components); - const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components); + const A = inputVariable('a', inputs[0].dataType, aRank, components); + const B = inputVariable('b', inputs[1].dataType, bRank, components); const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components); const inputVariables = [A, B]; if (hasBias) { @@ -481,10 +478,9 @@ export const createMatmulProgramInfo = } const uniforms: UniformsArrayType = [{name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}]; - if (activationAttributes.activation === 'Clip') { - uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); - } - const applyActivation = getActivationSnippet(activationAttributes, output.type.value); + appendActivationUniforms(activationAttributes, uniforms); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType); const declareFunctions = matMulReadWriteFnSource( components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], isChannelsLast); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index ef8038dff487e..2cfe6356dd6e7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {tensorDataTypeEnumToString} from '../../../wasm-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ComputeContext, GpuDataType, ProgramUniform} from '../types'; @@ -241,9 +241,10 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView WG = Math.ceil(dComp / 8); } const elementsPerWG = Math.ceil(d / components / WG); - const tensorDataType = tensorDataTypeEnumToString(input.dataType) as ProgramUniform['type']; - const programUniforms: ProgramUniform[] = - [{type: tensorDataType, data: 1 / d}, {type: 'uint32', data: dComp}, {type: 'uint32', data: elementsPerWG}]; + const programUniforms: ProgramUniform[] = [ + {type: input.dataType, data: 1 / d}, {type: DataType.uint32, data: dComp}, + {type: DataType.uint32, data: elementsPerWG} + ]; const dataType = tensorTypeToWsglStorageType(input.dataType, components); const getShaderSource = (shaderHelper: ShaderHelper) => { @@ -297,7 +298,7 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView if (sum == 0) { for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) { - x[offset + i] = ${fillVector('f32', components, 'uniforms.d_inv')}; + x[offset + i] = ${fillVector(elemValueType, components, 'uniforms.d_inv')}; } } else { for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) { @@ -336,11 +337,10 @@ const computeAttentionProbs = y: Math.ceil(parameters.sequenceLength / TILE_SIZE), z: parameters.batchSize * parameters.numHeads }; - const tensorDataType = tensorDataTypeEnumToString(q.dataType) as ProgramUniform['type']; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: parameters.sequenceLength}, {type: 'uint32', data: vectorizedHeadSize}, - {type: 'uint32', data: parameters.totalSequenceLength}, {type: 'uint32', data: parameters.kvSequenceLength}, - {type: tensorDataType, data: alpha} + {type: DataType.uint32, data: parameters.sequenceLength}, {type: DataType.uint32, data: vectorizedHeadSize}, + {type: DataType.uint32, data: parameters.totalSequenceLength}, + {type: DataType.uint32, data: parameters.kvSequenceLength}, {type: q.dataType, data: alpha} ]; const inputs = [q, key]; @@ -430,9 +430,9 @@ const computeVxAttentionScore = z: params.batchSize * params.numHeads }; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: params.sequenceLength}, {type: 'uint32', data: params.totalSequenceLength}, - {type: 'uint32', data: params.vHeadSize}, {type: 'uint32', data: params.numHeads}, - {type: 'uint32', data: params.vHiddenSize} + {type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: params.totalSequenceLength}, + {type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads}, + {type: DataType.uint32, data: params.vHiddenSize} ]; const getShaderSource = (shaderHelper: ShaderHelper) => { @@ -526,10 +526,10 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => { }; const inputs = [context.inputs[0], context.inputs[1], context.inputs[2]]; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: M}, {type: 'uint32', data: K}, {type: 'uint32', data: N}, - {type: 'uint32', data: parameters.numHeads}, {type: 'uint32', data: parameters.headSize}, - {type: 'uint32', data: parameters.hiddenSize}, - {type: 'uint32', data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize} + {type: DataType.uint32, data: M}, {type: DataType.uint32, data: K}, {type: DataType.uint32, data: N}, + {type: DataType.uint32, data: parameters.numHeads}, {type: DataType.uint32, data: parameters.headSize}, + {type: DataType.uint32, data: parameters.hiddenSize}, + {type: DataType.uint32, data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize} ]; const getShaderSource = (shaderHelper: ShaderHelper) => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts index 00a6ca75b34fa..39b932375891b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts @@ -3,12 +3,13 @@ import {env} from 'onnxruntime-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common'; export interface BatchNormAttributes extends AttributeWithCacheKey { readonly epsilon: number; @@ -61,7 +62,7 @@ const createBatchNormInferenceProgramInfo = const cComponents = format === 'NHWC' && yShape.length > 1 ? components : 1; const outputSize = ShapeUtil.size(yShape) / components; // Only support uniforms for opset version >= 9 (spatial = true). - const useShapesUniforms = enableShapesUniforms(yShape.length) && spatial; + const useShapesUniforms = spatial; const shapeOrRank = useShapesUniforms ? yShape.length : yShape; const x = inputVariable('x', inputs[0].dataType, inputs[0].dims, components); const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims, cComponents); @@ -123,11 +124,11 @@ const createBatchNormInferenceProgramInfo = dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: useShapesUniforms ? [ - {type: 'uint32', data: outputSize}, + {type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(yShape), ] : [ - {type: 'uint32', data: outputSize}, + {type: DataType.uint32, data: outputSize}, ], }), }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts index a81a7a8f1df5c..089fecd758e30 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts @@ -43,7 +43,7 @@ const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramI ${shaderHelper.declareVariables(input, bias, output)} - ${erfImpl(`vec4<${dataType}>`, dataType)} + ${erfImpl(dataType)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index c033c0ba05356..a094fffe239c4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view'; import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; type BuiltinFunctionName = string; type BinaryCustomExpression = (expressionA: string, expressionB: string) => string; @@ -18,8 +18,7 @@ type BinaryFunctionCall = BuiltinFunctionName|BinaryCustomExpression|{ const createBinaryOpProgramShader = (shaderHelper: ShaderHelper, dimsA: readonly number[], dimsB: readonly number[], dimsOutput: readonly number[], vectorize: boolean, doBroadcast: boolean, sharedDimensionDivisibleBy4: boolean, funcCall: BinaryFunctionCall, - typeA: number, typeB: number, typeOutput: number, useShapesUniforms: boolean, - additionalImplementation?: string) => { + typeA: number, typeB: number, typeOutput: number, additionalImplementation?: string) => { let expressionScalar: BinaryCustomExpression; let expressionVector: BinaryCustomExpression; if (typeof funcCall === 'string') { @@ -31,12 +30,9 @@ const createBinaryOpProgramShader = expressionVector = funcCall.vector; } - const inputAShapeOrRank = useShapesUniforms ? dimsA.length : dimsA; - const inputBShapeOrRank = useShapesUniforms ? dimsB.length : dimsB; - const outputShapeOrRank = useShapesUniforms ? dimsOutput.length : dimsOutput; - const output = outputVariable('outputData', typeOutput, outputShapeOrRank, 4); - const a = inputVariable('aData', typeA, inputAShapeOrRank, 4); - const b = inputVariable('bData', typeB, inputBShapeOrRank, 4); + const output = outputVariable('outputData', typeOutput, dimsOutput.length, 4); + const a = inputVariable('aData', typeA, dimsA.length, 4); + const b = inputVariable('bData', typeB, dimsB.length, 4); let assignment: string; if (vectorize) { @@ -169,30 +165,23 @@ const createBinaryOpProgramInfo = vectorize = true; } cacheKeyAux.push(vectorize); - const useShapesUniforms = enableShapesUniforms(a.dims.length) && enableShapesUniforms(b.dims.length) && - enableShapesUniforms(outputShape.length); + return { name, shaderCache: { hint: cacheKey + cacheKeyAux.map((x) => x.toString()).join('_'), - inputDependencies: useShapesUniforms ? ['rank', 'rank'] : ['dims', 'dims'], + inputDependencies: ['rank', 'rank'], }, getShaderSource: (shaderHelper) => createBinaryOpProgramShader( shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, sharedDimensionDivisibleBy4, funcCall, - a.dataType, b.dataType, outputDataType, useShapesUniforms, additionalImplementation), + a.dataType, b.dataType, outputDataType, additionalImplementation), getRunData: () => ({ outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}, - programUniforms: useShapesUniforms ? - [ - {type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, - ...createTensorShapeVariables(a.dims), - ...createTensorShapeVariables(b.dims), - ...createTensorShapeVariables(outputShape), - ] : - [ - {type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, - ], + programUniforms: [ + {type: DataType.uint32, data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, + ...createTensorShapeVariables(a.dims, b.dims, outputShape) + ], }), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 643744108c0f4..516094d0ef87b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -259,8 +259,16 @@ export const tensorTypeToWsglValueType = (type: DataType, components: 1|2|3|4 = return typeof mappedType === 'string' ? mappedType : mappedType[1]; }; -export const createTensorShapeVariables = (dims: readonly number[]): ProgramUniform[] => - dims.length === 0 ? [] : [{type: 'uint32', data: dims}, {type: 'uint32', data: ShapeUtil.computeStrides(dims)}]; +export const createTensorShapeVariables = (...dims: ReadonlyArray): ProgramUniform[] => { + const programUniforms: ProgramUniform[] = []; + dims.forEach(dim => { + if (dim.length !== 0) { + programUniforms.push( + {type: DataType.uint32, data: dim}, {type: DataType.uint32, data: ShapeUtil.computeStrides(dim)}); + } + }); + return programUniforms; +}; /** * A helper function to get maximum vector size for specified data length @@ -922,6 +930,3 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly } return dims; }; - -// TODO: remove this when all related uses have been removed. -export const enableShapesUniforms = (_rank: number): boolean => true; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 43cc4a4c080bd..010ee589c44fa 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -1,36 +1,44 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; export interface ConcatAttributes extends AttributeWithCacheKey { readonly axis: number; } -const validateInputs = (inputs: readonly TensorView[]): void => { +const validateInputs = (inputs: readonly TensorView[], axis: number): void => { if (!inputs || inputs.length < 1) { throw new Error('too few inputs'); } - - const inputType = inputs[0].dataType; - const inputDimensionality = inputs[0].dims.length; - - for (const input of inputs) { + const referenceIndex = 0; + const referenceInput = inputs[referenceIndex]; + const inputType = referenceInput.dataType; + const inputRank = referenceInput.dims.length; + inputs.forEach((input, i) => { + if (i === referenceIndex) { + return; + } // make sure types of all inputs match if (input.dataType !== inputType) { throw new Error('input tensors should be one type'); } - // make sure the dimensionality of all inputs are the same - if (input.dims.length !== inputDimensionality) { + if (input.dims.length !== inputRank) { throw new Error('input tensors should have the same shape'); } - } + input.dims.forEach((dim, i) => { + if (i !== axis && dim !== referenceInput.dims[i]) { + throw new Error('non concat dimensions must match'); + } + }); + }); }; const calculateInputIndexImpl = (numberOfTensors: number, sizeInConcatAxisStr: string): string => ` @@ -63,75 +71,43 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe return codeLines.join('\n'); }; -const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): ProgramInfo => { - const inputShape = inputs[0].dims.slice(); - if (axis >= inputShape.length || axis < (-1 * inputShape.length)) { - throw new Error('axis specified for concat doesn\'t match input dimensionality'); - } - const adjustedAxis = (axis < 0) ? inputShape.length + axis : axis; - // ensure all of the non-concatenated axes match each other - // calculate the shape of the output tensor while we do that - const outputShape = inputShape.slice(0); - for (let i = 1; i < inputs.length; i++) { - const dataNShape = inputs[i].dims.slice(); - for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) { - // add to the placeholder for computing output shape - if (axisIndex === adjustedAxis) { - outputShape[adjustedAxis] += dataNShape[axisIndex]; +const createConcatProgramInfo = + (inputs: readonly TensorView[], adjustedAxis: number, outputShape: number[], dataType: DataType): ProgramInfo => { + const outputSize = ShapeUtil.size(outputShape); + + const sizeInConcatAxis = new Array(inputs.length); + const inputVars = new Array(inputs.length); + + let previousSum = 0; + const inputDependencies: ProgramInputTensorInfoDependency[] = []; + const inputRanks = []; + const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}]; + for (let i = 0; i < inputs.length; ++i) { + previousSum += inputs[i].dims[adjustedAxis]; + sizeInConcatAxis[i] = previousSum; + inputRanks.push(inputs[i].dims.length); + inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]); + inputDependencies.push('rank'); + programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]}); } - // ensure all non-cancatenated axes match each other - else if (inputShape[axisIndex] !== dataNShape[axisIndex]) { - throw new Error('non concat dimensions must match'); + for (let i = 0; i < inputs.length; ++i) { + programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); } - } - } - - const outputSize = ShapeUtil.size(outputShape); - - const sizeInConcatAxis = new Array(inputs.length); - const inputVars = new Array(inputs.length); - const dataType = inputs[0].dataType; - - let previousSum = 0; - const inputDependencies: ProgramInputTensorInfoDependency[] = []; - const inputShapeOrRanks = []; - const enableInputShapesUniforms = []; - const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}]; - for (let i = 0; i < inputs.length; ++i) { - previousSum += inputs[i].dims[adjustedAxis]; - sizeInConcatAxis[i] = previousSum; - enableInputShapesUniforms.push(enableShapesUniforms(inputs[i].dims.length)); - inputShapeOrRanks.push(enableInputShapesUniforms[i] ? inputs[i].dims.length : inputs[i].dims); - inputVars[i] = inputVariable(`input${i}`, dataType, inputShapeOrRanks[i]); - inputDependencies.push(enableInputShapesUniforms[i] ? 'rank' : 'dims'); - programUniforms.push({type: 'uint32', data: sizeInConcatAxis[i]}); - } - for (let i = 0; i < inputs.length; ++i) { - if (enableInputShapesUniforms[i]) { - programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); - } - } + programUniforms.push(...createTensorShapeVariables(outputShape)); - const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); - if (enableOutputShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(outputShape)); - } - - const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; - const output = outputVariable('output', dataType, outputShapeOrRank); - - const indicesAxis = output.indicesGet('indices', adjustedAxis); - const sizeInConcatAxisStr = - Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(','); - const getShaderSource = (shaderHelper: ShaderHelper) => ` + const output = outputVariable('output', dataType, outputShape.length); + const indicesAxis = output.indicesGet('indices', adjustedAxis); + const sizeInConcatAxisStr = + Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(','); + const getShaderSource = (shaderHelper: ShaderHelper) => ` ${(() => { - shaderHelper.registerUniform('outputSize', 'u32'); - for (let i = 0; i < inputs.length; i++) { - shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32'); - } - return shaderHelper.declareVariables(...inputVars, output); - })()} + shaderHelper.registerUniform('outputSize', 'u32'); + for (let i = 0; i < inputs.length; i++) { + shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32'); + } + return shaderHelper.declareVariables(...inputVars, output); + })()} ${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)} @@ -149,21 +125,30 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P ${assignOutputData(inputVars, output)} }`; - return { - name: 'Concat', - shaderCache: {hint: `${axis}`, inputDependencies}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms, - }), - getShaderSource, - }; -}; + return { + name: 'Concat', + shaderCache: {hint: `${adjustedAxis}`, inputDependencies}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms, + }), + getShaderSource, + }; + }; export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => { - validateInputs(context.inputs); - context.compute(createConcatProgramInfo(context.inputs, attributes.axis)); + const inputs = context.inputs; + const inputShape = inputs[0].dims; + const adjustedAxis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length); + validateInputs(inputs, adjustedAxis); + const outputShape = inputShape.slice(); + outputShape[adjustedAxis] = + inputs.reduce((sum, input) => sum + (input.dims.length > adjustedAxis ? input.dims[adjustedAxis] : 0), 0); + // 0 length tensors are valid for concat, remove them + const nonEmptyInputs = inputs.filter(input => ShapeUtil.size(input.dims) > 0); + context.compute( + createConcatProgramInfo(nonEmptyInputs, adjustedAxis, outputShape, inputs[0].dataType), {inputs: nonEmptyInputs}); }; export const parseConcatAttributes = (attributes: Record): ConcatAttributes => diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index f81d6577890c5..7d424305c715f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -1,13 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; +import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; import {calculateOutputShape, ConvAttributes} from './conv'; -import {getActivationSnippet} from './fuse-utils'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from './fuse-utils'; /** * naive grouped conv implementation, supports 1d/2d conv @@ -28,17 +29,13 @@ export const createGroupedConvProgramInfo = const outputSize = ShapeUtil.size(outputShape); const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.dilations}, - {type: 'uint32', data: [attributes.strides[0], attributes.strides[1]]}, - {type: 'uint32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'uint32', data: outputChannelsPerGroup} + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.dilations}, + {type: DataType.uint32, data: [attributes.strides[0], attributes.strides[1]]}, + {type: DataType.uint32, data: [attributes.pads[0], attributes.pads[1]]}, + {type: DataType.uint32, data: outputChannelsPerGroup} ]; - if (attributes.activation === 'Clip') { - programUniforms.push( - {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); - } - programUniforms.push( - ...createTensorShapeVariables(xShape), ...createTensorShapeVariables(wShape), - ...createTensorShapeVariables(outputShape)); + appendActivationUniformsData(attributes, programUniforms); + programUniforms.push(...createTensorShapeVariables(xShape, wShape, outputShape)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; if (hasBias) { programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); @@ -48,7 +45,8 @@ export const createGroupedConvProgramInfo = const getShaderSource = (shaderHelper: ShaderHelper) => { const output = outputVariable('output', inputs[0].dataType, outputShape.length); - const applyActivation = getActivationSnippet(attributes, output.type.value); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(attributes, output.type.value, baseType); const x = inputVariable('x', inputs[0].dataType, xShape.length); const w = inputVariable('w', inputs[1].dataType, wShape.length); const inputVars = [x, w]; @@ -61,9 +59,7 @@ export const createGroupedConvProgramInfo = {name: 'strides', type: 'u32', length: 2}, {name: 'pads', type: 'u32', length: 2}, {name: 'output_channels_per_group', type: 'u32'} ]; - if (attributes.activation === 'Clip') { - uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); - } + appendActivationUniforms(attributes, uniforms); return ` ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)} @@ -132,14 +128,17 @@ export const createGroupedConvVectorizeProgramInfo = const outputShapeInShader = [outputShape[0], outputShape[1], outputShape[2], outputShape[3] / components]; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'int32', data: attributes.strides}, - {type: 'int32', data: attributes.pads}, ...createTensorShapeVariables(xShape), - ...createTensorShapeVariables(wShape), ...createTensorShapeVariables(outputShapeInShader) + {type: DataType.uint32, data: outputSize}, + {type: DataType.int32, data: [attributes.strides[0], attributes.strides[1]]}, + {type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]} ]; + appendActivationUniformsData(attributes, programUniforms); + programUniforms.push(...createTensorShapeVariables(xShape, wShape, outputShapeInShader)); const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1]; const getShaderSource = (shaderHelper: ShaderHelper) => { const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); - const applyActivation = getActivationSnippet(attributes, output.type.value); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(attributes, output.type.value, baseType); const x = inputVariable('x', inputs[0].dataType, xShape.length, components); const w = inputVariable('w', inputs[1].dataType, wShape.length, components); const inputVars = [x, w]; @@ -147,13 +146,14 @@ export const createGroupedConvVectorizeProgramInfo = inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims, components)); } const processBias = hasBias ? 'value += b[output_channel];' : ''; - + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, + {name: 'strides', type: 'i32', length: 2}, + {name: 'pads', type: 'i32', length: 2}, + ]; + appendActivationUniforms(attributes, uniforms); return ` - ${ - shaderHelper.registerUniform('output_size', 'u32') - .registerUniform('strides', 'i32', 2) - .registerUniform('pads', 'i32', 2) - .declareVariables(...inputVars, output)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let width0 = uniforms.output_shape[3]; @@ -173,7 +173,7 @@ export const createGroupedConvVectorizeProgramInfo = // Use constant instead of uniform can give better performance for w's height/width. for (var w_height: u32 = 0u; w_height < ${wShape[0]}; w_height++) { let x_height = x_corner.x + i32(w_height); - if (x_height >= 0 || u32(x_height) < uniforms.x_shape[1]) { + if (x_height >= 0 && u32(x_height) < uniforms.x_shape[1]) { for (var i = 0; i < ${xNumber}; i++) { let x_width = x_corner.y + i; if (x_width >= 0 && u32(x_width) < uniforms.x_shape[2]) { @@ -185,7 +185,7 @@ export const createGroupedConvVectorizeProgramInfo = for (var w_width: u32 = 0u; w_width < ${wShape[1]}; w_width++) { let w_val = ${w.get('w_height', 'w_width', '0', 'output_channel')}; for (var i = 0u; i < ${outputNumber}u; i++) { - values[i] = fma(x_vals[i * ${attributes.strides[1]}u + w_width], w_val, values[i]); + values[i] = fma(x_vals[i * u32(uniforms.strides[1]) + w_width], w_val, values[i]); } } } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index 5afec0389fac8..b68d4dcae4cb9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -148,11 +148,12 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut // const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */ const isChannelsLast = attributes.format === 'NHWC'; if (attributes.group !== 1) { - // Temporarily disable createGroupedConvVectorizeProgramInfo path due to bots failures with below two cases: + // NVIDIA GPU with ampere architecture fails with below 2 cases, but we couldn't repro them with any other + // GPUs. So just disable vectorize on NVIDIA ampere to ensure always correct outputs. // [webgpu]Conv - conv - vectorize group - B // [webgpu]Conv - conv - vectorize group - D - const disableGroupedConvVectorize = true; - if (!disableGroupedConvVectorize && isChannelsLast && inputs[1].dims[0] === attributes.group && + const enableGroupedConvVectorize = !context.adapterInfo.isArchitecture('ampere'); + if (enableGroupedConvVectorize && isChannelsLast && inputs[1].dims[0] === attributes.group && inputs[1].dims[1] === 1 && attributes.dilations[0] === 1 && attributes.dilations[1] === 1) { const outputShape = calculateOutputShape( inputs[0].dims, inputs[1].dims, attributes.dilations, adjustedAttributes.pads, attributes.strides, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts index 2ff909c30e62e..6080301d9946b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts @@ -54,8 +54,8 @@ const createCumsumProgramInfo = outputs: [{dims: inputShape, dataType: inputType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: [ - {type: 'uint32', data: outputSize}, {type: 'int32', data: axis}, - ...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape) + {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axis}, + ...createTensorShapeVariables(inputShape, inputShape) ] }), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts index 4db7c04ad67be..19a009c2eb79b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; - +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; export interface EinsumAttributes extends AttributeWithCacheKey { readonly equation: string; @@ -181,14 +181,12 @@ class EinsumEquation { const appendMax = (name: string): string => name + '_max'; const createEinsumProgramInfo = - (enableInputShapesUniforms: readonly boolean[], inputShapes: Array, dataType: number, - einsumEquation: EinsumEquation, outputShape: readonly number[]): ProgramInfo => { - const shapeOrRanks = inputShapes.map((dims, index) => enableInputShapesUniforms[index] ? dims.length : dims); - const inputVars = shapeOrRanks.map((shapeOrRank, index) => inputVariable(`input${index}`, dataType, shapeOrRank)); + (inputShapes: Array, dataType: number, einsumEquation: EinsumEquation, + outputShape: readonly number[]): ProgramInfo => { + const ranks = inputShapes.map((dims) => dims.length); + const inputVars = ranks.map((rank, index) => inputVariable(`input${index}`, dataType, rank)); const outputSize = ShapeUtil.size(outputShape); - const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); - const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; - const output = outputVariable('output', dataType, outputShapeOrRank); + const output = outputVariable('output', dataType, outputShape.length); const uniformsSymbols = [...einsumEquation.symbolToInfo.keys()].filter((symbol) => !einsumEquation.rhs.symbolToIndices.has(symbol)); const getShaderSource = (shaderHelper: ShaderHelper) => { @@ -269,24 +267,20 @@ const createEinsumProgramInfo = }; return { name: 'Einsum', - shaderCache: { - hint: einsumEquation.equation, - inputDependencies: enableInputShapesUniforms.map((enableShapeUniform) => enableShapeUniform ? 'rank' : 'dims') - }, + shaderCache: {hint: einsumEquation.equation, inputDependencies: inputShapes.map(() => 'rank')}, getRunData: () => { // The symbols from uniformSymbols array are guaranteed to exist in einsumEquations.symbolToInfo map. The // filter is added to make sure that dimValue is never 0. const programUniformsInit: ProgramUniform[] = uniformsSymbols.filter((symbol) => einsumEquation.symbolToInfo.has(symbol)) - .map((symbol) => ({type: 'uint32', data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0})); - programUniformsInit.push({type: 'uint32', data: outputSize}); + .map( + (symbol) => + ({type: DataType.uint32, data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0})); + programUniformsInit.push({type: DataType.uint32, data: outputSize}); const programUniforms: ProgramUniform[] = - inputShapes.filter((_, index) => enableInputShapesUniforms[index]) - .map((dims, _) => [...createTensorShapeVariables(dims)]) + inputShapes.map((dims, _) => [...createTensorShapeVariables(dims)]) .reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit); - if (enableOutputShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(outputShape)); - } + programUniforms.push(...createTensorShapeVariables(outputShape)); return ({ outputs: [{dims: outputShape, dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, @@ -299,11 +293,9 @@ const createEinsumProgramInfo = export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => { const einsumEquation = new EinsumEquation(context.inputs, attributes.equation); - const enableInputShapesUniforms = context.inputs.map((input, _) => enableShapesUniforms(input.dims.length)); const outputShape = einsumEquation.outputDims; const inputShapes = context.inputs.map((input, _) => input.dims); - context.compute(createEinsumProgramInfo( - enableInputShapesUniforms, inputShapes, context.inputs[0].dataType, einsumEquation, outputShape)); + context.compute(createEinsumProgramInfo(inputShapes, context.inputs[0].dataType, einsumEquation, outputShape)); }; export const parseEinsumAttributes = (attributes: Record): EinsumAttributes => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts index 035d89755c7d7..80ee906423e19 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 2) { @@ -49,15 +49,9 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => const components = dataType === DataType.bool ? 4 : 1; const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components); - const enableInputShapeUniform = enableShapesUniforms(inputShape.length); - const enableOutputShapeUniform = enableShapesUniforms(outputShape.length); - - const getShaderSource = (shaderHelper: ShaderHelper) => { - const inputShapeOrRank = enableInputShapeUniform ? inputShape.length : inputShape; - const outputShapeOrRank = enableOutputShapeUniform ? outputShape.length : outputShape; - const input = inputVariable('input', dataType, inputShapeOrRank, components); - const output = outputVariable('output', dataType, outputShapeOrRank, components); + const input = inputVariable('input', dataType, inputShape.length, components); + const output = outputVariable('output', dataType, outputShape.length, components); let assignment: string; if (dataType === DataType.bool) { const singleAssignment = (resStr: string, x: number, typeCast = '') => ` @@ -90,16 +84,11 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => ${assignment}`; }; - const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}]; - if (enableInputShapeUniform) { - programUniforms.push(...createTensorShapeVariables(inputShape)); - } - if (enableOutputShapeUniform) { - programUniforms.push(...createTensorShapeVariables(outputShape)); - } + const programUniforms: ProgramUniform[] = + [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape, outputShape)]; return { name: 'Expand', - shaderCache: {hint: `${outputShape.length}`, inputDependencies: [enableInputShapeUniform ? 'rank' : 'dims']}, + shaderCache: {hint: `${outputShape.length}`, inputDependencies: ['rank']}, getShaderSource, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts b/js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts new file mode 100644 index 0000000000000..f50a6a3f011fe --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {ComputeContext, ProgramInfo} from '../types'; + +import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglValueType, UniformsArrayType, WORKGROUP_SIZE} from './common'; +import * as unary from './unary-op'; + +// GELU is defined as Y=0.5*X*(1+tanh(0.797885*X+0.035677*X*X*X)), where X may pre-add a bias. + +const createFastGeluProgramInfo = (inputTensors: readonly TensorView[]): ProgramInfo => { + const dataType = inputTensors[0].dataType; + const outputSize = ShapeUtil.size(inputTensors[0].dims); + const biasLength = ShapeUtil.size(inputTensors[1].dims); + // can only use vec4 when bias length is multiple of 4 + const useVec4 = biasLength % 4 === 0; + const getShaderSource = (shaderHelper: ShaderHelper): string => { + const x = inputVariable('x', dataType, [1], 4); + const bias = inputVariable('bias', dataType, [1], 4); + const y = outputVariable('y', dataType, [1], 4); + + const uniforms: UniformsArrayType = [{name: 'output_vec_size', type: 'u32'}, {name: 'bias_size', type: 'u32'}]; + + const singleElementBias = (i: 0|1|2|3) => ` + let bias${i}_offset: u32 = (global_idx * 4 + ${i}) % uniforms.bias_size; + let bias${i} = ${bias.getByOffset(`bias${i}_offset / 4`)}[bias${i}_offset % 4];`; + const biasGetExpression = useVec4 ? + ` + let bias = ${bias.getByOffset('global_idx % (uniforms.bias_size / 4)')};` : + `${singleElementBias(0)}${singleElementBias(1)}${singleElementBias(2)}${singleElementBias(3)} + let bias = ${x.type.value}(bias0, bias1, bias2, bias3);`; + + return `${shaderHelper.registerUniforms(uniforms).declareVariables(x, bias, y)} + + ${unary.fastGeluImpl(tensorTypeToWsglValueType(dataType))} + + ${shaderHelper.mainStart(WORKGROUP_SIZE)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_vec_size')} + + let x = ${x.getByOffset('global_idx')}; + ${biasGetExpression} + let x_in = x + bias; + ${y.setByOffset('global_idx', unary.fastGeluExpression('x_in'))} + }`; + }; + + return { + name: 'FastGeluWithBias', + shaderCache: {hint: `${useVec4}`, inputDependencies: ['type', 'type']}, + getShaderSource, + getRunData: (inputs) => ({ + outputs: [{dims: inputs[0].dims, dataType: inputs[0].dataType}], + programUniforms: + [{type: DataType.uint32, data: Math.ceil(outputSize / 4)}, {type: DataType.uint32, data: biasLength}], + dispatchGroup: {x: Math.ceil(outputSize / WORKGROUP_SIZE / 4)} + }) + }; +}; + +export const fastGelu = (context: ComputeContext): void => { + if (context.inputs.length < 2 || ShapeUtil.size(context.inputs[1].dims) === 0) { + unary.fastGelu(context); + } else { + context.compute(createFastGeluProgramInfo(context.inputs)); + } +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index 2e0aa33a957dc..6e66abacf3471 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -1,35 +1,78 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {MAX_CLIP, MIN_CLIP} from '../../util'; +import {ProgramUniform} from '../types'; + +import {UniformsArrayType} from './common'; export interface InternalActivationAttributes { readonly activation: string; readonly clipMin?: number; readonly clipMax?: number; + readonly alpha?: number; + readonly beta?: number; } -export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): string => { - switch (attributes.activation) { - case 'Relu': - return `value = max(value, ${valueType}(0.0));`; - case 'Sigmoid': - return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`; - case 'Clip': - return `value = clamp(value, ${valueType}(uniforms.clip_min), ${valueType}(uniforms.clip_max));`; - // TODO: adding other activations that can be fused. - default: - return ''; +export const getActivationSnippet = + (attributes: InternalActivationAttributes, valueType: string, baseType = 'f32'): string => { + switch (attributes.activation) { + case 'Relu': + return `value = max(value, ${valueType}(0.0));`; + case 'Sigmoid': + return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`; + case 'Clip': + return `value = clamp(value, ${valueType}(${baseType}(uniforms.clip_min)), ${valueType}(${ + baseType}(uniforms.clip_max)));`; + case 'HardSigmoid': + return `value = max(${valueType}(0.0), min(${valueType}(1.0), ${baseType}(uniforms.alpha) * value + ${ + baseType}(uniforms.beta)));`; + case 'LeakyRelu': + return `value = select(${baseType}(uniforms.alpha) * value, value, value >= ${valueType}(0.0));`; + case '': + return ''; + // TODO: adding other activations that can be fused. + default: + throw new Error(`Unsupported activation ${attributes.activation}`); + } + }; + +export const appendActivationUniformsData = + (attributes: InternalActivationAttributes, programUniform: ProgramUniform[]) => { + if (attributes.activation === 'Clip') { + programUniform.push( + {type: DataType.float, data: attributes.clipMax!}, {type: DataType.float, data: attributes.clipMin!}); + } else if (attributes.activation === 'HardSigmoid') { + programUniform.push( + {type: DataType.float, data: attributes.alpha!}, {type: DataType.float, data: attributes.beta!}); + } else if (attributes.activation === 'LeakyRelu') { + programUniform.push({type: DataType.float, data: attributes.alpha!}); + } + }; + +export const appendActivationUniforms = (attributes: InternalActivationAttributes, uniforms: UniformsArrayType) => { + if (attributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } else if (attributes.activation === 'HardSigmoid') { + uniforms.push({name: 'alpha', type: 'f32'}, {name: 'beta', type: 'f32'}); + } else if (attributes.activation === 'LeakyRelu') { + uniforms.push({name: 'alpha', type: 'f32'}); } }; export const parseInternalActivationAttributes = (attributes: Record|undefined): InternalActivationAttributes => { const activation = attributes?.activation as string || ''; - - if (activation === 'Clip') { + if (activation === 'HardSigmoid') { + const [alpha, beta] = attributes?.activation_params as [number, number] || [0.2, 0.5]; + return {activation, alpha, beta}; + } else if (activation === 'Clip') { const [clipMin, clipMax] = attributes?.activation_params as [number, number] || [MIN_CLIP, MAX_CLIP]; return {activation, clipMax, clipMin}; + } else if (activation === 'LeakyRelu') { + const [alpha] = attributes?.activation_params as [number] || [0.01]; + return {activation, alpha}; } return {activation}; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts index a945954adcaa4..4ab6c175a67e2 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -46,11 +47,11 @@ const createGatherElementsProgramInfo = const output = outputVariable('output', inputOutputDataType, outputShape.length); - const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}]; - programUniforms.push(...createTensorShapeVariables(inputShape)); - programUniforms.push(...createTensorShapeVariables(indicesShape)); - programUniforms.push(...createTensorShapeVariables(outputShape)); + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axisDimLimit}, + {type: DataType.uint32, data: axis} + ]; + programUniforms.push(...createTensorShapeVariables(inputShape, indicesShape, outputShape)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; // int64 indices would be treated as little endian i32 with assumption they fall in i32 limits diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index 469249f92ff28..d48bb909f7f8f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -5,9 +5,9 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; export interface GatherAttributes extends AttributeWithCacheKey { axis: number; @@ -33,33 +33,15 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath const components = inputs[0].dataType === DataType.bool ? 4 : 1; const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components); - const enableInputShapesUniforms = enableShapesUniforms(inputs[0].dims.length); - const inputShapeOrRank = enableInputShapesUniforms ? inputs[0].dims.length : inputs[0].dims; - const enableIndicesShapesUniforms = enableShapesUniforms(inputs[1].dims.length); - const indicesShapeOrRank = enableIndicesShapesUniforms ? inputs[1].dims.length : inputs[1].dims; - const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); - const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; - - const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}]; - if (enableInputShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); - } - if (enableIndicesShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); - } - if (enableOutputShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(outputShape)); - } - - const inputDependencies: ProgramInputTensorInfoDependency[] = []; - inputDependencies.push(enableInputShapesUniforms ? 'rank' : 'dims'); - inputDependencies.push(enableIndicesShapesUniforms ? 'rank' : 'dims'); + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axisDimLimit}, + {type: DataType.uint32, data: axis}, ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims, outputShape) + ]; const getShaderSource = (shaderHelper: ShaderHelper) => { - const data = inputVariable('data', inputs[0].dataType, inputShapeOrRank, components); - const indices = inputVariable('inputIndices', inputs[1].dataType, indicesShapeOrRank); - const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank, components); + const data = inputVariable('data', inputs[0].dataType, inputs[0].dims.length, components); + const indices = inputVariable('inputIndices', inputs[1].dataType, inputs[1].dims.length); + const output = outputVariable('output', inputs[0].dataType, outputShape.length, components); const calcDataIndices = (x: number|string): string => { const indicesRank = indicesShape.length; @@ -73,7 +55,7 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath if (idx${x} < 0) { idx${x} = idx${x} + uniforms.axisDimLimit; } - var dataIndices${x} = ${data.type.indices}(0); + var dataIndices${x} : ${data.type.indices}; `; for (let i = 0, j = 0; i < inputRank; i++) { if (i === axis) { @@ -127,7 +109,7 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath }; return { name: 'Gather', - shaderCache: {hint: attributes.cacheKey, inputDependencies}, + shaderCache: {hint: attributes.cacheKey, inputDependencies: ['rank', 'rank']}, getRunData: () => ({ outputs: [ {dims: outputShape, dataType: inputs[0].dataType}, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts index a0d4021516bf7..76302e1af2e53 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {GemmUtil, ShapeUtil} from '../../util'; import {AttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -45,8 +46,9 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt } const outputSize = ShapeUtil.size(outputShape); const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N}, {type: 'uint32', data: K}, - {type: 'float32', data: attributes.alpha}, {type: 'float32', data: attributes.beta} + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: M}, {type: DataType.uint32, data: N}, + {type: DataType.uint32, data: K}, {type: DataType.float, data: attributes.alpha}, + {type: DataType.float, data: attributes.beta} ]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; if (inputs.length === 3) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index a835c90bd5451..2f652dbd310ab 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -25,8 +25,8 @@ const createInstanceNormProgramInfo = const inputShape = [xShape[0], xShape[1], normPackedSize]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type']; const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: normSize}, {type: 'uint32', data: normPackedSize}]; - programUniforms.push(...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape)); + [{type: DataType.uint32, data: normSize}, {type: DataType.uint32, data: normPackedSize}]; + programUniforms.push(...createTensorShapeVariables(inputShape, inputShape)); const getShaderSource = (shaderHelper: ShaderHelper) => { const x = inputVariable('x', inputs[0].dataType, inputShape.length, components); @@ -132,8 +132,9 @@ const computeMean = const meanInputDependencies: ProgramInputTensorInfoDependency[] = ['type']; const meanProgramUniforms: ProgramUniform[] = [ - {type: 'uint32', data: wgSize}, {type: 'uint32', data: h}, {type: 'uint32', data: Math.floor(c / components)}, - {type: 'uint32', data: Math.floor(h * c / components)} + {type: DataType.uint32, data: wgSize}, {type: DataType.uint32, data: h}, + {type: DataType.uint32, data: Math.floor(c / components)}, + {type: DataType.uint32, data: Math.floor(h * c / components)} ]; const getMeanShaderSource = (shaderHelper: ShaderHelper) => { @@ -182,8 +183,9 @@ const computeMean = {inputs: [input], outputs: [-1]})[0]; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: unitsOfWork}, {type: 'uint32', data: h}, - {type: 'uint32', data: Math.floor(c / components)}, {type: 'uint32', data: Math.floor(WG * c / components)} + {type: DataType.uint32, data: unitsOfWork}, {type: DataType.uint32, data: h}, + {type: DataType.uint32, data: Math.floor(c / components)}, + {type: DataType.uint32, data: Math.floor(WG * c / components)} ]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type', 'type']; const getShaderSource = (shaderHelper: ShaderHelper) => { @@ -246,7 +248,7 @@ const createInstanceNormNHWCProgramInfo = const components = getMaxComponents(C); const outputSize = ShapeUtil.size(outputShape) / components; const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: H}, {type: 'uint32', data: Math.floor(C / components)}]; + [{type: DataType.uint32, data: H}, {type: DataType.uint32, data: Math.floor(C / components)}]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; // first compute mean const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index 3c9f6ce71bb67..d5f97213e49ce 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -49,8 +49,9 @@ const createLayerNormProgramInfo = const components = getMaxComponents(normSize); const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: normCount}, {type: 'float32', data: normSize}, - {type: 'uint32', data: Math.floor(normSize / components)}, {type: 'float32', data: attributes.epsilon} + {type: DataType.uint32, data: normCount}, {type: DataType.float, data: normSize}, + {type: DataType.uint32, data: Math.floor(normSize / components)}, + {type: DataType.float, data: attributes.epsilon} ]; if (bias) { inputDependencies.push('type'); @@ -84,28 +85,28 @@ const createLayerNormProgramInfo = ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.norm_count')} let offset = global_idx * uniforms.norm_size_vectorized; - var meanVector = ${fillVector('f32', components)}; - var meanSquareVector = ${fillVector('f32', components)}; + var mean_vector = ${fillVector('f32', components)}; + var mean_square_vector = ${fillVector('f32', components)}; for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) { let value = ${castToF32(dataType, components, 'x[h + offset]')}; - meanVector += value; - meanSquareVector += value * value; + mean_vector += value; + mean_square_vector += value * value; } - let mean = ${sumVector('meanVector', components)} / uniforms.norm_size; - let invStdDev = - inverseSqrt(${sumVector('meanSquareVector', components)} / uniforms.norm_size - mean * mean + uniforms.epsilon); + let mean = ${sumVector('mean_vector', components)} / uniforms.norm_size; + let inv_std_dev = inverseSqrt(${ + sumVector('mean_square_vector', components)} / uniforms.norm_size - mean * mean + uniforms.epsilon); for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) { let f32input = ${castToF32(dataType, components, 'x[j + offset]')}; let f32scale = ${castToF32(dataType, components, 'scale[j]')}; - output[j + offset] = ${variables[0].type.value}((f32input - mean) * invStdDev * f32scale + output[j + offset] = ${variables[0].type.value}((f32input - mean) * inv_std_dev * f32scale ${bias ? `+ ${castToF32(dataType, components, 'bias[j]')}` : ''} ); } ${hasMeanDataOutput ? 'mean_data_output[global_idx] = mean' : ''}; - ${hasInvStdOutput ? 'inv_std_output[global_idx] = invStdDev' : ''}; + ${hasInvStdOutput ? 'inv_std_output[global_idx] = inv_std_dev' : ''}; }`; }; const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index c946ea6366123..1a92d861002fb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -1,13 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; -import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, UniformsArrayType,} from './common'; -import {getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; +import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; export const createNaiveMatmulProgramInfo = (inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, outputShape: readonly number[], @@ -29,17 +30,11 @@ export const createNaiveMatmulProgramInfo = const outputShapeInShader = [batchSize, M, N]; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N}, - {type: 'uint32', data: K} + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: M}, {type: DataType.uint32, data: N}, + {type: DataType.uint32, data: K} ]; - if (activationAttributes.activation === 'Clip') { - programUniforms.push( - {type: 'float32', data: activationAttributes.clipMax!}, - {type: 'float32', data: activationAttributes.clipMin!}); - } - programUniforms.push( - ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape), - ...createTensorShapeVariables(bShape)); + appendActivationUniformsData(activationAttributes, programUniforms); + programUniforms.push(...createTensorShapeVariables(outerDims, aShape, bShape)); if (hasBias) { programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); } @@ -50,7 +45,8 @@ export const createNaiveMatmulProgramInfo = const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents); const b = inputVariable('b', inputs[1].dataType, bShape.length, components); const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); - const applyActivation = getActivationSnippet(activationAttributes, output.type.value); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType); const inputVariables = [a, b]; let processBias = ''; if (hasBias) { @@ -69,9 +65,7 @@ export const createNaiveMatmulProgramInfo = {name: 'output_size', type: 'u32'}, {name: 'M', type: 'u32'}, {name: 'N', type: 'u32'}, {name: 'K', type: 'u32'} ]; - if (activationAttributes.activation === 'Clip') { - uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); - } + appendActivationUniforms(activationAttributes, uniforms); const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => { const rank = variable.rank; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts new file mode 100644 index 0000000000000..9bf5e4066139d --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -0,0 +1,250 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; + +import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; + +// TODO support quantization bits not equal to 4 +export interface MatMulNBitsAttributes extends AttributeWithCacheKey { + k: number; + n: number; + accuracyLevel: number; + bits: number; + blockSize: number; +} + +const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): void => { + if (inputs.length < 3 || inputs.length > 4) { + throw new Error('MatMulNBits requires 3 or 4 inputs'); + } + const a = inputs[0]; + const aRank = a.dims.length; + if (a.dims[aRank - 1] !== attributes.k) { + throw new Error('The last dim of input shape does not match the k value'); + } + const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); + const blobSize = attributes.blockSize / 8 * attributes.bits; + const b = inputs[1]; + if (!ShapeUtil.areEqual(b.dims, [attributes.n, nBlocksPerCol, blobSize])) { + throw new Error('The second inputs must be 3D tensor with shape N X nBlocksPerCol X blobSize'); + } + const scales = inputs[2]; + const scalesShape = scales.dims; + if (ShapeUtil.size(scalesShape) !== attributes.n * nBlocksPerCol) { + throw new Error('scales input size error.'); + } + if (inputs.length === 4) { + const zeroPoints = inputs[3]; + const zeroPointsShape = zeroPoints.dims; + const expectedZeroPointsSize = + attributes.bits > 4 ? (attributes.n * nBlocksPerCol) : attributes.n * Math.floor((nBlocksPerCol + 1) / 2); + if (ShapeUtil.size(zeroPointsShape) !== expectedZeroPointsSize) { + throw new Error('zeroPoints input size error.'); + } + } +}; + +export const createMatMulNBitsProgramInfo = + (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): ProgramInfo => { + const inputShape = inputs[0].dims; + const aRank = inputShape.length; + const outputShape = inputShape.slice(0, aRank - 1).concat(attributes.n); + const m = inputShape[aRank - 2]; + const blobSize = attributes.blockSize / 8 * attributes.bits; + const blobSizeInWords = blobSize / 4; + const outputNumber = getMaxComponents(m); + const components = getMaxComponents(attributes.n); + const aComponents = getMaxComponents(attributes.k); + const bComponents = getMaxComponents(blobSizeInWords); + const outputSize = ShapeUtil.size(outputShape) / components / outputNumber; + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.k}, + {type: DataType.uint32, data: attributes.n}, {type: DataType.uint32, data: attributes.accuracyLevel}, + {type: DataType.uint32, data: attributes.bits}, {type: DataType.uint32, data: attributes.blockSize} + ]; + const aShape = inputShape.slice(); + aShape.splice(-1, 1, attributes.k / aComponents); + const bShape = ShapeUtil.convertShape(inputs[1].dims).slice(); + bShape.splice(-1, 1, blobSizeInWords / bComponents); + programUniforms.push(...createTensorShapeVariables(aShape)); + programUniforms.push(...createTensorShapeVariables(bShape)); + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + if (inputs.length === 4) { + programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims))); + } + const oShape = outputShape.slice(); + oShape.splice(-1, 1, attributes.n / components); + programUniforms.push(...createTensorShapeVariables(oShape)); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents); + const b = inputVariable('b', DataType.uint32, bShape.length, bComponents); + const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length); + const inputVariables = [a, b, scales]; + const zeroPoints = + inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined; + if (zeroPoints) { + inputVariables.push(zeroPoints); + } + const output = outputVariable('output', inputs[0].dataType, outputShape.length, components); + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'}, + {name: 'accuracy_level', type: 'u32'}, {name: 'bits', type: 'u32'}, {name: 'block_size', type: 'u32'} + ]; + const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + + const qDqDataType = (() => { + switch (aComponents) { + case 1: + return `array<${dataType}, 8>`; + case 2: + return `mat4x2<${dataType}>`; + case 4: + return `mat2x4<${dataType}>`; + default: + throw new Error(`${aComponents}-component is not supported.`); + } + })(); + + const dequantizeImpl = ` + fn dequantize(quantized: ${qDqDataType}, zero_point: ${dataType}, scale: ${dataType}) -> ${qDqDataType} { + ${(() => { + if (aComponents === 1) { + return `var dequantized = ${qDqDataType}(${ + Array.from({length: 8}, (_, i) => `(quantized[${i}] - zero_point) * scale`).join(', ')}); + return dequantized;`; + } else { + return `var zero_points: ${qDqDataType} = ${qDqDataType}(${Array(8).fill('zero_point').join(',')}); + return (quantized - zero_points) * scale;`; + } + })()} + }`; + const ortUnpack8x4snormImpl = ` + fn ortUnpack8x4snorm(value: u32) -> ${qDqDataType} { + var quantized: ${qDqDataType}; + var offset: u32 = 0; + let count: u32 = 4; + for (var i: u32 = 0; i < 8u; i++) { + var result = ${dataType}(extractBits(value, offset, count)); + ${(() => { + switch (aComponents) { + case 1: + return 'quantized[i] = result;'; + case 2: + return 'quantized[i / 2][i % 2] = result;'; + case 4: + return 'quantized[i / 4][i % 4] = result;'; + default: + throw new Error(`${aComponents}-component is not supported.`); + } + })()} + offset += count; + } + return quantized; + }`; + + const updateZeroPointIndex = zeroPoints ? ` + zero_point_offset += 4; + if (zero_point_offset == 32) { + zero_point_offset = 0; + zero_point_index++; + zero_point_word = ${zeroPoints.getByOffset('zero_point_index')}; + }` : + ''; + + return ` + ${dequantizeImpl}; + ${ortUnpack8x4snormImpl}; + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + var output_values: array<${output.type.value}, ${outputNumber}>; + var output_indices = ${output.offsetToIndices('global_idx')}; + var n = ${output.indicesGet('output_indices', aRank - 1)}; + var m = ${output.indicesGet('output_indices', aRank - 2)}; + var a_indices: ${a.type.indices} = output_indices; + // Two zero points are packed into one byte because uniforms.bits <= 4. + // zero_point_offset is either 0 or 4. It is bit offset within one byte. + // TODO support zero_point_offset for bits > 4 + ${ + zeroPoints ? ` + var zero_point_index: u32 = n * ${components} * ((${nBlocksPerCol} + 1) / 2) / 4; + var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')}; + var zero_point_offset: u32 = 0;` : + ''} + var scale_index = n * ${nBlocksPerCol * components}; + var b_indices: ${b.type.indices}; + for (var c: u32 = 0; c < ${components}; c++) { + ${b.indicesSet('b_indices', '0', `n * ${components} + c`)}; + var block_offset: u32 = 0; + for (var block: u32 = 0; block < ${nBlocksPerCol}; block++) { + // The scale and zero points are computed per block. + let scale = ${scales.getByOffset('scale_index')}; + // The default zero point is 8 for unsigned 4-bit quantization. + let zero_point = ${dataType}(${zeroPoints ? 'extractBits(zero_point_word, zero_point_offset, 4)' : 8.0}); + ${b.indicesSet('b_indices', '1', 'block')}; + var word_offset: u32 = block_offset; + for (var word: u32 = 0; word < ${blobSizeInWords}; word += ${bComponents}) { + ${b.indicesSet('b_indices', '2', 'word')}; + let b_data = ${b.getByIndices('b_indices')}; + for (var i: u32 = 0; i < ${bComponents}; i++) { + let b_value = ${bComponents === 1 ? 'b_data' : 'b_data[word + i]'}; + let b_quantized_values: ${qDqDataType} = ortUnpack8x4snorm(b_value); + let b_dequantized_values = dequantize(b_quantized_values, zero_point, scale); + // Number of B elements per 32-bit word is 32/bits = 32/4 = 8 + var offset: u32 = word_offset; + for (var j: u32 = 0; j < 8/${aComponents}; j++) { + ${a.indicesSet('a_indices', aRank - 1, `offset/${aComponents}`)}; + for (var k: u32 = 0; k < ${outputNumber}u; k++) { + ${a.indicesSet('a_indices', aRank - 2, `m * ${outputNumber} + k`)}; + let a_data = ${a.getByIndices('a_indices')}; + output_values[k]${components > 1 ? '[c]' : ''} += ${ + aComponents === 1 ? 'a_data * b_dequantized_values[j]' : 'dot(a_data, b_dequantized_values[j])'}; + } + offset += ${aComponents}; + } + word_offset += 8; + } + } + scale_index++; + ${updateZeroPointIndex} + block_offset += uniforms.block_size; + } + // Drop the trailing 4 bits if the zero_poit_offset is not a byte boundary to align with the next byte. + ${ + zeroPoints ? `if (zero_point_offset % 8 > 0) { + ${updateZeroPointIndex} + }` : + ''} + } + for (var k: u32 = 0u; k < ${outputNumber}u; k++) { + ${output.indicesSet('output_indices', aRank - 2, `${outputNumber + ' * m + k'}`)}; + ${output.setByIndices('output_indices', 'output_values[k]')} + } + }`; + }; + return { + name: 'MatMulNBits', + shaderCache: + {hint: `${attributes.cacheKey};${inputs.length}`, inputDependencies: Array(inputs.length).fill('rank')}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms + }), + getShaderSource + }; + }; + +export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => { + validateInputs(context.inputs, attributes); + context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes)); +}; + +export const parseMatMulNBitsAttributes = (attributes: Record): MatMulNBitsAttributes => + createAttributeWithCacheKey(attributes as Omit); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts index 6d22e3780efd9..5c5c849d99811 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -238,8 +239,10 @@ const addBiasTranspose = hiddenSize: number, biasOffset: number) => { const outputShape = [batchSize, sequenceLength, hiddenSize]; const outputSize = ShapeUtil.size(outputShape); - const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: outputSize}, {type: 'uint32', data: biasOffset}, {type: 'uint32', data: hiddenSize}]; + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: biasOffset}, + {type: DataType.uint32, data: hiddenSize} + ]; const getShaderSource = (shaderHelper: ShaderHelper) => { const output = outputVariable('qkv_with_bias', qkv.dataType, outputShape); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts index c65b741e1105a..236fc29fdf1ab 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType, tensorDataTypeEnumToString} from '../../../wasm-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; @@ -153,13 +153,12 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr const inputDims = inputs[0].dims; const outputSize = ShapeUtil.size(outputShape); const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.pads}]; + [{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.pads}]; if (attributes.mode === 0) { - const tensorDataType = tensorDataTypeEnumToString(inputs[0].dataType) as ProgramUniform['type']; - programUniforms.push({type: tensorDataType, data: attributes.value}); + programUniforms.push({type: inputs[0].dataType, data: attributes.value}); } - programUniforms.push(...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape)); + programUniforms.push(...createTensorShapeVariables(inputs[0].dims, outputShape)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank']; const getShaderSource = (shaderHelper: ShaderHelper) => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts index 9e9b361c1af1c..4e933573b9137 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -3,6 +3,7 @@ import {env} from 'onnxruntime-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {PoolConvUtil, ShapeUtil} from '../../util'; import {AttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -56,7 +57,8 @@ const getUniformAndPadInfo = aBestValues : array<${output.type.storage}, ${workgroupSize}>; + var aBestValues : array; `; const getShaderSource = (shaderHelper: ShaderHelper) => ` @@ -145,10 +145,10 @@ export const createReduceSharedProgramInfo = let outputIndex = global_idx / ${workgroupSize}; let offset = outputIndex * uniforms.reduceSize; - var bestValue = ${output.type.storage}(${reduceInitValues[reduceType]}); + var bestValue = f32(${reduceInitValues[reduceType]}); let Length = uniforms.reduceSize; for (var k = local_idx; k < Length; k = k + ${workgroupSize}) { - let candidate = ${output.type.storage}(${input.getByOffset('offset + k')}); + let candidate = f32(${input.getByOffset('offset + k')}); bestValue = ${reduceOps[reduceType]}; } aBestValues[local_idx] = bestValue; @@ -172,8 +172,8 @@ export const createReduceSharedProgramInfo = output.setByOffset( 'outputIndex', `${ - reduceType === 'mean' ? `bestValue / ${output.type.storage}(uniforms.reduceSize)` : - `${reduceOutputValues[reduceType]}`}`)}; + reduceType === 'mean' ? `${output.type.storage}(bestValue / f32(uniforms.reduceSize))` : + `${output.type.storage}(${reduceOutputValues[reduceType]})`}`)}; } }`; @@ -185,7 +185,7 @@ export const createReduceSharedProgramInfo = getRunData: () => ({ outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: outputSize}, - programUniforms: [{type: 'uint32', data: reduceSize}] + programUniforms: [{type: DataType.uint32, data: reduceSize}] }), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts index e8851ac546942..e8205ba6fd928 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts @@ -100,10 +100,8 @@ export const createReduceProgramInfo = getRunData: () => ({ outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: [ - {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputShape), - ...createTensorShapeVariables(outputShape) - ] + programUniforms: + [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape, outputShape)] }), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index f68526acc0e63..2c6b537de1f00 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -641,11 +642,8 @@ const createResizeProgramInfo = outputs: [{dims: outputShape, dataType: inputTensor.dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: [ - {type: 'uint32', data: outputSize}, - {type: 'float32', data: scales}, - {type: 'float32', data: roi}, - ...createTensorShapeVariables(inputShape), - ...createTensorShapeVariables(outputShape), + {type: DataType.uint32, data: outputSize}, {type: DataType.float, data: scales}, + {type: DataType.float, data: roi}, ...createTensorShapeVariables(inputShape, outputShape) ] }) }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index 509a722f4b52a..7be9ceec6bc65 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -88,10 +88,10 @@ const createSkipLayerNormProgramInfo = const components = getMaxComponents(hiddenSize); const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, - {type: 'uint32', data: components}, - {type: 'uint32', data: hiddenSize}, - {type: 'float32', data: attributes.epsilon}, + {type: DataType.uint32, data: outputSize}, + {type: DataType.uint32, data: components}, + {type: DataType.uint32, data: hiddenSize}, + {type: DataType.float, data: attributes.epsilon}, ]; const getShaderSource = (shaderHelper: ShaderHelper) => { const uniformsArray: UniformsArrayType = [ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index 5212c6475dce0..a5e71f30e5966 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -155,9 +155,9 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice ]; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'uint32', data: starts}, {type: 'int32', data: signs}, - {type: 'uint32', data: steps}, ...createTensorShapeVariables(inputs[0].dims), - ...createTensorShapeVariables(outputShape) + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: starts}, + {type: DataType.int32, data: signs}, {type: DataType.uint32, data: steps}, + ...createTensorShapeVariables(inputs[0].dims, outputShape) ]; const getShaderSource = (shaderHelper: ShaderHelper) => ` diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index 324dc3af1a710..6f8bfa08d7b62 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -5,6 +5,7 @@ // performance limitations when the reduced axis is long. Need to add // a optimized codepath for this. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -136,7 +137,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut getRunData: () => ({ outputs: [{dims: shape, dataType: input.dataType}], dispatchGroup: {x: rows}, - programUniforms: [{type: 'uint32', data: packedCols}] + programUniforms: [{type: DataType.uint32, data: packedCols}] }), getShaderSource, }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index b8582614fa214..a09ac78b17006 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -67,24 +68,23 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split const dataType = inputs[0].dataType; const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length); const outputs = new Array(attributes.numOutputs); - const input = inputVariable('input', dataType, inputShape); + const input = inputVariable('input', dataType, inputShape.length); const sizeInSplitAxis = new Array(attributes.numOutputs); const outputsTensorInfo: TensorInfo[] = []; const outputShapes: number[][] = []; let previousSum = 0; - const programUniforms: ProgramUniform[] = [{type: 'uint32', data: inputSize}]; + const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: inputSize}]; for (let i = 0; i < attributes.numOutputs; i++) { previousSum += attributes.splitSizes[i]; sizeInSplitAxis[i] = previousSum; const outputShape = inputShape.slice(); outputShape[attributes.axis] = attributes.splitSizes[i]; outputShapes.push(outputShape); - outputs[i] = outputVariable(`output${i}`, dataType, outputShape); + outputs[i] = outputVariable(`output${i}`, dataType, outputShape.length); outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType}); } - programUniforms.push({type: 'uint32', data: sizeInSplitAxis}); - programUniforms.push(...createTensorShapeVariables(inputShape)); - outputShapes.forEach((outputShape) => programUniforms.push(...createTensorShapeVariables(outputShape))); + programUniforms.push( + {type: DataType.uint32, data: sizeInSplitAxis}, ...createTensorShapeVariables(inputShape, ...outputShapes)); const getShaderSource = (shaderHelper: ShaderHelper) => ` ${ shaderHelper.registerUniform('input_size', 'u32') diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts index 90a36a7bec2a9..f9728575fe072 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts @@ -79,10 +79,8 @@ export const createTileProgramInfo = (inputs: readonly TensorView[]): ProgramInf getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: [ - {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputs[0].dims), - ...createTensorShapeVariables(outputShape) - ], + programUniforms: + [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims, outputShape)], }), getShaderSource, }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index c4d43e9f466f5..7ae801222b875 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -1,12 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; export interface TransposeAttributes extends AttributeWithCacheKey { readonly perm: number[]; @@ -39,12 +40,9 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu const inputDataType = inputTensor.dataType; const inputRank = inputTensor.dims.length; const perm = getAdjustedPerm(inputRank, permAttr); - const useShapesUniforms = enableShapesUniforms(inputRank); const outputShape = getOutputShape(inputTensor.dims, perm); - const outShapeOrRank = useShapesUniforms ? outputShape.length : outputShape; - const inShapeOrRank = useShapesUniforms ? inputRank : inputTensor.dims; - const output = outputVariable('output', inputDataType, outShapeOrRank); - const input = inputVariable('a', inputDataType, inShapeOrRank); + const output = outputVariable('output', inputDataType, outputShape.length); + const input = inputVariable('a', inputDataType, inputRank); const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} @@ -61,21 +59,14 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu }`; return { name: 'Transpose', - shaderCache: {hint: `${permAttr}`, inputDependencies: useShapesUniforms ? ['rank'] : ['dims']}, + shaderCache: {hint: `${permAttr}`, inputDependencies: ['rank']}, getRunData: (inputs) => { const outputSize = ShapeUtil.size(outputShape); return { outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: useShapesUniforms ? - [ - {type: 'uint32', data: outputSize}, - ...createTensorShapeVariables(inputs[0].dims), - ...createTensorShapeVariables(outputShape), - ] : - [ - {type: 'uint32', data: outputSize}, - ], + programUniforms: + [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims, outputShape)], }; }, getShaderSource, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 76929efb32537..5f105c745739e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -53,7 +53,7 @@ const createElementwiseProgramInfo = dispatchGroup: {x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */)}, programUniforms: [ - {type: 'uint32', data: Math.ceil(ShapeUtil.size(input.dims) / 4)}, + {type: DataType.uint32, data: Math.ceil(ShapeUtil.size(input.dims) / 4)}, ], }) }); @@ -178,7 +178,7 @@ export const elu = (context: ComputeContext, attributes: AlphaAttributes): void attributes.cacheKey)); }; -export const erfImpl = (dataType: string, varType = 'f32') => ` +export const erfImpl = (varType = 'f32') => ` const r0: ${varType} = 0.3275911; const r1: ${varType} = 0.254829592; const r2: ${varType} = -0.284496736; @@ -186,7 +186,7 @@ const r3: ${varType} = 1.421413741; const r4: ${varType} = -1.453152027; const r5: ${varType} = 1.061405429; -fn erf_vf32(v: ${dataType}) -> ${dataType} { +fn erf_vf32(v: vec4<${varType}>) -> vec4<${varType}> { let absv = abs(v); let x = 1.0 / (1.0 + r0 * absv); return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv)); @@ -194,8 +194,7 @@ fn erf_vf32(v: ${dataType}) -> ${dataType} { export const erf = (context: ComputeContext): void => { const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); - context.compute(createElementwiseProgramInfo( - context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType))); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(dataType))); }; export const exp = (context: ComputeContext): void => { @@ -209,8 +208,7 @@ export const floor = (context: ComputeContext): void => { export const gelu = (context: ComputeContext): void => { const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfo( - context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`, - erfImpl(`vec4<${dataType}>`, dataType))); + context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`, erfImpl(dataType))); }; export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => { @@ -278,10 +276,31 @@ export const tan = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tan', 'tan')); }; +export const tanhExpression = (a: string) => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`; + export const tanh = (context: ComputeContext): void => { // TODO: revisit after https://github.com/gpuweb/gpuweb/issues/4458 is resolved + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tanh', tanhExpression)); +}; + +export const fastGeluImpl = (varType = 'f32') => ` +const fast_gelu_a: ${varType} = 0.5; +const fast_gelu_b: ${varType} = 0.7978845608028654; +const fast_gelu_c: ${varType} = 0.035677408136300125; + +fn tanh_v(v: vec4<${varType}>) -> vec4<${varType}> { + return ${tanhExpression('v')}; +} +`; + +export const fastGeluExpression = (x: string) => + `(fast_gelu_a + fast_gelu_a * tanh_v(${x} * (fast_gelu_c * ${x} * ${x} + fast_gelu_b))) * ${x}`; + +export const fastGelu = (context: ComputeContext): void => { + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfo( - context.inputs[0], 'Tanh', a => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`)); + context.inputs[0], 'FastGelu', fastGeluExpression, fastGeluImpl(dataType), undefined, + context.inputs[0].dataType)); }; export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts index 2ef9637bcda5e..a6375847fc42f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -27,7 +27,7 @@ const createWhereOpProgramShader = const expressionA = `a_data[index_a${x}][component_a${x}]`; const expressionB = `b_data[index_b${x}][component_b${x}]`; // eslint-disable-next-line no-bitwise - const expressionC = `bool(c_data[index_c${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`; + const expressionC = `bool(c_data[index_c${x}] & (0xffu << (component_c${x} * 8)))`; return ` let output_indices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; let offset_a${x} = ${a.broadcastedIndicesToOffset(`output_indices${x}`, output)}; @@ -38,6 +38,7 @@ const createWhereOpProgramShader = let index_c${x} = offset_c${x} / 4u; let component_a${x} = offset_a${x} % 4u; let component_b${x} = offset_b${x} % 4u; + let component_c${x} = offset_c${x} % 4u; ${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)}); `; }; @@ -97,10 +98,8 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => getRunData: () => ({ outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)}, - programUniforms: [ - {type: 'uint32', data: vecSize}, ...createTensorShapeVariables(dimsC), ...createTensorShapeVariables(dimsA), - ...createTensorShapeVariables(dimsB), ...createTensorShapeVariables(outputShape) - ], + programUniforms: + [{type: DataType.uint32, data: vecSize}, ...createTensorShapeVariables(dimsC, dimsA, dimsB, outputShape)], }), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index 72eb9713e26a8..9d05f607f817f 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -38,7 +38,6 @@ export class ProgramManager { const device = this.backend.device; const computePassEncoder = this.backend.getComputePassEncoder(); this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2); - computePassEncoder.setPipeline(buildArtifact.computePipeline); const entries = []; for (const input of inputs) { entries.push({binding: entries.length, resource: {buffer: input.buffer}}); @@ -51,8 +50,20 @@ export class ProgramManager { } const bindGroup = device.createBindGroup( {layout: buildArtifact.computePipeline.getBindGroupLayout(0), entries, label: buildArtifact.programInfo.name}); - computePassEncoder.setBindGroup(0, bindGroup); + if (this.backend.sessionStatus === 'capturing') { + const commandInfo = { + kernelId: this.backend.currentKernelId!, + computePipeline: buildArtifact.computePipeline, + bindGroup, + dispatchGroup + }; + const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!); + sessionCommandList!.push(commandInfo); + } + + computePassEncoder.setPipeline(buildArtifact.computePipeline); + computePassEncoder.setBindGroup(0, bindGroup); computePassEncoder.dispatchWorkgroups(...dispatchGroup); this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1); this.backend.pendingDispatchNumber++; diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index 789ac70a6913a..48e0855f01a97 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -1,10 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../wasm-common'; import {TensorView} from '../tensor-view'; import {ShaderHelper} from './ops/common'; +export type SessionState = 'default'|'capturing'|'replaying'; + export enum GpuDataType { default = 0, upload = 1, @@ -12,6 +15,13 @@ export enum GpuDataType { } export type GpuDataId = number; +export type GpuArchitecture = 'ampere'; +export type GpuVendor = 'amd'|'intel'|'nvidia'; +export interface AdapterInfo { + isArchitecture: (architecture: GpuArchitecture) => boolean; + isVendor: (vendor: GpuVendor) => boolean; +} + export interface GpuData { type: GpuDataType; id: GpuDataId; @@ -24,7 +34,7 @@ export interface TensorInfo { } export interface ProgramUniform { - type: 'int32'|'float16'|'float32'|'uint32'; + type: DataType; data: number|readonly number[]; } @@ -143,6 +153,11 @@ export interface ComputeContextInputsOutputsMapping { * A ComputeContext instance carries the states that representing the current running of a kernel. */ export interface ComputeContext { + /** + * gpu adapter info + */ + readonly adapterInfo: AdapterInfo; + /** * stores the pointer to OpKernelContext */ diff --git a/js/web/lib/wasm/proxy-worker/main.ts b/js/web/lib/wasm/proxy-worker/main.ts index 6cbd38c76ccc8..3ce37a2d6b652 100644 --- a/js/web/lib/wasm/proxy-worker/main.ts +++ b/js/web/lib/wasm/proxy-worker/main.ts @@ -103,7 +103,7 @@ self.onmessage = (ev: MessageEvent): void => { } else { postMessage( {type, out: outputs} as OrtWasmMessage, - extractTransferableBuffers(outputs as SerializableTensorMetadata[])); + extractTransferableBuffers([...inputs, ...outputs] as SerializableTensorMetadata[])); } }, err => { diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index 86017a4ec6904..6ff4e86b1235e 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -155,7 +155,7 @@ export const createSession = ensureWorker(); return new Promise((resolve, reject) => { enqueueCallbacks('create', [resolve, reject]); - const message: OrtWasmMessage = {type: 'create', in : {model, options}}; + const message: OrtWasmMessage = {type: 'create', in : {model, options: {...options}}}; const transferable: Transferable[] = []; if (model instanceof Uint8Array) { transferable.push(model.buffer); diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 41ab2d52ca209..48eac57494726 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -168,6 +168,18 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n setExecutionProviders(sessionOptionsHandle, sessionOptions.executionProviders, allocs); } + if (sessionOptions.enableGraphCapture !== undefined) { + if (typeof sessionOptions.enableGraphCapture !== 'boolean') { + throw new Error(`enableGraphCapture must be a boolean value: ${sessionOptions.enableGraphCapture}`); + } + const keyDataOffset = allocWasmString('enableGraphCapture', allocs); + const valueDataOffset = allocWasmString(sessionOptions.enableGraphCapture.toString(), allocs); + if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { + checkLastError( + `Can't set a session config entry: 'enableGraphCapture' - ${sessionOptions.enableGraphCapture}.`); + } + } + if (sessionOptions.freeDimensionOverrides) { for (const [name, value] of Object.entries(sessionOptions.freeDimensionOverrides)) { if (typeof name !== 'string') { diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index b9eff45e890c4..54eaf5e0c43cc 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -3,6 +3,12 @@ import {Tensor} from 'onnxruntime-common'; +// a dummy type declaration for Float16Array in case any polyfill is available. +declare global { + // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any + const Float16Array: any; +} + // This file includes common definitions. They do NOT have dependency on the WebAssembly instance. /** @@ -117,7 +123,8 @@ export const tensorTypeToTypedArrayConstructor = (type: Tensor.Type): Float32Arr Uint8ArrayConstructor|Float64ArrayConstructor|Uint32ArrayConstructor|BigUint64ArrayConstructor => { switch (type) { case 'float16': - return Uint16Array; + // allow Float16Array polyfill. + return typeof Float16Array !== 'undefined' && Float16Array.from ? Float16Array : Uint16Array; case 'float32': return Float32Array; case 'uint8': @@ -169,7 +176,8 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro * Check whether the given tensor type is supported by GPU buffer */ export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuBufferDataTypes => type === 'float32' || - type === 'int32' || type === 'int64' || type === 'bool' || type === 'float16' || type === 'uint32'; + type === 'float16' || type === 'int32' || type === 'int64' || type === 'uint32' || type === 'uint8' || + type === 'bool'; /** * Map string data location to integer value diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 046336dc9cac0..7019758be0efd 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -84,27 +84,44 @@ export const initRuntime = async(env: Env): Promise => { * @param epName */ export const initEp = async(env: Env, epName: string): Promise => { - if (!BUILD_DEFS.DISABLE_WEBGPU && (epName === 'webgpu' || epName === 'webnn')) { - // perform WebGPU availability check - if (typeof navigator === 'undefined' || !navigator.gpu) { - throw new Error('WebGPU is not supported in current environment'); - } - const adapter = await navigator.gpu.requestAdapter(); - if (!adapter) { - throw new Error( - 'Failed to get GPU adapter. You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.'); - } + if (!BUILD_DEFS.DISABLE_WEBGPU) { + // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires + const initJsep = require('./jsep/init').init; - if (!env.wasm.simd) { - throw new Error( - 'Not supported for WebGPU=ON and SIMD=OFF. Please set `env.wasm.simd` to true when using `webgpu` EP'); - } + if (epName === 'webgpu') { + // perform WebGPU availability check + if (typeof navigator === 'undefined' || !navigator.gpu) { + throw new Error('WebGPU is not supported in current environment'); + } + const powerPreference = env.webgpu?.powerPreference; + if (powerPreference !== undefined && powerPreference !== 'low-power' && powerPreference !== 'high-performance') { + throw new Error(`Invalid powerPreference setting: "${powerPreference}"`); + } + const forceFallbackAdapter = env.webgpu?.forceFallbackAdapter; + if (forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean') { + throw new Error(`Invalid forceFallbackAdapter setting: "${forceFallbackAdapter}"`); + } + const adapter = await navigator.gpu.requestAdapter({powerPreference, forceFallbackAdapter}); + if (!adapter) { + throw new Error( + 'Failed to get GPU adapter. You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.'); + } - // init JSEP if available + if (!env.wasm.simd) { + throw new Error( + 'Not supported for WebGPU=ON and SIMD=OFF. Please set `env.wasm.simd` to true when using `webgpu` EP'); + } - // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires - const initJsep = require('./jsep/init').init; - await initJsep(getInstance(), env, adapter); + await initJsep('webgpu', getInstance(), env, adapter); + } + if (epName === 'webnn') { + // perform WebNN availability check + if (typeof navigator === 'undefined' || !(navigator as unknown as {ml: unknown}).ml) { + throw new Error('WebNN is not supported in current environment'); + } + + await initJsep('webnn', getInstance(), env); + } } }; @@ -139,7 +156,7 @@ type IOBindingState = { */ type SessionMetadata = [ inferenceSessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[], - bindingState: IOBindingState|null + bindingState: IOBindingState|null, enableGraphCapture: boolean, inputOutputBound: boolean ]; const activeSessions = new Map(); @@ -235,6 +252,8 @@ export const createSession = async( const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle); + const enableGraphCapture = !!options?.enableGraphCapture; + const inputNames = []; const outputNames = []; const outputPreferredLocations: SupportedTensorDataLocationForInputOutput[] = []; @@ -256,12 +275,20 @@ export const createSession = async( outputNames.push(nameString); if (!BUILD_DEFS.DISABLE_WEBGPU) { + if (enableGraphCapture && options?.preferredOutputLocation === undefined) { + outputPreferredLocations.push('gpu-buffer'); + continue; + } const location = typeof options?.preferredOutputLocation === 'string' ? options.preferredOutputLocation : options?.preferredOutputLocation?.[nameString] ?? 'cpu'; if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') { throw new Error(`Not supported preferred output location: ${location}.`); } + if (enableGraphCapture && location !== 'gpu-buffer') { + throw new Error(`Not supported preferred output location: ${ + location}. Only 'gpu-buffer' location is supported when enableGraphCapture is true.`); + } outputPreferredLocations.push(location); } } @@ -281,7 +308,9 @@ export const createSession = async( }; } - activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState]); + activeSessions.set( + sessionHandle, + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState, enableGraphCapture, false]); return [sessionHandle, inputNames, outputNames]; } catch (e) { inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); @@ -313,13 +342,16 @@ export const releaseSession = (sessionId: number): void => { if (!session) { throw new Error(`cannot release session. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; + const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture] = session; if (ioBindingState) { + if (enableGraphCapture) { + wasm._OrtClearBoundOutputs(ioBindingState.handle); + } wasm._OrtReleaseBinding(ioBindingState.handle); } - wasm.jsepUnregisterBuffers?.(sessionId); + wasm.jsepOnReleaseSession?.(sessionId); inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); @@ -328,70 +360,80 @@ export const releaseSession = (sessionId: number): void => { }; export const prepareInputOutputTensor = - (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number): - void => { - if (!tensor) { - tensorHandles.push(0); - return; - } + (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number, + enableGraphCapture = false): void => { + if (!tensor) { + tensorHandles.push(0); + return; + } - const wasm = getInstance(); + const wasm = getInstance(); - const dataType = tensor[0]; - const dims = tensor[1]; - const location = tensor[3]; + const dataType = tensor[0]; + const dims = tensor[1]; + const location = tensor[3]; - let rawData: number; - let dataByteLength: number; + let rawData: number; + let dataByteLength: number; - if (dataType === 'string' && location === 'gpu-buffer') { - throw new Error('String tensor is not supported on GPU.'); - } + if (dataType === 'string' && location === 'gpu-buffer') { + throw new Error('String tensor is not supported on GPU.'); + } - if (location === 'gpu-buffer') { - const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; - const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; - dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; - rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength); - } else { - const data = tensor[2]; - - if (Array.isArray(data)) { - // string tensor - dataByteLength = 4 * data.length; - rawData = wasm._malloc(dataByteLength); - allocs.push(rawData); - let dataIndex = rawData / 4; - for (let i = 0; i < data.length; i++) { - if (typeof data[i] !== 'string') { - throw new TypeError(`tensor data at index ${i} is not a string`); - } - wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); - } - } else { - dataByteLength = data.byteLength; - rawData = wasm._malloc(dataByteLength); - allocs.push(rawData); - wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); - } - } + if (enableGraphCapture && location !== 'gpu-buffer') { + throw new Error( + `External buffer must be provided for input/output index ${index} when enableGraphCapture is true.`); + } + + if (location === 'gpu-buffer') { + const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; + const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; + dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; - const stack = wasm.stackSave(); - const dimsOffset = wasm.stackAlloc(4 * dims.length); - try { - let dimIndex = dimsOffset / 4; - dims.forEach(d => wasm.HEAP32[dimIndex++] = d); - const tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, - dataLocationStringToEnum(location)); - if (tensor === 0) { - checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); + const registerBuffer = wasm.jsepRegisterBuffer; + if (!registerBuffer) { + throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.'); + } + rawData = registerBuffer(sessionId, index, gpuBuffer, dataByteLength); + } else { + const data = tensor[2]; + + if (Array.isArray(data)) { + // string tensor + dataByteLength = 4 * data.length; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + let dataIndex = rawData / 4; + for (let i = 0; i < data.length; i++) { + if (typeof data[i] !== 'string') { + throw new TypeError(`tensor data at index ${i} is not a string`); } - tensorHandles.push(tensor); - } finally { - wasm.stackRestore(stack); + wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); } - }; + } else { + dataByteLength = data.byteLength; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); + } + } + + const stack = wasm.stackSave(); + const dimsOffset = wasm.stackAlloc(4 * dims.length); + try { + let dimIndex = dimsOffset / 4; + dims.forEach(d => wasm.HEAP32[dimIndex++] = d); + const tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, + dataLocationStringToEnum(location)); + if (tensor === 0) { + checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); + } + tensorHandles.push(tensor); + } finally { + wasm.stackRestore(stack); + } + }; /** * perform inference run @@ -404,7 +446,12 @@ export const run = async( if (!session) { throw new Error(`cannot run inference. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; + const sessionHandle = session[0]; + const inputNamesUTF8Encoded = session[1]; + const outputNamesUTF8Encoded = session[2]; + const ioBindingState = session[3]; + const enableGraphCapture = session[4]; + const inputOutputBound = session[5]; const inputCount = inputIndices.length; const outputCount = outputIndices.length; @@ -427,13 +474,15 @@ export const run = async( // create input tensors for (let i = 0; i < inputCount; i++) { - prepareInputOutputTensor(inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i]); + prepareInputOutputTensor( + inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i], enableGraphCapture); } // create output tensors for (let i = 0; i < outputCount; i++) { prepareInputOutputTensor( - outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i]); + outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i], + enableGraphCapture); } let inputValuesIndex = inputValuesOffset / 4; @@ -449,7 +498,7 @@ export const run = async( wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; } - if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { + if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState && !inputOutputBound) { const {handle, outputPreferredLocations, outputPreferredLocationsEncoded} = ioBindingState; if (inputNamesUTF8Encoded.length !== inputCount) { @@ -486,9 +535,12 @@ export const run = async( } } } + activeSessions.set( + sessionId, + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, true]); } - wasm.jsepOnRunStart?.(); + wasm.jsepOnRunStart?.(sessionHandle); let errorCode: number; if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { errorCode = await wasm._OrtRunWithBinding( @@ -557,7 +609,11 @@ export const run = async( // If a certain output's preferred location is GPU but the tensor is empty, we still need to create a CPU // tensor for it. There is no mapping GPU buffer for an empty tensor. if (preferredLocation === 'gpu-buffer' && size > 0) { - const gpuBuffer = wasm.jsepGetBuffer(dataOffset); + const getBuffer = wasm.jsepGetBuffer; + if (!getBuffer) { + throw new Error('preferredLocation "gpu-buffer" is not supported without using WebGPU.'); + } + const gpuBuffer = getBuffer(dataOffset); const elementSize = getTensorElementSize(dataType); if (elementSize === undefined || !isGpuBufferSupportedType(type)) { throw new Error(`Unsupported data type: ${type}`); @@ -569,7 +625,7 @@ export const run = async( output.push([ type, dims, { gpuBuffer, - download: wasm.jsepCreateDownloader(gpuBuffer, size * elementSize, type), + download: wasm.jsepCreateDownloader!(gpuBuffer, size * elementSize, type), dispose: () => { wasm._OrtReleaseTensor(tensor); } @@ -595,10 +651,12 @@ export const run = async( } } - if (ioBindingState) { + if (ioBindingState && !enableGraphCapture) { wasm._OrtClearBoundOutputs(ioBindingState.handle); + activeSessions.set( + sessionId, + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, false]); } - return output; } finally { wasm.stackRestore(beforeRunStack); diff --git a/js/web/package-lock.json b/js/web/package-lock.json index 41c44aaa2679b..5c9113459ff06 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -52,7 +52,7 @@ "version": "1.18.0", "license": "MIT", "devDependencies": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" } }, "node_modules/@chiragrupani/karma-chromium-edge-launcher": { @@ -1351,9 +1351,9 @@ "dev": true }, "node_modules/follow-redirects": { - "version": "1.15.4", - "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.4.tgz", - "integrity": "sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==", + "version": "1.15.6", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz", + "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==", "dev": true, "funding": [ { @@ -4595,9 +4595,9 @@ "dev": true }, "follow-redirects": { - "version": "1.15.4", - "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.4.tgz", - "integrity": "sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==", + "version": "1.15.6", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz", + "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==", "dev": true }, "from": { @@ -5503,7 +5503,7 @@ "onnxruntime-common": { "version": "file:../common", "requires": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" } }, "p-cancelable": { diff --git a/js/web/package.json b/js/web/package.json index a502c2b6b032d..55c3a3238bafc 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -69,11 +69,14 @@ "exports": { ".": { "node": "./dist/ort.node.min.js", + "types": "./types.d.ts", "default": { "import": "./dist/esm/ort.min.js", "require": "./dist/cjs/ort.min.js", + "types": "./types.d.ts", "default": { "development": "./dist/ort.js", + "types": "./types.d.ts", "default": "./dist/ort.min.js" } } @@ -81,34 +84,41 @@ "./experimental": { "import": "./dist/esm/ort.all.min.js", "require": "./dist/cjs/ort.all.min.js", + "types": "./types.d.ts", "default": { "development": "./dist/ort.all.js", + "types": "./types.d.ts", "default": "./dist/ort.all.min.js" } }, "./wasm": { "import": "./dist/esm/ort.wasm.min.js", "require": "./dist/cjs/ort.wasm.min.js", + "types": "./types.d.ts", "default": "./dist/ort.wasm.min.js" }, "./wasm-core": { "import": "./dist/esm/ort.wasm-core.min.js", "require": "./dist/cjs/ort.wasm-core.min.js", + "types": "./types.d.ts", "default": "./dist/ort.wasm-core.min.js" }, "./webgl": { "import": "./dist/esm/ort.webgl.min.js", "require": "./dist/cjs/ort.webgl.min.js", + "types": "./types.d.ts", "default": "./dist/ort.webgl.min.js" }, "./webgpu": { "import": "./dist/esm/ort.webgpu.min.js", "require": "./dist/cjs/ort.webgpu.min.js", + "types": "./types.d.ts", "default": "./dist/ort.webgpu.min.js" }, "./training": { "import": "./dist/esm/ort.training.wasm.min.js", "require": "./dist/cjs/ort.training.wasm.min.js", + "types": "./types.d.ts", "default": "./dist/ort.training.wasm.min.js" } }, diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index ed4dd76a6e315..b2b212bdb9bc1 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -29,8 +29,10 @@ Options: *** General Options *** -h, --help Print this message. - -d, --debug Specify to run test runner in debug mode. - Debug mode outputs verbose log for test runner, sets up environment debug flag, and keeps karma not to exit after tests completed. + -d, --debug Specify to run test runner in debug mode. Debug mode does the following: + - outputs verbose log for test runner + - sets up environment debug flag (env.debug = true) + - opens Chromium debug port at 9333 and keeps karma not to exit after tests completed. -b=<...>, --backend=<...> Specify one or more backend(s) to run the test upon. Backends can be one or more of the following, splitted by comma: webgl @@ -47,38 +49,55 @@ Options: bs (for BrowserStack tests) -p, --profile Enable profiler. Profiler will generate extra logs which include the information of events time consumption + -t, --trace Enable trace. -P[=<...>], --perf[=<...>] Generate performance number. Cannot be used with flag --debug. This flag can be used with a number as value, specifying the total count of test cases to run. The test cases may be used multiple times. Default value is 10. -c, --file-cache Enable file cache. + +*** Session Options *** + -u=<...>, --optimized-model-file-path=<...> Specify whether to dump the optimized model. + -o=<...>, --graph-optimization-level=<...> Specify graph optimization level. + Default is 'all'. Valid values are 'disabled', 'basic', 'extended', 'all'. -i=<...>, --io-binding=<...> Specify the IO binding testing type. Should be one of the following: - none (default) + none (default) gpu-tensor use pre-allocated GPU tensors for inputs and outputs gpu-location use pre-allocated GPU tensors for inputs and set preferredOutputLocation to 'gpu-buffer' -*** Session Options *** - -u=<...>, --optimized-model-file-path=<...> Specify whether to dump the optimized model. - -o=<...>, --graph-optimization-level=<...> Specify graph optimization level. - Default is 'all'. Valid values are 'disabled', 'basic', 'extended', 'all'. *** Logging Options *** - --log-verbose=<...> Set log level to verbose - --log-info=<...> Set log level to info - --log-warning=<...> Set log level to warning - --log-error=<...> Set log level to error - The 4 flags above specify the logging configuration. Each flag allows to specify one or more category(s), splitted by comma. If use the flags without value, the log level will be applied to all category. + --log-verbose Set log level to verbose + --log-info Set log level to info + --log-warning Set log level to warning + --log-error Set log level to error + The 4 flags above specify the logging configuration. *** Backend Options *** + --wasm.<...>=<...> Set global environment flags for each backend. + --webgl.<...>=<...> These flags can be used multiple times to set multiple flags. For example: + --webgpu.<...>=<...> --webgpu.profiling.mode=default --wasm.numThreads=1 --wasm.simd=false + --webnn.<...>=<...> + + --webnn-device-type Set the WebNN device type (cpu/gpu) + -x, --wasm-number-threads Set the WebAssembly number of threads + ("--wasm-number-threads" is deprecated. use "--wasm.numThreads" or "-x" instead) --wasm-init-timeout Set the timeout for WebAssembly backend initialization, in milliseconds + (deprecated. use "--wasm.initTimeout" instead) --wasm-enable-simd Set whether to enable SIMD + (deprecated. use "--wasm.simd" instead) --wasm-enable-proxy Set whether to enable proxy worker + (deprecated. use "--wasm.proxy" instead) --webgl-context-id Set the WebGL context ID (webgl/webgl2) + (deprecated. use "--webgl.contextId" instead) --webgl-matmul-max-batch-size Set the WebGL matmulMaxBatchSize + (deprecated. use "--webgl.matmulMaxBatchSize" instead) --webgl-texture-cache-mode Set the WebGL texture cache mode (initializerOnly/full) + (deprecated. use "--webgl.textureCacheMode" instead) --webgl-texture-pack-mode Set the WebGL texture pack mode (true/false) + (deprecated. use "--webgl.pack" instead) --webgpu-profiling-mode Set the WebGPU profiling mode (off/default) - --webnn-device-type Set the WebNN device type (cpu/gpu) + (deprecated. use "--webgpu.profiling.mode" instead) *** Browser Options *** @@ -171,7 +190,6 @@ export interface TestRunnerCliArgs { cpuOptions?: InferenceSession.CpuExecutionProviderOption; cudaOptions?: InferenceSession.CudaExecutionProviderOption; - cudaFlags?: Record; wasmOptions?: InferenceSession.WebAssemblyExecutionProviderOption; webglOptions?: InferenceSession.WebGLExecutionProviderOption; webnnOptions?: InferenceSession.WebNNExecutionProviderOption; @@ -260,40 +278,29 @@ function parseCpuOptions(_args: minimist.ParsedArgs): InferenceSession.CpuExecut return {name: 'cpu'}; } -function parseCpuFlags(_args: minimist.ParsedArgs): Record { - return {}; -} - function parseWasmOptions(_args: minimist.ParsedArgs): InferenceSession.WebAssemblyExecutionProviderOption { return {name: 'wasm'}; } function parseWasmFlags(args: minimist.ParsedArgs): Env.WebAssemblyFlags { - const numThreads = args.x || args['wasm-number-threads']; + const wasm = args.wasm || {}; + const numThreads = wasm.numThreads = wasm.numThreads ?? (args.x ?? args['wasm-number-threads']); if (typeof numThreads !== 'undefined' && typeof numThreads !== 'number') { - throw new Error('Flag "x"/"wasm-number-threads" must be a number value'); + throw new Error('Flag "wasm.numThreads"/"x"/"wasm-number-threads" must be a number value'); } - const initTimeout = args['wasm-init-timeout']; + const initTimeout = wasm.initTimeout = wasm.initTimeout ?? args['wasm-init-timeout']; if (typeof initTimeout !== 'undefined' && typeof initTimeout !== 'number') { - throw new Error('Flag "wasm-init-timeout" must be a number value'); - } - let simd = args['wasm-enable-simd']; - if (simd === 'true') { - simd = true; - } else if (simd === 'false') { - simd = false; - } else if (typeof simd !== 'undefined' && typeof simd !== 'boolean') { - throw new Error('Flag "wasm-enable-simd" must be a boolean value'); - } - let proxy = args['wasm-enable-proxy']; - if (proxy === 'true') { - proxy = true; - } else if (proxy === 'false') { - proxy = false; - } else if (typeof proxy !== 'undefined' && typeof proxy !== 'boolean') { - throw new Error('Flag "wasm-enable-proxy" must be a boolean value'); - } - return {numThreads, initTimeout, simd, proxy}; + throw new Error('Flag "wasm.initTimeout"/"wasm-init-timeout" must be a number value'); + } + const simd = wasm.simd = parseBooleanArg(wasm.simd ?? args['wasm-enable-simd']); + if (typeof simd !== 'undefined' && typeof simd !== 'boolean') { + throw new Error('Flag "wasm.simd"/"wasm-enable-simd" must be a boolean value'); + } + const proxy = wasm.proxy = parseBooleanArg(wasm.proxy ?? args['wasm-enable-proxy']); + if (typeof proxy !== 'undefined' && typeof proxy !== 'boolean') { + throw new Error('Flag "wasm.proxy"/"wasm-enable-proxy" must be a boolean value'); + } + return wasm; } function parseWebglOptions(_args: minimist.ParsedArgs): InferenceSession.WebGLExecutionProviderOption { @@ -301,39 +308,43 @@ function parseWebglOptions(_args: minimist.ParsedArgs): InferenceSession.WebGLEx } function parseWebglFlags(args: minimist.ParsedArgs): Partial { - const contextId = args['webgl-context-id']; + const webgl = args.webgl || {}; + const contextId = webgl.contextId = webgl.contextId ?? args['webgl-context-id']; if (contextId !== undefined && contextId !== 'webgl' && contextId !== 'webgl2') { - throw new Error('Flag "webgl-context-id" is invalid'); + throw new Error('Flag "webgl.contextId"/"webgl-context-id" is invalid'); } - const matmulMaxBatchSize = args['webgl-matmul-max-batch-size']; + const matmulMaxBatchSize = webgl.matmulMaxBatchSize = webgl.matmulMaxBatchSize ?? args['webgl-matmul-max-batch-size']; if (matmulMaxBatchSize !== undefined && typeof matmulMaxBatchSize !== 'number') { - throw new Error('Flag "webgl-matmul-max-batch-size" must be a number value'); + throw new Error('Flag "webgl.matmulMaxBatchSize"/"webgl-matmul-max-batch-size" must be a number value'); } - const textureCacheMode = args['webgl-texture-cache-mode']; + const textureCacheMode = webgl.textureCacheMode = webgl.textureCacheMode ?? args['webgl-texture-cache-mode']; if (textureCacheMode !== undefined && textureCacheMode !== 'initializerOnly' && textureCacheMode !== 'full') { - throw new Error('Flag "webgl-texture-cache-mode" is invalid'); + throw new Error('Flag "webgl.textureCacheMode"/"webgl-texture-cache-mode" is invalid'); } - const pack = args['webgl-texture-pack-mode']; + const pack = webgl.pack = parseBooleanArg(webgl.pack ?? args['webgl-texture-pack-mode']); if (pack !== undefined && typeof pack !== 'boolean') { - throw new Error('Flag "webgl-texture-pack-mode" is invalid'); + throw new Error('Flag "webgl.pack"/"webgl-texture-pack-mode" is invalid'); } - const async = args['webgl-async']; + const async = webgl.async = parseBooleanArg(webgl.async ?? args['webgl-async']); if (async !== undefined && typeof async !== 'boolean') { - throw new Error('Flag "webgl-async" is invalid'); + throw new Error('Flag "webgl.async"/"webgl-async" is invalid'); } - return {contextId, matmulMaxBatchSize, textureCacheMode, pack}; + return webgl; } function parseWebgpuFlags(args: minimist.ParsedArgs): Partial { - const profilingMode = args['webgpu-profiling-mode']; + const webgpu = args.webgpu || {}; + const profilingMode = (webgpu.profiling = webgpu.profiling ?? {}).mode = + webgpu?.profiling?.mode ?? webgpu.profilingMode ?? args['webgpu-profiling-mode']; if (profilingMode !== undefined && profilingMode !== 'off' && profilingMode !== 'default') { throw new Error('Flag "webgpu-profiling-mode" is invalid'); } - const validateInputContent = args['webgpu-validate-input-content']; + const validateInputContent = webgpu.validateInputContent = + parseBooleanArg(webgpu.validateInputContent ?? args['webgpu-validate-input-content']); if (validateInputContent !== undefined && typeof validateInputContent !== 'boolean') { throw new Error('Flag "webgpu-validate-input-content" is invalid'); } - return {profilingMode, validateInputContent}; + return webgpu; } function parseWebNNOptions(args: minimist.ParsedArgs): InferenceSession.WebNNExecutionProviderOption { @@ -344,12 +355,11 @@ function parseWebNNOptions(args: minimist.ParsedArgs): InferenceSession.WebNNExe return {name: 'webnn', deviceType}; } -function parseGlobalEnvFlags(args: minimist.ParsedArgs): NonNullable { +function parseGlobalEnvFlags(args: minimist.ParsedArgs) { const wasm = parseWasmFlags(args); const webgl = parseWebglFlags(args); const webgpu = parseWebgpuFlags(args); - const cpuFlags = parseCpuFlags(args); - return {webgl, wasm, webgpu, ...cpuFlags}; + return {webgl, wasm, webgpu}; } export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs { @@ -394,15 +404,14 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs } } - const globalEnvFlags = parseGlobalEnvFlags(args); - // Options: // --log-verbose=<...> // --log-info=<...> // --log-warning=<...> // --log-error=<...> const logConfig = parseLogConfig(args); - globalEnvFlags.logLevel = logConfig[0]?.config.minimalSeverity; + let logLevel = logConfig[0]?.config.minimalSeverity; + // Option: -p, --profile const profile = (args.profile || args.p) ? true : false; if (profile) { @@ -410,9 +419,18 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs logConfig.push({category: 'Profiler.node', config: {minimalSeverity: 'verbose'}}); logConfig.push({category: 'Profiler.op', config: {minimalSeverity: 'verbose'}}); logConfig.push({category: 'Profiler.backend', config: {minimalSeverity: 'verbose'}}); - globalEnvFlags.logLevel = 'verbose'; + logLevel = 'verbose'; } + // Option: -t, --trace + const trace = parseBooleanArg(args.trace || args.t, false); + + // Options: + // --wasm.<...>=<...> + // --webgl.<...>=<...> + // --webgpu.<...>=<...> + const globalEnvFlags = {...parseGlobalEnvFlags(args), debug, trace, logLevel}; + // Option: -P[=<...>], --perf[=<...>] const perfArg = (args.perf || args.P); const perf = perfArg ? true : false; diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index d56792c6e3595..ace64e9532b12 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -12,6 +12,7 @@ import * as os from 'os'; import * as path from 'path'; import {inspect} from 'util'; +import {onnx} from '../lib/onnxjs/ort-schema/protobuf/onnx'; import {bufferToBase64} from '../test/test-shared'; import {Test} from '../test/test-types'; @@ -264,10 +265,12 @@ async function main() { let modelUrl: string|null = null; let cases: Test.ModelTestCase[] = []; + let externalData: Array<{data: string; path: string}>|undefined; npmlog.verbose('TestRunnerCli.Init.Model', `Start to prepare test data from folder: ${testDataRootFolder}`); try { + const maybeExternalDataFiles: Array<[fileNameWithoutExtension: string, size: number]> = []; for (const thisPath of fs.readdirSync(testDataRootFolder)) { const thisFullPath = path.join(testDataRootFolder, thisPath); const stat = fs.lstatSync(thisFullPath); @@ -282,6 +285,8 @@ async function main() { } else { throw new Error('there are multiple model files under the folder specified'); } + } else { + maybeExternalDataFiles.push([path.parse(thisPath).name, stat.size]); } } else if (stat.isDirectory()) { const dataFiles: string[] = []; @@ -307,6 +312,34 @@ async function main() { if (modelUrl === null) { throw new Error('there are no model file under the folder specified'); } + // for performance consideration, we do not parse every model. when we think it's likely to have external + // data, we will parse it. We think it's "likely" when one of the following conditions is met: + // 1. any file in the same folder has the similar file name as the model file + // (e.g., model file is "model_abc.onnx", and there is a file "model_abc.pb" or "model_abc.onnx.data") + // 2. the file size is larger than 1GB + const likelyToHaveExternalData = maybeExternalDataFiles.some( + ([fileNameWithoutExtension, size]) => + path.basename(modelUrl!).startsWith(fileNameWithoutExtension) || size >= 1 * 1024 * 1024 * 1024); + if (likelyToHaveExternalData) { + const model = onnx.ModelProto.decode(fs.readFileSync(path.join(testDataRootFolder, path.basename(modelUrl!)))); + const externalDataPathSet = new Set(); + for (const initializer of model.graph!.initializer!) { + if (initializer.externalData) { + for (const data of initializer.externalData) { + if (data.key === 'location') { + externalDataPathSet.add(data.value!); + } + } + } + } + externalData = []; + const externalDataPaths = [...externalDataPathSet]; + for (const dataPath of externalDataPaths) { + const fullPath = path.resolve(testDataRootFolder, dataPath); + const url = path.join(TEST_DATA_BASE, path.relative(TEST_ROOT, fullPath)); + externalData.push({data: url, path: dataPath}); + } + } } catch (e) { npmlog.error('TestRunnerCli.Init.Model', `Failed to prepare test data. Error: ${inspect(e)}`); throw e; @@ -340,9 +373,23 @@ async function main() { npmlog.verbose('TestRunnerCli.Init.Model', ` Model file: ${modelUrl}`); npmlog.verbose('TestRunnerCli.Init.Model', ` Backend: ${backend}`); npmlog.verbose('TestRunnerCli.Init.Model', ` Test set(s): ${cases.length} (${caseCount})`); + if (externalData) { + npmlog.verbose('TestRunnerCli.Init.Model', ` External data: ${externalData.length}`); + for (const data of externalData) { + npmlog.verbose('TestRunnerCli.Init.Model', ` - ${data.path}`); + } + } npmlog.verbose('TestRunnerCli.Init.Model', '==============================================================='); - return {name: path.basename(testDataRootFolder), platformCondition, modelUrl, backend, cases, ioBinding}; + return { + name: path.basename(testDataRootFolder), + platformCondition, + modelUrl, + backend, + cases, + ioBinding, + externalData + }; } function tryLocateModelTestFolder(searchPattern: string): string { @@ -495,14 +542,13 @@ async function main() { npmlog.info('TestRunnerCli.Run', '(4/4) Running karma to start test runner...'); const webgpu = args.backends.indexOf('webgpu') > -1; const webnn = args.backends.indexOf('webnn') > -1; - const browser = getBrowserNameFromEnv( - args.env, - args.bundleMode === 'perf' ? 'perf' : - args.debug ? 'debug' : - 'test', - webgpu); + const browser = getBrowserNameFromEnv(args.env); const karmaArgs = ['karma', 'start', `--browsers ${browser}`]; const chromiumFlags = ['--enable-features=SharedArrayBuffer', ...args.chromiumFlags]; + if (args.bundleMode === 'dev' && !args.debug) { + // use headless for 'test' mode (when 'perf' and 'debug' are OFF) + chromiumFlags.push('--headless=new'); + } if (args.debug) { karmaArgs.push('--log-level info --timeout-mocha 9999999'); chromiumFlags.push('--remote-debugging-port=9333'); @@ -523,6 +569,9 @@ async function main() { if (webnn) { chromiumFlags.push('--enable-experimental-web-platform-features'); } + if (process.argv.includes('--karma-debug')) { + karmaArgs.push('--log-level debug'); + } karmaArgs.push(`--bundle-mode=${args.bundleMode}`); karmaArgs.push(...chromiumFlags.map(flag => `--chromium-flags=${flag}`)); if (browser.startsWith('Edge')) { @@ -615,10 +664,10 @@ async function main() { fs.writeJSONSync(path.join(TEST_ROOT, './testdata-config.json'), config); } - function getBrowserNameFromEnv(env: TestRunnerCliArgs['env'], mode: 'debug'|'perf'|'test', webgpu: boolean) { + function getBrowserNameFromEnv(env: TestRunnerCliArgs['env']) { switch (env) { case 'chrome': - return selectChromeBrowser(mode, webgpu); + return 'ChromeTest'; case 'edge': return 'EdgeTest'; case 'firefox': @@ -633,20 +682,6 @@ async function main() { throw new Error(`env "${env}" not supported.`); } } - - function selectChromeBrowser(mode: 'debug'|'perf'|'test', webgpu: boolean) { - if (webgpu) { - return 'ChromeTest'; - } else { - switch (mode) { - case 'debug': - case 'perf': - return 'ChromeTest'; - default: - return 'ChromeTestHeadless'; - } - } - } } void main(); diff --git a/js/web/test/data/ops/add_zero-sized.jsonc b/js/web/test/data/ops/add_zero-sized.jsonc new file mode 100644 index 0000000000000..37e08cd7f20ac --- /dev/null +++ b/js/web/test/data/ops/add_zero-sized.jsonc @@ -0,0 +1,31 @@ +[ + { + "name": "Add with no attributes", + "operator": "Add", + "attributes": [], + "cases": [ + { + "name": "T[2,0] T[2,1]", + "inputs": [ + { + "data": [], + "dims": [2, 0], + "type": "float32" + }, + { + "data": [1, 2], + "dims": [2, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [], + "dims": [2, 0], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/concat_zero-sized.jsonc b/js/web/test/data/ops/concat_zero-sized.jsonc new file mode 100644 index 0000000000000..be9625145d157 --- /dev/null +++ b/js/web/test/data/ops/concat_zero-sized.jsonc @@ -0,0 +1,641 @@ +[ + { + "name": "Concat 2D axis=0", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": -2, "type": "int" }], + "cases": [ + { + "name": "X", + "inputs": [ + { + "data": [], + "dims": [1, 4, 0, 64], + "type": "float32" + }, + { + "data": [ + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], + "dims": [1, 4, 36, 64], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], + "dims": [1, 4, 36, 64], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Concat 2D axis=1; Preserve dims", + "operator": "Concat", + "attributes": [ + { + "name": "axis", + "data": 0, + "type": "int" + } + ], + "cases": [ + { + "name": "Some but not all input tensors are zero-sized", + "inputs": [ + { + "data": [], + "dims": [0, 1], + "type": "float32" + }, + { + "data": [1], + "dims": [1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1], + "dims": [1, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Concat 2D axis=1; Preserve dims", + "operator": "Concat", + "attributes": [ + { + "name": "axis", + "data": 1, + "type": "int" + } + ], + "cases": [ + { + "name": "All input tensors are zero-sized", + "inputs": [ + { + "data": [], + "dims": [0, 0], + "type": "float32" + }, + { + "data": [], + "dims": [0, 1], + "type": "float32" + }, + { + "data": [], + "dims": [0, 2], + "type": "float32" + }, + { + "data": [], + "dims": [0, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [], + "dims": [0, 6], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/fast-gelu.jsonc b/js/web/test/data/ops/fast-gelu.jsonc new file mode 100644 index 0000000000000..2550173e95402 --- /dev/null +++ b/js/web/test/data/ops/fast-gelu.jsonc @@ -0,0 +1,211 @@ +[ + { + "name": "FastGelu test without bias", + "operator": "FastGelu", + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "scalar", + "inputs": [ + { + "data": [1], + "dims": [], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.841192], + "dims": [], + "type": "float32" + } + ] + }, + { + "name": "[2x4]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], + "dims": [2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.0539828, 0.115851, 0.185371, 0.262161, 0.345714, 0.435415, 0.53057, 0.630432], + "dims": [2, 4], + "type": "float32" + } + ] + }, + { + "name": "[3x5]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5], + "dims": [3, 5], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.0539828, 0.115851, 0.185371, 0.262161, 0.345714, 0.841192, 1.9546, 2.99636, 3.99993, 5, 0.950581, + 1.0617, 1.17393, 1.28671, 1.39957 + ], + "dims": [3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "FastGelu test with bias", + "operator": "FastGelu", + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "scalar", + "inputs": [ + { + "data": [1], + "dims": [], + "type": "float32" + }, + { + "data": [0.5], + "dims": [], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1.39957], + "dims": [], + "type": "float32" + } + ] + }, + { + "name": "[2x4], [4]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], + "dims": [2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.950581, 2.16968, 3.29869, 4.39999, 1.39957, 2.58835, 3.69973, 4.8], + "dims": [2, 4], + "type": "float32" + } + ] + }, + { + "name": "[2x4], [3]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], + "dims": [2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.950581, 2.16968, 3.29869, 1.28671, 2.48492, 3.59959, 1.62411, 2.79331], + "dims": [2, 4], + "type": "float32" + } + ] + }, + { + "name": "[3x5], [2]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5], + "dims": [3, 5], + "type": "float32" + }, + { + "data": [2, 3], + "dims": [2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2.06267, 3.19813, 2.27567, 3.39909, 2.48492, 3.99993, 3.99993, 6, 6, 8, 3.09737, 4.19997, 3.29869, + 4.39999, 3.49938 + ], + "dims": [3, 5], + "type": "float32" + } + ] + }, + { + "name": "[3x5], [7]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5], + "dims": [3, 5], + "type": "float32" + }, + { + "data": [2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7], + "dims": [7], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2.16968, 2.38072, 2.58835, 2.79331, 2.99636, 3.59959, 4.7, 5.1, 6.2, 7.3, 3.49938, 3.69973, 3.89989, + 4.09996, 3.59959 + ], + "dims": [3, 5], + "type": "float32" + } + ] + }, + { + "name": "[4x4], [8]", + "inputs": [ + { + "data": [0.8, -0.5, 0.0, 1, 1.3, 2.1, -0.2, 1.1, 0.5, 0.2, 0.3, -0.6, 3.1, 2.2, -1.1, 0.0], + "dims": [4, 4], + "type": "float32" + }, + { + "data": [-0.5, 0.6, 1.2, 2.1, 1.3, -1, 0, 3.1], + "dims": [8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.185371, 0.0539828, 1.0617, 3.09737, 2.58835, 0.950581, -0.0841486, 4.19997, 0, 0.630432, 1.39957, + 1.39957, 4.39999, 1.0617, -0.149419, 3.09737 + ], + "dims": [4, 4], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc index ad1c0a72c11d3..6a10e3b96a26a 100644 --- a/js/web/test/data/ops/fused-conv.jsonc +++ b/js/web/test/data/ops/fused-conv.jsonc @@ -142,5 +142,293 @@ ] } ] + }, + { + "name": "fused conv with HardSigmoid", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "HardSigmoid", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [2.0, 5.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 0, 1, 1], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC conv with HardSigmoid", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "HardSigmoid", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [2.0, 5.0], "type": "floats" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 3, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 0, 1, 1], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "fused group-conv with HardSigmoid", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "HardSigmoid", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" }, + { "name": "activation_params", "data": [2.0, 5.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1], + "dims": [1, 3, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC group-conv with HardSigmoid", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "HardSigmoid", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" }, + { "name": "activation_params", "data": [2.0, 5.0], "type": "floats" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "dims": [1, 2, 2, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "fused group-conv with LeakyRelu", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "LeakyRelu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" }, + { "name": "activation_params", "data": [2.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [9, -6, 51, 47, -170, -10, 251, 229, 847, 889, 973, 1015], + "dims": [1, 3, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC group-conv with LeakyRelu", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "LeakyRelu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" }, + { "name": "activation_params", "data": [2.0], "type": "floats" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-162, 63, -158, 33, 281, 85, 105, 337, 455, 177, 515, 609], + "dims": [1, 2, 2, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "fused conv with LeakyRelu", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "LeakyRelu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [2.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-540, -860, 390, 430], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC conv with LeakyRelu", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "LeakyRelu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [2.0], "type": "floats" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 3, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-540, -860, 390, 430], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ] + } + ] } ] diff --git a/js/web/test/data/ops/matmulnbits.jsonc b/js/web/test/data/ops/matmulnbits.jsonc new file mode 100644 index 0000000000000..175be78cc0818 --- /dev/null +++ b/js/web/test/data/ops/matmulnbits.jsonc @@ -0,0 +1,1584 @@ +[ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127 + ], + "dims": [8, 16], + "type": "float32" + }, + { + "dims": [8, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 + ] + }, + { + "dims": [8], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7] + } + ], + "outputs": [ + { + "dims": [8, 8], + "type": "float32", + "data": [ + 0, -385, -1120, -963, -1984, -1285, -2592, -1351, 0, -1073, -3808, -2643, -6848, -3445, -9120, -3479, 0, + -1761, -6496, -4323, -11712, -5605, -15648, -5607, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735, + 0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, 0, -3825, -14560, -9363, -26304, -12085, -35232, + -11991, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, 0, -5201, -19936, -12723, -36032, + -16405, -48288, -16247 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255 + ], + "dims": [16, 16], + "type": "float32" + }, + { + "dims": [16, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128 + ] + }, + { + "dims": [16], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + 0, -385, -1120, -963, -1984, -1285, -2592, -1351, -2944, -1161, -3040, -715, -2880, -13, -2464, 945, 0, + -1073, -3808, -2643, -6848, -3445, -9120, -3479, -10624, -2745, -11360, -1243, -11328, 1027, -10528, 4065, + 0, -1761, -6496, -4323, -11712, -5605, -15648, -5607, -18304, -4329, -19680, -1771, -19776, 2067, -18592, + 7185, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735, -25984, -5913, -28000, -2299, -28224, 3107, + -26656, 10305, 0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, -33664, -7497, -36320, -2827, + -36672, 4147, -34720, 13425, 0, -3825, -14560, -9363, -26304, -12085, -35232, -11991, -41344, -9081, + -44640, -3355, -45120, 5187, -42784, 16545, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, + -49024, -10665, -52960, -3883, -53568, 6227, -50848, 19665, 0, -5201, -19936, -12723, -36032, -16405, + -48288, -16247, -56704, -12249, -61280, -4411, -62016, 7267, -58912, 22785, 0, -5889, -22624, -14403, + -40896, -18565, -54816, -18375, -64384, -13833, -69600, -4939, -70464, 8307, -66976, 25905, 0, -6577, + -25312, -16083, -45760, -20725, -61344, -20503, -72064, -15417, -77920, -5467, -78912, 9347, -75040, + 29025, 0, -7265, -28000, -17763, -50624, -22885, -67872, -22631, -79744, -17001, -86240, -5995, -87360, + 10387, -83104, 32145, 0, -7953, -30688, -19443, -55488, -25045, -74400, -24759, -87424, -18585, -94560, + -6523, -95808, 11427, -91168, 35265, 0, -8641, -33376, -21123, -60352, -27205, -80928, -26887, -95104, + -20169, -102880, -7051, -104256, 12467, -99232, 38385, 0, -9329, -36064, -22803, -65216, -29365, -87456, + -29015, -102784, -21753, -111200, -7579, -112704, 13507, -107296, 41505, 0, -10017, -38752, -24483, + -70080, -31525, -93984, -31143, -110464, -23337, -119520, -8107, -121152, 14547, -115360, 44625, 0, + -10705, -41440, -26163, -74944, -33685, -100512, -33271, -118144, -24921, -127840, -8635, -129600, 15587, + -123424, 47745 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255 + ], + "dims": [16, 16], + "type": "float32" + }, + { + "dims": [16, 1, 8], + "type": "uint8", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127 + ] + }, + { + "dims": [16], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + }, + { + "dims": [16], + "type": "uint8", + "data": [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + 0, 728, 688, 2376, 1632, 4280, 2832, 6440, 4288, 8856, 6000, 11528, 7968, 14456, 10192, 17640, 0, 2200, + 1840, 7176, 4448, 12920, 7824, 19432, 11968, 26712, 16880, 34760, 22560, 43576, 29008, 53160, 0, 3672, + 2992, 11976, 7264, 21560, 12816, 32424, 19648, 44568, 27760, 57992, 37152, 72696, 47824, 88680, 0, 5144, + 4144, 16776, 10080, 30200, 17808, 45416, 27328, 62424, 38640, 81224, 51744, 101816, 66640, 124200, 0, + 6616, 5296, 21576, 12896, 38840, 22800, 58408, 35008, 80280, 49520, 104456, 66336, 130936, 85456, 159720, + 0, 8088, 6448, 26376, 15712, 47480, 27792, 71400, 42688, 98136, 60400, 127688, 80928, 160056, 104272, + 195240, 0, 9560, 7600, 31176, 18528, 56120, 32784, 84392, 50368, 115992, 71280, 150920, 95520, 189176, + 123088, 230760, 0, 11032, 8752, 35976, 21344, 64760, 37776, 97384, 58048, 133848, 82160, 174152, 110112, + 218296, 141904, 266280, 0, 12504, 9904, 40776, 24160, 73400, 42768, 110376, 65728, 151704, 93040, 197384, + 124704, 247416, 160720, 301800, 0, 13976, 11056, 45576, 26976, 82040, 47760, 123368, 73408, 169560, + 103920, 220616, 139296, 276536, 179536, 337320, 0, 15448, 12208, 50376, 29792, 90680, 52752, 136360, + 81088, 187416, 114800, 243848, 153888, 305656, 198352, 372840, 0, 16920, 13360, 55176, 32608, 99320, + 57744, 149352, 88768, 205272, 125680, 267080, 168480, 334776, 217168, 408360, 0, 18392, 14512, 59976, + 35424, 107960, 62736, 162344, 96448, 223128, 136560, 290312, 183072, 363896, 235984, 443880, 0, 19864, + 15664, 64776, 38240, 116600, 67728, 175336, 104128, 240984, 147440, 313544, 197664, 393016, 254800, + 479400, 0, 21336, 16816, 69576, 41056, 125240, 72720, 188328, 111808, 258840, 158320, 336776, 212256, + 422136, 273616, 514920, 0, 22808, 17968, 74376, 43872, 133880, 77712, 201320, 119488, 276696, 169200, + 360008, 226848, 451256, 292432, 550440 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [32, 16], + "type": "float32" + }, + { + "dims": [32, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, -428, -1288, -1068, -2288, -1420, -3000, -1484, -3424, -1260, -3560, -748, -3408, 52, -2968, 1140, + -2272, 2516, -1224, 4180, 80, 6132, 1672, 8372, 3552, 10900, 5720, 13716, 8176, 16820, 10920, 12276, 0, + -1116, -3976, -2748, -7152, -3580, -9528, -3612, -11104, -2844, -11880, -1276, -11856, 1092, -11032, 4260, + -8160, 8228, -6984, 12996, -3760, 18564, 264, 24932, 5088, 32100, 10712, 40068, 17136, 48836, 24360, + 42532, 0, -1804, -6664, -4428, -12016, -5740, -16056, -5740, -18784, -4428, -20200, -1804, -20304, 2132, + -19096, 7380, -14048, 13940, -12744, 21812, -7600, 30996, -1144, 41492, 6624, 53300, 15704, 66420, 26096, + 80852, 37800, 72788, 0, -2492, -9352, -6108, -16880, -7900, -22584, -7868, -26464, -6012, -28520, -2332, + -28752, 3172, -27160, 10500, -19936, 19652, -18504, 30628, -11440, 43428, -2552, 58052, 8160, 74500, + 20696, 92772, 35056, 112868, 51240, 103044, 0, -3180, -12040, -7788, -21744, -10060, -29112, -9996, + -34144, -7596, -36840, -2860, -37200, 4212, -35224, 13620, -25824, 25364, -24264, 39444, -15280, 55860, + -3960, 74612, 9696, 95700, 25688, 119124, 44016, 144884, 64680, 133300, 0, -3868, -14728, -9468, -26608, + -12220, -35640, -12124, -41824, -9180, -45160, -3388, -45648, 5252, -43288, 16740, -31712, 31076, -30024, + 48260, -19120, 68292, -5368, 91172, 11232, 116900, 30680, 145476, 52976, 176900, 78120, 163556, 0, -4556, + -17416, -11148, -31472, -14380, -42168, -14252, -49504, -10764, -53480, -3916, -54096, 6292, -51352, + 19860, -37600, 36788, -35784, 57076, -22960, 80724, -6776, 107732, 12768, 138100, 35672, 171828, 61936, + 208916, 91560, 193812, 0, -5244, -20104, -12828, -36336, -16540, -48696, -16380, -57184, -12348, -61800, + -4444, -62544, 7332, -59416, 22980, -43488, 42500, -41544, 65892, -26800, 93156, -8184, 124292, 14304, + 159300, 40664, 198180, 70896, 240932, 105000, 224068, 0, -5932, -22792, -14508, -41200, -18700, -55224, + -18508, -64864, -13932, -70120, -4972, -70992, 8372, -67480, 26100, -49376, 48212, -47304, 74708, -30640, + 105588, -9592, 140852, 15840, 180500, 45656, 224532, 79856, 272948, 118440, 254324, 0, -6620, -25480, + -16188, -46064, -20860, -61752, -20636, -72544, -15516, -78440, -5500, -79440, 9412, -75544, 29220, + -55264, 53924, -53064, 83524, -34480, 118020, -11000, 157412, 17376, 201700, 50648, 250884, 88816, 304964, + 131880, 284580, 0, -7308, -28168, -17868, -50928, -23020, -68280, -22764, -80224, -17100, -86760, -6028, + -87888, 10452, -83608, 32340, -61152, 59636, -58824, 92340, -38320, 130452, -12408, 173972, 18912, 222900, + 55640, 277236, 97776, 336980, 145320, 314836, 0, -7996, -30856, -19548, -55792, -25180, -74808, -24892, + -87904, -18684, -95080, -6556, -96336, 11492, -91672, 35460, -67040, 65348, -64584, 101156, -42160, + 142884, -13816, 190532, 20448, 244100, 60632, 303588, 106736, 368996, 158760, 345092, 0, -8684, -33544, + -21228, -60656, -27340, -81336, -27020, -95584, -20268, -103400, -7084, -104784, 12532, -99736, 38580, + -72928, 71060, -70344, 109972, -46000, 155316, -15224, 207092, 21984, 265300, 65624, 329940, 115696, + 401012, 172200, 375348, 0, -9372, -36232, -22908, -65520, -29500, -87864, -29148, -103264, -21852, + -111720, -7612, -113232, 13572, -107800, 41700, -78816, 76772, -76104, 118788, -49840, 167748, -16632, + 223652, 23520, 286500, 70616, 356292, 124656, 433028, 185640, 405604, 0, -10060, -38920, -24588, -70384, + -31660, -94392, -31276, -110944, -23436, -120040, -8140, -121680, 14612, -115864, 44820, -84704, 82484, + -81864, 127604, -53680, 180180, -18040, 240212, 25056, 307700, 75608, 382644, 133616, 465044, 199080, + 435860, 0, -10748, -41608, -26268, -75248, -33820, -100920, -33404, -118624, -25020, -128360, -8668, + -130128, 15652, -123928, 47940, -90592, 88196, -87624, 136420, -57520, 192612, -19448, 256772, 26592, + 328900, 80600, 408996, 142576, 497060, 212520, 466116, 0, -11436, -44296, -27948, -80112, -35980, -107448, + -35532, -126304, -26604, -136680, -9196, -138576, 16692, -131992, 51060, -96480, 93908, -93384, 145236, + -61360, 205044, -20856, 273332, 28128, 350100, 85592, 435348, 151536, 529076, 225960, 496372, 0, -12124, + -46984, -29628, -84976, -38140, -113976, -37660, -133984, -28188, -145000, -9724, -147024, 17732, -140056, + 54180, -102368, 99620, -99144, 154052, -65200, 217476, -22264, 289892, 29664, 371300, 90584, 461700, + 160496, 561092, 239400, 526628, 0, -12812, -49672, -31308, -89840, -40300, -120504, -39788, -141664, + -29772, -153320, -10252, -155472, 18772, -148120, 57300, -108256, 105332, -104904, 162868, -69040, 229908, + -23672, 306452, 31200, 392500, 95576, 488052, 169456, 593108, 252840, 556884, 0, -13500, -52360, -32988, + -94704, -42460, -127032, -41916, -149344, -31356, -161640, -10780, -163920, 19812, -156184, 60420, + -114144, 111044, -110664, 171684, -72880, 242340, -25080, 323012, 32736, 413700, 100568, 514404, 178416, + 625124, 266280, 587140, 0, -14188, -55048, -34668, -99568, -44620, -133560, -44044, -157024, -32940, + -169960, -11308, -172368, 20852, -164248, 63540, -120032, 116756, -116424, 180500, -76720, 254772, -26488, + 339572, 34272, 434900, 105560, 540756, 187376, 657140, 279720, 617396, 0, -14876, -57736, -36348, -104432, + -46780, -140088, -46172, -164704, -34524, -178280, -11836, -180816, 21892, -172312, 66660, -125920, + 122468, -122184, 189316, -80560, 267204, -27896, 356132, 35808, 456100, 110552, 567108, 196336, 689156, + 293160, 647652, 0, -15564, -60424, -38028, -109296, -48940, -146616, -48300, -172384, -36108, -186600, + -12364, -189264, 22932, -180376, 69780, -131808, 128180, -127944, 198132, -84400, 279636, -29304, 372692, + 37344, 477300, 115544, 593460, 205296, 721172, 306600, 677908, 0, -16252, -63112, -39708, -114160, -51100, + -153144, -50428, -180064, -37692, -194920, -12892, -197712, 23972, -188440, 72900, -137696, 133892, + -133704, 206948, -88240, 292068, -30712, 389252, 38880, 498500, 120536, 619812, 214256, 753188, 320040, + 708164, 0, -16940, -65800, -41388, -119024, -53260, -159672, -52556, -187744, -39276, -203240, -13420, + -206160, 25012, -196504, 76020, -143584, 139604, -139464, 215764, -92080, 304500, -32120, 405812, 40416, + 519700, 125528, 646164, 223216, 785204, 333480, 738420, 0, -17628, -68488, -43068, -123888, -55420, + -166200, -54684, -195424, -40860, -211560, -13948, -214608, 26052, -204568, 79140, -149472, 145316, + -145224, 224580, -95920, 316932, -33528, 422372, 41952, 540900, 130520, 672516, 232176, 817220, 346920, + 768676, 0, -18316, -71176, -44748, -128752, -57580, -172728, -56812, -203104, -42444, -219880, -14476, + -223056, 27092, -212632, 82260, -155360, 151028, -150984, 233396, -99760, 329364, -34936, 438932, 43488, + 562100, 135512, 698868, 241136, 849236, 360360, 798932, 0, -19004, -73864, -46428, -133616, -59740, + -179256, -58940, -210784, -44028, -228200, -15004, -231504, 28132, -220696, 85380, -161248, 156740, + -156744, 242212, -103600, 341796, -36344, 455492, 45024, 583300, 140504, 725220, 250096, 881252, 373800, + 829188, 0, -19692, -76552, -48108, -138480, -61900, -185784, -61068, -218464, -45612, -236520, -15532, + -239952, 29172, -228760, 88500, -167136, 162452, -162504, 251028, -107440, 354228, -37752, 472052, 46560, + 604500, 145496, 751572, 259056, 913268, 387240, 859444, 0, -20380, -79240, -49788, -143344, -64060, + -192312, -63196, -226144, -47196, -244840, -16060, -248400, 30212, -236824, 91620, -173024, 168164, + -168264, 259844, -111280, 366660, -39160, 488612, 48096, 625700, 150488, 777924, 268016, 945284, 400680, + 889700, 0, -21068, -81928, -51468, -148208, -66220, -198840, -65324, -233824, -48780, -253160, -16588, + -256848, 31252, -244888, 94740, -178912, 173876, -174024, 268660, -115120, 379092, -40568, 505172, 49632, + 646900, 155480, 804276, 276976, 977300, 414120, 919956, 0, -21756, -84616, -53148, -153072, -68380, + -205368, -67452, -241504, -50364, -261480, -17116, -265296, 32292, -252952, 97860, -184800, 179588, + -179784, 277476, -118960, 391524, -41976, 521732, 51168, 668100, 160472, 830628, 285936, 1009316, 427560, + 950212 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [32, 16], + "type": "float32" + }, + { + "dims": [32, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + }, + { + "dims": [32], + "type": "uint8", + "data": [ + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, 660, 888, 2196, 2064, 4020, 3528, 6132, 5280, 8532, 7320, 11220, 9648, 14196, 12264, 17460, 15136, + 21012, 18360, 24852, 21840, 28980, 25608, 33396, 29664, 38100, 34008, 43092, 38640, 48372, 43560, 46004, + 0, 2020, 2296, 6660, 5392, 12100, 9288, 18340, 13984, 25380, 19480, 33220, 25776, 41860, 32872, 51300, + 42016, 61540, 49464, 72580, 58960, 84420, 69256, 97060, 80352, 110500, 92248, 124740, 104944, 139780, + 118440, 139748, 0, 3380, 3704, 11124, 8720, 20180, 15048, 30548, 22688, 42228, 31640, 55220, 41904, 69524, + 53480, 85140, 68896, 102068, 80568, 120308, 96080, 139860, 112904, 160724, 131040, 182900, 150488, 206388, + 171248, 231188, 193320, 233492, 0, 4740, 5112, 15588, 12048, 28260, 20808, 42756, 31392, 59076, 43800, + 77220, 58032, 97188, 74088, 118980, 95776, 142596, 111672, 168036, 133200, 195300, 156552, 224388, 181728, + 255300, 208728, 288036, 237552, 322596, 268200, 327236, 0, 6100, 6520, 20052, 15376, 36340, 26568, 54964, + 40096, 75924, 55960, 99220, 74160, 124852, 94696, 152820, 122656, 183124, 142776, 215764, 170320, 250740, + 200200, 288052, 232416, 327700, 266968, 369684, 303856, 414004, 343080, 420980, 0, 7460, 7928, 24516, + 18704, 44420, 32328, 67172, 48800, 92772, 68120, 121220, 90288, 152516, 115304, 186660, 149536, 223652, + 173880, 263492, 207440, 306180, 243848, 351716, 283104, 400100, 325208, 451332, 370160, 505412, 417960, + 514724, 0, 8820, 9336, 28980, 22032, 52500, 38088, 79380, 57504, 109620, 80280, 143220, 106416, 180180, + 135912, 220500, 176416, 264180, 204984, 311220, 244560, 361620, 287496, 415380, 333792, 472500, 383448, + 532980, 436464, 596820, 492840, 608468, 0, 10180, 10744, 33444, 25360, 60580, 43848, 91588, 66208, 126468, + 92440, 165220, 122544, 207844, 156520, 254340, 203296, 304708, 236088, 358948, 281680, 417060, 331144, + 479044, 384480, 544900, 441688, 614628, 502768, 688228, 567720, 702212, 0, 11540, 12152, 37908, 28688, + 68660, 49608, 103796, 74912, 143316, 104600, 187220, 138672, 235508, 177128, 288180, 230176, 345236, + 267192, 406676, 318800, 472500, 374792, 542708, 435168, 617300, 499928, 696276, 569072, 779636, 642600, + 795956, 0, 12900, 13560, 42372, 32016, 76740, 55368, 116004, 83616, 160164, 116760, 209220, 154800, + 263172, 197736, 322020, 257056, 385764, 298296, 454404, 355920, 527940, 418440, 606372, 485856, 689700, + 558168, 777924, 635376, 871044, 717480, 889700, 0, 14260, 14968, 46836, 35344, 84820, 61128, 128212, + 92320, 177012, 128920, 231220, 170928, 290836, 218344, 355860, 283936, 426292, 329400, 502132, 393040, + 583380, 462088, 670036, 536544, 762100, 616408, 859572, 701680, 962452, 792360, 983444, 0, 15620, 16376, + 51300, 38672, 92900, 66888, 140420, 101024, 193860, 141080, 253220, 187056, 318500, 238952, 389700, + 310816, 466820, 360504, 549860, 430160, 638820, 505736, 733700, 587232, 834500, 674648, 941220, 767984, + 1053860, 867240, 1077188, 0, 16980, 17784, 55764, 42000, 100980, 72648, 152628, 109728, 210708, 153240, + 275220, 203184, 346164, 259560, 423540, 337696, 507348, 391608, 597588, 467280, 694260, 549384, 797364, + 637920, 906900, 732888, 1022868, 834288, 1145268, 942120, 1170932, 0, 18340, 19192, 60228, 45328, 109060, + 78408, 164836, 118432, 227556, 165400, 297220, 219312, 373828, 280168, 457380, 364576, 547876, 422712, + 645316, 504400, 749700, 593032, 861028, 688608, 979300, 791128, 1104516, 900592, 1236676, 1017000, + 1264676, 0, 19700, 20600, 64692, 48656, 117140, 84168, 177044, 127136, 244404, 177560, 319220, 235440, + 401492, 300776, 491220, 391456, 588404, 453816, 693044, 541520, 805140, 636680, 924692, 739296, 1051700, + 849368, 1186164, 966896, 1328084, 1091880, 1358420, 0, 21060, 22008, 69156, 51984, 125220, 89928, 189252, + 135840, 261252, 189720, 341220, 251568, 429156, 321384, 525060, 418336, 628932, 484920, 740772, 578640, + 860580, 680328, 988356, 789984, 1124100, 907608, 1267812, 1033200, 1419492, 1166760, 1452164, 0, 22420, + 23416, 73620, 55312, 133300, 95688, 201460, 144544, 278100, 201880, 363220, 267696, 456820, 341992, + 558900, 445216, 669460, 516024, 788500, 615760, 916020, 723976, 1052020, 840672, 1196500, 965848, 1349460, + 1099504, 1510900, 1241640, 1545908, 0, 23780, 24824, 78084, 58640, 141380, 101448, 213668, 153248, 294948, + 214040, 385220, 283824, 484484, 362600, 592740, 472096, 709988, 547128, 836228, 652880, 971460, 767624, + 1115684, 891360, 1268900, 1024088, 1431108, 1165808, 1602308, 1316520, 1639652, 0, 25140, 26232, 82548, + 61968, 149460, 107208, 225876, 161952, 311796, 226200, 407220, 299952, 512148, 383208, 626580, 498976, + 750516, 578232, 883956, 690000, 1026900, 811272, 1179348, 942048, 1341300, 1082328, 1512756, 1232112, + 1693716, 1391400, 1733396, 0, 26500, 27640, 87012, 65296, 157540, 112968, 238084, 170656, 328644, 238360, + 429220, 316080, 539812, 403816, 660420, 525856, 791044, 609336, 931684, 727120, 1082340, 854920, 1243012, + 992736, 1413700, 1140568, 1594404, 1298416, 1785124, 1466280, 1827140, 0, 27860, 29048, 91476, 68624, + 165620, 118728, 250292, 179360, 345492, 250520, 451220, 332208, 567476, 424424, 694260, 552736, 831572, + 640440, 979412, 764240, 1137780, 898568, 1306676, 1043424, 1486100, 1198808, 1676052, 1364720, 1876532, + 1541160, 1920884, 0, 29220, 30456, 95940, 71952, 173700, 124488, 262500, 188064, 362340, 262680, 473220, + 348336, 595140, 445032, 728100, 579616, 872100, 671544, 1027140, 801360, 1193220, 942216, 1370340, + 1094112, 1558500, 1257048, 1757700, 1431024, 1967940, 1616040, 2014628, 0, 30580, 31864, 100404, 75280, + 181780, 130248, 274708, 196768, 379188, 274840, 495220, 364464, 622804, 465640, 761940, 606496, 912628, + 702648, 1074868, 838480, 1248660, 985864, 1434004, 1144800, 1630900, 1315288, 1839348, 1497328, 2059348, + 1690920, 2108372, 0, 31940, 33272, 104868, 78608, 189860, 136008, 286916, 205472, 396036, 287000, 517220, + 380592, 650468, 486248, 795780, 633376, 953156, 733752, 1122596, 875600, 1304100, 1029512, 1497668, + 1195488, 1703300, 1373528, 1920996, 1563632, 2150756, 1765800, 2202116, 0, 33300, 34680, 109332, 81936, + 197940, 141768, 299124, 214176, 412884, 299160, 539220, 396720, 678132, 506856, 829620, 660256, 993684, + 764856, 1170324, 912720, 1359540, 1073160, 1561332, 1246176, 1775700, 1431768, 2002644, 1629936, 2242164, + 1840680, 2295860, 0, 34660, 36088, 113796, 85264, 206020, 147528, 311332, 222880, 429732, 311320, 561220, + 412848, 705796, 527464, 863460, 687136, 1034212, 795960, 1218052, 949840, 1414980, 1116808, 1624996, + 1296864, 1848100, 1490008, 2084292, 1696240, 2333572, 1915560, 2389604, 0, 36020, 37496, 118260, 88592, + 214100, 153288, 323540, 231584, 446580, 323480, 583220, 428976, 733460, 548072, 897300, 714016, 1074740, + 827064, 1265780, 986960, 1470420, 1160456, 1688660, 1347552, 1920500, 1548248, 2165940, 1762544, 2424980, + 1990440, 2483348, 0, 37380, 38904, 122724, 91920, 222180, 159048, 335748, 240288, 463428, 335640, 605220, + 445104, 761124, 568680, 931140, 740896, 1115268, 858168, 1313508, 1024080, 1525860, 1204104, 1752324, + 1398240, 1992900, 1606488, 2247588, 1828848, 2516388, 2065320, 2577092, 0, 38740, 40312, 127188, 95248, + 230260, 164808, 347956, 248992, 480276, 347800, 627220, 461232, 788788, 589288, 964980, 767776, 1155796, + 889272, 1361236, 1061200, 1581300, 1247752, 1815988, 1448928, 2065300, 1664728, 2329236, 1895152, 2607796, + 2140200, 2670836, 0, 40100, 41720, 131652, 98576, 238340, 170568, 360164, 257696, 497124, 359960, 649220, + 477360, 816452, 609896, 998820, 794656, 1196324, 920376, 1408964, 1098320, 1636740, 1291400, 1879652, + 1499616, 2137700, 1722968, 2410884, 1961456, 2699204, 2215080, 2764580, 0, 41460, 43128, 136116, 101904, + 246420, 176328, 372372, 266400, 513972, 372120, 671220, 493488, 844116, 630504, 1032660, 821536, 1236852, + 951480, 1456692, 1135440, 1692180, 1335048, 1943316, 1550304, 2210100, 1781208, 2492532, 2027760, 2790612, + 2289960, 2858324, 0, 42820, 44536, 140580, 105232, 254500, 182088, 384580, 275104, 530820, 384280, 693220, + 509616, 871780, 651112, 1066500, 848416, 1277380, 982584, 1504420, 1172560, 1747620, 1378696, 2006980, + 1600992, 2282500, 1839448, 2574180, 2094064, 2882020, 2364840, 2952068 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [16, 32], + "type": "float32" + }, + { + "dims": [16, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + -1116, -4036, -5868, -6612, -6268, -4836, -2316, 1292, 5956, 11772, 18644, 26604, 35652, 45788, 57012, + 53452, -2492, -12772, -19916, -23924, -24796, -22532, -17132, -8596, 5604, 17884, 35828, 56908, 81124, + 108476, 138964, 140844, -3868, -21508, -33964, -41236, -43324, -40228, -31948, -18484, 5252, 23996, 53012, + 87212, 126596, 171164, 220916, 228236, -5244, -30244, -48012, -58548, -61852, -57924, -46764, -28372, + 4900, 30108, 70196, 117516, 172068, 233852, 302868, 315628, -6620, -38980, -62060, -75860, -80380, -75620, + -61580, -38260, 4548, 36220, 87380, 147820, 217540, 296540, 384820, 403020, -7996, -47716, -76108, -93172, + -98908, -93316, -76396, -48148, 4196, 42332, 104564, 178124, 263012, 359228, 466772, 490412, -9372, + -56452, -90156, -110484, -117436, -111012, -91212, -58036, 3844, 48444, 121748, 208428, 308484, 421916, + 548724, 577804, -10748, -65188, -104204, -127796, -135964, -128708, -106028, -67924, 3492, 54556, 138932, + 238732, 353956, 484604, 630676, 665196, -12124, -73924, -118252, -145108, -154492, -146404, -120844, + -77812, 3140, 60668, 156116, 269036, 399428, 547292, 712628, 752588, -13500, -82660, -132300, -162420, + -173020, -164100, -135660, -87700, 2788, 66780, 173300, 299340, 444900, 609980, 794580, 839980, -14876, + -91396, -146348, -179732, -191548, -181796, -150476, -97588, 2436, 72892, 190484, 329644, 490372, 672668, + 876532, 927372, -16252, -100132, -160396, -197044, -210076, -199492, -165292, -107476, 2084, 79004, + 207668, 359948, 535844, 735356, 958484, 1014764, -17628, -108868, -174444, -214356, -228604, -217188, + -180108, -117364, 1732, 85116, 224852, 390252, 581316, 798044, 1040436, 1102156, -19004, -117604, -188492, + -231668, -247132, -234884, -194924, -127252, 1380, 91228, 242036, 420556, 626788, 860732, 1122388, + 1189548, -20380, -126340, -202540, -248980, -265660, -252580, -209740, -137140, 1028, 97340, 259220, + 450860, 672260, 923420, 1204340, 1276940, -21756, -135076, -216588, -266292, -284188, -270276, -224556, + -147028, 676, 103452, 276404, 481164, 717732, 986108, 1286292, 1364332 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [16, 32], + "type": "float32" + }, + { + "dims": [16, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + }, + { + "dims": [16], + "type": "uint8", + "data": [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + -1116, -1860, -1516, -84, 2436, 6044, 10740, 16524, 23364, 31356, 40404, 50540, 61764, 74076, 87476, + 86092, -2492, -2404, 820, 7180, 16676, 29308, 45076, 63980, 88548, 111196, 139508, 170956, 205540, 243260, + 284116, 296364, -3868, -2948, 3156, 14444, 30916, 52572, 79412, 111436, 153732, 191036, 238612, 291372, + 349316, 412444, 480756, 506636, -5244, -3492, 5492, 21708, 45156, 75836, 113748, 158892, 218916, 270876, + 337716, 411788, 493092, 581628, 677396, 716908, -6620, -4036, 7828, 28972, 59396, 99100, 148084, 206348, + 284100, 350716, 436820, 532204, 636868, 750812, 874036, 927180, -7996, -4580, 10164, 36236, 73636, 122364, + 182420, 253804, 349284, 430556, 535924, 652620, 780644, 919996, 1070676, 1137452, -9372, -5124, 12500, + 43500, 87876, 145628, 216756, 301260, 414468, 510396, 635028, 773036, 924420, 1089180, 1267316, 1347724, + -10748, -5668, 14836, 50764, 102116, 168892, 251092, 348716, 479652, 590236, 734132, 893452, 1068196, + 1258364, 1463956, 1557996, -12124, -6212, 17172, 58028, 116356, 192156, 285428, 396172, 544836, 670076, + 833236, 1013868, 1211972, 1427548, 1660596, 1768268, -13500, -6756, 19508, 65292, 130596, 215420, 319764, + 443628, 610020, 749916, 932340, 1134284, 1355748, 1596732, 1857236, 1978540, -14876, -7300, 21844, 72556, + 144836, 238684, 354100, 491084, 675204, 829756, 1031444, 1254700, 1499524, 1765916, 2053876, 2188812, + -16252, -7844, 24180, 79820, 159076, 261948, 388436, 538540, 740388, 909596, 1130548, 1375116, 1643300, + 1935100, 2250516, 2399084, -17628, -8388, 26516, 87084, 173316, 285212, 422772, 585996, 805572, 989436, + 1229652, 1495532, 1787076, 2104284, 2447156, 2609356, -19004, -8932, 28852, 94348, 187556, 308476, 457108, + 633452, 870756, 1069276, 1328756, 1615948, 1930852, 2273468, 2643796, 2819628, -20380, -9476, 31188, + 101612, 201796, 331740, 491444, 680908, 935940, 1149116, 1427860, 1736364, 2074628, 2442652, 2840436, + 3029900, -21756, -10020, 33524, 108876, 216036, 355004, 525780, 728364, 1001124, 1228956, 1526964, + 1856780, 2218404, 2611836, 3037076, 3240172 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [64], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + -1116, -4036, -5868, -6612, -6268, -4836, -2316, 1292, 5956, 11772, 18644, 26604, 35652, 45788, 57012, + 53452, -59740, -53956, -47084, -39124, -30076, -19940, -8716, 3596, 16996, 31484, 47060, 63724, 81476, + 100316, 120244, 109004, -2492, -12772, -19916, -23924, -24796, -22532, -17132, -8596, 5604, 17884, 35828, + 56908, 81124, 108476, 138964, 140844, -199356, -184548, -166604, -145524, -121308, -93956, -63468, -29844, + 6916, 46812, 89844, 136012, 185316, 237756, 293332, 287532, -3868, -21508, -33964, -41236, -43324, -40228, + -31948, -18484, 5252, 23996, 53012, 87212, 126596, 171164, 220916, 228236, -338972, -315140, -286124, + -251924, -212540, -167972, -118220, -63284, -3164, 62140, 132628, 208300, 289156, 375196, 466420, 466060, + -5244, -30244, -48012, -58548, -61852, -57924, -46764, -28372, 4900, 30108, 70196, 117516, 172068, 233852, + 302868, 315628, -478588, -445732, -405644, -358324, -303772, -241988, -172972, -96724, -13244, 77468, + 175412, 280588, 392996, 512636, 639508, 644588, -6620, -38980, -62060, -75860, -80380, -75620, -61580, + -38260, 4548, 36220, 87380, 147820, 217540, 296540, 384820, 403020, -618204, -576324, -525164, -464724, + -395004, -316004, -227724, -130164, -23324, 92796, 218196, 352876, 496836, 650076, 812596, 823116, -7996, + -47716, -76108, -93172, -98908, -93316, -76396, -48148, 4196, 42332, 104564, 178124, 263012, 359228, + 466772, 490412, -757820, -706916, -644684, -571124, -486236, -390020, -282476, -163604, -33404, 108124, + 260980, 425164, 600676, 787516, 985684, 1001644, -9372, -56452, -90156, -110484, -117436, -111012, -91212, + -58036, 3844, 48444, 121748, 208428, 308484, 421916, 548724, 577804, -897436, -837508, -764204, -677524, + -577468, -464036, -337228, -197044, -43484, 123452, 303764, 497452, 704516, 924956, 1158772, 1180172, + -10748, -65188, -104204, -127796, -135964, -128708, -106028, -67924, 3492, 54556, 138932, 238732, 353956, + 484604, 630676, 665196, -1037052, -968100, -883724, -783924, -668700, -538052, -391980, -230484, -53564, + 138780, 346548, 569740, 808356, 1062396, 1331860, 1358700, -12124, -73924, -118252, -145108, -154492, + -146404, -120844, -77812, 3140, 60668, 156116, 269036, 399428, 547292, 712628, 752588, -1176668, -1098692, + -1003244, -890324, -759932, -612068, -446732, -263924, -63644, 154108, 389332, 642028, 912196, 1199836, + 1504948, 1537228, -13500, -82660, -132300, -162420, -173020, -164100, -135660, -87700, 2788, 66780, + 173300, 299340, 444900, 609980, 794580, 839980, -1316284, -1229284, -1122764, -996724, -851164, -686084, + -501484, -297364, -73724, 169436, 432116, 714316, 1016036, 1337276, 1678036, 1715756, -14876, -91396, + -146348, -179732, -191548, -181796, -150476, -97588, 2436, 72892, 190484, 329644, 490372, 672668, 876532, + 927372, -1455900, -1359876, -1242284, -1103124, -942396, -760100, -556236, -330804, -83804, 184764, + 474900, 786604, 1119876, 1474716, 1851124, 1894284, -16252, -100132, -160396, -197044, -210076, -199492, + -165292, -107476, 2084, 79004, 207668, 359948, 535844, 735356, 958484, 1014764, -1595516, -1490468, + -1361804, -1209524, -1033628, -834116, -610988, -364244, -93884, 200092, 517684, 858892, 1223716, 1612156, + 2024212, 2072812, -17628, -108868, -174444, -214356, -228604, -217188, -180108, -117364, 1732, 85116, + 224852, 390252, 581316, 798044, 1040436, 1102156, -1735132, -1621060, -1481324, -1315924, -1124860, + -908132, -665740, -397684, -103964, 215420, 560468, 931180, 1327556, 1749596, 2197300, 2251340, -19004, + -117604, -188492, -231668, -247132, -234884, -194924, -127252, 1380, 91228, 242036, 420556, 626788, + 860732, 1122388, 1189548, -1874748, -1751652, -1600844, -1422324, -1216092, -982148, -720492, -431124, + -114044, 230748, 603252, 1003468, 1431396, 1887036, 2370388, 2429868, -20380, -126340, -202540, -248980, + -265660, -252580, -209740, -137140, 1028, 97340, 259220, 450860, 672260, 923420, 1204340, 1276940, + -2014364, -1882244, -1720364, -1528724, -1307324, -1056164, -775244, -464564, -124124, 246076, 646036, + 1075756, 1535236, 2024476, 2543476, 2608396, -21756, -135076, -216588, -266292, -284188, -270276, -224556, + -147028, 676, 103452, 276404, 481164, 717732, 986108, 1286292, 1364332, -2153980, -2012836, -1839884, + -1635124, -1398556, -1130180, -829996, -498004, -134204, 261404, 688820, 1148044, 1639076, 2161916, + 2716564, 2786924, -23132, -143812, -230636, -283604, -302716, -287972, -239372, -156916, 324, 109564, + 293588, 511468, 763204, 1048796, 1368244, 1451724, -2293596, -2143428, -1959404, -1741524, -1489788, + -1204196, -884748, -531444, -144284, 276732, 731604, 1220332, 1742916, 2299356, 2889652, 2965452, -24508, + -152548, -244684, -300916, -321244, -305668, -254188, -166804, -28, 115676, 310772, 541772, 808676, + 1111484, 1450196, 1539116, -2433212, -2274020, -2078924, -1847924, -1581020, -1278212, -939500, -564884, + -154364, 292060, 774388, 1292620, 1846756, 2436796, 3062740, 3143980, -25884, -161284, -258732, -318228, + -339772, -323364, -269004, -176692, -380, 121788, 327956, 572076, 854148, 1174172, 1532148, 1626508, + -2572828, -2404612, -2198444, -1954324, -1672252, -1352228, -994252, -598324, -164444, 307388, 817172, + 1364908, 1950596, 2574236, 3235828, 3322508, -27260, -170020, -272780, -335540, -358300, -341060, -283820, + -186580, -732, 127900, 345140, 602380, 899620, 1236860, 1614100, 1713900, -2712444, -2535204, -2317964, + -2060724, -1763484, -1426244, -1049004, -631764, -174524, 322716, 859956, 1437196, 2054436, 2711676, + 3408916, 3501036, -28636, -178756, -286828, -352852, -376828, -358756, -298636, -196468, -1084, 134012, + 362324, 632684, 945092, 1299548, 1696052, 1801292, -2852060, -2665796, -2437484, -2167124, -1854716, + -1500260, -1103756, -665204, -184604, 338044, 902740, 1509484, 2158276, 2849116, 3582004, 3679564, -30012, + -187492, -300876, -370164, -395356, -376452, -313452, -206356, -1436, 140124, 379508, 662988, 990564, + 1362236, 1778004, 1888684, -2991676, -2796388, -2557004, -2273524, -1945948, -1574276, -1158508, -698644, + -194684, 353372, 945524, 1581772, 2262116, 2986556, 3755092, 3858092, -31388, -196228, -314924, -387476, + -413884, -394148, -328268, -216244, -1788, 146236, 396692, 693292, 1036036, 1424924, 1859956, 1976076, + -3131292, -2926980, -2676524, -2379924, -2037180, -1648292, -1213260, -732084, -204764, 368700, 988308, + 1654060, 2365956, 3123996, 3928180, 4036620, -32764, -204964, -328972, -404788, -432412, -411844, -343084, + -226132, -2140, 152348, 413876, 723596, 1081508, 1487612, 1941908, 2063468, -3270908, -3057572, -2796044, + -2486324, -2128412, -1722308, -1268012, -765524, -214844, 384028, 1031092, 1726348, 2469796, 3261436, + 4101268, 4215148, -34140, -213700, -343020, -422100, -450940, -429540, -357900, -236020, -2492, 158460, + 431060, 753900, 1126980, 1550300, 2023860, 2150860, -3410524, -3188164, -2915564, -2592724, -2219644, + -1796324, -1322764, -798964, -224924, 399356, 1073876, 1798636, 2573636, 3398876, 4274356, 4393676, + -35516, -222436, -357068, -439412, -469468, -447236, -372716, -245908, -2844, 164572, 448244, 784204, + 1172452, 1612988, 2105812, 2238252, -3550140, -3318756, -3035084, -2699124, -2310876, -1870340, -1377516, + -832404, -235004, 414684, 1116660, 1870924, 2677476, 3536316, 4447444, 4572204, -36892, -231172, -371116, + -456724, -487996, -464932, -387532, -255796, -3196, 170684, 465428, 814508, 1217924, 1675676, 2187764, + 2325644, -3689756, -3449348, -3154604, -2805524, -2402108, -1944356, -1432268, -865844, -245084, 430012, + 1159444, 1943212, 2781316, 3673756, 4620532, 4750732, -38268, -239908, -385164, -474036, -506524, -482628, + -402348, -265684, -3548, 176796, 482612, 844812, 1263396, 1738364, 2269716, 2413036, -3829372, -3579940, + -3274124, -2911924, -2493340, -2018372, -1487020, -899284, -255164, 445340, 1202228, 2015500, 2885156, + 3811196, 4793620, 4929260, -39644, -248644, -399212, -491348, -525052, -500324, -417164, -275572, -3900, + 182908, 499796, 875116, 1308868, 1801052, 2351668, 2500428, -3968988, -3710532, -3393644, -3018324, + -2584572, -2092388, -1541772, -932724, -265244, 460668, 1245012, 2087788, 2988996, 3948636, 4966708, + 5107788, -41020, -257380, -413260, -508660, -543580, -518020, -431980, -285460, -4252, 189020, 516980, + 905420, 1354340, 1863740, 2433620, 2587820, -4108604, -3841124, -3513164, -3124724, -2675804, -2166404, + -1596524, -966164, -275324, 475996, 1287796, 2160076, 3092836, 4086076, 5139796, 5286316, -42396, -266116, + -427308, -525972, -562108, -535716, -446796, -295348, -4604, 195132, 534164, 935724, 1399812, 1926428, + 2515572, 2675212, -4248220, -3971716, -3632684, -3231124, -2767036, -2240420, -1651276, -999604, -285404, + 491324, 1330580, 2232364, 3196676, 4223516, 5312884, 5464844, -43772, -274852, -441356, -543284, -580636, + -553412, -461612, -305236, -4956, 201244, 551348, 966028, 1445284, 1989116, 2597524, 2762604, -4387836, + -4102308, -3752204, -3337524, -2858268, -2314436, -1706028, -1033044, -295484, 506652, 1373364, 2304652, + 3300516, 4360956, 5485972, 5643372 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [64], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63 + ] + }, + { + "dims": [32], + "type": "uint8", + "data": [ + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + -1116, -1860, -1516, -84, 2436, 6044, 10740, 16524, 23364, 31356, 40404, 50540, 61764, 74076, 87476, + 86092, -24924, -16964, -7916, 2220, 13444, 25756, 39156, 53644, 69220, 85884, 103636, 122476, 142404, + 163420, 185524, 176460, -2492, -2404, 820, 7180, 16676, 29308, 45076, 63980, 88548, 111196, 139508, + 170956, 205540, 243260, 284116, 296364, -33468, -8292, 20020, 51468, 86052, 123772, 164628, 208620, + 255748, 306012, 359412, 415948, 475620, 538428, 604372, 608940, -3868, -2948, 3156, 14444, 30916, 52572, + 79412, 111436, 153732, 191036, 238612, 291372, 349316, 412444, 480756, 506636, -42012, 380, 47956, 100716, + 158660, 221788, 290100, 363596, 442276, 526140, 615188, 709420, 808836, 913436, 1023220, 1041420, -5244, + -3492, 5492, 21708, 45156, 75836, 113748, 158892, 218916, 270876, 337716, 411788, 493092, 581628, 677396, + 716908, -50556, 9052, 75892, 149964, 231268, 319804, 415572, 518572, 628804, 746268, 870964, 1002892, + 1142052, 1288444, 1442068, 1473900, -6620, -4036, 7828, 28972, 59396, 99100, 148084, 206348, 284100, + 350716, 436820, 532204, 636868, 750812, 874036, 927180, -59100, 17724, 103828, 199212, 303876, 417820, + 541044, 673548, 815332, 966396, 1126740, 1296364, 1475268, 1663452, 1860916, 1906380, -7996, -4580, 10164, + 36236, 73636, 122364, 182420, 253804, 349284, 430556, 535924, 652620, 780644, 919996, 1070676, 1137452, + -67644, 26396, 131764, 248460, 376484, 515836, 666516, 828524, 1001860, 1186524, 1382516, 1589836, + 1808484, 2038460, 2279764, 2338860, -9372, -5124, 12500, 43500, 87876, 145628, 216756, 301260, 414468, + 510396, 635028, 773036, 924420, 1089180, 1267316, 1347724, -76188, 35068, 159700, 297708, 449092, 613852, + 791988, 983500, 1188388, 1406652, 1638292, 1883308, 2141700, 2413468, 2698612, 2771340, -10748, -5668, + 14836, 50764, 102116, 168892, 251092, 348716, 479652, 590236, 734132, 893452, 1068196, 1258364, 1463956, + 1557996, -84732, 43740, 187636, 346956, 521700, 711868, 917460, 1138476, 1374916, 1626780, 1894068, + 2176780, 2474916, 2788476, 3117460, 3203820, -12124, -6212, 17172, 58028, 116356, 192156, 285428, 396172, + 544836, 670076, 833236, 1013868, 1211972, 1427548, 1660596, 1768268, -93276, 52412, 215572, 396204, + 594308, 809884, 1042932, 1293452, 1561444, 1846908, 2149844, 2470252, 2808132, 3163484, 3536308, 3636300, + -13500, -6756, 19508, 65292, 130596, 215420, 319764, 443628, 610020, 749916, 932340, 1134284, 1355748, + 1596732, 1857236, 1978540, -101820, 61084, 243508, 445452, 666916, 907900, 1168404, 1448428, 1747972, + 2067036, 2405620, 2763724, 3141348, 3538492, 3955156, 4068780, -14876, -7300, 21844, 72556, 144836, + 238684, 354100, 491084, 675204, 829756, 1031444, 1254700, 1499524, 1765916, 2053876, 2188812, -110364, + 69756, 271444, 494700, 739524, 1005916, 1293876, 1603404, 1934500, 2287164, 2661396, 3057196, 3474564, + 3913500, 4374004, 4501260, -16252, -7844, 24180, 79820, 159076, 261948, 388436, 538540, 740388, 909596, + 1130548, 1375116, 1643300, 1935100, 2250516, 2399084, -118908, 78428, 299380, 543948, 812132, 1103932, + 1419348, 1758380, 2121028, 2507292, 2917172, 3350668, 3807780, 4288508, 4792852, 4933740, -17628, -8388, + 26516, 87084, 173316, 285212, 422772, 585996, 805572, 989436, 1229652, 1495532, 1787076, 2104284, 2447156, + 2609356, -127452, 87100, 327316, 593196, 884740, 1201948, 1544820, 1913356, 2307556, 2727420, 3172948, + 3644140, 4140996, 4663516, 5211700, 5366220, -19004, -8932, 28852, 94348, 187556, 308476, 457108, 633452, + 870756, 1069276, 1328756, 1615948, 1930852, 2273468, 2643796, 2819628, -135996, 95772, 355252, 642444, + 957348, 1299964, 1670292, 2068332, 2494084, 2947548, 3428724, 3937612, 4474212, 5038524, 5630548, 5798700, + -20380, -9476, 31188, 101612, 201796, 331740, 491444, 680908, 935940, 1149116, 1427860, 1736364, 2074628, + 2442652, 2840436, 3029900, -144540, 104444, 383188, 691692, 1029956, 1397980, 1795764, 2223308, 2680612, + 3167676, 3684500, 4231084, 4807428, 5413532, 6049396, 6231180, -21756, -10020, 33524, 108876, 216036, + 355004, 525780, 728364, 1001124, 1228956, 1526964, 1856780, 2218404, 2611836, 3037076, 3240172, -153084, + 113116, 411124, 740940, 1102564, 1495996, 1921236, 2378284, 2867140, 3387804, 3940276, 4524556, 5140644, + 5788540, 6468244, 6663660, -23132, -10564, 35860, 116140, 230276, 378268, 560116, 775820, 1066308, + 1308796, 1626068, 1977196, 2362180, 2781020, 3233716, 3450444, -161628, 121788, 439060, 790188, 1175172, + 1594012, 2046708, 2533260, 3053668, 3607932, 4196052, 4818028, 5473860, 6163548, 6887092, 7096140, -24508, + -11108, 38196, 123404, 244516, 401532, 594452, 823276, 1131492, 1388636, 1725172, 2097612, 2505956, + 2950204, 3430356, 3660716, -170172, 130460, 466996, 839436, 1247780, 1692028, 2172180, 2688236, 3240196, + 3828060, 4451828, 5111500, 5807076, 6538556, 7305940, 7528620, -25884, -11652, 40532, 130668, 258756, + 424796, 628788, 870732, 1196676, 1468476, 1824276, 2218028, 2649732, 3119388, 3626996, 3870988, -178716, + 139132, 494932, 888684, 1320388, 1790044, 2297652, 2843212, 3426724, 4048188, 4707604, 5404972, 6140292, + 6913564, 7724788, 7961100, -27260, -12196, 42868, 137932, 272996, 448060, 663124, 918188, 1261860, + 1548316, 1923380, 2338444, 2793508, 3288572, 3823636, 4081260, -187260, 147804, 522868, 937932, 1392996, + 1888060, 2423124, 2998188, 3613252, 4268316, 4963380, 5698444, 6473508, 7288572, 8143636, 8393580, -28636, + -12740, 45204, 145196, 287236, 471324, 697460, 965644, 1327044, 1628156, 2022484, 2458860, 2937284, + 3457756, 4020276, 4291532, -195804, 156476, 550804, 987180, 1465604, 1986076, 2548596, 3153164, 3799780, + 4488444, 5219156, 5991916, 6806724, 7663580, 8562484, 8826060, -30012, -13284, 47540, 152460, 301476, + 494588, 731796, 1013100, 1392228, 1707996, 2121588, 2579276, 3081060, 3626940, 4216916, 4501804, -204348, + 165148, 578740, 1036428, 1538212, 2084092, 2674068, 3308140, 3986308, 4708572, 5474932, 6285388, 7139940, + 8038588, 8981332, 9258540, -31388, -13828, 49876, 159724, 315716, 517852, 766132, 1060556, 1457412, + 1787836, 2220692, 2699692, 3224836, 3796124, 4413556, 4712076, -212892, 173820, 606676, 1085676, 1610820, + 2182108, 2799540, 3463116, 4172836, 4928700, 5730708, 6578860, 7473156, 8413596, 9400180, 9691020, -32764, + -14372, 52212, 166988, 329956, 541116, 800468, 1108012, 1522596, 1867676, 2319796, 2820108, 3368612, + 3965308, 4610196, 4922348, -221436, 182492, 634612, 1134924, 1683428, 2280124, 2925012, 3618092, 4359364, + 5148828, 5986484, 6872332, 7806372, 8788604, 9819028, 10123500, -34140, -14916, 54548, 174252, 344196, + 564380, 834804, 1155468, 1587780, 1947516, 2418900, 2940524, 3512388, 4134492, 4806836, 5132620, -229980, + 191164, 662548, 1184172, 1756036, 2378140, 3050484, 3773068, 4545892, 5368956, 6242260, 7165804, 8139588, + 9163612, 10237876, 10555980, -35516, -15460, 56884, 181516, 358436, 587644, 869140, 1202924, 1652964, + 2027356, 2518004, 3060940, 3656164, 4303676, 5003476, 5342892, -238524, 199836, 690484, 1233420, 1828644, + 2476156, 3175956, 3928044, 4732420, 5589084, 6498036, 7459276, 8472804, 9538620, 10656724, 10988460, + -36892, -16004, 59220, 188780, 372676, 610908, 903476, 1250380, 1718148, 2107196, 2617108, 3181356, + 3799940, 4472860, 5200116, 5553164, -247068, 208508, 718420, 1282668, 1901252, 2574172, 3301428, 4083020, + 4918948, 5809212, 6753812, 7752748, 8806020, 9913628, 11075572, 11420940, -38268, -16548, 61556, 196044, + 386916, 634172, 937812, 1297836, 1783332, 2187036, 2716212, 3301772, 3943716, 4642044, 5396756, 5763436, + -255612, 217180, 746356, 1331916, 1973860, 2672188, 3426900, 4237996, 5105476, 6029340, 7009588, 8046220, + 9139236, 10288636, 11494420, 11853420, -39644, -17092, 63892, 203308, 401156, 657436, 972148, 1345292, + 1848516, 2266876, 2815316, 3422188, 4087492, 4811228, 5593396, 5973708, -264156, 225852, 774292, 1381164, + 2046468, 2770204, 3552372, 4392972, 5292004, 6249468, 7265364, 8339692, 9472452, 10663644, 11913268, + 12285900, -41020, -17636, 66228, 210572, 415396, 680700, 1006484, 1392748, 1913700, 2346716, 2914420, + 3542604, 4231268, 4980412, 5790036, 6183980, -272700, 234524, 802228, 1430412, 2119076, 2868220, 3677844, + 4547948, 5478532, 6469596, 7521140, 8633164, 9805668, 11038652, 12332116, 12718380, -42396, -18180, 68564, + 217836, 429636, 703964, 1040820, 1440204, 1978884, 2426556, 3013524, 3663020, 4375044, 5149596, 5986676, + 6394252, -281244, 243196, 830164, 1479660, 2191684, 2966236, 3803316, 4702924, 5665060, 6689724, 7776916, + 8926636, 10138884, 11413660, 12750964, 13150860, -43772, -18724, 70900, 225100, 443876, 727228, 1075156, + 1487660, 2044068, 2506396, 3112628, 3783436, 4518820, 5318780, 6183316, 6604524, -289788, 251868, 858100, + 1528908, 2264292, 3064252, 3928788, 4857900, 5851588, 6909852, 8032692, 9220108, 10472100, 11788668, + 13169812, 13583340 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 32, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 1, 16], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, -1560, -2576, -3048, -2976, -2360, -1200, 504, 2736, 5544, 8880, 12760, 17184, 22152, 27664, 26040, + -29312, -26520, -23184, -19304, -14880, -9912, -4400, 1656, 8256, 15400, 23088, 31320, 40096, 49416, + 59280, 53816, 0, -5368, -9168, -11400, -12064, -11160, -8688, -4648, 2224, 8136, 16880, 27192, 39072, + 52520, 67536, 68760, -98432, -91256, -82512, -72200, -60320, -46872, -31856, -15272, 2880, 22600, 43888, + 66744, 91168, 117160, 144720, 142104, 0, -9176, -15760, -19752, -21152, -19960, -16176, -9800, 1712, + 10728, 24880, 41624, 60960, 82888, 107408, 111480, -167552, -155992, -141840, -125096, -105760, -83832, + -59312, -32200, -2496, 29800, 64688, 102168, 142240, 184904, 230160, 230392, 0, -12984, -22352, -28104, + -30240, -28760, -23664, -14952, 1200, 13320, 32880, 56056, 82848, 113256, 147280, 154200, -236672, + -220728, -201168, -177992, -151200, -120792, -86768, -49128, -7872, 37000, 85488, 137592, 193312, 252648, + 315600, 318680, 0, -16792, -28944, -36456, -39328, -37560, -31152, -20104, 688, 15912, 40880, 70488, + 104736, 143624, 187152, 196920, -305792, -285464, -260496, -230888, -196640, -157752, -114224, -66056, + -13248, 44200, 106288, 173016, 244384, 320392, 401040, 406968, 0, -20600, -35536, -44808, -48416, -46360, + -38640, -25256, 176, 18504, 48880, 84920, 126624, 173992, 227024, 239640, -374912, -350200, -319824, + -283784, -242080, -194712, -141680, -82984, -18624, 51400, 127088, 208440, 295456, 388136, 486480, 495256, + 0, -24408, -42128, -53160, -57504, -55160, -46128, -30408, -336, 21096, 56880, 99352, 148512, 204360, + 266896, 282360, -444032, -414936, -379152, -336680, -287520, -231672, -169136, -99912, -24000, 58600, + 147888, 243864, 346528, 455880, 571920, 583544, 0, -28216, -48720, -61512, -66592, -63960, -53616, -35560, + -848, 23688, 64880, 113784, 170400, 234728, 306768, 325080, -513152, -479672, -438480, -389576, -332960, + -268632, -196592, -116840, -29376, 65800, 168688, 279288, 397600, 523624, 657360, 671832, 0, -32024, + -55312, -69864, -75680, -72760, -61104, -40712, -1360, 26280, 72880, 128216, 192288, 265096, 346640, + 367800, -582272, -544408, -497808, -442472, -378400, -305592, -224048, -133768, -34752, 73000, 189488, + 314712, 448672, 591368, 742800, 760120, 0, -35832, -61904, -78216, -84768, -81560, -68592, -45864, -1872, + 28872, 80880, 142648, 214176, 295464, 386512, 410520, -651392, -609144, -557136, -495368, -423840, + -342552, -251504, -150696, -40128, 80200, 210288, 350136, 499744, 659112, 828240, 848408, 0, -39640, + -68496, -86568, -93856, -90360, -76080, -51016, -2384, 31464, 88880, 157080, 236064, 325832, 426384, + 453240, -720512, -673880, -616464, -548264, -469280, -379512, -278960, -167624, -45504, 87400, 231088, + 385560, 550816, 726856, 913680, 936696, 0, -43448, -75088, -94920, -102944, -99160, -83568, -56168, -2896, + 34056, 96880, 171512, 257952, 356200, 466256, 495960, -789632, -738616, -675792, -601160, -514720, + -416472, -306416, -184552, -50880, 94600, 251888, 420984, 601888, 794600, 999120, 1024984, 0, -47256, + -81680, -103272, -112032, -107960, -91056, -61320, -3408, 36648, 104880, 185944, 279840, 386568, 506128, + 538680, -858752, -803352, -735120, -654056, -560160, -453432, -333872, -201480, -56256, 101800, 272688, + 456408, 652960, 862344, 1084560, 1113272, 0, -51064, -88272, -111624, -121120, -116760, -98544, -66472, + -3920, 39240, 112880, 200376, 301728, 416936, 546000, 581400, -927872, -868088, -794448, -706952, -605600, + -490392, -361328, -218408, -61632, 109000, 293488, 491832, 704032, 930088, 1170000, 1201560, 0, -54872, + -94864, -119976, -130208, -125560, -106032, -71624, -4432, 41832, 120880, 214808, 323616, 447304, 585872, + 624120, -996992, -932824, -853776, -759848, -651040, -527352, -388784, -235336, -67008, 116200, 314288, + 527256, 755104, 997832, 1255440, 1289848, 0, -58680, -101456, -128328, -139296, -134360, -113520, -76776, + -4944, 44424, 128880, 229240, 345504, 477672, 625744, 666840, -1066112, -997560, -913104, -812744, + -696480, -564312, -416240, -252264, -72384, 123400, 335088, 562680, 806176, 1065576, 1340880, 1378136, 0, + -62488, -108048, -136680, -148384, -143160, -121008, -81928, -5456, 47016, 136880, 243672, 367392, 508040, + 665616, 709560, -1135232, -1062296, -972432, -865640, -741920, -601272, -443696, -269192, -77760, 130600, + 355888, 598104, 857248, 1133320, 1426320, 1466424, 0, -66296, -114640, -145032, -157472, -151960, -128496, + -87080, -5968, 49608, 144880, 258104, 389280, 538408, 705488, 752280, -1204352, -1127032, -1031760, + -918536, -787360, -638232, -471152, -286120, -83136, 137800, 376688, 633528, 908320, 1201064, 1511760, + 1554712, 0, -70104, -121232, -153384, -166560, -160760, -135984, -92232, -6480, 52200, 152880, 272536, + 411168, 568776, 745360, 795000, -1273472, -1191768, -1091088, -971432, -832800, -675192, -498608, -303048, + -88512, 145000, 397488, 668952, 959392, 1268808, 1597200, 1643000, 0, -73912, -127824, -161736, -175648, + -169560, -143472, -97384, -6992, 54792, 160880, 286968, 433056, 599144, 785232, 837720, -1342592, + -1256504, -1150416, -1024328, -878240, -712152, -526064, -319976, -93888, 152200, 418288, 704376, 1010464, + 1336552, 1682640, 1731288, 0, -77720, -134416, -170088, -184736, -178360, -150960, -102536, -7504, 57384, + 168880, 301400, 454944, 629512, 825104, 880440, -1411712, -1321240, -1209744, -1077224, -923680, -749112, + -553520, -336904, -99264, 159400, 439088, 739800, 1061536, 1404296, 1768080, 1819576, 0, -81528, -141008, + -178440, -193824, -187160, -158448, -107688, -8016, 59976, 176880, 315832, 476832, 659880, 864976, 923160, + -1480832, -1385976, -1269072, -1130120, -969120, -786072, -580976, -353832, -104640, 166600, 459888, + 775224, 1112608, 1472040, 1853520, 1907864, 0, -85336, -147600, -186792, -202912, -195960, -165936, + -112840, -8528, 62568, 184880, 330264, 498720, 690248, 904848, 965880, -1549952, -1450712, -1328400, + -1183016, -1014560, -823032, -608432, -370760, -110016, 173800, 480688, 810648, 1163680, 1539784, 1938960, + 1996152, 0, -89144, -154192, -195144, -212000, -204760, -173424, -117992, -9040, 65160, 192880, 344696, + 520608, 720616, 944720, 1008600, -1619072, -1515448, -1387728, -1235912, -1060000, -859992, -635888, + -387688, -115392, 181000, 501488, 846072, 1214752, 1607528, 2024400, 2084440, 0, -92952, -160784, -203496, + -221088, -213560, -180912, -123144, -9552, 67752, 200880, 359128, 542496, 750984, 984592, 1051320, + -1688192, -1580184, -1447056, -1288808, -1105440, -896952, -663344, -404616, -120768, 188200, 522288, + 881496, 1265824, 1675272, 2109840, 2172728, 0, -96760, -167376, -211848, -230176, -222360, -188400, + -128296, -10064, 70344, 208880, 373560, 564384, 781352, 1024464, 1094040, -1757312, -1644920, -1506384, + -1341704, -1150880, -933912, -690800, -421544, -126144, 195400, 543088, 916920, 1316896, 1743016, 2195280, + 2261016, 0, -100568, -173968, -220200, -239264, -231160, -195888, -133448, -10576, 72936, 216880, 387992, + 586272, 811720, 1064336, 1136760, -1826432, -1709656, -1565712, -1394600, -1196320, -970872, -718256, + -438472, -131520, 202600, 563888, 952344, 1367968, 1810760, 2280720, 2349304, 0, -104376, -180560, + -228552, -248352, -239960, -203376, -138600, -11088, 75528, 224880, 402424, 608160, 842088, 1104208, + 1179480, -1895552, -1774392, -1625040, -1447496, -1241760, -1007832, -745712, -455400, -136896, 209800, + 584688, 987768, 1419040, 1878504, 2366160, 2437592, 0, -108184, -187152, -236904, -257440, -248760, + -210864, -143752, -11600, 78120, 232880, 416856, 630048, 872456, 1144080, 1222200, -1964672, -1839128, + -1684368, -1500392, -1287200, -1044792, -773168, -472328, -142272, 217000, 605488, 1023192, 1470112, + 1946248, 2451600, 2525880, 0, -111992, -193744, -245256, -266528, -257560, -218352, -148904, -12112, + 80712, 240880, 431288, 651936, 902824, 1183952, 1264920, -2033792, -1903864, -1743696, -1553288, -1332640, + -1081752, -800624, -489256, -147648, 224200, 626288, 1058616, 1521184, 2013992, 2537040, 2614168, 0, + -115800, -200336, -253608, -275616, -266360, -225840, -154056, -12624, 83304, 248880, 445720, 673824, + 933192, 1223824, 1307640, -2102912, -1968600, -1803024, -1606184, -1378080, -1118712, -828080, -506184, + -153024, 231400, 647088, 1094040, 1572256, 2081736, 2622480, 2702456, 0, -119608, -206928, -261960, + -284704, -275160, -233328, -159208, -13136, 85896, 256880, 460152, 695712, 963560, 1263696, 1350360, + -2172032, -2033336, -1862352, -1659080, -1423520, -1155672, -855536, -523112, -158400, 238600, 667888, + 1129464, 1623328, 2149480, 2707920, 2790744 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 32, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 1, 16], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + }, + { + "dims": [32], + "type": "uint8", + "data": [ + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, 2664, 5872, 9624, 13920, 18760, 24144, 30072, 36528, 43560, 51120, 59224, 67872, 77064, 86800, 89400, + 38272, 45288, 52848, 60952, 69600, 78792, 88528, 98808, 109632, 121000, 132912, 145368, 158368, 171912, + 186000, 184760, 0, 7048, 15664, 25848, 37600, 50920, 65808, 82264, 101552, 119880, 141040, 163768, 188064, + 213928, 241360, 255000, 100224, 119816, 140976, 163704, 188000, 213864, 241296, 270296, 300864, 333000, + 366704, 401976, 438816, 477224, 517200, 527000, 0, 11432, 25456, 42072, 61280, 83080, 107472, 134456, + 166576, 196200, 230960, 268312, 308256, 350792, 395920, 420600, 162176, 194344, 229104, 266456, 306400, + 348936, 394064, 441784, 492096, 545000, 600496, 658584, 719264, 782536, 848400, 869240, 0, 15816, 35248, + 58296, 84960, 115240, 149136, 186648, 231600, 272520, 320880, 372856, 428448, 487656, 550480, 586200, + 224128, 268872, 317232, 369208, 424800, 484008, 546832, 613272, 683328, 757000, 834288, 915192, 999712, + 1087848, 1179600, 1211480, 0, 20200, 45040, 74520, 108640, 147400, 190800, 238840, 296624, 348840, 410800, + 477400, 548640, 624520, 705040, 751800, 286080, 343400, 405360, 471960, 543200, 619080, 699600, 784760, + 874560, 969000, 1068080, 1171800, 1280160, 1393160, 1510800, 1553720, 0, 24584, 54832, 90744, 132320, + 179560, 232464, 291032, 361648, 425160, 500720, 581944, 668832, 761384, 859600, 917400, 348032, 417928, + 493488, 574712, 661600, 754152, 852368, 956248, 1065792, 1181000, 1301872, 1428408, 1560608, 1698472, + 1842000, 1895960, 0, 28968, 64624, 106968, 156000, 211720, 274128, 343224, 426672, 501480, 590640, 686488, + 789024, 898248, 1014160, 1083000, 409984, 492456, 581616, 677464, 780000, 889224, 1005136, 1127736, + 1257024, 1393000, 1535664, 1685016, 1841056, 2003784, 2173200, 2238200, 0, 33352, 74416, 123192, 179680, + 243880, 315792, 395416, 491696, 577800, 680560, 791032, 909216, 1035112, 1168720, 1248600, 471936, 566984, + 669744, 780216, 898400, 1024296, 1157904, 1299224, 1448256, 1605000, 1769456, 1941624, 2121504, 2309096, + 2504400, 2580440, 0, 37736, 84208, 139416, 203360, 276040, 357456, 447608, 556720, 654120, 770480, 895576, + 1029408, 1171976, 1323280, 1414200, 533888, 641512, 757872, 882968, 1016800, 1159368, 1310672, 1470712, + 1639488, 1817000, 2003248, 2198232, 2401952, 2614408, 2835600, 2922680, 0, 42120, 94000, 155640, 227040, + 308200, 399120, 499800, 621744, 730440, 860400, 1000120, 1149600, 1308840, 1477840, 1579800, 595840, + 716040, 846000, 985720, 1135200, 1294440, 1463440, 1642200, 1830720, 2029000, 2237040, 2454840, 2682400, + 2919720, 3166800, 3264920, 0, 46504, 103792, 171864, 250720, 340360, 440784, 551992, 686768, 806760, + 950320, 1104664, 1269792, 1445704, 1632400, 1745400, 657792, 790568, 934128, 1088472, 1253600, 1429512, + 1616208, 1813688, 2021952, 2241000, 2470832, 2711448, 2962848, 3225032, 3498000, 3607160, 0, 50888, + 113584, 188088, 274400, 372520, 482448, 604184, 751792, 883080, 1040240, 1209208, 1389984, 1582568, + 1786960, 1911000, 719744, 865096, 1022256, 1191224, 1372000, 1564584, 1768976, 1985176, 2213184, 2453000, + 2704624, 2968056, 3243296, 3530344, 3829200, 3949400, 0, 55272, 123376, 204312, 298080, 404680, 524112, + 656376, 816816, 959400, 1130160, 1313752, 1510176, 1719432, 1941520, 2076600, 781696, 939624, 1110384, + 1293976, 1490400, 1699656, 1921744, 2156664, 2404416, 2665000, 2938416, 3224664, 3523744, 3835656, + 4160400, 4291640, 0, 59656, 133168, 220536, 321760, 436840, 565776, 708568, 881840, 1035720, 1220080, + 1418296, 1630368, 1856296, 2096080, 2242200, 843648, 1014152, 1198512, 1396728, 1608800, 1834728, 2074512, + 2328152, 2595648, 2877000, 3172208, 3481272, 3804192, 4140968, 4491600, 4633880, 0, 64040, 142960, 236760, + 345440, 469000, 607440, 760760, 946864, 1112040, 1310000, 1522840, 1750560, 1993160, 2250640, 2407800, + 905600, 1088680, 1286640, 1499480, 1727200, 1969800, 2227280, 2499640, 2786880, 3089000, 3406000, 3737880, + 4084640, 4446280, 4822800, 4976120, 0, 68424, 152752, 252984, 369120, 501160, 649104, 812952, 1011888, + 1188360, 1399920, 1627384, 1870752, 2130024, 2405200, 2573400, 967552, 1163208, 1374768, 1602232, 1845600, + 2104872, 2380048, 2671128, 2978112, 3301000, 3639792, 3994488, 4365088, 4751592, 5154000, 5318360, 0, + 72808, 162544, 269208, 392800, 533320, 690768, 865144, 1076912, 1264680, 1489840, 1731928, 1990944, + 2266888, 2559760, 2739000, 1029504, 1237736, 1462896, 1704984, 1964000, 2239944, 2532816, 2842616, + 3169344, 3513000, 3873584, 4251096, 4645536, 5056904, 5485200, 5660600, 0, 77192, 172336, 285432, 416480, + 565480, 732432, 917336, 1141936, 1341000, 1579760, 1836472, 2111136, 2403752, 2714320, 2904600, 1091456, + 1312264, 1551024, 1807736, 2082400, 2375016, 2685584, 3014104, 3360576, 3725000, 4107376, 4507704, + 4925984, 5362216, 5816400, 6002840, 0, 81576, 182128, 301656, 440160, 597640, 774096, 969528, 1206960, + 1417320, 1669680, 1941016, 2231328, 2540616, 2868880, 3070200, 1153408, 1386792, 1639152, 1910488, + 2200800, 2510088, 2838352, 3185592, 3551808, 3937000, 4341168, 4764312, 5206432, 5667528, 6147600, + 6345080, 0, 85960, 191920, 317880, 463840, 629800, 815760, 1021720, 1271984, 1493640, 1759600, 2045560, + 2351520, 2677480, 3023440, 3235800, 1215360, 1461320, 1727280, 2013240, 2319200, 2645160, 2991120, + 3357080, 3743040, 4149000, 4574960, 5020920, 5486880, 5972840, 6478800, 6687320, 0, 90344, 201712, 334104, + 487520, 661960, 857424, 1073912, 1337008, 1569960, 1849520, 2150104, 2471712, 2814344, 3178000, 3401400, + 1277312, 1535848, 1815408, 2115992, 2437600, 2780232, 3143888, 3528568, 3934272, 4361000, 4808752, + 5277528, 5767328, 6278152, 6810000, 7029560, 0, 94728, 211504, 350328, 511200, 694120, 899088, 1126104, + 1402032, 1646280, 1939440, 2254648, 2591904, 2951208, 3332560, 3567000, 1339264, 1610376, 1903536, + 2218744, 2556000, 2915304, 3296656, 3700056, 4125504, 4573000, 5042544, 5534136, 6047776, 6583464, + 7141200, 7371800, 0, 99112, 221296, 366552, 534880, 726280, 940752, 1178296, 1467056, 1722600, 2029360, + 2359192, 2712096, 3088072, 3487120, 3732600, 1401216, 1684904, 1991664, 2321496, 2674400, 3050376, + 3449424, 3871544, 4316736, 4785000, 5276336, 5790744, 6328224, 6888776, 7472400, 7714040, 0, 103496, + 231088, 382776, 558560, 758440, 982416, 1230488, 1532080, 1798920, 2119280, 2463736, 2832288, 3224936, + 3641680, 3898200, 1463168, 1759432, 2079792, 2424248, 2792800, 3185448, 3602192, 4043032, 4507968, + 4997000, 5510128, 6047352, 6608672, 7194088, 7803600, 8056280, 0, 107880, 240880, 399000, 582240, 790600, + 1024080, 1282680, 1597104, 1875240, 2209200, 2568280, 2952480, 3361800, 3796240, 4063800, 1525120, + 1833960, 2167920, 2527000, 2911200, 3320520, 3754960, 4214520, 4699200, 5209000, 5743920, 6303960, + 6889120, 7499400, 8134800, 8398520, 0, 112264, 250672, 415224, 605920, 822760, 1065744, 1334872, 1662128, + 1951560, 2299120, 2672824, 3072672, 3498664, 3950800, 4229400, 1587072, 1908488, 2256048, 2629752, + 3029600, 3455592, 3907728, 4386008, 4890432, 5421000, 5977712, 6560568, 7169568, 7804712, 8466000, + 8740760, 0, 116648, 260464, 431448, 629600, 854920, 1107408, 1387064, 1727152, 2027880, 2389040, 2777368, + 3192864, 3635528, 4105360, 4395000, 1649024, 1983016, 2344176, 2732504, 3148000, 3590664, 4060496, + 4557496, 5081664, 5633000, 6211504, 6817176, 7450016, 8110024, 8797200, 9083000, 0, 121032, 270256, + 447672, 653280, 887080, 1149072, 1439256, 1792176, 2104200, 2478960, 2881912, 3313056, 3772392, 4259920, + 4560600, 1710976, 2057544, 2432304, 2835256, 3266400, 3725736, 4213264, 4728984, 5272896, 5845000, + 6445296, 7073784, 7730464, 8415336, 9128400, 9425240, 0, 125416, 280048, 463896, 676960, 919240, 1190736, + 1491448, 1857200, 2180520, 2568880, 2986456, 3433248, 3909256, 4414480, 4726200, 1772928, 2132072, + 2520432, 2938008, 3384800, 3860808, 4366032, 4900472, 5464128, 6057000, 6679088, 7330392, 8010912, + 8720648, 9459600, 9767480, 0, 129800, 289840, 480120, 700640, 951400, 1232400, 1543640, 1922224, 2256840, + 2658800, 3091000, 3553440, 4046120, 4569040, 4891800, 1834880, 2206600, 2608560, 3040760, 3503200, + 3995880, 4518800, 5071960, 5655360, 6269000, 6912880, 7587000, 8291360, 9025960, 9790800, 10109720, 0, + 134184, 299632, 496344, 724320, 983560, 1274064, 1595832, 1987248, 2333160, 2748720, 3195544, 3673632, + 4182984, 4723600, 5057400, 1896832, 2281128, 2696688, 3143512, 3621600, 4130952, 4671568, 5243448, + 5846592, 6481000, 7146672, 7843608, 8571808, 9331272, 10122000, 10451960, 0, 138568, 309424, 512568, + 748000, 1015720, 1315728, 1648024, 2052272, 2409480, 2838640, 3300088, 3793824, 4319848, 4878160, 5223000, + 1958784, 2355656, 2784816, 3246264, 3740000, 4266024, 4824336, 5414936, 6037824, 6693000, 7380464, + 8100216, 8852256, 9636584, 10453200, 10794200 + ] + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/where.jsonc b/js/web/test/data/ops/where.jsonc index 047fd6fd7511b..990120dd3708e 100644 --- a/js/web/test/data/ops/where.jsonc +++ b/js/web/test/data/ops/where.jsonc @@ -168,5 +168,39 @@ ] } ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[1 1 2 1] T[1 4] T[1 1 2 4] float32 broadcast 1", + "inputs": [ + { + "data": [true, false], + "dims": [1, 1, 2, 1], + "type": "bool" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 4], + "type": "float32" + }, + { + "data": [5, 6, 7, 8, 9, 10, 11, 12], + "dims": [1, 1, 2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 4, 9, 10, 11, 12], + "dims": [1, 1, 2, 4], + "type": "float32" + } + ] + } + ] } ] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 56db28b0a379c..e96a0aa045bc8 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1231,7 +1231,7 @@ "test_split_variable_parts_1d", "test_split_variable_parts_2d", "test_split_variable_parts_default_axis", - // // "test_split_zero_size_splits", + "test_split_zero_size_splits", "test_sqrt_example", "test_sqrt", "test_squeeze_negative_axes", @@ -1334,6 +1334,7 @@ "acos.jsonc", "add.jsonc", "add_int32.jsonc", + "add_zero-sized.jsonc", //"and.jsonc", "asin.jsonc", "attention.jsonc", @@ -1343,6 +1344,7 @@ "ceil.jsonc", "concat.jsonc", "concat_int32.jsonc", + "concat_zero-sized.jsonc", "cast.jsonc", "conv.jsonc", "cos.jsonc", @@ -1352,7 +1354,9 @@ "equal.jsonc", "exp.jsonc", "expand.jsonc", + "fast-gelu.jsonc", "floor.jsonc", + "fused-conv.jsonc", "gather-elements.jsonc", "gemm.jsonc", "global-average-pool.jsonc", @@ -1361,6 +1365,7 @@ "less.jsonc", "log.jsonc", "matmul.jsonc", + "matmulnbits.jsonc", "matmul-broadcast.jsonc", "mul.jsonc", "mul_int32.jsonc", diff --git a/js/web/test/test-main.ts b/js/web/test/test-main.ts index 2d83ce1e095ce..96e374f87aed1 100644 --- a/js/web/test/test-main.ts +++ b/js/web/test/test-main.ts @@ -19,49 +19,7 @@ if (ORT_WEB_TEST_CONFIG.model.some(testGroup => testGroup.tests.some(test => tes } // set flags -const options = ORT_WEB_TEST_CONFIG.options; -if (options.debug !== undefined) { - ort.env.debug = options.debug; -} -if (options.globalEnvFlags) { - const flags = options.globalEnvFlags; - if (flags.logLevel !== undefined) { - ort.env.logLevel = flags.logLevel; - } - if (flags.webgl?.contextId !== undefined) { - ort.env.webgl.contextId = flags.webgl.contextId; - } - if (flags.webgl?.matmulMaxBatchSize !== undefined) { - ort.env.webgl.matmulMaxBatchSize = flags.webgl.matmulMaxBatchSize; - } - if (flags.webgl?.textureCacheMode !== undefined) { - ort.env.webgl.textureCacheMode = flags.webgl.textureCacheMode; - } - if (flags.webgl?.pack !== undefined) { - ort.env.webgl.pack = flags.webgl.pack; - } - if (flags.webgl?.async !== undefined) { - ort.env.webgl.async = flags.webgl.async; - } - if (flags.wasm?.numThreads !== undefined) { - ort.env.wasm.numThreads = flags.wasm.numThreads; - } - if (flags.wasm?.simd !== undefined) { - ort.env.wasm.simd = flags.wasm.simd; - } - if (flags.wasm?.proxy !== undefined) { - ort.env.wasm.proxy = flags.wasm.proxy; - } - if (flags.wasm?.initTimeout !== undefined) { - ort.env.wasm.initTimeout = flags.wasm.initTimeout; - } - if (flags.webgpu?.profilingMode !== undefined) { - ort.env.webgpu.profiling = {mode: flags.webgpu.profilingMode}; - } - if (flags.webgpu?.validateInputContent !== undefined) { - ort.env.webgpu.validateInputContent = flags.webgpu.validateInputContent; - } -} +Object.assign(ort.env, ORT_WEB_TEST_CONFIG.options.globalEnvFlags); // Set logging configuration for (const logConfig of ORT_WEB_TEST_CONFIG.log) { diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 442cb1bcf1f34..7c03e5b915fd7 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -39,10 +39,6 @@ const ONNXRUNTIME_THRESHOLD_RELATIVE_ERROR = 1.00001; */ const now = (typeof performance !== 'undefined' && performance.now) ? () => performance.now() : Date.now; -function toInternalTensor(tensor: ort.Tensor): Tensor { - return new Tensor( - tensor.dims, tensor.type as Tensor.DataType, undefined, undefined, tensor.data as Tensor.NumberType); -} function fromInternalTensor(tensor: Tensor): ort.Tensor { return new ort.Tensor(tensor.type, tensor.data as ort.Tensor.DataType, tensor.dims); } @@ -138,8 +134,8 @@ async function loadTensors( async function initializeSession( modelFilePath: string, backendHint: ort.InferenceSession.ExecutionProviderConfig, ioBindingMode: Test.IOBindingMode, - profile: boolean, sessionOptions: ort.InferenceSession.SessionOptions, - fileCache?: FileCacheBuffer): Promise { + profile: boolean, externalData: ort.InferenceSession.SessionOptions['externalData'], + sessionOptions: ort.InferenceSession.SessionOptions, fileCache?: FileCacheBuffer): Promise { const preloadModelData: Uint8Array|undefined = fileCache && fileCache[modelFilePath] ? fileCache[modelFilePath] : undefined; Logger.verbose( @@ -153,7 +149,8 @@ async function initializeSession( executionProviders: [backendHint], profiler: profilerConfig, enableProfiling: profile, - preferredOutputLocation: ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined + preferredOutputLocation: ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined, + externalData }; let session: ort.InferenceSession; @@ -246,8 +243,8 @@ export class ModelTestContext { const executionProviderConfig = modelTest.backend === 'webnn' ? (testOptions?.webnnOptions || 'webnn') : modelTest.backend!; const session = await initializeSession( - modelTest.modelUrl, executionProviderConfig, modelTest.ioBinding, profile, testOptions?.sessionOptions || {}, - this.cache); + modelTest.modelUrl, executionProviderConfig, modelTest.ioBinding, profile, modelTest.externalData, + testOptions?.sessionOptions || {}, this.cache); const initEnd = now(); @@ -329,6 +326,10 @@ export class TensorResultValidator { } checkTensorResult(actual: Tensor[], expected: Tensor[]): void { + this.checkApiTensorResult(actual.map(fromInternalTensor), expected.map(fromInternalTensor)); + } + + checkApiTensorResult(actual: ort.Tensor[], expected: ort.Tensor[]): void { // check output size expect(actual.length, 'size of output tensors').to.equal(expected.length); @@ -346,10 +347,6 @@ export class TensorResultValidator { } } - checkApiTensorResult(actual: ort.Tensor[], expected: ort.Tensor[]): void { - this.checkTensorResult(actual.map(toInternalTensor), expected.map(toInternalTensor)); - } - checkNamedTensorResult(actual: Record, expected: Test.NamedTensor[]): void { // check output size expect(Object.getOwnPropertyNames(actual).length, 'size of output tensors').to.equal(expected.length); @@ -363,7 +360,7 @@ export class TensorResultValidator { } // This function check whether 2 tensors should be considered as 'match' or not - areEqual(actual: Tensor, expected: Tensor): boolean { + areEqual(actual: ort.Tensor, expected: ort.Tensor): boolean { if (!actual || !expected) { return false; } @@ -391,13 +388,13 @@ export class TensorResultValidator { switch (actualType) { case 'string': - return this.strictEqual(actual.stringData, expected.stringData); + return this.strictEqual(actual.data, expected.data); case 'float32': case 'float64': return this.floatEqual( - actual.numberData as number[] | Float32Array | Float64Array, - expected.numberData as number[] | Float32Array | Float64Array); + actual.data as number[] | Float32Array | Float64Array, + expected.data as number[] | Float32Array | Float64Array); case 'uint8': case 'int8': @@ -408,10 +405,8 @@ export class TensorResultValidator { case 'int64': case 'bool': return TensorResultValidator.integerEqual( - actual.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | - Int32Array, - expected.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | - Int32Array); + actual.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array, + expected.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array); default: throw new Error('type not implemented or not supported'); @@ -578,7 +573,9 @@ export async function sessionRun(options: { // replace the CPU tensors in feeds into GPU tensors for (const name in feeds) { if (Object.hasOwnProperty.call(feeds, name)) { - feeds[name] = createGpuTensorForInput(feeds[name]); + if (feeds[name].size > 0) { + feeds[name] = createGpuTensorForInput(feeds[name]); + } } } } @@ -587,7 +584,11 @@ export async function sessionRun(options: { for (const name in options.outputsMetaInfo) { if (Object.hasOwnProperty.call(options.outputsMetaInfo, name)) { const {type, dims} = options.outputsMetaInfo[name]; - fetches[name] = createGpuTensorForOutput(type, dims); + if (dims.some(d => d === 0)) { + fetches[name] = new ort.Tensor(type, [], dims); + } else { + fetches[name] = createGpuTensorForOutput(type, dims); + } } } } @@ -632,8 +633,8 @@ export async function runModelTestSet( try { const feeds: Record = {}; const outputsMetaInfo: Record = {}; - testCase.inputs!.forEach((tensor, i) => feeds[context.session.inputNames[i]] = tensor); - testCase.outputs!.forEach((tensor, i) => outputsMetaInfo[context.session.outputNames[i]] = tensor); + testCase.inputs!.forEach((tensor) => feeds[tensor.name] = tensor); + testCase.outputs!.forEach((tensor) => outputsMetaInfo[tensor.name] = tensor); const [start, end, outputs] = await sessionRun({session: context.session, feeds, outputsMetaInfo, ioBinding: context.ioBinding}); if (context.perfData.count === 0) { diff --git a/js/web/test/test-types.ts b/js/web/test/test-types.ts index cd008e82e570b..14b9fd7c005ab 100644 --- a/js/web/test/test-types.ts +++ b/js/web/test/test-types.ts @@ -65,6 +65,7 @@ export declare namespace Test { export interface ModelTest { name: string; modelUrl: string; + externalData?: InferenceSession.SessionOptions['externalData']; backend?: string; // value should be populated at build time ioBinding: IOBindingMode; platformCondition?: PlatformCondition; diff --git a/js/web/test/unittests/backends/webgl/test-conv-new.ts b/js/web/test/unittests/backends/webgl/test-conv-new.ts index 8c186b9b36451..014fc57f21558 100644 --- a/js/web/test/unittests/backends/webgl/test-conv-new.ts +++ b/js/web/test/unittests/backends/webgl/test-conv-new.ts @@ -893,7 +893,9 @@ describe('New Conv tests', () => { const expected = cpuConv( inputTensor, kernelTensor, biasTensor, testData.autoPad, testData.dilations, testData.pads, testData.strides); - if (!validator.areEqual(actual, expected)) { + try { + validator.checkTensorResult([actual], [expected]); + } catch { console.log(actual.dims, `[${actual.numberData.slice(0, 20).join(',')},...]`); console.log(expected.dims, `[${expected.numberData.slice(0, 20).join(',')},...]`); throw new Error('Expected and Actual did not match'); diff --git a/objectivec/include/ort_coreml_execution_provider.h b/objectivec/include/ort_coreml_execution_provider.h index a015b6fd60c8f..6ff18176ebeb2 100644 --- a/objectivec/include/ort_coreml_execution_provider.h +++ b/objectivec/include/ort_coreml_execution_provider.h @@ -41,6 +41,17 @@ NS_ASSUME_NONNULL_BEGIN */ @property BOOL onlyEnableForDevicesWithANE; +/** + * Only allow CoreML EP to take nodes with inputs with static shapes. By default it will also allow inputs with + * dynamic shapes. However, the performance may be negatively impacted if inputs have dynamic shapes. + */ +@property BOOL onlyAllowStaticInputShapes; + +/** + * Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or later. + */ +@property BOOL createMLProgram; + @end @interface ORTSessionOptions (ORTSessionOptionsCoreMLEP) diff --git a/objectivec/ort_coreml_execution_provider.mm b/objectivec/ort_coreml_execution_provider.mm index 6340fdea1c3a7..58b47d68eea63 100644 --- a/objectivec/ort_coreml_execution_provider.mm +++ b/objectivec/ort_coreml_execution_provider.mm @@ -26,7 +26,10 @@ - (BOOL)appendCoreMLExecutionProviderWithOptions:(ORTCoreMLExecutionProviderOpti const uint32_t flags = (options.useCPUOnly ? COREML_FLAG_USE_CPU_ONLY : 0) | (options.enableOnSubgraphs ? COREML_FLAG_ENABLE_ON_SUBGRAPH : 0) | - (options.onlyEnableForDevicesWithANE ? COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE : 0); + (options.onlyEnableForDevicesWithANE ? COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE : 0) | + (options.onlyAllowStaticInputShapes ? COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES : 0) | + (options.createMLProgram ? COREML_FLAG_CREATE_MLPROGRAM : 0); + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML( [self CXXAPIOrtSessionOptions], flags)); return YES; diff --git a/objectivec/ort_value.mm b/objectivec/ort_value.mm index b9dc1a9885c61..c61a7ea809237 100644 --- a/objectivec/ort_value.mm +++ b/objectivec/ort_value.mm @@ -148,6 +148,9 @@ - (nullable ORTValueTypeInfo*)typeInfoWithError:(NSError**)error { - (nullable ORTTensorTypeAndShapeInfo*)tensorTypeAndShapeInfoWithError:(NSError**)error { try { const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo(); + if (!tensorTypeAndShapeInfo) { + ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION); + } return CXXAPIToPublicTensorTypeAndShapeInfo(tensorTypeAndShapeInfo); } ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) @@ -156,6 +159,9 @@ - (nullable ORTTensorTypeAndShapeInfo*)tensorTypeAndShapeInfoWithError:(NSError* - (nullable NSMutableData*)tensorDataWithError:(NSError**)error { try { const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo(); + if (!tensorTypeAndShapeInfo) { + ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION); + } if (tensorTypeAndShapeInfo.GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { ORT_CXX_API_THROW( "This ORTValue holds string data. Please call tensorStringDataWithError: " @@ -182,6 +188,9 @@ - (nullable NSMutableData*)tensorDataWithError:(NSError**)error { - (nullable NSArray*)tensorStringDataWithError:(NSError**)error { try { const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo(); + if (!tensorTypeAndShapeInfo) { + ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION); + } const size_t elementCount = tensorTypeAndShapeInfo.GetElementCount(); const size_t tensorStringDataLength = _value->GetStringTensorDataLength(); std::vector tensorStringData(tensorStringDataLength, '\0'); diff --git a/onnxruntime/contrib_ops/cpu/activations.cc b/onnxruntime/contrib_ops/cpu/activations.cc index 556699192d2eb..3e0533dd8b9e5 100644 --- a/onnxruntime/contrib_ops/cpu/activations.cc +++ b/onnxruntime/contrib_ops/cpu/activations.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "core/providers/cpu/activation/activations.h" -#include "activations.h" +#include "contrib_ops/cpu/activations.h" namespace onnxruntime { namespace contrib { @@ -26,14 +26,6 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()), ThresholdedRelu); -ONNX_OPERATOR_KERNEL_EX( - Gelu, - kMSDomain, - 1, - kCpuExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Gelu); - ONNX_OPERATOR_KERNEL_EX( QuickGelu, kMSDomain, diff --git a/onnxruntime/contrib_ops/cpu/activations.h b/onnxruntime/contrib_ops/cpu/activations.h index aed4c2229215d..7e64235d3fc3d 100644 --- a/onnxruntime/contrib_ops/cpu/activations.h +++ b/onnxruntime/contrib_ops/cpu/activations.h @@ -54,47 +54,6 @@ namespace contrib { DEFINE_ELE_KERNEL(ScaledTanh); DEFINE_ELE_KERNEL(ParametricSoftplus); -template -class Gelu : public OpKernel { - public: - Gelu(const OpKernelInfo& info) : OpKernel(info) { - } - - Status Compute(OpKernelContext* context) const override { - const Tensor* input = context->Input(0); - const T* input_data = input->Data(); - - Tensor* output = context->Output(0, input->Shape()); - T* output_data = output->MutableData(); - - concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); - int64_t elem_count = input->Shape().Size(); - constexpr int64_t length_per_task = 4096; // this number comes from FastGelu. - int64_t task_count = (elem_count + length_per_task - 1) / length_per_task; - concurrency::ThreadPool::TryBatchParallelFor( - tp, static_cast(task_count), - [&](ptrdiff_t task_idx) { - const auto start = task_idx * length_per_task; - const T* p_input = input_data + start; - T* p_output = output_data + start; - int64_t count = std::min(length_per_task, elem_count - start); - - for (int64_t i = 0; i < count; i++) { - T value = p_input[i]; - p_output[i] = value * static_cast(M_SQRT1_2); - } - - MlasComputeErf(p_output, p_output, narrow(count)); - - for (int64_t i = 0; i < count; i++) { - p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f); - } - }, - 0); - return Status::OK(); - } -}; - // Implement a new one instead of inheriting from ElementWiseRangedTransform so that we can call // MlasComputeLogistic instead of using Eigen for better perf. template diff --git a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h index d72868cd8fa9f..56c8e2911e280 100644 --- a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h +++ b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h @@ -10,7 +10,7 @@ namespace onnxruntime { namespace contrib { namespace aten_ops { -typedef bool (*IsCpuArgumentFunc)(const char* op_name, const char* overload_name, size_t index, bool is_input); +typedef bool (*IsTensorArgumentFunc)(const char* op_name, const char* overload_name, size_t index, bool is_input); typedef void (*ExecuteATenOperatorFunc)(const char* op_name, const char* overload_name, size_t input_size, DLManagedTensor** dlpack_inputs, size_t output_size, DLManagedTensor** dlpack_outputs); @@ -22,17 +22,17 @@ class ATenOperatorExecutor { return instance; } - void Initialize(void* p_is_cpu_argument_func_raw, void* p_execute_aten_op_func_raw) { - ORT_ENFORCE(p_is_cpu_argument_func_raw && p_execute_aten_op_func_raw); - p_is_cpu_argument_func_ = reinterpret_cast(p_is_cpu_argument_func_raw); + void Initialize(void* p_is_tensor_argument_func_raw, void* p_execute_aten_op_func_raw) { + ORT_ENFORCE(p_is_tensor_argument_func_raw && p_execute_aten_op_func_raw); + p_is_tensor_argument_func_ = reinterpret_cast(p_is_tensor_argument_func_raw); p_execute_aten_op_func_ = reinterpret_cast(p_execute_aten_op_func_raw); } bool IsInitialized() { return p_execute_aten_op_func_ != nullptr; } - bool IsCpuArgument(const std::string& op_name, const std::string& overload_name, size_t index, bool is_input) { - ORT_ENFORCE(p_is_cpu_argument_func_, "ATenOperatorExecutor is not initialized."); - return p_is_cpu_argument_func_(op_name.c_str(), overload_name.c_str(), index, is_input); + bool IsTensorArgument(const std::string& op_name, const std::string& overload_name, size_t index, bool is_input) { + ORT_ENFORCE(p_is_tensor_argument_func_, "ATenOperatorExecutor is not initialized."); + return p_is_tensor_argument_func_(op_name.c_str(), overload_name.c_str(), index, is_input); } void operator()(const std::string& op_name, const std::string& overload_name, size_t input_size, @@ -43,7 +43,7 @@ class ATenOperatorExecutor { } private: - IsCpuArgumentFunc p_is_cpu_argument_func_ = nullptr; + IsTensorArgumentFunc p_is_tensor_argument_func_ = nullptr; ExecuteATenOperatorFunc p_execute_aten_op_func_ = nullptr; }; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 8afeb874750b4..5a0c3af05c9da 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -64,6 +64,7 @@ struct AttentionParameters { bool pass_past_in_kv; float mask_filter_value; float scale; + bool use_tf32; AttentionMaskType mask_type; AttentionQkvFormat qkv_format; }; @@ -82,6 +83,7 @@ struct PackedAttentionParameters { int token_count; bool has_relative_position_bias; bool broadcast_res_pos_bias; + bool use_tf32; }; // Parameters deduced from node attributes and inputs/outputs. @@ -96,6 +98,7 @@ struct GroupQueryAttentionParameters { int kv_hidden_size; int kv_num_heads; int num_splits; // number of splits for splitkv + int rotary_dim; // rotary embedding dimension bool is_unidirectional; // causal int local_window_size; bool kv_share_buffer; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index b761b1afd8529..c617533319a18 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -140,17 +140,6 @@ class AttentionCPUBase : public AttentionBase { if (mask_data != nullptr) { PrepareMask(mask_index, mask_index_dims, mask_data, causal, batch_size, sequence_length, past_sequence_length, mask_filter_value_); - } else { // no any mask - const int memset_loop_len = batch_size * num_heads_; - const double memset_cost = static_cast(sequence_length) * total_sequence_length; - - ThreadPool::TryParallelFor(tp, memset_loop_len, memset_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { - for (std::ptrdiff_t i = begin; i != end; ++i) { - const int output_offset = static_cast(i) * sequence_length * total_sequence_length; - T* output = attention_probs + output_offset; - memset(output, 0, static_cast(sequence_length) * total_sequence_length * sizeof(T)); - } - }); } const int loop_len = batch_size * num_heads_; @@ -188,7 +177,7 @@ class AttentionCPUBase : public AttentionBase { // B: K' (B x N x) T x H (B x N x) H x T H x T // C: attention_probs (B x N x) S x T (B x N x) S x T S x T math::Gemm(CblasNoTrans, CblasTrans, sequence_length, total_sequence_length, head_size, alpha, - Q + q_input_chunk_length * i, k, 1.0, + Q + q_input_chunk_length * i, k, mask_data != nullptr ? 1.0f : 0.0f, output, nullptr); if (relative_position_bias_data != nullptr) { diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index eb25d0fd7cc1e..c4e4b4ec707fb 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -58,11 +58,12 @@ Status Reshape_BSD_to_BSNH(Tensor* qkv, // Transpose Q/K/V from BxSxNxH to BxNxSxH Status Transpose_BSNH_to_BNSH(const Tensor* qkv, - OrtValue& qkv_transposed) { + OrtValue& qkv_transposed, + concurrency::ThreadPool* tp = nullptr) { std::vector permutations({0, 2, 1, 3}); gsl::span permutations_span{permutations}; size_t from = 2, to = 1; - SingleAxisTranspose(permutations_span, *qkv, *qkv_transposed.GetMutable(), from, to); + SingleAxisTranspose(permutations_span, *qkv, *qkv_transposed.GetMutable(), from, to, nullptr, tp); return Status::OK(); } @@ -143,7 +144,8 @@ Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V dat ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(qkv_with_bias.GetMutable(), batch_size, sequence_length, num_heads, head_size)); // Transpose Q from BxSxNxH to BxNxSxH - ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable(), qkv_with_bias_transposed)); + auto tp = context->GetOperatorThreadPool(); + ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable(), qkv_with_bias_transposed, tp)); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 72948c74d7877..602dd98d8c0d6 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -1,6 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h" + +#include +#include + +#include "core/common/common.h" #include "core/common/narrow.h" #include "core/common/safeint.h" #include "core/framework/op_kernel.h" @@ -9,6 +15,7 @@ #include "core/mlas/inc/mlas_q4.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" + #ifdef ORT_NEURAL_SPEED #include "contrib_ops/cpu/quantization/neural_speed_gemm.h" #endif @@ -16,6 +23,50 @@ namespace onnxruntime { namespace contrib { +namespace { +int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level_attr) { + const auto accuracy_level = std::clamp(accuracy_level_attr, + static_cast(CompMostAccurate), + static_cast(CompLeastAccurate)); + +#if defined(ORT_NEURAL_SPEED) + + ORT_UNUSED_PARAMETER(nbits); + ORT_UNUSED_PARAMETER(block_size); + + // Neural Speed APIs already expect a minimum accuracy level so just use the given value. + return accuracy_level; + +#else // defined(ORT_NEURAL_SPEED) + + // Find a supported accuracy level that is not less accurate than the one given. + // CompMostAccurate is always supported with the fallback implementation. + // Note: A higher numeric accuracy level value means lower accuracy, so the comparison order is reversed. + int64_t effective_accuracy_level = accuracy_level; + for (; effective_accuracy_level > CompMostAccurate; --effective_accuracy_level) { + const auto compute_type = static_cast(effective_accuracy_level); + if (MlasIsSQNBitGemmAvailable(nbits, block_size, compute_type)) { + break; + } + } + + return effective_accuracy_level; + +#endif // defined(ORT_NEURAL_SPEED) +} +} // namespace + +bool GetType(const NodeArg& node_arg, int32_t& type) { + type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; + const auto* type_proto = node_arg.TypeAsProto(); + if (!type_proto || !type_proto->has_tensor_type() || !type_proto->tensor_type().has_elem_type()) { + return false; + } + + type = type_proto->tensor_type().elem_type(); + return true; +} + class MatMulNBits final : public OpKernel { public: MatMulNBits(const OpKernelInfo& info) @@ -24,7 +75,18 @@ class MatMulNBits final : public OpKernel { N_{narrow(info.GetAttr("N"))}, block_size_{narrow(info.GetAttr("block_size"))}, nbits_{narrow(info.GetAttr("bits"))}, - accuracy_level_{info.GetAttr("accuracy_level")} { + accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr("accuracy_level"))} { + const auto& node = info.node(); + auto input_defs = node.InputDefs(); + // g_idx + if (input_defs.size() > 4) { + act_order_ = true; + } + int32_t type; + if (input_defs.size() > 3 && GetType(*input_defs[3], type)) { + zero_point_is_not_quant_ = type != ONNX_NAMESPACE::TensorProto_DataType_UINT8; + } + ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); #ifdef ORT_NEURAL_SPEED @@ -54,21 +116,30 @@ class MatMulNBits final : public OpKernel { const size_t N_; const size_t block_size_; const size_t nbits_; + bool act_order_{false}; + bool zero_point_is_not_quant_{false}; const int64_t accuracy_level_; const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_; size_t packed_b_size_{0}; -#ifdef ORT_NEURAL_SPEED + +#if defined(ORT_NEURAL_SPEED) + bool is_asym_{false}; bool all_constant_{false}; -#endif + +#endif // defined(ORT_NEURAL_SPEED) }; Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; -#ifdef ORT_NEURAL_SPEED + if (act_order_ || zero_point_is_not_quant_) { + return Status::OK(); + } +#if defined(ORT_NEURAL_SPEED) + if (!all_constant_) { return Status::OK(); } @@ -116,11 +187,17 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat #else // defined(ORT_NEURAL_SPEED) if (input_idx == 1) { - packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_); - if (packed_b_size_ == 0) return Status::OK(); + const auto compute_type = static_cast(accuracy_level_); + if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { + return Status::OK(); + } + packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type); + if (packed_b_size_ == 0) { + return Status::OK(); + } auto qptr = tensor.DataRaw(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, qptr, packed_b_.get()); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get()); if (prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); @@ -136,7 +213,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; -#ifdef ORT_NEURAL_SPEED + +#if defined(ORT_NEURAL_SPEED) + // Pack three tensors into one buffer if (input_idx == 1) { used_shared_buffers = true; @@ -159,16 +238,18 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep } #endif // defined(ORT_NEURAL_SPEED) + return Status::OK(); } Status MatMulNBits::Compute(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); - const Tensor* a = ctx->Input(0); const auto* a_data = a->Data(); -#ifdef ORT_NEURAL_SPEED - if (packed_b_.get()) { + +#if defined(ORT_NEURAL_SPEED) + + if (packed_b_) { TensorShape b_shape({static_cast(N_), static_cast(K_)}); MatMulComputeHelper helper; @@ -207,11 +288,14 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { #endif // defined(ORT_NEURAL_SPEED) const Tensor* scales = ctx->Input(2); - const Tensor* zero_points = ctx->Input(3); + const Tensor* zero_points = ctx->InputCount() > 3 ? ctx->Input(3) : nullptr; + const Tensor* reorder_idx = ctx->InputCount() > 4 ? ctx->Input(4) : nullptr; + const auto* scales_data = scales->Data(); - const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); TensorShape b_shape({static_cast(N_), static_cast(K_)}); + const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); MatMulComputeHelper helper; ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); @@ -231,40 +315,47 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(false); - const bool has_single_b_matrix = std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), - [](size_t offset) { return offset == 0; }); - - if (has_single_b_matrix && packed_b_) { - for (int64_t accuracy_level = accuracy_level_; - accuracy_level >= static_cast(CompMostAccurate); - --accuracy_level) { - const auto compute_type = static_cast(accuracy_level); - if (MlasIsSQNBitGemmAvailable(M, N, K, nbits_, block_size_, compute_type)) { - IAllocatorUniquePtr workspace{}; - if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count, - nbits_, block_size_, compute_type); - workspace_size > 0) { - AllocatorPtr allocator; - ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); - workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); - } + const bool has_single_b_matrix = + (!act_order_) && (!zero_point_is_not_quant_) && + std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), [](size_t offset) { return offset == 0; }); + + if (has_single_b_matrix) { + const auto compute_type = static_cast(accuracy_level_); + + if (MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { + IAllocatorUniquePtr workspace{}; + if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count, + nbits_, block_size_, compute_type); + workspace_size > 0) { + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); + workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); + } - InlinedVector data(batch_count); - for (size_t i = 0; i < batch_count; ++i) { - data[i].A = a_data + helper.LeftOffsets()[i]; - data[i].lda = lda; - data[i].QuantBData = packed_b_.get(); - data[i].QuantBScale = scales_data; - data[i].QuantBZeroPoint = zero_points_data; - data[i].C = y_data + helper.OutputOffsets()[i]; - data[i].ldc = N; + const void* b_data = [&]() -> const void* { + if (packed_b_) { + return packed_b_.get(); } - MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), - thread_pool); - - return Status::OK(); + const Tensor* b = ctx->Input(1); + return b->DataRaw(); + }(); + + InlinedVector data(batch_count); + for (size_t i = 0; i < batch_count; ++i) { + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].QuantBData = b_data; + data[i].QuantBScale = scales_data; + data[i].QuantBZeroPoint = zero_points_data; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; } + + MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), + thread_pool); + + return Status::OK(); } } @@ -272,22 +363,50 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const uint8_t* b_data = b->Data(); const size_t ldb = helper.Ldb(true); - AllocatorPtr allocator; ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); - // dequantize b, only 4b quantization is supported for now - MlasDequantizeBlockwise( - tmp_b_data_ptr.get(), // dequantized output - b_data, // quantized input - scales_data, // quantization scales - zero_points_data, // quantization zero points - static_cast(block_size_), // quantization block size - column_wise_quant_, // columnwise quantization or row-wise - static_cast(K_), // number of rows in quantized input - static_cast(N_), // number of columns in quantized input - thread_pool); - + if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { + // dequantize b, only 4b quantization is supported for now + MlasDequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + static_cast(zero_points_data), // quantization zero points + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } else { + ORT_ENFORCE(column_wise_quant_, "Row-wise quantization is not supported for now"); + // !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!! + if ((zero_points && zero_points->IsDataType())) { + DequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + static_cast(zero_points_data), // quantization zero points + reorder_idx_data, + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } else { + DequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + static_cast(zero_points_data), // quantization zero points + reorder_idx_data, + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } + } #if 0 // for debug auto tm_b_data_ptr_trans = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); MlasTranspose(tmp_b_data_ptr.get(), tm_b_data_ptr_trans.get(), N_, K_); @@ -318,7 +437,9 @@ ONNX_OPERATOR_KERNEL_EX( kCpuExecutionProvider, KernelDefBuilder() .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), MatMulNBits); } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc new file mode 100644 index 0000000000000..7e343d85f4048 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc @@ -0,0 +1,109 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h" + +#include +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/framework/float16.h" +#include "core/providers/common.h" +#include "core/platform/threadpool.h" + +namespace onnxruntime { +namespace contrib { + +template +void Dequantize4BitsKernelReOrder( + T* output, const uint8_t* quant_data, const T* scale_data, + const zeroT* zero_points, const int32_t* reorder_idx, int block_size, + int groups_per_threadblock, int total_groups, int out_rows, int out_cols, + int blockIdx_x, int threadIdx_x) { + const int group_id = blockIdx_x * groups_per_threadblock + ((threadIdx_x * 8) / block_size); + if (group_id >= total_groups) { + return; + } + const int scales_shape_x = (out_cols + block_size - 1) / block_size; + const int zero_point_shape_x = (scales_shape_x + 1) / 2; + + int n_idx = group_id / scales_shape_x; + int kb_idx = group_id % scales_shape_x; + int element_offset = group_id * block_size + ((threadIdx_x * 8) & (block_size - 1)); + + const int out_x = element_offset % (scales_shape_x * block_size); + const int out_y = element_offset / (scales_shape_x * block_size); + if (out_y >= out_rows || out_x >= out_cols) { + return; + } + T* output_i = output + out_y * out_cols + out_x; + uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); + const int remain_x = std::min(8, out_cols - out_x); + const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx_x * 8) & (block_size - 1)); + for (int i = 0; i < remain_x; i++) { + int32_t rid = reorder_idx ? reorder_idx_with_off[i] : kb_idx; + T scale = *(scale_data + n_idx * scales_shape_x + rid); + float zp_f = 8; + if (zero_points) { + if constexpr (std::is_same_v) { + zp_f = *(zero_points + n_idx * scales_shape_x + rid); + } else { + uint8_t zp = 8; + zp = zero_points[n_idx * zero_point_shape_x + rid / 2]; + zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f); + } + } + + if constexpr (std::is_same_v) { + T zp_adjust = -scale * MLFloat16(zp_f); + output_i[i] = static_cast((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } else { + T zp_adjust = -scale * zp_f; + output_i[i] = T((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } + } +} + +template +void DequantizeBlockwise( + inputT* output, // dequantized output + const uint8_t* quant_data, // quantized input + const inputT* scales_data, // quantization scales + const zeroT* zero_points, // quantization zero points + const int32_t* reorder_idx, // reorder_idx for groupwise quantization + int32_t block_size, // quantization block size + bool, // columnwise quantization or row-wise + int32_t K, // number of rows in quantized input + int32_t N, // number of columns in quantized input + onnxruntime::concurrency::ThreadPool* pool) { + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + constexpr int element_per_thread = 8; + int groups_per_threadblock = 256 * element_per_thread / block_size; + int groups_per_K = ceildiv(K, block_size); + int total_groups = N * groups_per_K; // total elemenets in quant_data + int blocks_per_grid = static_cast(ceildiv(total_groups, groups_per_threadblock)); + concurrency::ThreadPool::TrySimpleParallelFor( + pool, static_cast(blocks_per_grid), + [&](std::ptrdiff_t block_id) { + for (int j = 0; j < 256; j++) { + Dequantize4BitsKernelReOrder(output, quant_data, scales_data, zero_points, + reorder_idx, block_size, groups_per_threadblock, + total_groups, N, K, static_cast(block_id), j); + } + }); +} + +template void DequantizeBlockwise( + float* output, const uint8_t* quant_data, const float* scales_data, + const uint8_t* zero_points, const int32_t* reorder_idx, int32_t block_size, + bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); + +template void DequantizeBlockwise( + float* output, const uint8_t* quant_data, const float* scales_data, + const float* zero_points, const int32_t* reorder_idx, int32_t block_size, + bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h new file mode 100644 index 0000000000000..5061ac5c800a6 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/providers/common.h" +#include "core/platform/threadpool.h" + +namespace onnxruntime { +namespace contrib { + +template +void DequantizeBlockwise( + inputT* output, // dequantized output + const uint8_t* quant_data, // quantized input + const inputT* scales_data, // quantization scales + const zeroT* zero_points, // quantization zero points + const int32_t* reorder_idx, // quantization zero points + int32_t block_size, // quantization block size + bool, // columnwise quantization or row-wise + int32_t K, // number of rows in quantized input + int32_t N, // number of columns in quantized input + onnxruntime::concurrency::ThreadPool* thread_pool); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h index dc72a038c3d58..b18e122980eda 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h @@ -258,7 +258,7 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch cpu_state.sequences.InitDevice(beam_state.sequences_device); ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2), cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2), - nullptr, + this->ort_stream_, DeviceCopyDirection::hostToDevice)); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index cd891a9508019..8f5cdc97f27e5 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -214,7 +214,7 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches cpu_state.sequences.InitDevice(beam_state.sequences_device); ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2), cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2), - nullptr, + this->ort_stream_, DeviceCopyDirection::hostToDevice)); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h index 4d6643c68a98b..af0904b7d6e4b 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h @@ -134,8 +134,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe TensorShape no_speech_probs_shape{parameters->batch_size}; Tensor* no_speech_probs = this->context_.Output(parameters->no_speech_probs_output_id, no_speech_probs_shape); if (no_speech_probs && no_speech_probs->MutableData()) { - ORT_ENFORCE(parameters->no_speech_token >= 0 && parameters->no_speech_token < parameters->vocab_size, - "no_speech_token id out of range, it is ", parameters->no_speech_token, + ORT_ENFORCE(parameters->no_speech_token_id >= 0 && parameters->no_speech_token_id < parameters->vocab_size, + "no_speech_token_id is out of range, it is ", parameters->no_speech_token_id, ", vocab_size is ", parameters->vocab_size); this->parameters_->no_speech_probs = (void*)no_speech_probs->MutableData(); } @@ -226,7 +226,7 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe cpu_state.sequences.InitDevice(beam_state.sequences_device); ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2), cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2), - nullptr, + this->ort_stream_, DeviceCopyDirection::hostToDevice)); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index bb6885c3216bc..93837e785b4a4 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -153,7 +153,13 @@ void WhisperBeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) model_type = static_cast(info.GetAttrOrDefault("model_type", IGenerationParameters::kModelTypeWhisper)); ORT_ENFORCE(model_type == IGenerationParameters::kModelTypeWhisper); - no_speech_token = static_cast(info.GetAttrOrDefault("no_speech_token", -1LL)); + // Token ids are defined below in the order that they appear in the tokenizer + translate_token_id = static_cast(info.GetAttrOrDefault("translate_token_id", -1LL)); + transcribe_token_id = static_cast(info.GetAttrOrDefault("transcribe_token_id", -1LL)); + start_of_lm_token_id = static_cast(info.GetAttrOrDefault("start_of_lm_token_id", -1LL)); + no_speech_token_id = static_cast(info.GetAttrOrDefault("no_speech_token_id", -1LL)); + no_timestamps_token_id = static_cast(info.GetAttrOrDefault("no_timestamps_token_id", -1LL)); + beginning_timestamp_token_id = static_cast(info.GetAttrOrDefault("beginning_timestamp_token_id", -1LL)); cross_qk_layer_head_input_id = 12; extra_decoding_ids_input_id = 13; cross_qk_output_id = 3; diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index cb62e2f7bf4da..b1dd55eb20f34 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -183,7 +183,14 @@ struct IGenerationParameters { // Parameters for whisper model bool decoder_output_cross_qk = false; gsl::span extra_decoding_ids; - int32_t no_speech_token = -1; + + // Token ids are defined below in the order that they appear in the tokenizer + int32_t translate_token_id = -1; + int32_t transcribe_token_id = -1; + int32_t start_of_lm_token_id = -1; + int32_t no_speech_token_id = -1; + int32_t no_timestamps_token_id = -1; + int32_t beginning_timestamp_token_id = -1; void* no_speech_probs = nullptr; int cross_qk_layer_head_input_id = -1; diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index f39f090c78b0c..c74e9160cc43f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -17,14 +17,6 @@ namespace onnxruntime { namespace contrib { namespace transformers { -#ifdef DEBUG_GENERATION -template -void DumpScores(const char* name, const NextTokenScores& next_token_scores) { - std::cout << name << std::endl; - ORT_UNUSED_PARAMETER(next_token_scores); -} -#endif - // Interface for all scorers for beam search or beam sample. template MinLengthLogitsProcessor::MinLengthLogitsProcessor(int min_length, int eos_token_id) @@ -36,10 +28,6 @@ void MinLengthLogitsProcessor::Process(const ISequences* sequences, if (sequences->GetSequenceLength() < min_length_) { next_token_scores.SetScore(eos_token_id_, std::numeric_limits::lowest()); } - -#ifdef DEBUG_GENERATION - DumpScores("MinLengthLogitsProcessor", next_token_scores); -#endif } template @@ -68,10 +56,6 @@ void RepetitionPenaltyLogitsProcessor::Process(const ISequences* sequences, beam_token_scores[word_id] = (score < 0 ? score * penalty_ : score / penalty_); } } - -#ifdef DEBUG_GENERATION - DumpScores("RepetitionPenaltyLogitsProcessor", next_token_scores); -#endif } template @@ -109,10 +93,6 @@ void NoRepeatNGramLogitsProcessor::Process(const ISequences* sequences, beam_token_scores[word_id] = std::numeric_limits::lowest(); } } - -#ifdef DEBUG_GENERATION - DumpScores("NoRepeatNGramLogitsProcessor", next_token_scores); -#endif } template @@ -136,10 +116,6 @@ void VocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, } } } - -#ifdef DEBUG_GENERATION - DumpScores("VocabMaskLogitsProcessor", next_token_scores); -#endif } template @@ -171,10 +147,6 @@ void PrefixVocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, } } } - -#ifdef DEBUG_GENERATION - DumpScores("PrefixVocabMaskLogitsProcessor", next_token_scores); -#endif } template @@ -193,10 +165,6 @@ void TemperatureLogitsProcessor::Process(const ISequences* /*sequences*/, *p /= temperature_; ++p; } - -#ifdef DEBUG_GENERATION - DumpScores("TemperatureLogitsProcessor", next_token_scores); -#endif } template @@ -218,10 +186,6 @@ void PresencePenaltyLogitsProcessor::Process(const ISequences*, for (size_t i = 0; i < next_token_scores.scores.size(); i++) { *p -= presence_mask_[i] * presence_penalty_; } - -#ifdef DEBUG_GENERATION - DumpScores("PresencePenaltyLogitsProcessor", next_token_scores); -#endif } void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 4688ff272cee9..231eb17d1a947 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -10,6 +10,7 @@ #include "contrib_ops/cpu/transformers/greedy_search_parameters.h" #include "contrib_ops/cpu/transformers/sampling_parameters.h" #include "contrib_ops/cpu/transformers/generation_shared.h" +#include namespace onnxruntime { namespace contrib { @@ -34,6 +35,14 @@ struct NextTokenScores { } }; +#ifdef DEBUG_GENERATION +template +void DumpScores(const char* name, const NextTokenScores& next_token_scores) { + std::cout << name << std::endl; + ORT_UNUSED_PARAMETER(next_token_scores); +} +#endif + // Interface for all scorers for beam search or beam sample. template class ILogitsProcessor { @@ -150,19 +159,25 @@ class PresencePenaltyLogitsProcessor : public ILogitsProcessor { template class TimestampLogitsProcessor : public ILogitsProcessor { public: - TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index) - : eos_token_id_(eos_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {} + TimestampLogitsProcessor(int end_of_text_token_id, // <|endoftext|> + int start_of_transcript_token_id, // <|startoftranscript|> + int translate_token_id, // <|translate|> + int transcribe_token_id, // <|transcribe|> + int start_of_lm_token_id, // <|startoflm|> + int no_timestamps_token_id, // <|notimestamps|> + int beginning_timestamp_token_id, // <|0.00|> + int max_initial_timestamp_index) + : end_of_text_token_id_(end_of_text_token_id), + start_of_transcript_token_id_(start_of_transcript_token_id), + translate_token_id_(translate_token_id), + transcribe_token_id_(transcribe_token_id), + start_of_lm_token_id_(start_of_lm_token_id), + no_timestamps_token_id_(no_timestamps_token_id), + beginning_timestamp_token_id_(beginning_timestamp_token_id), + max_initial_timestamp_index_(max_initial_timestamp_index) {} void Process(const ISequences* sequences, NextTokenScores& next_token_scores) override { - // TODO: translate_token_id_ and transcribe_token_id_ need to support both multilingual and English-only models. - const int beg_token_id_ = eos_token_id_ + 107; - const int not_token_id_ = eos_token_id_ + 106; - const int solm_token_id_ = eos_token_id_ + 105; - const int sot_token_id_ = eos_token_id_ + 1; - constexpr int translate_token_id_ = 50358; - constexpr int transcribe_token_id_ = 50359; - const int batch_beam_size = next_token_scores.batch_beam_size; const int vocab_size = next_token_scores.vocab_size; for (int i = 0; i < batch_beam_size; i++) { @@ -174,7 +189,7 @@ class TimestampLogitsProcessor : public ILogitsProcessor { size_t sample_begin = 0; for (size_t j = 0; j < seq_length; j++) { sample_begin++; - if (sequence[j] >= beg_token_id_) { + if (sequence[j] >= beginning_timestamp_token_id_) { break; } } @@ -182,30 +197,30 @@ class TimestampLogitsProcessor : public ILogitsProcessor { // Suppress tokens for (int j = 0; j < vocab_size; j++) { // Suppress notimestamps and solm tokens - if (j == not_token_id_ || j == solm_token_id_) { + if (j == no_timestamps_token_id_ || j == start_of_lm_token_id_) { beam_token_scores[j] = std::numeric_limits::lowest(); } // Suppress sot, translate and transcribe tokens if (seq_length > sample_begin) { - if (j == sot_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { + if (j == start_of_transcript_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { beam_token_scores[j] = std::numeric_limits::lowest(); } } } // Timestamps should be in pair except the first one - const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beg_token_id_; - const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beg_token_id_; + const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beginning_timestamp_token_id_; + const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beginning_timestamp_token_id_; if (last_was_timestamp) { if (penultimate_was_timestamp) { // If timestamps show up in pair, or it's the first timestamp, no more timestamp is generated - for (int j = beg_token_id_; j < vocab_size; j++) { + for (int j = beginning_timestamp_token_id_; j < vocab_size; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } } else { // If timestamp doesn't show up in pair, generate timestamp - for (int j = 0; j < eos_token_id_; j++) { + for (int j = 0; j < end_of_text_token_id_; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } } @@ -214,7 +229,7 @@ class TimestampLogitsProcessor : public ILogitsProcessor { // Find timestamp tokens std::vector timestamps; for (const auto& word_id : sequence) { - if (word_id >= beg_token_id_) { + if (word_id >= beginning_timestamp_token_id_) { timestamps.push_back(word_id); } } @@ -231,13 +246,13 @@ class TimestampLogitsProcessor : public ILogitsProcessor { timestamp_last = timestamps.back() + 1; } - for (int j = beg_token_id_; j < timestamp_last; j++) { + for (int j = beginning_timestamp_token_id_; j < timestamp_last; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } } if (seq_length == sample_begin) { - const int last_allowed = beg_token_id_ + max_initial_timestamp_index_; + const int last_allowed = beginning_timestamp_token_id_ + max_initial_timestamp_index_; for (int j = last_allowed + 1; j < vocab_size; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } @@ -247,8 +262,8 @@ class TimestampLogitsProcessor : public ILogitsProcessor { float timestamp_logprob = std::numeric_limits::lowest(); { float logsumexp = 0.0f; - const float logprob_max = *std::max_element(beam_token_scores.begin() + beg_token_id_, beam_token_scores.end()); - for (int j = beg_token_id_; j < vocab_size; ++j) { + const float logprob_max = *std::max_element(beam_token_scores.begin() + beginning_timestamp_token_id_, beam_token_scores.end()); + for (int j = beginning_timestamp_token_id_; j < vocab_size; ++j) { if (beam_token_scores[j] > std::numeric_limits::lowest()) { logsumexp += expf(beam_token_scores[j] - logprob_max); } @@ -258,21 +273,23 @@ class TimestampLogitsProcessor : public ILogitsProcessor { } } - const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beg_token_id_); + const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beginning_timestamp_token_id_); if (timestamp_logprob > max_text_token_logprob) { - for (int j = 0; j < beg_token_id_; ++j) { + for (int j = 0; j < beginning_timestamp_token_id_; ++j) { beam_token_scores[j] = std::numeric_limits::lowest(); } } } - -#ifdef DEBUG_GENERATION - DumpScores("TimestampLogitsProcessor", next_token_scores); -#endif } private: - int eos_token_id_; + int end_of_text_token_id_; + int start_of_transcript_token_id_; + int translate_token_id_; + int transcribe_token_id_; + int start_of_lm_token_id_; + int no_timestamps_token_id_; + int beginning_timestamp_token_id_; int max_initial_timestamp_index_; }; @@ -334,7 +351,15 @@ class LogitsProcessorList : public ILogitsProcessorList { // Add timestamp processor for whisper model if (parameters.model_type == IGenerationParameters::kModelTypeWhisper && parameters.logits_processor == IGenerationParameters::kLogitsProcessorTypeWhisper) { constexpr int max_initial_timestamp_index = 50; - timestamp_processor_ = std::make_unique>(parameters.eos_token_id, max_initial_timestamp_index); + // Token ids are passed below in the order that they appear in the tokenizer + timestamp_processor_ = std::make_unique>(parameters.eos_token_id, + parameters.decoder_start_token_id, + parameters.translate_token_id, + parameters.transcribe_token_id, + parameters.start_of_lm_token_id, + parameters.no_timestamps_token_id, + parameters.beginning_timestamp_token_id, + max_initial_timestamp_index); processor_list_.push_back(timestamp_processor_.get()); } diff --git a/onnxruntime/contrib_ops/cuda/activation/activations.cc b/onnxruntime/contrib_ops/cuda/activation/activations.cc index 1a86c5dbece5a..6303858b9bd48 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations.cc +++ b/onnxruntime/contrib_ops/cuda/activation/activations.cc @@ -49,7 +49,6 @@ namespace cuda { UNARY_ACTIVATION_OP_HFD(Affine, 1, kOnnxDomain); UNARY_ACTIVATION_OP_HFD(ParametricSoftplus, 1, kOnnxDomain); UNARY_ACTIVATION_OP_HFD(ScaledTanh, 1, kOnnxDomain); -UNARY_ACTIVATION_OP_HFD(Gelu, 1, kMSDomain); UNARY_ACTIVATION_OP_HFD(QuickGelu, 1, kMSDomain); REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, MLFloat16) diff --git a/onnxruntime/contrib_ops/cuda/activation/activations.h b/onnxruntime/contrib_ops/cuda/activation/activations.h index ab339f276c2bd..fc9a71b0b7fa1 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations.h +++ b/onnxruntime/contrib_ops/cuda/activation/activations.h @@ -66,17 +66,6 @@ class ScaledTanh final : public UnaryElementwise { float beta_; }; -template -class Gelu final : public UnaryElementwise { - public: - Gelu(const OpKernelInfo& info) : UnaryElementwise(info) {} - - Status ComputeInternal(OpKernelContext* context) const override; - - private: - MAKE_FUNC_CTX_NULL() -}; - template class QuickGelu final : public UnaryElementwise { public: diff --git a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu index 0c856815fd437..36f33fbb24c18 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu +++ b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu @@ -36,20 +36,6 @@ struct OP_ScaledTanh : public CtxScaledTanh { } }; -template -struct OP_Gelu : public CtxGelu { - __device__ __inline__ T operator()(const T& a) const { - return _Gelu(a); - } -}; - -template <> -struct OP_Gelu : public CtxGelu { - __device__ __inline__ half operator()(const half& a) const { - return static_cast(_Gelu(static_cast(a))); - } -}; - template struct OP_QuickGelu : public CtxQuickGelu { __device__ __inline__ T operator()(const T& a) const { diff --git a/onnxruntime/contrib_ops/cuda/activation/activations_impl.h b/onnxruntime/contrib_ops/cuda/activation/activations_impl.h index 5d18283a395e3..782d4bf59a5ad 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations_impl.h +++ b/onnxruntime/contrib_ops/cuda/activation/activations_impl.h @@ -11,14 +11,12 @@ namespace cuda { typedef onnxruntime::cuda::CtxAlphaBeta CtxAffine; typedef onnxruntime::cuda::CtxAlphaBeta CtxParametricSoftplus; typedef onnxruntime::cuda::CtxAlphaBeta CtxScaledTanh; -typedef onnxruntime::cuda::CtxNull CtxGelu; typedef onnxruntime::cuda::CtxAlpha CtxQuickGelu; #define UNARY_CONTRIB_ACTIVATION_OPS() \ UNARY_ACTIVATION_OP_NAME(ScaledTanh) \ UNARY_ACTIVATION_OP_NAME(Affine) \ UNARY_ACTIVATION_OP_NAME(ParametricSoftplus) \ - UNARY_ACTIVATION_OP_NAME(Gelu) \ UNARY_ACTIVATION_OP_NAME(QuickGelu) #define UNARY_ACTIVATION_OP_NAME(name) UNARY_ACTIVATION_IMPL_DECLARATION(name); diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu index 1ea2540db486f..9e6752b451868 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu @@ -843,11 +843,11 @@ void InvokeAddBiasTransposeTrt( template <> void LaunchAddBiasTransposeTrt( - cudaStream_t stream, const int max_threads_per_block, - const int batch_size, const int sequence_length, - const int num_heads, const int head_size, - const float* biases, const float* query, const float* key, const float* value, float* output, - bool is_cross_attention, int kv_sequence_length) { + cudaStream_t /*stream*/, const int /*max_threads_per_block*/, + const int /*batch_size*/, const int /*sequence_length*/, + const int /*num_heads*/, const int /*head_size*/, + const float* /*biases*/, const float* /*query*/, const float* /*key*/, const float* /*value*/, float* /*output*/, + bool /*is_cross_attention*/, int /*kv_sequence_length*/) { ORT_ENFORCE(false, "Shall not call this since fused kernel does not support float input."); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index bf6431cf1afb2..7a807342ad685 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -84,6 +84,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { auto& device_prop = GetDeviceProp(); AttentionParameters parameters; + parameters.use_tf32 = UseTF32(); + // Use the second dimension from weight for bias to get q_hidden_size when bias is nullptr std::vector bias_dims{weights->Shape().GetDims()[1]}; const TensorShape bias_shape{bias_dims}; @@ -251,7 +253,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(weights->Data()), n, reinterpret_cast(input->Data()), k, - &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, UseTF32())); constexpr size_t element_size = sizeof(T); constexpr bool use_fused_cross_attention = false; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 54c9a5da1e9da..a93fdf74dc28c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -58,12 +58,12 @@ size_t AlignSize(size_t bytes) { return bytesAligned; } -void CumulatedSequenceLengthCache::Initialize(int32_t sequence_length, cudaStream_t stream) { - if (this->sequence_length != sequence_length) { +void CumulatedSequenceLengthCache::Initialize(int32_t seq_length, cudaStream_t stream) { + if (this->sequence_length != seq_length) { ORT_ENFORCE(buffer.get() != nullptr && this->max_batch_size > 0); LaunchTrtSequenceOffset(reinterpret_cast(buffer.get()), nullptr, - this->max_batch_size, sequence_length, stream); - this->sequence_length = sequence_length; + this->max_batch_size, seq_length, stream); + this->sequence_length = seq_length; } } @@ -213,9 +213,9 @@ Status FusedTrtCrossAttention( template <> Status FusedTrtCrossAttention( - cudaStream_t stream, - contrib::AttentionParameters& parameters, - AttentionData& data) { + cudaStream_t /*stream*/, + contrib::AttentionParameters& /*parameters*/, + AttentionData& /*data*/) { return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "Trt fused cross attention does not support float tensor"); } @@ -276,9 +276,9 @@ Status FusedTrtSelfAttention( // Template Specialization for float type template <> Status FusedTrtSelfAttention( - cudaStream_t stream, - contrib::AttentionParameters& parameters, - AttentionData& data) { + cudaStream_t /*stream*/, + contrib::AttentionParameters& /*parameters*/, + AttentionData& /*data*/) { return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "Trt fused attention does not support float tensor"); } @@ -461,7 +461,8 @@ Status UnfusedAttention( total_sequence_length, sequence_length, qk_head_size, &alpha, data.k, qk_head_size, present_size_per_batch_k, data.q, qk_head_size, sequence_length * qk_head_size, - &zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop)); + &zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, + device_prop, parameters.use_tf32)); DUMP_TENSOR_D("Q", data.q, batch_size, num_heads, sequence_length, qk_head_size); DUMP_TENSOR_D("K", data.k, batch_size, num_heads, qk_head_size, sequence_length); @@ -514,7 +515,7 @@ Status UnfusedAttention( v_head_size, sequence_length, total_sequence_length, &one, data.v, v_head_size, present_size_per_batch_v, scratch2, total_sequence_length, sequence_length * total_sequence_length, - &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop)); + &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32)); // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads, diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index a513d9e8d2211..b843966d88e85 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -231,7 +231,7 @@ Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + T* /*q*/, T* /*k*/, T* /*v*/, AttentionQkvFormat& qkv_format) { const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int num_heads = parameters.num_heads; @@ -279,7 +279,7 @@ Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + T* /*q*/, T* k, T* /*v*/, AttentionQkvFormat& qkv_format) { const int batch_size = parameters.batch_size; const int kv_sequence_length = parameters.kv_sequence_length; const int num_heads = parameters.num_heads; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index db78722cc0e4c..c12cb374d9adf 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -242,18 +242,18 @@ void DispatchIsAligned(const MemoryEfficientAttentionParams& params) { using AlignedAK = AttentionKernel; #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(push) -#pragma warning(disable : 6287) +#pragma warning(disable : 6287 4189) // kAligned is used via capture so 4189 warning seems incorrect #endif // Run a more efficient kernel with `isAligned=True` when memory is correctly aligned. bool is_aligned = params.qk_head_size % AlignedAK::kAlignmentQ == 0 && params.qk_head_size % AlignedAK::kAlignmentK == 0 && params.v_head_size % AlignedAK::kAlignmentV == 0; -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(pop) -#endif DISPATCH_BOOL(is_aligned, kIsAligned, ([&]() { LaunchCutlassFmha(params); })); +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif } template diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc index 3f703ae3d05e6..ceee17c2a2d01 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc @@ -273,13 +273,13 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one, reinterpret_cast(bias->Data()), n, GetConstOnes(m, Stream(context)), 1, - &zero, reinterpret_cast(gemm_query_buffer_p.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_query_buffer_p.get()), n, device_prop, UseTF32())); // matmul: (h2, h1)*(h1, S*B) CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(q_weights->Data()), n, reinterpret_cast(query->Data()), k, - &one, reinterpret_cast(gemm_query_buffer_p.get()), n, device_prop)); + &one, reinterpret_cast(gemm_query_buffer_p.get()), n, device_prop, UseTF32())); // gemm_query_buffer in col-base: (h2, S*B) // calcualte k, v @@ -298,13 +298,13 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one, reinterpret_cast(bias->Data() + hidden_size), n, GetConstOnes(m, Stream(context)), 1, - &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // matmul: (2*h2, h1)*(h1, T_S*B) CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(kv_weights->Data()), n, reinterpret_cast(query->Data()), k, - &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // gemm_kv_buffer in col-base: (2*h2, T_S*B) } else { gemm_kv_buffer_p = GetScratchBuffer(static_cast(batch_size) * 2 * key_sequence_length * hidden_size, @@ -318,13 +318,13 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one, reinterpret_cast(bias->Data() + hidden_size), n, GetConstOnes(m, Stream(context)), 1, - &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // matmul: (2*h2, h1)*(h1, T_S*B) CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(kv_weights->Data()), n, reinterpret_cast(key->Data()), k, - &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // gemm_kv_buffer in col-base: (2*h2, T_S*B) } } else { @@ -342,13 +342,13 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one, reinterpret_cast(bias->Data() + hidden_size), n, GetConstOnes(m, Stream(context)), 1, - &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // matmul: (2*h2, h1)*(h1, T_S*B) CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(kv_weights->Data()), n, reinterpret_cast(query->Data()), k, - &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // gemm_kv_buffer in col-base: (2*h2, T_S*B) } else { kv_sequence_length = cache_sequence_length; @@ -372,6 +372,8 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { device_prop, #ifdef USE_ROCM GetTuningContext(), +#else + UseTF32(), #endif context->GetComputeStream(), cublas, diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu index 1dc22a9c8ea98..c0b1996789183 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu @@ -17,7 +17,7 @@ Status DecoderQkvToContext( const cudaDeviceProp& device_prop, Stream* ort_stream, cublasHandle_t& cublas, - const size_t element_size, + const size_t /*element_size*/, const int batch_size, const int sequence_length, const int kv_sequence_length, @@ -37,7 +37,8 @@ Status DecoderQkvToContext( T* workspace_buffer, T* output, T* new_key_cache, - T* new_value_cache) { + T* new_value_cache, + bool use_tf32) { const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int BN = batch_size * num_heads; const int BHN = BN * head_size; @@ -128,14 +129,14 @@ Status DecoderQkvToContext( kv_sequence_length, sequence_length, head_size, &alpha, key_cache, head_size, strideA, q, head_size, strideB, - &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop, use_tf32)); } else { CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_T, CUBLAS_OP_N, kv_sequence_length, sequence_length, head_size, &alpha, k, head_size, strideA, q, head_size, strideB, - &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop, use_tf32)); } constexpr bool is_unidirectional = false; @@ -163,14 +164,14 @@ Status DecoderQkvToContext( head_size, sequence_length, kv_sequence_length, &one, value_cache, head_size, strideA, scratch2, kv_sequence_length, temp_matrix_size, - &zero, scratch3, head_size, strideB, BN, device_prop)); + &zero, scratch3, head_size, strideB, BN, device_prop, use_tf32)); } else { CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, head_size, sequence_length, kv_sequence_length, &one, v, head_size, strideA, scratch2, kv_sequence_length, temp_matrix_size, - &zero, scratch3, head_size, strideB, BN, device_prop)); + &zero, scratch3, head_size, strideB, BN, device_prop, use_tf32)); } // scratch3 is BxNxSxH, transpose to output SxBxNxH @@ -180,6 +181,7 @@ Status DecoderQkvToContext( Status LaunchDecoderAttentionKernel( const cudaDeviceProp& device_prop, + bool use_tf32, Stream* stream, cublasHandle_t& cublas, const size_t element_size, @@ -228,7 +230,8 @@ Status LaunchDecoderAttentionKernel( reinterpret_cast(workspace_buffer), reinterpret_cast(output), reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); + reinterpret_cast(new_value_cache), + use_tf32); } else { return DecoderQkvToContext( device_prop, @@ -254,7 +257,8 @@ Status LaunchDecoderAttentionKernel( reinterpret_cast(workspace_buffer), reinterpret_cast(output), reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); + reinterpret_cast(new_value_cache), + use_tf32); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h index 9db9ccb45e330..f9667a613e648 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h @@ -11,6 +11,7 @@ namespace cuda { Status LaunchDecoderAttentionKernel( const cudaDeviceProp& prop, // Device Properties + bool use_tf32, // Use TF32 Stream* stream, // ORT Stream cublasHandle_t& cublas, // Cublas handle const size_t element_size, // Element size of input tensor diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc index a9b60da0c96ca..66c0aceaed1e7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc @@ -74,6 +74,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* parameters.kv_data_in_flight = ParseEnvironmentVariableWithDefault( attention::kDecoderMaskedAttentionLoadKVDataInFlight, false); + bool is_unidirectional = false; bool is_dmmha_packing = (key == nullptr && value == nullptr); ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, @@ -88,6 +89,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* num_heads_, mask_filter_value_, scale_, + is_unidirectional, past_present_share_buffer_, is_dmmha_packing, // dmmha_packing device_prop.maxThreadsPerBlock)); diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc index 72ede2e22b557..07a6fbd60e171 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc @@ -143,7 +143,7 @@ Status DecoderMaskedSelfAttention::ComputeInternal(OpKernelContext* cont cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(weights->Data()), n, reinterpret_cast(input->Data()), k, - &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, UseTF32())); // Update the q, k, and v buffers parameters.q = gemm_buffer.get(); diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc index 892f5c181a607..8b8e4e267f895 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc @@ -4,9 +4,13 @@ #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cudnn_common.h" #include "fast_gelu.h" -#include "fast_gelu_impl.h" +#include "core/providers/cuda/tensor/gelu_impl.h" #include "contrib_ops/cpu/bert/bias_gelu_helper.h" -#include "transformer_common.h" +#ifdef USE_ROCM +#include "contrib_ops/rocm/bert/elementwise.h" +#else +#include "contrib_ops/cuda/bert/transformer_common.h" +#endif namespace onnxruntime { namespace contrib { @@ -31,8 +35,10 @@ using namespace ONNX_NAMESPACE; template FastGelu::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) { +#ifndef USE_ROCM const TransformerOptions* options = TransformerOptions::GetInstance(); use_half2_ = !options->DisableHalf2(); +#endif } template @@ -50,6 +56,13 @@ Status FastGelu::ComputeInternal(OpKernelContext* context) const { int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size(); typedef typename ToCudaType::MappedType CudaT; +#ifdef USE_ROCM + return LaunchElementwiseKernel( + GetTuningContext(), context->GetComputeStream(), + reinterpret_cast(input->Data()), static_cast(input_length), + (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, static_cast(bias_length), + reinterpret_cast(output->MutableData())); +#else return LaunchFastGeluKernel(GetDeviceProp(), Stream(context), static_cast(input_length), @@ -58,6 +71,7 @@ Status FastGelu::ComputeInternal(OpKernelContext* context) const { (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, reinterpret_cast(output->MutableData()), use_half2_); +#endif } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h index 3e642a70afef5..26f3bd5a03928 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h @@ -18,7 +18,9 @@ class FastGelu final : public CudaKernel { Status ComputeInternal(OpKernelContext* ctx) const override; private: +#ifndef USE_ROCM bool use_half2_; +#endif }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index 2c296bf4f8483..0f58a74c4d2fd 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -371,6 +371,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, int seqlen_q, int seqlen_k, int seqlen_k_new, + int rotary_dim, const float softmax_scale, bool is_causal, bool is_bf16, @@ -448,7 +449,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, params.rotary_cos_ptr = rotary_cos; params.rotary_sin_ptr = rotary_sin; params.is_rotary_interleaved = is_rotary_interleaved; - params.rotary_dim = (head_size / 16) * 16; + params.rotary_dim = rotary_dim; } params.num_splits = num_splits; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 387d1cf9d84fe..24891bcc4d499 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -96,6 +96,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, int seqlen_q, int seqlen_k, int seqlen_k_new, + int rotary_dim, const float softmax_scale, bool is_causal, bool is_bf16, diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index fe56f84f0a886..814aa1fb3c8f0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -110,6 +110,11 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { parameters.do_rotary = do_rotary_; parameters.rotary_interleaved = rotary_interleaved_; + if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1"); + } + TensorShapeVector output_shape(3); output_shape[0] = static_cast(parameters.batch_size); output_shape[1] = static_cast(sequence_length); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h index 853e1a710cb24..1a7c3fcea3fa3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -205,6 +205,7 @@ Status CheckInputs(const Tensor* query, int total_sequence_length = *((*total_seqlen).template Data()); int present_sequence_length = std::max(total_sequence_length, past_sequence_length); + int rotary_dim = 0; if (cos_cache != nullptr && sin_cache != nullptr) { const auto& cos_dims = cos_cache->Shape().GetDims(); const auto& sin_dims = sin_cache->Shape().GetDims(); @@ -214,22 +215,27 @@ Status CheckInputs(const Tensor* query, "head_size shall be a multiple of 16. Got head_size % 16 == ", head_size % 16); } - if (cos_dims[0] != present_sequence_length) { + if (cos_dims[0] < present_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 0 must be of present_sequence_length."); + "cos_cache dimension 0 should be of max_sequence_length."); } - if (sin_dims[0] != present_sequence_length) { + if (sin_dims[0] < present_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 0 must be of present_sequence_length."); + "sin_cache dimension 0 should be of max_sequence_length."); } - if (cos_dims[1] != (head_size / 16) * 8) { + if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); } - if (sin_dims[1] != (head_size / 16) * 8) { + if (sin_dims[1] > (head_size / 16) * 8 || sin_dims[1] % 8 != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); } + if (cos_dims[1] != sin_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache and sin_cache dimension 1 must be the same."); + } + rotary_dim = static_cast(cos_dims[1] * 2); } else if (cos_cache != nullptr || sin_cache != nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); @@ -248,6 +254,7 @@ Status CheckInputs(const Tensor* query, output_parameters->head_size = head_size; output_parameters->kv_hidden_size = kv_hidden_size; output_parameters->kv_num_heads = kv_num_heads; + output_parameters->rotary_dim = rotary_dim; output_parameters->is_packed_qkv = is_packed_qkv; output_parameters->is_unidirectional = true; output_parameters->is_prompt = is_prompt; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index d88e9a49fb5ee..afba83be34e2d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -451,7 +451,7 @@ __global__ void PastToTotalSeqlen(int32_t* seqlens_k, // Convert Past to Total sequence length tensor Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream, - const int threads_per_block) { + const int /*threads_per_block*/) { if (parameters.is_prompt) { return Status::OK(); } @@ -530,7 +530,7 @@ Status FlashAttention( device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache, batch_size, num_heads, kv_num_heads, head_size, sequence_length, - parameters.seqlen_present_kv_cache, kv_sequence_length, + parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim, scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved, parameters.is_packed_qkv)); @@ -655,7 +655,7 @@ Status EfficientAttention( template Status QkvToContext( const cudaDeviceProp& device_prop, - cublasHandle_t& cublas, + cublasHandle_t& /*cublas*/, Stream* ort_stream, contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data) { diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc b/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc index e556ae4a490e9..9c5d0e9834f6f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc @@ -136,7 +136,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, weights_data, n, input_data, k, - &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, UseTF32())); } else { // q const CudaT* q_weight = weights_data; @@ -145,7 +145,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, q_weight, n, input_data, k, - &zero, q_data, n, device_prop)); + &zero, q_data, n, device_prop, UseTF32())); // k const CudaT* k_weight = q_weight + static_cast(hidden_size) * hidden_size; CudaT* k_data = q_data + static_cast(batch_size) * sequence_length * hidden_size; @@ -153,7 +153,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, k_weight, n, input_data, k, - &zero, k_data, n, device_prop)); + &zero, k_data, n, device_prop, UseTF32())); // v const CudaT* v_weight = k_weight + static_cast(hidden_size) * hidden_size; @@ -162,7 +162,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, v_weight, n, input_data, k, - &zero, v_data, n, device_prop)); + &zero, v_data, n, device_prop, UseTF32())); } // Wait for async copy of batch_global_num @@ -195,7 +195,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(global_weights->Data()), n, input_data, k, - &zero, global_gemm_buffer, n, device_prop)); + &zero, global_gemm_buffer, n, device_prop, UseTF32())); } else { // global q const CudaT* global_q_weight = global_weights_data; @@ -205,7 +205,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, global_q_weight, n, input_data, k, - &zero, global_q, n, device_prop)); + &zero, global_q, n, device_prop, UseTF32())); } else { CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, @@ -226,7 +226,8 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { hidden_size, // ldc static_cast(max_num_global) * hidden_size, // strideC batch_size, // batch count - device_prop)); + device_prop, + UseTF32())); } // global k const CudaT* global_k_weight = global_weights_data + static_cast(hidden_size) * hidden_size; @@ -235,7 +236,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, global_k_weight, n, input_data, k, - &zero, global_k, n, device_prop)); + &zero, global_k, n, device_prop, UseTF32())); // global v const CudaT* global_v_weight = global_k_weight + static_cast(hidden_size) * hidden_size; @@ -244,7 +245,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, global_v_weight, n, input_data, k, - &zero, global_v, n, device_prop)); + &zero, global_v, n, device_prop, UseTF32())); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu index f00239460071b..c9c66b73b3e9d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu @@ -1005,7 +1005,6 @@ Status LaunchLongformerAttentionKernel( bool disable_compact_memory, bool use_merged_qkv_weights, bool use_half4) { - CublasMathModeSetter helper(device_prop, cublas, CUBLAS_TENSOR_OP_MATH); size_t softmax_workspace_size = GetLongformerSoftmaxWorkspaceSize(element_size, batch_size, num_heads, diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index f978f50c6851f..2ef011cdd9a21 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -94,6 +94,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { auto& device_prop = GetDeviceProp(); AttentionParameters parameters; + parameters.use_tf32 = UseTF32(); + ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, value, diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc index ec8b1d051b3d9..55deed55dfd33 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc @@ -268,6 +268,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* relative_position_bias = context->Input(5); PackedAttentionParameters parameters; + parameters.use_tf32 = UseTF32(); ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), @@ -308,12 +309,12 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { cublasHandle_t cublas = this->GetCublasHandle(context); // Gemm, note that CUDA assumes col-major, so result(N, M) = 1 * weights x input + 1 x bias - // The bias part is not included here since we fuse bias, transpose and output 3 matrice into one cuda kernel. + // The bias part is not included here since we fuse bias, transpose and output 3 matrices into one cuda kernel. CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(weights->Data()), n, reinterpret_cast(input->Data()), k, - &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, UseTF32())); constexpr size_t element_size = sizeof(T); constexpr bool no_qkv_workspace = false; // need workspace to add bias diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index 3b52320839403..a84a310b46ca0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -440,7 +440,7 @@ Status LaunchTransposeRemovePadding( template Status FusedScaledDotProductAttention( - const cudaDeviceProp& device_prop, + const cudaDeviceProp& /*device_prop*/, cudaStream_t stream, PackedAttentionParameters& parameters, PackedAttentionData& data) { @@ -596,7 +596,7 @@ Status UnfusedScaledDotProductAttention( q, qk_head_size, sequence_length * qk_head_size, &zero, scaled_qk, sequence_length, sequence_length * sequence_length, - batches, device_prop)); + batches, device_prop, parameters.use_tf32)); DUMP_TENSOR_D("PackedAttention unfused QK", scaled_qk, batch_size * num_heads, sequence_length, sequence_length); @@ -624,7 +624,7 @@ Status UnfusedScaledDotProductAttention( v_head_size, sequence_length, sequence_length, &one, v, v_head_size, sequence_length * v_head_size, attention_score, sequence_length, sequence_length * sequence_length, - &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop)); + &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32)); // Temp_output is BxNxSxH_v, transpose and remove padding to output token_countxNxH_v Status result = LaunchTransposeRemovePadding( diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc index 1b026e64778e3..b4a162989978c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc @@ -228,6 +228,7 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co const Tensor* relative_position_bias = context->Input(6); PackedAttentionParameters parameters; + parameters.use_tf32 = UseTF32(); ORT_RETURN_IF_ERROR(CheckInputs(query->Shape(), key, value, diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index 83af018a97ea6..982c7eaa2cb2c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -381,7 +381,7 @@ void InvokeTranspose( const T* query, const T* key, const T* value, const T* bias, T* output, const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const int v_head_size, - AttentionQkvFormat source_format, AttentionQkvFormat target_format, + [[maybe_unused]] AttentionQkvFormat source_format, AttentionQkvFormat target_format, const int32_t* token_offset, int32_t token_count, cudaStream_t stream) { if (key != nullptr && value != nullptr) { @@ -551,7 +551,7 @@ void LaunchTranspose( template Status FusedAttentionTrt( - const cudaDeviceProp& device_prop, + const cudaDeviceProp& /*device_prop*/, cudaStream_t stream, PackedAttentionParameters& parameters, PackedMultiHeadAttentionData& data) { @@ -775,7 +775,7 @@ Status UnfusedAttention( q, qk_head_size, sequence_length * qk_head_size, &zero, scaled_qk, sequence_length, sequence_length * sequence_length, - batches, device_prop)); + batches, device_prop, parameters.use_tf32)); // Q, K and V are ready now DUMP_TENSOR_INIT(); @@ -808,7 +808,7 @@ Status UnfusedAttention( v_head_size, sequence_length, sequence_length, &one, v, v_head_size, sequence_length * v_head_size, attention_score, sequence_length, sequence_length * sequence_length, - &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop)); + &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32)); // Temp_output is BxNxSxH_v, transpose and remove padding to output TxNxH_v Status result = LaunchTransposeRemovePadding( diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc index 92ba808dd85c2..05f55d9106d0e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc @@ -200,7 +200,7 @@ Status GatedRelativePositionBias::ComputeInternal(OpKernelContext* context) c D, BNS, head_size, &one, reinterpret_cast(weight_tensor.template Data()), (int)D, reinterpret_cast(workspace.get()), (int)head_size, - &zero, gemm_output, ld_gemm_output, device_prop)); + &zero, gemm_output, ld_gemm_output, device_prop, UseTF32())); auto status = LaunchGatedRelativePositionBiasKernel( device_prop, stream, diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc index 9de7ba3885c3c..ab7479f2938fe 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc @@ -82,8 +82,6 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { interleaved, device_prop.maxThreadsPerBlock, parameters.transposed); - - return Status::OK(); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu index c6637041f05bd..3a14161f29e9f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -93,7 +93,7 @@ Status LaunchRotaryEmbeddingKernel( const int num_heads, const int head_size, const int rotary_embedding_dim, - const int max_sequence_length, + const int /*max_sequence_length*/, const int position_ids_format, const bool interleaved, const int max_threads_per_block, diff --git a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu index 8fb6575d27cc0..4a4e3eeecf642 100644 --- a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu +++ b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu @@ -53,9 +53,9 @@ class FusedMHARunnerFP16v2::mhaImpl { ~mhaImpl() {} - void setup(const int S, const int B) { + void setup(const int seq_len, const int B) { // For bert and vit, use flash attention when sequence length is larger than the threshold. - use_flash_attention = is_flash_attention(S); + use_flash_attention = is_flash_attention(seq_len); params.force_unroll = use_flash_attention; @@ -68,26 +68,26 @@ class FusedMHARunnerFP16v2::mhaImpl { warps_n = 1; } else { if (sm == 70) { - if (S == 64 || S == 96) { + if (seq_len == 64 || seq_len == 96) { warps_m = 2; warps_n = 2; - } else if (S == 128) { + } else if (seq_len == 128) { warps_m = 1; warps_n = 4; - } else if (S == 256 || S == 384) { + } else if (seq_len == 256 || seq_len == 384) { warps_m = 1; warps_n = 8; } else { ORT_ENFORCE(false, "Unsupported sequence length"); } } else { - if (S == 32 || S == 64 || S == 96 || S == 128) { + if (seq_len == 32 || seq_len == 64 || seq_len == 96 || seq_len == 128) { warps_m = 2; warps_n = 2; - } else if (S == 192 || S == 256) { + } else if (seq_len == 192 || seq_len == 256) { warps_m = 1; warps_n = 4; - } else if (S == 384) { + } else if (seq_len == 384) { warps_m = 1; warps_n = 8; } else { @@ -99,7 +99,7 @@ class FusedMHARunnerFP16v2::mhaImpl { // The number of threads per CTA. threads_per_cta = warps_m * warps_n * warps_k * 32; // The number of xmmas in the M dimension. We use one uint32_t per XMMA in the M dimension. - xmmas_m = (S + 16 * warps_m - 1) / (16 * warps_m); + xmmas_m = (seq_len + 16 * warps_m - 1) / (16 * warps_m); const float scale_bmm1 = interface->mScale; const float scale_softmax = 1.f; // Seems to be only required for int8 @@ -111,7 +111,7 @@ class FusedMHARunnerFP16v2::mhaImpl { params.b = B; params.h = interface->mNumHeads; - params.s = S; + params.s = seq_len; params.d = interface->mHeadSize; params.qkv_stride_in_bytes = 3 * interface->mNumHeads * interface->mHeadSize * sizeof(half); @@ -121,7 +121,7 @@ class FusedMHARunnerFP16v2::mhaImpl { has_causal_mask = false; } - void setup_causal_masked_fmha(const int S, const int B) { + void setup_causal_masked_fmha(const int seq_len, const int B) { const float scale_bmm1 = interface->mScale; const float scale_softmax = 1.f; // Seems to be only required for int8 const float scale_bmm2 = 1.f; @@ -132,7 +132,7 @@ class FusedMHARunnerFP16v2::mhaImpl { params.b = B; params.h = interface->mNumHeads; - params.s = S; + params.s = seq_len; params.d = interface->mHeadSize; params.qkv_stride_in_bytes = 3 * interface->mNumHeads * interface->mHeadSize * sizeof(half); @@ -182,30 +182,30 @@ class FusedMHARunnerFP16v2::mhaImpl { return max_seq_len; } - int S = max_seq_len; + int seq_len = max_seq_len; if (max_seq_len <= 32) { - S = (sm == 70) ? 64 : 32; + seq_len = (sm == 70) ? 64 : 32; } else if (max_seq_len <= 64) { - S = 64; + seq_len = 64; } else if (max_seq_len <= 96) { - S = 96; + seq_len = 96; } else if (max_seq_len <= 128) { - S = 128; + seq_len = 128; } else if (max_seq_len <= 192) { - S = (sm == 70) ? 256 : 192; + seq_len = (sm == 70) ? 256 : 192; } else if (max_seq_len <= 256) { - S = 256; + seq_len = 256; } else if (max_seq_len <= 384) { - S = 384; + seq_len = 384; } - return S; + return seq_len; } protected: - bool is_flash_attention(const int S) const { + bool is_flash_attention(const int seq_len) const { ORT_ENFORCE(interface->mHasCausalMask == false); - return interface->mEnableFlashAttention && S >= kMinSequenceLengthFlashAttention; + return interface->mEnableFlashAttention && seq_len >= kMinSequenceLengthFlashAttention; } private: @@ -232,12 +232,12 @@ FusedMHARunnerFP16v2::FusedMHARunnerFP16v2(const int numHeads, pimpl(new mhaImpl(this)) { } -void FusedMHARunnerFP16v2::setup(const int S, const int B) { - MHARunner::setup(S, B); +void FusedMHARunnerFP16v2::setup(const int seq_len, const int B) { + MHARunner::setup(seq_len, B); if (mHasCausalMask) { - pimpl->setup_causal_masked_fmha(S, B); + pimpl->setup_causal_masked_fmha(seq_len, B); } else { - pimpl->setup(S, B); + pimpl->setup(seq_len, B); } } diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 8f368251f12c7..57e951d3a68ff 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -120,6 +120,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits); @@ -202,6 +203,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSqueeze); #endif +#ifdef ENABLE_CUDA_NHWC_OPS +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 16, float, GridSample); +#endif + template <> KernelCreateInfo BuildKernelCreateInfo() { KernelCreateInfo info; @@ -318,6 +323,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -406,6 +412,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, #endif +#ifdef ENABLE_CUDA_NHWC_OPS + BuildKernelCreateInfo, +#endif }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index 87e88ac31c998..dea5391c7629b 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -24,7 +24,8 @@ namespace { template struct DispatchGroupNorm { - Status operator()(cudaStream_t stream, + Status operator()(CudaTuningContext* tuning_ctx, + Stream* ort_stream, Tensor* output, Tensor* add_out, const Tensor* input, @@ -44,7 +45,8 @@ struct DispatchGroupNorm { int channels_per_block) { typedef typename ToCudaType::MappedType CudaT; return LaunchGroupNormKernel( - stream, + tuning_ctx, + ort_stream, reinterpret_cast(output->MutableData()), add_out == nullptr ? nullptr : reinterpret_cast(add_out->MutableData()), reinterpret_cast(input->Data()), @@ -209,7 +211,8 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { context->GetComputeStream()); utils::MLTypeCallDispatcher dispatcher(input->GetElementType()); - return dispatcher.InvokeRet(Stream(context), output, add_out, input, skip, bias, + return dispatcher.InvokeRet(GetTuningContext(), + context->GetComputeStream(), output, add_out, input, skip, bias, gamma, beta, workspace.get(), epsilon_, batch_size, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h index 84f3403b8d5ae..a80584d3293a0 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h @@ -126,7 +126,7 @@ struct GroupNormNHWCParams { const T* bias, const float* gamma, const float* beta, - void* workspace, + float* workspace, float epsilon, int batch_size, int num_channels, @@ -136,10 +136,10 @@ struct GroupNormNHWCParams { bool use_silu, bool broadcast_skip, int channels_per_block) { - int32_t channels_per_group = num_channels / num_groups; + int32_t channels_per_group_in = num_channels / num_groups; // channels_per_block is computed in PrePack. // If the gamma is not initializer, channels_per_block might be zero after PrePack. In that happens, compute it here. - if (channels_per_block < channels_per_group) { + if (channels_per_block < channels_per_group_in) { channels_per_block = GetChannelsPerBlock(num_channels, num_groups); } @@ -151,7 +151,7 @@ struct GroupNormNHWCParams { this->bias = bias; this->gamma = gamma; this->beta = beta; - this->group_sum_buffer = reinterpret_cast(workspace); + this->group_sum_buffer = workspace; this->n = batch_size; this->h = height; this->w = width; @@ -167,7 +167,7 @@ struct GroupNormNHWCParams { this->hw_per_block = DivUp(this->hw, blocks_per_hw); this->channels_per_block = channels_per_block; - this->channels_per_group = channels_per_group; + this->channels_per_group = channels_per_group_in; this->hwc = this->hw * this->c; this->inv_hw_channels_per_group = 1.F / (float)(this->hw * this->channels_per_group); this->groups_per_block = channels_per_block / this->channels_per_group; diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu index d7b2cc2379f4f..4909dc5e3897b 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -49,23 +49,26 @@ void GroupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) // The number of instances. grid.z = params.n; +#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ + GroupNormNHWCSumKernel \ + <<>>( \ + params.skip_workspace, params.group_sum_buffer, params.src, params.skip, params.bias, \ + params.channels_per_block, params.hw_per_block, params.hw, params.hwc, params.c, \ + params.channels_per_group, params.groups, params.groups_per_block, params.broadcast_skip); \ + break; + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. switch (params.threads_per_block) { case 256: - GroupNormNHWCSumKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SUM(256, CHANNELS_PER_THREAD) case 192: - GroupNormNHWCSumKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SUM(192, CHANNELS_PER_THREAD) case 160: - GroupNormNHWCSumKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SUM(160, CHANNELS_PER_THREAD) case 128: - GroupNormNHWCSumKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SUM(128, CHANNELS_PER_THREAD) case 64: - GroupNormNHWCSumKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SUM(64, CHANNELS_PER_THREAD) } } @@ -80,29 +83,34 @@ void GroupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t strea // The number of instances. grid.z = params.n; +#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ + GroupNormNHWCScaleKernel \ + <<>>( \ + params.dst, params.src, params.skip, params.gamma, params.beta, params.skip_workspace, \ + params.group_sum_buffer, params.epsilon, params.c, params.channels_per_block, params.channels_per_group, \ + params.groups, params.hwc, params.inv_hw_channels_per_group, params.hw, params.hw_per_block, \ + params.use_silu); \ + break; + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. switch (params.threads_per_block) { case 256: - GroupNormNHWCScaleKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SCALE(256, CHANNELS_PER_THREAD) case 192: - GroupNormNHWCScaleKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SCALE(192, CHANNELS_PER_THREAD) case 160: - GroupNormNHWCScaleKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SCALE(160, CHANNELS_PER_THREAD) case 128: - GroupNormNHWCScaleKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SCALE(128, CHANNELS_PER_THREAD) case 64: - GroupNormNHWCScaleKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SCALE(64, CHANNELS_PER_THREAD) } } template Status LaunchGroupNormKernel( - cudaStream_t stream, + CudaTuningContext* tuning_ctx, + Stream* ort_stream, T* output, T* add_out, const T* input, @@ -120,7 +128,11 @@ Status LaunchGroupNormKernel( bool use_silu, bool broadcast_skip, int channels_per_block) { - GroupNormNHWCParams params(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon, + + // tuning_ctx only used for ROCm EP. + ORT_UNUSED_PARAMETER(tuning_ctx); + + GroupNormNHWCParams params(output, add_out, input, skip, bias, gamma, beta, reinterpret_cast(workspace), epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, channels_per_block); @@ -135,6 +147,7 @@ Status LaunchGroupNormKernel( " groups=", num_groups); } + auto stream = static_cast(ort_stream->GetHandle()); CUDA_RETURN_IF_ERROR(cudaMemsetAsync( params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream)); @@ -150,14 +163,14 @@ Status LaunchGroupNormKernel( return Status::OK(); } -template Status LaunchGroupNormKernel(cudaStream_t stream, half* output, half* add_out, +template Status LaunchGroupNormKernel(CudaTuningContext* tuning_ctx, Stream* stream, half* output, half* add_out, const half* input, const half* skip, const half* bias, const float* gamma, const float* beta, void* workspace, float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, bool silu, bool broadcast_skip, int channels_per_block); -template Status LaunchGroupNormKernel(cudaStream_t stream, float* output, float* add_out, +template Status LaunchGroupNormKernel(CudaTuningContext* tuning_ctx, Stream* stream, float* output, float* add_out, const float* input, const float* skip, const float* bias, const float* gamma, const float* beta, void* workspace, float epsilon, int batch_size, int num_channels, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h index 9532aeecb2f57..98f38a1475eee 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h @@ -8,6 +8,8 @@ #include #include +#include "core/providers/cuda/tunable/cuda_tunable.h" + namespace onnxruntime { namespace contrib { namespace cuda { @@ -21,7 +23,8 @@ int GetChannelsPerBlock(int num_channels, int num_groups); template Status LaunchGroupNormKernel( - cudaStream_t stream, + CudaTuningContext* tuning_ctx, + Stream* ort_stream, T* output, // normalized output tensor. Shape is (n, h, w, c) T* add_out, // optional output tensor for element-wise sum of input + skip + bias. Shape is (n, h, w, c) const T* input, // input tensor. Shape is (n, h, w, c) diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh index 081e9a3de578c..ecd06315e3708 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh @@ -21,9 +21,9 @@ // Licensed under the MIT License. #pragma once #include +#include #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cu_inc/common.cuh" -#include "contrib_ops/cuda/diffusion/group_norm_impl.h" using namespace onnxruntime::cuda; @@ -54,11 +54,21 @@ struct GroupSumsOp { } }; -template -inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq); +template +inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq) { + using VecT = onnxruntime::cuda::aligned_vector; + const VecT input_v = *reinterpret_cast(src + offset); + +#pragma unroll + for (int i = 0; i < ILP; i++) { + const float val = static_cast(input_v.val[i]); + sum += val; + sum_sq += val * val; + } +} template <> -inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) { +inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) { // Fetch two channels per thread. __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); @@ -72,7 +82,7 @@ inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, fl } template <> -inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { +inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { // Fetch two channels per thread. float2 f2 = *reinterpret_cast(&src[offset]); @@ -84,13 +94,28 @@ inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, f } // Sum for SkipGroupNorm: add_out[offset] = src[offset] + skip[skip_offset] + bias[bias_offset] -template +template inline __device__ void AddSkipBias(T* add_out, const T* src, const T* skip, const T* bias, - int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq); + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + using VecT = onnxruntime::cuda::aligned_vector; + const VecT input_v = *reinterpret_cast(src + offset); + const VecT skip_v = *reinterpret_cast(skip + skip_offset); + const VecT bias_v = *reinterpret_cast(bias + bias_offset); + VecT output_v = *reinterpret_cast(add_out + offset); + +#pragma unroll + for (int i = 0; i < ILP; i++) { + output_v.val[i] = input_v.val[i] + skip_v.val[i] + bias_v.val[i]; + const float val = static_cast(output_v.val[i]); + sum += val; + sum_sq += val * val; + } + *(reinterpret_cast(add_out + offset)) = output_v; +} template <> -inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias, - int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { +inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { // Fetch two channels per thread. __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); @@ -106,8 +131,8 @@ inline __device__ void AddSkipBias(half* add_out, const half* src, const half* s } template <> -inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias, - int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { +inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { float2 f2 = *reinterpret_cast(&src[offset]); float2 s = *reinterpret_cast(&skip[skip_offset]); float2 b = *reinterpret_cast(&bias[bias_offset]); @@ -121,13 +146,27 @@ inline __device__ void AddSkipBias(float* add_out, const float* src, const float } // Sum for SkipGroupNorm without bias: add_out[offset] = src[offset] + skip[skip_offset] -template +template inline __device__ void AddSkip(T* add_out, const T* src, const T* skip, - int64_t offset, int64_t skip_offset, float& sum, float& sum_sq); + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + using VecT = onnxruntime::cuda::aligned_vector; + const VecT input_v = *reinterpret_cast(src + offset); + const VecT skip_v = *reinterpret_cast(skip + skip_offset); + VecT output_v = *reinterpret_cast(add_out + offset); + +#pragma unroll + for (int i = 0; i < ILP; i++) { + output_v.val[i] = input_v.val[i] + skip_v.val[i]; + const float val = static_cast(output_v.val[i]); + sum += val; + sum_sq += val * val; + } + *(reinterpret_cast(add_out + offset)) = output_v; +} template <> -inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, - int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { +inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); h2 = h2 + s; @@ -140,8 +179,8 @@ inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, } template <> -inline __device__ void AddSkip(float* add_out, const float* src, const float* skip, - int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { +inline __device__ void AddSkip(float* add_out, const float* src, const float* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { float2 f2 = *reinterpret_cast(&src[offset]); float2 s = *reinterpret_cast(&skip[skip_offset]); f2.x += s.x; @@ -151,8 +190,10 @@ inline __device__ void AddSkip(float* add_out, const float* src, const float* sk sum_sq += f2.x * f2.x + f2.y * f2.y; } -template -__global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { +template +__global__ void GroupNormNHWCSumKernel(T* skip_workspace, float* group_sum_buffer, const T* src, const T* skip, const T* bias, + int32_t channels_per_block, int32_t hw_per_block, int32_t hw, int32_t hwc, int32_t c, + int32_t channels_per_group, int32_t groups, int32_t groups_per_block, bool broadcast_skip) { // The object in charge of doing the sums for the different blocks. typedef cub::BlockScan BlockScan; @@ -166,60 +207,60 @@ __global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { int32_t ni = blockIdx.z; // The channel loaded by that thread. - int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; + int32_t ci = blockIdx.x * channels_per_block + threadIdx.x * ILP; - if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { + if (ci >= c || threadIdx.x * ILP >= channels_per_block) { return; } // The first activation loaded by that block. - int32_t hw_begin = blockIdx.y * params.hw_per_block; + int32_t hw_begin = blockIdx.y * hw_per_block; // The last activation loaded by that block. - int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); + int32_t hw_end = min(hw_begin + hw_per_block, hw); // The sums. float sum = 0.F; float sum_sq = 0.F; // Iterate over the activations to compute the sums. - int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; - if (params.skip != nullptr) { + int64_t offset = static_cast(ni) * hwc + static_cast(hw_begin) * c + ci; + if (skip != nullptr) { // SkipGroupNorm: skip is (n, h, w, c) or (n, 1, 1, c) or (n, c), bias is (c), and add_out is (n, h, w, c) const int64_t bias_offset = static_cast(ci); - T* add_out = params.skip_workspace; - if (params.broadcast_skip) { - const int64_t skip_offset = static_cast(ni) * params.c + ci; + T* add_out = skip_workspace; + if (broadcast_skip) { + const int64_t skip_offset = static_cast(ni) * c + ci; - if (params.bias != nullptr) { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkipBias(add_out, params.src, params.skip, params.bias, offset, skip_offset, bias_offset, sum, sum_sq); + if (bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + AddSkipBias(add_out, src, skip, bias, offset, skip_offset, bias_offset, sum, sum_sq); } } else { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkip(add_out, params.src, params.skip, offset, skip_offset, sum, sum_sq); + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + AddSkip(add_out, src, skip, offset, skip_offset, sum, sum_sq); } } } else { - if (params.bias != nullptr) { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkipBias(add_out, params.src, params.skip, params.bias, offset, offset, bias_offset, sum, sum_sq); + if (bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + AddSkipBias(add_out, src, skip, bias, offset, offset, bias_offset, sum, sum_sq); } } else { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkip(add_out, params.src, params.skip, offset, offset, sum, sum_sq); + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + AddSkip(add_out, src, skip, offset, offset, sum, sum_sq); } } } } else { // GroupNorm - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - UpdateSum(params.src, offset, sum, sum_sq); + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + UpdateSum(src, offset, sum, sum_sq); } } // The group index relative to the first group within the same block. - int32_t gi = threadIdx.x * CHANNELS_PER_THREAD / params.channels_per_group; + int32_t gi = threadIdx.x * ILP / channels_per_group; // The channel in the group. - int32_t cj = ci % params.channels_per_group; + int32_t cj = ci % channels_per_group; // The data for the summations. GroupSums inp{cj == 0 ? 1 : 0, sum, sum_sq}; @@ -230,7 +271,7 @@ __global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { // Store the results for the groups in shared memory (to produce coalesced stores later). // For each group, only the last thread of that group is picked to save sum to shared memory. - if (cj == params.channels_per_group - CHANNELS_PER_THREAD) { + if (cj == channels_per_group - ILP) { smem[gi] = make_float2(out.sum, out.sum_sq); } @@ -238,20 +279,41 @@ __global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { __syncthreads(); // Threads that have nothing left to do, exit. - if (threadIdx.x >= params.groups_per_block) { + if (threadIdx.x >= groups_per_block) { return; } // The global group index. // Use neighboring threads for coalesced write. - int32_t gj = blockIdx.x * params.groups_per_block + threadIdx.x; + int32_t gj = blockIdx.x * groups_per_block + threadIdx.x; - if (gj < params.groups) { + if (gj < groups) { float2 sums = smem[threadIdx.x]; - const int index = (2 * ni) * params.groups + gj; - atomicAdd(¶ms.group_sum_buffer[index], sums.x); - atomicAdd(¶ms.group_sum_buffer[index + params.groups], sums.y); + const int index = (2 * ni) * groups + gj; + atomicAdd(&group_sum_buffer[index], sums.x); + atomicAdd(&group_sum_buffer[index + groups], sums.y); + } +} + +template +__device__ void computeGroupNormVec(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev, + const float* gamma_v, const float* beta_v, bool silu) { + using VecT = onnxruntime::cuda::aligned_vector; + const VecT input_v = *reinterpret_cast(src + offset); + VecT output_v; + +#pragma unroll + for (int i = 0; i < ILP; i++) { + float val = static_cast(input_v.val[i]); + val = (val - mean) * inv_std_dev; + val = gamma_v[i] * val + beta_v[i]; + + if (silu) { + val = val * sigmoid(val); + } + output_v.val[i] = static_cast(val); } + *(reinterpret_cast(dst + offset)) = output_v; } template @@ -307,11 +369,51 @@ __device__ void ComputeGroupNorm(const float* src, float* dst, int64_t offset, f *reinterpret_cast(&dst[offset]) = f2; } -template -__global__ void GroupNormNHWCScaleKernel(GroupNormNHWCParams params) { +template +__device__ void ComputeGroupNormKernel(const T* input, T* dst, int64_t offset, float mean, float inv_std_dev, + const float* gamma, const float* beta, bool use_silu, int32_t c, int32_t ci, int32_t hw_begin, int32_t hw_end) { + using VecF = onnxruntime::cuda::aligned_vector; + + const VecF gamma_v = *reinterpret_cast(gamma + ci); + const VecF beta_v = *reinterpret_cast(beta + ci); + // Iterate over the activations to compute the sums. + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + // Fetch ILP channels per thread. + computeGroupNormVec(input, dst, offset, mean, inv_std_dev, gamma_v.val, beta_v.val, use_silu); + } +} + +template <> +__device__ void ComputeGroupNormKernel(const float* input, float* dst, int64_t offset, float mean, float inv_std_dev, + const float* gamma, const float* beta, bool use_silu, int32_t c, int32_t ci, int32_t hw_begin, int32_t hw_end) { + // Load gamma/beta. Fetch two per thread. + float2 gamma_f2 = *reinterpret_cast(&gamma[ci]); + float2 beta_f2 = *reinterpret_cast(&beta[ci]); + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + ComputeGroupNorm(input, dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, use_silu); + } +} + +template <> +__device__ void ComputeGroupNormKernel(const half* input, half* dst, int64_t offset, float mean, float inv_std_dev, + const float* gamma, const float* beta, bool use_silu, int32_t c, int32_t ci, int32_t hw_begin, int32_t hw_end) { + // Load gamma/beta. Fetch two per thread. + float2 gamma_f2 = *reinterpret_cast(&gamma[ci]); + float2 beta_f2 = *reinterpret_cast(&beta[ci]); + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + ComputeGroupNorm(input, dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, use_silu); + } +} + +template +__global__ void GroupNormNHWCScaleKernel(T* dst, const T* src, const T* skip, const float* gamma, const float* beta, + const T* skip_workspace, const float* group_sum_buffer, float epsilon, + int32_t c, int32_t channels_per_block, int32_t channels_per_group, + int32_t groups, int32_t hwc, float inv_hw_channels_per_group, + int32_t hw, int32_t hw_per_block, bool use_silu) { // The channel loaded by that thread. - int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; - if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { + int32_t ci = blockIdx.x * channels_per_block + threadIdx.x * ILP; + if (ci >= c || threadIdx.x * ILP >= channels_per_block) { return; } @@ -319,35 +421,29 @@ __global__ void GroupNormNHWCScaleKernel(GroupNormNHWCParams params) { int32_t ni = blockIdx.z; // The group that thread works on. - int32_t gi = ci / params.channels_per_group; + int32_t gi = ci / channels_per_group; // Load the sum and sum of squares for the group. float sum = 0.F, sum_sq = 0.F; - if (gi < params.groups) { - const int index = (2 * ni) * params.groups + gi; - sum = params.group_sum_buffer[index]; - sum_sq = params.group_sum_buffer[index + params.groups]; + if (gi < groups) { + const int index = (2 * ni) * groups + gi; + sum = group_sum_buffer[index]; + sum_sq = group_sum_buffer[index + groups]; } - // Load gamma/beta. Fetch two per thread. - float2 gamma_f2 = *reinterpret_cast(¶ms.gamma[ci]); - float2 beta_f2 = *reinterpret_cast(¶ms.beta[ci]); - // Compute the mean. - float mean = sum * params.inv_hw_channels_per_group; + float mean = sum * inv_hw_channels_per_group; // Compute the variance. - float var = sum_sq * params.inv_hw_channels_per_group - (mean * mean); + float var = sum_sq * inv_hw_channels_per_group - (mean * mean); // Compute the inverse of the stddev. - float inv_std_dev = rsqrtf(var + params.epsilon); + float inv_std_dev = rsqrtf(var + epsilon); - int32_t hw_begin = blockIdx.y * params.hw_per_block; - int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); + int32_t hw_begin = blockIdx.y * hw_per_block; + int32_t hw_end = min(hw_begin + hw_per_block, hw); - const T* input = (params.skip != nullptr) ? params.skip_workspace : params.src; - int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - ComputeGroupNorm(input, params.dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, params.use_silu); - } + const T* input = (skip != nullptr) ? skip_workspace : src; + int64_t offset = static_cast(ni) * hwc + static_cast(hw_begin) * c + ci; + ComputeGroupNormKernel(input, dst, offset, mean, inv_std_dev, gamma, beta, use_silu, c, ci, hw_begin, hw_end); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/grid_sample.cc b/onnxruntime/contrib_ops/cuda/grid_sample.cc index 4c2999c279e0a..2500de39d3536 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample.cc +++ b/onnxruntime/contrib_ops/cuda/grid_sample.cc @@ -9,22 +9,23 @@ namespace onnxruntime { namespace contrib { namespace cuda { -#define REGISTER_KERNEL_TYPED(T) \ +#define REGISTER_KERNEL_TYPED(T, VERSION, LAYOUT, DOMAIN) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ GridSample, \ - kMSDomain, \ - 1, \ + DOMAIN, \ + VERSION, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ - GridSample); + onnxruntime::contrib::cuda::GridSample); -REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(float, 1, LAYOUT_NCHW, kMSDomain) +REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NHWC, kMSInternalNHWCDomain) -template -GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) { +template +GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) { std::string mode_str = info.GetAttrOrDefault("mode", "bilinear"); std::string padding_mode_str = info.GetAttrOrDefault("padding_mode", "zeros"); align_corners_ = static_cast(info.GetAttrOrDefault("align_corners", 0)); @@ -48,8 +49,8 @@ GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) { } } -template -Status GridSample::ComputeInternal(OpKernelContext* context) const { +template +Status GridSample::ComputeInternal(OpKernelContext* context) const { const Tensor* X = context->Input(0); const auto& dims_input = X->Shape().GetDims(); const Tensor* Grid = context->Input(1); @@ -61,11 +62,13 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { ORT_ENFORCE(dims_grid[0] == dims_input[0], "Grid batch size ", dims_grid[0], " does not match input batch size ", dims_input[0]); ORT_ENFORCE(dims_grid[3] == 2, "Last dimension of grid: ", dims_grid[3], ", expect 2"); + using Ch = Channels; + TensorShapeVector dims_output(4); - dims_output[0] = dims_input[0]; - dims_output[1] = dims_input[1]; - dims_output[2] = dims_grid[1]; - dims_output[3] = dims_grid[2]; + dims_output[Ch::N] = dims_input[Ch::N]; + dims_output[Ch::C] = dims_input[Ch::C]; + dims_output[Ch::H] = dims_grid[1 /* Grid::H */]; + dims_output[Ch::W] = dims_grid[2 /* Grid::W */]; Tensor* Y = context->Output(0, dims_output); // Return early if the output tensor is going to be of size 0 if (Y->Shape().Size() == 0) { @@ -74,7 +77,7 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; CudaT* Y_data = reinterpret_cast(Y->MutableData()); - GridSampleImpl( + GridSampleImpl( Stream(context), reinterpret_cast(X->Data()), reinterpret_cast(Grid->Data()), @@ -89,4 +92,8 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { } } // namespace cuda } // namespace contrib + +namespace cuda { +REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NCHW, kOnnxDomain) +} // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/grid_sample.h b/onnxruntime/contrib_ops/cuda/grid_sample.h index 08ca58c7cc458..16581bfe77482 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample.h +++ b/onnxruntime/contrib_ops/cuda/grid_sample.h @@ -12,7 +12,7 @@ namespace cuda { using namespace onnxruntime::cuda; -template +template class GridSample final : public CudaKernel { public: explicit GridSample(const OpKernelInfo& info); diff --git a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu index 8a391eca7e86a..b23da635bc83d 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu +++ b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu @@ -50,28 +50,34 @@ __device__ T GsReflect(T x, float x_min, float x_max) { return static_cast(fx); } -template +template __device__ T PixelAtGrid(const T* input_data, int64_t bIdx, int64_t cIdx, int64_t y, int64_t x, - int64_t padding_mode, int64_t N, int64_t C, int64_t H, int64_t W, float border[4]) { + int64_t padding_mode, int64_t N, int64_t C, int64_t H, int64_t W, float border[4]) { T pixel = 0.0f; + + auto PixelOffset = [bIdx, cIdx, C, H, W](int64_t x, int64_t y) -> int64_t { + return Layout == LAYOUT_NCHW + ? (bIdx * C * H * W + cIdx * H * W + y * W + x) + : (bIdx * H * W * C + y * W * C + x * C + cIdx); + }; + if (padding_mode == 0) { // zeros if (x >= 0 && x < W && y >= 0 && y < H) { - pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x]; + pixel = input_data[PixelOffset(x, y)]; } - } else if (padding_mode == 1) { //border + } else if (padding_mode == 1) { // border x = max((int64_t)0, min((int64_t)W - 1, (int64_t)x)); y = max((int64_t)0, min((int64_t)H - 1, (int64_t)y)); - pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x]; + pixel = input_data[PixelOffset(x, y)]; } else { // Reflection - x = (int64_t) GsReflect(x, border[0], border[2]); - y = (int64_t) GsReflect(y, border[1], border[3]); - pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x]; + x = (int64_t)GsReflect(x, border[0], border[2]); + y = (int64_t)GsReflect(y, border[1], border[3]); + pixel = input_data[PixelOffset(x, y)]; } return pixel; } -__device__ void GsGetCubicCoeffs(float x, float coeffs[4]) -{ +__device__ void GsGetCubicCoeffs(float x, float coeffs[4]) { float cubic_alpha = -0.75f; x = abs(x); coeffs[0] = (((cubic_alpha * (x + 1) - 5 * cubic_alpha) * (x + 1) + 8 * cubic_alpha) * (x + 1) - 4 * cubic_alpha); @@ -93,7 +99,7 @@ __device__ T GsBicubicInterpolate(T p[4][4], float x, float y) { return pixel; } -template +template __global__ void _GridSampleKernel( const T* input_data, const T* grid_data, @@ -110,16 +116,32 @@ __global__ void _GridSampleKernel( { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(idx, N * C * H_out * W_out); // extract batch index, channel index, y index, x index for current thread - int BIdx = idx / (C * H_out * W_out ); - int tmpBCnt = BIdx * (C * H_out * W_out); + int BIdx, yIdx, xIdx, cIdx; + if constexpr (Layout == LAYOUT_NCHW) { + BIdx = idx / (C * H_out * W_out); + int tmpBCnt = BIdx * (C * H_out * W_out); + + cIdx = (idx - tmpBCnt) / (H_out * W_out); + int tmpCCnt = tmpBCnt + cIdx * (H_out * W_out); - int cIdx = (idx - tmpBCnt) / (H_out * W_out); - int tmpCCnt = tmpBCnt + cIdx * (H_out * W_out); + yIdx = (idx - tmpCCnt) / W_out; + int tmpHCnt = tmpCCnt + yIdx * W_out; - int yIdx = (idx - tmpCCnt) / W_out; - int tmpHCnt = tmpCCnt + yIdx * W_out; + xIdx = (idx - tmpHCnt); + } else { + static_assert(Layout == LAYOUT_NHWC, "Unsupported layout"); - int xIdx = (idx - tmpHCnt); + BIdx = idx / (H_out * W_out * C); + int tmpBCnt = BIdx * (H_out * W_out * C); + + yIdx = (idx - tmpBCnt) / (W_out * C); + int tmpHCnt = tmpBCnt + yIdx * (W_out * C); + + xIdx = (idx - tmpHCnt) / C; + int tmpWCnt = tmpHCnt + xIdx * C; + + cIdx = (idx - tmpWCnt); + } int grid_idx = BIdx * H_out * W_out + yIdx * W_out + xIdx; T grid_X = grid_data[grid_idx * 2 + 0]; @@ -147,8 +169,9 @@ __global__ void _GridSampleKernel( if (grid_x_imgSpace < x_min || grid_x_imgSpace > x_max || grid_y_imgSpace < y_min || grid_y_imgSpace > y_max) { // out of bound if (padding_mode == 1) { // border - grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f)); - grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f)); + // Clamping must not be done here, see #10607 + // grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f)); + // grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f)); } else if (padding_mode == 2) { // reflection grid_x_imgSpace = GsReflect(grid_x_imgSpace, x_min, x_max); grid_y_imgSpace = GsReflect(grid_y_imgSpace, y_min, y_max); @@ -175,10 +198,10 @@ __global__ void _GridSampleKernel( w_lb = w_b * w_l; w_rb = w_b * w_r; - T lt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border); - T rt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border); - T lb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border); - T rb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border); + T lt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border); + T rt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border); + T lb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border); + T rb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border); T interpoV = w_lt * lt_v + w_rt * rt_v + w_lb * lb_v + w_rb * rb_v; output_data[outIdx] = interpoV; return; @@ -186,7 +209,8 @@ __global__ void _GridSampleKernel( if (mode == 1) { // nearest int x_n = grid_x_imgSpace; int y_n = grid_y_imgSpace; - output_data[outIdx] = PixelAtGrid(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border); + output_data[outIdx] = + PixelAtGrid(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border); return; } if (mode == 2) { // bicubic @@ -195,7 +219,8 @@ __global__ void _GridSampleKernel( T p[4][4] = {}; // [H][W] for (int64_t h = 0; h < 4; h++) { for (int64_t w = 0; w < 4; w++) { - p[h][w] = PixelAtGrid(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border); + p[h][w] = + PixelAtGrid(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border); } } T dx = grid_x_imgSpace - x0 - 1; @@ -204,7 +229,7 @@ __global__ void _GridSampleKernel( } } -template +template void GridSampleImpl( cudaStream_t stream, const T* input_data, @@ -216,17 +241,23 @@ void GridSampleImpl( const int64_t H_out, const int64_t W_out, T* output_data) { - int blocksPerGrid = (int)(ceil(static_cast(dims[0] * dims[1] * H_out * W_out) / GridDim::maxThreadsPerBlock)); - _GridSampleKernel<<>>( - input_data, grid_data, mode, padding_mode, align_corners, dims[0], dims[1], dims[2], dims[3], H_out, W_out, output_data); + using Ch = Channels; + + int blocksPerGrid = static_cast( + ceil(static_cast(dims[Ch::N] * dims[Ch::C] * H_out * W_out) / GridDim::maxThreadsPerBlock)); + _GridSampleKernel<<>>( + input_data, grid_data, mode, padding_mode, align_corners, + dims[Ch::N], dims[Ch::C], dims[Ch::H], dims[Ch::W], + H_out, W_out, output_data); } -#define SPECIALIZED_IMPL(T) \ - template void GridSampleImpl(cudaStream_t stream, const T* input_data, const T* grid_data, \ - const int64_t mode, const int64_t padding_mode, const int64_t align_corners, \ - const int64_t[4], const int64_t H_out, const int64_t W_out, T* output_data); +#define SPECIALIZED_IMPL(T, IsNHWC) \ + template void GridSampleImpl(cudaStream_t stream, const T* input_data, const T* grid_data, \ + const int64_t mode, const int64_t padding_mode, const int64_t align_corners, \ + const int64_t[4], const int64_t H_out, const int64_t W_out, T* output_data); -SPECIALIZED_IMPL(float) +SPECIALIZED_IMPL(float, false) // NCHW +SPECIALIZED_IMPL(float, true) // NHWC } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/grid_sample_impl.h b/onnxruntime/contrib_ops/cuda/grid_sample_impl.h index 6df86ce161908..62cd66a48fa84 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample_impl.h +++ b/onnxruntime/contrib_ops/cuda/grid_sample_impl.h @@ -8,7 +8,7 @@ namespace onnxruntime { namespace contrib { namespace cuda { -template +template void GridSampleImpl( cudaStream_t stream, const T* input_data, diff --git a/onnxruntime/contrib_ops/cuda/inverse.cc b/onnxruntime/contrib_ops/cuda/inverse.cc index 81e161e60642c..9075dda26f86b 100644 --- a/onnxruntime/contrib_ops/cuda/inverse.cc +++ b/onnxruntime/contrib_ops/cuda/inverse.cc @@ -78,9 +78,9 @@ struct Inverse::ComputeImpl { cudaStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; // Make a copy of the input which will serve as a workspace as well. - if (std::is_same::value || std::is_same::value) { + if constexpr (std::is_same::value || std::is_same::value) { IAllocatorUniquePtr input_workspace = inst->GetScratchBuffer(input_count, ort_stream); - if (std::is_same::value) { + if constexpr (std::is_same::value) { // Convert from MLFloat16(half) to float Impl_Cast(stream, reinterpret_cast(input.Data()), input_workspace.get(), input_count); } else { @@ -96,7 +96,7 @@ struct Inverse::ComputeImpl { // Need to compute ptrs for output buffers // Output for MLFloat IAllocatorUniquePtr output_ptrs = inst->GetScratchBuffer(n_batches, ort_stream); - if (std::is_same::value) { + if constexpr (std::is_same::value) { IAllocatorUniquePtr ml_float_output = inst->GetScratchBuffer(input_count, ort_stream); ORT_RETURN_IF_ERROR(ComputeMatrixOffsets(stream, ml_float_output.get(), num_batches, rows, output_ptrs)); // Do the inverse @@ -112,7 +112,7 @@ struct Inverse::ComputeImpl { ORT_RETURN_IF_ERROR(CheckForSingularity(stream, info, info_cpu, num_batches)); // We are done here } - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { IAllocatorUniquePtr input_workspace = inst->GetScratchBuffer(static_cast(input_count), ort_stream); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_workspace.get(), input.Data(), sizeof(double) * input_count, cudaMemcpyDeviceToDevice, stream)); diff --git a/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu b/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu index ca94477114ee2..47a64502b3480 100644 --- a/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu +++ b/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu @@ -97,8 +97,8 @@ void ComplexMul_Impl( const TArray* rhs_padded_strides, const T* rhs_data, const TArray* fdm_output_strides, - const onnxruntime::cuda::fast_divmod& fdm_H, - const onnxruntime::cuda::fast_divmod& fdm_C, + const onnxruntime::cuda::fast_divmod& /*fdm_H*/, + const onnxruntime::cuda::fast_divmod& /*fdm_C*/, T* output_data, int64_t count, int64_t lhs_size, diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu index 064b6dd392437..28ab27ee33d10 100644 --- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -174,7 +174,7 @@ Status GemmFloat8::ComputeGemm( int32_t dtype_A, int32_t dtype_B, int32_t dtype_C, int32_t dtype_Y, const TensorShape& shape_A, const TensorShape& shape_B, - const TensorShape& shape_C, const TensorShape& shape_Y, + const TensorShape& shape_C, const TensorShape& /*shape_Y*/, bool trans_A, bool trans_B, const void* p_input_a, const void* p_input_b, const void* p_input_c, const void* p_scale_a, const void* p_scale_b, const void* p_scale_y, void* p_output_y, int M, int N, int K, int lda, diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h index bfe30b71170d8..cfe306c2482a5 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h @@ -202,7 +202,7 @@ struct MoeFCGemm { total_rows_before_expert(total_rows_before_expert), gemm_n(gemm_n), gemm_k(gemm_k), - host_problem_sizes(nullptr) { + host_problem_sizes(host_problem_sizes) { if (platform::is_same::value || platform::is_same::value) { assert(weight_scales); } diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h index 66950c9b65970..a3dcf0da16b98 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -20,6 +20,12 @@ #pragma GCC diagnostic ignored "-Wstrict-aliasing" #endif +// Ignore CUTLASS warning C4100: unreferenced formal parameter +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#endif + #include "cutlass/array.h" #include "cutlass/numeric_conversion.h" #include "cutlass/layout/matrix.h" @@ -36,6 +42,10 @@ #include "layout_traits_helper.h" #include "moe_cutlass_kernel.h" +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + #ifdef __GNUC__ #pragma GCC diagnostic pop #endif @@ -149,10 +159,10 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w template struct dispatch_stages { - static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, - CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, - int* occupancy = nullptr) { + static void dispatch(const T* /*A*/, const WeightType* /*B*/, const T* /*weight_scales*/, const T* /*biases*/, + T* /*C*/, int64_t* /*total_rows_before_expert*/, int64_t /*gemm_n*/, int64_t /*gemm_k*/, + int /*num_experts*/, CutlassGemmConfig /*gemm_config*/, int /*multi_processor_count*/, + cudaStream_t /*stream*/, [[maybe_unused]] int* occupancy = nullptr) { std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); ORT_THROW("[FT Error][dispatch_stages::dispatch] " + err_msg); @@ -221,9 +231,10 @@ template < typename T, typename WeightType, typename arch, typename EpilogueTag, typename std::enable_if::value && std::is_same::value>::type* = nullptr> void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, - int num_experts, CutlassGemmConfig gemm_config, int sm_version, - int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { + int64_t* total_rows_before_expert, int64_t /*total_rows*/, + int64_t gemm_n, int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config, + int /*sm_version*/, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) { switch (gemm_config.tile_config) { case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: dispatch_gemm_config, @@ -300,8 +311,8 @@ void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weig template ::value>::type* = nullptr> void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, - int num_experts, CutlassGemmConfig gemm_config, int sm_version, + int64_t* total_rows_before_expert, int64_t /*total_rows*/, int64_t gemm_n, int64_t gemm_k, + int num_experts, CutlassGemmConfig gemm_config, int /*sm_version*/, int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { switch (gemm_config.tile_config) { case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index f4f2b49032d23..a5b47bcddefbc 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -370,7 +370,7 @@ struct TopkConstants { template void topk_gating_softmax_launcher_helper(const T* input, const bool* finished, T* output, int* indices, int* source_row, - int num_rows, int num_experts, int k, cudaStream_t stream) { + int num_rows, int /*num_experts*/, int k, cudaStream_t stream) { static constexpr unsigned long MAX_BYTES_PER_LDG = 16; static constexpr int BYTES_PER_LDG = std::min((int)MAX_BYTES_PER_LDG, (int)sizeof(T) * EXPERTS); @@ -599,7 +599,7 @@ void CutlassMoeFCRunner::run_moe_fc( static constexpr bool scales_required = std::is_same::value || std::is_same::value; - if (scales_required) { + if constexpr (scales_required) { if (fc1_scales == nullptr) { ORT_THROW("[FT Error][Run MoE FC] Scales expected but scale for first matmul is a null pointer"); } else if (fc2_scales == nullptr) { diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h index 00f977c615df6..1de8f6b69642c 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h @@ -276,13 +276,13 @@ struct MoeProblemVisitor::ComputeInternal(OpKernelContext* context) const { const Tensor* past_tensor = context->Input(8); AttentionParameters parameters; + parameters.use_tf32 = UseTF32(); + ORT_RETURN_IF_ERROR(CheckInputs(input, weights, bias, @@ -152,7 +154,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { CudaT dequant_scale; CudaT input_scale = *(reinterpret_cast(input_scale_tensor->Data())); CudaT weight_scale = *(reinterpret_cast(weight_scale_tensor->Data())); - if (sizeof(T) == 2) { + if constexpr (sizeof(T) == 2) { dequant_scale = __float2half(__half2float(input_scale) * __half2float(weight_scale)); } else { dequant_scale = input_scale * weight_scale; diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu index 6b66f1d84e221..265adf22eeb61 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu @@ -2,10 +2,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include #include #include +#include #include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" @@ -21,7 +23,7 @@ namespace cuda { __device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, half scale, half zp, half* output) { half2 scale_half2 = {scale, scale}; - half zp_adjust = -scale * __short2half_rn(zp); + half zp_adjust = -scale * zp; half2 zp_adjust2 = {zp_adjust, zp_adjust}; alignas(16) half2 results[4]; @@ -56,41 +58,95 @@ __device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, f } template -__global__ void Dequantize4BitsKernel( +__global__ void Dequantize4BitsKernelReOrder( T* output, const uint8_t* quant_data, const T* scale_data, const uint8_t* zero_points, + const int32_t* reorder_idx, int block_size, - int blocks_per_K, - int blocks_per_threadblock, - int total_blks, - int shift) { - int block_id = blockIdx.x * blocks_per_threadblock + ((threadIdx.x * 8) >> shift); - if (block_id >= total_blks) { + int groups_per_K, + int groups_per_threadblock, + int total_groups) { + int group_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * 8) / block_size); + if (group_id >= total_groups) { return; } - int n_idx = block_id / blocks_per_K; - int kb_idx = block_id % blocks_per_K; - int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1)); + // T __shared__ zero_points_after_reorder[];//K + // T __shared__ scales_after_reorder[]; // K + // const int num_r_per_thread = k / 256; + + const int zero_point_shape_x = (groups_per_K + 1) / 2; + const int scales_shape_x = groups_per_K; + int n_idx = group_id / scales_shape_x; + int kb_idx = group_id % scales_shape_x; + int element_offset = group_id * block_size + ((threadIdx.x * 8) & (block_size - 1)); + T* output_i = output + element_offset; + uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); + const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx.x * 8) & (block_size - 1)); + for (int i = 0; i < 8; i++) { + int32_t rid = reorder_idx_with_off[i]; + T scale = *(scale_data + n_idx * scales_shape_x + rid); + uint8_t zp = 8; + if (zero_points) { + zp = zero_points[n_idx * zero_point_shape_x + rid / 2]; + zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f); + } + + if constexpr (std::is_same_v) { + T zp_adjust = -scale * __short2half_rn(zp); + output_i[i] = __uint2half_rn((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } else { + T zp_adjust = -scale * T(zp); + output_i[i] = T((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } + } +} + +template +__global__ void Dequantize4BitsKernel( + T* output, + const uint8_t* quant_data, + const T* scale_data, + const ZeroT* zero_points, + int block_size, + int groups_per_K, + int groups_per_threadblock, + int total_groups) { + int block_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * 8) / block_size); + if (block_id >= total_groups) { + return; + } + int element_offset = block_id * block_size + ((threadIdx.x * 8) & (block_size - 1)); uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); T scale = *(scale_data + block_id); - uint8_t zp = 8; - if (zero_points) { - zp = zero_points[n_idx * ((blocks_per_K + 1)/2) + kb_idx / 2]; - zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f); + T zero_point_value; + if constexpr (std::is_same_v) { + const int scales_shape_x = groups_per_K; + const int zero_point_shape_x = (groups_per_K + 1) / 2; + int kb_idx = block_id % scales_shape_x; + int n_idx = block_id / scales_shape_x; + uint8_t zp = 8; + if (zero_points) { + zp = zero_points[n_idx * zero_point_shape_x + kb_idx / 2]; + zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f); + } + zero_point_value = static_cast(zp); + } else { + zero_point_value = zero_points? *(zero_points + block_id):static_cast(8); } output = output + element_offset; - DequantizeEightElements(quant_value, scale, static_cast(zp), output); + DequantizeEightElements(quant_value, scale, zero_point_value, output); } -template +template Status Dequantize4Bits( T* output, const uint8_t* quant_data, const T* scales_data, - const uint8_t* zero_points, // shape: [N, (block_per_K + 1)/2] + const ZeroT* zero_points, // shape: [N, (block_per_K + 1)/2] + const int32_t* reorder_idx, int k, int n, int block_size, @@ -98,47 +154,79 @@ Status Dequantize4Bits( // k is padded and equal to block_per_K * block_size ORT_ENFORCE(k % block_size == 0, "k must be a multiplier of block_size"); constexpr int element_per_thread = 8; - int blocks_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; - int blocks_per_K = k / block_size; - int total_blks = n * blocks_per_K; - int blocks_per_grid = static_cast(CeilDiv(n * blocks_per_K, blocks_per_threadblock)); - int shift = static_cast(log2f(float(block_size))); - - Dequantize4BitsKernel<<>>( - output, - quant_data, - scales_data, - zero_points, - block_size, - blocks_per_K, - blocks_per_threadblock, - total_blks, - shift); + int groups_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; + int groups_per_K = k / block_size; + int total_groups = n * groups_per_K; // total elemenets in quant_data + int groups_per_grid = static_cast(CeilDiv(total_groups, groups_per_threadblock)); + if (!reorder_idx || std::is_same_v) { + Dequantize4BitsKernel<<>>( + output, + quant_data, + scales_data, + zero_points, + block_size, + groups_per_K, + groups_per_threadblock, + total_groups); + } else { + // static_assert(std::is_same_v, "ZeroT must be uint8_t"); + Dequantize4BitsKernelReOrder<<>>( + output, + quant_data, + scales_data, + (const uint8_t*)zero_points, + reorder_idx, + block_size, + groups_per_K, + groups_per_threadblock, + total_groups); + } return Status::OK(); } -template Status Dequantize4Bits( +template Status Dequantize4Bits( float* output, const uint8_t* quant_data, const float* scales_data, const uint8_t* zero_points, + const int32_t* reorder_idx, int k, int n, int block_size, cudaStream_t stream); -template Status Dequantize4Bits( +template Status Dequantize4Bits( half* output, const uint8_t* quant_data, const half* scales_data, const uint8_t* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); +template Status Dequantize4Bits( + float* output, + const uint8_t* quant_data, + const float* scales_data, + const float* zero_points, + const int32_t* reorder_idx, int k, int n, int block_size, cudaStream_t stream); - +template Status Dequantize4Bits( + half* output, + const uint8_t* quant_data, + const half* scales_data, + const half* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); /////////////////////////////////////////////////////////////////////////////// // A more general block-wise dequantization implementation that supports // different block sizes and block orientations (row-wise/column-wise). diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh index f9c09c55fd893..580b5087f3fa3 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh @@ -7,18 +7,18 @@ namespace onnxruntime { namespace contrib { namespace cuda { -template +template Status Dequantize4Bits( T* output, const uint8_t* quant_data, const T* scales_data, - const uint8_t* zero_points, + const ZeroT* zero_points, + const int32_t* reorder_idx, int k, int n, int block_size, cudaStream_t stream); - /** * @brief Dequantize a block-wise quantized matrix, and store the result in a * column major matrix for use in subsequent GEMM. This implementation supports diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc index bbcb7de99781f..0534ed6dc7fc0 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc @@ -117,7 +117,8 @@ Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { &zero, reinterpret_cast(Y->MutableData()), helper.Ldc(), - GetDeviceProp())); + GetDeviceProp(), + UseTF32())); } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 5b0e61e197014..1cec6f6a12f1c 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -1,15 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// -// This module define MatMulFp32Q4 operator, it is basically -// matmul float32 with right hand side being a 2-D matrix -// pre-packed and block-compacted into int4 -// - -#include "core/common/safeint.h" -#include "core/providers/cuda/cuda_kernel.h" -#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/quantization/matmul_nbits.h" + +#include + +#include "core/common/status.h" +#include "core/framework/float16.h" #include "core/providers/cpu/math/matmul_helper.h" #include "matmul_nbits.cuh" #include "dequantize_blockwise.cuh" @@ -19,40 +16,19 @@ namespace contrib { namespace cuda { using namespace onnxruntime::cuda; -template -class MatMulNBits final : public CudaKernel { - public: - MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) { - ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); - ORT_ENFORCE(nbits_ == 4, - "Only 4b quantization is supported for MatMulNBits op," - " additional bits support is planned."); - } - - Status ComputeInternal(OpKernelContext* context) const override; - - private: - int64_t K_; - int64_t N_; - int64_t block_size_; - int64_t nbits_; - bool column_wise_quant_blk_{true}; -}; - template Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { const Tensor* a = ctx->Input(0); const Tensor* b = ctx->Input(1); const Tensor* scales = ctx->Input(2); const Tensor* zero_points = ctx->Input(3); + const Tensor* reorder_idx = ctx->Input(4); const auto* a_data = a->Data(); const uint8_t* blob_data = b->Data(); const auto* scales_data = scales->Data(); - const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); + const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); typedef typename ToCudaType::MappedType CudaT; @@ -67,76 +43,99 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { // Bail out early if the output is going to be empty if (Y->Shape().Size() == 0) return Status::OK(); - bool is_4bit_done = TryMatMul4Bits( - reinterpret_cast(Y->MutableData()), - reinterpret_cast(a_data), - blob_data, - reinterpret_cast(scales_data), - zero_points_data, - SafeInt(helper.M()), - SafeInt(helper.N()), - SafeInt(helper.K()), - SafeInt(block_size_), - SafeInt(GetDeviceProp().sharedMemPerBlock), - static_cast(ctx->GetComputeStream()->GetHandle())); - if (!is_4bit_done) { - int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_; - IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream()); - auto* b_data = b_data_ptr.get(); - if (column_wise_quant_blk_) { - // column-wise block + bool is_4bit_done = (reorder_idx_data == nullptr) && + (!zero_points || !zero_points->IsDataType()) && + TryMatMul4Bits( + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + blob_data, + reinterpret_cast(scales_data), + static_cast(zero_points_data), + SafeInt(helper.M()), + SafeInt(helper.N()), + SafeInt(helper.K()), + SafeInt(block_size_), + SafeInt(GetDeviceProp().sharedMemPerBlock), + static_cast(ctx->GetComputeStream()->GetHandle())); + + if (is_4bit_done) { + return Status::OK(); + } + + int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_; + IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream()); + auto* b_data = b_data_ptr.get(); + if (column_wise_quant_blk_) { + if (reorder_idx) { + ORT_ENFORCE(K_padded == reorder_idx->Shape()[0], "K_padded != g_idx->Shape()[0]"); + } + // column-wise block + if ((zero_points && zero_points->IsDataType())) { ORT_RETURN_IF_ERROR(Dequantize4Bits( reinterpret_cast(b_data), blob_data, reinterpret_cast(scales_data), - zero_points_data, + (const CudaT*)zero_points_data, + reorder_idx_data, SafeInt(K_padded), SafeInt(N_), SafeInt(block_size_), static_cast(ctx->GetComputeStream()->GetHandle()))); } else { - // row-wise block - K_padded = K_; - - ORT_RETURN_IF_ERROR(DequantizeBlockwise4b( + ORT_RETURN_IF_ERROR(Dequantize4Bits( reinterpret_cast(b_data), blob_data, reinterpret_cast(scales_data), - zero_points_data, - SafeInt(block_size_), - column_wise_quant_blk_, - SafeInt(K_), + (const uint8_t*)zero_points_data, + reorder_idx_data, + SafeInt(K_padded), SafeInt(N_), + SafeInt(block_size_), static_cast(ctx->GetComputeStream()->GetHandle()))); } + } else { + // row-wise block + K_padded = K_; + + ORT_RETURN_IF_ERROR(DequantizeBlockwise4b( + reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + (const uint8_t*)zero_points_data, + SafeInt(block_size_), + column_wise_quant_blk_, + SafeInt(K_), + SafeInt(N_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } #if 0 - cudaStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle())); - T* b_data_cpu = new T[K_ * N_]; - cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost); - delete[] b_data_cpu; +cudaStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle())); +T* b_data_cpu = new T[K_ * N_]; +cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost); +delete[] b_data_cpu; #endif - const CudaT alpha = ToCudaType::FromFloat(1.f); - const CudaT zero = ToCudaType::FromFloat(0.f); - - if (helper.OutputOffsets().size() == 1) { - CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( - GetCublasHandle(ctx), - CUBLAS_OP_T, - CUBLAS_OP_N, - SafeInt(helper.N()), - SafeInt(helper.M()), - SafeInt(helper.K()), - &alpha, - reinterpret_cast(b_data), - SafeInt(K_padded), - reinterpret_cast(a_data), - helper.Lda(transa), - &zero, - reinterpret_cast(Y->MutableData()), - helper.Ldc(), - GetDeviceProp())); - } + const CudaT alpha = ToCudaType::FromFloat(1.f); + const CudaT zero = ToCudaType::FromFloat(0.f); + + if (helper.OutputOffsets().size() == 1) { + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( + GetCublasHandle(ctx), + CUBLAS_OP_T, + CUBLAS_OP_N, + SafeInt(helper.N()), + SafeInt(helper.M()), + SafeInt(helper.K()), + &alpha, + reinterpret_cast(b_data), + SafeInt(K_padded), + reinterpret_cast(a_data), + helper.Lda(transa), + &zero, + reinterpret_cast(Y->MutableData()), + helper.Ldc(), + GetDeviceProp(), + UseTF32())); } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h new file mode 100644 index 0000000000000..f5c2c6c4e4fdf --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// +// This module define MatMulNBits operator, it is basically +// matmul float with right hand side being a 2-D matrix +// pre-packed and block-compacted into int4 +// +#pragma once +#include "core/common/safeint.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +using namespace onnxruntime::cuda; + +template +class MatMulNBits final : public CudaKernel { + public: + MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t nbits_; + bool column_wise_quant_blk_{true}; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc index 3cecebedae2f0..12835978536e1 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc @@ -142,7 +142,7 @@ inline void debug_print([[maybe_unused]] const T* arr, std::cout << "========" << name << std::endl; for (size_t i = 0; i < sz; i++) { if (i % w == 0) std::cout << std::endl; - if (std::is_same().value) { + if constepxr (std::is_same::value) { std::cout << (int)buf[i] << ", "; } else { std::cout << buf[i] << ", "; diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_impl.cu b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_impl.cu index f4d5a7b404a62..fd4b51f40fb4f 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_impl.cu @@ -151,7 +151,7 @@ QOrderBatchInt8MatrixTransposeKernel(const int8_t* src, const int8_t* dst, const } } -Status QOrderBatchTransposeInt8Matrix(cudaStream_t stream, const cudaDeviceProp& device_prop, +Status QOrderBatchTransposeInt8Matrix(cudaStream_t stream, const cudaDeviceProp& /*device_prop*/, const int batch_size, const int rows, const int cols, const int8_t* input, int8_t* output) { ORT_ENFORCE(rows % 4 == 0 && cols % 4 == 0, "Matrix rows and cols must be divisible by 4!"); diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq_impl.cu b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq_impl.cu index baff8e76ec73b..e6ac0bc8a5171 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq_impl.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq_impl.cu @@ -389,7 +389,7 @@ QOrderDequantizeKernel_Strict(const int8_t* __restrict__ src, const __half* __re } } -Status QOrderDequantize_Strict(cudaStream_t stream, const cudaDeviceProp& device_prop, +Status QOrderDequantize_Strict(cudaStream_t stream, const cudaDeviceProp& /*device_prop*/, const int8_t* src, __half* dst, float scale, size_t N) { ORT_RETURN_IF(N & 0x3LL, "N can not divide by 4!"); diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index a39abefed9cd0..eb1943b59d976 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -1,11 +1,22 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. + +// cub.cuh includes device/dispatch_radix_sort.cuh which has assignment in conditional expressions +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4706) +#endif +#include +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +#include + #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cu_inc/common.cuh" -#include "cub/util_type.cuh" -#include -#include + #include "contrib_ops/cuda/bert/utils.cuh" #include "contrib_ops/cuda/transformers/generation_cuda_impl.h" diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index bba30805ae1be..7adc2fe0a67ea 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -424,7 +424,7 @@ Status ProcessLogits(const OrtValue& logits, // const bool is_whisper_model = (parameters->model_type == onnxruntime::contrib::transformers::IGenerationParameters::kModelTypeWhisper); if (step == 1 && is_whisper_model && parameters->no_speech_probs) { cuda::LaunchSaveNoSpeechProbs( - (T*)parameters->no_speech_probs, Y_data, batch_size, num_beams, vocab_size, parameters->no_speech_token, cuda_stream); + (T*)parameters->no_speech_probs, Y_data, batch_size, num_beams, vocab_size, parameters->no_speech_token_id, cuda_stream); } // NOTE: currently we treat extra decoding ids are same @@ -469,7 +469,15 @@ Status ProcessLogits(const OrtValue& logits, // cudaMemcpyDeviceToHost, cuda_stream)); constexpr int max_initial_timestamp_index = 50; - onnxruntime::contrib::transformers::TimestampLogitsProcessor time_logit_processor(parameters->eos_token_id, max_initial_timestamp_index); + // Token ids are passed below in the order that they appear in the tokenizer + onnxruntime::contrib::transformers::TimestampLogitsProcessor time_logit_processor(parameters->eos_token_id, + parameters->decoder_start_token_id, + parameters->translate_token_id, + parameters->transcribe_token_id, + parameters->start_of_lm_token_id, + parameters->no_timestamps_token_id, + parameters->beginning_timestamp_token_id, + max_initial_timestamp_index); onnxruntime::contrib::transformers::NextTokenScores next_token_scores_timestamp({cpu_next_token_scores_span, batch_beam_size, vocab_size}); CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); diff --git a/onnxruntime/contrib_ops/js/fast_gelu.cc b/onnxruntime/contrib_ops/js/fast_gelu.cc new file mode 100644 index 0000000000000..62c538318160d --- /dev/null +++ b/onnxruntime/contrib_ops/js/fast_gelu.cc @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "fast_gelu.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + FastGelu, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + FastGelu); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/fast_gelu.h b/onnxruntime/contrib_ops/js/fast_gelu.h new file mode 100644 index 0000000000000..68c7892741c66 --- /dev/null +++ b/onnxruntime/contrib_ops/js/fast_gelu.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; +JSEP_KERNEL_IMPL(FastGelu, FastGelu); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index 498a9f5679eb5..25e7567a2e9fc 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -8,12 +8,14 @@ namespace contrib { namespace js { class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv); template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -24,13 +26,15 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo}; + SkipLayerNormalization)>}; for (auto& function_table_entry : function_table) { KernelCreateInfo info = function_table_entry(); diff --git a/onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc new file mode 100644 index 0000000000000..888db0fd161f2 --- /dev/null +++ b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/js/quantization/matmul_nbits.h" +#include "core/providers/js/js_data_types.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", JsepSupportedFloatTypes()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulNBits); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h new file mode 100644 index 0000000000000..cca2c4757765b --- /dev/null +++ b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; + +class MatMulNBits final : public JsKernel { + public: + MatMulNBits(const OpKernelInfo& info) : JsKernel(info), + K_{narrow(info.GetAttr("K"))}, + N_{narrow(info.GetAttr("N"))}, + accuracy_level_{info.GetAttrOrDefault("accuracy_level", 0)}, + nbits_{narrow(info.GetAttr("bits"))}, + block_size_{narrow(info.GetAttr("block_size"))} { + ORT_ENFORCE(nbits_ == 4, + "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + ORT_ENFORCE(block_size_ >= 16 && !(block_size_ & (block_size_ - 1)), + "Block size must be a power of 2 and greater than or equal to 16."); + JSEP_INIT_KERNEL_ATTRIBUTE(MatMulNBits, ({ + "k" : $1, + "n" : $2, + "accuracyLevel" : $3, + "bits" : $4, + "blockSize" : $5 + }), + static_cast(K_), + static_cast(N_), + static_cast(accuracy_level_), + static_cast(nbits_), + static_cast(block_size_)); + } + + private: + const size_t K_; + const size_t N_; + const int64_t accuracy_level_; + const size_t nbits_; + const size_t block_size_; +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh index 0599318a4022d..be8508670e4b1 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh @@ -31,7 +31,7 @@ using MaskingSpecialization = ck::tensor_operation::device::MaskingSpecializatio using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute; // the interface +using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute; // the interface using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle; // the implementation static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; @@ -141,6 +141,35 @@ std::vector, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>(); +template <> +std::vector, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); + +// fp16, biased, non-masked +template <> +std::vector, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); + +// fp16, biased, fp16 masked, basically, two bias +template <> +std::vector, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); + } // namespace internal } // namespace rocm } // namespace contrib diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu index 181e47f012c99..2e32a6594d164 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu @@ -32,6 +32,27 @@ GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< return instances; } +using NonBiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< + 2, 1, 1, 1, 1, + F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>; + +template <> +std::vector> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { + std::vector> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_instances< + 2, 1, 1, 1, 1, + F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, + MaskingSpecialization::MaskOutUpperTriangle>{}); + + return instances; +} + } // namespace internal } // namespace rocm } // namespace contrib diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu index 1577bdf397fa5..91da8d9e1f9a8 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu @@ -32,6 +32,27 @@ GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< return instances; } +using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< + 2, 1, 1, 1, 1, + F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>; + +template <> +std::vector> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { + std::vector> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_instances< + 2, 1, 1, 1, 1, + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, + MaskingSpecialization::MaskOutUpperTriangle>{}); + + return instances; +} + } // namespace internal } // namespace rocm } // namespace contrib diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu index 14de59234356b..b08123be18977 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu @@ -32,6 +32,27 @@ GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< return instances; } +using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< + 2, 1, 1, 1, 1, + F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>; + +template <> +std::vector> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { + std::vector> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_instances< + 2, 1, 1, 1, 1, + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, + MaskingSpecialization::MaskOutUpperTriangle>{}); + + return instances; +} + } // namespace internal } // namespace rocm } // namespace contrib diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh index 78983ac95e672..54dda4bfa6d2c 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh @@ -732,122 +732,154 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp -auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { +template +auto GetArgAndRunInvoker(const U& impl, const V& invoker, const GemmSoftmaxGemmPermuteParams* params) { constexpr const int kNumBiasBuffer = static_cast(USE_BIAS) + static_cast(USE_MASK); using Nop = ck::tensor_operation::element_wise::PassThrough; using Acc0ElementOp = internal::PreSoftmaxAttentionScoreOp; + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMode(params->attention), + "attention mode is not supported, got ", params->attention->mode); + if constexpr (USE_BIAS) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->bias_buffer == nullptr, "biased version only support input with bias"); + } else { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->bias_buffer != nullptr, "non-biased version only support input without bias"); + } + if constexpr (USE_MASK) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMaskType(params->attention), + "mask type is not supported, got ", params->attention->mask_type); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->mask_index_buffer == nullptr, "masked version only support input with mask"); + } else { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->mask_index_buffer != nullptr, "non-masked version only support input without mask"); + } + + auto attn = params->attention; + const int& G0 = attn->batch_size; + const int& G1 = attn->num_heads; + const int& M = attn->sequence_length; + const int& N = attn->total_sequence_length; + const int& K = attn->head_size; + const int& O = attn->v_head_size; + { + auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); + ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch"); + } + + auto [qs, ks, vs] = GetQkvStrides(attn); + std::vector q_buffer_lengths = {G0, G1, M, K}; + std::vector q_buffer_strides = qs.template ForBNSHCoord>(); + std::vector k_buffer_lengths = {G0, G1, N, K}; + std::vector k_buffer_strides = ks.template ForBNSHCoord>(); + std::vector v_buffer_lengths = {G0, G1, O, N}; + std::vector v_buffer_strides = vs.template ForBNHSCoord>(); + std::vector out_buffer_lengths = {G0, G1, M, O}; + std::vector out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213 + + std::array bias_buffers{}; + std::array, kNumBiasBuffer> bias_lengths{}; + std::array, kNumBiasBuffer> bias_strides{}; + if constexpr (USE_BIAS) { + bias_buffers[0] = const_cast(params->bias_buffer); + bias_lengths[0] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) + bias_strides[0] = {G1 * M * N, M * N, N, 1}; + } + if constexpr (USE_MASK) { + bias_buffers[kNumBiasBuffer - 1] = params->workspace_buffer; + bias_lengths[kNumBiasBuffer - 1] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) + if (params->mask_index_dims.size() == 2) { // [B,T] + bias_strides[kNumBiasBuffer - 1] = {N, 0, 0, 1}; + } else if (params->mask_index_dims.size() == 3) { // [B,S,T] + bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; + } else if (params->mask_index_dims.size() == 4) { // [B,1,max_seq_len,max_seq_len] -->convert--> [B,S,T] + bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; + } else { + ORT_ENFORCE(false, "Unreachable"); + } + } + + auto arg = impl->MakeArgumentPointer( + params->q_buffer, params->k_buffer, params->v_buffer, params->out_buffer, + bias_buffers, // Gemm1 bias, as attention mask + {}, // Gemm2 bias + q_buffer_lengths, q_buffer_strides, + k_buffer_lengths, k_buffer_strides, + v_buffer_lengths, v_buffer_strides, + out_buffer_lengths, out_buffer_strides, + bias_lengths, bias_strides, + {}, + {}, + Nop{}, + Nop{}, + Acc0ElementOp{params->scale}, + Nop{}, + Nop{}); + + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), + impl->GetTypeString(), " does not support the params"); + + if constexpr (USE_MASK) { + ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp::LaunchConvertToFilledMaskValue(params)); + } + + invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); + return Status::OK(); +} + +template +auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { using CKDataType = typename CKDataTypeAdaptor::type; using D0DataType = typename ck::detail::tuple_concat< std::conditional_t, ck::Tuple<>>, std::conditional_t, ck::Tuple<>>>::type; - constexpr static auto MaskingSpec = + constexpr static auto MaskingSpecMaskDisabled = ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + constexpr static auto MaskingSpecMaskOutUpperTriangle = + ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle; + + std::vector>>> + ret; - std::vector>>> ret; for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpec>()) { + CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskDisabled>()) { auto type_string = impl->GetTypeString(); auto invoker = impl->MakeInvokerPointer(); auto op = [impl = std::move(impl), invoker = std::move(invoker)]( const GemmSoftmaxGemmPermuteParams* params) -> Status { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMode(params->attention), - "attention mode is not supported, got ", params->attention->mode); - if constexpr (USE_BIAS) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->bias_buffer == nullptr, "biased version only support input with bias"); - } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->bias_buffer != nullptr, "non-biased version only support input without bias"); - } - if constexpr (USE_MASK) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMaskType(params->attention), - "mask type is not supported, got ", params->attention->mask_type); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->mask_index_buffer == nullptr, "masked version only support input with mask"); - } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->mask_index_buffer != nullptr, "non-masked version only support input without mask"); - } + params->attention->is_unidirectional, "unidirectional attention is not supported with MaskingSpecMaskDisabled"); - auto attn = params->attention; - const int& G0 = attn->batch_size; - const int& G1 = attn->num_heads; - const int& M = attn->sequence_length; - const int& N = attn->total_sequence_length; - const int& K = attn->head_size; - const int& O = attn->v_head_size; - { - auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); - ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch"); - } + return GetArgAndRunInvoker(impl, invoker, params); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); + } - auto [qs, ks, vs] = GetQkvStrides(attn); - std::vector q_buffer_lengths = {G0, G1, M, K}; - std::vector q_buffer_strides = qs.template ForBNSHCoord>(); - std::vector k_buffer_lengths = {G0, G1, N, K}; - std::vector k_buffer_strides = ks.template ForBNSHCoord>(); - std::vector v_buffer_lengths = {G0, G1, O, N}; - std::vector v_buffer_strides = vs.template ForBNHSCoord>(); - std::vector out_buffer_lengths = {G0, G1, M, O}; - std::vector out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213 - - std::array bias_buffers{}; - std::array, kNumBiasBuffer> bias_lengths{}; - std::array, kNumBiasBuffer> bias_strides{}; - if constexpr (USE_BIAS) { - bias_buffers[0] = const_cast(params->bias_buffer); - bias_lengths[0] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) - bias_strides[0] = {G1 * M * N, M * N, N, 1}; - } - if constexpr (USE_MASK) { - bias_buffers[kNumBiasBuffer - 1] = params->workspace_buffer; - bias_lengths[kNumBiasBuffer - 1] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) - if (params->mask_index_dims.size() == 2) { // [B,T] - bias_strides[kNumBiasBuffer - 1] = {N, 0, 0, 1}; - } else if (params->mask_index_dims.size() == 3) { // [B,S,T] - bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; - } else if (params->mask_index_dims.size() == 4) { // [B,1,max_seq_len,max_seq_len] -->convert--> [B,S,T] - bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; - } else { - ORT_ENFORCE(false, "Unreachable"); - } - } + for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskOutUpperTriangle>()) { + auto type_string = impl->GetTypeString(); - auto arg = impl->MakeArgumentPointer( - params->q_buffer, params->k_buffer, params->v_buffer, params->out_buffer, - bias_buffers, // Gemm1 bias, as attention mask - {}, // Gemm2 bias - q_buffer_lengths, q_buffer_strides, - k_buffer_lengths, k_buffer_strides, - v_buffer_lengths, v_buffer_strides, - out_buffer_lengths, out_buffer_strides, - bias_lengths, bias_strides, - {}, - {}, - Nop{}, - Nop{}, - Acc0ElementOp{params->scale}, - Nop{}, - Nop{}); - - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support the params"); - - if constexpr (USE_MASK) { - ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp::LaunchConvertToFilledMaskValue(params)); - } - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); + auto invoker = impl->MakeInvokerPointer(); + auto op = [impl = std::move(impl), invoker = std::move(invoker)]( + const GemmSoftmaxGemmPermuteParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !params->attention->is_unidirectional, "bidirectional attention is not supported with MaskingSpecMaskOutUpperTriangle"); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->attention->sequence_length != params->attention->total_sequence_length, + "seqence_length != total_seqence_length is not supported with MaskingSpecMaskOutUpperTriangle"); + + return GetArgAndRunInvoker(impl, invoker, params); }; ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); } + return ret; } #endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc b/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc deleted file mode 100644 index 9cb414e4e8980..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/fast_gelu.h" - -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/miopen_common.h" -#include "contrib_ops/cpu/bert/bias_gelu_helper.h" -#include "contrib_ops/rocm/bert/elementwise.h" -#include "contrib_ops/rocm/bert/transformer_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - FastGelu, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - FastGelu); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) -REGISTER_KERNEL_TYPED(BFloat16) - -using namespace ONNX_NAMESPACE; - -template -Status FastGelu::ComputeInternal(OpKernelContext* context) const { - ORT_RETURN_IF_ERROR(bias_gelu_helper::CheckInputs(context)); - - const Tensor* input = context->Input(0); - const Tensor* bias = context->Input(1); - Tensor* output = context->Output(0, input->Shape()); - - int64_t input_length = input->Shape().Size(); - if (input_length == 0) { - return Status::OK(); - } - int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size(); - typedef typename ToHipType::MappedType HipT; - - const HipT* input_buffer = reinterpret_cast(input->Data()); - const HipT* bias_buffer = (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr; - return LaunchElementwiseKernel( - GetTuningContext(), context->GetComputeStream(), - input_buffer, static_cast(input_length), - bias_buffer, static_cast(bias_length), - reinterpret_cast(output->MutableData())); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.h b/onnxruntime/contrib_ops/rocm/bert/fast_gelu.h deleted file mode 100644 index 42bfe5a0b0246..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.h +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/providers/rocm/rocm_kernel.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using namespace onnxruntime::rocm; - -template -class FastGelu final : public RocmKernel { - public: - FastGelu(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {} - Status ComputeInternal(OpKernelContext* ctx) const override; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc deleted file mode 100644 index e82e15a304f4c..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc +++ /dev/null @@ -1,152 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/rocm/rocm_common.h" -#include "contrib_ops/rocm/diffusion/group_norm.h" -#include "contrib_ops/rocm/diffusion/group_norm_impl.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define GROUP_NORM_TYPES float, MLFloat16 - -ONNX_OPERATOR_KERNEL_EX( - GroupNorm, kMSDomain, 1, kRocmExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm); - -using namespace ONNX_NAMESPACE; - -namespace { -template -struct DispatchGroupNorm { - Status operator()(RocmTuningContext* tuning_ctx, - Stream* stream, - Tensor* output, - const Tensor* input, - const Tensor* gamma, - const Tensor* beta, - void* workspace, - float epsilon, - int batch_size, - int num_channels, - int height, - int width, - int num_groups, - bool use_swish_activation) { - typedef typename ToHipType::MappedType HipT; - return LaunchGroupNormKernel( - tuning_ctx, - stream, - reinterpret_cast(output->MutableData()), - reinterpret_cast(input->Data()), - gamma->Data(), - beta->Data(), - workspace, - epsilon, - batch_size, - num_channels, - height, - width, - num_groups, - use_swish_activation); - } -}; - -} // namespace - -GroupNorm::GroupNorm(const OpKernelInfo& op_info) : RocmKernel(op_info) { - epsilon_ = op_info.GetAttrOrDefault("epsilon", 1e-5f); - ORT_ENFORCE(epsilon_ >= 0); - - int64_t num_groups; - ORT_ENFORCE(op_info.GetAttr("groups", &num_groups).IsOK()); - ORT_ENFORCE(num_groups >= 0); - num_groups_ = static_cast(num_groups); - - int64_t activation; - ORT_ENFORCE(op_info.GetAttr("activation", &activation).IsOK()); - ORT_ENFORCE(activation == 0 || activation == 1); // 0 is None, 1 is Swish - use_swish_activation_ = (activation == 1); - - channels_last_ = (op_info.GetAttrOrDefault("channels_last", static_cast(1)) != 0); -} - -Status GroupNorm::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/, - bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { - is_packed = false; - return Status::OK(); -} - -Status GroupNorm::ComputeInternal(OpKernelContext* context) const { - const Tensor* input = context->Input(0); - const Tensor* gamma = context->Input(1); - const Tensor* beta = context->Input(2); - Tensor* output = context->Output(0, input->Shape()); - - if (!channels_last_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "only the channels_last layout is supported"); - } - - const auto& input_dims = input->Shape().GetDims(); - if (input_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "input is expected to have 4 dimensions, got ", input_dims.size()); - } - - const auto& gamma_dims = gamma->Shape().GetDims(); - if (gamma_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "gamma is expected to have 1 dimension, got ", gamma_dims.size()); - } - if (gamma_dims[0] != input_dims[3]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Number of channels in gamma and input does not match"); - } - - const auto& beta_dims = beta->Shape().GetDims(); - if (beta_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "beta is expected to have 1 dimension, got ", beta_dims.size()); - } - if (beta_dims[0] != input_dims[3]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Number of channels in beta and input does not match"); - } - - // Input and output format is NHWC - int batch_size = static_cast(input_dims[0]); - int num_channels = static_cast(input_dims[3]); - int height = static_cast(input_dims[1]); - int width = static_cast(input_dims[2]); - - if (num_channels % num_groups_ != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "number of channels should be divisible by num_groups"); - } - - if (context->GetUseDeterministicCompute()) { - static std::once_flag log_warning; - std::call_once(log_warning, []() { - LOGS_DEFAULT(WARNING) << "GroupNorm has no deterministic GPU kernel, its outputs may still be nondeterministic."; - }); - } - - auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream()); - - utils::MLTypeCallDispatcher dispatcher(input->GetElementType()); - return dispatcher.InvokeRet(GetTuningContext(), context->GetComputeStream(), - output, input, gamma, beta, workspace.get(), - epsilon_, - batch_size, - num_channels, - height, - width, - num_groups_, - use_swish_activation_); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh index fb7091592c16e..d0a0d09fcbae3 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh @@ -26,13 +26,18 @@ namespace rocm { using onnxruntime::rocm::CKDataTypeAdaptor; -using Swish = ck::tensor_operation::element_wise::Swish; +// The SiLU function is a special case of Swish function, +// The Swish function is parametrized by b, which is set to 1.0 for SiLU. They are defined as: +// SiLU(x) = x * sigmoid(x) +// Swish(x) = x * sigmoid(bx) +// The default value of b is 1.0 in ck::tensor_operation::element_wise::Swish function. We treat them as the same function here. +using Silu = ck::tensor_operation::element_wise::Swish; using Pass = ck::tensor_operation::element_wise::PassThrough; constexpr int Rank = 5; constexpr int NumReduceDim = 3; -template +template auto GetCKGroupNormNHWCTypeStringAndOps() { using XDataType = typename CKDataTypeAdaptor::type; using YDataType = typename CKDataTypeAdaptor::type; @@ -40,26 +45,30 @@ auto GetCKGroupNormNHWCTypeStringAndOps() { using GammaDataType = float; using BetaDataType = float; - using Activation = std::conditional_t; + using Activation = std::conditional_t; - std::vector>>> ret; + std::vector>>> ret; for (auto&& impl : internal::GetDeviceGroupNormInstances()) { - std::string swish_suffix = WithSwish ? "_Swish" : "_Pass"; - auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + swish_suffix; + std::string silu_suffix = WithSilu ? "_Silu" : "_Pass"; + auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + silu_suffix; auto invoker = impl->MakeInvokerPointer(); - auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GroupNormNHWCParams* params) -> Status { - if constexpr (WithSwish) { + auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)]( + const GroupNormNHWCTunableParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), + "Input skip or bias is not supported by composable kernel."); + if constexpr (WithSilu) { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !params->withSwish, "Swish version only support groupnorm with swish"); + !params->use_silu, "Silu version only support groupnorm with silu"); } else { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->withSwish, "Pass version only support groupnorm without swish"); + params->use_silu, "Pass version only support groupnorm without silu"); } - std::vector in_lengths{params->n, params->h, params->w, params->groups, params->cPerGroup}; - std::vector in_out_strides{params->h * params->w * params->c, params->w * params->c, params->c, params->cPerGroup, 1}; - std::vector gamma_beta_strides{0, 0, 0, params->cPerGroup, 1}; + std::vector in_lengths{params->n, params->h, params->w, params->groups, params->channels_per_group}; + std::vector in_out_strides{params->h * params->w * params->c, params->w * params->c, + params->c, params->channels_per_group, 1}; + std::vector gamma_beta_strides{0, 0, 0, params->channels_per_group, 1}; std::vector reduce_dims{1, 2, 4}; auto activation = Activation{}; diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh index 19b081881dcec..4cb371fdcf960 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh @@ -18,7 +18,7 @@ namespace internal { using F16 = ck::half_t; using F32 = float; -using Swish = ck::tensor_operation::element_wise::Swish; +using Silu = ck::tensor_operation::element_wise::Swish; using Pass = ck::tensor_operation::element_wise::PassThrough; using ck::tensor_operation::device::DeviceNormalizationFwd; // the interface @@ -101,9 +101,9 @@ GetDeviceGroupNormInstances() { template <> std::vector>> + F16, F32, F32, F16, F32, Silu, 5, 3>>> GetDeviceGroupNormInstances< - F16, F32, F32, F16, F32, Swish, 5, 3>(); + F16, F32, F32, F16, F32, Silu, 5, 3>(); template <> std::vector std::vector>> + F32, F32, F32, F32, F32, Silu, 5, 3>>> GetDeviceGroupNormInstances< - F32, F32, F32, F32, F32, Swish, 5, 3>(); + F32, F32, F32, F32, F32, Silu, 5, 3>(); template <> std::vector -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, - device_normalization_f16_instances{}); + device_normalization_f16_instances{}); return instances; } diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu index 9b0ccab17b4c1..ceb53ed442abc 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu @@ -11,12 +11,12 @@ namespace rocm { namespace internal { template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, - device_normalization_f32_instances{}); + device_normalization_f32_instances{}); return instances; } diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h index 008ae20b0561f..7cff640db2f34 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h @@ -8,110 +8,47 @@ #include "core/providers/rocm/cu_inc/common.cuh" #include "core/providers/rocm/rocm_common.h" #include "core/providers/rocm/tunable/rocm_tunable.h" +#include "contrib_ops/rocm/diffusion/group_norm_common_base.h" namespace onnxruntime { namespace contrib { namespace rocm { -using onnxruntime::rocm::CeilDiv; - -int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) { - int32_t maxDivisor = -1; - for (int32_t i = 1; i <= std::sqrt(n); i++) { - if (n % i == 0) { - int32_t divisor1 = n / i; - int32_t divisor2 = i; - - if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) { - maxDivisor = divisor1; - } - if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) { - maxDivisor = divisor2; - } - } - } - return maxDivisor; -} - template -struct GroupNormNHWCParams : OpParams { - GroupNormNHWCParams(RocmTuningContext* tuning_ctx, onnxruntime::Stream* stream, T* dst, float* redBuffer, const T* src, const float* gamma, - const float* beta, int32_t n, int32_t h, int32_t w, int32_t c, int32_t groups, float epsilon, bool withSwish) - : OpParams(tuning_ctx, stream), dst(dst), src(src), gamma(gamma), beta(beta), redBuffer(redBuffer), epsilon(epsilon), n(n), h(h), w(w), c(c), groups(groups), withSwish(withSwish) { - int32_t maxBlocksPerHW = 1024; - switch (c) { - case 960: - case 1920: - cPerBlock = 480; - break; - case 512: - case 256: - cPerBlock = 256; - break; - case 128: - cPerBlock = 128; - break; - default: - cPerBlock = 320; - } - - hw = h * w; - const int32_t blocksPerHW = findMaxDivisor(hw, maxBlocksPerHW); - hwPerBlock = CeilDiv(hw, blocksPerHW); - cPerGroup = c / groups; - hwc = hw * c; - invHWC = 1.F / (float)(hw * cPerGroup); - groupsPerBlock = cPerBlock / cPerGroup; - } +struct GroupNormNHWCTunableParams : OpParams, GroupNormNHWCParams { + GroupNormNHWCTunableParams(RocmTuningContext* tuning_ctx, + onnxruntime::Stream* ort_stream, + T* output, + T* add_out, + const T* input, + const T* skip, + const T* bias, + const float* gamma, + const float* beta, + float* workspace, + float epsilon, + int batch_size, + int num_channels, + int height, + int width, + int num_groups, + bool use_silu, + bool broadcast_skip, + int channels_per_block) + : OpParams(tuning_ctx, ort_stream), + GroupNormNHWCParams(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon, batch_size, + num_channels, height, width, num_groups, use_silu, broadcast_skip, channels_per_block) {} std::string Signature() const override { - std::string swish_suffix = withSwish ? "_Swish" : "_Pass"; - std::string sig = std::to_string(n) + "_" + std::to_string(h * w) + "_" + std::to_string(c) + "_" + std::to_string(groups) + swish_suffix; + std::string silu_suffix = this->use_silu ? "_silu" : "_pass"; + std::string skip_suffix = this->skip != nullptr ? "_skip" : "_noskip"; + std::string broadcast_suffix = this->broadcast_skip ? "_broadcast" : "_nobroadcast"; + std::string bias_suffix = this->bias != nullptr ? "_bias" : "_nobias"; + std::string sig = std::to_string(this->n) + "_" + std::to_string(this->h * this->w) + "_" + + std::to_string(this->c) + "_" + std::to_string(this->groups) + silu_suffix + + skip_suffix + broadcast_suffix + bias_suffix; return sig; } - - // The output buffer. Layout NHWC. - T* dst; - // The input buffer. Layout NHWC. - T const* src; - // The gamma scaling factor. - float const* gamma; - // The beta term to add in GN. - float const* beta; - // The temporary buffer to do the global parallel reduction. Size: - // BLOCKS_PER_BATCH x C x 2. - float* redBuffer; - float epsilon; - - // The number of instances in the batch. - int32_t n; - // The height and width of each activation map. - int32_t h; - int32_t w; - // The number of channels. - int32_t c; - // The number of groups. - int32_t groups; - // Do we apply the Swish activation function? - bool withSwish; - - // Precomputed values and parameters to control the execution of the kernels. - - // The number of activations per instance (h * w) and the number of - // activations per block. - int32_t hw; - int32_t hwPerBlock; - // The number of channels per group and blocks per activation in the C - // dimension. - int32_t cPerBlock; - int32_t cPerGroup; - - // The precomputed stride between instances. - int32_t hwc; - // The inverse of hwc in floats (to compute mean/var). - float invHWC; - // The precomputed number of groups per block. - int32_t groupsPerBlock; }; } // namespace rocm diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu index dbd5009e63676..142aaf14e8d2d 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu @@ -15,9 +15,12 @@ namespace rocm { template Status LaunchGroupNormKernel( RocmTuningContext* tuning_ctx, - Stream* stream, + Stream* ort_stream, T* output, + T* add_out, const T* input, + const T* skip, + const T* bias, const float* gamma, const float* beta, void* workspace, @@ -27,19 +30,26 @@ Status LaunchGroupNormKernel( int height, int width, int num_groups, - bool use_swish_activation) { - if (batch_size > static_cast(kMaxGroupNormBatchSize)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only support batch_size <= 32. Got", batch_size); - } + bool use_silu, + bool broadcast_skip, + int channels_per_block) { + GroupNormNHWCTunableParams params(tuning_ctx, ort_stream, output, add_out, input, skip, bias, gamma, beta, + reinterpret_cast(workspace), epsilon, batch_size, num_channels, + height, width, num_groups, use_silu, broadcast_skip, channels_per_block); - if (num_groups != static_cast(kGroupNormNumberOfGroups)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only num_groups=32 is supported. Got", num_groups); + if (params.channels_per_block % params.channels_per_group != 0 || + params.channels_per_block > kMaxSize || + (params.channels_per_group % CHANNELS_PER_THREAD != 0)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "GroupNorm in ROCM does not support the input: n=", batch_size, + " h=", height, + " w=", width, + " c=", num_channels, + " groups=", num_groups); } - GroupNormNHWCParams params(tuning_ctx, stream, output, reinterpret_cast(workspace), input, gamma, beta, - batch_size, height, width, num_channels, num_groups, epsilon, use_swish_activation); + HIP_RETURN_IF_ERROR(hipMemsetAsync( + params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), params.StreamHandle())); if (tuning_ctx->IsTunableOpEnabled()) { static GroupNormNHWCTunableOp op; @@ -50,14 +60,17 @@ Status LaunchGroupNormKernel( } template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, half* output, - const half* input, const float* gamma, const float* beta, void* workspace, - float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + half* add_out, const half* input, const half* skip, const half* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, + int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block); template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, float* output, - const float* input, const float* gamma, const float* beta, void* workspace, - float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + float* add_out, const float* input, const float* skip, const float* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, + int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block); + } // namespace rocm } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h deleted file mode 100644 index a0f7e0aca5def..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -#include "core/common/common.h" -#include "core/common/status.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -using onnxruntime::rocm::tunable::RocmTuningContext; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -constexpr size_t kMaxGroupNormBatchSize = 32; -constexpr size_t kGroupNormNumberOfGroups = 32; - -constexpr size_t GetGroupNormWorkspaceSizeInBytes() { - // Two buffers for sum and squared sum - return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups; -} - -template -Status LaunchGroupNormKernel( - RocmTuningContext* tuning_ctx, - Stream* stream, - T* output, // normalized output tensor - const T* input, // input tensor - const float* gamma, // gamma (also known as weight or scale) - const float* beta, // beta (also known as bias) - void* workspace, // Work space - float epsilon, // epsilon used normalization - int batch_size, // N - int num_channels, // C - int height, // H - int width, // W - int num_groups, // number of groups - bool use_swish_activation // Whether there is Swish activation after group normalization -); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh deleted file mode 100644 index d6322a12a9363..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh +++ /dev/null @@ -1,213 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// The ROCm kernel is modified from TensorRT 8.5. -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/rocm_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -static inline __device__ __host__ float sigmoid(float x) { - return 1.F / (1.F + expf(-x)); -} - -struct GroupSums { - // Is it the 1st element of the group? - int32_t flag; - // The sum. - float sum; - // The sum of squares. - float sumSq; -}; - -struct GroupSumsOp { - inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { - GroupSums dst; - dst.sum = b.flag ? b.sum : (a.sum + b.sum); - dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq); - dst.flag = a.flag + b.flag; - return dst; - } -}; - -template -inline __device__ void UpdateSum(const T* src, int64_t offset, U& sum, U& sumSq) { - using VecT = onnxruntime::rocm::aligned_vector; - const VecT input_v = *reinterpret_cast(src + offset); - -#pragma unroll - for (int i = 0; i < ILP; i++) { - const U val = static_cast(input_v.val[i]); - sum += val; - sumSq += val * val; - } -} - -template -__global__ void groupNormNHWCSumKernel(const T* src, float* redBuffer, int32_t cPerBlock, int32_t hwPerBlock, int32_t hw, - int32_t hwc, int32_t c, int32_t cPerGroup, int32_t groups, int32_t groupsPerBlock) { - // The object in charge of doing the sums for the different blocks. - typedef hipcub::BlockScan BlockScan; - - // Allocate shared memory for BlockScan. - __shared__ typename BlockScan::TempStorage tempStorage; - // Allocate shared memory for the groups. We could reduce the amount of shared - // memory reserved. - __shared__ float2 smem[ThreadsPerBlock]; - - // The instance in the batch. - int32_t ni = blockIdx.z; - // The channel loaded by that thread (ILP channels per thread). - int32_t ci = blockIdx.x * cPerBlock + threadIdx.x * ILP; - - // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * hwPerBlock; - // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + hwPerBlock, hw); - - // The sums. - float sum = 0.F; - float sumSq = 0.F; - - // Iterate over the activations to compute the sums. - if (ci < c) { - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The offset. - int64_t offset = static_cast(ni) * hwc + static_cast(hwi) * c + ci; - UpdateSum(src, offset, sum, sumSq); - } - } - - // The group that thread works on and the channel in the group (modulus). - int32_t gi = threadIdx.x * ILP / cPerGroup; - int32_t cj = threadIdx.x * ILP - cPerGroup * gi; - - // The data for the summations. - GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq}; - - // Do the segmented scan. - GroupSums out; - BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp()); - - // Store the results for the groups in shared memory (to produce coalesced - // stores later). - if (cj == cPerGroup - ILP) { // ILP channels per thread - smem[gi] = make_float2(out.sum, out.sumSq); - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The global group index. - int32_t gj = blockIdx.x * groupsPerBlock + threadIdx.x; - - // Threads that have nothing left to do, exit. - if (threadIdx.x >= groupsPerBlock || gj >= groups) { - return; - } - - // The first threads (those storing to global memory, load the values). - float2 sums = smem[threadIdx.x]; - - // Store to global memory. - atomicAdd(&redBuffer[(2 * ni + 0) * groups + gj], sums.x); - atomicAdd(&redBuffer[(2 * ni + 1) * groups + gj], sums.y); -} - -template -__device__ void computeGroupNorm(const T* src, T* dst, int64_t offset, U mean, U invStdDev, - const U* gamma_v, const U* beta_v, bool swish) { - using VecT = onnxruntime::rocm::aligned_vector; - const VecT input_v = *reinterpret_cast(src + offset); - VecT output_v; - -#pragma unroll - for (int i = 0; i < ILP; i++) { - U val = static_cast(input_v.val[i]); - val = (val - mean) * invStdDev; - val = gamma_v[i] * val + beta_v[i]; - - if (swish) { - val = val * sigmoid(val); - } - output_v.val[i] = static_cast(val); - } - *(reinterpret_cast(dst + offset)) = output_v; -} - -template -__global__ void groupNormNHWCScaleKernel(T* dst, const T* src, const float* gamma, const float* beta, const float* redBuffer, float epsilon, int32_t c, int32_t cPerBlock, - int32_t cPerGroup, int32_t groups, int32_t hwc, float invHWC, int32_t hw, int32_t hwPerBlock, bool withSwish) { - // The channel loaded by that thread (ILP channels per thread for F16x2). - int32_t ci = blockIdx.x * cPerBlock + threadIdx.x * ILP; - if (ci >= c) { - return; - } - - // The instance in the batch. - int32_t ni = blockIdx.z; - - // The group that thread works on and the channel in the group (modulus). - int32_t gi = ci / cPerGroup; - - // Load the sum and sum of squares for the group. - float sum = 0.F, sumSq = 0.F; - if (gi < groups) { - sum = redBuffer[(2 * ni + 0) * groups + gi]; - sumSq = redBuffer[(2 * ni + 1) * groups + gi]; - } - - using VecF = onnxruntime::rocm::aligned_vector; - - const VecF gamma_v = *reinterpret_cast(gamma + ci); - const VecF beta_v = *reinterpret_cast(beta + ci); - - // Compute the mean. - float mean = sum * invHWC; - // Compute the variance. - float var = sumSq * invHWC - (mean * mean); - // Compute the inverse of the stddev. - float invStdDev = var <= 0.F ? 1.F : rsqrtf(var + epsilon); - - // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * hwPerBlock; - // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + hwPerBlock, hw); - - // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The src/dst offset. - int64_t offset = (int64_t)ni * hwc + hwi * c + ci; - - // Fetch ILP channels per thread. - computeGroupNorm(src, dst, offset, mean, invStdDev, gamma_v.val, beta_v.val, withSwish); - } -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh index b7b9441ac997d..c6ca16bfdfc80 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh @@ -20,21 +20,21 @@ namespace rocm { namespace { -template +template std::string GetGroupNormTritonGroupName() { std::string ret = "GroupNormTriton_"; - std::string swish_suffix = WithSwish ? "Swish_" : "Pass_"; - ret += swish_suffix; + std::string silu_suffix = WithSilu ? "Silu_" : "Pass_"; + ret += silu_suffix; ret += GetDataTypeName(); return ret; } } // namespace -template +template auto GetTritonGroupNormNHWCTypeStringAndOps() { - std::vector>>> ret; - auto group_name = GetGroupNormTritonGroupName(); + std::vector>>> ret; + auto group_name = GetGroupNormTritonGroupName(); auto* kernel_list = GetOrtTritonKernelByGroup(group_name); if (kernel_list == nullptr) { return ret; @@ -45,36 +45,50 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { auto* metadata = GetOrtTritonKernelMetadata(i); auto block_size = metadata->constants.at("BLOCK_SIZE"); auto hw_size = metadata->constants.at("HW_SIZE"); - auto impl = [i, block_size, hw_size](const GroupNormNHWCParams* params) -> Status { + auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams* params) -> Status { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->cPerGroup > block_size || params->cPerGroup * 2 <= block_size, - "Arg block_size (", block_size, ") is not the next power of 2 of cPerGroup (", params->cPerGroup, ")."); + params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size, + "Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (", + params->channels_per_group, ")."); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->hw % hw_size != 0, "Arg hw_size (", hw_size, ") is not a divisor of hw (", params->hw, ")."); - if constexpr (WithSwish) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->withSwish, "Swish version does not support GN w/o swish."); + if constexpr (WithSilu) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->use_silu, "Silu version does not support GN w/o silu."); } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->withSwish, "Pass version does not support GN w/ swish."); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->use_silu, "Pass version does not support GN w/ silu."); } // Construct args for launch kernel struct { - void* X; - void* Y; + const void* src; + const void* skip; + const void* bias; + void* out; + void* add_out; const void* gamma; const void* beta; int hw; int c; int c_per_group; float eps; + bool has_skip; + bool has_bias; + bool broadcast_skip; } args = { - (void*)params->src, + (const void*)params->src, + (const void*)params->skip, + (const void*)params->bias, (void*)params->dst, + (void*)params->skip_workspace, (const void*)params->gamma, (const void*)params->beta, params->hw, params->c, - params->cPerGroup, - params->epsilon}; + params->channels_per_group, + params->epsilon, + params->skip != nullptr, + params->bias != nullptr, + params->broadcast_skip, + }; // Grid dim is (batch_count, groups, 1) return LaunchTritonKernel(params->StreamHandle(), i, params->n, params->groups, 1, &args, sizeof(args)); diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py index 56b3a030b289e..5ba96ebc117f0 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py @@ -12,16 +12,22 @@ @triton.jit def group_norm_kernel( input_ptr, + skip_ptr, + bias_ptr, output_ptr, + add_out_ptr, gamma_ptr, beta_ptr, img_size, c, c_per_group, eps, + has_skip, + has_bias, + broadcast_skip, BLOCK_SIZE: tl.constexpr, HW_SIZE: tl.constexpr, - ACTIVATION_SWISH: tl.constexpr, + ACTIVATION_SILU: tl.constexpr, ): row_x = tl.program_id(0) row_y = tl.program_id(1) @@ -36,14 +42,35 @@ def group_norm_kernel( offsets = hw[:, None] * c + cols[None, :] mask = (cols < c_per_group)[None, :] + bias = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + if has_skip: + add_out_ptr += row_x * stride + row_y * c_per_group + if broadcast_skip: + broadcast_skip_ptr = skip_ptr + row_x * c + row_y * c_per_group + bias += tl.load(broadcast_skip_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) + else: + skip_ptr += row_x * stride + row_y * c_per_group + if has_bias: + bias_ptr += row_y * c_per_group + bias += tl.load(bias_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) + # Calculate mean and variance _sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) _square_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) for i in range(tl.cdiv(img_size, HW_SIZE)): x_ptr = input_ptr + i * HW_SIZE * c a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + if has_skip and not broadcast_skip: + s_ptr = skip_ptr + i * HW_SIZE * c + s = tl.load(s_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + a += s + if has_bias or broadcast_skip: + a += bias _sum += a _square_sum += a * a + if has_skip: + add_y_ptr = add_out_ptr + i * HW_SIZE * c + tl.store(add_y_ptr + offsets, a, mask=mask) # Set axis=None (or leave it unspecified) to reduce all axes. # TODO: In older Triton we have to reduce an axis at a time, but in our case @@ -57,12 +84,16 @@ def group_norm_kernel( gamma = tl.load(gamma_ptr + cols, mask=cols < c_per_group).to(tl.float32) beta = tl.load(beta_ptr + cols, mask=cols < c_per_group).to(tl.float32) for i in range(tl.cdiv(img_size, HW_SIZE)): - x_ptr = input_ptr + i * HW_SIZE * c y_ptr = output_ptr + i * HW_SIZE * c - x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + if has_skip: + add_y_ptr = add_out_ptr + i * HW_SIZE * c + x = tl.load(add_y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + else: + x_ptr = input_ptr + i * HW_SIZE * c + x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) x_hat = (x - group_mean) * rstd y = x_hat * gamma + beta - if ACTIVATION_SWISH: + if ACTIVATION_SILU: y *= tl.sigmoid(y) tl.store(y_ptr + offsets, y, mask=mask) @@ -71,27 +102,27 @@ def group_norm_kernel( # blocks = [16, 32, 64, 128, 256, 512] # hw_sizes = [8, 16, 32, 64, 128, 256, 512] # but this will result in too many functions and slow down the compilation. -with_swish = [True, False] +with_silu = [True, False] dtypes = ["fp32", "fp16"] blocks = [16, 32, 64, 128] hw_sizes = [8, 16, 32, 64, 128, 256] warps = [1, 2, 4, 8, 16] name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}" -sig_pattern = "*{},*{},*fp32,*fp32,i32,i32,i32,fp32" +sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1,i1" group_pattern = "GroupNormTriton_{}_{}" def get_function_table(): func_table = [] - for swish, dtype, hw_size, warp, b in product(with_swish, dtypes, hw_sizes, warps, blocks): - swish_suffix = "Swish" if swish else "Pass" - name = name_pattern.format(swish_suffix, dtype, b, hw_size, warp) - group = group_pattern.format(swish_suffix, dtype) - sig = sig_pattern.format(dtype, dtype) + for silu, dtype, hw_size, warp, b in product(with_silu, dtypes, hw_sizes, warps, blocks): + silu_suffix = "Silu" if silu else "Pass" + name = name_pattern.format(silu_suffix, dtype, b, hw_size, warp) + group = group_pattern.format(silu_suffix, dtype) + sig = sig_pattern.format(dtype, dtype, dtype, dtype, dtype) kwargs = { "num_warps": warp, - "constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SWISH": int(swish)}, + "constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SILU": int(silu)}, } func_desc = {"name": name, "group": group, "func": group_norm_kernel, "sig": sig, "kwargs": kwargs} func_table.append(func_desc) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h index 25d820f7ed326..e6831f764b418 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h @@ -20,115 +20,117 @@ namespace rocm { using onnxruntime::rocm::GPU_WARP_SIZE; template -void groupNormNHWCSum(const GroupNormNHWCParams* params) { - // Make sure the values are as we expect. - ORT_ENFORCE(params->c % params->cPerBlock == 0 && params->hw % params->hwPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params->cPerBlock % params->cPerGroup == 0); - +void GroupNormNHWCSum(const GroupNormNHWCTunableParams* params) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params->c / params->cPerBlock; + grid.x = DivUp(params->c, params->channels_per_block); // The number of blocks to compute all the activations in a given instance. - grid.y = CeilDiv(params->hw, params->hwPerBlock); + grid.y = DivUp(params->hw, params->hw_per_block); // The number of instances. grid.z = params->n; -#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ - groupNormNHWCSumKernel \ - <<StreamHandle()>>>( \ - params->src, params->redBuffer, params->cPerBlock, \ - params->hwPerBlock, params->hw, params->hwc, params->c, \ - params->cPerGroup, params->groups, params->groupsPerBlock); \ +#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ + GroupNormNHWCSumKernel \ + <<StreamHandle()>>>( \ + params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, \ + params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, \ + params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); \ break; - switch (params->cPerBlock) { - case 320: - LAUNCH_GROUPNORM_SUM(256, 2) - case 480: - LAUNCH_GROUPNORM_SUM(256, 2) + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params->threads_per_block) { case 256: - LAUNCH_GROUPNORM_SUM(128, 2) + LAUNCH_GROUPNORM_SUM(256, CHANNELS_PER_THREAD) + case 192: + LAUNCH_GROUPNORM_SUM(192, CHANNELS_PER_THREAD) + case 160: + LAUNCH_GROUPNORM_SUM(160, CHANNELS_PER_THREAD) case 128: - LAUNCH_GROUPNORM_SUM(64, 2) + LAUNCH_GROUPNORM_SUM(128, CHANNELS_PER_THREAD) + case 64: + LAUNCH_GROUPNORM_SUM(64, CHANNELS_PER_THREAD) default: ORT_NOT_IMPLEMENTED("Not implemented"); } } template -Status GroupNormNHWCSumOp(const GroupNormNHWCParams* params) { +Status GroupNormNHWCSumOp(const GroupNormNHWCTunableParams* params) { dim3 grid; - grid.x = params->c / params->cPerBlock; - grid.y = CeilDiv(params->hw, params->hwPerBlock); + grid.x = DivUp(params->c, params->channels_per_block); + grid.y = DivUp(params->hw, params->hw_per_block); grid.z = params->n; - groupNormNHWCSumKernel + GroupNormNHWCSumKernel <<StreamHandle()>>>( - params->src, params->redBuffer, params->cPerBlock, params->hwPerBlock, - params->hw, params->hwc, params->c, params->cPerGroup, params->groups, params->groupsPerBlock); + params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, + params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, + params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); return HIP_CALL(hipGetLastError()); } template -void groupNormNHWCScale(const GroupNormNHWCParams* params) { - // Make sure the dimensions are aligned with what we expect. - ORT_ENFORCE(params->c % params->cPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params->cPerBlock % params->cPerGroup == 0); - +void GroupNormNHWCScale(const GroupNormNHWCTunableParams* params) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params->c / params->cPerBlock; + grid.x = DivUp(params->c, params->channels_per_block); // The number of blocks to compute all the activations in a given instance. - grid.y = CeilDiv(params->hw, params->hwPerBlock); + grid.y = DivUp(params->hw, params->hw_per_block); // The number of instances. grid.z = params->n; -#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ - groupNormNHWCScaleKernel \ - <<StreamHandle()>>>( \ - params->dst, params->src, params->gamma, params->beta, \ - params->redBuffer, params->epsilon, params->c, params->cPerBlock, \ - params->cPerGroup, params->groups, params->hwc, params->invHWC, \ - params->hw, params->hwPerBlock, params->withSwish); \ +#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ + GroupNormNHWCScaleKernel \ + <<StreamHandle()>>>( \ + params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, \ + params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, \ + params->channels_per_group, params->groups, params->hwc, params->inv_hw_channels_per_group, \ + params->hw, params->hw_per_block, params->use_silu); \ break; - switch (params->cPerBlock) { - case 320: - LAUNCH_GROUPNORM_SCALE(256, 2) - case 480: - LAUNCH_GROUPNORM_SCALE(256, 2) + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params->threads_per_block) { case 256: - LAUNCH_GROUPNORM_SCALE(128, 2) + LAUNCH_GROUPNORM_SCALE(256, CHANNELS_PER_THREAD) + case 192: + LAUNCH_GROUPNORM_SCALE(192, CHANNELS_PER_THREAD) + case 160: + LAUNCH_GROUPNORM_SCALE(160, CHANNELS_PER_THREAD) case 128: - LAUNCH_GROUPNORM_SCALE(64, 2) + LAUNCH_GROUPNORM_SCALE(128, CHANNELS_PER_THREAD) + case 64: + LAUNCH_GROUPNORM_SCALE(64, CHANNELS_PER_THREAD) default: ORT_NOT_IMPLEMENTED("Not implemented"); } } template -Status GroupNormNHWCScaleOp(const GroupNormNHWCParams* params) { +Status GroupNormNHWCScaleOp(const GroupNormNHWCTunableParams* params) { dim3 grid; - grid.x = params->c / params->cPerBlock; - grid.y = CeilDiv(params->hw, params->hwPerBlock); + grid.x = DivUp(params->c, params->channels_per_block); + grid.y = DivUp(params->hw, params->hw_per_block); grid.z = params->n; - groupNormNHWCScaleKernel + GroupNormNHWCScaleKernel <<StreamHandle()>>>( - params->dst, params->src, params->gamma, params->beta, params->redBuffer, params->epsilon, params->c, params->cPerBlock, - params->cPerGroup, params->groups, params->hwc, params->invHWC, params->hw, params->hwPerBlock, params->withSwish); + params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, + params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, params->channels_per_group, + params->groups, params->hwc, params->inv_hw_channels_per_group, params->hw, params->hw_per_block, + params->use_silu); return HIP_CALL(hipGetLastError()); } template class GroupNormNHWCOp { public: - Status operator()(const GroupNormNHWCParams* params) { - HIP_RETURN_IF_ERROR(hipMemsetAsync(params->redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), params->StreamHandle())); + Status operator()(const GroupNormNHWCTunableParams* params) { + HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, + 0, + GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), + params->StreamHandle())); auto status = GroupNormNHWCSumOp(params); ORT_RETURN_IF_ERROR(status); HIP_RETURN_IF_ERROR(hipGetLastError()); @@ -138,29 +140,30 @@ class GroupNormNHWCOp { return Status::OK(); } - Status IsSupported(const GroupNormNHWCParams* params) { + Status IsSupported(const GroupNormNHWCTunableParams* params) { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !(params->c % VecSize == 0 && params->cPerGroup % VecSize == 0), - "The number of channels (", params->c, ") or the number of channels per group (", params->cPerGroup, + !(params->c % VecSize == 0 && params->channels_per_group % VecSize == 0), + "The number of channels (", params->c, ") or the number of channels per group (", params->channels_per_group, ") isn't divisible by the number of vector size: ", VecSize); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->cPerBlock % params->cPerGroup == 0 && - params->c % params->cPerBlock == 0 && params->hw % params->hwPerBlock == 0), - "The value of attributes don't meet the requirements."); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->cPerBlock <= ThreadsPerBlock * VecSize && - params->cPerBlock > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize), + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->channels_per_block <= ThreadsPerBlock * VecSize && + params->channels_per_block > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize), "Configuration: Threads (", ThreadsPerBlock, "), vector size (", - VecSize, ") is redundant for the number of channels per group: ", params->cPerBlock); + VecSize, ") is redundant for the number of channels per group: ", + params->channels_per_block); return Status::OK(); } }; template -Status GroupNormNHWCStaticSelection(const GroupNormNHWCParams* params) { - HIP_RETURN_IF_ERROR(hipMemsetAsync(params->redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), params->StreamHandle())); - groupNormNHWCSum(params); +Status GroupNormNHWCStaticSelection(const GroupNormNHWCTunableParams* params) { + HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, + 0, + GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), + params->StreamHandle())); + GroupNormNHWCSum(params); HIP_RETURN_IF_ERROR(hipGetLastError()); - groupNormNHWCScale(params); + GroupNormNHWCScale(params); HIP_RETURN_IF_ERROR(hipGetLastError()); return Status::OK(); } @@ -178,30 +181,30 @@ Status GroupNormNHWCStaticSelection(const GroupNormNHWCParams* params) { ADD_OP_FOR_ALL_VEC_SIZE(name, 320) template -class GroupNormNHWCTunableOp : public TunableOp> { +class GroupNormNHWCTunableOp : public TunableOp> { public: GroupNormNHWCTunableOp() { this->RegisterOp(GroupNormNHWCStaticSelection); ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWCOp) #ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { + for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } - for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { + for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } #endif // USE_COMPOSABLE_KERNEL #ifdef USE_TRITON_KERNEL - for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { + for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } - for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { + for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index 55cd6a1d112f5..382a3951f3a83 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -93,6 +93,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Samp class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SkipGroupNorm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); @@ -246,6 +247,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index 711fd595e90fd..be881f6bc4bc2 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -52,6 +52,13 @@ #if defined(CPUINFO_SUPPORTED) #include +#if defined(CPUIDINFO_ARCH_ARM) +namespace onnxruntime { +// The following function is declared in "core/common/cpuid_uarch.h" but we cannot include the whole header file because +// some of its symbols are conflict with +void decodeMIDR(uint32_t midr, uint32_t uarch[1]); +} // namespace onnxruntime +#endif #else #include "core/common/cpuid_uarch.h" #endif // CPUINFO_SUPPORTED @@ -142,11 +149,6 @@ void CPUIDInfo::ArmLinuxInit() { // Pytorch CPUINFO only works on ARM linux or android // Assuming no hyper-threading, no NUMA groups #ifdef CPUINFO_SUPPORTED - pytorch_cpuinfo_init_ = cpuinfo_initialize(); - if (!pytorch_cpuinfo_init_) { - LOGS_DEFAULT(WARNING) << "Failed to init pytorch cpuinfo library, may cause CPU EP performance degradation due to undetected CPU features."; - return; - } is_hybrid_ = cpuinfo_get_uarchs_count() > 1; has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot(); has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); @@ -239,52 +241,24 @@ void CPUIDInfo::ArmWindowsInit() { lastUarch = uarch; } } - - switch (lastUarch) { - case cpuinfo_uarch_cortex_a55: - case cpuinfo_uarch_cortex_a55r0: - case cpuinfo_uarch_cortex_a76: - case cpuinfo_uarch_neoverse_n1: - case cpuinfo_uarch_cortex_a77: - case cpuinfo_uarch_exynos_m4: - case cpuinfo_uarch_exynos_m5: - has_fp16_ = true; - break; - default: - break; - } - if (!has_fp16_) { - /* - * Detecting fp16 support. Different cores should have the same instruction set. - * So we just check the first ID_AA64PFR0_EL1 - * Op0(0b11), Op1(0b000), CRn(0b0000), CRm(0b0100), Op2(0b000), - */ - uint64_t ID_AA64PFR0_EL1; - unsigned long valsize = sizeof(uint64_t); - auto retCode = ::RegGetValueA( - HKEY_LOCAL_MACHINE, - "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0", - "CP 4020", RRF_RT_REG_QWORD, nullptr, - &ID_AA64PFR0_EL1, &valsize); - if (retCode == ERROR_SUCCESS) { - // AdvSIMD, bits [23:20] - auto advSimd = ID_AA64PFR0_EL1 >> 20; - if ((advSimd & 0xfULL) == 1) { - has_fp16_ = true; - } - } - } #endif /* Application Family or OneCore Family */ has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); #else has_arm_neon_dot_ = false; #endif - has_fp16_ |= has_arm_neon_dot_; - /* TODO: implement them when hw+sw is available for testing these features */ - has_arm_neon_i8mm_ = false; - has_arm_sve_i8mm_ = false; - has_arm_neon_bf16_ = false; + + if (pytorch_cpuinfo_init_) { + has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); + has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); + has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); + has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); + } else { + has_fp16_ = false; + has_arm_neon_i8mm_ = false; + has_arm_sve_i8mm_ = false; + has_arm_neon_bf16_ = false; + } } #endif /* (arm or arm64) and windows */ @@ -304,5 +278,21 @@ uint32_t CPUIDInfo::GetCurrentCoreIdx() const { return 0xFFFFFFFF; // don't know how to get core index #endif } - +CPUIDInfo::CPUIDInfo() { +#ifdef CPUIDINFO_ARCH_X86 + X86Init(); +#elif defined(CPUIDINFO_ARCH_ARM) +#if CPUINFO_SUPPORTED + pytorch_cpuinfo_init_ = cpuinfo_initialize(); + if (!pytorch_cpuinfo_init_) { + LOGS_DEFAULT(WARNING) << "Failed to init pytorch cpuinfo library, may cause CPU EP performance degradation due to undetected CPU features."; + } +#endif +#ifdef __linux__ + ArmLinuxInit(); +#elif defined(_WIN32) + ArmWindowsInit(); +#endif /* (arm or arm64) and windows */ +#endif +} } // namespace onnxruntime diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index 2f8041e39f680..a3936b4bd11a6 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -93,17 +93,7 @@ class CPUIDInfo { } private: - CPUIDInfo() { -#ifdef CPUIDINFO_ARCH_X86 - X86Init(); -#elif defined(CPUIDINFO_ARCH_ARM) -#ifdef __linux__ - ArmLinuxInit(); -#elif defined(_WIN32) - ArmWindowsInit(); -#endif /* (arm or arm64) and windows */ -#endif - } + CPUIDInfo(); bool has_amx_bf16_{false}; bool has_avx_{false}; bool has_avx2_{false}; @@ -131,11 +121,13 @@ class CPUIDInfo { #ifdef CPUIDINFO_ARCH_X86 void X86Init(); - #elif defined(CPUIDINFO_ARCH_ARM) + // Now the following var is only used in ARM build, but later one we may expand the usage. + bool pytorch_cpuinfo_init_{false}; +#endif + #ifdef __linux__ - bool pytorch_cpuinfo_init_{false}; void ArmLinuxInit(); #elif defined(_WIN32) @@ -143,7 +135,6 @@ class CPUIDInfo { void ArmWindowsInit(); #endif /* (arm or arm64) and windows */ -#endif }; } // namespace onnxruntime diff --git a/onnxruntime/core/common/flatbuffers.h b/onnxruntime/core/common/flatbuffers.h new file mode 100644 index 0000000000000..0d61e1038a82c --- /dev/null +++ b/onnxruntime/core/common/flatbuffers.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#if defined(__GNUC__) +#include "onnxruntime_config.h" +#pragma GCC diagnostic push + +#ifdef HAS_SHORTEN_64_TO_32 +#pragma GCC diagnostic ignored "-Wshorten-64-to-32" +#endif +#endif + +#include "flatbuffers/flatbuffers.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif \ No newline at end of file diff --git a/onnxruntime/core/common/string_utils.h b/onnxruntime/core/common/string_utils.h index eca1221e84cb8..716eed1afec51 100644 --- a/onnxruntime/core/common/string_utils.h +++ b/onnxruntime/core/common/string_utils.h @@ -65,5 +65,24 @@ inline std::string TrimString(std::string s) { return s; } +/** + * @brief A consistent way to construct the full qualified op name. + */ +inline std::string GetFullQualifiedOpName(const std::string& op_type, const std::string& domain) { + return MakeString(domain, "::", op_type); +} + +/** + * Use this simple hash to generate unique int by given string input. + */ +inline uint32_t GetHashFromString(const std::string& str_value) { + uint32_t hash = 0; + for (char const& c : str_value) { + hash = hash * 101 + c; + } + + return hash; +} + } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/core/flatbuffers/checkpoint_version.h b/onnxruntime/core/flatbuffers/checkpoint_version.h index 6cad27c35024b..e6ee20bf508ce 100644 --- a/onnxruntime/core/flatbuffers/checkpoint_version.h +++ b/onnxruntime/core/flatbuffers/checkpoint_version.h @@ -13,7 +13,9 @@ namespace onnxruntime { // The format includes support for the ModuleState (stores the module parameters), OptimizerGroups // (stores the optimizer states), and PropertyBag // (stores custom user properties with support for int64, float and strings). -constexpr const int kCheckpointVersion = 1; +// Version 2: Introduces the On-Device Training nominal checkpoint state. +// Changes include the addition of the is_nominal_state field in the checkpoint's ModuleState. +constexpr const int kCheckpointVersion = 2; /** * @brief Check if the given checkpoint version is supported in this build diff --git a/onnxruntime/core/flatbuffers/flatbuffers_utils.h b/onnxruntime/core/flatbuffers/flatbuffers_utils.h index 55bde0b2df806..76860d6ab1db8 100644 --- a/onnxruntime/core/flatbuffers/flatbuffers_utils.h +++ b/onnxruntime/core/flatbuffers/flatbuffers_utils.h @@ -5,7 +5,7 @@ #include -#include "flatbuffers/flatbuffers.h" +#include "core/common/flatbuffers.h" #include "core/common/common.h" #include "core/common/path_string.h" diff --git a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ModuleState.py b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ModuleState.py index 2be826fee2cc3..19c6b1b6f2753 100644 --- a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ModuleState.py +++ b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ModuleState.py @@ -74,9 +74,17 @@ def FrozenParamsIsNone(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) return o == 0 -def ModuleStateStart(builder): builder.StartObject(2) + # ModuleState + def IsNominalState(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + +def ModuleStateStart(builder): builder.StartObject(3) def ModuleStateAddRequiresGradParams(builder, requiresGradParams): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(requiresGradParams), 0) def ModuleStateStartRequiresGradParamsVector(builder, numElems): return builder.StartVector(4, numElems, 4) def ModuleStateAddFrozenParams(builder, frozenParams): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(frozenParams), 0) def ModuleStateStartFrozenParamsVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def ModuleStateAddIsNominalState(builder, isNominalState): builder.PrependBoolSlot(2, isNominalState, 0) def ModuleStateEnd(builder): return builder.EndObject() diff --git a/onnxruntime/core/flatbuffers/schema/README.md b/onnxruntime/core/flatbuffers/schema/README.md index 932478111ee68..96a2936c196ae 100644 --- a/onnxruntime/core/flatbuffers/schema/README.md +++ b/onnxruntime/core/flatbuffers/schema/README.md @@ -21,7 +21,7 @@ e.g. - /build/Linux/Debug/_deps/flatbuffers-build/flatc It is possible to use another flatc as well, e.g., from a separate installation. Note that ONNX Runtime uses -FlatBuffers 1.12. +FlatBuffers 23.5.26. To update the flatbuffers schemas and generated files: 1. Modify [the ORT file format schema](ort.fbs) or [training checkpoint schema](ort_training_checkpoint.fbs). diff --git a/onnxruntime/core/flatbuffers/schema/ort.fbs.h b/onnxruntime/core/flatbuffers/schema/ort.fbs.h index e0f5342c29621..dc8a471f2d81f 100644 --- a/onnxruntime/core/flatbuffers/schema/ort.fbs.h +++ b/onnxruntime/core/flatbuffers/schema/ort.fbs.h @@ -4,7 +4,7 @@ #ifndef FLATBUFFERS_GENERATED_ORT_ONNXRUNTIME_FBS_H_ #define FLATBUFFERS_GENERATED_ORT_ONNXRUNTIME_FBS_H_ -#include "flatbuffers/flatbuffers.h" +#include "core/common/flatbuffers.h" namespace onnxruntime { namespace fbs { @@ -562,8 +562,8 @@ struct DimensionValue FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_DIM_TYPE) && - VerifyField(verifier, VT_DIM_VALUE) && + VerifyField(verifier, VT_DIM_TYPE, 1) && + VerifyField(verifier, VT_DIM_VALUE, 8) && VerifyOffset(verifier, VT_DIM_PARAM) && verifier.VerifyString(dim_param()) && verifier.EndTable(); @@ -634,7 +634,7 @@ struct TensorTypeAndShape FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_ELEM_TYPE) && + VerifyField(verifier, VT_ELEM_TYPE, 4) && VerifyOffset(verifier, VT_SHAPE) && verifier.VerifyTable(shape()) && verifier.EndTable(); @@ -687,7 +687,7 @@ struct MapType FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_KEY_TYPE) && + VerifyField(verifier, VT_KEY_TYPE, 4) && VerifyOffset(verifier, VT_VALUE_TYPE) && verifier.VerifyTable(value_type()) && verifier.EndTable(); @@ -787,7 +787,7 @@ struct NodeEdge FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_NODE_INDEX) && + VerifyField(verifier, VT_NODE_INDEX, 4) && VerifyOffset(verifier, VT_INPUT_EDGES) && verifier.VerifyVector(input_edges()) && VerifyOffset(verifier, VT_OUTPUT_EDGES) && @@ -911,11 +911,11 @@ struct Node FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyString(doc_string()) && VerifyOffset(verifier, VT_DOMAIN) && verifier.VerifyString(domain()) && - VerifyField(verifier, VT_SINCE_VERSION) && - VerifyField(verifier, VT_INDEX) && + VerifyField(verifier, VT_SINCE_VERSION, 4) && + VerifyField(verifier, VT_INDEX, 4) && VerifyOffset(verifier, VT_OP_TYPE) && verifier.VerifyString(op_type()) && - VerifyField(verifier, VT_TYPE) && + VerifyField(verifier, VT_TYPE, 4) && VerifyOffset(verifier, VT_EXECUTION_PROVIDER_TYPE) && verifier.VerifyString(execution_provider_type()) && VerifyOffset(verifier, VT_INPUTS) && @@ -1174,7 +1174,7 @@ struct TypeInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_DENOTATION) && verifier.VerifyString(denotation()) && - VerifyField(verifier, VT_VALUE_TYPE) && + VerifyField(verifier, VT_VALUE_TYPE, 1) && VerifyOffset(verifier, VT_VALUE) && VerifyTypeInfoValue(verifier, value(), value_type()) && verifier.EndTable(); @@ -1259,7 +1259,7 @@ struct OperatorSetId FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_DOMAIN) && verifier.VerifyString(domain()) && - VerifyField(verifier, VT_VERSION) && + VerifyField(verifier, VT_VERSION, 8) && verifier.EndTable(); } }; @@ -1343,7 +1343,7 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyString(doc_string()) && VerifyOffset(verifier, VT_DIMS) && verifier.VerifyVector(dims()) && - VerifyField(verifier, VT_DATA_TYPE) && + VerifyField(verifier, VT_DATA_TYPE, 4) && VerifyOffset(verifier, VT_RAW_DATA) && verifier.VerifyVector(raw_data()) && VerifyOffset(verifier, VT_STRING_DATA) && @@ -1568,9 +1568,9 @@ struct Attribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyString(name()) && VerifyOffset(verifier, VT_DOC_STRING) && verifier.VerifyString(doc_string()) && - VerifyField(verifier, VT_TYPE) && - VerifyField(verifier, VT_F) && - VerifyField(verifier, VT_I) && + VerifyField(verifier, VT_TYPE, 4) && + VerifyField(verifier, VT_F, 4) && + VerifyField(verifier, VT_I, 8) && VerifyOffset(verifier, VT_S) && verifier.VerifyString(s()) && VerifyOffset(verifier, VT_T) && @@ -1759,12 +1759,12 @@ struct NodesToOptimizeIndices FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_NODE_INDICES) && verifier.VerifyVector(node_indices()) && - VerifyField(verifier, VT_NUM_INPUTS) && - VerifyField(verifier, VT_NUM_OUTPUTS) && - VerifyField(verifier, VT_HAS_VARIADIC_INPUT) && - VerifyField(verifier, VT_HAS_VARIADIC_OUTPUT) && - VerifyField(verifier, VT_NUM_VARIADIC_INPUTS) && - VerifyField(verifier, VT_NUM_VARIADIC_OUTPUTS) && + VerifyField(verifier, VT_NUM_INPUTS, 4) && + VerifyField(verifier, VT_NUM_OUTPUTS, 4) && + VerifyField(verifier, VT_HAS_VARIADIC_INPUT, 1) && + VerifyField(verifier, VT_HAS_VARIADIC_OUTPUT, 1) && + VerifyField(verifier, VT_NUM_VARIADIC_INPUTS, 4) && + VerifyField(verifier, VT_NUM_VARIADIC_OUTPUTS, 4) && verifier.EndTable(); } }; @@ -1862,8 +1862,8 @@ struct DeprecatedNodeIndexAndKernelDefHash FLATBUFFERS_FINAL_CLASS : private fla } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_NODE_INDEX) && - VerifyField(verifier, VT_KERNEL_DEF_HASH) && + VerifyField(verifier, VT_NODE_INDEX, 4) && + VerifyField(verifier, VT_KERNEL_DEF_HASH, 8) && verifier.EndTable(); } }; @@ -2161,7 +2161,7 @@ struct Graph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyOffset(verifier, VT_NODES) && verifier.VerifyVector(nodes()) && verifier.VerifyVectorOfTables(nodes()) && - VerifyField(verifier, VT_MAX_NODE_INDEX) && + VerifyField(verifier, VT_MAX_NODE_INDEX, 4) && VerifyOffset(verifier, VT_NODE_EDGES) && verifier.VerifyVector(node_edges()) && verifier.VerifyVectorOfTables(node_edges()) && @@ -2390,7 +2390,7 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_IR_VERSION) && + VerifyField(verifier, VT_IR_VERSION, 8) && VerifyOffset(verifier, VT_OPSET_IMPORT) && verifier.VerifyVector(opset_import()) && verifier.VerifyVectorOfTables(opset_import()) && @@ -2400,7 +2400,7 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyString(producer_version()) && VerifyOffset(verifier, VT_DOMAIN) && verifier.VerifyString(domain()) && - VerifyField(verifier, VT_MODEL_VERSION) && + VerifyField(verifier, VT_MODEL_VERSION, 8) && VerifyOffset(verifier, VT_DOC_STRING) && verifier.VerifyString(doc_string()) && VerifyOffset(verifier, VT_GRAPH) && @@ -2740,8 +2740,8 @@ struct ArgTypeAndIndex FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_ARG_TYPE) && - VerifyField(verifier, VT_INDEX) && + VerifyField(verifier, VT_ARG_TYPE, 1) && + VerifyField(verifier, VT_INDEX, 4) && verifier.EndTable(); } }; diff --git a/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs b/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs index c8244b0a426f3..94757fa6d5bf5 100644 --- a/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs +++ b/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs @@ -8,6 +8,10 @@ namespace onnxruntime.fbs; table ModuleState { requires_grad_params:[Tensor]; frozen_params:[Tensor]; + // Nominal state just means that the Tensors in the ModuleState + // are empty. i.e. The tensors are treated as named entities + // without any meaningful data. + is_nominal_state:bool; } table ParameterOptimizerState { diff --git a/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h b/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h index 48feebb197694..62e6cf74394e5 100644 --- a/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h +++ b/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h @@ -4,7 +4,7 @@ #ifndef FLATBUFFERS_GENERATED_ORTTRAININGCHECKPOINT_ONNXRUNTIME_FBS_H_ #define FLATBUFFERS_GENERATED_ORTTRAININGCHECKPOINT_ONNXRUNTIME_FBS_H_ -#include "flatbuffers/flatbuffers.h" +#include "core/common/flatbuffers.h" #include "ort.fbs.h" @@ -39,7 +39,8 @@ struct ModuleState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef ModuleStateBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_REQUIRES_GRAD_PARAMS = 4, - VT_FROZEN_PARAMS = 6 + VT_FROZEN_PARAMS = 6, + VT_IS_NOMINAL_STATE = 8 }; const flatbuffers::Vector> *requires_grad_params() const { return GetPointer> *>(VT_REQUIRES_GRAD_PARAMS); @@ -47,6 +48,9 @@ struct ModuleState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::Vector> *frozen_params() const { return GetPointer> *>(VT_FROZEN_PARAMS); } + bool is_nominal_state() const { + return GetField(VT_IS_NOMINAL_STATE, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_REQUIRES_GRAD_PARAMS) && @@ -55,6 +59,7 @@ struct ModuleState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyOffset(verifier, VT_FROZEN_PARAMS) && verifier.VerifyVector(frozen_params()) && verifier.VerifyVectorOfTables(frozen_params()) && + VerifyField(verifier, VT_IS_NOMINAL_STATE, 1) && verifier.EndTable(); } }; @@ -69,6 +74,9 @@ struct ModuleStateBuilder { void add_frozen_params(flatbuffers::Offset>> frozen_params) { fbb_.AddOffset(ModuleState::VT_FROZEN_PARAMS, frozen_params); } + void add_is_nominal_state(bool is_nominal_state) { + fbb_.AddElement(ModuleState::VT_IS_NOMINAL_STATE, static_cast(is_nominal_state), 0); + } explicit ModuleStateBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -84,23 +92,27 @@ struct ModuleStateBuilder { inline flatbuffers::Offset CreateModuleState( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset>> requires_grad_params = 0, - flatbuffers::Offset>> frozen_params = 0) { + flatbuffers::Offset>> frozen_params = 0, + bool is_nominal_state = false) { ModuleStateBuilder builder_(_fbb); builder_.add_frozen_params(frozen_params); builder_.add_requires_grad_params(requires_grad_params); + builder_.add_is_nominal_state(is_nominal_state); return builder_.Finish(); } inline flatbuffers::Offset CreateModuleStateDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector> *requires_grad_params = nullptr, - const std::vector> *frozen_params = nullptr) { + const std::vector> *frozen_params = nullptr, + bool is_nominal_state = false) { auto requires_grad_params__ = requires_grad_params ? _fbb.CreateVector>(*requires_grad_params) : 0; auto frozen_params__ = frozen_params ? _fbb.CreateVector>(*frozen_params) : 0; return onnxruntime::fbs::CreateModuleState( _fbb, requires_grad_params__, - frozen_params__); + frozen_params__, + is_nominal_state); } struct ParameterOptimizerState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -194,8 +206,8 @@ struct OptimizerGroup FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_GROUP_NAME) && verifier.VerifyString(group_name()) && - VerifyField(verifier, VT_STEP) && - VerifyField(verifier, VT_INITIAL_LEARNING_RATE) && + VerifyField(verifier, VT_STEP, 8) && + VerifyField(verifier, VT_INITIAL_LEARNING_RATE, 4) && VerifyOffset(verifier, VT_OPTIMIZER_STATES) && verifier.VerifyVector(optimizer_states()) && verifier.VerifyVectorOfTables(optimizer_states()) && @@ -277,7 +289,7 @@ struct IntProperty FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_NAME) && verifier.VerifyString(name()) && - VerifyField(verifier, VT_VALUE) && + VerifyField(verifier, VT_VALUE, 8) && verifier.EndTable(); } }; @@ -341,7 +353,7 @@ struct FloatProperty FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_NAME) && verifier.VerifyString(name()) && - VerifyField(verifier, VT_VALUE) && + VerifyField(verifier, VT_VALUE, 4) && verifier.EndTable(); } }; @@ -560,7 +572,7 @@ struct Checkpoint FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_VERSION) && + VerifyField(verifier, VT_VERSION, 4) && VerifyOffset(verifier, VT_MODULE_STATE) && verifier.VerifyTable(module_state()) && VerifyOffset(verifier, VT_OPTIMIZER_GROUPS) && diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 931997694e812..95e5380675df2 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -181,7 +181,6 @@ class PlannerImpl { // upstream_node_0 and upstream_node_1 are the immmediate upstream nodes of downstream_node // upstream_node_2 is the immediate nodes ahead of downstream_node in the same logic stream InlinedHashMap> dependence_graph_; - InlinedHashMap> value_consumer_map_; InlinedHashMap value_node_map_; // OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation: @@ -294,7 +293,7 @@ class PlannerImpl { } #endif - // Find if there exists some input tensor that we can use in-place for output_arg_num-th input in the node. + // Find if there exists some input tensor that we can use in-place for output_arg_num-th output in the node. bool FindReusableInput(const onnxruntime::Node& node, int output_arg_num, OrtValueIndex* reusable_input, bool* is_strided_tensor) { *is_strided_tensor = false; @@ -529,6 +528,7 @@ class PlannerImpl { // Initialize allocation plan: plan_.allocation_plan.resize(num_ml_values); + for (int i = 0; static_cast(i) < num_ml_values; i++) AllocPlan(i).reused_buffer = i; } bool HasExternalOutputs(const Node& node) const { @@ -1064,7 +1064,8 @@ class PlannerImpl { // build the consumer list for each value int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1; - value_consumer_map_.reserve(num_ml_values); + InlinedHashMap> value_consumer_map; + value_consumer_map.reserve(num_ml_values); // iterate each stream from back, so the first element is the last consumer in single stream case for (auto& stream : stream_nodes_) { @@ -1077,10 +1078,10 @@ class PlannerImpl { const auto& name = input.Name(); int value_idx; ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx)); - auto origin = Buffer(value_idx); - if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) { + auto origin = AllocPlan(value_idx).reused_buffer; + if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) { // add current node as consumer for origin buffer - value_consumer_map_[origin].insert(node_index); + value_consumer_map[origin].insert(node_index); } } return Status::OK(); @@ -1137,8 +1138,8 @@ class PlannerImpl { std::cout << p_input_arg->Name() << " reused by " << p_output_arg->Name() << " as input" << std::endl; allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = reusable_input; - value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(), - value_consumer_map_[output_idx_global].end()); + value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(), + value_consumer_map[output_idx_global].end()); reused.insert(reusable_input); found_reusable = true; break; @@ -1167,8 +1168,8 @@ class PlannerImpl { allocation_plan[reusable_input].alloc_kind == AllocKind::kAllocate) { allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = reusable_input; - value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(), - value_consumer_map_[output_idx_global].end()); + value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(), + value_consumer_map[output_idx_global].end()); reused.insert(reusable_input); continue; } // if @@ -1186,11 +1187,11 @@ class PlannerImpl { OrtValueIndex input_arg_index{}; if (value_map.GetIdx(p_input_arg->Name(), input_arg_index).IsOK() && allocation_plan[input_arg_index].alloc_kind == AllocKind::kAllocate) { - if (value_consumer_map_[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) { + if (value_consumer_map[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) { allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = input_arg_index; - value_consumer_map_[input_arg_index].insert(value_consumer_map_[output_idx_global].begin(), - value_consumer_map_[output_idx_global].end()); + value_consumer_map[input_arg_index].insert(value_consumer_map[output_idx_global].begin(), + value_consumer_map[output_idx_global].end()); reused.insert(input_arg_index); } } @@ -1265,7 +1266,7 @@ class PlannerImpl { } bool all_covered = true; - for (auto consumer : value_consumer_map_[output_idx_global]) { + for (auto consumer : value_consumer_map[output_idx_global]) { if (deps->find(consumer) == deps->end()) { all_covered = false; break; @@ -1276,9 +1277,9 @@ class PlannerImpl { allocation_plan[downstream_value].reused_buffer = output_idx_global; get_reused = true; // add new consumer for the value to be reused - value_consumer_map_[output_idx_global].insert(value_node_map_[downstream_value]); - value_consumer_map_[output_idx_global].insert(value_consumer_map_[downstream_value].begin(), - value_consumer_map_[downstream_value].end()); + value_consumer_map[output_idx_global].insert(value_node_map_[downstream_value]); + value_consumer_map[output_idx_global].insert(value_consumer_map[downstream_value].begin(), + value_consumer_map[downstream_value].end()); node_iter = size_iter->second.erase(node_iter); if (size_iter->second.empty()) { local_iter->second.erase(size_iter); @@ -1341,8 +1342,9 @@ class PlannerImpl { ort_value_usecount.reserve(ort_value_info_.size()); #endif for (size_t i = 0; i < stream_nodes_.size(); ++i) { - // compute use count first + // compute use count first. TODO(leca): call ComputeReuseCount() only once is enough! ORT_RETURN_IF_ERROR(ComputeReuseCount()); + for (int j = 0; static_cast(j) < ort_value_info_.size(); j++) Buffer(j) = j; #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) if (i == 0) { for (auto ort_value_info : ort_value_info_) { @@ -1692,8 +1694,8 @@ class PlannerImpl { const auto& name = input.Name(); int value_idx; ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx)); - auto origin = Buffer(value_idx); - if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) { + auto origin = AllocPlan(value_idx).reused_buffer; + if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) { // add current node as consumer for origin buffer value_consumers[origin].push_back(node_index); } @@ -1772,7 +1774,12 @@ class PlannerImpl { execution_plan.emplace_back(std::make_unique(node_device_mem_location)); // 2. add steps to the execution plan for (auto node_index : stream_nodes_[0]) { +#if defined(ORT_MINIMAL_BUILD) execution_plan[0]->steps_.emplace_back(std::make_unique(node_index)); +#else + execution_plan[0]->steps_.emplace_back(std::make_unique(node_index, + graph_viewer_.GetNode(node_index)->Name())); +#endif } } else { // graph with no nodes. e.g. subgraph of If might return the input as-is or a constant value from an initializer @@ -1888,7 +1895,7 @@ class PlannerImpl { // 2. the consumer is in the same stream(non-cpu device), but it consumes a CPU tensor from an non-shape op. // for example, a resize cuda kernel consumer a tensor from MemCpyToHost cuda kernel on the same stream. // in this case, the FIFO can't guarantee the cpu tensor is ready when resize kernel is launching - OrtDevice::DeviceType output_arg_device = plan_.allocation_plan[output_arg_idx].location.Type(); + OrtDevice::DeviceType output_arg_device = AllocPlan(output_arg_idx).location.Type(); WaitNotificationFn wait_handle = stream_handle_registry.GetWaitHandle(stream_device, output_arg_device); if ((plan_.node_stream_map_[it->Index()] != i || output_arg_device == OrtDevice::CPU) && wait_handle != nullptr) { if (node_to_notification.find(node_index) == node_to_notification.end()) { @@ -1977,8 +1984,12 @@ class PlannerImpl { // add dependency for model graph dependence_graph_[it->Index()].insert(node_index); } - // push launch kernel command +// push launch kernel command +#if defined(ORT_MINIMAL_BUILD) execution_plan[i]->steps_.emplace_back(std::make_unique(node_index)); +#else + execution_plan[i]->steps_.emplace_back(std::make_unique(node_index, graph_viewer_.GetNode(node_index)->Name())); +#endif // check if any notification generated by this node, if yes, push a activate auto notification_it = node_to_notification.find(node_index); if (notification_it != node_to_notification.end()) { diff --git a/onnxruntime/core/framework/bfc_arena.h b/onnxruntime/core/framework/bfc_arena.h index e16b90ded3381..5e4cd9f62f11b 100644 --- a/onnxruntime/core/framework/bfc_arena.h +++ b/onnxruntime/core/framework/bfc_arena.h @@ -482,7 +482,7 @@ class BFCArena : public IAllocator { Bin* BinForSize(size_t bytes) { return BinFromIndex(BinNumForSize(bytes)); } - char bins_space_[sizeof(Bin) * kNumBins]; + alignas(Bin) char bins_space_[sizeof(Bin) * kNumBins]; // The size of the current region allocation. SafeInt curr_region_allocation_bytes_; diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index 8c08152986cf6..32a5f749af084 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -204,6 +204,14 @@ AllocatorPtr IExecutionFrame::GetAllocator(const OrtDevice& info) const { Status IExecutionFrame::ReleaseMLValue(int ort_value_idx) { return ReleaseMLValueImpl(ort_value_idx); } +#ifdef ENABLE_TRAINING +void IExecutionFrame::ReleaseAllMLValues() { + for (size_t ort_value_idx = 0; ort_value_idx < all_values_.size(); ort_value_idx++) { + all_values_[ort_value_idx] = OrtValue(); + } +} +#endif + Status IExecutionFrame::ReleaseMLValueImpl(int ort_value_idx) { if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast(ort_value_idx) >= all_values_size_) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index ", ort_value_idx); @@ -831,7 +839,20 @@ AllocatorPtr ExecutionFrame::GetAllocatorImpl(const OrtDevice& info) const { // This method is not thread safe! // Return S_OK and nullptr if index map to a value that is an unused optional input/output Status ExecutionFrame::CreateNodeOutputMLValueImpl(OrtValue& ort_value, int ort_value_idx, const TensorShape* shape) { +#ifdef ENABLE_TRAINING + try { + auto status = AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape); + return status; + } catch (const std::exception& e) { + LOGS(session_state_.Logger(), WARNING) + << "Exception caught when allocating memory for ort_value with index: " << ort_value_idx + << "so clean up all OrtValues"; + ReleaseAllMLValues(); + return Status(ONNXRUNTIME, FAIL, e.what()); + } +#else return AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape); +#endif } void ExecutionFrame::VerifyOutputSizes(int output_index, const Node& node, const TensorShape& output_shape) { diff --git a/onnxruntime/core/framework/execution_frame.h b/onnxruntime/core/framework/execution_frame.h index 1576c16684faa..18d210ffd48f7 100644 --- a/onnxruntime/core/framework/execution_frame.h +++ b/onnxruntime/core/framework/execution_frame.h @@ -67,6 +67,8 @@ class IExecutionFrame { const std::unordered_map& initializers); Status GetOutputs(gsl::span fetch_mlvalue_idxs, std::vector& fetches); + // if OOM happens, then release all values, so session can run next batch. + void ReleaseAllMLValues(); #endif // TO DO: make it thread safe diff --git a/onnxruntime/core/framework/execution_providers.h b/onnxruntime/core/framework/execution_providers.h index 61147e4367876..dc45cad692b6e 100644 --- a/onnxruntime/core/framework/execution_providers.h +++ b/onnxruntime/core/framework/execution_providers.h @@ -3,7 +3,6 @@ #pragma once -// #include #include #include #include @@ -14,7 +13,9 @@ #include "core/common/logging/logging.h" #ifdef _WIN32 #include +#include #include "core/platform/tracing.h" +#include "core/platform/windows/telemetry.h" #endif namespace onnxruntime { @@ -44,6 +45,49 @@ class ExecutionProviders { exec_provider_options_[provider_id] = providerOptions; #ifdef _WIN32 + LogProviderOptions(provider_id, providerOptions, false); + + // Register callback for ETW capture state (rundown) + WindowsTelemetry::RegisterInternalCallback( + [this]( + LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + (void)SourceId; + (void)Level; + (void)MatchAnyKeyword; + (void)MatchAllKeyword; + (void)FilterData; + (void)CallbackContext; + + // Check if this callback is for capturing state + if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) && + ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) { + for (size_t i = 0; i < exec_providers_.size(); ++i) { + const auto& provider_id = exec_provider_ids_[i]; + + auto it = exec_provider_options_.find(provider_id); + if (it != exec_provider_options_.end()) { + const auto& options = it->second; + + LogProviderOptions(provider_id, options, true); + } + } + } + }); +#endif + + exec_provider_ids_.push_back(provider_id); + exec_providers_.push_back(p_exec_provider); + return Status::OK(); + } + +#ifdef _WIN32 + void LogProviderOptions(const std::string& provider_id, const ProviderOptions& providerOptions, bool captureState) { for (const auto& config_pair : providerOptions) { TraceLoggingWrite( telemetry_provider_handle, @@ -52,14 +96,11 @@ class ExecutionProviders { TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingString(provider_id.c_str(), "ProviderId"), TraceLoggingString(config_pair.first.c_str(), "Key"), - TraceLoggingString(config_pair.second.c_str(), "Value")); + TraceLoggingString(config_pair.second.c_str(), "Value"), + TraceLoggingBool(captureState, "isCaptureState")); } -#endif - - exec_provider_ids_.push_back(provider_id); - exec_providers_.push_back(p_exec_provider); - return Status::OK(); } +#endif const IExecutionProvider* Get(const onnxruntime::Node& node) const { return Get(node.GetExecutionProviderType()); diff --git a/onnxruntime/core/framework/execution_steps.cc b/onnxruntime/core/framework/execution_steps.cc index df19236d037c0..b647833cfd373 100644 --- a/onnxruntime/core/framework/execution_steps.cc +++ b/onnxruntime/core/framework/execution_steps.cc @@ -1,8 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. + #include "core/framework/execution_steps.h" #include "core/framework/sequential_executor.h" + namespace onnxruntime { + BarrierStep::BarrierStep(size_t id, NodeIndex node_index) : SequentialExecutionPlan::ExecutionStep(node_index), barrier_id_{id} {} @@ -16,8 +19,8 @@ Status BarrierStep::Execute(StreamExecutionContext& ctx, } std::string BarrierStep::ToString() const { - return ::onnxruntime::MakeString("Set a barrier with id: ", - barrier_id_, ", count: ", 2, "."); + // Set a barrier with id: barrier_id_, count: 2. + return MakeString("Barrier - BarrierId: ", barrier_id_, ", Count: ", 2); } WaitOnEPStep::WaitOnEPStep(WaitNotificationFn handle, @@ -42,11 +45,17 @@ Status WaitOnEPStep::Execute(StreamExecutionContext& ctx, } std::string WaitOnEPStep::ToString() const { - return ::onnxruntime::MakeString("WaitOnEPStep: wait on notification with id: ", - notification_idx_, ". "); + // Wait on notification with notification_idx_ + return MakeString("WaitOnEP - NotificationId: ", notification_idx_); } -LaunchKernelStep::LaunchKernelStep(NodeIndex index) : SequentialExecutionPlan::ExecutionStep(index) {} +#if defined(ORT_MINIMAL_BUILD) +LaunchKernelStep::LaunchKernelStep(NodeIndex index) + : SequentialExecutionPlan::ExecutionStep(index) {} +#else +LaunchKernelStep::LaunchKernelStep(NodeIndex index, std::string_view node_name) + : SequentialExecutionPlan::ExecutionStep(index), node_name_(node_name) {} +#endif Status LaunchKernelStep::Execute(StreamExecutionContext& ctx, size_t stream_idx, @@ -61,13 +70,17 @@ Status LaunchKernelStep::Execute(StreamExecutionContext& ctx, return Status::OK(); } #endif - onnxruntime::Status status = ExecuteKernel(ctx, node_index_, stream_idx, terminate_flag, session_scope); + Status status = ExecuteKernel(ctx, node_index_, stream_idx, terminate_flag, session_scope); continue_flag = status.IsOK(); return status; } std::string LaunchKernelStep::ToString() const { - return ::onnxruntime::MakeString("Launch kernel with node id: ", node_index_, ". "); +#if defined(ORT_MINIMAL_BUILD) + return MakeString("LaunchKernel - ", "NodeIndex: ", node_index_); +#else + return MakeString("LaunchKernel - ", "NodeIndex: ", node_index_, ", Name: ", node_name_); +#endif } ActivateNotificationStep::ActivateNotificationStep( @@ -89,12 +102,12 @@ Status ActivateNotificationStep::Execute(StreamExecutionContext& ctx, } std::string ActivateNotificationStep::ToString() const { - return ::onnxruntime::MakeString("ActivateNotificationStep: activate notification with id: ", - notification_idx_, ". "); + // Activate notification with id: notification_idx_ + return MakeString("ActivateNotification - NotificationId: ", notification_idx_); } -TriggerDownstreamStep::TriggerDownstreamStep(size_t trigger_point_index, NodeIndex node_index) : SequentialExecutionPlan::ExecutionStep(node_index), - trigger_point_index_(trigger_point_index) {} +TriggerDownstreamStep::TriggerDownstreamStep(size_t trigger_point_index, NodeIndex node_index) + : SequentialExecutionPlan::ExecutionStep(node_index), trigger_point_index_(trigger_point_index) {} Status TriggerDownstreamStep::Execute(StreamExecutionContext& ctx, size_t /*stream_idx*/, @@ -107,7 +120,8 @@ Status TriggerDownstreamStep::Execute(StreamExecutionContext& ctx, } std::string TriggerDownstreamStep::ToString() const { - return ::onnxruntime::MakeString("TriggerDownstreamStep: trigger downstream of trigger point: ", - trigger_point_index_, "."); + // Trigger downstream of trigger point: trigger_point_index_. + return MakeString("TriggerDownstream - TriggerPointIndex: ", trigger_point_index_); } + } // namespace onnxruntime diff --git a/onnxruntime/core/framework/execution_steps.h b/onnxruntime/core/framework/execution_steps.h index b67b583900824..545dabc56b272 100644 --- a/onnxruntime/core/framework/execution_steps.h +++ b/onnxruntime/core/framework/execution_steps.h @@ -44,7 +44,11 @@ class WaitOnEPStep : public SequentialExecutionPlan::ExecutionStep { class LaunchKernelStep : public SequentialExecutionPlan::ExecutionStep { public: +#if defined(ORT_MINIMAL_BUILD) LaunchKernelStep(NodeIndex index); +#else + LaunchKernelStep(NodeIndex index, std::string_view node_name); +#endif Status Execute(StreamExecutionContext& ctx, size_t stream_idx, @@ -53,6 +57,11 @@ class LaunchKernelStep : public SequentialExecutionPlan::ExecutionStep { bool& continue_flag) override; std::string ToString() const override; + +#if !defined(ORT_MINIMAL_BUILD) + private: + std::string node_name_; +#endif }; class ActivateNotificationStep : public SequentialExecutionPlan::ExecutionStep { diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 07b465c80745a..90ee8a46f66a9 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -645,6 +645,10 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers all_ep_context_nodes.insert(all_ep_context_nodes.begin(), ep_context_nodes.begin(), ep_context_nodes.end()); } + if (all_ep_context_nodes.size() < 1) { + return Status::OK(); + } + auto get_ep_context_node = [&all_ep_context_nodes](const std::string& node_name) -> std::pair { for (auto& node : all_ep_context_nodes) { if (node_name == node->Name()) { @@ -656,76 +660,70 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers onnxruntime::PathString context_cache_path; PathString model_pathstring = graph.ModelPath().ToPathString(); - if (all_ep_context_nodes.size() > 0) { - if (!ep_context_path.empty()) { - context_cache_path = ToPathString(ep_context_path); - } else if (!model_pathstring.empty()) { - context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); - } - { + if (!ep_context_path.empty()) { + context_cache_path = ToPathString(ep_context_path); + } else if (!model_pathstring.empty()) { + context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); + } + + { #ifdef _WIN32 - std::wifstream fs(context_cache_path); + std::wifstream fs(context_cache_path); #else - std::ifstream fs(context_cache_path); + std::ifstream fs(context_cache_path); #endif - ORT_RETURN_IF(fs.good(), "Failed to generate EP context model since the file exist already."); - } + ORT_RETURN_IF(fs.good(), "Failed to generate EP context model since the file exist already."); + } - Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), - graph.DomainToVersionMap(), {}, logger); - auto& ep_graph = ep_context_model.MainGraph(); - ep_graph.SetDescription(graph.Description()); - - // Set inputs outputs explicitly to make sure the order is same as the user model. - auto inputs = graph.GetInputs(); - auto outputs = graph.GetOutputs(); - - InlinedVector ep_graph_inputs; - ep_graph_inputs.reserve(inputs.size()); - for (auto& input : inputs) { - auto input_arg = graph.GetNodeArg(input->Name()); - auto& ep_graph_input_arg = ep_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto()); - ep_graph_inputs.push_back(&ep_graph_input_arg); - } + Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + graph.DomainToVersionMap(), {}, logger); + auto& ep_graph = ep_context_model.MainGraph(); + ep_graph.SetDescription(graph.Description()); - InlinedVector ep_graph_outputs; - ep_graph_outputs.reserve(outputs.size()); - for (auto& output : outputs) { - auto output_arg = graph.GetNodeArg(output->Name()); - auto& ep_graph_output_arg = ep_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); - ep_graph_outputs.push_back(&ep_graph_output_arg); - } + // Set inputs outputs explicitly to make sure the order is same as the user model. + auto inputs = graph.GetInputs(); + auto outputs = graph.GetOutputs(); - ep_graph.SetInputs(ep_graph_inputs); - ep_graph.SetOutputs(ep_graph_outputs); + InlinedVector ep_graph_inputs; + ep_graph_inputs.reserve(inputs.size()); + for (auto& input : inputs) { + auto input_arg = graph.GetNodeArg(input->Name()); + auto& ep_graph_input_arg = ep_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto()); + ep_graph_inputs.push_back(&ep_graph_input_arg); + } - for (const auto& node : graph.Nodes()) { - // the fused node and EPContext node has same node name - auto ep_context_node = get_ep_context_node(node.Name()); - // Use EpContext node created by the EPs if name matched, otherwise use node from original model - if (ep_context_node.first) { - ep_graph.AddNode(*ep_context_node.second); - } else { - ep_graph.AddNode(node); - } - } + InlinedVector ep_graph_outputs; + ep_graph_outputs.reserve(outputs.size()); + for (auto& output : outputs) { + auto output_arg = graph.GetNodeArg(output->Name()); + auto& ep_graph_output_arg = ep_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); + ep_graph_outputs.push_back(&ep_graph_output_arg); + } - // handle initializers - for (const auto& input : graph.GetInputsIncludingInitializers()) { - const ONNX_NAMESPACE::TensorProto* initializer = nullptr; - if (graph.GetInitializedTensor(input->Name(), initializer)) { - // There initializer could have duplicates so make sure we only add once - const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; - if (!ep_graph.GetInitializedTensor(input->Name(), subgraph_initializer)) { - ep_graph.AddInitializedTensor(*initializer); - } - } + ep_graph.SetInputs(ep_graph_inputs); + ep_graph.SetOutputs(ep_graph_outputs); + + for (const auto& node : graph.Nodes()) { + // the fused node and EPContext node has same node name + auto ep_context_node = get_ep_context_node(node.Name()); + // Use EpContext node created by the EPs if name matched, otherwise use node from original model + if (ep_context_node.first) { + ep_graph.AddNode(*ep_context_node.second); + } else { + ep_graph.AddNode(node); } + } - ORT_RETURN_IF_ERROR(Model::Save(ep_context_model, context_cache_path)); + // handle initializers + for (const auto& initialized_tensor : graph.GetAllInitializedTensors()) { + if (ep_graph.GetNodeArg(initialized_tensor.first) != nullptr) { + ep_graph.AddInitializedTensor(*initialized_tensor.second); + } } + ORT_RETURN_IF_ERROR(Model::Save(ep_context_model, context_cache_path)); + return Status::OK(); } diff --git a/onnxruntime/core/framework/kernel_type_str_resolver.h b/onnxruntime/core/framework/kernel_type_str_resolver.h index 31a806dd52291..fea2a6ef3a439 100644 --- a/onnxruntime/core/framework/kernel_type_str_resolver.h +++ b/onnxruntime/core/framework/kernel_type_str_resolver.h @@ -7,7 +7,7 @@ #include #include -#include "flatbuffers/flatbuffers.h" +#include "core/common/flatbuffers.h" #if !defined(ORT_MINIMAL_BUILD) #include "core/graph/onnx_protobuf.h" diff --git a/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc b/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc index 4f5fa9910b5df..473e78c3f5e25 100644 --- a/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc +++ b/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc @@ -5,7 +5,7 @@ #include "core/framework/kernel_type_str_resolver_utils.h" -#include "flatbuffers/flatbuffers.h" +#include "core/common/flatbuffers.h" #include "core/common/common.h" #include "core/flatbuffers/schema/ort.fbs.h" diff --git a/onnxruntime/core/framework/node_unit.cc b/onnxruntime/core/framework/node_unit.cc new file mode 100644 index 0000000000000..4dee1c14b3761 --- /dev/null +++ b/onnxruntime/core/framework/node_unit.cc @@ -0,0 +1,351 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + +#include "node_unit.h" +#include "core/graph/graph_viewer.h" + +namespace onnxruntime { + +namespace { + +enum class QLinearOpType : uint8_t { + Unknown, // Unknown or not a linear quantized op + DequantizeLinear, + QuantizeLinear, + QLinearConv, + QLinearMatMul, + QLinearAdd, + QLinearSigmoid, + QLinearAveragePool, + QLinearMul, + QLinearReduceMean, + QLinearConcat, + QLinearGlobalAveragePool, + QLinearLeakyRelu, +}; + +QLinearOpType GetQLinearOpType(const onnxruntime::Node& node) { + const auto& op_type = node.OpType(); + if (op_type == "DequantizeLinear") + return QLinearOpType::DequantizeLinear; + else if (op_type == "QuantizeLinear") + return QLinearOpType::QuantizeLinear; + else if (op_type == "QLinearConv") + return QLinearOpType::QLinearConv; + else if (op_type == "QLinearMatMul") + return QLinearOpType::QLinearMatMul; + else if (op_type == "QLinearAdd") + return QLinearOpType::QLinearAdd; + else if (op_type == "QLinearSigmoid") + return QLinearOpType::QLinearSigmoid; + else if (op_type == "QLinearAveragePool") + return QLinearOpType::QLinearAveragePool; + else if (op_type == "QLinearMul") + return QLinearOpType::QLinearMul; + else if (op_type == "QLinearReduceMean") + return QLinearOpType::QLinearReduceMean; + else if (op_type == "QLinearConcat") + return QLinearOpType::QLinearConcat; + else if (op_type == "QLinearGlobalAveragePool") + return QLinearOpType::QLinearGlobalAveragePool; + else if (op_type == "QLinearLeakyRelu") + return QLinearOpType::QLinearLeakyRelu; + + return QLinearOpType::Unknown; +} + +// Ops have 1 input +bool IsUnaryQLinearOp(QLinearOpType type) { + return type == QLinearOpType::QLinearSigmoid || + type == QLinearOpType::QLinearAveragePool || + type == QLinearOpType::QLinearGlobalAveragePool || + type == QLinearOpType::QLinearLeakyRelu || + type == QLinearOpType::QLinearReduceMean; +} + +// Ops have 2 inputs +bool IsBinaryQLinearOp(QLinearOpType type) { + return type == QLinearOpType::QLinearConv || + type == QLinearOpType::QLinearMatMul || + type == QLinearOpType::QLinearAdd || + type == QLinearOpType::QLinearMul; +} + +// Ops have 1 or more inputs +bool IsVariadicQLinearOp(QLinearOpType type) { + return type == QLinearOpType::QLinearConcat; +} + +const std::vector GetQDQIONodes(const GraphViewer& graph_viewer, + const QDQ::NodeGroup& node_group, bool is_input) { + std::vector io_nodes; + const auto& src_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes; + io_nodes.reserve(src_nodes.size()); + for (const auto& node_idx : src_nodes) { + io_nodes.push_back(graph_viewer.GetNode(node_idx)); + } + + return io_nodes; +} + +// Get the input or output NodeUnitIODef(s) for the given QDQ NodeGroup +std::vector GetQDQIODefs(const Node& target_node, const QDQ::NodeGroup& node_group, bool is_input) { + const auto& dq_or_q_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes; + const auto target_node_io_defs = is_input ? target_node.InputDefs() : target_node.OutputDefs(); + const size_t target_node_io_defs_size = target_node_io_defs.size(); + + // Find all the quantized IO defs and indices (for the input/output of the target node) + std::unordered_map quantized_io_defs; + quantized_io_defs.reserve(target_node_io_defs_size); + + auto cur = is_input ? target_node.InputEdgesBegin() : target_node.OutputEdgesBegin(); + auto end = is_input ? target_node.InputEdgesEnd() : target_node.OutputEdgesEnd(); + + for (; cur != end; ++cur) { + const Node& node = cur->GetNode(); + + // If we can find the node index in the dq or q nodes this is a quantized input/output + if (std::find(dq_or_q_nodes.cbegin(), dq_or_q_nodes.cend(), node.Index()) != dq_or_q_nodes.cend()) { + const auto node_inputs = node.InputDefs(); + // quantization scale and zp are always the input[1, 2] + NodeUnitIODef::QuantParam quant_param{*node_inputs[1], node_inputs.size() == 3 ? node_inputs[2] : nullptr}; + + if (is_input) { + // DQ is input to the target node, use the DstArgIndex + auto idx = cur->GetDstArgIndex(); + // This is a DQ node, we are using x, x_scale, x_zp (input[0, 1, 2]) + quantized_io_defs.insert({idx, NodeUnitIODef{*node_inputs[0], quant_param}}); + } else { + // Q is output of the target node, use the SrcArgIndex + auto idx = cur->GetSrcArgIndex(); + // This is a Q node, we are using y (output[0]), y_scale, y_zp (input[1, 2]) + const auto node_outputs = node.OutputDefs(); + quantized_io_defs.insert({idx, NodeUnitIODef{*node_outputs[0], quant_param}}); + } + } + } + + // Construct the IODefs for this QDQ NodeGroup + std::vector io_defs; + io_defs.reserve(target_node_io_defs_size); + for (size_t i = 0; i < target_node_io_defs_size; i++) { + // If we can find the NodeUnitIODef for this index, this is a quantized input/output + if (quantized_io_defs.find(i) != quantized_io_defs.cend()) { + io_defs.push_back(std::move(quantized_io_defs.at(i))); + } else { + // This is a regular input + io_defs.push_back({*target_node_io_defs[i], std::nullopt}); + } + } + + return io_defs; +} + +} // namespace + +Status QDQ::NodeGroup::CanCreateNodeGroup(const GraphViewer& graph_viewer, + const Node& target_node, + gsl::span dq_nodes, + gsl::span q_nodes) { + // Within a QDQ node group, a target node input is the only consumer of each DQ. + // This should have been ensured by the EnsureUniqueDQForNodeUnit graph transformer, but other graph modifications + // may have happened since. Verify that this is still true. + for (const auto* dq_node : dq_nodes) { + const bool dq_produces_graph_output = graph_viewer.NodeProducesGraphOutput(*dq_node); + ORT_RETURN_IF(dq_produces_graph_output, + "QDQ node group cannot have DQ node that produces a graph output. DQ node: ", dq_node->Name(), + ", target node: ", target_node.Name()); + + const bool dq_has_single_output_edge_to_target = + dq_node->GetOutputEdgesCount() == 1 && + dq_node->OutputEdgesBegin()->GetNode().Index() == target_node.Index(); + ORT_RETURN_IF_NOT(dq_has_single_output_edge_to_target, + "QDQ node group cannot have DQ that doesn't have a single output edge to the target node. " + "DQ node: ", + dq_node->Name(), ", target node: ", target_node.Name()); + } + + // an output from the target node can have either Q consumers or direct consumers. it cannot have both. + // this must be checked on a per output basis. + // e.g. TopK produces values and indices. The indices output won't be quantized, so even if we replace the TopK QDQ + // node group with a quantized TopK, an int64_t indices value will be produced and can provide a graph output. + if (!q_nodes.empty()) { + auto cur_edge = target_node.OutputEdgesBegin(); + auto end_edge = target_node.OutputEdgesEnd(); + std::vector output_consumers(target_node.OutputDefs().size(), nullptr); + + for (; cur_edge != end_edge; ++cur_edge) { + auto output_idx = cur_edge->GetSrcArgIndex(); + const Node& this_consumer = cur_edge->GetNode(); + const Node* existing_consumer = output_consumers[output_idx]; + + if (existing_consumer != nullptr) { + // another edge for this output. either both are Q or both are not. + bool valid = true; + if (existing_consumer->OpType() == "QuantizeLinear") { + valid = this_consumer.OpType() == "QuantizeLinear"; + } else { + valid = this_consumer.OpType() != "QuantizeLinear"; + } + + ORT_RETURN_IF_NOT(valid, + "QDQ node group cannot have an output from the target node being consumed by a Q node and " + "a non-Q node. target node: ", + target_node.Name()); + } else { + output_consumers[output_idx] = &this_consumer; + } + } + + const auto& graph_outputs = graph_viewer.GetOutputs(); + for (size_t idx = 0, end = output_consumers.size(); idx < end; ++idx) { + // any output with a Q cannot be a graph output as it will disappear if the QDQ node unit is converted to + // a quantized op. + if (output_consumers[idx] != nullptr && output_consumers[idx]->OpType() == "QuantizeLinear") { + const auto& output_name = target_node.OutputDefs()[idx]->Name(); + bool is_graph_output = std::any_of(graph_outputs.begin(), graph_outputs.end(), + [&output_name](const NodeArg* node_arg) { + return node_arg->Name() == output_name; + }); + ORT_RETURN_IF(is_graph_output, + "QDQ node group cannot have an output from the target node that is consumed by a Q node and " + "a graph output. target node: ", + target_node.Name(), " output idx:", idx); + } + } + } + + return Status::OK(); +} +NodeUnit::NodeUnit(const Node& node) + : target_node_(node), + type_(Type::SingleNode), + input_edge_count_(node.GetInputEdgesCount()) { + InitForSingleNode(); +} + +NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group) + : dq_nodes_{GetQDQIONodes(graph_viewer, node_group, true /* is_input */)}, + target_node_(*graph_viewer.GetNode(node_group.target_node)), + q_nodes_{GetQDQIONodes(graph_viewer, node_group, false /* is_input */)}, + type_(Type::QDQGroup), + inputs_{GetQDQIODefs(target_node_, node_group, true /* is_input */)}, + outputs_{GetQDQIODefs(target_node_, node_group, false /* is_input */)} { + ORT_THROW_IF_ERROR(QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, target_node_, dq_nodes_, q_nodes_)); + + input_edge_count_ = std::accumulate(dq_nodes_.cbegin(), dq_nodes_.cend(), size_t(0), + [](size_t acc, const Node* node) { return acc + node->GetInputEdgesCount(); }); + + // add edges for inputs that are not from DQ nodes. there is one edge to each DQ node. + // other inputs could come from initializers or graph inputs (no edges) or other nodes (edge). + input_edge_count_ += target_node_.GetInputEdgesCount() - dq_nodes_.size(); + + // create output edges. each target node output either goes to Q node/s or non-Q node/s. + // ValidateNodeGroupQDQNodes ensures this. + auto cur_edge = target_node_.OutputEdgesBegin(); + auto end_edge = target_node_.OutputEdgesEnd(); + for (; cur_edge != end_edge; ++cur_edge) { + const Node& node = cur_edge->GetNode(); + + // if node is in q_nodes we hide the Q node. + if (std::find(q_nodes_.cbegin(), q_nodes_.cend(), &node) != q_nodes_.cend()) { + auto src_idx = cur_edge->GetSrcArgIndex(); + auto q_cur_edge = node.OutputEdgesBegin(); + auto q_end_edge = node.OutputEdgesEnd(); + for (; q_cur_edge != q_end_edge; ++q_cur_edge) { + output_edges_.insert(Node::EdgeEnd{q_cur_edge->GetNode(), src_idx, q_cur_edge->GetDstArgIndex()}); + } + } else { + // non-Q node, or Q node that isn't in the QDQ node group (unexpected but may be possible). add as-is. + output_edges_.insert(*cur_edge); + } + } +} + +const std::string& NodeUnit::Domain() const noexcept { return target_node_.Domain(); } +const std::string& NodeUnit::OpType() const noexcept { return target_node_.OpType(); } +const std::string& NodeUnit::Name() const noexcept { return target_node_.Name(); } +int NodeUnit::SinceVersion() const noexcept { return target_node_.SinceVersion(); } +NodeIndex NodeUnit::Index() const noexcept { return target_node_.Index(); } +const Path& NodeUnit::ModelPath() const noexcept { return target_node_.ModelPath(); } +ProviderType NodeUnit::GetExecutionProviderType() const noexcept { return target_node_.GetExecutionProviderType(); } + +void NodeUnit::InitForSingleNode() { + const auto& input_defs = target_node_.InputDefs(); + const auto& output_defs = target_node_.OutputDefs(); + auto qlinear_type = GetQLinearOpType(target_node_); + if (qlinear_type == QLinearOpType::Unknown || IsVariadicQLinearOp(qlinear_type)) { // TODO, add variadic support + // Not a Qlinear op, add all inputs / outputs + auto add_all_io = [](std::vector& defs, + const ConstPointerContainer>& node_defs) { + defs.reserve(node_defs.size()); + + for (const auto def : node_defs) { + defs.push_back(NodeUnitIODef{*def, std::nullopt}); + } + }; + + add_all_io(inputs_, input_defs); + add_all_io(outputs_, output_defs); + } else if (IsUnaryQLinearOp(qlinear_type)) { + // Unary QLinear Op has 5 inputs + // x, x_scale, x_zp, y_scale, y_zp (optional) + inputs_.push_back(NodeUnitIODef{*input_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}}); + outputs_.push_back(NodeUnitIODef{*output_defs[0], + NodeUnitIODef::QuantParam{*input_defs[3], + input_defs.size() > 4 ? input_defs[4] : nullptr}}); + + } else if (IsBinaryQLinearOp(qlinear_type)) { + // Binary QLinear Op has 9 inputs + // x1, x1_scale, x1_zp, x2/w, x2_scale, x2_zp, y_scale , y_zp, B + inputs_.push_back(NodeUnitIODef{*input_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}}); + inputs_.push_back(NodeUnitIODef{*input_defs[3], NodeUnitIODef::QuantParam{*input_defs[4], input_defs[5]}}); + + if (input_defs.size() == 9) { // has Bias + inputs_.push_back(NodeUnitIODef{*input_defs[8], std::nullopt}); // for Bias the scale and zp are optional + } + + outputs_.push_back(NodeUnitIODef{*output_defs[0], NodeUnitIODef::QuantParam{*input_defs[6], input_defs[7]}}); + + } else if (qlinear_type == QLinearOpType::DequantizeLinear) { + // DequantizeLinear has 3 inputs + // x, x_scale, x_zp + // output is not quantized + inputs_.push_back(NodeUnitIODef{*input_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs.size() == 3 + ? input_defs[2] + : nullptr}}); + outputs_.push_back(NodeUnitIODef{*output_defs[0], std::nullopt}); + + } else if (qlinear_type == QLinearOpType::QuantizeLinear) { + // QuantizeLinear the input is not quantized and has 3 inputs + // x, y_scale, y_zp (optional) + // The output is quantized + inputs_.push_back(NodeUnitIODef{*input_defs[0], std::nullopt}); + outputs_.push_back(NodeUnitIODef{*output_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs.size() == 3 + ? input_defs[2] + : nullptr}}); + } else { + ORT_THROW("The QLinear op [", static_cast(qlinear_type), "] is not supported"); + } +} + +Node::EdgeConstIterator NodeUnit::OutputEdgesBegin() const { + return (type_ == Type::SingleNode) ? target_node_.OutputEdgesBegin() : output_edges_.begin(); +} + +Node::EdgeConstIterator NodeUnit::OutputEdgesEnd() const { + return (type_ == Type::SingleNode) ? target_node_.OutputEdgesEnd() : output_edges_.end(); +} + +std::vector NodeUnit::GetAllNodesInGroup() const noexcept { + std::vector all_nodes = dq_nodes_; + all_nodes.push_back(&target_node_); + all_nodes.insert(all_nodes.end(), q_nodes_.begin(), q_nodes_.end()); + return all_nodes; +} + +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/providers/shared/node_unit/node_unit.h b/onnxruntime/core/framework/node_unit.h similarity index 54% rename from onnxruntime/core/providers/shared/node_unit/node_unit.h rename to onnxruntime/core/framework/node_unit.h index b47204ca3c42d..66afaec8ee1e2 100644 --- a/onnxruntime/core/providers/shared/node_unit/node_unit.h +++ b/onnxruntime/core/framework/node_unit.h @@ -3,6 +3,9 @@ #pragma once +// QDQ models require graph modification at runtime, so we know this infrastructure is not used in a minimal build +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + #include #include #include @@ -18,8 +21,21 @@ class NodeArg; class Path; namespace QDQ { -struct NodeGroup; -} +// Struct to represent a DequantizeLinear -> Op -> QuantizeLinear node group +struct NodeGroup { + std::vector dq_nodes; + std::vector q_nodes; + NodeIndex target_node; + + // Validator to check if the set of nodes can form a valid QDQ NodeGroup. + // Checks target node is only consumer of each DQ, and that the outputs remain valid if the QDQ node group was to + // be converted into a single node with a quantized operator. + static Status CanCreateNodeGroup(const GraphViewer& graph_viewer, + const Node& target_node, + gsl::span dq_nodes, + gsl::span q_nodes); +}; +} // namespace QDQ // Definition of one input or output // If the optional quant_param is present, then this is a quantized input, @@ -69,26 +85,33 @@ class NodeUnit { const std::vector& GetQNodes() const noexcept { return q_nodes_; } std::vector GetAllNodesInGroup() const noexcept; - Node::EdgeConstIterator OutputEdgesBegin(size_t index) const; - Node::EdgeConstIterator OutputEdgesEnd(size_t index) const; + /// Number of input edges to the logical node. For a QDQ node this is the count of input edges to the DQ nodes + /// plus any other edges to the target node for inputs that are not via a DQ node. + size_t InputEdgeCount() const { return input_edge_count_; } + + // output edges. src index is for outputs of the target node. dest index and node is for consumer of node unit + // output. any Q nodes are hidden. + Node::EdgeConstIterator OutputEdgesBegin() const; + Node::EdgeConstIterator OutputEdgesEnd() const; private: - const std::vector q_nodes_; // q-nodes for this NodeUnit - const std::vector dq_nodes_; // dq nodes for this NodeUnit, not all inputs + // Initialization for a NodeUnit that contains a single node + void InitForSingleNode(); + + const std::vector dq_nodes_; // dq nodes for this NodeUnit, not necessarily all inputs const Node& target_node_; + const std::vector q_nodes_; // q-nodes for this NodeUnit. not necessarily all outputs const Type type_; std::vector inputs_; std::vector outputs_; - // Initializing for a single Node - void InitForSingleNode(); -}; + size_t input_edge_count_; // total number of input edges -// Get all the nodes in the given graph_viewer as NodeUnits (SingleNode or QDQGroup) -// And return a map to quick query the NodeUnit which contains the given Node, -// Note, the value of the map is owned by the vector of std::unique_ptr -std::pair>, std::unordered_map> -GetAllNodeUnits(const GraphViewer& graph_viewer); + // output edges, hiding any Q nodes involved. src_idx will be value from target node. only used for QDQ node group. + Node::EdgeSet output_edges_; +}; } // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/sequential_executor.cc b/onnxruntime/core/framework/sequential_executor.cc index ea7f1397c961b..0cc7294a46495 100644 --- a/onnxruntime/core/framework/sequential_executor.cc +++ b/onnxruntime/core/framework/sequential_executor.cc @@ -306,18 +306,20 @@ class KernelScope { #endif #ifdef ENABLE_NVTX_PROFILE - auto& node = kernel_.Node(); - profile::NvtxRangeCreator& forward_range = session_scope_.forward_range_; - profile::NvtxRangeCreator& backward_range = session_scope_.backward_range_; - if (node.Description() != "Backward pass" && !forward_range.IsBeginCalled()) { - // Start timing forward pass when encountering the first forward node. - forward_range.Begin(); - } else if (node.Description() == "Backward pass" && !backward_range.IsBeginCalled() && - forward_range.IsBeginCalled()) { - // Start timing backward pass when encountering the first backward node. - // In the meanwhile, forward range ends. - forward_range.End(); - backward_range.Begin(); + { + auto& node = kernel_.Node(); + profile::NvtxRangeCreator& forward_range = session_scope_.forward_range_; + profile::NvtxRangeCreator& backward_range = session_scope_.backward_range_; + if (node.Description() != "Backward pass" && !forward_range.IsBeginCalled()) { + // Start timing forward pass when encountering the first forward node. + forward_range.Begin(); + } else if (node.Description() == "Backward pass" && !backward_range.IsBeginCalled() && + forward_range.IsBeginCalled()) { + // Start timing backward pass when encountering the first backward node. + // In the meanwhile, forward range ends. + forward_range.End(); + backward_range.Begin(); + } } #endif diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index 5bf229c5a3a19..e318c9a8238c7 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -8,7 +8,7 @@ #include #include -#include "flatbuffers/flatbuffers.h" +#include "core/common/flatbuffers.h" #include "core/common/gsl.h" diff --git a/onnxruntime/core/framework/stream_execution_context.cc b/onnxruntime/core/framework/stream_execution_context.cc index 875e7f395bfa8..dd7f4d35b34bd 100644 --- a/onnxruntime/core/framework/stream_execution_context.cc +++ b/onnxruntime/core/framework/stream_execution_context.cc @@ -181,11 +181,13 @@ void RunSince(size_t stream_idx, StreamExecutionContext& ctx, SessionScope& sess } #ifdef USE_CANN + // Leave it to CANN EP to fill the gap if they want to use run_options + static onnxruntime::RunOptions run_options; // For CANN EP, it is necessary to explicitly create a corresponding Context for each thread in the thread pool, // which is different from CUDA Runtime API, but similar to CUDA Driver API. auto& execution_providers = ctx.GetSessionState().GetExecutionProviders(); for (auto& xp : execution_providers) { - auto status = xp->OnRunStart(); + auto status = xp->OnRunStart(run_options); if (!status.IsOK()) { ctx.SetStatus(status); return; diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 5c8854b13a4e7..f2c9aaa395be0 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -1019,9 +1019,19 @@ bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) } #ifdef ENABLE_ATEN + // For ATen node, we assume that all tensor inputs are on device, all non-tensor inputs are on CPU, + // except those specified in attribute cpu_input_args; if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" && node.Domain() == kPytorchAtenDomain) { const auto& attrs = node.GetAttributes(); + if (auto entry = attrs.find("cpu_input_args"); entry != attrs.end()) { + const auto& attr = entry->second; + if (utils::HasInts(attr) && std::any_of(attr.ints().cbegin(), attr.ints().cend(), + [index](int64_t arg) { return static_cast(index) == arg; })) { + return true; + } + } + ORT_ENFORCE(utils::HasString(attrs.at("operator"))); std::string op_name = attrs.at("operator").s(); std::string overload_name = ""; @@ -1029,7 +1039,7 @@ bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) overload_name = attrs.at("overload_name").s(); } - return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, true); + return !contrib::aten_ops::ATenOperatorExecutor::Instance().IsTensorArgument(op_name, overload_name, index, true); } #else ORT_UNUSED_PARAMETER(node); @@ -1044,9 +1054,19 @@ bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index } #ifdef ENABLE_ATEN + // For ATen node, we assume that all tensor outputs are on device, all non-tensor outputs are on CPU, + // except those specified in attribute cpu_output_args; if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" && node.Domain() == kPytorchAtenDomain) { const auto& attrs = node.GetAttributes(); + if (auto entry = attrs.find("cpu_output_args"); entry != attrs.end()) { + const auto& attr = entry->second; + if (utils::HasInts(attr) && std::any_of(attr.ints().cbegin(), attr.ints().cend(), + [index](int64_t arg) { return static_cast(index) == arg; })) { + return true; + } + } + ORT_ENFORCE(utils::HasString(attrs.at("operator"))); std::string op_name = attrs.at("operator").s(); std::string overload_name = ""; @@ -1054,7 +1074,7 @@ bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index overload_name = attrs.at("overload_name").s(); } - return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, false); + return !contrib::aten_ops::ATenOperatorExecutor::Instance().IsTensorArgument(op_name, overload_name, index, false); } #else ORT_UNUSED_PARAMETER(node); diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 8583474a1e391..8bf013ed009d5 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -259,6 +259,16 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& *output_shape.add_dim() = query_dims[1]; *output_shape.add_dim() = query_dims[2]; updateOutputShape(ctx, 0, output_shape); + } else { + ONNX_NAMESPACE::TensorShapeProto output_shape; + int64_t num_heads = getAttribute(ctx, "num_heads", 0); + int64_t kv_num_heads = getAttribute(ctx, "kv_num_heads", 0); + int64_t hidden_size = query_dims[2].dim_value(); + int64_t head_size = hidden_size / (num_heads + 2 * kv_num_heads); + *output_shape.add_dim() = query_dims[0]; + *output_shape.add_dim() = query_dims[1]; + output_shape.add_dim()->set_dim_value(head_size * num_heads); + updateOutputShape(ctx, 0, output_shape); } } diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 27c968a59eb91..6709398c788f0 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1163,7 +1163,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, "Shape is (1,)", "T", OpSchema::Optional) .Input(6, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) + .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "M", OpSchema::Optional) .Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional) .Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) @@ -1188,7 +1188,15 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, .SetDoc("Beam Search for whisper model, especiall with cross_qk features etc.") .Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT) .Attr("pad_token_id", "The id of the padding token", AttributeProto::INT) - .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts.", AttributeProto::INT, static_cast(-1)) + .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts (i.e. the start of transcription token id)", AttributeProto::INT, static_cast(-1)) + .Attr("translate_token_id", "The id of the translate task", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("transcribe_token_id", "The id of the transcribe task", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("start_of_lm_token_id", "The id of the token that indicates LM starts", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("no_speech_token_id", + "The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.", + AttributeProto::INT, OPTIONAL_VALUE) + .Attr("no_timestamps_token_id", "The id of the token that indicates no timestamps", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("beginning_timestamp_token_id", "The id of the first timestamp", AttributeProto::INT, OPTIONAL_VALUE) .Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast(0)) .Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast(0)) .Attr("model_type", "Must be 2 for whisper", AttributeProto::INT, static_cast(2)) @@ -1203,27 +1211,24 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, "If not provided, it will be inferred from the decoder subgraph's output shape", AttributeProto::INT, static_cast(-1)) .Attr("decoder_output_cross_qk", "If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.", AttributeProto::INT, OPTIONAL_VALUE) - .Attr("no_speech_token", - "The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.", - AttributeProto::INT, OPTIONAL_VALUE) .Input(0, "input_ids", "The sequence used as a prompt for the generation in the encoder subgraph. Shape is (batch_size, sequence_length)", "F") .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "num_beams", "Number of beams for beam search. 1 means no beam search. Shape is (1)", "I") .Input(4, "num_return_sequences", "The number of returned sequences in the batch. Shape is (1)", "I") .Input(5, "length_penalty", - "Exponential penalty to the length. Default value 1.0 means no penalty." - "Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences." + "Exponential penalty to the length. Default value 1.0 means no penalty. " + "Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences. " "Shape is (1,)", "T", OpSchema::Optional) .Input(6, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) + .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "M", OpSchema::Optional) .Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional) .Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) .Input(11, "logits_processor", "Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)", "I", OpSchema::Optional) .Input(12, "cross_qk_layer_head", - "Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all" + "Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all " "its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]", "I", OpSchema::Optional) .Input(13, "extra_decoding_ids", @@ -1235,20 +1240,19 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) .Output(2, "scores", - "Processed beam scores for each vocabulary token at each generation step." - "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam." + "Processed beam scores for each vocabulary token at each generation step. " + "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam. " "Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)", "T", OpSchema::Optional) .Output(3, "cross_qk", "Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, " - "F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers," - "B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]." + "F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers, " + "B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]. " "If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]", "V", OpSchema::Optional) .Output(4, "non_speech_probs", - "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token." - "Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph." - "The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]", + "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id. " + "The shape of non_speech_probs is [B]", "T", OpSchema::Optional) .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain to float tensors.") .TypeConstraint("F", {"tensor(float)", "tensor(int32)", "tensor(float16)"}, "Constrain input type to float or int tensors.") @@ -1322,7 +1326,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GreedySearch, 1, .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "I", OpSchema::Optional) + .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "I", OpSchema::Optional) .Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)", "I") @@ -1363,7 +1367,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "I", OpSchema::Optional) + .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "I", OpSchema::Optional) .Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(7, "presence_mask", "Presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) @@ -3339,22 +3343,23 @@ MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7 And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. 3. Input B's scale and zero point are specified by input scales and zero_points. -Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: -- n_blocks_per_col = (K + block_size - 1) / block_size -- blob_size = block_size / 8 * bits + Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: + - n_blocks_per_col = (K + block_size - 1) / block_size + - blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>) + For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t. + - for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t. + 4bit example: + |.|.|.|.| .|.|.|.| =uint8_t (2x4bit) + - for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted. + 3bit example: + |.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used. + The last uint_8 may have some bits unused. - For a block blob. It is stored in format: - struct Blob { - uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization - uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization - uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization - } Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] -Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is: - - [(N * n_blocks_per_col + 1) / 2] if bits <=4 - - [N * n_blocks_per_col] if bits > 4 - +Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B. + - [CeilDiv((N * n_blocks_per_col + 1) *bits, 8)] + If zero_points has same type as A, it's not packed and has the same shape as Scales. )DOC"; ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBits) @@ -3373,12 +3378,15 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored "type T1.", AttributeProto::INT, static_cast(0)) .Input(0, "A", "The input tensor, not quantized", "T1") - .Input(1, "B", "1-dimensional data blob", "T2") + .Input(1, "B", "1 or 2 dimensional data blob", "T2") .Input(2, "scales", "quantization scale", "T1") - .Input(3, "zero_points", "quantization zero points", "T2", OpSchema::Optional) + .Input(3, "zero_points", "quantization zero points", "T3", OpSchema::Optional) + .Input(4, "g_idx", "group_idx", "T4", OpSchema::Optional) .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.") - .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") + .TypeConstraint("T2", {"tensor(uint8)", "tensor(int32)"}, "Constrain quantized weight types to uint8/int32.") + .TypeConstraint("T3", {"tensor(uint8)", "tensor(int32)", "tensor(float16)", "tensor(float)"}, "Constrain quantized zero point types to uint8/int32/float16/float.") + .TypeConstraint("T4", {"tensor(int32)"}, "the index tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { // Type inference propagateElemTypeFromInputToOutput(ctx, 0, 0); @@ -3466,6 +3474,8 @@ MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 /*min_arity*/ 1) .Attr("operator", "Name of ATen operator.", AttributeProto::STRING) .Attr("overload_name", "Overload name of ATen operator.", AttributeProto::STRING, false) + .Attr("cpu_input_args", "CPU input argument indices.", AttributeProto::INTS, false) + .Attr("cpu_output_args", "CPU output argument indices.", AttributeProto::INTS, false) .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Allow inputs and outputs to be any kind of tensor."); #endif diff --git a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc index c8960578f9e3d..6bf19654a3ce9 100644 --- a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc +++ b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc @@ -106,6 +106,7 @@ void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::function from, } } +template +struct VisitorPriorityQueue { + using ComparatorType = std::function; + std::list list_; + const ComparatorType comparator_ = nullptr; + VisitorPriorityQueue(const ComparatorType& comp) : comparator_(comp) {} + + void push(T node) { + list_.insert( + std::upper_bound(list_.begin(), list_.end(), node, comparator_), + node); + } + bool empty() { return list_.empty(); } + T top() { return list_.back(); } + void pop() { list_.pop_back(); } +}; + #if !defined(ORT_MINIMAL_BUILD) void Graph::KahnsTopologicalSort(const std::function& enter, const std::function& comp) const { - std::unordered_map in_degree; - std::priority_queue, decltype(comp)> to_visit(comp); - std::vector topo_order; + InlinedVector in_degree(MaxNodeIndex(), 0); + InlinedVector topo_order; + VisitorPriorityQueue to_visit(comp); + + auto number_of_nodes = NumberOfNodes(); + topo_order.reserve(number_of_nodes); for (auto& node : Nodes()) { size_t input_edge_count = node.GetInputEdgesCount(); - in_degree.insert({node.Index(), input_edge_count}); + in_degree[node.Index()] = input_edge_count; if (input_edge_count == 0) { to_visit.push(&node); } @@ -1844,16 +1864,17 @@ void Graph::KahnsTopologicalSort(const std::function& enter, } for (auto node_it = current->OutputNodesBegin(); node_it != current->OutputNodesEnd(); ++node_it) { - in_degree[node_it->Index()]--; + auto& node_in_degree = in_degree[node_it->Index()]; + node_in_degree--; - if (in_degree[node_it->Index()] == 0) { + if (node_in_degree == 0) { to_visit.push(&*node_it); } } topo_order.push_back(current->Index()); } - if (NumberOfNodes() != static_cast(topo_order.size())) { + if (number_of_nodes != static_cast(topo_order.size())) { ORT_THROW("Some nodes are not included in the topological sort, graph have a cycle."); } } @@ -2798,12 +2819,13 @@ Status Graph::Resolve(const ResolveOptions& options) { graph.GraphProtoSyncNeeded(false); } + // set num_resolves_ here so the graph and any subgraphs all have the same value + ++graph.num_resolves_; + return Status::OK(); }; ORT_RETURN_IF_ERROR(ForThisAndAllSubgraphs(all_subgraphs, finalize_func)); - ++num_resolves_; - return Status::OK(); } @@ -2842,7 +2864,7 @@ void Graph::AddInitializedTensor(const TensorProto& tensor) { const gsl::not_null tensor_added{graph_proto_->add_initializer()}; *(tensor_added) = tensor; - name_to_initial_tensor_[tensor.name()] = tensor_added; + name_to_initial_tensor_.emplace(tensor.name(), tensor_added); SetGraphResolveNeeded(); if (!is_loaded_from_model_file_ && GetNodeArg(tensor.name()) == nullptr) { // make sure there is a NodeArg for the initializer as SetGraphInputsOutputs may add it to the graph inputs. diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.cc b/onnxruntime/core/graph/graph_flatbuffers_utils.cc index 8e962403556dd..2314a5228f83c 100644 --- a/onnxruntime/core/graph/graph_flatbuffers_utils.cc +++ b/onnxruntime/core/graph/graph_flatbuffers_utils.cc @@ -3,7 +3,7 @@ #include "graph_flatbuffers_utils.h" -#include "flatbuffers/flatbuffers.h" +#include "core/common/flatbuffers.h" #include "core/common/narrow.h" #include "core/flatbuffers/flatbuffers_utils.h" @@ -392,6 +392,14 @@ Status LoadOrtTensorOrtFormat(const fbs::Tensor& fbs_tensor, const AllocatorPtr ort_tensor = onnxruntime::Tensor( tensor_dtype, TensorShape(tensor_dims->data(), tensor_dims->size()), allocator); + if (fbs_tensor.raw_data()->size() == 0U) { + // Empty tensor. Nothing to unpack. + // This check is necessary because an empty ort tensor will return a size of 1. + // As a result, the following call to UnpackTensor will fail since the src and + // dst sizes do not match (0 and 1 elements). + return Status::OK(); + } + // The tensor proto is used as a dummy here. The actual data is stored in the raw_data field of the flatbuffer. // The data is copied from the raw_data field to the ort_tensor. ONNX_NAMESPACE::TensorProto unused_tensor_proto; diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.h b/onnxruntime/core/graph/graph_flatbuffers_utils.h index b625cbf3ca492..9c55dad3c41ef 100644 --- a/onnxruntime/core/graph/graph_flatbuffers_utils.h +++ b/onnxruntime/core/graph/graph_flatbuffers_utils.h @@ -5,7 +5,7 @@ #include -#include "flatbuffers/flatbuffers.h" +#include "core/common/flatbuffers.h" #include "core/common/status.h" #include "core/graph/ort_format_load_options.h" diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index cf78040ea5ac6..119d420066a84 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -14,8 +14,8 @@ bool NodeCompare::operator()(const Node* n1, const Node* n2) const { struct PriorityNodeCompare { inline bool IsHighPri(const Node* n) const { // local statics so we can compare std::strings in the checks - static const std::string shape_op("Shape"); - static const std::string size_op("Size"); + static constexpr std::string_view shape_op("Shape"); + static constexpr std::string_view size_op("Size"); const auto& op_type = n->OpType(); return op_type == shape_op || op_type == size_op; @@ -26,15 +26,20 @@ struct PriorityNodeCompare { // If return true, n2 will be output first bool operator()(const Node* n1, const Node* n2) const { // nodes in global high priority list will be output first - if (IsHighPri(n1) != IsHighPri(n2)) { - return IsHighPri(n2); + const bool isN1HighPri = IsHighPri(n1); + const bool isN2HighPri = IsHighPri(n2); + if (isN1HighPri != isN2HighPri) { + return isN2HighPri; } // nodes with lower priority value will be output first - if (n1->Priority() != n2->Priority()) { - return n1->Priority() > n2->Priority(); + const auto n1_priority = n1->Priority(); + const auto n2_priority = n2->Priority(); + if (n1_priority != n2_priority) { + return n1_priority > n2_priority; } +#ifdef ENABLE_TRAINING // nodes of forward pass will be output first auto n1_attrs = n1->GetAttributes(); auto n2_attrs = n2->GetAttributes(); @@ -45,6 +50,7 @@ struct PriorityNodeCompare { if (n1_is_forward != n2_is_forward) { return n2_is_forward > n1_is_forward; } +#endif // otherwise, nodes with lower index will be output first return n1->Index() > n2->Index(); @@ -212,6 +218,8 @@ const std::string& GraphViewer::Description() const noexcept { bool GraphViewer::GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const { + value = nullptr; + // if we are using filtered subgraph, the initializer has to be part of the subgraph if (filter_info_ != nullptr && filtered_initializers_.find(tensor_name) == filtered_initializers_.cend()) return false; diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 4ce6660b794bc..a774d5fe34461 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -8,7 +8,7 @@ #include #include -#include "flatbuffers/flatbuffers.h" +#include "core/common/flatbuffers.h" #include "core/common/path.h" #include "core/graph/graph_viewer.h" diff --git a/onnxruntime/core/graph/op_identifier_utils.h b/onnxruntime/core/graph/op_identifier_utils.h index 8a9351a2d0ddc..f7b1198c31972 100644 --- a/onnxruntime/core/graph/op_identifier_utils.h +++ b/onnxruntime/core/graph/op_identifier_utils.h @@ -3,7 +3,7 @@ #pragma once -#include "flatbuffers/flatbuffers.h" +#include "core/common/flatbuffers.h" #include "core/graph/op_identifier.h" diff --git a/onnxruntime/core/graph/runtime_optimization_record_container.h b/onnxruntime/core/graph/runtime_optimization_record_container.h index a28b19e786de0..75750c2b96987 100644 --- a/onnxruntime/core/graph/runtime_optimization_record_container.h +++ b/onnxruntime/core/graph/runtime_optimization_record_container.h @@ -9,7 +9,7 @@ #include #include -#include "flatbuffers/flatbuffers.h" +#include "core/common/flatbuffers.h" #include "core/common/common.h" #include "core/graph/runtime_optimization_record.h" diff --git a/onnxruntime/core/mickey/README.md b/onnxruntime/core/mickey/README.md index 7e8d30cd1805b..735ec4b80daf3 100644 --- a/onnxruntime/core/mickey/README.md +++ b/onnxruntime/core/mickey/README.md @@ -4,3 +4,7 @@ Playful name for a template library of high performance cuda code that are often shared by various AI operators. The intention is to make this header files only, with no binary impact unless it is instantiated where it is needed. + +Currently cuda code are scattered in multiple locations in the repo. +Hopefully this can be the starting point of consolidating all cuda +code. diff --git a/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h b/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h new file mode 100644 index 0000000000000..52bff7e40dbe3 --- /dev/null +++ b/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h @@ -0,0 +1,208 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * blk_q4/f16_gemm_sm80.h + * + * Abstract: + * Entry point for Q4F16 GEMM kernel for SM80 devices. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass_ext/q4gemm/device/quantb_gemm.h" + +namespace onnxruntime { +namespace cuda { + +// +// This is the implementation of the quantized GEMM kernel for 16b float x blocked quantized 4b data type +// +template < + typename ElementDequant_, // <- data type of dequantized elements for gemm, fp16 or bf16 + typename QuantBlocking_, // <- weights block per scale, cutlass::MatrixShape + bool SmallM, // <- true if M <= 16 + bool kHasQuantOffset> +struct BlkQ4F16GemmImpl { + // + // Type definitions + // + + using ElementDequant = ElementDequant_; + using QuantBlocking = QuantBlocking_; + + static_assert(sizeof(ElementDequant) == 2, "q4f16gemm kerenl only support 16b operands!"); + + // Data types that are fixed for this kernel + using ElementAccumulator = float; + using ElementComputeEpilogue = ElementAccumulator; + using ElementInputA = ElementDequant; + using ElementOutput = ElementDequant; + + using ElementW = uint8_t; // <- Weight is int4, uint8 for two of them + + // We pack 4 weights into one 16b element, so as to leverage cutlass tile iterators + // for async shared memory loading and minimize bank conflict + using ElementWPack = ElementDequant; + + using ElementQScale = ElementDequant; // <- data type of quantization scale + using ElementQOffset = uint8_t; + + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputWPack = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + // Layout of quantization scale and offset, oriented to be loaded using less instructions + // in a warp tile + using LayoutInputQScale = + typename std::conditional::type; // <- layout of quantization scale + + using ShapeMMAThreadBlock = + typename std::conditional, + cutlass::gemm::GemmShape<128, 256, 64>>::type; + + static constexpr int MinN = QuantBlocking::kColumn > 32 ? QuantBlocking::kColumn : 32; + using ShapeMMAWarp = + typename std::conditional, + cutlass::gemm::GemmShape<64, 64, 64>>::type; + + using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>; + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? + + // This code section describes the epilogue part of the kernel + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function + + // Number of pipelines you want to use + static constexpr int NumStages = 3; + + using Gemm = cutlass::gemm::device::QuantBGemm< + ElementInputA, + LayoutInputA, + ElementWPack, + LayoutInputWPack, + ElementQScale, + typename std::conditional::type, + LayoutInputQScale, + QuantBlocking, + ElementOutput, + LayoutOutput, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ShapeMMAThreadBlock, + ShapeMMAWarp, + ShapeMMAOp, + EpilogueOp, + SwizzleThreadBlock, + NumStages>; + + using Arguments = typename Gemm::Arguments; + + // Invoke gemm kernel (the version with quantization offset) + static cutlass::Status run( + cudaStream_t stream, + const cutlass::gemm::GemmCoord& problem_size_, + cutlass::TensorRef ref_A_, + cutlass::TensorRef ref_B_, + cutlass::TensorRef ref_Qscale_, + cutlass::TensorRef ref_Qoffset_, + cutlass::TensorRef ref_C_, + cutlass::TensorRef ref_D_, + typename EpilogueOp::Params epilogue_ = typename EpilogueOp::Params()) { + if constexpr (!kHasQuantOffset) { + return cutlass::Status::kErrorNotSupported; + } else { + if constexpr (ShapeMMAThreadBlock::kM == 16) { + if (problem_size_.m() > 16) { + // For M > 16, the caller should have picked the + // kernel with bigger M + return cutlass::Status::kErrorNotSupported; + } + } + + // Construct Gemm arguments + Arguments args{ + problem_size_, + ref_A_, + ref_B_, + ref_Qscale_, + ref_Qoffset_, + ref_C_, + ref_D_, + epilogue_}; + + Gemm gemm_op; + + // Check if this GEMM can be run or not + cutlass::Status status = gemm_op.can_implement(args); + if (status != cutlass::Status::kSuccess) { + return status; + } + + // Launch the CUTLASS GEMM kernel. + return gemm_op(args, nullptr, stream); + } + } + + // Invoke gemm kernel (the version without quantization offset) + static cutlass::Status run( + cudaStream_t stream, + const cutlass::gemm::GemmCoord& problem_size_, + cutlass::TensorRef ref_A_, + cutlass::TensorRef ref_B_, + cutlass::TensorRef ref_Qscale_, + cutlass::TensorRef ref_C_, + cutlass::TensorRef ref_D_, + typename EpilogueOp::Params epilogue_ = typename EpilogueOp::Params()) { + if constexpr (kHasQuantOffset) { + return cutlass::Status::kErrorNotSupported; + } else { + if constexpr (ShapeMMAThreadBlock::kM == 16) { + if (problem_size_.m() > 16) { + // For M > 16, the caller should have picked the + // kernel with bigger M + return cutlass::Status::kErrorNotSupported; + } + } + + // Construct Gemm arguments + Arguments args{ + problem_size_, + ref_A_, + ref_B_, + ref_Qscale_, + ref_C_, + ref_D_, + epilogue_}; + + Gemm gemm_op; + + // Check if this GEMM can be run or not + cutlass::Status status = gemm_op.can_implement(args); + if (status != cutlass::Status::kSuccess) { + return status; + } + + // Launch the CUTLASS GEMM kernel. + return gemm_op(args, nullptr, stream); + } + } +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/mickey/blk_q4/prepack_sm80.h b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h similarity index 99% rename from onnxruntime/core/mickey/blk_q4/prepack_sm80.h rename to onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h index e291ab39e8aa3..a08cfb97eed4a 100644 --- a/onnxruntime/core/mickey/blk_q4/prepack_sm80.h +++ b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h @@ -3,7 +3,7 @@ * Licensed under the MIT License. * * Module Name: - * prepack_sm80.h + * blk_q4/f16_prepack_sm80.h * * Abstract: * Prepack weights and quantization parameters (scales and offsets) for diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h new file mode 100644 index 0000000000000..38795291b0328 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h @@ -0,0 +1,481 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file quantb_gemm.h + * @brief Modified from cutlass/gemm/device/gemm.h, boilerplate code passing input pointers to the kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm.h" + +#include "cutlass_ext/q4gemm/kernel/default_quantb_gemm.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +#include "cutlass/layout/permute.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! A specialized GEMM operator for quantized B GEMM. + + It is modified from cutlass::gemm::device::Gemm. Both this class and the original Gemm class + are pretty much boilerplate code that construct the Gemm kernel class, and pass parameters + and controls to it. The only difference is that this class has a few more template parameters + to support quantization. + + This implementation pretty much follows the design of cutlass. But this class seems to be + just a wrapper of the Gemm kernel class. Consider combining them in future iterations. + +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for quant scales + typename ElementQScale_, + /// Element type for quant offsets + typename ElementQOffset_, + /// Layout type for quant scales and offsets + typename LayoutQMeta_, + /// Blocking dimensions for quantization + typename QuantBlocking_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm80, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = + typename threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial = false, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Scatter result D by using an index array + bool ScatterD = false, + /// Permute result D + typename PermuteDLayout = layout::NoPermute> +class QuantBGemm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static bool const kSplitKSerial = SplitKSerial; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + // Quantization Parameters + static_assert(std::is_same::value, + "LayoutB, i.e. packed weights must appear ColumnMajor."); + static_assert(InstructionShape::kK == 16, + "InstructionShape::kK must be a multiple of 16 (2 tiles), required by 4b weight packing layout."); + using ElementQScale = ElementQScale_; + using ElementQOffset = ElementQOffset_; + using LayoutQMeta = LayoutQMeta_; + using QuantBlocking = QuantBlocking_; + static constexpr bool kHasQOffset = !(std::is_same::value); + + // TODO(chenfucn): consider moving to uint4_t or smaller for QOffset + static_assert(!kHasQOffset || std::is_same::value, "QOffset must be uint8_t"); + + /// Define the kernel + using GemmKernel = typename kernel::DefaultQuantBGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementQScale, + ElementQOffset, + LayoutQMeta, + QuantBlocking, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kSplitKSerial, + Operator, + GatherA, + GatherB, + ScatterD, + PermuteDLayout + >::GemmKernel; + + /// Argument structure + struct Arguments { + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A; + TensorRef ref_B; + TensorRef ref_C; + TensorRef ref_D; + TensorRef ref_Qscale; + TensorRef ref_Qoffset; + + typename EpilogueOutputOp::Params epilogue; + + // split-K parallelism (etc.) are not yet supported, keeping this for future extension + int split_k_slices{1}; + // For gather+scatter operations + int const *gather_A_indices{nullptr}; + int const *gather_B_indices{nullptr}; + int const *scatter_D_indices{nullptr}; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): problem_size(0, 0, 0) {} + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_Qscale_, + TensorRef ref_C_, + TensorRef ref_D_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params()): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_Qscale(ref_Qscale_), + ref_C(ref_C_), + ref_D(ref_D_), + epilogue(epilogue_) { + assert(!kHasQOffset); + } + + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_Qscale_, + TensorRef ref_Qoffset_, + TensorRef ref_C_, + TensorRef ref_D_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params()): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_Qscale(ref_Qscale_), + ref_Qoffset(ref_Qoffset_), + ref_C(ref_C_), + ref_D(ref_D_), + epilogue(epilogue_) { + assert(kHasQOffset); + } + }; + + private: + /// Kernel parameters object + typename GemmKernel::Params params_; + + public: + /// Constructs the GEMM. + QuantBGemm() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (!kSplitKSerial && args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + + Status status = GemmKernel::can_implement( + args.problem_size, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_Qscale.non_const_ref(), + args.ref_Qoffset.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D + ); + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial && args.split_k_slices > 1) { + + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial) { + if (args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } else { + + if (args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + } + + // Initialize the Params structure + params_ = typename GemmKernel::Params{ + args.problem_size, + grid_shape, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_Qscale.non_const_ref(), + args.ref_Qoffset.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D, + args.epilogue, + static_cast(workspace), + args.gather_A_indices, + args.gather_B_indices, + args.scatter_D_indices + }; + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + if (kSplitKSerial && args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + } + + params_.ref_A.reset(args.ref_A.non_const_ref().data()); + params_.ref_B.reset(args.ref_B.non_const_ref().data()); + params_.ref_Qscale.reset(args.ref_Qscale.non_const_ref().data()); + params_.ref_Qoffset.reset(args.ref_Qoffset.non_const_ref().data()); + params_.ref_C.reset(args.ref_C.non_const_ref().data()); + params_.ref_D.reset(args.ref_D.data()); + params_.output_op = args.epilogue; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + std::cerr << "Failed to obtain maximum shared memory size " << smem_size << " for kernel: " + << cudaGetErrorString(result) << "\n"; + return Status::kErrorInternal; + } + } + + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h new file mode 100644 index 0000000000000..2f4460bb59e9f --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h @@ -0,0 +1,255 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file default_quantb_gemm.h + * @brief Modified from cutlass/gemm/kernel/default_gemm.h. templates for combining + * threadblock-scoped matrix multiply-add with the appropriate + * threadblock-scoped epilogue. + */ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/wmma.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass_ext/q4gemm/kernel/quantb_gemm.h" +#include "cutlass/gemm/kernel/gemm_pipelined.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass_ext/q4gemm/threadblock/default_quantb_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" + +#include "cutlass/layout/permute.h" + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) +#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" +#endif //CUTLASS_ARCH_WMMA_ENABLED + +//////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace gemm { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for quant scales + typename ElementQScale_, + /// Element type for quant offsets + typename ElementQOffset_, + /// Layout type for quant scales and offsets + typename LayoutQMeta_, + /// Blocking dimensions for quantization + typename QuantBlocking_, + /// Access granularity of quant scales in units of elements + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Scatter result D by using an index array + bool ScatterD = false, + /// Permute result D + typename PermuteDLayout = layout::NoPermute, + /// Permute operand A + typename PermuteALayout = layout::NoPermute, + /// Permute operand B + typename PermuteBLayout = layout::NoPermute, + /// + typename Enable = void +> +struct DefaultQuantBGemm; + +//////////////////////////////////////////////////////////////////////////////// + + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Ampere Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for quant scales + typename ElementQScale, + /// Element type for quant offsets + typename ElementQOffset, + /// Layout type for quant scales + typename LayoutQMeta, + /// Blocking dimensions for quantization + typename QuantBlocking, + /// Access granularity of quant scales in units of elements + typename ElementC, + /// Layout type for C and D matrix operand + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB, + /// Scatter result D by using an index array + bool ScatterD, + /// Permute result D + typename PermuteDLayout, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout +> +struct DefaultQuantBGemm { + + static_assert((platform::is_same::value + || platform::is_same>::value), + "Epilogue in the kernel level must be row major"); + + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultQuantBMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementQScale, ElementQOffset, LayoutQMeta, QuantBlocking, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape, WarpShape, InstructionShape, Stages, + Operator, false, GatherA, GatherB, + PermuteALayout, PermuteBLayout>::ThreadblockMma; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using RegularEpilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + EpilogueOutputOp::kCount, ScatterD, PermuteDLayout>::Epilogue; + + using Affine2Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpAffineRankN< + 2, ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + EpilogueOutputOp::kCount>::Epilogue; + + using Epilogue = typename platform::conditional::value, + RegularEpilogue, + Affine2Epilogue>::type; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::QuantBGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h new file mode 100644 index 0000000000000..6e5ad8f406147 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h @@ -0,0 +1,462 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file quantb_gemm.h + * @brief Modified from cutlass/gemm/kernel/gemm.h. + * Template for a pipelined GEMM kernel. Does not compute batching or support split-K. + */ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" +#include "cutlass/arch/arch.h" + +#include "cutlass/util/debug.h" +#include "cutlass/util/device_dump.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. +> +struct QuantBGemm { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using OutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; + + static constexpr bool kHasQOffset = Mma::kHasQOffset; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorQScale::Params params_QScale; + typename Mma::IteratorQScale::TensorRef ref_QScale; + typename Mma::IteratorQOffset::Params params_QOffset; + typename Mma::IteratorQOffset::TensorRef ref_QOffset; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + typename OutputOp::Params output_op; + int *semaphore; + int gemm_k_size; // how many k vectors are processed by this threadblock + // For gather+scatter operations + int const *gather_A_indices; + int const *gather_B_indices; + int const *scatter_D_indices; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { } + + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorQScale::TensorRef ref_QScale, + typename Mma::IteratorQOffset::TensorRef ref_QOffset, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, + typename OutputOp::Params output_op = typename OutputOp::Params(), + int *workspace = nullptr, + int const *gather_A_indices = nullptr, + int const *gather_B_indices = nullptr, + int const *scatter_D_indices = nullptr + ): + problem_size(problem_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(ref_A.layout()), + ref_A(ref_A), + params_B(ref_B.layout()), + ref_B(ref_B), + params_QScale(ref_QScale.layout()), + ref_QScale(ref_QScale), + params_QOffset(ref_QOffset.layout()), + ref_QOffset(ref_QOffset), + params_C(ref_C.layout()), + ref_C(ref_C), + params_D(ref_D.layout()), + ref_D(ref_D), + output_op(output_op), + gather_A_indices(gather_A_indices), + gather_B_indices(gather_B_indices), + scatter_D_indices(scatter_D_indices) { + int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); + + gemm_k_size = gemm_k_iterations * Mma::Shape::kK; + + semaphore = workspace; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + QuantBGemm() { } + + /// Determines whether kernel satisfies alignment + CUTLASS_HOST_DEVICE + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorQScale::TensorRef ref_QScale, + typename Mma::IteratorQOffset::TensorRef ref_QOffset, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D) { + + // TODO check problem_size K, N must be multiple of QuantBlocking + + static int const kAlignmentA = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(ref_A, kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (problem_size.k() % Mma::Shape::kK != 0) { + // Currently we don't support this case due to the way + // predicate iterator works, it loads the partial tile + // in the first iteration and then the full tile in the + // remaining iterations. This will cause the blockwise + // quantization parameters to go out of step with the + // weights. We can fix this by adding a predicate iterator + // that loads the full tile in the first iterations and + // then the partial tile in the last iteration. + return Status::kErrorInvalidProblem; + } + + int qscale_k = problem_size.k() / Mma::QuantBlocking::kRow; + int qscale_n = problem_size.n() / Mma::QuantBlocking::kColumn; + if ((qscale_k == 0) || (qscale_k * Mma::QuantBlocking::kRow != problem_size.k())) { + // partial block not supported + return Status::kErrorInvalidProblem; + } + if ((qscale_n == 0) || (qscale_n * Mma::QuantBlocking::kColumn != problem_size.n())) { + // partial block not supported + return Status::kErrorInvalidProblem; + } + + if (!TensorRef_aligned(ref_QScale, Mma::IteratorQScale::AccessType::kElements)) { + return Status::kErrorMisalignedOperand; + } + + if constexpr(kHasQOffset) { + if (!TensorRef_aligned(ref_QOffset, Mma::IteratorQOffset::AccessType::kElements)) { + return Status::kErrorMisalignedOperand; + } + } + + if (!TensorRef_aligned(ref_C, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{ + (threadblock_tile_offset.k() * params.gemm_k_size) / 2, + (threadblock_tile_offset.n() * Mma::Shape::kN) / 2 + }; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = min( + params.problem_size.k(), + (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + params.ref_A.data(), + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A, + params.gather_A_indices); + + typename Mma::IteratorB iterator_B( + params.params_B, + params.ref_B.data(), + {problem_size_k/2, params.problem_size.n()/2}, + thread_idx, + tb_offset_B, + params.gather_B_indices); + + const int qscale_k = problem_size_k / Mma::QuantBlocking::kRow; + const int qscale_n = params.problem_size.n() / Mma::QuantBlocking::kColumn; + + // should have been verified by can_implement() + assert((qscale_k > 0) && (qscale_k * Mma::QuantBlocking::kRow == problem_size_k)); + assert((qscale_n > 0) && (qscale_n * Mma::QuantBlocking::kColumn == params.problem_size.n())); + + cutlass::MatrixCoord tb_offset_QScale{ + threadblock_tile_offset.k() * (params.gemm_k_size/Mma::QuantBlocking::kRow), + threadblock_tile_offset.n() * (Mma::Shape::kN/Mma::QuantBlocking::kColumn) + }; + + typename Mma::IteratorQScale iterator_QScale( + params.params_QScale, + params.ref_QScale.data(), + {qscale_k, qscale_n}, + thread_idx, + tb_offset_QScale, + nullptr); + + typename Mma::IteratorQOffset iterator_QOffset( + params.params_QOffset, + params.ref_QOffset.data(), + {qscale_k, qscale_n}, + thread_idx, + tb_offset_QScale); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + const int warp_idx = canonical_warp_idx(); + const int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + if (!kSplitKSerial || gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_QScale, iterator_QOffset, accumulators); + } + + // + // Epilogue + // + + OutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + params.ref_C.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + params.ref_D.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices + ); + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h new file mode 100644 index 0000000000000..0af604f090e1f --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h @@ -0,0 +1,248 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file default_quantb_mma.h + * @brief Modified from cutlass/gemm/threadblock/default_mma.h. + * Defining global memory data layout and iterators, combinging with mma core and + * pipelined GEMM kernel. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/wmma.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/permute.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" +#include "cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h" +#include "cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for quant scales + typename ElementQScale_, + /// Element type for quant offsets + typename ElementQOffset_, + /// Layout for quant scales and offsets + typename LayoutQMeta_, + /// Blocking size for quantization + typename QuantBlocking_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Permute operand A + typename PermuteALayout = layout::NoPermute, + /// Permute operand B + typename PermuteBLayout = layout::NoPermute + > +struct DefaultQuantBMma; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp) +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for quant scales + typename ElementQScale, + /// Element type for quant offsets + typename ElementQOffset, + /// Layout for quant scales and offsets + typename LayoutQMeta, + /// Blocking size for quantization + typename QuantBlocking, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Layout type for C and D matrix operand + typename LayoutC, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the multistage mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout + > +struct DefaultQuantBMma { + + static_assert(platform::is_same::value + || platform::is_same>::value, + "simt epilogue must be row major"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultQuantBMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementQScale, ElementQOffset, LayoutQMeta, QuantBlocking, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, Operator, false, CacheOpA, CacheOpB>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA, GatherA, PermuteALayout>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB, GatherB, PermuteBLayout>; + + // Define iterators over tiles from the quant scales + using ThreadMapQScale = typename MmaCore::IteratorThreadMapQScale; + using AccessTypeQScale = + cutlass::Array; + using IteratorQScale = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + typename MmaCore::ThreadblockQShape, + ElementQScale, LayoutQMeta, 0, ThreadMapQScale, AccessTypeQScale>; + + using ThreadMapQOffset = typename MmaCore::IteratorThreadMapQOffset; + using AccessTypeQOffset = + cutlass::Array; + using IteratorQOffset = + cutlass::transform::threadblock::OptionalPredicatedTileAccessIterator< + typename MmaCore::ThreadblockQShape, ElementQOffset, LayoutQMeta, + 0, ThreadMapQOffset, AccessTypeQOffset, MmaCore::kThreads>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::QuantBMmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, IteratorQScale, typename MmaCore::SmemIteratorQScale, + cutlass::arch::CacheOperation::Global, IteratorQOffset, + typename MmaCore::SmemIteratorQOffset, cutlass::arch::CacheOperation::Global, + ElementAccumulator, LayoutC, + typename MmaCore::MmaPolicy, Stages>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h new file mode 100644 index 0000000000000..ad322f6505200 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h @@ -0,0 +1,340 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file default_quantb_mma_core.h + * @brief Modified from cutlass/gemm/threadblock/default_mma_core.h. + * Defining data layout in shared memory, and its iterators. + */ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" + +#include "cutlass/layout/tensor_op_multiplicand_sm75.h" +#include "cutlass/layout/tensor_op_multiplicand_sm80.h" + +#include "cutlass/gemm/warp/mma_simt_policy.h" +#include "cutlass/gemm/warp/mma_simt.h" +#include "cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core.h" +#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" +#include "cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h" + +#include "cutlass/util/debug.h" +#include "cutlass/util/device_dump.h" +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template defininng default matrix multiply operators inferred from threadblock tile size, +/// global memory data layout, and target math instruction. +template < + /// Shape of threadblock-scoped matrix multiply operator + typename Shape, + /// Shape of warp-level matrix multiply operator + typename WarpShape, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape, + /// Element data type of A operand + typename ElementA, + /// Layout of operand A + typename LayoutA, + /// Element data type of B operand + typename ElementB, + /// Layout of operand B + typename LayoutB, + /// Element data type of quant scale + typename ElementQScale, + /// Element data type of quant offset + typename ElementQOffset, + /// Layout of quant scale + typename LayoutQMeta, + /// Blocking dimensions for quantization + typename QuantBlocking, + /// Data type of accumulator + typename ElementC, + /// Layout of accumulator + typename LayoutC, + /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) + typename OperatorClass, + /// Number of stages + int Stages = 2, + /// Operation performed by MMA + typename Operator = typename platform::conditional< + (platform::is_same::value) && + (platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value), + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAdd>::type, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA = + cutlass::arch::CacheOperation::Global, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB = + cutlass::arch::CacheOperation::Global, + /// per-element transformation for elements of A + ComplexTransform TransformA = ComplexTransform::kNone, + /// per-element transformation for elements of B + ComplexTransform TransformB = ComplexTransform::kNone, + bool IsComplex = false // (is_complex::value || is_complex::value) +> +struct DefaultQuantBMmaCore; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: column-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Element data type of quant scale + typename ElementQScale_, + /// Element data type of quant offset + typename ElementQOffset_, + /// Layout of quant scale + typename LayoutQMeta_, + /// Blocking dimensions for quantization + typename QuantBlocking_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultQuantBMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::ColumnMajor; + + using ElementQScale = ElementQScale_; + using ElementQOffset = ElementQOffset_; + using LayoutQMeta = LayoutQMeta_; + using QuantBlocking = QuantBlocking_; + + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 128; + + /// Default Operator + using Operator = Operator_; + + // Warp thread arrangement + static int const kWarpThreadArrangementContiguousA = + Shape::kK / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedA = + kWarpSize / kWarpThreadArrangementContiguousA; + + static int const kWarpThreadArrangementContiguousB = + (Shape::kK / 2) / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedB = + kWarpSize / kWarpThreadArrangementContiguousB; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kK>; + + using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kK/2>; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 0, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 1, + IteratorThreadMapB>; + + using SmemLayoutQScale = LayoutQMeta; + using SmemLayoutQOffset = LayoutQMeta; + + /// Threadblock-level quantization meta data shape + using ThreadblockQShape = MatrixShape; + static_assert(Shape::kK % QuantBlocking::kRow == 0, "K must be multiple of QuantBlocking::kRow"); + static_assert(Shape::kN % QuantBlocking::kColumn == 0, "N must be multiple of QuantBlocking::kColumn"); + static_assert(ThreadblockQShape::kCount > 0, "QuantBlocking too big to fit in a thread block!"); + static_assert(QuantBlocking::kRow == 1 || QuantBlocking::kColumn == 1, + "Only support single column or row quantize blocking!"); + static_assert(QuantBlocking::kColumn != 1 || std::is_same::value, + "Quant scale matrix's major dimension must have more elements, to facilitate fast loading!"); + + /// Threadblock-level quantization meta data shape in pitch-linear layout + using TBQPitchLinearShape = typename std::conditional< + std::is_same::value, + layout::PitchLinearShape, + layout::PitchLinearShape>::type; + + /// By default we would like to use 128b load. However, we can't load more than + /// a column at a time in a column major layout. + static int const kElementsPerAccessQScale = + (kAccessSizeInBits / sizeof_bits::value) > TBQPitchLinearShape::kContiguous + ? TBQPitchLinearShape::kContiguous + : (kAccessSizeInBits / sizeof_bits::value); + + /// quant scale is tiny. Not all threads are needed. + static int const kAccessCntQScale = ThreadblockQShape::kCount / kElementsPerAccessQScale; + static int const kThreadsQScale = (kAccessCntQScale > kThreads) ? kThreads : kAccessCntQScale; + + using IteratorThreadMapQScale = transform::PitchLinearStripminedThreadMap< + TBQPitchLinearShape, kThreadsQScale, kElementsPerAccessQScale>; + + using SmemIteratorQScale = transform::threadblock::RegularTileAccessIterator< + ThreadblockQShape, ElementQScale, SmemLayoutQScale, 1, IteratorThreadMapQScale>; + + static int const kElementsPerAccessQOffset = + (kAccessSizeInBits / sizeof_bits::value) > TBQPitchLinearShape::kContiguous + ? TBQPitchLinearShape::kContiguous + : (kAccessSizeInBits / sizeof_bits::value); + static int const kAccessCntQOffset = ThreadblockQShape::kCount / kElementsPerAccessQOffset; + static int const kThreadsQOffset = (kAccessCntQOffset > kThreads) ? kThreads : kAccessCntQOffset; + + using IteratorThreadMapQOffset = transform::PitchLinearStripminedThreadMap< + TBQPitchLinearShape, kThreadsQOffset, kElementsPerAccessQOffset>; + + using SmemIteratorQOffset = transform::threadblock::OptionalRegularTileAccessIterator< + ThreadblockQShape, ElementQOffset, SmemLayoutQOffset, 1, IteratorThreadMapQOffset, kThreads>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultQuantBMmaTensorOp< + WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, + ElementQScale, SmemLayoutQScale, ElementQOffset, SmemLayoutQScale, QuantBlocking, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h new file mode 100644 index 0000000000000..6f27a692a3a2e --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h @@ -0,0 +1,314 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + * + * @file optional_predicated_tile_access_iter.h + * @brief Templates for loading and storing optional tiles of matrix data. + * This iterator is just a wrapper of PredicatedTileAccessIterator, with + * the option to turn it off at compile time and minimize its runtime + * footprint. Also, it utilize the higher numbered threads in the + * threadblock when the iterator can not utilize all the threads. + */ + +#pragma once + +#include + +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + + +//////////////////////////////////////////////////////////////////////////////// + +/// Optional 2-D matrix data loader, when element is std::monostate, the +/// iterator becomes no-op with minimal runtime footprint. Also, it utilize the +/// higher numbered threads in the threadblock when the iterator can not utilize +/// all the threads. +/// +template < + /// Tile shape of the iterator + typename Shape_, + /// Element data type of the iterator, no-op when it is std::monostate + typename Element_, + /// Layout of the source matrix + typename Layout_, + int AdvanceRank_, + typename ThreadMap_, + typename AccessType_, + /// Number of threads in the threadblock, when provided, the iterator + /// will utilize the higher numbered threads + int kThreadBlockSize_ = -1> +class OptionalPredicatedTileAccessIterator{ + public: + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + static constexpr int kAdvanceRank = AdvanceRank_; + static constexpr int kThreadblockSize = kThreadBlockSize_; + + static_assert(!std::is_same::value, + "Disabled Iterator failed to match the specialized version below."); + static_assert(kThreadblockSize == -1 || kThreadblockSize >= ThreadMap::kThreads, + "kThreadblockSize must be no smaller than ThreadMap::kThreads"); + + using Base = PredicatedTileAccessIterator; + + using LongIndex = typename Base::LongIndex; + using Mask = typename Base::Mask; + using TensorCoord = typename Base::TensorCoord; + using TensorRef = typename Base::TensorRef; + using Params = typename Base::Params; + using Pointer = typename Base::Pointer; + + static constexpr int kAccessesPerVector = Base::kAccessesPerVector; + + CUTLASS_HOST_DEVICE + static int flip_thread_id(int thread_id){ + if constexpr (kThreadblockSize > 0) { + return kThreadblockSize - 1 - thread_id; + } + return thread_id; + } + + public: + Base base_; + + /// Default constructor + OptionalPredicatedTileAccessIterator(): base_() {}; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : base_(params, pointer, extent, flip_thread_id(thread_id), threadblock_offset) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : OptionalPredicatedTileAccessIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + base_.set_iteration_index(index); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + base_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + base_.add_tile_offset(tile_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return base_.get(); + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator &operator++() { + ++base_; + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator operator++(int) { + OptionalPredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + base_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + base_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + base_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + base_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return base_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for the disabled version +/// Reduce runtime overhead +/// +template < + /// Tile shape of the iterator + typename Shape_, + typename Layout_, + int AdvanceRank_, + typename ThreadMap_, + typename AccessType_, + int kThreadBlockSize_> +class OptionalPredicatedTileAccessIterator{ + public: + + using Shape = Shape_; + using Element = std::monostate; + using Layout = Layout_; + static int const kAdvanceRank = AdvanceRank_; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + static constexpr int kThreadblockSize = kThreadBlockSize_; + + using Base = PredicatedTileAccessIterator; + + using LongIndex = typename Base::LongIndex; + using Mask = typename Base::Mask; + using TensorCoord = typename Base::TensorCoord; + using TensorRef = typename Base::TensorRef; + using Params = typename Base::Params; + using Pointer = typename Base::Pointer; + + static constexpr int kAccessesPerVector = Base::kAccessesPerVector; + + public: + std::monostate base_; + + /// Default constructor + OptionalPredicatedTileAccessIterator(): base_() {}; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : base_() {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : base_() {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) {} + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) {} + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return nullptr; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator &operator++() { + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator operator++(int) { + return *this; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) {} + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() {} + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) {} + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) {} + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { return false; } +}; + +//////////////////////////////////////////////////////////////////////////////// +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h new file mode 100644 index 0000000000000..4b0ae5317f8bb --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h @@ -0,0 +1,224 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + * + * @file optional_regular_tile_access_iter.h + * @brief Templates implementing the address computation of storing of tiles + * from pitch-linear rank=2 tensors. + * + * This iterator is just a wrapper of RegularTileAccessIterator, with the + * option to turn it off at compile time and minimize its runtime footprint. + * Also, it utilize the higher numbered threads in the threadblock when the + * iterator can not utilize all the threads. + * + * Must be used in conjunction with OptionalPredicatedTileAccessIterator, + * with the same template parameters. + */ + +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Optional 2-D tile iterator, when element is std::monostate, the iterator +/// becomes no-op with minimal runtime footprint. Also, it utilize the higher +/// numbered threads in the threadblock when the iterator can not utilize all +/// the threads. +/// +template < + /// Tile shape of the iterator + typename Shape_, + typename Element_, + typename Layout_, + int AdvanceRank, + typename ThreadMap_, + /// Number of threads in the threadblock, when not -1, the iterator + /// will utilize the higher numbered threads + int ThreadblockSize_ = -1, + int Alignment = + sizeof_bits::value * ThreadMap_::kElementsPerAccess / 8> +class OptionalRegularTileAccessIterator{ + public: + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + static constexpr int kAlignment = Alignment; + static constexpr int kThreadblockSize = ThreadblockSize_; + + static_assert(!std::is_same::value, + "Disabled Iterator failed to match the specialized template"); + static_assert(kThreadblockSize == -1 || kThreadblockSize >= ThreadMap::kThreads, + "kThreadblockSize must be no smaller than ThreadMap::kThreads"); + + using Base = RegularTileAccessIterator; + + using LongIndex = typename Base::LongIndex; + using TensorRef = typename Base::TensorRef; + using TensorCoord = typename Base::TensorCoord; + using AccessType = typename Base::AccessType; + + CUTLASS_HOST_DEVICE + static int flip_thread_id(int thread_id){ + if constexpr (kThreadblockSize > 0) { + return kThreadblockSize - 1 - thread_id; + } + return thread_id; + } + + private: + + Base base_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : base_(ref, flip_thread_id(thread_id)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + base_.set_iteration_index(index); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + base_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + return base_.get(); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator &operator++() { + ++base_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset in the unit of tile. + /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory. + /// Below layouts are the shared memory layouts. Current SM50 SIMT kernels only use col major A and row major B. + /// For row major A operand, k dimension is contiguous dimension; + /// For col major A operand, k dimension is strided dimension; + /// For row major B operand, k dimension is strided dimension; + /// For col major B operand, k dimension is contiguous dimension. + /// Below two classes map col/row major to the pitch linear coordinates used + /// in this base class. + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + base_.add_tile_offset(coord); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization when Element is std::monostate, the iterator becomes no-op +/// +template < + typename Shape_, + typename Layout_, + int AdvanceRank, + typename ThreadMap_, + int ThreadblockSize_, + int Alignment> +class OptionalRegularTileAccessIterator{ + public: + + using Shape = Shape_; + using Element = std::monostate; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + static constexpr int kAlignment = Alignment; + static constexpr int kThreadblockSize = ThreadblockSize_; + + using Base = RegularTileAccessIterator; + + using LongIndex = typename Base::LongIndex; + using TensorRef = typename Base::TensorRef; + using TensorCoord = typename Base::TensorCoord; + using AccessType = typename Base::AccessType; + + private: + + std::monostate base_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : base_() {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) {} + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + return nullptr; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator &operator++() { + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator operator++(int) { + return *this; + } + + /// Adds a tile offset in the unit of tile. + /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory. + /// Below layouts are the shared memory layouts. Current SM50 SIMT kernels only use col major A and row major B. + /// For row major A operand, k dimension is contiguous dimension; + /// For col major A operand, k dimension is strided dimension; + /// For row major B operand, k dimension is strided dimension; + /// For col major B operand, k dimension is contiguous dimension. + /// Below two classes map col/row major to the pitch linear coordinates used + /// in this base class. + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) {} +}; + +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h new file mode 100644 index 0000000000000..8b6bac8c5099a --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h @@ -0,0 +1,1290 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file quantb_mma_multistage.h + * @brief Modified from cutlass/gemm/threadblock/mma_multistage.h. + * Added the quantized data memory pipeline, dequantization, and feeding + * to tensor cores. Mainloop pipeline is heavily modified. + */ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/threadblock/mma_base.h" + +#include "cutlass/util/debug.h" +#include "cutlass/util/device_dump.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// +namespace{ + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Utilities for printing layout for the prepacked weights and quantization parameters +/// +template< + /// Data type of the prepacked weights + typename ElementWeight, + /// Data type of the quant scales + typename ElementQScale, + /// Data type of the quant offsets + typename ElementQOffset> +struct QuantBLayoutDebug{ + static constexpr bool debug_smem = true; + static constexpr bool debug_fragment = true; + ElementWeight* smem_b_ptr_; + ElementQScale* smem_qscale_ptr_; + ElementQOffset* smem_qoffset_ptr_; + int warp_id_; + int lane_id_; + int block_id_; + + template + CUTLASS_DEVICE + static void print_fragment(cutlass::Array const& frag, char label, int block_id, int warp_id, int lane_id){ + static_assert(Size % 4 == 0, "Size must be multiple of 4"); + if constexpr (debug_fragment){ + if (block_id == 1 && warp_id == 0){ + const Element* ptr = reinterpret_cast(&frag); + for (int i = 0; i < Size/4; i++, ptr+=4){ + if constexpr(std::is_integral::value){ + printf("T%.2d%c%d, %3d, %3d, %3d, %3d\n", + threadIdx.x, label, i, + ptr[0], ptr[1], ptr[2], ptr[3]); + } else { + printf("T%.2d%c%d, %.3f, %.3f, %.3f, %.3f\n", + threadIdx.x, label, i, + float(ptr[0]), float(ptr[1]), float(ptr[2]), float(ptr[3])); + } + } + } + } + } + + template + CUTLASS_DEVICE + static void print_as_int4(cutlass::Array const& frag, char label, int block_id, int warp_id, int lane_id){ + constexpr int I8Size = Size * cutlass::sizeof_bits::value / 8; + static_assert(I8Size % 2 == 0, "Size must be multiple of 4"); + if constexpr (debug_fragment){ + if (block_id == 1 && warp_id == 0){ + const uint8_t* ptr = reinterpret_cast(&frag); + for (int i = 0; i < I8Size/2; i++, ptr+=2){ + printf("T%.2dW%d, %d, %d, %d, %d\n", threadIdx.x, i, ptr[0] & 0x0f, ptr[0] >> 4, ptr[1] & 0x0f, ptr[1] >> 4); + } + } + } + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dummy type when quant offset is not used, to avoid compilation error, +/// and reduce runtime footprint +/// +struct DummyType{ + std::monostate dummy_; + public: + DummyType() = default; + + CUTLASS_HOST_DEVICE + void* data() const { + return nullptr; + } + + CUTLASS_HOST_DEVICE + std::monostate& operator[](int idx) { + return dummy_; + } +}; + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class QuantBMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + static constexpr bool kHasQOffset = !std::is_same::value; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the prepacked weights + using TensorRefB = TensorRef; + + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + static_assert((kWarpGemmIterations % 2) == 0, + "Inner loop iteration must be an even number."); + + // Tensor reference to the quantization scales + using TensorRefQScale = TensorRef; + using TensorRefQOffset = TensorRef; + + // Block size of the quantization (one set of quantization parameters per block of weights) + using QuantBlocking = typename Operator::QuantBlocking; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the prepacked weights in shared memory + using ShapeB = + MatrixShape; + + /// Shape of the quantization parameter matrix in shared memory + /// Validation done in mma core class ThreadblockQShape + using ShapeQScale = + MatrixShape<(Shape::kK / QuantBlocking::kRow) * kStages, + Shape::kN / QuantBlocking::kColumn>; + + using BufTypeQOffset = std::conditional_t, + DummyType>; + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for prepacked weights + AlignedBuffer operand_B; + + /// Buffer for quantization scales + AlignedBuffer operand_QScale; + + /// Buffer for quantization offsets + BufTypeQOffset operand_QOffset; + + public: + + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + CUTLASS_HOST_DEVICE + static typename Operator::SmemLayoutQScale LayoutQMeta() { + return Operator::SmemLayoutQScale::packed({ShapeQScale::kRow, ShapeQScale::kColumn}); + } + + CUTLASS_HOST_DEVICE + static typename Operator::SmemLayoutQOffset LayoutQOffset() { + return Operator::SmemLayoutQOffset::packed({ShapeQScale::kRow, ShapeQScale::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the prepacked weights + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + + /// Returns a TensorRef to the quantization scales + CUTLASS_HOST_DEVICE + TensorRefQScale operand_QScale_ref() { + return TensorRefQScale{operand_QScale.data(), LayoutQMeta()}; + } + + CUTLASS_HOST_DEVICE + TensorRefQOffset operand_QOffset_ref() { + if constexpr (!kHasQOffset){ + return TensorRefQOffset(); + } else { + return TensorRefQOffset{operand_QOffset.data(), LayoutQOffset()}; + } + } + }; + + protected: + + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + /// Iterator to load a warp-scoped tile of quant scales from shared memory + typename Operator::IteratorQMeta warp_tile_iterator_QScale_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + QuantBMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx), + warp_tile_iterator_QScale_(shared_storage.operand_QScale_ref(), + shared_storage.operand_QOffset_ref(), lane_idx) + {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterators over tiles of quant scales in global memory + typename IteratorQScale_, + /// Iterators over tiles of quant scales in shared memory + typename SmemIteratorQScale_, + /// Cache operation for quant scales + cutlass::arch::CacheOperation::Kind CacheOpQScale, + /// Iterators over tiles of quant scales in global memory + typename IteratorQOffset_, + /// Iterators over tiles of quant scales in shared memory + typename SmemIteratorQOffset_, + /// Cache operation for quant scales + cutlass::arch::CacheOperation::Kind CacheOpQOffset, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class QuantBMmaMultistage : + public QuantBMmaBase { +public: + ///< Base class + using Base = QuantBMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using IteratorQScale = IteratorQScale_; + using IteratorQOffset = IteratorQOffset_; + using SmemIteratorQScale = SmemIteratorQScale_; + using SmemIteratorQOffset = SmemIteratorQOffset_; + using QuantBlocking = typename Base::QuantBlocking; + + static cutlass::arch::CacheOperation::Kind const kCacheOpQScale = CacheOpQScale; + static cutlass::arch::CacheOperation::Kind const kCacheOpQOffset = CacheOpQOffset; + static constexpr bool kHasQOffset = Base::kHasQOffset; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of packed weights + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + static int const AsyncCopyIterationsPerStageQScale = + IteratorQScale::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of quant scale + static int const kAccessesPerGroupQScale = + (AsyncCopyIterationsPerStageQScale + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + static int const AsyncCopyIterationsPerStageQOffset = + IteratorQOffset::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of quant offset + static int const kAccessesPerGroupQOffset = + (AsyncCopyIterationsPerStageQOffset + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + // Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical + // accuracy, where each mainloop iteration first accumulates into a temporary + // set of freshly-cleared accumulators, which are subsequently added to the + // final accumulator set. + static bool const kStagedAccumulation = arch::UseStagedAccumulation::value; + }; + + private: + + + // Structure encapsulating pipeline state live from one iteration to the next + struct PipeState { + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + /// Temporary accumulator to facilitate staged-accumulation + FragmentC tmp_accum_; + + /// Pair of A fragments used to overlap shared memory loads and math instructions + WarpLoadedFragmentA warp_loaded_frag_A_[2]; + + /// Pair of B fragments used to overlap shared memory loads and math instructions + WarpLoadedFragmentB warp_loaded_frag_B_; + WarpTransformedFragmentB warp_transformed_frag_B_[2]; + + using WarpLoadedFragmentQScale = typename Operator::FragmentQScale; + WarpLoadedFragmentQScale warp_loaded_frag_QScale_; + + using WarpLoadedFragmentQOffset = typename std::conditional::type; + WarpLoadedFragmentQOffset warp_loaded_frag_QOffset_; + }; + + + private: + + // + // Data members + // + + /// Warp-level MMA operator + Operator warp_mma_; + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of quant meta data to shared memory + SmemIteratorQScale smem_iterator_QScale_; + SmemIteratorQOffset smem_iterator_QOffset_; + + /// Shared memory write stage index + int smem_write_stage_idx_; + + /// Shared memory read stage index + int smem_read_stage_idx_; + + /// very small meta data tensor require less threads to load + bool const should_load_qscale_; + bool const should_load_qoffset_; + + /// Shared memory pointers for debug dumping + static constexpr bool debug_layout = false; + using LayoutDebugType = typename std::conditional, + std::monostate>::type; + LayoutDebugType layout_debug_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + QuantBMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_QScale_(shared_storage.operand_QScale_ref(), thread_idx), + smem_iterator_QOffset_(shared_storage.operand_QOffset_ref(), thread_idx), + should_load_qscale_(thread_idx < IteratorQScale::ThreadMap::kThreads), + should_load_qoffset_(thread_idx >= IteratorQOffset::kThreadblockSize - IteratorQOffset::ThreadMap::kThreads), + smem_write_stage_idx_(0), + smem_read_stage_idx_(0) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + if constexpr(debug_layout){ + layout_debug_.smem_b_ptr_ = shared_storage.operand_B_ref().data(); + layout_debug_.smem_qscale_ptr_ = shared_storage.operand_QScale_ref().data(); + if constexpr(kHasQOffset){ + layout_debug_.smem_qoffset_ptr_ = shared_storage.operand_QOffset_ref().data(); + } else { + layout_debug_.smem_qoffset_ptr_ = nullptr; + } + layout_debug_.warp_id_ = warp_idx; + layout_debug_.lane_id_ = lane_idx; + layout_debug_.block_id_ = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; + } + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + this->warp_tile_iterator_QScale_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + /// Advance shared memory read-iterators to the next stage + CUTLASS_DEVICE + void advance_smem_read_stage() + { + ++smem_read_stage_idx_; + + if (smem_read_stage_idx_ == Base::kStages) { + // Wrap back around to the 'start' of the circular buffer in shared memory + this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + this->warp_tile_iterator_QScale_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + + smem_read_stage_idx_ = 0; + } + } + + /// Advance global memory read-iterators and shared memory write-iterators to the stage + CUTLASS_DEVICE + void advance_smem_write_stage( + IteratorA &iterator_A, + IteratorB &iterator_B, + IteratorQScale &iterator_QScale, + IteratorQOffset &iterator_QOffset) + { + // Advance global iterators + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + iterator_QScale.add_tile_offset({1, 0}); + + // Advance shared iterators + smem_iterator_A_.add_tile_offset({0, 1}); + smem_iterator_B_.add_tile_offset({1, 0}); + smem_iterator_QScale_.add_tile_offset({1, 0}); + + if constexpr (kHasQOffset) { + iterator_QOffset.add_tile_offset({1, 0}); + smem_iterator_QOffset_.add_tile_offset({1, 0}); + } + + // Increment shared memory write stage index + ++smem_write_stage_idx_; + + if (smem_write_stage_idx_ == Base::kStages) { + // Wrap back around to the 'start' of the circular buffer in shared memory + smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_iterator_QScale_.add_tile_offset({-Base::kStages, 0}); + if constexpr (kHasQOffset) { + smem_iterator_QOffset_.add_tile_offset({-Base::kStages, 0}); + } + smem_write_stage_idx_ = 0; + } + } + + CUTLASS_DEVICE + void copy_qscale_tiles(IteratorQScale &iterator_QScale){ + // Quant scale matrix is 1/block_size of the B matrix, for a 64x64 warp tile, + // it's only 64x64/block_size elements. For blocking size 16 ~ 64, it only + // takes 4 ~ 16 cp.async instructions to load. One warp has 32 threads, so + // it should be loaded in less than one cp.async instruction per thread. + // Even less for quant offset matrix. + static_assert(Detail::AsyncCopyIterationsPerStageQScale == 1, + "Quant scale should be loaded in one shot!"); + static_assert(IteratorQScale::kAccessesPerVector == 1, + "Quant scale should 1 access per vector!"); + + // Async Copy for quantization scale + typename IteratorQScale::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_QScale_.get()); + + constexpr int kSrcBytes = + sizeof_bits::value * + IteratorQScale::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async( + dst_ptr, iterator_QScale.get(), iterator_QScale.valid()); + } + + CUTLASS_DEVICE + void copy_qoffset_tiles(IteratorQOffset & iterator_QOffset) { + static_assert(Detail::AsyncCopyIterationsPerStageQOffset == 1, + "Quant offset should be loaded in one shot!"); + static_assert(IteratorQOffset::kAccessesPerVector == 1, + "Quant offset should 1 access per vector!"); + + if constexpr(kHasQOffset) { + // Async Copy for quantization offset + typename IteratorQOffset::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_QOffset_.get()); + + constexpr int kSrcBytes = sizeof_bits::value * + IteratorQOffset::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async( + dst_ptr, iterator_QOffset.get(), iterator_QOffset.valid()); + } + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, + int group_start = 0) { + auto group_start_A = group_start * Detail::kAccessesPerGroupA; + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + auto group_start_B = group_start * Detail::kAccessesPerGroupB; + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching + /// the global fragments needed by the first kStages-1 threadblock mainloop iterations + CUTLASS_DEVICE + void prologue( + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory + IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory + int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { + + // Disable global fetching if done with global fetch iterations + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + // Async Copy for quantization scale + static_assert(Detail::AsyncCopyIterationsPerStageQScale == 1, "Quant scale should be loaded in one shot!"); + static_assert(IteratorQScale::kAccessesPerVector == 1, "Quant scale should 1 access per vector!"); + + typename IteratorQScale::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_QScale_.get()); + + constexpr int kSrcBytes = + sizeof_bits::value * + IteratorQScale::ThreadMap::kElementsPerAccess / 8; + + auto gmem_ptr = iterator_QScale.get(); + + cutlass::arch::cp_async( + dst_ptr, gmem_ptr, iterator_QScale.valid()); + + if constexpr (kHasQOffset) { + iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); + + // Async Copy for quantization offset + static_assert(Detail::AsyncCopyIterationsPerStageQOffset == 1, "Quant offset should be loaded in one shot!"); + static_assert(IteratorQOffset::kAccessesPerVector == 1, "Quant offset should 1 access per vector!"); + typename IteratorQOffset::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_QOffset_.get()); + + constexpr int kSrcBytes = + sizeof_bits::value * + IteratorQOffset::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async( + dst_ptr, iterator_QOffset.get(), iterator_QOffset.valid()); + } + + // Move to the next write stage + advance_smem_write_stage(iterator_A, iterator_B, iterator_QScale, iterator_QOffset); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + } + + + /// Wait until we have at least one completed global fetch stage + CUTLASS_DEVICE + void gmem_wait() + { + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); + + if constexpr(debug_layout) { + if (LayoutDebugType::debug_smem && layout_debug_.block_id_ == 1) { + if (threadIdx.x == 0){ + printf("stage: %d\n", smem_write_stage_idx_); + } + cutlass::debug::dump_shmem(layout_debug_.smem_qscale_ptr_, Base::SharedStorage::ShapeQScale::kCount); + if constexpr(kHasQOffset){ + cutlass::debug::dump_shmem(layout_debug_.smem_qoffset_ptr_, Base::SharedStorage::ShapeQScale::kCount); + } + } + } + } + + /// Perform a threadblock mainloop iteration of matrix multiply-accumulate + CUTLASS_DEVICE + void mac_loop_iter( + PipeState &pipe_state, ///< [in|out] loop-carried pipeline state + FragmentC &accum, ///< [in|out] destination accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory + IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory + int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + { + // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // Loading next warp-level tiles from shared memory. This can be skipped on the very + // last iteration where: + // (gemm_k_iterations == (1 - Base::kStages)) && (warp_mma_k == (Base::kWarpGemmIterations - 1)) + // However, evaluating this condition seems more expensive than simply loading the tiles + this->warp_tile_iterator_QScale_.load( + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + ++this->warp_tile_iterator_QScale_; + + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); + ++this->warp_tile_iterator_B_; + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + // All warp-tiles issue their share of global->shared fragment copies + copy_tiles_and_advance( + iterator_A, + iterator_B, + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + if constexpr(debug_layout) { + if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ + printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, warp_mma_k % Base::kWarpGemmIterations); + } + LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + if constexpr(kHasQOffset){ + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + } + + warp_mma_.transform( + pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + + if constexpr(debug_layout) { + LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + + // Execute the current warp-tile of MMA operations + if (Detail::kStagedAccumulation) { + warp_mma_( + pipe_state.tmp_accum_, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.tmp_accum_ + ); + + if (warp_mma_k == 0) { + plus plus_accum; + accum = plus_accum(accum, pipe_state.tmp_accum_); + pipe_state.tmp_accum_.clear(); + } + } else { + warp_mma_( + accum, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + accum + ); + } + + if (warp_mma_k == 0) { + copy_qscale_tiles(iterator_QScale); + } + if (warp_mma_k == 1) { + copy_qoffset_tiles(iterator_QOffset); + } + + // The second-to-last warp-tile also moves to the next global fetch stage + if (warp_mma_k == Base::kWarpGemmIterations - 2) { + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Move to the next global fetch stage + advance_smem_write_stage(iterator_A, iterator_B, iterator_QScale, iterator_QOffset); + advance_smem_read_stage(); + + // Disable global fetching when done with global fetch iterations + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); + if constexpr(kHasQOffset){ + iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); + } + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + } + + } + } + + /// Specialized mainloop iteration of matrix multiply-accumulate, for small M + CUTLASS_DEVICE + void mac_loop_iter_small_m( + PipeState &pipe_state, ///< [in|out] loop-carried pipeline state + FragmentC &accum, ///< [in|out] destination accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory + IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory + int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + { + // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // In the case of small M, memory latency dominates. We try to move uses far + // from their definitions to hide latency. + if constexpr(debug_layout) { + if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ + printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, warp_mma_k % Base::kWarpGemmIterations); + } + LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + if constexpr(kHasQOffset){ + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + } + + warp_mma_.transform( + pipe_state.warp_transformed_frag_B_[(warp_mma_k) % 2], + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + + if constexpr(debug_layout) { + LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[(warp_mma_k) % 2], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + + // Loading next warp-level tiles from shared memory. + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); + ++this->warp_tile_iterator_B_; + + this->warp_tile_iterator_QScale_.load( + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + ++this->warp_tile_iterator_QScale_; + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + // All warp-tiles issue their share of global->shared fragment copies + copy_tiles_and_advance( + iterator_A, + iterator_B, + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + // Execute the current warp-tile of MMA operations + if (Detail::kStagedAccumulation) { + warp_mma_( + pipe_state.tmp_accum_, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.tmp_accum_ + ); + + if (warp_mma_k == 0) { + plus plus_accum; + accum = plus_accum(accum, pipe_state.tmp_accum_); + pipe_state.tmp_accum_.clear(); + } + } else { + warp_mma_( + accum, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + accum + ); + } + + // The second-to-last warp-tile also moves to the next global fetch stage + if (warp_mma_k == Base::kWarpGemmIterations - 2) { + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Move to the next global fetch stage + advance_smem_write_stage(iterator_A, iterator_B, iterator_QScale, iterator_QOffset); + advance_smem_read_stage(); + + // Disable global fetching when done with global fetch iterations + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); + if constexpr(kHasQOffset){ + iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); + } + + copy_qscale_tiles(iterator_QScale); + copy_qoffset_tiles(iterator_QOffset); + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + } + + } + } + + + /// Perform the specified number of threadblock mainloop iterations of matrix + /// multiply-accumulate. Assumes prologue has been initiated. + CUTLASS_DEVICE + void gemm_iters( + int gemm_k_iterations, ///< number of threadblock mainloop iterations + FragmentC &accum, ///< [in|out] accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale &iterator_QScale, ///< [in|out] iterator over QScale operand in global memory + IteratorQOffset &iterator_QOffset) ///< [in|out] iterator over QOffset operand in global memory + { + PipeState pipe_state; + + // Disable global fetching if done with global fetch iterations + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); + if constexpr(kHasQOffset) { + iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); + } + + // Load first warp-tile's B fragment from shared memory + this->warp_tile_iterator_QScale_.load( + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + ++this->warp_tile_iterator_QScale_; + + this->warp_tile_iterator_B_.set_kgroup_index(0); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); + ++this->warp_tile_iterator_B_; + + // Load first warp-tile's A fragment from shared memory + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]); + ++this->warp_tile_iterator_A_; + + copy_tiles_and_advance(iterator_A, iterator_B, 0); + + if constexpr(Shape::kM > 32) { + // the case of bigger m + if constexpr(debug_layout) { + if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ + printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, 0); + } + LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + if constexpr(kHasQOffset){ + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + } + + warp_mma_.transform( + pipe_state.warp_transformed_frag_B_[0], + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + + if constexpr(debug_layout) { + LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[0], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + } else { + // the case of small m + copy_qscale_tiles(iterator_QScale); + copy_qoffset_tiles(iterator_QOffset); + } + + if (Detail::kStagedAccumulation) { + pipe_state.tmp_accum_.clear(); + } + + // Mainloop + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + if constexpr(Shape::kM > 32) { + mac_loop_iter( + pipe_state, + accum, + iterator_A, + iterator_B, + iterator_QScale, + iterator_QOffset, + gemm_k_iterations); + } else { + mac_loop_iter_small_m( + pipe_state, + accum, + iterator_A, + iterator_B, + iterator_QScale, + iterator_QOffset, + gemm_k_iterations); + } + } + + if (Detail::kStagedAccumulation) { + plus plus_accum; + accum = plus_accum(accum, pipe_state.tmp_accum_); + } + + // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + } + + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over quant scales in global memory + IteratorQScale iterator_QScale, + ///< Iterator over quant offsets in global memory + IteratorQOffset iterator_QOffset, + ///< initial value of accumulator + FragmentC const &src_accum) { + + // Prologue (start fetching iterations of global fragments into shared memory) + prologue(iterator_A, iterator_B, iterator_QScale, iterator_QOffset, gemm_k_iterations); + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + + // Initialize destination accumulators with source accumulators + accum = src_accum; + + // Perform the MAC-iterations + gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, iterator_QScale, iterator_QOffset); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h new file mode 100644 index 0000000000000..2c49888c94504 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h @@ -0,0 +1,112 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file default_quantb_mma_tensor_op.h + * @brief Modified from cutlass/gemm/warp/default_mma_tensor_op.h + * Default warp-level GEMM operators selected by data type, size, and layouts of operands. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h" + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for m-by-n-by-kgroup +template < + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Data type of quant scales + typename ElementQScale, + /// Layout of quant scales (concept: MatrixLayout) + typename SmemLayoutQScale, + /// Data type of quant offsets + typename ElementQOffset, + /// Layout of quant offsets (concept: MatrixLayout) + typename SmemLayoutQOffset, + /// Blocking size of quantization + typename QuantBlocking, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Operator describing the tensor operation + typename Operator_ = arch::OpMultiplyAdd, + /// Number of partitions along K dimension + int PartitionsK = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false> +struct DefaultQuantBMmaTensorOp { + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma, + cutlass::MatrixShape<1, 1> >; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::QuantBMmaTensorOp< + WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementQScale, SmemLayoutQScale, + ElementQOffset, SmemLayoutQOffset, QuantBlocking, ElementC, LayoutC, + Policy, PartitionsK, AccumulatorsInRowMajor>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h new file mode 100644 index 0000000000000..4ba39dda3db8d --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h @@ -0,0 +1,883 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + * + * @file quantb_meta_mma_tensor_op_tile_iterator.h + * @brief Templates for loading quantization meta data for operand B + * from shared memory to fragments. This is meant to be used in + * lock step with the operand B tile iterator. Containing logic + * to figure out the operand B layout in the tensor core, + * and deliver each meta data element to its corresponding + * operand B element for dequantization. + */ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor_op_multiplicand_sm75.h" + +#include "cutlass/platform/platform.h" +#include "cutlass/fast_math.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace{ + +struct b32_pair{ + uint32_t a; + uint32_t b; +}; + +struct fp16_quad{ + cutlass::half_t a; + cutlass::half_t b; + cutlass::half_t c; + cutlass::half_t d; +}; + +struct b16_quad{ + int16_t a; + int16_t b; + int16_t c; + int16_t d; +}; + +union b64 { + uint64_t single; + b32_pair pair; + b16_quad quard; + fp16_quad fp16_quad; +}; + +static_assert(sizeof(b64) == 8, "b64 should be 64 bits"); + +/// Convert packed 4b weights into fp16(weight + 16) +/// Current bit hacking only supports fp16, need to add bf16 later. +/// +template +CUTLASS_DEVICE +void weights2Half(cutlass::Array const &weights, + cutlass::Array& dest) +{ + static_assert(Size % 8 == 0, "Weights should have been prepacked by 2x2 tiles, 2 weights per tile."); + uint32_t* dest_pair = reinterpret_cast(dest.data()); + const uint32_t* w_oct = reinterpret_cast(weights.data()); + + CUTLASS_PRAGMA_UNROLL + for (int oct_idx = 0; oct_idx < Size/8; oct_idx++, w_oct++, dest_pair += 4){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + // static_cast(16 + weight) + // 4b weights are prepacked into [0, 2, 4, 6, 1, 3, 5, 7], so that adjacent weights + // are in different 16b half words, making it easier to convert to fp16. + asm volatile( + "{\n\t" + " shl.b32 %0, %4, 6;\n" + " shl.b32 %1, %4, 2;\n" + " shr.u32 %2, %4, 2;\n" + " shr.u32 %3, %4, 6;\n" + " lop3.b32 %0, %0, 0x03c003c0, 0x4c004c00, 0xea;\n" // a & 0x03c0 | 0x4c00 + " lop3.b32 %1, %1, 0x03c003c0, 0x4c004c00, 0xea;\n" + " lop3.b32 %2, %2, 0x03c003c0, 0x4c004c00, 0xea;\n" + " lop3.b32 %3, %3, 0x03c003c0, 0x4c004c00, 0xea;\n" + "}\n" + : "=r"(dest_pair[0]), "=r"(dest_pair[1]), + "=r"(dest_pair[2]), "=r"(dest_pair[3]) + : "r"(*w_oct)); +#else + assert(0); +#endif + } + +} + +} // namespace + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +// Traits to describe the layout of quantization meta data layout in a MMA fragment +// Since operand B is quantized on a per block basis, it's one meta data per block. + +template < + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads> +class QuantBMetaMmaTile{ +public: + + using WarpShapeB = WarpShapeB_; + using BlockingShape = BlockingShape_; + using ArchMmaOperator = ArchMmaOperator_; + + static_assert(Threads == 32, "This iterator should work in a warp only."); + + /// Shape of the curresponding operand B tile iterator + using TileShapeB = MatrixShape; + + // Tensor core operand B layout is a column major 4x8 tile, divided + // into 32 threads (T0 ~ T31) as shown below. Each element of the tile is 32b, + // so for fp16 it becomes 8 x 8, and int8 it becomes 16 x 8. + // T0 | T4 | T8 | T12 | T16 | T20 | T24 | T28 + // T1 | T5 | T9 | T13 | T17 | T21 | T25 | T29 + // T2 | T6 | T10 | T14 | T18 | T22 | T26 | T30 + // T3 | T7 | T11 | T15 | T19 | T23 | T27 | T31 + using CoreTile = layout::PitchLinearShape<4, 8>; + + /// Each thread holds a 32b fragment per tile: for half precision, it's 2 elements, 4 elements for int8 + static int const kNumBsPerCoreTileFragement = 32 / sizeof_bits::value; + + /// Each mma instruction can process either 1 or 2 tensor core operand B tiles (stacked on the k dimension) + static int const kBTilesPerMma = + sizeof_bits::value * ArchMmaOperator::FragmentB::kElements / 32; + static_assert(kBTilesPerMma == 1 || kBTilesPerMma == 2, "Only support 1 or 2 operand B tiles per mma."); + + /// Each operand B tile iterator load covers a number of mma instructions + static int const kMmaIterationsB = WarpShapeB::kColumn / ArchMmaOperator::Shape::kN; + + /// Number of B elements a fragment of meta data should cover + static int const kExpandedSize = kNumBsPerCoreTileFragement * kBTilesPerMma * kMmaIterationsB; + + // Now we figure out how many meta data elements to load for each TileShapeB + + /// Number of meta elements per CoreTile. + static int const kCoreTileFragementSize = (kNumBsPerCoreTileFragement + BlockingShape::kRow - 1) / BlockingShape::kRow; + + /// Number of core tiles per mma instruction, different from kBTilesPerMma when blocking size on K dimension + /// exceeds the tile depth, so two tiles share the same meta data + static int const kTilesPerMma = ((kBTilesPerMma == 2) && + (BlockingShape::kRow <= kNumBsPerCoreTileFragement * CoreTile::kContiguous)) + ? 2 : 1; + + /// stride to reach the meta data for the next CoreTile on the K dimension + static int const kKTileStride = (kNumBsPerCoreTileFragement * CoreTile::kContiguous + BlockingShape::kRow - 1) / BlockingShape::kRow; + + /// Stride on N dimension should be the tile width, shrunk by blocking size on this dimension. + static int const kNStride = (CoreTile::kStrided + BlockingShape::kColumn - 1) / BlockingShape::kColumn; + + /// On N dimension, how many tiles share the same meta data + static int const kNRepeats = (BlockingShape::kColumn + CoreTile::kStrided - 1) / CoreTile::kStrided; + + /// Each fragment should cover kMmaIterationsB number of mma intructions on the N dimension. + /// When blocking size on this dimension exceeds the tile width, multiple iterations + /// would share the same data. + static int const kMmaIterations = (kMmaIterationsB + kNRepeats - 1) / kNRepeats; + + static int const kFragementSize = kCoreTileFragementSize * kTilesPerMma * kMmaIterations; + + CUTLASS_DEVICE + static MatrixCoord lane_position(int lane_id) { + if constexpr(kNumBsPerCoreTileFragement == 2 + && kBTilesPerMma == 2 + && BlockingShape::kRow == 1){ + // Optimize for a special case of: + // 16b gemm (kNumBsPerCoreTileFragement == 2) + // 2 B operand tiles per mma (kBTilesPerMma == 2) + // (1,n) quantization blocking + // The scale and offset tensors are prepacked to reduce the number of load instructions. + return make_Coord((lane_id % CoreTile::kContiguous) * 4, + lane_id / CoreTile::kContiguous); + } else { + return make_Coord((lane_id % CoreTile::kContiguous) * kNumBsPerCoreTileFragement, + lane_id / CoreTile::kContiguous); + } + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// This tile iterator is to load quantization meta data for operand B from +/// shared memory to fragments (hopefully allocated to registers by compilers). +/// Examples of meta data include scale or offsets. The operand B matrix is +/// quantized on a per block basis, meaning one element of meta data per block. +/// +/// This is meant to be used in lock step with the operand B tile iterator. +/// So all parameters are logical positions in the operand B tiles. +/// The goal here is to deliver each meta data element to its corresponding +/// operand B element for dequantization. As a result, we need to figure +/// out the operand B layout in the tensor core. +/// +template < + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Data type of the quant scales + typename ElementScale_, + /// Layout of the quant scales + typename LayoutScale_, + /// Data type of quant offsets + typename ElementOffset_, + /// Layout of quant offsets + typename LayoutOffset_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads, + /// Number of partitions along K dimension + int PartitionsK_ = 1> +class QuantBMetaMmaTensorOpTileIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for column major layout + +template < + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Data type of the meta data elements + typename ElementScale_, + /// Data type of quant offsets + typename ElementOffset_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads> +class QuantBMetaMmaTensorOpTileIterator{ +public: + + using WarpShapeB = WarpShapeB_; + using BlockingShape = BlockingShape_; + using ElementScale = ElementScale_; + using Layout = cutlass::layout::ColumnMajor; + using ElementOffset = ElementOffset_; + using ArchMmaOperator = ArchMmaOperator_; + + static constexpr bool kHasOffset = !(std::is_same::value); + + static_assert(BlockingShape::kRow == 1 && BlockingShape::kColumn > 1, + "Only support row blocking for column major layout"); + + using MetaTile = QuantBMetaMmaTile; + + /// Number of MMA instructions for this tile + static constexpr int kMmaIterationsB = MetaTile::kMmaIterationsB; + + /// Number of B elements per mma tile fragment (32b), 2 for half precision, 4 for int8 + static constexpr int kNumBsPerCoreTileFragement = MetaTile::kNumBsPerCoreTileFragement; + + /// Each mma instruction can process either 1 or 2 operand B tiles (stacked on the k dimension) + static constexpr int kBTilesPerMma = MetaTile::kBTilesPerMma; + + /// Number of B elements a fragment of meta data should cover + static constexpr int kExpandedSize = MetaTile::kExpandedSize; + + /// Number of meta elements per core tile fragment + static constexpr int kCoreTileFragementSize = MetaTile::kCoreTileFragementSize; + + /// stride for reaching the next core tile (if there is one) on the K dimension + static constexpr int kKTileStride = MetaTile::kKTileStride; + + /// do we need to load meta data for the next core tile on the K dimension? + static constexpr int kTilesPerMma = MetaTile::kTilesPerMma; + + static constexpr int kNStride = MetaTile::kNStride; + static constexpr int kNRepeats = MetaTile::kNRepeats; + static constexpr int kMmaIterations = MetaTile::kMmaIterations; + + using TensorRefScale = TensorRef; + using TensorRefOffset = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using FragmentScale = Array; + using FragmentOffset = typename std::conditional, + std::monostate>::type; + + using AccessTypeScale = Array; + using AccessTypeOffset = Array; + +private: + + ElementScale *pointer_; + Layout layout_; + + ElementOffset *pointer_offset_; + Layout layout_offset_; + + TensorCoord lane_position_; + +public: + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator() { } + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator( + TensorRefScale const &ref, + TensorRefOffset const &ref_offset, + int lane_idx + ): + pointer_(ref.data()), + layout_(ref.layout()), + pointer_offset_(ref_offset.data()), + layout_offset_(ref_offset.layout()), + lane_position_(MetaTile::lane_position(lane_idx)){} + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(FragmentScale &frag, FragmentOffset &frag_offset) { + if constexpr(kNumBsPerCoreTileFragement == 2 + && kBTilesPerMma == 2){ + // Optimize for a special case of: + // 16b gemm (kNumBsPerCoreTileFragement == 2) + // 2 B operand tiles per mma (kBTilesPerMma == 2) + // (1,n) quantization blocking (BlockingShape::kRow == 1) + // The scale and offset tensors are prepacked to reduce the number of load instructions needed + const int row = lane_position_.row(); + const int column = lane_position_.column() / BlockingShape::kColumn; + + Array *dst_ptr = reinterpret_cast*>(frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + Array *src_ptr = reinterpret_cast*>(pointer_ + layout_({row, c})); + *dst_ptr = *src_ptr; + dst_ptr++; + } + + if constexpr(kHasOffset){ + Array *dst_ptr_offset = reinterpret_cast*>(frag_offset.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + Array *src_ptr_offset = reinterpret_cast*>(pointer_offset_ + layout_offset_({row, c})); + *dst_ptr_offset = *src_ptr_offset; + dst_ptr_offset++; + } + } + + } else { + // Other cases, offsets and scales are not prepacked. + + const int row = lane_position_.row() / BlockingShape::kRow; + const int column = lane_position_.column() / BlockingShape::kColumn; + + AccessTypeScale* dst_ptr = reinterpret_cast(frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_idx = 0, r = row; mma_tile_idx < kTilesPerMma; mma_tile_idx++, r += kKTileStride){ + AccessTypeScale* src_ptr = reinterpret_cast(pointer_ + layout_({r, c})); + *dst_ptr = *src_ptr; + dst_ptr++; + } + } + + if constexpr(kHasOffset){ + AccessTypeOffset* dst_ptr = reinterpret_cast(frag_offset.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_idx = 0, r = row; mma_tile_idx < kTilesPerMma; mma_tile_idx++, r += kKTileStride){ + AccessTypeOffset* src_ptr = reinterpret_cast(pointer_offset_ + layout_offset_({r, c})); + *dst_ptr = *src_ptr; + dst_ptr++; + } + } + } + } + } + + template + CUTLASS_HOST_DEVICE + static Array debug_expand(Array const &frag){ + Array ret; + int out_idx = 0; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + int n_idx = n_out / kNRepeats; + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma); + CUTLASS_PRAGMA_UNROLL + for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){ + int elem_idx = elem_out_idx / BlockingShape::kRow; + int idx = elem_idx + mma_tile_idx * kCoreTileFragementSize + n_idx * kCoreTileFragementSize * kTilesPerMma; + ret[out_idx] = frag[idx]; + out_idx++; + } + } + } + return ret; + } + + CUTLASS_HOST_DEVICE + static void dequant(FragmentScale const &scales, + FragmentOffset const &offsets, + Array const &weights, + Array& dest){ + static_assert(kNumBsPerCoreTileFragement == 2, "Only for 16b gemm."); + static_assert(kExpandedSize % 8 == 0, "Weights should have been prepacked by 2x2 tiles, 2 weights per tile."); + + // First convert 4b weight into fp16(weight + 16) + weights2Half(weights, dest); + + if constexpr(kBTilesPerMma == 2){ + // Optimize for a special case of: + // 2 B operand tiles per mma (kBTilesPerMma == 2) + // (1,n) quantization blocking (BlockingShape::kRow == 1) + + uint32_t* dest_pair = reinterpret_cast(dest.data()); + const b64* scales_ptr = reinterpret_cast(scales.data()); + const ElementOffset* offsets_ptr = nullptr; + if constexpr(kHasOffset) { offsets_ptr = offsets.data(); } + + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){ + // dequantize: d = scale * (weight - offset) + // to use FMA, d = scale * weight + (scale * (-offset)) + + b64 offsets; + if constexpr(kHasOffset){ + const uint32_t* p = reinterpret_cast(offsets_ptr); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0, rb1;\n" // b32 regs for fp16x2 mul operands + + // static_cast(-16 - offset) + // input [d, b, c, a], + " shl.b32 rb0, %4, 6;\n" // rb0 = [x, b, x, a] << 6 + " shr.u32 rb1, %4, 2;\n" // rb1 = [x, d, x, c] << 6 + " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " lop3.b32 rb1, rb1, 0x03c003c0, 0xcc00cc00, 0xea;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - offset) + " mul.rn.f16x2 %1, %3, rb1;\n" + "}\n" + : "=r"(offsets.pair.a), "=r"(offsets.pair.b) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b), + "r"(p[0])); +#else + assert(0); +#endif + + offsets_ptr += 4; + } else { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0;\n" + " mov.u32 rb0, 0xce00ce00;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - 8) + " mul.rn.f16x2 %1, %3, rb0;\n" + "}\n" + : "=r"(offsets.pair.a), "=r"(offsets.pair.b) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b)); +#else + offsets.fp16_quad.a = scales_ptr->fp16_quad.a * static_cast(-16-8); + offsets.fp16_quad.b = scales_ptr->fp16_quad.b * static_cast(-16-8); + offsets.fp16_quad.c = scales_ptr->fp16_quad.c * static_cast(-16-8); + offsets.fp16_quad.d = scales_ptr->fp16_quad.d * static_cast(-16-8); +#endif + } + + CUTLASS_PRAGMA_UNROLL + for (int n_r = 0; n_r < kNRepeats; n_r++){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " fma.rn.f16x2 %0, %2, %0, %4;\n" // dest = scale * (16 + weight) + (scale * (-16 - offset)) + " fma.rn.f16x2 %1, %3, %1, %5;\n" + "}\n" + : "+r"(dest_pair[0]), "+r"(dest_pair[1]) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b), + "r"(offsets.pair.a), "r"(offsets.pair.b)); +#else + assert(0); +#endif + dest_pair += 2; + } + scales_ptr++; + } + + } else { + // unoptiomized path for other cases, very slow + int out_idx = 0; + ElementScale offset; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + int n_idx = n_out / kNRepeats; + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma); + CUTLASS_PRAGMA_UNROLL + for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){ + int elem_idx = elem_out_idx / BlockingShape::kRow; + int idx = elem_idx + mma_tile_idx * kCoreTileFragementSize + n_idx * kCoreTileFragementSize * kTilesPerMma; + ElementScale s = scales[idx]; + if constexpr(kHasOffset){ + offset = s * static_cast(-16 - int(offsets[idx])); + } else { + offset = s * static_cast(-16-8); + } + dest[out_idx] = s * dest[out_idx] + offset; + out_idx++; + } + } + } + + } + + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + QuantBMetaMmaTensorOpTileIterator &operator++() { + // This is for operand B, so advance on the K dimension + lane_position_ += make_Coord(MetaTile::TileShapeB::kRow, 0); + return *this; + } + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator &add_tile_offset( + TensorCoord const &tile_offset) { + int rows = tile_offset.row() * MetaTile::TileShapeB::kRow; + int columns = tile_offset.column() * MetaTile::TileShapeB::kColumn; + lane_position_ += TensorCoord(rows, columns); + return *this; + } + +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row major layout + +template < + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Data type of the meta data elements + typename ElementScale_, + /// Data type of quant offsets + typename ElementOffset_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads> +class QuantBMetaMmaTensorOpTileIterator{ +public: + + using WarpShapeB = WarpShapeB_; + using BlockingShape = BlockingShape_; + using ElementScale = ElementScale_; + using ElementOffset = ElementOffset_; + using Layout = cutlass::layout::RowMajor; + using ArchMmaOperator = ArchMmaOperator_; + + static constexpr bool kHasOffset = !(std::is_same::value); + + static_assert(BlockingShape::kColumn == 1 && BlockingShape::kRow > 1, + "Only support column blocking for row major layout"); + + using MetaTile = QuantBMetaMmaTile; + + /// Number of MMA instructions for this tile + static constexpr int kMmaIterationsB = MetaTile::kMmaIterationsB; + + /// Number of B elements per mma tile fragment (32b), 2 for half precision, 4 for int8 + static constexpr int kNumBsPerCoreTileFragement = MetaTile::kNumBsPerCoreTileFragement; + + /// Each mma instruction can process either 1 or 2 operand B tiles (stacked on the k dimension) + static constexpr int kBTilesPerMma = MetaTile::kBTilesPerMma; + + /// Number of B elements a fragment of meta data should cover + static constexpr int kExpandedSize = MetaTile::kExpandedSize; + + /// Number of meta elements per core tile fragment + static constexpr int kCoreTileFragementSize = MetaTile::kCoreTileFragementSize; + + /// stride for reaching the next core tile (if there is one) on the K dimension + static constexpr int kKTileStride = MetaTile::kKTileStride; + + /// do we need to load meta data for the next core tile on the K dimension? + static constexpr int kTilesPerMma = MetaTile::kTilesPerMma; + + static constexpr int kNStride = MetaTile::kNStride; + static constexpr int kNRepeats = MetaTile::kNRepeats; + static constexpr int kMmaIterations = MetaTile::kMmaIterations; + + using TensorRefScale = TensorRef; + using TensorRefOffset = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using FragmentScale = Array; + using FragmentOffset = typename std::conditional, + std::monostate>::type; + +private: + + ElementScale *pointer_; + Layout layout_; + + ElementOffset *pointer_offset_; + Layout layout_offset_; + + TensorCoord lane_position_; + +public: + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator() { } + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator( + TensorRefScale const &ref, + TensorRefOffset const &ref_offset, + int lane_idx + ): + pointer_(ref.data()), + layout_(ref.layout()), + pointer_offset_(ref_offset.data()), + layout_offset_(ref_offset.layout()), + lane_position_(MetaTile::lane_position(lane_idx)) + {} + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(FragmentScale &frag, FragmentOffset &frag_offset) { + const int row = lane_position_.row() / BlockingShape::kRow; + const int column = lane_position_.column() / BlockingShape::kColumn; + static_assert(kTilesPerMma * kCoreTileFragementSize == 1, "Only support one meta data per core tile"); + + ElementScale* src_ptr = pointer_ + layout_({row, column}); + ElementScale* dst_ptr = frag.data(); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){ + dst_ptr[n_idx] = src_ptr[n_idx * kNStride]; + } + + if constexpr(kHasOffset){ + ElementOffset* src_ptr_offset = pointer_offset_ + layout_offset_({row, column}); + ElementOffset* dst_ptr_offset = frag_offset.data(); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){ + dst_ptr_offset[n_idx] = src_ptr_offset[n_idx * kNStride]; + } + } + } + + template + CUTLASS_HOST_DEVICE + static Array debug_expand(Array const &frag){ + Array ret; + + int out_idx = 0; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + int n_idx = n_out / kNRepeats; + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma); + CUTLASS_PRAGMA_UNROLL + for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){ + int elem_idx = elem_out_idx / BlockingShape::kRow; + int col = elem_idx + mma_tile_idx * kCoreTileFragementSize; + int idx = col * kMmaIterations + n_idx; + ret[out_idx] = frag[idx]; + out_idx++; + } + } + } + return ret; + } + + CUTLASS_HOST_DEVICE + static void dequant(FragmentScale const &scales, + FragmentOffset const &offsets, + Array const &weights, + Array& dest){ + static_assert(kNRepeats == 1, "This is implied by BlockingShape::kColumn == 1"); + static_assert(kNumBsPerCoreTileFragement == 2, "Only for 16b gemm now."); + + // First convert 4b weight into fp16(weight + 16) + weights2Half(weights, dest); + + ElementScale addon[kMmaIterationsB]; + if constexpr (kMmaIterationsB % 4 == 0) { + const b64* scales_ptr = reinterpret_cast(scales.data()); + uint32_t* addon_ptr = reinterpret_cast(addon); + if constexpr(kHasOffset){ + const uint32_t* p = reinterpret_cast(offsets.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterationsB; n_idx += 4){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0, rb1, rb2;\n" + + // offset from [d, c, b, a] --> [d, b, c, a] + " prmt.b32 rb2, %4, rb0, 0x3120;\n" + + // static_cast(-16 - offset) + // input [d, b, c, a], + " shl.b32 rb0, rb2, 6;\n" // rb0 = [x, b, x, a] << 6 + " shr.u32 rb1, rb2, 2;\n" // rb1 = [x, d, x, c] << 6 + " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " lop3.b32 rb1, rb1, 0x03c003c0, 0xcc00cc00, 0xea;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - offset) + " mul.rn.f16x2 %1, %3, rb1;\n" + "}\n" + : "=r"(addon_ptr[0]), "=r"(addon_ptr[1]) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b), + "r"(p[0])); +#else + assert(0); +#endif + scales_ptr++; + p++; + addon_ptr += 2; + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterationsB; n_idx += 4){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0;\n" + " mov.u32 rb0, 0xce00ce00;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - 8) + " mul.rn.f16x2 %1, %3, rb0;\n" + "}\n" + : "=r"(addon_ptr[0]), "=r"(addon_ptr[1]) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b)); +#else + assert(0); +#endif + scales_ptr++; + addon_ptr += 2; + } + } + } else if constexpr (kMmaIterationsB % 2 == 0) { + const uint32_t* scales_ptr = reinterpret_cast(scales.data()); + uint32_t* addon_ptr = reinterpret_cast(addon); + + if constexpr (kHasOffset){ + // possible buffer over read 2 bytes here. + const uint32_t* p = reinterpret_cast(offsets.data()); +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0, rb1, rb2;\n" + + // offset from [?, ?, b, a] --> [?, b, ?, a] + " prmt.b32 rb2, %2, rb0, 0x3120;\n" + + // static_cast(-16 - offset) + // input [d, b, c, a], + " shl.b32 rb0, rb2, 6;\n" // rb0 = [x, b, x, a] << 6 + " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " mul.rn.f16x2 %0, %1, rb0;\n" // offset = scale * (-16 - offset) + "}\n" + : "=r"(addon_ptr[0]) + : "r"(scales_ptr[0]) + "r"(p[0])); +#else + assert(0); +#endif + } else { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0;\n" + " mov.u32 rb0, 0xce00ce00;\n" + " mul.rn.f16x2 %0, %1, rb0;\n" // offset = scale * (-16 - 8) + "}\n" + : "=r"(addon_ptr[0]) + : "r"(scales_ptr[0])); +#else + assert(0); +#endif + } + } else { + // kMmaIterationsB == 1 + if constexpr(kHasOffset){ + uint8_t zp = offsets[0]; + addon[0] = scales[0] * static_cast(-16 - static_cast(zp)); + } else { + addon[0] = scales[0] * static_cast(-16-8); + } + } + + int out_idx = 0; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + dest[out_idx] = scales[n_out] * dest[out_idx] + addon[n_out]; + dest[out_idx + 1] = scales[n_out] * dest[out_idx + 1] + addon[n_out]; + out_idx += 2; + } + } + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + QuantBMetaMmaTensorOpTileIterator &operator++() { + // This is for operand B, so advance on the K dimension + lane_position_ += make_Coord(MetaTile::TileShapeB::kRow, 0); + return *this; + } + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator &add_tile_offset( + TensorCoord const &tile_offset) { + int rows = tile_offset.row() * MetaTile::TileShapeB::kRow; + int columns = tile_offset.column() * MetaTile::TileShapeB::kColumn; + lane_position_ += TensorCoord(rows, columns); + return *this; + } + +}; + + +//////////////////////////////////////////////////////////////////////////////// +} // namespace warp +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h new file mode 100644 index 0000000000000..f29cedf326a44 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h @@ -0,0 +1,361 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file quantb_mma_tensor_op.h + * @brief Modified from cutlass/gemm/warp/mma_tensor_op.h + * Templates implementing warp-level matrix multiply-accumulate operations + * targeting tensor cores. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" +#include "cutlass/gemm/warp/mma_tensor_op_policy.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +#include "cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Data type of quant scales + typename ElementQScale_, + /// Layout of quant scales (concept: MatrixLayout) + typename SmemLayoutQScale_, + /// Data type of quant offsets + typename ElementQOffset_, + /// Layout of quant offsets (concept: MatrixLayout) + typename SmemLayoutQOffset_, + /// Blocking dimensions of quantization + typename QuantBlocking_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Number of partitions along K dimension + int PartitionsK_ = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Used for partial specialization + typename Enable = bool +> +class QuantBMmaTensorOp { +public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + +public: + + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator< + MatrixShape, Operand::kA, ElementA, LayoutA, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = + Array; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator< + MatrixShape, Operand::kB, ElementB, LayoutB, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + // warp B MatrixShape<64, 64>, + // layout B cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<16, 64>, + // instruction op shape cutlass::MatrixShape<16, 8>, + // kPartitionsK 1 + // FragmentB::kElements 32 + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; // cutlass::Array + + /// Storage for transformed B tile + /// When loading weights, we packed 4 int4 weights into one 2-byte-element, when expanded + /// we multiply the number of elements by 4. + /// TODO: make sure ArchMmaOperator::ElementB same as dequantized ElementB + /// and change the transform function below to perform dequantization + using TransformedFragmentB = + Array; + + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpAccumulatorTileIterator< + MatrixShape, ElementC, LayoutC, + typename ArchMmaOperator::Shape, typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + using ElementQScale = ElementQScale_; + using SmemLayoutQScale = SmemLayoutQScale_; + using QuantBlocking = QuantBlocking_; + + using ElementQOffset = ElementQOffset_; + using SmemLayoutQOffset = SmemLayoutQOffset_; + + /// Iterates over the quantization parameters in memory + using WarpQScaleShape = MatrixShape<(Shape::kK / QuantBlocking::kRow), (Shape::kN / QuantBlocking::kColumn)>; + static_assert(Shape::kK % QuantBlocking::kRow == 0, "K must be multiple of QuantBlocking::kRow"); + static_assert(Shape::kN % QuantBlocking::kColumn == 0, "N must be multiple of QuantBlocking::kColumn"); + static_assert(WarpQScaleShape::kCount > 0, "QuantBlocking too big to fit in a warp block!"); + + // TODO This is an expanding iterator, it needs to replicate the quantization parameters + // to all threads in the warp. + using IteratorQMeta = QuantBMetaMmaTensorOpTileIterator< + MatrixShape, QuantBlocking, ElementQScale, SmemLayoutQScale, + ElementQOffset, SmemLayoutQOffset, + ArchMmaOperator, kThreadCount, kPartitionsK>; + + using FragmentQScale = typename IteratorQMeta::FragmentScale; + using FragmentQOffset = typename IteratorQMeta::FragmentOffset; + + /// Number of mma operations performed + using MmaIterations = MatrixShape< + (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN + >; + +public: + + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + +public: + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + QuantBMmaTensorOp() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + TransformedFragmentA const &A, + TransformedFragmentB const &B, + FragmentC const &C + ) const { + + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + D = C; + + MmaOperandA const *ptr_A = reinterpret_cast(&A); + MmaOperandB const *ptr_B = reinterpret_cast(&B); + MmaOperandC *ptr_D = reinterpret_cast(&D); + + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + // Serpentine visitation order maximizing reuse of Rb + // The visitation order is like + // _ + // | | | | + // | | | | + // |_| |_| + // + // Down Up Down Up + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma( + ptr_D[n + m_serpentine * MmaIterations::kColumn], + ptr_A[m_serpentine], + ptr_B[n], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } else { + mma( + ptr_D[m_serpentine + n * MmaIterations::kRow], + ptr_A[m_serpentine], + ptr_B[n], + ptr_D[m_serpentine + n * MmaIterations::kRow]); + } + } + } + #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + // The visitation order is like + // _________ + // _________| + // |_________ + // __________| + // + // Right Left Right Left + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma( + ptr_D[n_serpentine + m * MmaIterations::kColumn], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } + #else + assert(0); + #endif + } + + /// Transform the mma operands to the required types + CUTLASS_DEVICE + void transform(TransformedFragmentB &dst_B, + FragmentB const &B, + FragmentQScale const &scales, + FragmentQOffset const &offsets) const { + + Array const *ptr_B = + reinterpret_cast const *>(&B); + IteratorQMeta::dequant(scales, offsets, *ptr_B, dst_B); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +//#include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 047011e70bd4d..32e9cc98106d5 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -37,9 +37,7 @@ typedef enum { CompMostAccurate = CompUndef, CompLeastAccurate = CompInt8, -} MLAS_SQNBIT_COMPUTE_TYPE; - -using MLAS_SQNBIT_GEMM_COMPUTE_TYPE = MLAS_SQNBIT_COMPUTE_TYPE; // TODO consolidate these +} MLAS_SQNBIT_GEMM_COMPUTE_TYPE; /** * @brief Data parameters for float/n-bit quantized int GEMM routine. @@ -102,18 +100,12 @@ MlasSQNBitGemmBatch( /** * @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform. * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ bool MLASCALL MlasIsSQNBitGemmAvailable( - size_t M, - size_t N, - size_t K, size_t BlkBitWidth, size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType @@ -153,13 +145,15 @@ MlasSQNBitGemmBatchWorkspaceSize( * @param[in] K column size of matrix A and row size of matrix B * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ size_t MLASCALL MlasSQNBitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkBitWidth, - size_t BlkLen + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ); /** @@ -169,6 +163,7 @@ MlasSQNBitGemmPackQuantBDataSize( * @param[in] K column size of matrix A and row size of matrix B * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) * @param[in] QuantBData quantized B data * @param[out] PackedQuantBData packed quantized B data * @param[in] ThreadPool optional thread pool to use @@ -179,6 +174,7 @@ MlasSQNBitGemmPackQuantBData( size_t K, size_t BlkBitWidth, size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, void* PackedQuantBData, MLAS_THREADPOOL* ThreadPool = nullptr diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 0d8a5692359a6..38c31c8841761 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -39,23 +39,17 @@ enum SQNBitGemmVariant { SQNBitGemmVariant GetSQNBitGemmVariant( - size_t M, - size_t N, - size_t K, size_t BlkBitWidth, size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - MLAS_UNREFERENCED_PARAMETER(N); - MLAS_UNREFERENCED_PARAMETER(K); - if (BlkBitWidth == 4 && (BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { if (ComputeType == CompFp32 || ComputeType == CompUndef) { // treat CompUndef (undefined) as CompFp32 return SQNBitGemmVariant_BitWidth4_CompFp32; - } else if (ComputeType == CompInt8 && M == 1) { + } else if (ComputeType == CompInt8) { return SQNBitGemmVariant_BitWidth4_CompInt8; } } @@ -67,9 +61,6 @@ GetSQNBitGemmVariant( bool MLASCALL MlasIsSQNBitGemmAvailable( - size_t M, - size_t N, - size_t K, size_t BlkBitWidth, size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType @@ -80,7 +71,7 @@ MlasIsSQNBitGemmAvailable( return false; } - const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); switch (Variant) { case SQNBitGemmVariant_BitWidth4_CompFp32: { @@ -164,7 +155,7 @@ MlasSQNBitGemmBatchWorkspaceSize( MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen); if (PerGemmWorkspaceStride == 0) { @@ -178,91 +169,24 @@ MlasSQNBitGemmBatchWorkspaceSize( return WorkspaceSize + Alignment - 1; } -namespace -{ - -void -SQ4BitGemmPackQuantBData( - size_t N, - size_t K, - size_t BlkLen, - const std::byte* QuantBDataBegin, - std::byte* PackedQuantBDataBegin, - MLAS_THREADPOOL* ThreadPool -) -{ - constexpr size_t BlkBitWidth = 4; - - assert(BlkLen % 16 == 0); - - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t Iterations = N * BlockCountK; // one iteration per block - - MlasTrySimpleParallel( - ThreadPool, Iterations, - [&](ptrdiff_t tid) { - const size_t n = tid / BlockCountK; - const size_t k_blk = tid % BlockCountK; - - const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize; - const std::byte* QuantBData = QuantBDataBegin + data_offset; - std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset; - - // - // Pack 16 4-bit values (8 bytes) at a time like this: - // - // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | - // => - // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | - // - for (size_t kk = 0; kk < BlkLen; kk += 16) { - for (size_t byte_pair_idx = 0; byte_pair_idx < 4; ++byte_pair_idx) { - const std::byte src0 = QuantBData[byte_pair_idx]; - const std::byte src1 = QuantBData[byte_pair_idx + 4]; - - std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; - std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; - - dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); - dst1 = (src0 >> 4) | ((src1 >> 4) << 4); - } - - QuantBData += 8; - PackedQuantBData += 8; - } - } - ); -} - -} // namespace - size_t MLASCALL MlasSQNBitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkBitWidth, - size_t BlkLen + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - // Ensure that a general implementation is available on this platform. - // For now, all implementations share the same packed format. - { - // Currently, there are implementations specific to M = 1, so pick a more general M > 1. - constexpr size_t M = 2; - // A CompUndef implementation should be available if any is available. - constexpr MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType = CompUndef; - const bool HasGeneralImplementation = - MlasIsSQNBitGemmAvailable(M, N, K, BlkBitWidth, BlkLen, ComputeType); - if (!HasGeneralImplementation) { - return 0; - } + const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + if (Dispatch == nullptr) { + return 0; } - if (BlkBitWidth == 4) { - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - return PackedQuantBDataSize; + if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBDataSize != nullptr) { + return Dispatch->SQ4BitGemmPackQuantBDataSize( + N, K, BlkLen, ComputeType + ); } return 0; @@ -274,20 +198,28 @@ MlasSQNBitGemmPackQuantBData( size_t K, size_t BlkBitWidth, size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, void* PackedQuantBData, MLAS_THREADPOOL* ThreadPool ) { - if (BlkBitWidth == 4) { - SQ4BitGemmPackQuantBData( + const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + if (Dispatch == nullptr) { + return; + } + + if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBData != nullptr) { + Dispatch->SQ4BitGemmPackQuantBData( N, K, BlkLen, + ComputeType, static_cast(QuantBData), static_cast(PackedQuantBData), ThreadPool ); + return; } } @@ -512,7 +444,37 @@ SQ4BitGemm_CompInt8( return; } - assert(false && "not implemented for M > 1"); + // This is a naive M > 1 implementation that repeatedly calls the M=1 kernel. + // TODO Replace it with an optimized implementation. + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const std::byte* a_row = QuantA; + const std::byte* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const std::byte* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + for (size_t m = 0; m < RangeCountM; ++m) { + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( + BlkLen, + a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + + c_blk += ldc; + a_row += lda; + } + } } typedef void(InitializeWorkspaceFn)( @@ -594,7 +556,7 @@ MlasSQNBitGemmBatch( MLAS_THREADPOOL* ThreadPool ) { - const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); assert(Variant != SQNBitGemmVariantInvalid); // diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index a66db79dc290a..3992bc3e452a3 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -99,6 +99,33 @@ Q8BlkAlignment() // struct MLAS_SQNBIT_GEMM_DISPATCH { + // + // Quantized B data packing function prototypes. + // + + /** Gets size of packed quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBDataSize(). */ + typedef size_t(SQ4BitGemmPackQuantBDataSize_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + ); + + SQ4BitGemmPackQuantBDataSize_Fn* SQ4BitGemmPackQuantBDataSize = nullptr; + + /** Packs quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBData(). */ + typedef void(SQ4BitGemmPackQuantBData_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool + ); + + SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; + // // CompFp32 kernel function prototypes. // diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index 69fd427fa574a..9d7b0ae06e220 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -15,14 +15,115 @@ Module Name: --*/ -#include "sqnbitgemm.h" - #include #include #include #include +#include "sqnbitgemm.h" + +// +// Quantized B data packing function implementation. +// + +namespace +{ + +size_t +SQ4BitGemmPackQuantBDataSize( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType + + constexpr size_t BlkBitWidth = 4; + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + return PackedQuantBDataSize; +} + +void +SQ4BitGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + constexpr size_t BlkBitWidth = 4; + + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t Iterations = N * BlockCountK; // one iteration per block + + const size_t SubBlkLen = (ComputeType == CompInt8) + ? ((BlkLen == 16) ? 16 : 32) + : 16; + + const size_t SubBlkDataSize = SubBlkLen / 2; + const size_t SubBlkBytePairCount = SubBlkLen / 4; + + // + // For SubBlkLen == 16, pack 16 4-bit values (8 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | + // => + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + // + + // + // For SubBlkLen == 32, pack 32 4-bit values (16 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | + // => + // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t k_blk = tid % BlockCountK; + + const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize; + const std::byte* QuantBData = QuantBDataBegin + data_offset; + std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset; + + for (size_t kk = 0; kk < BlkLen; kk += SubBlkLen) { + for (size_t byte_pair_idx = 0; byte_pair_idx < SubBlkBytePairCount; ++byte_pair_idx) { + const std::byte src0 = QuantBData[byte_pair_idx]; + const std::byte src1 = QuantBData[byte_pair_idx + SubBlkDataSize / 2]; + + std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; + std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; + + dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); + dst1 = (src0 >> 4) | ((src1 >> 4) << 4); + } + + QuantBData += SubBlkDataSize; + PackedQuantBData += SubBlkDataSize; + } + } + ); +} + +} // namespace + +// +// General helpers. +// + namespace { @@ -95,7 +196,16 @@ LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4]) } } -template +} // namespace + +// +// CompFp32 kernel implementation. +// + +namespace +{ + +template MLAS_FORCEINLINE void ComputeDotProducts_BlkBitWidth4_CompFp32( size_t BlkLen, @@ -112,11 +222,11 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( ) { constexpr size_t BlkBitWidth = 4; + constexpr size_t SubBlkLen = 16; static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); - constexpr size_t SubBlkLen = 16; // number of block elements to process in a sub-block iteration - assert(BlkLen % SubBlkLen == 0); + assert(BlkLen >= SubBlkLen && BlkLen % SubBlkLen == 0); const uint8x8_t LowMask = vdup_n_u8(0x0F); @@ -137,7 +247,8 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( const std::byte* QuantBData = QuantBDataColPtr; const float* QuantBScale = QuantBScaleColPtr; - size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + // only used if HasZeroPoint == true for (size_t k = 0; k < CountK; k += BlkLen) { const size_t k_blk_len = std::min(CountK - k, BlkLen); @@ -147,8 +258,9 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( [&](size_t i) { scale[i] = QuantBScale[i * StrideQuantBScale]; } ); - float offset[NCols]; // Includes zero point and float conversion offset of 16. - if (QuantBZeroPointColPtr != nullptr) { + [[maybe_unused]] float offset[NCols]; // Includes zero point and float conversion offset of 16. + // only used if HasZeroPoint == true + if constexpr (HasZeroPoint) { UnrolledLoop([&](size_t i) { const std::byte zp_packed = QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; @@ -157,11 +269,6 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( : (zp_packed & std::byte{0x0F}); offset[i] = 16.0f + std::to_integer(zp); }); - } else { - UnrolledLoop([&](size_t i) { - constexpr float zp = 8.0f; - offset[i] = 16.0f + zp; - }); } for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { @@ -187,8 +294,6 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); }); - // dequantize B - // shift left 3 and widen to 16 bits uint16x8_t bv_u16[NCols][2]; UnrolledLoop([&](size_t i) { @@ -217,10 +322,17 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( }); // subtract float conversion offset (16) and zero point - UnrolledLoop([&](size_t i) { - const float32x4_t offset_v = vdupq_n_f32(offset[i]); - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); - }); + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const float32x4_t offset_v = vdupq_n_f32(offset[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } else { + const float32x4_t offset_v = vdupq_n_f32(16.0f + 8.0f); + UnrolledLoop([&](size_t i) { + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } // multiply by scale UnrolledLoop([&](size_t i) { @@ -237,7 +349,9 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( // increment pointers to next block QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); QuantBScale += 1; - QuantBZeroPointIdx += 1; + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } } if constexpr (NCols == 4) { @@ -258,8 +372,9 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( } } -MLAS_FORCEINLINE void -SQ4BitGemmM1Kernel_CompFp32( +template +void +SQ4BitGemmM1Kernel_CompFp32_Impl( size_t BlkLen, const float* A, const std::byte* QuantBData, @@ -295,7 +410,7 @@ SQ4BitGemmM1Kernel_CompFp32( int64_t nblk = static_cast(CountN) - NCols; while (nblk >= 0) { - ComputeDotProducts_BlkBitWidth4_CompFp32( + ComputeDotProducts_BlkBitWidth4_CompFp32( BlkLen, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, @@ -306,7 +421,7 @@ SQ4BitGemmM1Kernel_CompFp32( QuantBDataColPtr += NCols * StrideQuantBData; QuantBScaleColPtr += NCols * StrideQuantBScale; - if (QuantBZeroPointColPtr != nullptr) { + if constexpr (HasZeroPoint) { QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; } @@ -319,7 +434,7 @@ SQ4BitGemmM1Kernel_CompFp32( // left over columns less than `NCols`? nblk += NCols; for (int64_t n = 0; n < nblk; ++n) { - ComputeDotProducts_BlkBitWidth4_CompFp32<1>( + ComputeDotProducts_BlkBitWidth4_CompFp32<1, HasZeroPoint>( BlkLen, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, @@ -330,7 +445,7 @@ SQ4BitGemmM1Kernel_CompFp32( QuantBDataColPtr += StrideQuantBData; QuantBScaleColPtr += StrideQuantBScale; - if (QuantBZeroPointColPtr != nullptr) { + if constexpr (HasZeroPoint) { QuantBZeroPointColPtr += StrideQuantBZeroPoint; } @@ -339,6 +454,49 @@ SQ4BitGemmM1Kernel_CompFp32( } } +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_CompFp32( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_CompFp32_Impl( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_CompFp32_Impl( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } +} + MLAS_FORCEINLINE void Q4BitBlkDequantBForSgemm_CompFp32( size_t BlkLen, @@ -353,6 +511,7 @@ Q4BitBlkDequantBForSgemm_CompFp32( { auto impl0_reference = [&]() { constexpr size_t BlkBitWidth = 4; + constexpr size_t SubBlkLen = 16; float* Dst = FpData; @@ -378,11 +537,11 @@ Q4BitBlkDequantBForSgemm_CompFp32( : 8; for (size_t kk = 0; kk < kklen; ++kk) { - const size_t packed_idx = kk % 16; + const size_t packed_idx = kk % SubBlkLen; - const bool is_low_half = packed_idx < 8; - const size_t packed_byte_idx = packed_idx % 8; - const size_t packed_range_offset = (kk / 16) * 8; + const bool is_low_half = packed_idx < (SubBlkLen / 2); + const size_t packed_byte_idx = packed_idx % (SubBlkLen / 2); + const size_t packed_range_offset = (kk / SubBlkLen) * (SubBlkLen / 2); const std::byte b_packed = b_data[packed_range_offset + packed_byte_idx]; const std::byte b_byte = is_low_half ? (b_packed & std::byte{0x0F}) : (b_packed >> 4); @@ -415,7 +574,7 @@ Q4BitBlkDequantBForSgemm_CompFp32( } // -// CompInt8 kernel implementation and related helpers +// CompInt8 kernel implementation. // template @@ -431,8 +590,6 @@ QuantizeBlock( assert(BlkLen % SubBlkLen == 0); - constexpr size_t VectorCount = SubBlkLen / 4; - // // Scan block values first to determine scale. // @@ -443,16 +600,16 @@ QuantizeBlock( for (k = 0; k < ElementCount; k += SubBlkLen) { const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); - float32x4_t a[VectorCount]{}; + float32x4_t a[SubBlkLen / 4]{}; LoadFloatData(A + k, SubBlkElementCount, a); - float32x4_t abs_a[VectorCount]; - UnrolledLoop([&](size_t i) { + float32x4_t abs_a[SubBlkLen / 4]; + UnrolledLoop([&](size_t i) { abs_a[i] = vabsq_f32(a[i]); }); // find amax of SubBlkLen elements - for (size_t interval = VectorCount / 2; interval > 0; interval /= 2) { + for (size_t interval = SubBlkLen / 4 / 2; interval > 0; interval /= 2) { for (size_t i = 0; i < interval; ++i) { abs_a[i] = vmaxq_f32(abs_a[i], abs_a[i + interval]); } @@ -477,19 +634,19 @@ QuantizeBlock( for (k = 0; k < ElementCount; k += SubBlkLen) { const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); - float32x4_t a[VectorCount]{}; + float32x4_t a[SubBlkLen / 4]{}; LoadFloatData(A + k, SubBlkElementCount, a); - UnrolledLoop([&](size_t i) { + UnrolledLoop([&](size_t i) { a[i] = vmulq_n_f32(a[i], scale_reciprocal); }); - int32x4_t a_s32[VectorCount]; - UnrolledLoop([&](size_t i) { + int32x4_t a_s32[SubBlkLen / 4]; + UnrolledLoop([&](size_t i) { a_s32[i] = vcvtaq_s32_f32(a[i]); }); - UnrolledLoop([&](size_t i) { + UnrolledLoop([&](size_t i) { QuantAData[k + i * 4 + 0] = static_cast(vgetq_lane_s32(a_s32[i], 0)); QuantAData[k + i * 4 + 1] = static_cast(vgetq_lane_s32(a_s32[i], 1)); QuantAData[k + i * 4 + 2] = static_cast(vgetq_lane_s32(a_s32[i], 2)); @@ -530,136 +687,314 @@ QuantizeARow_CompInt8( } } -template -MLAS_FORCEINLINE void -ComputeDotProducts_BlkBitWidth4_CompInt8( - size_t BlkLen, - const std::byte* QuantARowPtr, - const std::byte* QuantBDataColPtr, - const float* QuantBScaleColPtr, - const std::byte* QuantBZeroPointColPtr, - float* SumPtr, - size_t CountK, - size_t StrideQuantBData, - size_t StrideQuantBScale, - size_t StrideQuantBZeroPoint, - const float* BiasPtr +template +void +SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen16( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias ) { - static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); - constexpr size_t BlkBitWidth = 4; + constexpr size_t BlkLen = 16; - constexpr size_t SubBlkLen = 16; // number of block elements to process in a sub-block iteration - assert(BlkLen % SubBlkLen == 0); + float* CRowPtr = C; - const uint8x8_t LowMask = vdup_n_u8(0x0F); + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - const std::byte* QuantA = QuantARowPtr; + const float* BiasPtr = Bias; - const std::byte* QuantBData = QuantBDataColPtr; - const float* QuantBScale = QuantBScaleColPtr; - size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - float32x4_t acc[NCols]{}; + float* SumPtr = CRowPtr; - for (size_t k = 0; k < CountK; k += BlkLen) { - const size_t k_blk_len = std::min(CountK - k, BlkLen); + const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); + const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F); - const float a_scale = Q8BlkScale(QuantA); - const int8_t* a_data = Q8BlkData(QuantA); + for (size_t n = 0; n < CountN; ++n) { + const std::byte* QuantAPtr = QuantA; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - float b_scale[NCols]; - UnrolledLoop([&](size_t i) { b_scale[i] = QuantBScale[i * StrideQuantBScale]; }); + float32x4_t acc0{}, acc1{}; - int8_t b_zp[NCols]; - if (QuantBZeroPointColPtr != nullptr) { - UnrolledLoop([&](size_t i) { - const std::byte zp_packed = - QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - b_zp[i] = ((QuantBZeroPointIdx & 1) == 1) - ? std::to_integer(zp_packed >> 4) - : std::to_integer(zp_packed & std::byte{0x0F}); - }); - } else { - UnrolledLoop([&](size_t i) { - b_zp[i] = 8; - }); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= 2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); + + // compute combined scale + const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]); + const float32x4_t scale1 = vdupq_n_f32(Q8BlkScale(QuantABlk1) * QuantBScalePtr[1]); + + // load B zero point + const int8x16_t bzp0 = vdupq_n_s8( + HasZeroPoint ? std::to_integer(QuantBZeroPointPtr[0] & std::byte{0x0F}) : 8 + ); + const int8x16_t bzp1 = vdupq_n_s8( + HasZeroPoint ? std::to_integer(QuantBZeroPointPtr[0] >> 4) : 8 + ); + + // load A + const int8x16_t av0 = vld1q_s8(Q8BlkData(QuantABlk0)); + const int8x16_t av1 = vld1q_s8(Q8BlkData(QuantABlk1)); + + // load B + const uint8x16_t bv_packed01 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + + const uint8x16_t bv_lo01 = vandq_u8(bv_packed01, LowMaskU8x16); + const uint8x16_t bv_hi01 = vshrq_n_u8(bv_packed01, 4); + + int8x16_t bv0 = vreinterpretq_s8_u8(vcombine_u8(vget_low_u8(bv_lo01), vget_low_u8(bv_hi01))); + int8x16_t bv1 = vreinterpretq_s8_u8(vcombine_u8(vget_high_u8(bv_lo01), vget_high_u8(bv_hi01))); + + // subtract B zero point + bv0 = vsubq_s8(bv0, bzp0); + bv1 = vsubq_s8(bv1, bzp1); + + // quantized dot product + const int32x4_t dot0 = vdotq_s32(vdupq_n_s32(0), av0, bv0); + const int32x4_t dot1 = vdotq_s32(vdupq_n_s32(0), av1, bv1); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); + acc1 = vfmaq_f32(acc1, dot_f32_1, scale1); + + // increment block pointers + + QuantAPtr += Q8BlkSize(BlkLen) * 2; + QuantBDataPtr += 8 * 2; + QuantBScalePtr += 2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } } - for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { - // load A row vector - int8x16_t av = vld1q_s8(a_data + k_idx_in_blk); + if (k_blks_remaining > 0) { + const std::byte* QuantABlk0 = QuantAPtr; - // load B column vectors - uint8x8_t bv_packed[NCols]; - const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; - UnrolledLoop([&](size_t i) { - bv_packed[i] = vld1_u8( - reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset - ); - }); + // compute combined scale + const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * (*QuantBScalePtr)); - int8x16_t bv[NCols]; - UnrolledLoop([&](size_t i) { - const int8x8_t lo = vreinterpret_s8_u8(vand_u8(bv_packed[i], LowMask)); - const int8x8_t hi = vreinterpret_s8_u8(vshr_n_u8(bv_packed[i], 4)); - bv[i] = vcombine_s8(lo, hi); - }); + // load B zero point + const int8x16_t bzp0 = vdupq_n_s8( + HasZeroPoint ? std::to_integer(QuantBZeroPointPtr[0] & std::byte{0x0F}) : 8 + ); + + // load A + const int8x16_t av0 = vld1q_s8(Q8BlkData(QuantABlk0)); + + // load B + const uint8x8_t bv_packed0 = vld1_u8(reinterpret_cast(QuantBDataPtr)); + + const uint8x8_t bv_lo0 = vand_u8(bv_packed0, LowMaskU8x8); + const uint8x8_t bv_hi0 = vshr_n_u8(bv_packed0, 4); + + int8x16_t bv0 = vreinterpretq_s8_u8(vcombine_u8(bv_lo0, bv_hi0)); // subtract B zero point - UnrolledLoop([&](size_t i) { - const int8x16_t zp_v = vdupq_n_s8(b_zp[i]); - bv[i] = vsubq_s8(bv[i], zp_v); - }); + bv0 = vsubq_s8(bv0, bzp0); - // compute quantized dot product - int32x4_t dot[NCols]{}; - UnrolledLoop([&](size_t i) { - dot[i] = vdotq_s32(dot[i], av, bv[i]); - }); + // quantized dot product + const int32x4_t dot0 = vdotq_s32(vdupq_n_s32(0), av0, bv0); - // convert dot product result to float - float32x4_t dot_f32[NCols]; - UnrolledLoop([&](size_t i) { - dot_f32[i] = vcvtq_f32_s32(dot[i]); - }); + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); - // multiply dot product result by scale and update accumulator - UnrolledLoop([&](size_t i) { - const float32x4_t scale_v = vdupq_n_f32(a_scale * b_scale[i]); - acc[i] = vfmaq_f32(acc[i], dot_f32[i], scale_v); - }); + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); } - // increment pointers to next block - QuantA += Q8BlkSize(BlkLen); - QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - QuantBScale += 1; - QuantBZeroPointIdx += 1; - } + *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } - if constexpr (NCols == 4) { - float32x4_t sum = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + // move to next column - if (BiasPtr != nullptr) { - sum = vaddq_f32(sum, vld1q_f32(BiasPtr)); + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; } - vst1q_f32(SumPtr, sum); - } else { - for (size_t i = 0; i < NCols; ++i) { - SumPtr[i] = vaddvq_f32(acc[i]); - if (BiasPtr != nullptr) { - SumPtr[i] += BiasPtr[i]; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +template +void +SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen32( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias +) +{ + constexpr size_t BlkBitWidth = 4; + constexpr size_t BlkLen = 32; + + float* CRowPtr = C; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); + + for (size_t n = 0; n < CountN; ++n) { + const std::byte* QuantAPtr = QuantA; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + float32x4_t acc0{}, acc1{}; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= 2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); + + // compute combined scale + const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]); + const float32x4_t scale1 = vdupq_n_f32(Q8BlkScale(QuantABlk1) * QuantBScalePtr[1]); + + // load B zero point + const int8x16_t bzp0 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 + ); + const int8x16_t bzp1 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) >> 4) : 8 + ); + + // load A + const int8x16_t av_lo0 = vld1q_s8(Q8BlkData(QuantABlk0)); + const int8x16_t av_hi0 = vld1q_s8(Q8BlkData(QuantABlk0) + 16); + const int8x16_t av_lo1 = vld1q_s8(Q8BlkData(QuantABlk1)); + const int8x16_t av_hi1 = vld1q_s8(Q8BlkData(QuantABlk1) + 16); + + // load B + const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + const uint8x16_t bv_packed1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + 16); + + int8x16_t bv_lo0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); + int8x16_t bv_hi0 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); + int8x16_t bv_lo1 = vreinterpretq_s8_u8(vandq_u8(bv_packed1, LowMaskU8x16)); + int8x16_t bv_hi1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed1, 4)); + + // subtract B zero point + bv_lo0 = vsubq_s8(bv_lo0, bzp0); + bv_hi0 = vsubq_s8(bv_hi0, bzp0); + bv_lo1 = vsubq_s8(bv_lo1, bzp1); + bv_hi1 = vsubq_s8(bv_hi1, bzp1); + + // quantized dot product + int32x4_t dot0{}, dot1{}; + dot0 = vdotq_s32(vdotq_s32(dot0, av_lo0, bv_lo0), av_hi0, bv_hi0); + dot1 = vdotq_s32(vdotq_s32(dot1, av_lo1, bv_lo1), av_hi1, bv_hi1); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); + acc1 = vfmaq_f32(acc1, dot_f32_1, scale1); + + // increment block pointers + + QuantAPtr += Q8BlkSize(BlkLen) * 2; + QuantBDataPtr += 16 * 2; + QuantBScalePtr += 2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; } } + + if (k_blks_remaining > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + + // compute combined scale + const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * (*QuantBScalePtr)); + + // load B zero point + const int8x16_t bzp0 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPoint) & std::byte{0x0F}) : 8 + ); + + // load A + const int8x16_t av_lo0 = vld1q_s8(Q8BlkData(QuantABlk0)); + const int8x16_t av_hi0 = vld1q_s8(Q8BlkData(QuantABlk0) + 16); + + // load B + const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + + int8x16_t bv_lo0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); + int8x16_t bv_hi0 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); + + // subtract B zero point + bv_lo0 = vsubq_s8(bv_lo0, bzp0); + bv_hi0 = vsubq_s8(bv_hi0, bzp0); + + // quantized dot product + int32x4_t dot0{}; + dot0 = vdotq_s32(vdotq_s32(dot0, av_lo0, bv_lo0), av_hi0, bv_hi0); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); + } + + *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; } } -MLAS_FORCEINLINE +template void -SQ4BitGemmM1Kernel_CompInt8( +SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLenGreaterThan32( size_t BlkLen, const std::byte* QuantA, const std::byte* QuantBData, @@ -667,18 +1002,16 @@ SQ4BitGemmM1Kernel_CompInt8( const std::byte* QuantBZeroPoint, float* C, size_t CountN, - size_t CountK, - size_t BlockStrideQuantB, + size_t BlockCountK, const float* Bias ) { constexpr size_t BlkBitWidth = 4; - constexpr size_t NCols = 4; - const std::byte* QuantARowPtr = QuantA; - float* CRowPtr = C; + assert(BlkLen > 32); + assert(BlkLen % 32 == 0); - const size_t BlockCountK = BlockStrideQuantB; + float* CRowPtr = C; const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); const size_t StrideQuantBScale = BlockCountK; @@ -692,45 +1025,97 @@ SQ4BitGemmM1Kernel_CompInt8( float* SumPtr = CRowPtr; - int64_t nblk = static_cast(CountN) - NCols; + const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); - while (nblk >= 0) { - ComputeDotProducts_BlkBitWidth4_CompInt8( - BlkLen, - QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - BiasPtr - ); + // process blocks in 32-element sub-blocks + const size_t SubBlksPerBlk = BlkLen / 32; - // move to next `NCols` columns + for (size_t n = 0; n < CountN; ++n) { + const std::byte* QuantAPtr = QuantA; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - QuantBDataColPtr += NCols * StrideQuantBData; - QuantBScaleColPtr += NCols * StrideQuantBScale; - if (QuantBZeroPointColPtr != nullptr) { - QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; - } + float32x4_t acc0{}, acc1{}; - BiasPtr += BiasPtr != nullptr ? NCols : 0; - SumPtr += NCols; + for (size_t k_blk_idx = 0; k_blk_idx < BlockCountK; ++k_blk_idx) { + // compute combined scale + const float32x4_t scale = vdupq_n_f32(Q8BlkScale(QuantAPtr) * (*QuantBScalePtr)); - nblk -= NCols; - } + // load B zero point + const int8x16_t bzp = [&]() -> int8x16_t { + if constexpr (HasZeroPoint) { + return vdupq_n_s8( + ((k_blk_idx & 1) == 0) ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) + : std::to_integer((*QuantBZeroPointPtr) >> 4) + ); + } else { + return vdupq_n_s8(8); + } + }(); + + const int8_t* QuantADataPtr = Q8BlkData(QuantAPtr); + + for (size_t sub_blk_idx = 0; sub_blk_idx < SubBlksPerBlk; sub_blk_idx += 2) { + // load A + const int8x16_t av0 = vld1q_s8(QuantADataPtr + 0); + const int8x16_t av1 = vld1q_s8(QuantADataPtr + 16); + const int8x16_t av2 = vld1q_s8(QuantADataPtr + 32); + const int8x16_t av3 = vld1q_s8(QuantADataPtr + 48); + + // load B + const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + const uint8x16_t bv_packed1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + 16); + + int8x16_t bv0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); + int8x16_t bv1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); + int8x16_t bv2 = vreinterpretq_s8_u8(vandq_u8(bv_packed1, LowMaskU8x16)); + int8x16_t bv3 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed1, 4)); + + // subtract B zero point + bv0 = vsubq_s8(bv0, bzp); + bv1 = vsubq_s8(bv1, bzp); + bv2 = vsubq_s8(bv2, bzp); + bv3 = vsubq_s8(bv3, bzp); + + // quantized dot product + int32x4_t dot0{}, dot1{}; + dot0 = vdotq_s32(vdotq_s32(dot0, av0, bv0), av1, bv1); + dot1 = vdotq_s32(vdotq_s32(dot1, av2, bv2), av3, bv3); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale); + acc1 = vfmaq_f32(acc1, dot_f32_1, scale); + + // increment block data pointers to next sub-block + QuantADataPtr += 16 * 4; + QuantBDataPtr += 16 * 2; + } - // left over columns less than `NCols`? - nblk += NCols; - for (int64_t n = 0; n < nblk; ++n) { - ComputeDotProducts_BlkBitWidth4_CompInt8<1>( - BlkLen, - QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - BiasPtr - ); + // increment other block pointers + + QuantAPtr += Q8BlkSize(BlkLen); + QuantBScalePtr += 1; + + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += ((k_blk_idx & 1) == 0) ? 0 : 1; + } + } + + *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } // move to next column QuantBDataColPtr += StrideQuantBData; QuantBScaleColPtr += StrideQuantBScale; - if (QuantBZeroPointColPtr != nullptr) { + if constexpr (HasZeroPoint) { QuantBZeroPointColPtr += StrideQuantBZeroPoint; } @@ -739,6 +1124,99 @@ SQ4BitGemmM1Kernel_CompInt8( } } +template +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (BlkLen == 16) { + SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen16( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } else if (BlkLen == 32) { + SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen32( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLenGreaterThan32( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } +} + +MLAS_FORCEINLINE +void +SQ4BitGemmM1Kernel_CompInt8( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t /*CountK*/, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } +} + } // namespace // @@ -748,8 +1226,12 @@ SQ4BitGemmM1Kernel_CompInt8( const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { MLAS_SQNBIT_GEMM_DISPATCH d; + d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32; + d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8; d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; diff --git a/onnxruntime/core/optimizer/common_subexpression_elimination.cc b/onnxruntime/core/optimizer/common_subexpression_elimination.cc index b2e7ef0b4f558..48df511d0c672 100644 --- a/onnxruntime/core/optimizer/common_subexpression_elimination.cc +++ b/onnxruntime/core/optimizer/common_subexpression_elimination.cc @@ -4,6 +4,7 @@ #include "common_subexpression_elimination.h" #include "core/optimizer/utils.h" #include "core/graph/graph_utils.h" +#include "core/framework/tensorprotoutils.h" #include #include @@ -170,6 +171,32 @@ bool AreRangesEqual(const Range& lhs, const Range& rhs) { std::equal(lhs.begin(), lhs.end(), rhs.begin()); } +// Check if two tensor attributes are equal scalar tensors, mainly to support ConstantOfShape Op. +// Currently support float, float16 and int64 data types, and requires the data are raw data in TensorProto. +bool AreScalarTensorAttributeEqual(const ONNX_NAMESPACE::TensorProto& lhs_t, const ONNX_NAMESPACE::TensorProto& rhs_t) { + if (!(utils::HasDataType(lhs_t) && utils::HasDataType(rhs_t) && lhs_t.data_type() == rhs_t.data_type() && + (lhs_t.data_type() == onnx::TensorProto_DataType_FLOAT || + lhs_t.data_type() == onnx::TensorProto_DataType_FLOAT16 || + lhs_t.data_type() == onnx::TensorProto_DataType_INT64) && + lhs_t.dims_size() == 1 && rhs_t.dims_size() == 1 && lhs_t.dims()[0] == 1 && rhs_t.dims()[0] == 1 && + utils::HasRawData(lhs_t) && utils::HasRawData(rhs_t))) { + return false; + } + const void* lhs_value = lhs_t.raw_data().data(); + const void* rhs_value = rhs_t.raw_data().data(); + switch (lhs_t.data_type()) { + case onnx::TensorProto_DataType_FLOAT: + return *reinterpret_cast(lhs_value) == *reinterpret_cast(rhs_value); + case onnx::TensorProto_DataType_FLOAT16: + return *reinterpret_cast(lhs_value) == *reinterpret_cast(rhs_value); + case onnx::TensorProto_DataType_INT64: + return *reinterpret_cast(lhs_value) == *reinterpret_cast(rhs_value); + default: + break; + } + return false; +} + bool AreEqual(const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::AttributeProto& rhs) { if (&lhs == &rhs) { return true; @@ -193,6 +220,7 @@ bool AreEqual(const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::A case onnx::AttributeProto_AttributeType_STRINGS: return AreRangesEqual(lhs.strings(), rhs.strings()); case onnx::AttributeProto_AttributeType_TENSOR: + return AreScalarTensorAttributeEqual(lhs.t(), rhs.t()); case onnx::AttributeProto_AttributeType_GRAPH: case onnx::AttributeProto_AttributeType_SPARSE_TENSOR: case onnx::AttributeProto_AttributeType_TYPE_PROTO: @@ -207,6 +235,31 @@ bool AreEqual(const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::A return false; } +// Support scalar float/int64/fp16 tensor attribute only for now, and requires data is raw data in TensorProto. +std::size_t GetTensorAttributeHash(const ONNX_NAMESPACE::TensorProto& attr_t) { + std::size_t hash = 0; + if (utils::HasDataType(attr_t) && attr_t.dims_size() == 1 && attr_t.dims()[0] == 1 && utils::HasRawData(attr_t)) { + int data_type = attr_t.data_type(); + switch (data_type) { + case onnx::TensorProto_DataType_FLOAT: + UpdateHash(data_type, hash); + UpdateHash(*reinterpret_cast(attr_t.raw_data().data()), hash); + break; + case onnx::TensorProto_DataType_FLOAT16: + UpdateHash(data_type, hash); + UpdateHash(static_cast(*reinterpret_cast(attr_t.raw_data().data())), hash); + break; + case onnx::TensorProto_DataType_INT64: + UpdateHash(data_type, hash); + UpdateHash(*reinterpret_cast(attr_t.raw_data().data()), hash); + break; + default: + break; + } + } + return hash; +} + std::size_t GetAttributeHash(const ONNX_NAMESPACE::AttributeProto& attr) { std::size_t hash = 0; UpdateHash( @@ -233,6 +286,8 @@ std::size_t GetAttributeHash(const ONNX_NAMESPACE::AttributeProto& attr) { UpdateHashWithContainer(attr.strings(), hash); break; case onnx::AttributeProto_AttributeType_TENSOR: + UpdateHash(attr.t(), &GetTensorAttributeHash, hash); + break; case onnx::AttributeProto_AttributeType_GRAPH: case onnx::AttributeProto_AttributeType_SPARSE_TENSOR: case onnx::AttributeProto_AttributeType_TYPE_PROTO: diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc index 9c98ed6d3e114..1516fb37a7e9f 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc @@ -4,6 +4,7 @@ #ifdef ENABLE_TRAINING #include +#include "core/common/string_utils.h" #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/utils.h" @@ -26,38 +27,38 @@ UpStreamGatherGraphTransformer::UpStreamGatherGraphTransformer( // 2. Whether the outputs have the same dim changes if the Gather node moves before that operator. // 3. Should all inputs be allowed when tracking back further (bottom-up); // if not, add the input index restriction as MatMul did. - {GetFullQualifiedOpName("Add", kOnnxDomain), + {utils::GetFullQualifiedOpName("Add", kOnnxDomain), OpPassThroughConfig(std::make_shared>(), opset_14_13_7_6_1)}, - {GetFullQualifiedOpName("BiasGelu", kMSDomain), + {utils::GetFullQualifiedOpName("BiasGelu", kMSDomain), OpPassThroughConfig(std::make_shared>(), opset_1)}, - {GetFullQualifiedOpName("Cast", kOnnxDomain), + {utils::GetFullQualifiedOpName("Cast", kOnnxDomain), OpPassThroughConfig(std::make_shared>(), opset_19_13_9_6_1)}, - {GetFullQualifiedOpName("Div", kOnnxDomain), + {utils::GetFullQualifiedOpName("Div", kOnnxDomain), OpPassThroughConfig(std::make_shared>(), opset_14_13_7_6_1)}, - {GetFullQualifiedOpName("Dropout", kOnnxDomain), + {utils::GetFullQualifiedOpName("Dropout", kOnnxDomain), OpPassThroughConfig(std::make_shared>(), opset_13_12_10_7_6_1)}, - {GetFullQualifiedOpName("Gelu", kMSDomain), + {utils::GetFullQualifiedOpName("Gelu", kMSDomain), OpPassThroughConfig(std::make_shared>(), opset_1)}, {// Be noted, this is our own implementation of ONNX domain op. - GetFullQualifiedOpName("LayerNormalization", kOnnxDomain), + utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain), OpPassThroughConfig(std::make_shared(), opset_1)}, - {GetFullQualifiedOpName("MatMul", kOnnxDomain), + {utils::GetFullQualifiedOpName("MatMul", kOnnxDomain), OpPassThroughConfig(std::make_shared(), opset_13_9_1)}, - {GetFullQualifiedOpName("Reshape", kOnnxDomain), + {utils::GetFullQualifiedOpName("Reshape", kOnnxDomain), OpPassThroughConfig(std::make_shared(), opset_19_14_13_5_1)}, - {GetFullQualifiedOpName("Softmax", kOnnxDomain), + {utils::GetFullQualifiedOpName("Softmax", kOnnxDomain), OpPassThroughConfig(std::make_shared(), opset_13_11_1)}, - {GetFullQualifiedOpName("Transpose", kOnnxDomain), + {utils::GetFullQualifiedOpName("Transpose", kOnnxDomain), OpPassThroughConfig(std::make_shared(), opset_13_1)}, }); @@ -69,7 +70,7 @@ bool UpStreamGatherGraphTransformer::UpStreamInternal( const OpPassThroughConfig& pass_through_config, const logging::Logger& logger) const { Node& slice_node = *info.node_ptr; - const std::string op_type = GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); + const std::string op_type = utils::GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); std::unordered_map propagate_input_indices; std::unordered_map> all_input_cmp_rets; diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc index f7b48de2caaf5..716988e93312c 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc @@ -4,6 +4,7 @@ #ifdef ENABLE_TRAINING #include "core/framework/tensorprotoutils.h" +#include "core/common/string_utils.h" #include "core/graph/graph_utils.h" #include "core/optimizer/utils.h" #include "core/optimizer/compute_optimizer/upstream_reshape_actors.h" @@ -21,23 +22,23 @@ UpStreamReshapeGraphTransformer::UpStreamReshapeGraphTransformer( // If optype is not enough to guarantee the equivalence, we need to add a customized pre-check function. // 2. Should all inputs be allowed when tracking back further (bottom-up); // if not, add the input index restriction. - {GetFullQualifiedOpName("Add", kOnnxDomain), + {utils::GetFullQualifiedOpName("Add", kOnnxDomain), OpPassThroughConfig( std::make_shared>(), opset_14_13_7_6_1)}, - {GetFullQualifiedOpName("BiasGelu", kMSDomain), + {utils::GetFullQualifiedOpName("BiasGelu", kMSDomain), OpPassThroughConfig( std::make_shared>(), opset_1)}, - {GetFullQualifiedOpName("Cast", kOnnxDomain), + {utils::GetFullQualifiedOpName("Cast", kOnnxDomain), OpPassThroughConfig( std::make_shared>(), opset_19_13_9_6_1)}, - {GetFullQualifiedOpName("Dropout", kOnnxDomain), + {utils::GetFullQualifiedOpName("Dropout", kOnnxDomain), OpPassThroughConfig( std::make_shared>(), opset_13_12_10_7_6_1)}, {// Be noted, this is our own implementation of ONNX domain op. - GetFullQualifiedOpName("LayerNormalization", kOnnxDomain), + utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain), OpPassThroughConfig( std::make_shared(), opset_1)}, - {GetFullQualifiedOpName("MatMul", kOnnxDomain), + {utils::GetFullQualifiedOpName("MatMul", kOnnxDomain), OpPassThroughConfig( std::make_shared(), opset_13_9_1)}, }); @@ -47,7 +48,7 @@ bool UpStreamReshapeGraphTransformer::UpStreamInternal( Graph& graph, std::deque& queue, Node& current_node, ReshapeInfo& info, const OpPassThroughConfig& pass_through_config, const logging::Logger& logger) const { - const std::string op_type = GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); + const std::string op_type = utils::GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); std::vector propagate_input_indices; std::unordered_map> all_input_cmp_rets; diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.cc index f08e37296d259..4582f26a7dc68 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.cc @@ -5,6 +5,7 @@ #include #include "core/common/safeint.h" +#include "core/common/string_utils.h" #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/utils.h" @@ -130,7 +131,7 @@ template bool UpStreamGraphTransformerBase::Upstream(Graph& graph, std::deque& queue, Node& current_node, T1& info, const logging::Logger& logger) const { - const std::string op_type = GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); + const std::string op_type = utils::GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); if (allowed_passthrough_ops_.count(op_type)) { auto& pass_through_config = allowed_passthrough_ops_.at(op_type); LOG_DEBUG_INFO(logger, "Enter reorder handle for node " + current_node.Name() + "(" + op_type + ")"); diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.h b/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.h index 6e22fc791ade3..d848a03c555bb 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.h +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.h @@ -72,13 +72,6 @@ class UpStreamGraphTransformerBase : public GraphTransformer { const OpPassThroughConfig& pass_through_config, const logging::Logger& logger) const = 0; - /** - * @brief A consistent way to construct the full qualified op name. - */ - std::string GetFullQualifiedOpName(const std::string& op_type, const std::string& domain) const { - return domain + "::" + op_type; - } - std::unordered_map> allowed_passthrough_ops_; private: diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index d27603e4ab3a1..b7cb3ba488c62 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -111,7 +111,7 @@ class ConvActivationSelector : public NodeSelector { if (!graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Relu", {6, 13, 14})) { return std::nullopt; } - } else if (node_ep.empty() || node_ep == kCpuExecutionProvider) { + } else if (node_ep.empty() || node_ep == kCpuExecutionProvider || node_ep == kJsExecutionProvider) { if (!is_supported_non_cuda_rocm_ep_activation(*next_node) && !graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "HardSigmoid", {6})) { return std::nullopt; diff --git a/onnxruntime/core/optimizer/gather_fusion.cc b/onnxruntime/core/optimizer/gather_fusion.cc index 4903bc1d6b961..90cabff88122c 100644 --- a/onnxruntime/core/optimizer/gather_fusion.cc +++ b/onnxruntime/core/optimizer/gather_fusion.cc @@ -9,55 +9,144 @@ namespace onnxruntime { -bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, - int64_t& indices_n_dims) const { - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) || +namespace { +static int64_t GetGatherAxis(const Node& node, int64_t rank) { + int64_t axis = 0; + auto& attrs = node.GetAttributes(); + if (attrs.find("axis") != attrs.end()) { + auto& axis_attr = attrs.at("axis"); + if (utils::HasInt(axis_attr)) { + axis = axis_attr.i(); + if (axis < 0) axis += rank; + } + } + return axis; +} + +static bool GetScalarInt64Initializer(const Graph& graph, const NodeArg& node_arg, int64_t& value, int64_t& rank) { + if (!optimizer_utils::IsScalar(node_arg)) return false; + const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node_arg.Name()); + if (!tensor_proto || tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto::INT64) return false; + Initializer init_const{*tensor_proto, graph.ModelPath()}; + value = *(init_const.data()); + rank = tensor_proto->dims_size(); + return true; +} + +static bool GetSliceAxis(const Graph& graph, const Node& node, int64_t rank, int64_t& axis) { + if (node.InputDefs().size() < 4) return false; + int64_t unused = 0; + if (!GetScalarInt64Initializer(graph, *node.InputDefs()[3], axis, unused)) return false; + if (axis < 0) axis += rank; + return true; +} + +static bool GetAxis(const Graph& graph, const Node& node, int64_t rank, int64_t& axis) { + if (node.OpType() == "Gather") { + axis = GetGatherAxis(node, rank); + return true; + } + if (node.OpType() == "Slice") { + return GetSliceAxis(graph, node, rank, axis); + } + return false; +} + +} // namespace + +bool GatherSliceToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t rank, + int64_t target_axis, int64_t dim_size, InlinedVector& consumed, + int64_t& start, bool& need_squeeze) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {13}) || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { return false; } - const NodeArg& input_arg = *(node.InputDefs()[1]); - if (!optimizer_utils::IsScalar(input_arg)) return false; - const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); - if (!tensor_proto) return false; - if (tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT64) return false; - Initializer init_const{*tensor_proto, graph.ModelPath()}; - index = *(init_const.data()); - axis = 0; // Default value. - auto& attrs = node.GetAttributes(); - if (attrs.find("axis") != attrs.end()) { - auto& axis_attr = attrs.at("axis"); - if (utils::HasInt(axis_attr)) axis = axis_attr.i(); + if (GetGatherAxis(node, rank) != target_axis) return false; + // Require the indices input to be a scalar tensor for now. Normally if not, the exporter will choose Slice. + // We can relax this later if needed. + int64_t indices_n_dims = 0; + if (!GetScalarInt64Initializer(graph, *(node.InputDefs()[1]), start, indices_n_dims)) return false; + if (start < 0) start += dim_size; + if (start < 0 || start >= dim_size || consumed[static_cast(start)]) return false; + consumed[static_cast(start)] = true; + need_squeeze = indices_n_dims == 0; + return true; +} + +bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& node, int64_t rank, int64_t target_axis, + int64_t dim_size, InlinedVector& consumed, int64_t& start, + int64_t& end) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {13}) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { + return false; + } + + int64_t axis = 0; + if (!GetSliceAxis(graph, node, rank, axis) || axis != target_axis) return false; + int64_t unused = 0; + if (!GetScalarInt64Initializer(graph, *node.InputDefs()[1], start, unused) || + !GetScalarInt64Initializer(graph, *node.InputDefs()[2], end, unused)) { + return false; + } + // Handling start and end according to schema definition. + if (start < 0) start += dim_size; + if (end < 0) end += dim_size; + if (start < 0) + start = 0; + else if (start > dim_size) + start = dim_size; + if (end < 0) + end = 0; + else if (end > dim_size) + end = dim_size; + if (start >= end) return false; + if (node.InputDefs().size() >= 5) { + int64_t step = 0; + if (!GetScalarInt64Initializer(graph, *node.InputDefs()[4], step, unused) || step != 1) return false; + } + for (int64_t i = start; i < end; ++i) { + if (consumed[static_cast(i)]) return false; + consumed[static_cast(i)] = true; } - indices_n_dims = tensor_proto->dims_size(); return true; } /* -GatherToSplitFusion is to fuse: -Node -> Gather(index=0, axis=axis) - |-> Gather(index=1, axis=axis) - |-> Gather(index=2, axis=axis) +GatherSliceToSplitFusion is to fuse: +Node -> Gather(indices=0, axis=axis) + |-> Gather(indices=[1], axis=axis) + |-> Slice(start=2, end=3, axes=[axis]) |... To Node -> Split -> Squeeze(axis=axis) - |-> Squeeze(axis=axis) - |-> Squeeze(axis=axis) + |-> + |-> |... So that we can use one kernel to finish the job. +The fusion requires that the indices of Gather nodes and start/end of Slice nodes are not overlapping and cover +all the elements in the target axis. Step of Slice node should be 1. */ -Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, - const logging::Logger& logger) const { +Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { + // Squeeze, Gather, Slice and Split have different schemas before and after OpSet 13. + // To make code simple, support OpSet >= 13 only. + int onnx_opset_version = -1; + if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { + onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); + } + if (onnx_opset_version < 13) return Status::OK(); + GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); - InlinedVector node_args; + InlinedVector candidate_args; for (auto node_arg : graph.GetInputs()) { if (node_arg && graph.GetConsumerNodes(node_arg->Name()).size() > 1) { - node_args.push_back(node_arg); + candidate_args.push_back(node_arg); } } @@ -65,7 +154,7 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le if (graph.GetConsumerNodes(entry.first).size() > 1) { auto node_arg = graph.GetNodeArg(entry.first); if (node_arg) { - node_args.push_back(node_arg); + candidate_args.push_back(node_arg); } } } @@ -90,129 +179,108 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le size_t output_count = node.GetOutputEdgesCount(); if (output_count <= 1) continue; - node_args.push_back(node.OutputDefs()[0]); + candidate_args.push_back(node.OutputDefs()[0]); } - for (const NodeArg* node_arg : node_args) { + for (const NodeArg* node_arg : candidate_args) { auto shape = node_arg->Shape(); if (!shape) continue; int64_t rank = static_cast(shape->dim_size()); - - bool can_fuse = true; - bool first_edge = true; - int64_t split_axis = 0; - int64_t indices_n_dims = -1; auto consumers = graph.GetConsumerNodes(node_arg->Name()); - size_t consumer_count = consumers.size(); - InlinedVector gather_outputs(consumer_count, nullptr); - InlinedVector> nodes_to_fuse; + InlinedVector condidate_consumers; for (auto consumer : consumers) { - int64_t index, axis, dims; - if (!consumer || consumer->InputDefs()[0] != node_arg || - !IsSupportedGather(graph, *consumer, index, axis, dims)) { - can_fuse = false; - break; - } - if (indices_n_dims == -1) { - indices_n_dims = dims; - } else if (indices_n_dims != dims) { - // Not the same number of dimensions (0 or 1) for all scalar indices. - can_fuse = false; - break; + if (consumer && consumer->InputDefs()[0] == node_arg && + (consumer->OpType() == "Gather" || consumer->OpType() == "Slice")) { + condidate_consumers.emplace_back(consumer); } - if (axis < 0) axis += rank; - if (first_edge) { - auto dim = shape->dim(static_cast(axis)); - if (!utils::HasDimValue(dim) || dim.dim_value() != static_cast(consumer_count)) { - can_fuse = false; - break; - } - split_axis = axis; - first_edge = false; - } else if (axis != split_axis) { + } + if (condidate_consumers.size() < 2) continue; + int64_t axis = 0; + if (!GetAxis(graph, *condidate_consumers[0], rank, axis)) continue; + auto dim = shape->dim(static_cast(axis)); + if (!utils::HasDimValue(dim)) continue; + int64_t dim_size = dim.dim_value(); + InlinedVector consumed(static_cast(dim_size), false); + bool can_fuse = true; + InlinedVector> nodes_to_fuse; + InlinedVector starts; + InlinedHashMap> output_info_map; + for (auto consumer : condidate_consumers) { + if (!consumer || consumer->InputDefs()[0] != node_arg) { can_fuse = false; break; } - if (index < 0) index += static_cast(consumer_count); - if (index < 0 || index >= static_cast(consumer_count) || gather_outputs[static_cast(index)]) { + int64_t start = 0, end = 0; + bool need_squeeze = false; + if (IsSupportedGather(graph, *consumer, rank, axis, dim_size, consumed, start, need_squeeze)) { + Node& gather_node = *graph.GetNode(consumer->Index()); + nodes_to_fuse.emplace_back(gather_node); + starts.emplace_back(start); + output_info_map[start] = std::make_tuple(gather_node.MutableOutputDefs()[0], 1, need_squeeze); + } else if (IsSupportedSlice(graph, *consumer, rank, axis, dim_size, consumed, start, end)) { + Node& slice_node = *graph.GetNode(consumer->Index()); + nodes_to_fuse.emplace_back(slice_node); + starts.emplace_back(start); + output_info_map[start] = std::make_tuple(slice_node.MutableOutputDefs()[0], end - start, false); + } else { can_fuse = false; break; } - Node& gather_node = *graph.GetNode(consumer->Index()); - nodes_to_fuse.emplace_back(gather_node); - gather_outputs[static_cast(index)] = gather_node.MutableOutputDefs()[0]; - } - - if (!can_fuse) continue; - - ONNX_NAMESPACE::TypeProto split_output_type; - const ONNX_NAMESPACE::TensorProto_DataType element_type = - static_cast(node_arg->TypeAsProto()->tensor_type().elem_type()); - split_output_type.mutable_tensor_type()->set_elem_type(element_type); - for (int64_t i = 0; i < rank; ++i) { - if (i == split_axis) { - split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL); - } else { - *(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast(i)); - } } + if (!can_fuse || std::find(consumed.begin(), consumed.end(), false) != consumed.end()) continue; + std::sort(starts.begin(), starts.end()); InlinedVector split_outputs; - bool add_squeeze_node = indices_n_dims == 0; - if (add_squeeze_node) { - for (size_t i = 0; i < consumer_count; ++i) { - split_outputs.emplace_back( - &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split" + std::to_string(i)), &split_output_type)); - } - } - - Node& split_node = - graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes", - {graph.GetNodeArg(node_arg->Name())}, add_squeeze_node ? split_outputs : gather_outputs); - split_node.AddAttribute("axis", split_axis); - split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); - - // Squeeze-11, Squeee-13, Split-13, Split-18 have different schemas. - int onnx_opset_version = -1; - if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { - onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); - } - - if (onnx_opset_version < 13) { - if (add_squeeze_node) { - for (size_t i = 0; i < consumer_count; ++i) { - Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze", - "Squeeze for Fused Gather nodes", {split_outputs[i]}, {gather_outputs[i]}); - squeeze_node.AddAttribute("axes", std::vector{split_axis}); - squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); + InlinedVector split_values; + for (int64_t start : starts) { + auto& output_info = output_info_map[start]; + NodeArg* original_output_arg = std::get<0>(output_info); + int64_t split_value = std::get<1>(output_info); + split_values.emplace_back(split_value); + if (std::get<2>(output_info)) { + ONNX_NAMESPACE::TypeProto split_output_type; + const ONNX_NAMESPACE::TensorProto_DataType element_type = + static_cast(node_arg->TypeAsProto()->tensor_type().elem_type()); + split_output_type.mutable_tensor_type()->set_elem_type(element_type); + for (int64_t i = 0; i < rank; ++i) { + if (i == axis) { + split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(split_value); + } else { + *(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast(i)); + } } - } - } else { - if (onnx_opset_version >= 18) { - split_node.AddAttribute("num_outputs", static_cast(consumer_count)); - } - - if (add_squeeze_node) { + NodeArg* split_output_arg = + &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split_output"), &split_output_type); ONNX_NAMESPACE::TensorProto axes_initializer_proto; - axes_initializer_proto.set_name(graph.GenerateNodeName("SqueezeAxesInitializer")); + axes_initializer_proto.set_name(graph.GenerateNodeName("squeeze_axes")); axes_initializer_proto.add_dims(static_cast(1)); axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - InlinedVector axes_value{split_axis}; - axes_initializer_proto.set_raw_data(axes_value.data(), axes_value.size() * sizeof(int64_t)); + axes_initializer_proto.add_int64_data(axis); NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto); - - for (size_t i = 0; i < consumer_count; ++i) { - Node& squeeze_node = - graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze", - "Squeeze for Fused Gather nodes", {split_outputs[i], axes_arg}, {gather_outputs[i]}); - squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); - } + Node& squeeze_node = + graph.AddNode(graph.GenerateNodeName("Squeeze"), "Squeeze", "Squeeze for Fused Gather nodes", + {split_output_arg, axes_arg}, {original_output_arg}); + squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); + split_outputs.emplace_back(split_output_arg); + } else { + split_outputs.emplace_back(original_output_arg); } } - for (Node& n : nodes_to_fuse) { - graph_utils::RemoveNodeOutputEdges(graph, n); - graph.RemoveNode(n.Index()); + ONNX_NAMESPACE::TensorProto split_initializer_proto; + split_initializer_proto.set_name(graph.GenerateNodeName("splits")); + split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + split_initializer_proto.add_dims(static_cast(split_values.size())); + split_initializer_proto.mutable_int64_data()->Add(split_values.begin(), split_values.end()); + NodeArg* split_initializer_arg = &graph_utils::AddInitializer(graph, split_initializer_proto); + Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes", + {graph.GetNodeArg(node_arg->Name()), split_initializer_arg}, split_outputs); + split_node.AddAttribute("axis", axis); + split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); + + for (Node& node : nodes_to_fuse) { + graph_utils::RemoveNodeOutputEdges(graph, node); + graph.RemoveNode(node.Index()); } modified = true; diff --git a/onnxruntime/core/optimizer/gather_fusion.h b/onnxruntime/core/optimizer/gather_fusion.h index 44c235915b6cc..098278a77dafe 100644 --- a/onnxruntime/core/optimizer/gather_fusion.h +++ b/onnxruntime/core/optimizer/gather_fusion.h @@ -8,19 +8,23 @@ namespace onnxruntime { /** -@Class GatherToSplitFusion +@Class GatherSliceToSplitFusion -Fuse multiple Gather nodes that comsuming one output to one Split node. +Fuse multiple Gather/Slice nodes that comsuming one output to one Split node. */ -class GatherToSplitFusion : public GraphTransformer { +class GatherSliceToSplitFusion : public GraphTransformer { public: - GatherToSplitFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept - : GraphTransformer("GatherToSplitFusion", compatible_execution_providers) {} + GatherSliceToSplitFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("GatherSliceToSplitFusion", compatible_execution_providers) {} Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; private: - bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const; + bool IsSupportedGather(const Graph& graph, const Node& node, int64_t rank, int64_t target_axis, int64_t dim_size, + InlinedVector& consumed, int64_t& start, bool& need_squeeze) const; + + bool IsSupportedSlice(const Graph& graph, const Node& node, int64_t rank, int64_t target_axis, int64_t dim_size, + InlinedVector& consumed, int64_t& start, int64_t& end) const; }; /** diff --git a/onnxruntime/core/optimizer/gemm_activation_fusion.cc b/onnxruntime/core/optimizer/gemm_activation_fusion.cc index c62887da09fdc..50be2cbd48f7b 100644 --- a/onnxruntime/core/optimizer/gemm_activation_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_activation_fusion.cc @@ -56,6 +56,13 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l continue; } + NodeArg* node_output = node.MutableOutputDefs()[0]; + auto data_type = node_output->TypeAsProto()->tensor_type().elem_type(); + if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + // FusedGemm is only registered for float data type in fused_gemm.cc! + continue; + } + const Node& next_node = *(node.OutputNodesBegin()); if (!IsFusableActivation(next_node) || next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { continue; diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index cd3c49be15aa4..63612c47f9c56 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -69,6 +69,7 @@ #include "core/optimizer/reshape_fusion.h" #include "core/optimizer/rocm_blas_alt_impl.h" #include "core/optimizer/rule_based_graph_transformer.h" +#include "core/optimizer/shape_input_merge.h" #include "core/optimizer/skip_layer_norm_fusion.h" #include "core/optimizer/slice_elimination.h" #include "core/optimizer/transpose_optimizer.h" @@ -211,9 +212,9 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique()); } - // Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for - // CSE. For example, if A and B nodes both do Add operation with a same value but different initializers, by - // default, CSE will not merge them, because the different initializers are represented by different NodeArg. + // Put ConstantSharing and ShapeInputMerge before CommonSubexpressionElimination by intention as it can create + // more opportunities for CSE. For example, if A and B nodes consume same different args but produce same output + // or consume different initializers with same value, by default, CSE will not merge them. InlinedHashSet excluded_initializers; excluded_initializers.reserve(session_options.initializers_to_share_map.size()); for (const auto& p : session_options.initializers_to_share_map) { @@ -221,7 +222,7 @@ InlinedVector> GenerateTransformers( } const InlinedHashSet no_limit_empty_ep_list = {}; transformers.emplace_back(std::make_unique(no_limit_empty_ep_list, excluded_initializers)); - + transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique(cpu_execution_provider, !disable_quant_qdq, session_options.config_options)); @@ -278,7 +279,8 @@ InlinedVector> GenerateTransformers( onnxruntime::kAclExecutionProvider, onnxruntime::kArmNNExecutionProvider, onnxruntime::kJsExecutionProvider}; - + const InlinedHashSet cpu_dml_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kDmlExecutionProvider}; #ifdef MLAS_TARGET_AMD64_IX86 const bool avx2_precision_mode = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow(); @@ -296,7 +298,7 @@ InlinedVector> GenerateTransformers( } transformers.emplace_back(std::make_unique(cpu_ep)); - transformers.emplace_back(std::make_unique(cpu_ep)); + transformers.emplace_back(std::make_unique(cpu_dml_eps)); transformers.emplace_back(std::make_unique(cpu_ep)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_acl_armnn_js_eps)); @@ -306,7 +308,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index 159e3b23d1ab0..ce696154adb6d 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -13,7 +13,7 @@ using namespace onnxruntime::common; namespace onnxruntime { // LayerNorm supports limited data types. -static constexpr std::array supported_data_types{"tensor(float16)", "tensor(float)", "tensor(double)"}; +static constexpr std::array supported_data_types{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}; // Default epsilon static constexpr float DEFAULT_LAYERNORM_EPSILON = 1e-5f; @@ -447,6 +447,13 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, NodeArg* x_input = has_leading_cast ? graph.GetNode(p_reduce_mean_input_node->Index())->MutableInputDefs()[0] : reduce_mean_node.MutableInputDefs()[0]; + + // CPU doesn't support fp16 + if (reduce_mean_node.GetExecutionProviderType() == kCpuExecutionProvider && + x_input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + continue; + } + InlinedVector layer_norm_input_defs{x_input, scale, bias}; Node& layer_norm_node = graph.AddNode(graph.GenerateNodeName("LayerNormalization"), "LayerNormalization", @@ -689,6 +696,13 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr NodeArg* x_input = has_leading_cast ? graph.GetNode(p_pow_input_node->Index())->MutableInputDefs()[0] : pow_node.MutableInputDefs()[0]; + + // CPU doesn't support fp16 + if (reduce_mean_node.GetExecutionProviderType() == kCpuExecutionProvider && + x_input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + continue; + } + InlinedVector layer_norm_input_defs{x_input, scale}; Node& layer_norm_node = graph.AddNode(graph.GenerateNodeName("SimplifiedLayerNormalization"), "SimplifiedLayerNormalization", diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc index 4505d4afdf1e0..7953cde6686c0 100644 --- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc +++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc @@ -31,6 +31,7 @@ CostCheckResult PostLayoutTransformCostCheck(const api::GraphRef& graph, const a } #if defined(USE_CUDA) && ENABLE_CUDA_NHWC_OPS +// TODO(mtavenrath) generate list from registered kernels using nhwc domain const std::unordered_set& GetCUDALayoutSensitiveOps() { static std::unordered_set cuda_nhwc_ops = []() { return std::unordered_set{ @@ -41,7 +42,10 @@ const std::unordered_set& GetCUDALayoutSensitiveOps() { "MaxPool", "GlobalAveragePool", "AveragePool", - }; + "GridSample", + "DepthToSpace", + "SpaceToDepth", + "LRN"}; }(); return cuda_nhwc_ops; } diff --git a/onnxruntime/core/optimizer/matmul_integer_to_float.cc b/onnxruntime/core/optimizer/matmul_integer_to_float.cc index 56e51cb787931..4fee1a6ce224e 100644 --- a/onnxruntime/core/optimizer/matmul_integer_to_float.cc +++ b/onnxruntime/core/optimizer/matmul_integer_to_float.cc @@ -31,6 +31,24 @@ static bool CheckBiasShape(const TensorShapeProto* bias_shape) { return bias_last_dim > 1; } +bool HasElementDataType(const NodeArg& node_arg, int32_t data_type) { + if (!node_arg.Exists()) { + return false; + } + + const auto* type_proto = node_arg.TypeAsProto(); + if (!type_proto) { + return false; + } + + int32_t actual_data_type; + if (!utils::TryGetElementDataType(*type_proto, actual_data_type)) { + return false; + } + + return data_type == actual_data_type; +} + /** MatMulIntegerToFloatFusion will fuse subgraph like below into MatMulIntegerToFloat: @@ -63,9 +81,10 @@ Status MatMulIntegerToFloatFusion::ApplyImpl(Graph& graph, bool& modified, int g auto& mul_node = *node_ptr; ORT_RETURN_IF_ERROR(Recurse(mul_node, modified, graph_level, logger)); - + const bool is_dml_ep = node_ptr->GetExecutionProviderType() == kDmlExecutionProvider; if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) || - !graph_utils::IsSupportedProvider(mul_node, GetCompatibleExecutionProviders())) { + !graph_utils::IsSupportedProvider(mul_node, GetCompatibleExecutionProviders()) || + (!is_dml_ep && HasElementDataType(*mul_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT16))) { continue; } diff --git a/onnxruntime/core/optimizer/noop_elimination.cc b/onnxruntime/core/optimizer/noop_elimination.cc index b3c2991d54b28..bba39b698a27a 100644 --- a/onnxruntime/core/optimizer/noop_elimination.cc +++ b/onnxruntime/core/optimizer/noop_elimination.cc @@ -42,49 +42,62 @@ bool NoopElimination::SatisfyCondition(const Graph& graph, const Node& node, con // if initializer_rank is bigger, the output is expected to be initializer_rank per broadcasting rule, // but it won't happen if the case is accepted, thus reject it - auto initializer_rank = initializer->dims().size(); + const auto& dims = initializer->dims(); + auto initializer_rank = dims.size(); const auto* other_input_shape = node.InputDefs()[input0_is_initializer ? 1 : 0]->Shape(); if (other_input_shape == nullptr || initializer_rank > other_input_shape->dim_size()) { return false; } - int32_t data_type = initializer->data_type(); - Initializer add_init(*initializer, graph.ModelPath()); - if (add_init.size() > 1) { + int64_t tensor_size = 1; + for (auto i : dims) { + tensor_size *= i; + } + + if (tensor_size > 1) { return false; } + // handle edge case where the total size of the initializer is 0 - if (add_init.size() == 0) { + if (tensor_size == 0) { return true; } - float value = 0.0f; - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - value = *add_init.data(); - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - value = math::halfToFloat(add_init.data()->val); - break; - case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: - value = static_cast(*add_init.data()); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - value = static_cast(*add_init.data()); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - value = static_cast(*add_init.data()); - break; - default: + if (op_type == "Add" || + op_type == "Sub" || + op_type == "Mul" || + op_type == "Div") { + int32_t data_type = initializer->data_type(); + Initializer add_init(*initializer, graph.ModelPath()); + + float value = 0.0f; + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + value = *add_init.data(); + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + value = math::halfToFloat(add_init.data()->val); + break; + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: + value = static_cast(*add_init.data()); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + value = static_cast(*add_init.data()); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + value = static_cast(*add_init.data()); + break; + default: + return false; + } + + if (value != 0.0f && (op_type == "Add" || op_type == "Sub")) { return false; - } + } - if ((op_type == "Add" || op_type == "Sub") && value != 0.0f) { - return false; - } - - if ((op_type == "Mul" || op_type == "Div") && value != 1.0f) { - return false; + if (value != 1.0f && (op_type == "Mul" || op_type == "Div")) { + return false; + } } // reject node output is graph output for now diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc index b1ab641a23256..4e3dff705bd41 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc @@ -76,6 +76,49 @@ bool IsQDQPairSupported( } } +bool IsDQQConversion( + const Node& dq_node, const Node& q_node, + const GetConstantInitializerFn& get_const_initializer, + const Path& model_path) { + ConstPointerContainer> dq_input_defs = dq_node.InputDefs(); + ConstPointerContainer> q_input_defs = q_node.InputDefs(); + + // Q/DQ contains optional input is not supported + // non-scalar Q/DQ scale and zero point needs are not supported + if (dq_input_defs.size() != InputIndex::TOTAL_COUNT || + q_input_defs.size() != InputIndex::TOTAL_COUNT || + !optimizer_utils::IsScalar(*q_input_defs[InputIndex::SCALE_ID]) || + !optimizer_utils::IsScalar(*q_input_defs[InputIndex::ZERO_POINT_ID]) || + !optimizer_utils::IsScalar(*dq_input_defs[InputIndex::SCALE_ID]) || + !optimizer_utils::IsScalar(*dq_input_defs[InputIndex::ZERO_POINT_ID])) { + return false; + } + + // if Q/DQ scale and zero point are not constant, return false + const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto = + get_const_initializer(dq_input_defs[InputIndex::SCALE_ID]->Name()); + const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto = + get_const_initializer(q_input_defs[InputIndex::SCALE_ID]->Name()); + const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto = + get_const_initializer(dq_input_defs[InputIndex::ZERO_POINT_ID]->Name()); + const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto = + get_const_initializer(q_input_defs[InputIndex::ZERO_POINT_ID]->Name()); + if (nullptr == q_zp_tensor_proto || + nullptr == dq_zp_tensor_proto || + nullptr == q_scale_tensor_proto || + nullptr == dq_scale_tensor_proto) { + return false; + } + + // check Q/DQ have same scale type and different zero point type + Initializer q_zp(*q_zp_tensor_proto, model_path); + Initializer q_scale(*q_scale_tensor_proto, model_path); + Initializer dq_zp(*dq_zp_tensor_proto, model_path); + Initializer dq_scale(*dq_scale_tensor_proto, model_path); + + return (dq_zp.data_type() != q_zp.data_type()) && (dq_scale.data_type() == q_scale.data_type()); +} + bool IsDQSupported(const Node& dq_node, const GetConstantInitializerFn& get_const_initializer) { bool zero_point_exists = false; if (!QOrDQNodeHasConstantScalarScaleAndZeroPoint(dq_node, get_const_initializer, zero_point_exists)) { diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h index bb0bf9438cfcb..8333168b0093f 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h @@ -38,6 +38,18 @@ bool IsQDQPairSupported( const GetConstantInitializerFn& get_const_initializer, const Path& model_path); +// Check if a DQ -> Q sequence represents a conversion in quantization data type. +// Example of uint8 to uint16: +// Dequantize (uint8 to float) -> Quantize (float to uint16) +// Requires: +// 1. Q/DQ doesn't have optional input. +// 2. scale and zero-point are constant scalars. +// 3. Q and DQ have the same scale *type* and different zero-point *types*. +bool IsDQQConversion( + const Node& dq_node, const Node& q_node, + const GetConstantInitializerFn& get_const_initializer, + const Path& model_path); + // Check if DQ is supported in extended level QDQ transformers. It requires: // 1. DQ doesn't have optional input. // 2. scale and zero point is constant scalar diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 8535b8c9a944a..6b4f62ae1343d 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -58,8 +58,8 @@ bool NodeGroupSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Nod return false; } - if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes); - !dq_validation_status.IsOK()) { + if (const auto qdq_validation_status = NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes); + !qdq_validation_status.IsOK()) { return false; } @@ -153,8 +153,8 @@ bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } - if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes); - !dq_validation_status.IsOK()) { + if (const auto qdq_validation_status = NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes); + !qdq_validation_status.IsOK()) { return false; } @@ -544,8 +544,8 @@ bool TopKNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } - if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes); - !dq_validation_status.IsOK()) { + if (const auto qdq_validation_status = QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes); + !qdq_validation_status.IsOK()) { return false; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index deee6e7f25f1a..c90a42a36483d 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -5,6 +5,7 @@ #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) +#include "core/framework/node_unit.h" #include "core/optimizer/selectors_actions/selector_action_transformer.h" namespace onnxruntime { @@ -13,13 +14,6 @@ class Node; namespace QDQ { -// Struct to represent a DQ->Op->Q node group -struct NodeGroup { - std::vector dq_nodes; - std::vector q_nodes; - NodeIndex target_node; -}; - class NodeGroupSelector { public: // This is a QDQ Selectors only function, will return QDQ::NodeGroup instead of NodesToOptimizeIndices diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index 544fe82a268c8..1876f7826c968 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -13,6 +13,7 @@ #include #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" namespace onnxruntime { namespace QDQ { @@ -43,6 +44,7 @@ static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() { {"Tile", {}}}; } +// These produce int64 indices output, which can't be quantized, so there's no downstream Q node. static const OpVersionsAndSelector::OpVersionsMap GetDropDQOpVersionsMap() { return {{"ArgMax", {}}, {"ArgMin", {}}}; @@ -324,28 +326,48 @@ std::vector SelectorManager::GetQDQSelections(const GraphViewer& grap return qdq_selections; } -Status ValidateNodeGroupDQNodes(const GraphViewer& graph_viewer, - const Node& target_node, - gsl::span dq_nodes) { - // Within a QDQ node group, a target node input is the only consumer of each DQ. - // This should have been ensured by the EnsureUniqueDQForNodeUnit graph transformer, but other graph modifications - // may have happened since. Verify that this is still true. - for (const auto* dq_node : dq_nodes) { - const bool dq_produces_graph_output = graph_viewer.NodeProducesGraphOutput(*dq_node); - ORT_RETURN_IF(dq_produces_graph_output, - "QDQ node group cannot have DQ node that produces a graph output. DQ node: ", dq_node->Name(), - ", target node: ", target_node.Name()); - - const bool dq_has_single_output_edge_to_target = - dq_node->GetOutputEdgesCount() == 1 && - dq_node->OutputEdgesBegin()->GetNode().Index() == target_node.Index(); - ORT_RETURN_IF_NOT(dq_has_single_output_edge_to_target, - "QDQ node group cannot have DQ that doesn't have a single output edge to the target node. " - "DQ node: ", - dq_node->Name(), ", target node: ", target_node.Name()); +std::pair>, std::unordered_map> +GetAllNodeUnits(const GraphViewer& graph_viewer) { + std::vector> node_unit_holder; + std::unordered_map node_unit_map; + + const auto add_node_unit_to_map = [&](const std::vector& node_indices, const NodeUnit* node_unit) { + for (const auto& node_idx : node_indices) { + const auto* node = graph_viewer.GetNode(node_idx); + node_unit_map.insert({node, node_unit}); + } + }; + + // Get QDQ NodeUnits first + QDQ::SelectorManager selector_mgr; + const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer); + + for (const auto& qdq_selection : qdq_selections) { + auto qdq_unit = std::make_unique(graph_viewer, qdq_selection); + + // Fill the node to node_unit map for all nodes in the QDQ Group + add_node_unit_to_map(qdq_selection.dq_nodes, qdq_unit.get()); + add_node_unit_to_map(qdq_selection.q_nodes, qdq_unit.get()); + add_node_unit_to_map({qdq_selection.target_node}, qdq_unit.get()); + + node_unit_holder.push_back(std::move(qdq_unit)); + } + + // Get the left over SingleNode NodeUnits + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); + for (const auto node_idx : node_indices) { + const auto* node(graph_viewer.GetNode(node_idx)); + + // This is already part of a QDQ NodeUnit + if (node_unit_map.find(node) != node_unit_map.cend()) + continue; + + auto node_unit = std::make_unique(*node); + node_unit_map[node] = node_unit.get(); + node_unit_holder.push_back(std::move(node_unit)); } - return Status::OK(); + return std::make_pair(std::move(node_unit_holder), std::move(node_unit_map)); } } // namespace QDQ diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h index 246f26c1760ec..de36202afff29 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h @@ -7,6 +7,7 @@ #include "core/common/common.h" #include "core/common/gsl.h" #include "core/common/inlined_containers.h" +#include "core/framework/node_unit.h" #include "core/graph/basic_types.h" #if !defined(ORT_MINIMAL_BUILD) @@ -78,11 +79,16 @@ class SelectorManager { ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SelectorManager); }; -// Checks whether the provided DQ nodes are valid for forming a QDQ node group with the provided target node. -// Returns successful status if so, failed status with reason otherwise. -Status ValidateNodeGroupDQNodes(const GraphViewer& graph_viewer, - const Node& target_node, - gsl::span dq_nodes); +// Get all the nodes in the given graph_viewer as NodeUnits (SingleNode or QDQGroup) +// And return a map to quick query the NodeUnit which contains the given Node, +// Note, the value of the map is owned by the vector of std::unique_ptr +// +// TODO: The overall QDQ setup needs refactoring to separate out generic functionality from optimizer specific +// functionality. +// We currently have a bit of a mess with generic things like this to get all the node units being in the optimizer +// library whereas it should be able to be used by an EP with no dependency on optimizers. +std::pair>, std::unordered_map> +GetAllNodeUnits(const GraphViewer& graph_viewer); } // namespace QDQ } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/shape_input_merge.cc b/onnxruntime/core/optimizer/shape_input_merge.cc new file mode 100644 index 0000000000000..9f20520e3e3f4 --- /dev/null +++ b/onnxruntime/core/optimizer/shape_input_merge.cc @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/shape_input_merge.h" + +#include "core/graph/graph_utils.h" + +namespace onnxruntime { + +namespace { +std::string GetShapeString(const NodeArg* input_arg) { + auto shape = input_arg->Shape(); + if (!shape) return ""; + std::stringstream ss; + ss << "["; + for (int i = 0; i < shape->dim_size(); ++i) { + if (i != 0) ss << ","; + auto dim = shape->dim(i); + if (dim.has_dim_value()) { + ss << std::to_string(dim.dim_value()); + } else if (dim.has_dim_param()) { + ss << "'" << dim.dim_param() << "'"; + } else { + return ""; + } + } + ss << "]"; + return ss.str(); +} + +} // namespace + +Status ShapeInputMerge::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + InlinedHashMap> input_hash_to_nodes; + for (auto node_index : node_topology_list) { + auto* p_node = graph.GetNode(node_index); + if (!p_node) continue; // we removed the node as part of an earlier fusion + ORT_RETURN_IF_ERROR(Recurse(*p_node, modified, graph_level, logger)); + if (!graph_utils::IsSupportedOptypeVersionAndDomain(*p_node, "Shape", {1, 13, 15, 19, 21}) || + !graph_utils::IsSupportedProvider(*p_node, GetCompatibleExecutionProviders())) { + continue; + } + std::string shape_str = GetShapeString(p_node->InputDefs()[0]); + if (shape_str.empty()) continue; + if (input_hash_to_nodes.find(shape_str) == input_hash_to_nodes.end()) { + input_hash_to_nodes[shape_str] = InlinedVector(); + } + input_hash_to_nodes[shape_str].emplace_back(p_node); + } + + // All Shape nodes are processed in topological order, so we can safely merge the inputs to the first node's input. + for (auto& kv : input_hash_to_nodes) { + if (kv.second.size() < 2) continue; + NodeArg* first_input_arg = kv.second[0]->MutableInputDefs()[0]; + bool is_first_input_arg_graph_input = graph.IsInputsIncludingInitializers(first_input_arg); + for (size_t i = 1; i < kv.second.size(); ++i) { + Node* p_node = kv.second[i]; + const NodeArg* input_arg = p_node->InputDefs()[0]; + if (p_node->InputDefs()[0]->Name() == first_input_arg->Name()) continue; + if (!graph.IsInputsIncludingInitializers(input_arg)) { + const Node::EdgeEnd& input_edge = *p_node->InputEdgesBegin(); + graph.RemoveEdge(input_edge.GetNode().Index(), p_node->Index(), input_edge.GetSrcArgIndex(), 0); + } + graph_utils::ReplaceNodeInput(*p_node, 0, *first_input_arg); + if (!is_first_input_arg_graph_input) { + const Node::EdgeEnd& first_input_edge = *kv.second[0]->InputEdgesBegin(); + graph.AddEdge(first_input_edge.GetNode().Index(), p_node->Index(), first_input_edge.GetSrcArgIndex(), 0); + } + modified = true; + } + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/shape_input_merge.h b/onnxruntime/core/optimizer/shape_input_merge.h new file mode 100644 index 0000000000000..5cb943998487b --- /dev/null +++ b/onnxruntime/core/optimizer/shape_input_merge.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** +@Class ShapeInputMerge +Merge all shape inputs having same shape value to a single shape input. +This change will not affect the performance, but it open chances for CSE fusion to merge nodes. +*/ +class ShapeInputMerge : public GraphTransformer { + public: + ShapeInputMerge(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("ShapeInputMerge", compatible_execution_providers) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/stft_decomposition.cc b/onnxruntime/core/optimizer/stft_decomposition.cc new file mode 100644 index 0000000000000..a54904ff15e1e --- /dev/null +++ b/onnxruntime/core/optimizer/stft_decomposition.cc @@ -0,0 +1,381 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/optimizer/stft_decomposition.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/optimizer_execution_frame.h" +#include "core/optimizer/utils.h" +#include "core/framework/op_kernel.h" +#include "core/framework/tensorprotoutils.h" + +using namespace onnxruntime::common; + +namespace onnxruntime { + +STFTDecomposition::STFTDecomposition(const InlinedHashSet& compatible_execution_providers) noexcept + : GraphTransformer("STFTDecomposition", compatible_execution_providers) { +} + +template +constexpr static ONNX_NAMESPACE::TensorProto_DataType GetDataType() { + if constexpr (std::is_same::value) { + return ONNX_NAMESPACE::TensorProto_DataType_FLOAT; + } else if constexpr (std::is_same::value) { + return ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; + } else if constexpr (std::is_same::value) { + return ONNX_NAMESPACE::TensorProto_DataType_DOUBLE; + } else if constexpr (std::is_same::value) { + return ONNX_NAMESPACE::TensorProto_DataType_INT64; + } else { + throw std::logic_error("Invalid data type requested for STFT decomposition"); + } +} + +template +NodeArg* AddInitializer(Graph& graph, const char* name, const int64_t (&shape)[TDims], const TDataType* begin) { + ONNX_NAMESPACE::TensorProto proto; + proto.set_name(graph.GenerateNodeArgName(name)); + proto.set_data_type(GetDataType()); + int64_t element_count = 1; + for (size_t i = 0; i < TDims; i++) { + element_count *= shape[i]; + proto.add_dims(shape[i]); + } + proto.set_raw_data(begin, element_count * sizeof(TDataType)); + return &graph_utils::AddInitializer(graph, proto); +} + +template +NodeArg* AddShapeInitializer(Graph& graph, const char* name, const int64_t (&shape)[TDims]) { + int64_t shape_shape[] = {TDims}; + return AddInitializer(graph, name, shape_shape, shape); +} + +std::pair AddNode(Graph& graph, + const char* op_type, + ProviderType execution_provider_type, + gsl::span inputs) { + auto def_name = graph.GenerateNodeArgName(op_type); + auto node_arg = &graph.GetOrCreateNodeArg(def_name, nullptr); + Node& node = graph.AddNode(graph.GenerateNodeName(op_type), + op_type, + "", + inputs, + {node_arg}); + node.SetExecutionProviderType(execution_provider_type); + return std::make_pair(&node, node_arg); +} + +std::pair AddNodeCast(Graph& graph, NodeArg* in, + ONNX_NAMESPACE::TensorProto_DataType data_type) { + auto def_name = graph.GenerateNodeArgName("Cast"); + auto node_arg = &graph.GetOrCreateNodeArg(def_name, nullptr); + Node& node = graph.AddNode(graph.GenerateNodeName("Cast"), + "Cast", + "", + {in}, + {node_arg}); + node.AddAttribute("to", static_cast(data_type)); + node.SetExecutionProviderType(kCpuExecutionProvider); + return std::make_pair(&node, node_arg); +} + +#define CONTINUE_IF_NO_DIM_VALUE(dim) \ + if (!dim.has_dim_value()) { \ + continue; \ + } +#define CONTINUE_IF_NULL(x) \ + if (x == nullptr) { \ + continue; \ + } + +/* + This function decomposes a STFT node into a subgraph. + The decomposition requires that: + 1) The signal input is real valued and not complex valued! + 2) Both (frame_step) *and* either (window or frame_length) inputs must be constant. + Otherwise the transform will not be applied. + + Subgraph pattern 1: STFT with optional Window parameter set + [root]--(signal)--------------------+ + [root]--(frame_step)---------------+| + [root]--(window)------------------+|| + [root]--(frame_length) ----------+||| + |||| + vvvv + [STFT]--(output)--> + After Fusion: + [root]--(signal)-------------------------+ + [root] | + [root]--(window)--+ | + [root] | | + v v + (only for non-fp32) [Cast] +--[Reshape] + | | | + v | v + [Reshape]-->[Mul]---|-->[Conv]-------+ + | | | + | +-----| | + | v v + +------>[Mul]------>[Conv]-->[Concat]-->[Reshape]-->[Transpose]--(output)--> + + + Subgraph pattern 2: STFT without optional Window parameter set + [root]--(signal)-------------------+ + [root]--(frame_step)--------------+| + [root] | + [root]--(frame_length) ----------+|| + ||| + vvv + [STFT]--(output)--> + After Fusion: + [root]--(signal)-->[Reshape]-->[Conv] + [root] | | + [root] | v + [root] +------>[Conv]-->[Concat]-->[Reshape]-->[Transpose]--(output)--> +*/ +Status STFTDecomposition::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + auto& order = graph_viewer.GetNodesInTopologicalOrder(); + + for (NodeIndex i : order) { + auto node = graph.GetNode(i); + CONTINUE_IF_NULL(node); + ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger)); + + if (node->OpType() != "STFT") { + continue; + } + + Node& stft = *node; + auto signal = stft.MutableInputDefs()[0]; + auto frame_step = stft.MutableInputDefs()[1]; + auto window = stft.MutableInputDefs()[2]; + auto frame_length = stft.MutableInputDefs()[3]; + + // If the signal has free dimensions, do not transform... + auto batch_size_dim = signal->Shape()->dim(0); + auto signal_length_dim = signal->Shape()->dim(1); + auto signal_components_dim = signal->Shape()->dim(2); + CONTINUE_IF_NO_DIM_VALUE(signal_length_dim); + CONTINUE_IF_NO_DIM_VALUE(signal_components_dim); + + auto batch_size = batch_size_dim.has_dim_value() ? batch_size_dim.dim_value() : static_cast(-1); + auto signal_length = signal_length_dim.dim_value(); + auto is_real = signal_components_dim.dim_value() == 1; + auto data_type = static_cast(signal->TypeAsProto()->tensor_type().elem_type()); + + auto frame_step_initializer = graph_utils::GetConstantInitializer(graph, frame_step->Name()); + auto window_initializer = graph_utils::GetConstantInitializer(graph, window->Name()); + auto frame_length_initializer = graph_utils::GetConstantInitializer(graph, frame_length->Name()); + CONTINUE_IF_NULL(frame_step_initializer); + if (!frame_length_initializer && !window_initializer) { + continue; + } + + auto read_int64_initializer = [](Graph& graph, const ONNX_NAMESPACE::TensorProto* initializer) { + return *Initializer(*initializer, graph.ModelPath()).data(); + }; + auto frame_step_value = read_int64_initializer(graph, frame_step_initializer); + + // Get DFT Size + int64_t dft_size = 0; + if (frame_length_initializer) { + dft_size = read_int64_initializer(graph, frame_length_initializer); + } + if (dft_size == 0 && window_initializer) { + auto window_length_dim = window->Shape()->dim(0); + CONTINUE_IF_NO_DIM_VALUE(window_length_dim); + dft_size = window_length_dim.dim_value(); + } + + bool is_onesided = true; + auto& attrs = stft.GetAttributes(); + if (attrs.find("onesided") != attrs.end()) { + auto& onesided_attr = attrs.at("onesided"); + if (utils::HasInt(onesided_attr)) { + is_onesided = static_cast(onesided_attr.i()); + } + } + + auto dft_unique_bins = is_onesided ? ((dft_size >> 1) + 1) : dft_size; + + Node* signal_recipient = nullptr; + Node* window_recipient = nullptr; + Node* stft_producer = nullptr; + if (is_real) { + auto output_num_frames = stft.MutableOutputDefs()[0]->Shape()->dim(1).dim_value(); + auto output_frame_length = stft.MutableOutputDefs()[0]->Shape()->dim(2).dim_value(); + auto weight_size = static_cast(dft_unique_bins * dft_size); + auto real_weights_data = std::vector(weight_size); + auto imag_weights_data = std::vector(weight_size); + + // Populate weights + for (size_t k = 0; k < static_cast(dft_unique_bins); k++) { + for (size_t n = 0; n < static_cast(dft_size); n++) { + auto index = static_cast(k * dft_size + n); + auto theta = -2 * M_PI * k * n / static_cast(dft_size); + real_weights_data[index] = static_cast(cos(theta)); + imag_weights_data[index] = static_cast(sin(theta)); + } + } + + const int64_t weight_shape[] = {dft_unique_bins, 1, 1, dft_size}; + auto real_weights = AddInitializer(graph, "stft_real_conv_weights", weight_shape, real_weights_data.data()); + auto imaginary_weights = AddInitializer(graph, "stft_imaginary_conv_weights", weight_shape, imag_weights_data.data()); + + const int64_t signal_reshaped[] = {batch_size, 1, 1, signal_length}; + auto signal_shape = AddShapeInitializer(graph, "stft_signal_shape", signal_reshaped); + + const int64_t unsqueezed_output_shape[] = {2, batch_size, output_frame_length, output_num_frames}; + auto unsqueezed_shape = AddShapeInitializer(graph, "stft_output_reshaped", unsqueezed_output_shape); + + NodeArg* signal_reshaped_inputs[] = {signal, signal_shape}; + Node* reshape_signal_node = nullptr; + NodeArg* reshape_output = nullptr; + std::tie(reshape_signal_node, reshape_output) = + AddNode(graph, "Reshape", stft.GetExecutionProviderType(), signal_reshaped_inputs); + + NodeArg* real_weights_final = real_weights; + NodeArg* imag_weights_final = imaginary_weights; + if (!window->Exists()) { + // When we are missing a window function + if (real_weights_final->TypeAsProto()->tensor_type().elem_type() != data_type) { + std::tie(std::ignore, real_weights_final) = + AddNodeCast(graph, real_weights_final, data_type); + } + if (imag_weights_final->TypeAsProto()->tensor_type().elem_type() != data_type) { + std::tie(std::ignore, imag_weights_final) = + AddNodeCast(graph, imag_weights_final, data_type); + } + } else { + // When we have a window function + const int64_t window_reshaped_shape[] = {1, 1, 1, dft_size}; + auto window_shape = AddShapeInitializer(graph, "stft_window_shape", window_reshaped_shape); + + auto window_final = window; + if (window->TypeAsProto()->tensor_type().elem_type() != GetDataType()) { + Node* window_cast_node = nullptr; + std::tie(window_cast_node, window_final) = + AddNodeCast(graph, window, GetDataType()); + window_recipient = window_cast_node; + } + + NodeArg* window_reshaped_inputs[] = {window_final, window_shape}; + Node* window_reshape_node; + NodeArg* window_reshaped = nullptr; + std::tie(window_reshape_node, window_reshaped) = + AddNode(graph, "Reshape", kCpuExecutionProvider, window_reshaped_inputs); + if (!window_recipient) { + window_recipient = window_reshape_node; + } + + NodeArg* scale_real_weights_inputs[] = {real_weights, window_reshaped}; + NodeArg* windowed_real_weights_output = nullptr; + std::tie(std::ignore, windowed_real_weights_output) = + AddNode(graph, "Mul", kCpuExecutionProvider, scale_real_weights_inputs); + + NodeArg* scale_imag_weights_inputs[] = {imaginary_weights, window_reshaped}; + NodeArg* windowed_imag_weights_output = nullptr; + std::tie(std::ignore, windowed_imag_weights_output) = + AddNode(graph, "Mul", kCpuExecutionProvider, scale_imag_weights_inputs); + + std::tie(std::ignore, real_weights_final) = + AddNodeCast(graph, windowed_real_weights_output, data_type); + std::tie(std::ignore, imag_weights_final) = + AddNodeCast(graph, windowed_imag_weights_output, data_type); + } + + // Add Convolution (reals) + NodeArg* conv_real_inputs[] = {reshape_output, real_weights_final}; + Node* real_conv_node = nullptr; + NodeArg* real_conv_output = nullptr; + std::tie(real_conv_node, real_conv_output) = + AddNode(graph, "Conv", stft.GetExecutionProviderType(), conv_real_inputs); + real_conv_node->AddAttribute("strides", std::vector{1, frame_step_value}); + + // Add Convolution (imaginary) + NodeArg* conv_imag_inputs[] = {reshape_output, imag_weights_final}; + Node* imag_conv_node = nullptr; + NodeArg* imag_conv_output = nullptr; + std::tie(imag_conv_node, imag_conv_output) = + AddNode(graph, "Conv", stft.GetExecutionProviderType(), conv_imag_inputs); + imag_conv_node->AddAttribute("strides", std::vector{1, frame_step_value}); + + // Concatenate + NodeArg* concatenate_inputs[] = {real_conv_output, imag_conv_output}; + Node* concat_node = nullptr; + NodeArg* concatenated_conv_output = nullptr; + std::tie(concat_node, concatenated_conv_output) = + AddNode(graph, "Concat", stft.GetExecutionProviderType(), concatenate_inputs); + concat_node->AddAttribute("axis", static_cast(0)); + + // Unsqueeze Reshape + NodeArg* unsqueeze_reshape_inputs[] = {concatenated_conv_output, unsqueezed_shape}; + NodeArg* unsqueezed_output = nullptr; + std::tie(std::ignore, unsqueezed_output) = + AddNode(graph, "Reshape", stft.GetExecutionProviderType(), unsqueeze_reshape_inputs); + + // Transpose + NodeArg* transpose_inputs[] = {unsqueezed_output}; + Node* transpose_node = nullptr; + NodeArg* transpose_output = nullptr; + std::tie(transpose_node, transpose_output) = + AddNode(graph, "Transpose", stft.GetExecutionProviderType(), transpose_inputs); + transpose_node->AddAttribute("perm", std::vector{1, 3, 2, 0}); + + signal_recipient = reshape_signal_node; + stft_producer = transpose_node; + } else { + continue; + } + + auto input_edges = graph_utils::GraphEdge::GetNodeInputEdges(stft); + auto output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(stft); + + // Copy inputs + auto signal_target_idx = signal_recipient->Index(); + auto window_target_idx = window_recipient->Index(); + for (auto cur = input_edges.cbegin(), end = input_edges.cend(); cur != end; ++cur) { + const graph_utils::GraphEdge& edge = *cur; + NodeIndex target_idx = 0; + Node* recipient = nullptr; + switch (cur->dst_arg_index) { + case 0: + target_idx = signal_target_idx; + recipient = signal_recipient; + break; + case 2: + target_idx = window_target_idx; + recipient = window_recipient; + break; + } + + if (!recipient) { + continue; + } + + auto arg_index = graph_utils::GetNodeInputIndexFromInputName(*recipient, edge.arg_name); + graph.AddEdge(edge.src_node, target_idx, edge.src_arg_index, arg_index); + } + + // Copy STFT outputs to stft_producer + stft_producer->MutableOutputDefs() = stft.MutableOutputDefs(); + auto stft_producer_target_idx = stft_producer->Index(); + for (auto cur = output_edges.cbegin(), end = output_edges.cend(); cur != end; ++cur) { + graph.AddEdge(stft_producer_target_idx, cur->dst_node, cur->src_arg_index, cur->dst_arg_index); + } + + graph_utils::GraphEdge::RemoveGraphEdges(graph, input_edges); + graph_utils::GraphEdge::RemoveGraphEdges(graph, output_edges); + graph.RemoveNode(stft.Index()); + + modified = true; + } + return Status::OK(); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/stft_decomposition.h b/onnxruntime/core/optimizer/stft_decomposition.h new file mode 100644 index 0000000000000..cac058474375e --- /dev/null +++ b/onnxruntime/core/optimizer/stft_decomposition.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" +#include "core/framework/ort_value.h" +#include +#include "core/framework/execution_provider.h" + +namespace onnxruntime { + +/** +@class STFTDecomposition + +Transformer that traverses the graph top-down and decomposes +STFT into convolution. +*/ +class STFTDecomposition : public GraphTransformer { + public: + /*! STFT decomposition . + \param execution_provider Execution provider instance to execute constant folding. + */ + STFTDecomposition(const InlinedHashSet& compatible_execution_providers = {}) noexcept; + + private: + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc index d9f08ffe1171e..c532f56b3d3d9 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc @@ -115,7 +115,7 @@ class ApiGraph final : public api::GraphRef { const auto& graph_outputs = graph_.GetOutputs(); graph_outputs_.reserve(graph_outputs.size()); for (const auto* output : graph_outputs) { - graph_outputs_.insert(output->Name()); + graph_outputs_.emplace(output->Name()); } } diff --git a/onnxruntime/core/optimizer/utils.cc b/onnxruntime/core/optimizer/utils.cc index 7c3599a08ec7a..7055882961e17 100644 --- a/onnxruntime/core/optimizer/utils.cc +++ b/onnxruntime/core/optimizer/utils.cc @@ -272,7 +272,7 @@ int32_t IndexOfNodeOutput(const Node& node, const NodeArg& node_arg) { // We could also allow other known domains (kMSDomain, kMSNchwcDomain, kMSFeaturizersDomain), // as long as we verify which of their operations are non-deterministic and add them in the map below. constexpr std::array kOnnxDomainNonDeterministicOps{"RandomUniform", "RandomNormal", "RandomUniformLike", - "RandomNormalLike", "Multinomial"}; + "RandomNormalLike", "Multinomial", "Dropout"}; // List of deterministic MS domain operators. Currently used for constant folding and common subexpression elimination. // @@ -280,7 +280,8 @@ constexpr std::array kOnnxDomainNonDeterministicOps{"RandomUniform", "RandomNorm // with the above ONNX list. With the current approach, only MS domain Q/DQ operators // (plus ShrunkenGather for training) are considered deterministic. #ifdef ENABLE_TRAINING_OPS -constexpr std::array kMSDomainDeterministicOps{"ShrunkenGather", "QuantizeLinear", "DequantizeLinear"}; +constexpr std::array kMSDomainDeterministicOps{"ShrunkenGather", "QuantizeLinear", "DequantizeLinear", + "ConcatTraining"}; #else constexpr std::array kMSDomainDeterministicOps{"QuantizeLinear", "DequantizeLinear"}; #endif diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index 1a0713db43db8..983cc6089bb4c 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -32,6 +32,9 @@ limitations under the License. #include "core/common/span_utils.h" #include "core/platform/env.h" #include "core/platform/scoped_resource.h" +#if defined(_M_X64) && !defined(_M_ARM64EC) && defined(ONNXRUNTIME_ENABLE_INTEL_METEOR_LAKE_MOBILE_PLATFORM_PERF_PATCH) +#include "core/platform/windows/hardware_core_enumerator.h" +#endif #include #include @@ -248,12 +251,53 @@ void WindowsEnv::SleepForMicroseconds(int64_t micros) const { Sleep(static_cast(micros) / 1000); } +// EIGEN_NO_CPUID is not defined in any C/C++ source code. It is a compile option. +#if defined(_M_X64) && !defined(_M_ARM64EC) && !defined(EIGEN_NO_CPUID) && defined(ONNXRUNTIME_ENABLE_INTEL_METEOR_LAKE_MOBILE_PLATFORM_PERF_PATCH) +static constexpr std::array kVendorID_Intel = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" +#endif int WindowsEnv::DefaultNumCores() { return std::max(1, static_cast(std::thread::hardware_concurrency() / 2)); } int WindowsEnv::GetNumPhysicalCpuCores() const { - return cores_.empty() ? DefaultNumCores() : static_cast(cores_.size()); +// EIGEN_NO_CPUID is not defined in any C/C++ source code. It is a compile option. +#if defined(_M_X64) && !defined(_M_ARM64EC) && !defined(EIGEN_NO_CPUID) && defined(ONNXRUNTIME_ENABLE_INTEL_METEOR_LAKE_MOBILE_PLATFORM_PERF_PATCH) + // The following code is a temporary fix for a perf problem on Intel's Meteor Lake CPUs. The Intel compute platform has + // a hybrid architecture that some CPU cores runs significant slower than the others. If we distribute our compute work + // evenly to all CPU cores, the slowest CPU core will drag the performance down. So, instead, we reduce the total number + // of threads to exclude the slowest cores out. + // The following code is based on assumptions that: + // 1. All Intel hybrid CPUs should have 3 levels of cache. + // 2. If a CPU core is only associated with two levels of cache, it should be a low performance CPU core and should + // not be used. + // Since we don't know what the next Intel hybrid CPU would be like, later on we may need to rework the following code. + // However, no matter what the code should not cause any crash. The worst is it might return 1 that + // thread pools will not be created, which is just a perf issue and does not impact usability. + // TODO: detect if CPUID instruction is available per instructions at https://wiki.osdev.org/CPUID#Checking_CPUID_availability + int regs[4]; + __cpuid(regs, 0); + bool bIsIntel = + (kVendorID_Intel[0] == regs[1]) && + (kVendorID_Intel[1] == regs[2]) && + (kVendorID_Intel[2] == regs[3]); + if (bIsIntel && regs[0] >= 7) { + // Query Structured Extended Feature Flags Enumeration Leaf + __cpuid(regs, 0x7); + // The bit 15 of EDX indicates if the processor is identified as a hybrid part. + bool ishybrid = regs[3] & (1 << 15); + if (ishybrid) { + // NOTE: even if ishybrid is true, it doesn't mean the processor must have P-cores and E-cores. + // On Intel CPUs we assume the HardwareCoreEnumerator::DefaultIntraOpNumThreads function would never fail. + // NOTE: due to resource restrictions, we cannot test this branch in our CI build pipelines. + return std::max(static_cast(1), HardwareCoreEnumerator::DefaultIntraOpNumThreads()); + } else { + return cores_.empty() ? DefaultNumCores() : static_cast(cores_.size()); + } + } else +#endif + { + return cores_.empty() ? DefaultNumCores() : static_cast(cores_.size()); + } } std::vector WindowsEnv::GetDefaultThreadAffinities() const { @@ -415,8 +459,8 @@ Status WindowsEnv::MapFileIntoMemory(_In_z_ const ORTCHAR_T* file_path, void* const mapped_base = MapViewOfFile(file_mapping_handle.get(), FILE_MAP_READ, - 0, - static_cast(mapped_offset), + static_cast((mapped_offset >> 32) & 0xFFFFFFFF), + static_cast(mapped_offset & 0xFFFFFFFF), mapped_length); GSL_SUPPRESS(r.11) mapped_memory = diff --git a/onnxruntime/core/platform/windows/hardware_core_enumerator.cc b/onnxruntime/core/platform/windows/hardware_core_enumerator.cc new file mode 100644 index 0000000000000..121c59808ae59 --- /dev/null +++ b/onnxruntime/core/platform/windows/hardware_core_enumerator.cc @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "hardware_core_enumerator.h" +#include +#include +#include + +namespace onnxruntime { + +struct LogicalProcessorInformation { + std::unique_ptr Buffer; + size_t Length; +}; + +struct CoreCounter { + uint32_t PhysicalCores = 0; + uint32_t SocDieCores = 0; +}; + +static LogicalProcessorInformation GetLogicalProcessorInfos(LOGICAL_PROCESSOR_RELATIONSHIP relationship) { + DWORD length = 0; + DWORD rc = GetLogicalProcessorInformationEx(relationship, nullptr, &length); + + assert(rc == FALSE); + + auto processorInformationBytes = std::make_unique(length); + + rc = GetLogicalProcessorInformationEx( + relationship, reinterpret_cast(processorInformationBytes.get()), &length); + + assert(rc == TRUE); + + return {std::move(processorInformationBytes), length}; +} + +uint32_t CountSetBits(DWORD input) { + uint32_t c; + for (c = 0; input; c++) { + input &= input - 1; + } + return c; +} + +static CoreCounter GetNumberOPhysicalAndEngineeringCores() { + auto logicalProcessorInformation = GetLogicalProcessorInfos(RelationAll); + + CoreCounter cores; + DWORD dwLevel2GroupMask = 0; + DWORD dwLevel3GroupMask = 0; + size_t read = 0; + PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX currentProcessorInfo = NULL; + + while ((read + FIELD_OFFSET(SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX, Processor)) < logicalProcessorInformation.Length) { + currentProcessorInfo = + reinterpret_cast(logicalProcessorInformation.Buffer.get() + read); + if ((read + currentProcessorInfo->Size) > logicalProcessorInformation.Length) { + break; + } + + switch (currentProcessorInfo->Relationship) { + case RelationProcessorCore: + cores.PhysicalCores++; + break; + case RelationCache: + if (currentProcessorInfo->Cache.Level == 2) { + dwLevel2GroupMask |= currentProcessorInfo->Cache.GroupMask.Mask; + } else if (currentProcessorInfo->Cache.Level == 3) { + dwLevel3GroupMask |= currentProcessorInfo->Cache.GroupMask.Mask; + } + break; + } + + read += currentProcessorInfo->Size; + } + + cores.SocDieCores = CountSetBits(dwLevel2GroupMask & ~dwLevel3GroupMask); + return cores; +} + +uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { + // # of physical cores = # of P cores + # of E Cores + # of Soc Cores. + // # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores. + auto cores = GetNumberOPhysicalAndEngineeringCores(); + // We want to use the number of physical cores, but exclude soc cores + return cores.PhysicalCores - cores.SocDieCores; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/platform/windows/hardware_core_enumerator.h b/onnxruntime/core/platform/windows/hardware_core_enumerator.h new file mode 100644 index 0000000000000..93b50f452afcd --- /dev/null +++ b/onnxruntime/core/platform/windows/hardware_core_enumerator.h @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +namespace onnxruntime { +struct HardwareCoreEnumerator { + HardwareCoreEnumerator() = delete; + static uint32_t DefaultIntraOpNumThreads(); +}; +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index a9849873fd060..654281d526e4d 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/platform/windows/telemetry.h" +#include "core/platform/ort_mutex.h" #include "core/common/logging/logging.h" #include "onnxruntime_config.h" @@ -63,6 +64,8 @@ bool WindowsTelemetry::enabled_ = true; uint32_t WindowsTelemetry::projection_ = 0; UCHAR WindowsTelemetry::level_ = 0; UINT64 WindowsTelemetry::keyword_ = 0; +std::vector WindowsTelemetry::callbacks_; +OrtMutex WindowsTelemetry::callbacks_mutex_; WindowsTelemetry::WindowsTelemetry() { std::lock_guard lock(mutex_); @@ -104,6 +107,11 @@ UINT64 WindowsTelemetry::Keyword() const { // return etw_status_; // } +void WindowsTelemetry::RegisterInternalCallback(const EtwInternalCallback& callback) { + std::lock_guard lock(callbacks_mutex_); + callbacks_.push_back(callback); +} + void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback( _In_ LPCGUID SourceId, _In_ ULONG IsEnabled, @@ -112,15 +120,21 @@ void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback( _In_ ULONGLONG MatchAllKeyword, _In_opt_ PEVENT_FILTER_DESCRIPTOR FilterData, _In_opt_ PVOID CallbackContext) { - (void)SourceId; - (void)MatchAllKeyword; - (void)FilterData; - (void)CallbackContext; - std::lock_guard lock(provider_change_mutex_); enabled_ = (IsEnabled != 0); level_ = Level; keyword_ = MatchAnyKeyword; + + InvokeCallbacks(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); +} + +void WindowsTelemetry::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + std::lock_guard lock(callbacks_mutex_); + for (const auto& callback : callbacks_) { + callback(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); + } } void WindowsTelemetry::EnableTelemetryEvents() const { diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index c3798943d491d..cdb186e9ed703 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -2,12 +2,14 @@ // Licensed under the MIT License. #pragma once +#include +#include + #include "core/platform/telemetry.h" #include #include #include "core/platform/ort_mutex.h" #include "core/platform/windows/TraceLoggingConfig.h" -#include namespace onnxruntime { @@ -58,16 +60,27 @@ class WindowsTelemetry : public Telemetry { void LogExecutionProviderEvent(LUID* adapterLuid) const override; + using EtwInternalCallback = std::function; + + static void RegisterInternalCallback(const EtwInternalCallback& callback); + private: static OrtMutex mutex_; static uint32_t global_register_count_; static bool enabled_; static uint32_t projection_; + static std::vector callbacks_; + static OrtMutex callbacks_mutex_; static OrtMutex provider_change_mutex_; static UCHAR level_; static ULONGLONG keyword_; + static void InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext); + static void NTAPI ORT_TL_EtwEnableCallback( _In_ LPCGUID SourceId, _In_ ULONG IsEnabled, diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index 752b742805a7c..9a242919665bb 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1045,7 +1045,7 @@ CANNExecutionProvider::~CANNExecutionProvider() { } // All threads share the same context and stream -Status CANNExecutionProvider::OnRunStart() { +Status CANNExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { CANN_RETURN_IF_ERROR(aclrtSetDevice(info_.device_id)); return Status::OK(); diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.h b/onnxruntime/core/providers/cann/cann_execution_provider.h index 63ae980869c65..d83bd88d6958f 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.h +++ b/onnxruntime/core/providers/cann/cann_execution_provider.h @@ -33,7 +33,7 @@ class CANNExecutionProvider : public IExecutionProvider { explicit CANNExecutionProvider(const CANNExecutionProviderInfo& info); virtual ~CANNExecutionProvider(); - Status OnRunStart() override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; template Status Fill(Tensor* y, void* addr, aclrtStream stream) const { diff --git a/onnxruntime/core/providers/cann/cann_stream_handle.h b/onnxruntime/core/providers/cann/cann_stream_handle.h index 4d03fe5201209..5d822d23f966f 100644 --- a/onnxruntime/core/providers/cann/cann_stream_handle.h +++ b/onnxruntime/core/providers/cann/cann_stream_handle.h @@ -12,6 +12,7 @@ #include "core/providers/cann/cann_call.h" namespace onnxruntime { +void WaitCannNotificationOnDevice(Stream& stream, synchronize::Notification& notification); struct CannStream : Stream { CannStream(aclrtStream stream, const OrtDevice& device, bool own_flag); @@ -23,10 +24,11 @@ struct CannStream : Stream { void Flush() override; bool own_stream_{true}; + + WaitNotificationFn GetWaitNotificationFn() const override { return WaitCannNotificationOnDevice; } }; void RegisterCannStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, const OrtDevice::DeviceType device_type); -void WaitCannNotificationOnDevice(Stream& stream, synchronize::Notification& notification); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/coreml_spec.h b/onnxruntime/core/providers/coreml/builders/coreml_spec.h index 631bb7e258303..9448f1167990e 100644 --- a/onnxruntime/core/providers/coreml/builders/coreml_spec.h +++ b/onnxruntime/core/providers/coreml/builders/coreml_spec.h @@ -3,12 +3,33 @@ #pragma once -// TODO come up with a more intuitive way of limiting this to Apple platform builds -// E.g., putting CoreML EP files that should be enabled iff `defined(__APPLE__)` in a separate directory. -#if !defined(__APPLE__) -#error "This file should only be included when building on Apple platforms." +#include "onnxruntime_config.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic push + +// Disable warning from protobuf code. +// +// In file included from coreml_proto/Model.pb.h:30: +// In file included from _deps/protobuf-src/src/google/protobuf/extension_set.h:53: +// _deps/protobuf-src/src/google/protobuf/parse_context.h:328:47: +// error: implicit conversion loses integer precision: 'long' to 'int' [-Werror,-Wshorten-64-to-32] +#ifdef HAS_SHORTEN_64_TO_32 +#pragma GCC diagnostic ignored "-Wshorten-64-to-32" +#endif +#elif defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4244) // conversion from long to int #endif -#include "coreml/Model.pb.h" +// Model.pb.h is generated in the build output directory from the CoreML protobuf files in +// /_deps/coremltools-src/mlmodel/format +#include "coreml_proto/Model.pb.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#elif defined(_MSC_VER) +#pragma warning(pop) +#endif namespace COREML_SPEC = CoreML::Specification; diff --git a/onnxruntime/core/providers/coreml/builders/helper.cc b/onnxruntime/core/providers/coreml/builders/helper.cc index 897856256cc79..b8ebbd05a2a20 100644 --- a/onnxruntime/core/providers/coreml/builders/helper.cc +++ b/onnxruntime/core/providers/coreml/builders/helper.cc @@ -22,22 +22,35 @@ namespace onnxruntime { namespace coreml { -OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer, uint32_t coreml_flags) { +OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer, + int32_t coreml_version, + uint32_t coreml_flags) { return OpBuilderInputParams{graph_viewer, - (coreml_flags & COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES) != 0}; + coreml_version, + (coreml_flags & COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES) != 0, + (coreml_flags & COREML_FLAG_CREATE_MLPROGRAM) != 0}; } -bool IsNodeSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) { +const IOpBuilder* GetOpBuilder(const Node& node) { const auto& op_builders = GetOpBuilders(); - if (Contains(op_builders, node.OpType())) { - const auto* op_builder = op_builders.at(node.OpType()); + const auto it = op_builders.find(node.OpType()); + if (it != op_builders.cend()) { + return it->second; + } + + return nullptr; +} + +bool IsNodeSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) { + const auto* op_builder = GetOpBuilder(node); + if (op_builder) { return op_builder->IsOpSupported(node, input_params, logger); } else { return false; } } -bool IsInputSupported(const NodeArg& input, const std::string& parent_name, +bool IsInputSupported(const Node& node, const NodeArg& input, const OpBuilderInputParams& input_params, const logging::Logger& logger) { if (!input.Exists()) { // optional input that is not provided @@ -48,8 +61,8 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, std::vector shape; // We do not support input with no shape if (!GetShape(input, shape, logger)) { - LOGS(logger, VERBOSE) << "Input [" << input_name << "] of [" << parent_name - << "] has no shape"; + LOGS(logger, VERBOSE) << MakeString("Input [", input_name, "] of Node [", node.Name(), "] type [", node.OpType(), + "] has no shape"); return false; } @@ -63,11 +76,25 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, // For some undocumented reason, Apple CoreML framework will fail loading the model if the model // input has dimension > 16384 // See this issue, https://github.com/apple/coremltools/issues/1003 + // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf has maximum texture widths which may be the + // root cause. if (dim > 16384) { LOGS(logger, WARNING) << "CoreML does not support input dim > 16384. Input:" << input_name << ", shape: " << Shape2String(shape); return false; } + + if (dim == 0) { + if (node.OpType() == "Resize" && &input == node.InputDefs()[1]) { + // one special case. Resize 'roi' input was originally a required input but is rarely used. + // ROI is not supported in the CoreML implementation so we will ignore the value, but is often added + // (at least in the unit tests) as an initializer with shape {0}. + } else { + LOGS(logger, WARNING) << "CoreML does not support shapes with dimension values of 0. Input:" << input_name + << ", shape: " << Shape2String(shape); + return false; + } + } } // Limit input shape rank to 5. @@ -87,13 +114,6 @@ std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewe const logging::Logger& logger) { std::unordered_set supported_nodes{}; -#ifdef __APPLE__ - if (!util::HasRequiredBaseOS()) { - LOGS(logger, WARNING) << "All ops will fallback to CPU EP, because we do not have supported OS"; - return supported_nodes; - } -#endif - for (const auto& node : graph_viewer.Nodes()) { const bool supported = IsNodeSupported(node, input_params, logger); LOGS(logger, VERBOSE) << "Operator type: [" << node.OpType() @@ -111,7 +131,7 @@ std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewe bool CheckIsConstantInitializer(const NodeArg& node_arg, const GraphViewer& graph_viewer, const logging::Logger& logger, std::string_view input_description) { - if (graph_viewer.GetConstantInitializer(node_arg.Name(), true) == nullptr) { + if (graph_viewer.GetConstantInitializer(node_arg.Name()) == nullptr) { LOGS(logger, VERBOSE) << input_description << " (NodeArg name: '" << node_arg.Name() << "') is not a constant initializer tensor"; return false; @@ -149,7 +169,9 @@ bool HasNeuralEngine(const logging::Logger& logger) { #else // In this case, we are running the EP on non-apple platform, which means we are running the model // conversion with CoreML EP enabled, for this we always assume the target system has Neural Engine - LOGS(logger, VERBOSE) << "HasNeuralEngine running on non-Apple hardware for model conversion only"; + LOGS(logger, INFO) << "HasNeuralEngine running on non-Apple hardware. " + "Returning true to enable model conversion and local testing of CoreML EP implementation. " + "No CoreML model will be compiled or run."; has_neural_engine = true; #endif // #ifdef __APPLE__ diff --git a/onnxruntime/core/providers/coreml/builders/helper.h b/onnxruntime/core/providers/coreml/builders/helper.h index d8b27ac76ae73..300de2dedd122 100644 --- a/onnxruntime/core/providers/coreml/builders/helper.h +++ b/onnxruntime/core/providers/coreml/builders/helper.h @@ -23,10 +23,14 @@ class Logger; namespace coreml { -OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer, uint32_t coreml_flags); +OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer, + int32_t coreml_version, + uint32_t coreml_flags); -bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, - const OpBuilderInputParams& input_params, const logging::Logger& logger); +const IOpBuilder* GetOpBuilder(const Node& node); + +bool IsInputSupported(const Node& node, const NodeArg& node_arg, const OpBuilderInputParams& input_params, + const logging::Logger& logger); bool IsNodeSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/coreml/builders/impl/LRN_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/LRN_op_builder.cc index 53f18b205880c..e9e520156576e 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/LRN_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/LRN_op_builder.cc @@ -3,39 +3,26 @@ #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class LRNOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ - Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + const logging::Logger& /*logger*/) const { + std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_lrn = layer->mutable_lrn(); @@ -56,9 +43,6 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool LRNOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc index 88d6616b4e097..dee87ce3632a8 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc @@ -2,44 +2,32 @@ // Licensed under the MIT License. #include "core/common/narrow.h" +#include "core/framework/tensorprotoutils.h" #include "core/optimizer/initializer.h" #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/framework/tensorprotoutils.h" -#include "core/providers/coreml/builders/impl/builder_utils.h" -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class ActivationOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + int GetMinSupportedOpSet(const Node& node) const override; }; -// Add operator related - -#ifdef __APPLE__ void ActivationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); @@ -86,7 +74,7 @@ Status AddPReluWeight(ModelBuilder& model_builder, const Node& node, Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& op_type(node.OpType()); if (op_type == "Sigmoid") { @@ -115,14 +103,10 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related namespace { // assumes that node.OpType() == "PRelu" -bool IsPReluOpSupported(const Node& node, const OpBuilderInputParams& input_params, - const logging::Logger& logger) { +bool IsPReluOpSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) { const auto& input_defs = node.InputDefs(); // X input rank must be 3 or 4 diff --git a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc index 7a5d4a5af673b..e9a8176c8349b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc @@ -1,37 +1,26 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ +#include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/op_builder_factory.h" - -#include "base_op_builder.h" +#include "core/providers/shared/utils/utils.h" namespace onnxruntime { namespace coreml { class ArgMaxOpBuilder : public BaseOpBuilder { - // Add operator related - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& graph_viewer = model_builder.GetGraphViewer(); NodeAttrHelper helper(node); @@ -67,9 +56,6 @@ Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool ArgMaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc index 25d5bad14ceb6..83a572f4b60fa 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc @@ -1,21 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/coreml/builders/impl/base_op_builder.h" - #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif +using namespace CoreML::Specification; namespace onnxruntime { namespace coreml { -// Shared functions - +namespace { // TODO, move this to shared_library bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node& node, const logging::Logger& logger) { @@ -37,93 +34,83 @@ bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node return false; } +} // namespace -// Add operator related -#ifdef __APPLE__ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node, - const OpBuilderInputParams& input_params, const logging::Logger& logger) const { - ORT_RETURN_IF_NOT( - IsOpSupported(node, input_params, logger), - "Unsupported operator ", - node.OpType()); - - ORT_RETURN_IF_ERROR(AddToModelBuilderImpl(model_builder, node, logger)); - LOGS(logger, VERBOSE) << "Operator name: [" << node.Name() - << "] type: [" << node.OpType() << "] was added"; - return Status::OK(); -} + Status status = AddToModelBuilderImpl(model_builder, node, logger); -/* static */ std::unique_ptr -BaseOpBuilder::CreateNNLayer(ModelBuilder& model_builder, const Node& node) { - auto layer_name = node.Name(); - if (layer_name.empty()) { - // CoreML requires layer has a name, while the node name is optional in ONNX - // In this case, create a unique name for the layer - layer_name = model_builder.GetUniqueName(MakeString("Node_", node.Index(), "_type_", node.OpType())); + if (status.IsOK()) { + LOGS(logger, VERBOSE) << "Operator name: [" << node.Name() << "] type: [" << node.OpType() << "] was added"; } - return CreateNNLayer(layer_name); -} -/* static */ std::unique_ptr -BaseOpBuilder::CreateNNLayer(const std::string& layer_name) { - std::unique_ptr layer = std::make_unique(); - layer->set_name(layer_name); - return layer; + return status; } -#endif - -// Operator support related bool BaseOpBuilder::IsOpSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { - if (!HasSupportedInputs(node, input_params, logger)) + if (input_params.create_mlprogram && !SupportsMLProgram()) { + LOGS(logger, VERBOSE) << "Operator [" << node.OpType() << "] does not support MLProgram"; return false; + } - // We do not support external initializers for now - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); - if (HasExternalInitializer(initializers, node, logger)) + if (!HasSupportedOpSet(node, logger)) { return false; + } - if (!HasSupportedOpSet(node, logger)) + if (!HasSupportedInputs(node, input_params, logger)) { return false; + } + + // We do not support external initializers for now + const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); + if (HasExternalInitializer(initializers, node, logger)) { + return false; + } return IsOpSupportedImpl(node, input_params, logger); } bool BaseOpBuilder::HasSupportedInputs(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { - const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); for (const auto* input : node.InputDefs()) { - if (!IsInputSupported(*input, node_name, input_params, logger)) { + if (!IsInputSupported(node, *input, input_params, logger)) { return false; } } - return HasSupportedInputsImpl(node, logger); + return HasSupportedInputsImpl(node, input_params, logger); } -bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const { - // We only check the type of input 0 by default - // specific op builder can override this - const auto& input = *node.InputDefs()[0]; - - int32_t input_type; - if (!GetType(input, input_type, logger)) +/* static */ +bool BaseOpBuilder::IsInputFloat(const Node& node, size_t idx, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& logger) { + if (idx >= node.InputDefs().size()) { + LOGS(logger, VERBOSE) << "Input index [" << idx << "] is out of range"; return false; + } + + const auto& input = *node.InputDefs()[idx]; - if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - LOGS(logger, VERBOSE) << "[" << node.OpType() - << "] Input type: [" << input_type - << "] is not supported for now"; + int32_t input_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; + + // currently only float is supported + if (!GetType(input, input_type, logger) || input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not currently supported"; return false; } return true; } -bool BaseOpBuilder::HasSupportedOpSet(const Node& node, - const logging::Logger& logger) const { +bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const { + // We only check the type of input 0 by default + // specific op builder can override this + return IsInputFloat(node, 0, input_params, logger); +} + +bool BaseOpBuilder::HasSupportedOpSet(const Node& node, const logging::Logger& logger) const { auto since_version = node.SinceVersion(); if (since_version < GetMinSupportedOpSet(node) || since_version > GetMaxSupportedOpSet(node)) { LOGS(logger, VERBOSE) << node.OpType() << "is only supported for opset [" diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h index b4132d3b770ec..63f0b813d654c 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h @@ -3,11 +3,9 @@ #pragma once -#include "core/providers/coreml/builders/op_builder.h" - -#ifdef __APPLE__ +#include "core/common/span_utils.h" #include "core/providers/coreml/builders/coreml_spec.h" -#endif +#include "core/providers/coreml/builders/op_builder.h" namespace onnxruntime { namespace coreml { @@ -18,45 +16,40 @@ class BaseOpBuilder : public IOpBuilder { public: virtual ~BaseOpBuilder() = default; - // Add operator related + // does the operator implementation support creating an ML Program + bool SupportsMLProgram() const override { return false; } + + bool IsOpSupported(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override final; -#ifdef __APPLE__ - public: - virtual void AddInitializersToSkip(ModelBuilder& /* model_builder */, const Node& /* node */) const override {} Status AddToModelBuilder(ModelBuilder& model_builder, const Node& node, - const OpBuilderInputParams& input_params, const logging::Logger& logger) const override final; - protected: - virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& logger) const = 0; - - static std::unique_ptr - CreateNNLayer(ModelBuilder& model_builder, const Node& node); - - static std::unique_ptr CreateNNLayer(const std::string& layer_name); -#endif - - // Operator support related - public: - bool IsOpSupported(const Node& node, const OpBuilderInputParams& input_params, - const logging::Logger& logger) const override final; + void AddInitializersToSkip(ModelBuilder& /*model_builder*/, const Node& /*node*/) const override {} protected: - virtual bool IsOpSupportedImpl(const Node& /* node */, const OpBuilderInputParams& /* input_params */, - const logging::Logger& /* logger */) const { + // currently we only support float + static bool IsInputFloat(const Node& node, size_t idx, const OpBuilderInputParams& input_params, + const logging::Logger& logger); + + private: + virtual bool IsOpSupportedImpl(const Node& /*node*/, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& /*logger*/) const { return true; } - virtual bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const; + virtual bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const; - virtual int GetMinSupportedOpSet(const Node& /* node */) const { return 1; } - virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 20; } + virtual int GetMinSupportedOpSet(const Node& /*node*/) const { return 1; } + virtual int GetMaxSupportedOpSet(const Node& /*node*/) const { return 20; } - private: bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; bool HasSupportedInputs(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const; + + virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const = 0; }; } // namespace coreml diff --git a/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc index 391b02eaec497..8da58f659acf1 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc @@ -5,30 +5,20 @@ #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class BatchNormalizationOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; @@ -36,9 +26,6 @@ class BatchNormalizationOpBuilder : public BaseOpBuilder { int GetMinSupportedOpSet(const Node& /* node */) const override { return 7; } }; -// Add operator related - -#ifdef __APPLE__ void BatchNormalizationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { // skip everything except input0 for BatchNormalization const auto& input_defs = node.InputDefs(); @@ -48,10 +35,9 @@ void BatchNormalizationOpBuilder::AddInitializersToSkip(ModelBuilder& model_buil model_builder.AddInitializerToSkip(input_defs[4]->Name()); // var } -Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, - const Node& node, +Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& input_defs = node.InputDefs(); const auto& initializers(model_builder.GetInitializerTensors()); @@ -81,9 +67,6 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool BatchNormalizationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc index 10c9b32d03f37..fb8e07633621f 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc @@ -1,35 +1,31 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/framework/tensorprotoutils.h" #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/framework/tensorprotoutils.h" -#include "core/providers/coreml/builders/model_builder.h" -#endif - -#include "base_op_builder.h" namespace onnxruntime { namespace coreml { - class BinaryOpBuilder : public BaseOpBuilder { - // Add operator related - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related + int GetMinSupportedOpSet(const Node& node) const override; - bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; + + bool SupportsMLProgram() const override { return true; } }; -#ifdef __APPLE__ -static bool CheckIfBothInputShapesMatch(const Node& node, const logging::Logger& logger) { +namespace { +bool CheckIfBothInputShapesMatch(const Node& node, const logging::Logger& logger) { const auto& input_defs = node.InputDefs(); const auto* x_shape_proto = input_defs[0]->Shape(); @@ -57,78 +53,94 @@ static bool CheckIfBothInputShapesMatch(const Node& node, const logging::Logger& y_shape_proto->dim().begin(), y_shape_proto->dim().end(), dim_eq); } - -// Add operator related +} // namespace Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { const auto& op_type(node.OpType()); const auto& input_defs(node.InputDefs()); - std::unique_ptr layer = CreateNNLayer(model_builder, node); - - if (op_type == "Add") { - // original mutable_add() has limited broadcasting support - // updated to use CoreML::AddBroadcastableLayerParams which has more general broadcasting support - if (CheckIfBothInputShapesMatch(node, logger)) { - layer->mutable_add(); +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.elementwise_binary + std::string_view coreml_op_type; + if (op_type == "Add") { + coreml_op_type = "add"; + } else if (op_type == "Mul") { + coreml_op_type = "mul"; + } else if (op_type == "Sub") { + coreml_op_type = "sub"; + } else if (op_type == "Div") { + // we only support fp32 currently. when we add support for integers we need to check the type and use + // "floor_div" or "real_div" accordingly + coreml_op_type = "real_div"; + } else if (op_type == "Pow") { + coreml_op_type = "pow"; } else { - layer->mutable_addbroadcastable(); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "BinaryOpBuilder::AddToModelBuilderImpl, unexpected op: ", op_type); } - } else if (op_type == "Mul") { - if (CheckIfBothInputShapesMatch(node, logger)) { - layer->mutable_multiply(); + + std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type); + AddOperationInput(*op, "x", input_defs[0]->Name()); + AddOperationInput(*op, "y", input_defs[1]->Name()); + AddOperationOutput(*op, *node.OutputDefs()[0]); + + model_builder.AddOperation(std::move(op)); + } else +#endif // defined (COREML_ENABLE_MLPROGRAM) + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); + + if (op_type == "Add") { + // original mutable_add() has limited broadcasting support + // updated to use CoreML::AddBroadcastableLayerParams which has more general broadcasting support + if (CheckIfBothInputShapesMatch(node, logger)) { + layer->mutable_add(); + } else { + layer->mutable_addbroadcastable(); + } + } else if (op_type == "Mul") { + if (CheckIfBothInputShapesMatch(node, logger)) { + layer->mutable_multiply(); + } else { + layer->mutable_multiplybroadcastable(); + } + } else if (op_type == "Sub") { + layer->mutable_subtractbroadcastable(); + } else if (op_type == "Div") { + layer->mutable_dividebroadcastable(); + } else if (op_type == "Pow") { + layer->mutable_powbroadcastable(); } else { - layer->mutable_multiplybroadcastable(); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "BinaryOpBuilder::AddToModelBuilderImpl, unexpected op: ", op_type); } - } else if (op_type == "Sub") { - layer->mutable_subtractbroadcastable(); - } else if (op_type == "Div") { - layer->mutable_dividebroadcastable(); - } else if (op_type == "Pow") { - layer->mutable_powbroadcastable(); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "BinaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); - } - *layer->mutable_input()->Add() = input_defs[0]->Name(); - *layer->mutable_input()->Add() = input_defs[1]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + *layer->mutable_input()->Add() = input_defs[0]->Name(); + *layer->mutable_input()->Add() = input_defs[1]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + + model_builder.AddLayer(std::move(layer)); + } - model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related int BinaryOpBuilder::GetMinSupportedOpSet(const Node& /* node */) const { // Add/Sub/Mul/Div opset 6- has broadcast attributes we do not support now return 7; } -bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const { - bool is_pow = node.OpType() == "Pow"; - if (!is_pow) { - return BaseOpBuilder::HasSupportedInputsImpl(node, logger); - } - - const auto& input_1 = *node.InputDefs()[0]; - const auto& input_2 = *node.InputDefs()[1]; - // Pow we only support both inputs as fp32 for now - int32_t input_type_1; - if (!GetType(input_1, input_type_1, logger)) - return false; - - int32_t input_type_2; - if (!GetType(input_2, input_type_2, logger)) - return false; - - if (input_type_1 != ONNX_NAMESPACE::TensorProto_DataType_FLOAT || input_type_1 != input_type_2) { - LOGS(logger, VERBOSE) << "Pow only supports fp32 inputs, actual input type" - << ", Input type 1: " << input_type_1 - << ", Input type 2: " << input_type_2; +bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const { + // Add/Sub/Mul/Div spec says inputs must be of the same type. + // Pow spec says inputs can be different types. + // We only support float for all of these inputs. + if (!IsInputFloat(node, 0, input_params, logger) || + ((node.OpType() == "Pow") && !IsInputFloat(node, 1, input_params, logger))) { return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc index 3b7bd5c1840cc..cbea969904ed5 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc @@ -1,17 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef __APPLE__ - #include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/common/narrow.h" #include "core/framework/tensorprotoutils.h" +#include "core/providers/coreml/builders/coreml_spec.h" #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/shared/utils/utils.h" #include "core/optimizer/initializer.h" -#include "coreml/NeuralNetwork.pb.h" +using namespace COREML_SPEC; namespace onnxruntime { namespace coreml { @@ -133,7 +133,249 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span> shape) { + tensor_type.set_datatype(data_type); + if (shape) { + tensor_type.set_rank(shape->size()); + for (const auto& dim : *shape) { + if (dim >= 0) { + tensor_type.add_dimensions()->mutable_constant()->set_size(narrow(dim)); + } else { + tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false); + } + } + } +} + +void SetTensorTypeInfo(MILSpec::TensorType& tensor_type, MILSpec::DataType data_type, + const ONNX_NAMESPACE::TensorShapeProto* shape) { + tensor_type.set_datatype(data_type); + if (shape) { + tensor_type.set_rank(shape->dim_size()); + for (const auto& dim : shape->dim()) { + if (dim.has_dim_value()) { + tensor_type.add_dimensions()->mutable_constant()->set_size(narrow(dim.dim_value())); + } else { + tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false); + } + } + } +} + +template +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + // need a 'false' that is dependent on the template types to make gcc happy and give a meaningful error message. + static_assert(false_for_T && false_for_T, "Unsupported data type"); // add specializations below as needed +} + +template <> +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + tensor_value.mutable_floats()->mutable_values()->Add(data.begin(), data.end()); +} + +template <> +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + tensor_value.mutable_ints()->mutable_values()->Add(data.begin(), data.end()); +} + +template <> +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + tensor_value.mutable_strings()->mutable_values()->Add(data.begin(), data.end()); +} + +// copy int64_t (used by ONNX for strides/indexes/etc.) to int32_t (used by CoreML) +template <> +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + auto& int32_out = *tensor_value.mutable_ints()->mutable_values(); + int32_out.Reserve(narrow(data.size())); + for (const int64_t v : data) { + int32_out.AddAlreadyReserved(narrow(v)); + } +} + +template <> +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + tensor_value.mutable_bools()->mutable_values()->Add(data.begin(), data.end()); +} + +} // namespace + +MILSpec::DataType OnnxDataTypeToMILSpec(int onnx_type) { + switch (static_cast(onnx_type)) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + return MILSpec::DataType::FLOAT32; + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: + return MILSpec::DataType::FLOAT64; + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: + return MILSpec::DataType::BFLOAT16; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + return MILSpec::DataType::FLOAT16; + + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + return MILSpec::DataType::INT8; + case ONNX_NAMESPACE::TensorProto_DataType_INT16: + return MILSpec::DataType::INT16; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + return MILSpec::DataType::INT32; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + return MILSpec::DataType::INT64; + + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + return MILSpec::DataType::UINT8; + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: + return MILSpec::DataType::UINT16; + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + return MILSpec::DataType::UINT32; + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: + return MILSpec::DataType::UINT64; + + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + return MILSpec::DataType::BOOL; + case ONNX_NAMESPACE::TensorProto_DataType_STRING: + return MILSpec::DataType::STRING; + default: + ORT_THROW("Unsupported data type: ", onnx_type); + } +} + +template +MILSpec::Value CreateTensorValue(const gsl::span data, + std::optional> shape) { + MILSpec::Value value; + MILSpec::TensorType& tensor_type = *value.mutable_type()->mutable_tensortype(); + + if (shape) { + SetTensorTypeInfo(tensor_type, DataTypeToMILSpec(), *shape); + } else { + // infer as 1D shape + std::vector coreml_shape{narrow(data.size())}; + SetTensorTypeInfo(tensor_type, DataTypeToMILSpec(), coreml_shape); + } + + MILSpec::TensorValue& tensor_value = *value.mutable_immediatevalue()->mutable_tensor(); + CopyDataToTensorValue(tensor_value, data); + + return value; +} + +template +MILSpec::Value CreateScalarTensorValue(const T& data) { + gsl::span data_span{&data, 1}; + std::vector shape = {}; // empty for scalar + return CreateTensorValue(data_span, shape); +} + +// explicit specializations for types we handle so the implementation can be in the .cc file +template MILSpec::Value CreateTensorValue(gsl::span data, + std::optional> shape); + +template MILSpec::Value CreateScalarTensorValue(const float& data); +template MILSpec::Value CreateScalarTensorValue(const int32_t& data); +template MILSpec::Value CreateScalarTensorValue(const std::string& data); +template MILSpec::Value CreateScalarTensorValue(const bool& data); + +COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg) { + MILSpec::NamedValueType nvt; + nvt.set_name(node_arg.Name()); + MILSpec::TensorType& tensor_type = *nvt.mutable_type()->mutable_tensortype(); + + SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(node_arg.TypeAsProto()->tensor_type().elem_type()), + node_arg.Shape()); + + return nvt; +} + +void AddOperationInput(MILSpec::Operation& op, std::string_view input_name, std::string_view value_name) { + MILSpec::Argument arg; + arg.mutable_arguments()->Add()->set_name(std::string(value_name)); + + (*op.mutable_inputs())[input_name] = std::move(arg); +} + +void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& output) { + auto& outputs = *op.mutable_outputs(); + auto& output_arg = *outputs.Add(); + output_arg.set_name(output.Name()); + + MILSpec::ValueType& value = *output_arg.mutable_type(); + MILSpec::TensorType& tensor_type = *value.mutable_tensortype(); + + SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(output.TypeAsProto()->tensor_type().elem_type()), + output.Shape()); +} + +void AddPadTypeAndPads(COREML_SPEC::MILSpec::Operation& op, ModelBuilder& model_builder, std::string_view op_type, + const NodeAttrHelper& helper, int num_spatial_dims) { + AutoPadType auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET")); + + // pad type (string) + // valid - no pads (ONNX auto_pad VALID) + // custom - pads input (ONNX NOTSET) + // same - inferred to be `d_out[i] = ceil(d_in[i] / strides[i])` (assuming == ONNX SAME_UPPER) + // same_lower - as per same but any extra rows/cols are added at top/left if padding is odd (ONNX SAME_LOWER) + // + // TODO: See if we want to update HandleAutoPad to support 1D (and 3D) so we can infer if an autopad value + // can be used. TBD if that provides any performance benefit with ML Program though as CoreML could + // potentially do that same optimization internally. + switch (auto_pad_type) { + case AutoPadType::NOTSET: { + // use `pads` attribute. + auto onnx_pads = helper.GetInt64s("pads"); // 'pads' are used if auto_pad is NOTSET + if (onnx_pads) { + AddOperationInput(op, "pad_type", + model_builder.AddScalarConstant(op_type, "pad_type", std::string("custom"))); + + // need to re-order from x1_start, x2_start..., x1_end, x2_end... to + // x1_start, x1_end, x2_start, x2_end,... + size_t num_pads = onnx_pads->size(); + size_t num_dims = num_pads / 2; + std::vector reordered_pads(num_pads, 0); + for (size_t i = 0; i < num_pads; ++i) { + auto cur_dim = i % num_dims; + if (i < num_dims) { // start values + reordered_pads[cur_dim * 2] = (*onnx_pads)[i]; + } else { // end values + reordered_pads[cur_dim * 2 + 1] = (*onnx_pads)[i]; + } + } + + AddOperationInput(op, "pad", model_builder.AddConstant(op_type, "pad", reordered_pads)); + + break; + } + + // fall through if explicit pads were not provided as the default value for `pads` is all zeros, + // which is the same as 'valid' padding. + [[fallthrough]]; + } + case AutoPadType::VALID: + AddOperationInput(op, "pad_type", + model_builder.AddScalarConstant(op_type, "pad_type", std::string("valid"))); + + break; + case AutoPadType::SAME_UPPER: + case AutoPadType::SAME_LOWER: { + const auto pad_type = (auto_pad_type == AutoPadType::SAME_UPPER ? "same" : "same_lower"); + AddOperationInput(op, "pad_type", + model_builder.AddScalarConstant(op_type, "pad_type", std::string(pad_type))); + + // despite what the spec says, a 'pad' input seems to be required. + // https://github.com/apple/coremltools/issues/2127 + // Provide the default value as that's what coremltools does for conv/avg_pool/max_pool. + std::vector ignored_pads(num_spatial_dims * 2, 0); + AddOperationInput(op, "pad", model_builder.AddConstant(op_type, "pad", ignored_pads)); + + break; + } + } +} +#endif // defined(COREML_ENABLE_MLPROGRAM) } // namespace coreml } // namespace onnxruntime - -#endif diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h index 23b11928f7dc2..2804589065631 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h @@ -5,21 +5,20 @@ #pragma once -#ifdef __APPLE__ +#include #include "core/common/gsl.h" #include "core/common/status.h" #include "core/graph/basic_types.h" #include "core/providers/common.h" - -namespace CoreML { -namespace Specification { -class WeightParams; -} -} // namespace CoreML +#include "core/providers/coreml/builders/coreml_spec.h" +#include "core/providers/shared/utils/utils.h" namespace onnxruntime { +class NodeArg; + namespace coreml { +class ModelBuilder; // Try to see if we can map explicit padding to auto padding for Conv/Pool // Since usually use auto padding is more efficient @@ -32,6 +31,10 @@ Status HandleAutoPad(const std::vector input_shape, AutoPadType auto_pad_type, AutoPadType& auto_pad_type_out); +// +// NeuralNetwork utils +// + // Copy an onnx initializer data to a coreml weight Status CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, const ONNX_NAMESPACE::TensorProto& tensor); @@ -44,7 +47,103 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data); +#if defined(COREML_ENABLE_MLPROGRAM) +// +// MLProgram utils +// + +// helper for static_assert where the value needs to be dependent on a template parameter +template +constexpr bool false_for_T = false; + +template +COREML_SPEC::MILSpec::DataType DataTypeToMILSpec() { + if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::FLOAT32; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::FLOAT64; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::BFLOAT16; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::FLOAT16; + + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::INT8; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::INT16; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::INT32; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::INT64; + + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::UINT8; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::UINT16; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::UINT32; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::UINT64; + + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::BOOL; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::STRING; + } else { + static_assert(false_for_T, "Unsupported type."); + } +} + +// The TensorProto.data_type field is an int, but must be a valid TensorProto_DataType value. +// Use int for the arg so the caller can pass TensorProto.data_type() value and do the cast to enum internally +COREML_SPEC::MILSpec::DataType OnnxDataTypeToMILSpec(int onnx_type); + +/// +/// Create a CoreML MILSpec::TensorValue for the given input data. +/// +/// Original C++ data type +/// CoreML C++ data type +/// ONNX data +/// ONNX data shape. Inferred to be a 1D shape of `{data.size()}` if not specified. +/// TensorValue containing data. +template +COREML_SPEC::MILSpec::Value CreateTensorValue(gsl::span data, + std::optional> shape = std::nullopt); + +template +COREML_SPEC::MILSpec::Value CreateScalarTensorValue(const T& data); + +/// Create a NamedValueType from an ONNX tensor NodeArg. +/// Used to create inputs for the 'main' function in an ML Program. +COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg); + +/// +/// Add an input argument to a MILSpec::Operation +/// +/// Operation to update. +/// The input name defined by the spec for the operation. +/// The name of the value that is providing the input. +/// "https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html" +void AddOperationInput(COREML_SPEC::MILSpec::Operation& op, + std::string_view input_name, std::string_view value_name); + +/// +/// Add an output to a MILSpec::Operation. Name, data type and shape are used from the NodeArg. +/// +/// Operation to update. +/// NodeArg with details of output to add. +void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& output); + +/// +/// Add pad_type and pad values. +/// +/// Operator to update +/// ModelBuilder to add constants with. +/// Operator type. +/// Node attribute helper. +/// Number of spatial dims in input. Generally rank - 2 (ignore N and C dims). +void AddPadTypeAndPads(COREML_SPEC::MILSpec::Operation& op, ModelBuilder& model_builder, std::string_view op_type, + const NodeAttrHelper& helper, int num_spatial_dims); +#endif // defined(COREML_ENABLE_MLPROGRAM) } // namespace coreml } // namespace onnxruntime - -#endif diff --git a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc index 15ee1f0fc7284..70053c2c606a0 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc @@ -1,34 +1,25 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/shared/utils/utils.h" #include "core/providers/coreml/builders/helper.h" -#ifdef __APPLE__ +#include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/op_builder_factory.h" - -#include "base_op_builder.h" +#include "core/providers/shared/utils/utils.h" namespace onnxruntime { namespace coreml { class CastOpBuilder : public BaseOpBuilder { - // Add operator related - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const override; -}; -// Add operator related + bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; +}; -#ifdef __APPLE__ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& /* model_builder */, const Node& /* node */, const logging::Logger& /* logger */) const { @@ -37,9 +28,6 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& /* model_builder */, // Cast node is not provided in CoreML model, so we're skipping adding the Cast node here. return Status::OK(); } -#endif - -// Operator support related bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { @@ -84,7 +72,8 @@ bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara return true; } -bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const { +bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& logger) const { // We only check the type of input 0 const auto& input = *node.InputDefs()[0]; diff --git a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc index 3a3f89d24c7d8..41f4041ef1181 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc @@ -1,40 +1,48 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef __APPLE__ +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" -#include "base_op_builder.h" - namespace onnxruntime { namespace coreml { class ClipOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; -}; -// Add operator related + bool SupportsMLProgram() const override { return true; } +}; -#ifdef __APPLE__ void ClipOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + bool skip = true; + + if (model_builder.CreateMLProgram()) { + float min, max; + ORT_IGNORE_RETURN_VALUE(GetClipMinMax(model_builder.GetGraphViewer(), node, min, max, model_builder.Logger())); + + bool has_min = min != std::numeric_limits::lowest(); + bool has_max = max != std::numeric_limits::max(); + if (has_min && has_max && min == 0.f && max == 6.f) { + // relu6 - skip both + } else if (has_min && min == 0.f && !has_max) { + // relu - skip both + } else { + // clip - we will use both + skip = false; + } + } + // Both min and max values will be injected into the layer, no need to add to the model - if (node.SinceVersion() >= 11) { + if (skip && node.SinceVersion() >= 11) { if (node.InputDefs().size() > 1) model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); @@ -48,92 +56,137 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const logging::Logger& logger) const { const auto& node_name = node.Name(); const auto& input_name = node.InputDefs()[0]->Name(); - const auto& output_name = node.OutputDefs()[0]->Name(); + const auto& output = *node.OutputDefs()[0]; + const auto& output_name = output.Name(); float min, max; - ORT_RETURN_IF_NOT(GetClipMinMax(model_builder.GetInitializerTensors(), node, min, max, logger), "GetClipMinMax failed"); + ORT_RETURN_IF_NOT(GetClipMinMax(model_builder.GetGraphViewer(), node, min, max, logger), "GetClipMinMax failed"); bool has_min = min != std::numeric_limits::lowest(); bool has_max = max != std::numeric_limits::max(); - if (!has_min && !has_max) { - // Clip without min/max is an identity node - // In CoreML we don't have identity, use ActivationLinear instead - std::unique_ptr layer = CreateNNLayer(model_builder, node); - layer->mutable_activation()->mutable_linear()->set_alpha(1.0f); - *layer->mutable_input()->Add() = input_name; - *layer->mutable_output()->Add() = output_name; - - model_builder.AddLayer(std::move(layer)); - } else { - // The implementation of clip(min, max) is done by - // 1. Clipping at min -> max(input, min) is handled by - // min_output = threshold(input, min) - // 2. Clipping at max -> min(min_output, max) is handled by - // output = -1 * (threshold(-min_output, -max)) - - // Now we have at least one or min or max is not default value - // Clipping at max will need take the output of clipping at min, or the node input, if min value is default - // If max value is default, the output of clipping at min will be the output of the node - std::string min_output_name = output_name; - if (has_max) { - min_output_name = has_min - ? model_builder.GetUniqueName(node_name + "min_output") - : input_name; +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + + std::unique_ptr op; + if (!has_min && !has_max) { + // Clip without min/max is an identity node. + op = model_builder.CreateOperation(node, "identity"); + Operation& identity_op = *op; + AddOperationInput(identity_op, "x", input_name); + } else { + if (has_min && has_max && min == 0.f && max == 6.f) { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.activation.relu6 + op = model_builder.CreateOperation(node, "relu6"); + Operation& relu6_op = *op; + AddOperationInput(relu6_op, "x", input_name); + } else if (has_min && min == 0.f && !has_max) { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.activation.relu + op = model_builder.CreateOperation(node, "relu"); + Operation& relu_op = *op; + AddOperationInput(relu_op, "x", input_name); + } else { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.elementwise_unary.clip + op = model_builder.CreateOperation(node, "clip"); + + Operation& clip_op = *op; + AddOperationInput(clip_op, "x", input_name); + + // if min and max were attributes we need to add initializers. otherwise we use the existing inputs + const bool min_max_attribs = node.SinceVersion() < 11; + std::string_view min_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "min", min) + : node.InputDefs()[1]->Name(); + + AddOperationInput(clip_op, "alpha", min_name); + + if (has_max) { + std::string_view max_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "max", max) + : node.InputDefs()[2]->Name(); + AddOperationInput(clip_op, "beta", max_name); + } + } } - // Handle clipping at min first - if (has_min) { - const auto clip_min_layer_name = model_builder.GetUniqueName(MakeString(node_name, "_Clip_min")); - std::unique_ptr min_layer = CreateNNLayer(clip_min_layer_name); - if (min == 0.0f) { // If min is 0. then this min will be handled by relu - min_layer->mutable_activation()->mutable_relu(); - } else { // otherwise, min will be handled by unary->threshold - min_layer->mutable_unary()->set_alpha(min); - min_layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::THRESHOLD); + AddOperationOutput(*op, output); + model_builder.AddOperation(std::move(op)); + } else +#endif // defined(COREML_ENABLE_MLPROGRAM) + { + // TODO: CoreML has a Clip layer for NeuralNetwork. Added in CoreML 4. We could potentially use that if available + // to simplify. + // https://apple.github.io/coremltools/mlmodel/Format/NeuralNetwork.html#cliplayerparams + + if (!has_min && !has_max) { + // Clip without min/max is an identity node + // In CoreML we don't have identity, use ActivationLinear instead + std::unique_ptr layer = model_builder.CreateNNLayer(node); + layer->mutable_activation()->mutable_linear()->set_alpha(1.0f); + *layer->mutable_input()->Add() = input_name; + *layer->mutable_output()->Add() = output_name; + + model_builder.AddLayer(std::move(layer)); + } else { + // The implementation of clip(min, max) is done by + // 1. Clipping at min -> max(input, min) is handled by + // min_output = threshold(input, min) + // 2. Clipping at max -> min(min_output, max) is handled by + // output = -1 * (threshold(-min_output, -max)) + + // Now we have at least one or min or max is not default value + // Clipping at max will need take the output of clipping at min, or the node input, if min value is default + // If max value is default, the output of clipping at min will be the output of the node + std::string min_output_name = output_name; + if (has_max) { + min_output_name = has_min + ? model_builder.GetUniqueName(node_name + "min_output") + : input_name; } - *min_layer->mutable_input()->Add() = input_name; - *min_layer->mutable_output()->Add() = min_output_name; - model_builder.AddLayer(std::move(min_layer)); - } - - // Clipping at max is handled by -1 * (threshold (-min_output, -max)) - if (has_max) { - const auto threshold_output_name = model_builder.GetUniqueName(MakeString(node_name, "threshold_output")); - { // Add threshold layer, which is actually max( -1 * min_output, -max) - const auto clip_max_threshold_layer_name = - model_builder.GetUniqueName(MakeString(node_name, "_Clip_max_threshold")); - auto threshold_layer = CreateNNLayer(clip_max_threshold_layer_name); - threshold_layer->mutable_unary()->set_alpha(-max); - threshold_layer->mutable_unary()->set_scale(-1.0f); - threshold_layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::THRESHOLD); - *threshold_layer->mutable_input()->Add() = min_output_name; - *threshold_layer->mutable_output()->Add() = threshold_output_name; - model_builder.AddLayer(std::move(threshold_layer)); + // Handle clipping at min first + if (has_min) { + std::unique_ptr min_layer = model_builder.CreateNNLayer(node, "_Clip_min"); + if (min == 0.0f) { // If min is 0. then this min will be handled by relu + min_layer->mutable_activation()->mutable_relu(); + } else { // otherwise, min will be handled by unary->threshold + min_layer->mutable_unary()->set_alpha(min); + min_layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::THRESHOLD); + } + + *min_layer->mutable_input()->Add() = input_name; + *min_layer->mutable_output()->Add() = min_output_name; + model_builder.AddLayer(std::move(min_layer)); } - { // Add linear activation layer -1 * threshold_output - const auto clip_max_linear_layer_name = - model_builder.GetUniqueName(MakeString(node_name, "_Clip_max_linear")); - auto linear_layer = CreateNNLayer(clip_max_linear_layer_name); - linear_layer->mutable_activation()->mutable_linear()->set_alpha(-1.0f); - *linear_layer->mutable_input()->Add() = threshold_output_name; - *linear_layer->mutable_output()->Add() = output_name; - model_builder.AddLayer(std::move(linear_layer)); + + // Clipping at max is handled by -1 * (threshold (-min_output, -max)) + if (has_max) { + const auto threshold_output_name = model_builder.GetUniqueName(MakeString(node_name, "threshold_output")); + { // Add threshold layer, which is actually max( -1 * min_output, -max) + auto threshold_layer = model_builder.CreateNNLayer(node, "_Clip_max_threshold"); + threshold_layer->mutable_unary()->set_alpha(-max); + threshold_layer->mutable_unary()->set_scale(-1.0f); + threshold_layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::THRESHOLD); + *threshold_layer->mutable_input()->Add() = min_output_name; + *threshold_layer->mutable_output()->Add() = threshold_output_name; + model_builder.AddLayer(std::move(threshold_layer)); + } + { // Add linear activation layer -1 * threshold_output + auto linear_layer = model_builder.CreateNNLayer(node, "_Clip_max_linear"); + linear_layer->mutable_activation()->mutable_linear()->set_alpha(-1.0f); + *linear_layer->mutable_input()->Add() = threshold_output_name; + *linear_layer->mutable_output()->Add() = output_name; + model_builder.AddLayer(std::move(linear_layer)); + } } } } return Status::OK(); } -#endif - -// Operator support related bool ClipOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { float min, max; - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); - return GetClipMinMax(initializers, node, min, max, logger); + return GetClipMinMax(input_params.graph_viewer, node, min, max, logger); } void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc index b1e761024f5c9..34193318a0264 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc @@ -4,37 +4,26 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class ConcatOpBuilder : public BaseOpBuilder { - // Add operator related - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); layer->mutable_concat()->set_sequenceconcat(false); @@ -48,9 +37,7 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif -// Operator support related bool ConcatOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /* input_params */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc index ff9dcbd9f8874..38125957bf481 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc @@ -4,39 +4,35 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" -#include "core/providers/coreml/builders/op_builder_factory.h" -#include "core/providers/shared/utils/utils.h" - -#ifdef __APPLE__ #include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" +#include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" -#endif +#include "core/providers/shared/utils/utils.h" + +using namespace CoreML::Specification; namespace onnxruntime { namespace coreml { class ConvOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& /* node */, const OpBuilderInputParams& /* input_params */, const logging::Logger& /* logger */) const override; -}; -// Add operator related + bool SupportsMLProgram() const override { return true; } +}; -#ifdef __APPLE__ void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + if (model_builder.CreateMLProgram()) { + // we add the initializers as 'const' operations via ModelBuilder::RegisterInitializers + return; + } + const auto& input_defs = node.InputDefs(); // skip the weight and bias (if has it) for conv as we will directly set those as part of the NN layer @@ -49,136 +45,177 @@ void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Nod Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); - const auto& input_defs = node.InputDefs(); const auto& output_defs = node.OutputDefs(); const auto& input_name = input_defs[0]->Name(); const auto& output_name = output_defs[0]->Name(); - const auto& weight_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); - std::vector weight_shape = {weight_tensor.dims().cbegin(), weight_tensor.dims().cend()}; + NodeAttrHelper helper(node); - const bool is_1d_conv = (weight_shape.size() == 3); +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; - if (is_1d_conv) { - // weight_shape needs to be expanded from MXCXH->MXCXHx1 - weight_shape.push_back(1); - } + // https://github.com/apple/coremltools/blob/7.1/coremltools/converters/mil/mil/ops/defs/iOS15/conv.py - NodeAttrHelper helper(node); - auto strides = helper.Get("strides", std::vector{1, 1}); - auto dilations = helper.Get("dilations", std::vector{1, 1}); - auto onnx_pads = helper.Get("pads", std::vector{0, 0, 0, 0}); - // Strides/dilations for 1d conv is normally of length 1. Expand them by 1 - // to meet the required length 2 (for 2d conv it's normally 2) - // Similarly 1d conv normally has a length 2 padding. Expand it to length 4 by adding additional zeros. - if (is_1d_conv) { - if (strides.size() < 2) { - ORT_RETURN_IF_NOT(strides.size() == 1, "strides size does not equal 1 for Conv 1d"); - strides.push_back(1); + std::unique_ptr conv_op = model_builder.CreateOperation(node, "conv"); + + AddOperationInput(*conv_op, "x", input_name); + AddOperationInput(*conv_op, "weight", input_defs[1]->Name()); + + if (input_defs.size() > 2) { + AddOperationInput(*conv_op, "bias", input_defs[2]->Name()); } - if (dilations.size() < 2) { - ORT_RETURN_IF_NOT(dilations.size() == 1, "dilations size does not equal 1 for Conv 1d"); - dilations.push_back(1); + + // we know this input has a valid shape due to the check in IsOpSupportedImpl. ignore N and C dims. + const auto num_spatial_dims = input_defs[1]->Shape()->dim_size() - 2; + const auto& op_type = conv_op->type(); + + // Spec says strides and dilations are optional, but reality is they're required for at least the iOS15 target + // (CoreML5). + const auto strides = helper.Get("strides", std::vector(num_spatial_dims, 1)); + auto dilations = helper.Get("dilations", std::vector(num_spatial_dims, 1)); + auto groups = helper.GetInt64("group"); + + AddOperationInput(*conv_op, "strides", model_builder.AddConstant(op_type, "strides", strides)); + AddOperationInput(*conv_op, "dilations", model_builder.AddConstant(op_type, "dilations", dilations)); + + if (groups) { + AddOperationInput(*conv_op, "groups", model_builder.AddScalarConstant(op_type, "groups", *groups)); } - if (onnx_pads.size() < 4) { - ORT_RETURN_IF_NOT(onnx_pads.size() == 2, "onnx_pads size does not equal 2 for Conv 1d"); - onnx_pads.insert(onnx_pads.begin() + 1, 0); - onnx_pads.push_back(0); + + AddPadTypeAndPads(*conv_op, model_builder, op_type, helper, num_spatial_dims); + + AddOperationOutput(*conv_op, *node.OutputDefs()[0]); + + model_builder.AddOperation(std::move(conv_op)); + } else +#endif // defined(COREML_ENABLE_MLPROGRAM) + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); + + auto strides = helper.Get("strides", std::vector{1, 1}); + auto dilations = helper.Get("dilations", std::vector{1, 1}); + auto onnx_pads = helper.Get("pads", std::vector{0, 0, 0, 0}); + const auto group = helper.Get("group", static_cast(1)); + + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + + const auto& weight_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); + std::vector weight_shape = {weight_tensor.dims().cbegin(), weight_tensor.dims().cend()}; + + const bool is_1d_conv = (weight_shape.size() == 3); + + // add dummy 'W' dim with value of 1 so we can use 2D conv. + if (is_1d_conv) { + input_shape.push_back(1); + weight_shape.push_back(1); + + // Strides/dilations for 1d conv is normally of length 1. Expand them by 1 + // to meet the required length 2 (for 2d conv it's normally 2) + if (strides.size() < 2) { + ORT_RETURN_IF_NOT(strides.size() == 1, "strides size does not equal 1 for Conv 1d"); + strides.push_back(1); + } + + if (dilations.size() < 2) { + ORT_RETURN_IF_NOT(dilations.size() == 1, "dilations size does not equal 1 for Conv 1d"); + dilations.push_back(1); + } + + // Similarly 1d conv normally has a length 2 padding. Expand it to length 4 by adding additional zeros. + if (onnx_pads.size() < 4) { + ORT_RETURN_IF_NOT(onnx_pads.size() == 2, "onnx_pads size does not equal 2 for Conv 1d"); + onnx_pads.insert(onnx_pads.begin() + 1, 0); + onnx_pads.push_back(0); + } } - } - const auto group = helper.Get("group", static_cast(1)); - - auto* coreml_conv = layer->mutable_convolution(); - - std::string expand_output_name = model_builder.GetUniqueName(node.Name() + "_expandDims"); - - if (is_1d_conv) { - const auto expand_layer_name = model_builder.GetUniqueName(MakeString(node.Name(), "_Conv_expand")); - std::unique_ptr expand_layer = CreateNNLayer(expand_layer_name); - // Add an expanddims layer here. CoreML only supports 2d convolution, so for 1d Conv case - // we need to add an additional dimension here to the input to make it "2d Conv" like. - // NxCxH -> NxCxHx1 - expand_layer->mutable_expanddims()->add_axes(-1); - *expand_layer->mutable_input()->Add() = input_name; - *expand_layer->mutable_output()->Add() = expand_output_name; - model_builder.AddLayer(std::move(expand_layer)); - } - coreml_conv->set_outputchannels(weight_shape[0]); // M - coreml_conv->set_kernelchannels(weight_shape[1]); // C/Group - coreml_conv->add_kernelsize(weight_shape[2]); // H - coreml_conv->add_kernelsize(weight_shape[3]); // W - coreml_conv->set_ngroups(group); - *coreml_conv->mutable_stride() = {strides.cbegin(), strides.cend()}; - *coreml_conv->mutable_dilationfactor() = {dilations.cbegin(), dilations.cend()}; - - coreml_conv->set_isdeconvolution(false); - - // Add Padding - // Usually using autopadding is more efficient than using explicit padding - // Try to see if we can map explicit padding to auto padding - std::vector input_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); - AutoPadType auto_pad_type; - ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3], - onnx_pads, strides, dilations, - StringToAutoPadType(helper.Get("auto_pad", "NOTSET")), - auto_pad_type)); - - if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { - auto* padding_type = coreml_conv->mutable_same(); - if (AutoPadType::SAME_LOWER == auto_pad_type) { // default is SAME_UPPER - padding_type->set_asymmetrymode(COREML_SPEC::SamePadding_SamePaddingMode_TOP_LEFT_HEAVY); + + auto* coreml_conv = layer->mutable_convolution(); + + std::string expand_output_name = model_builder.GetUniqueName(node.Name() + "_expandDims"); + + if (is_1d_conv) { + // Add an expanddims layer here. CoreML only supports 2d convolution, so for 1d Conv case + // we need to add an additional dimension here to the input to make it "2d Conv" like. + // NxCxH -> NxCxHx1 + auto expand_layer = model_builder.CreateNNLayer(node, "_Conv_expand"); + expand_layer->mutable_expanddims()->add_axes(-1); + *expand_layer->mutable_input()->Add() = input_name; + *expand_layer->mutable_output()->Add() = expand_output_name; + model_builder.AddLayer(std::move(expand_layer)); } - } else { - auto* padding_type = coreml_conv->mutable_valid(); - if (AutoPadType::NOTSET == auto_pad_type && onnx_pads != std::vector{0, 0, 0, 0}) { - // NOTSET is adding the explicit padding to the ValidPadding.paddingAmounts - auto* height_border = padding_type->mutable_paddingamounts()->add_borderamounts(); - height_border->set_startedgesize(onnx_pads[0]); - height_border->set_endedgesize(onnx_pads[2]); - auto* width_border = padding_type->mutable_paddingamounts()->add_borderamounts(); - width_border->set_startedgesize(onnx_pads[1]); - width_border->set_endedgesize(onnx_pads[3]); + + coreml_conv->set_outputchannels(weight_shape[0]); // M + coreml_conv->set_kernelchannels(weight_shape[1]); // C/Group + coreml_conv->add_kernelsize(weight_shape[2]); // H + coreml_conv->add_kernelsize(weight_shape[3]); // W + coreml_conv->set_ngroups(group); + *coreml_conv->mutable_stride() = {strides.cbegin(), strides.cend()}; + *coreml_conv->mutable_dilationfactor() = {dilations.cbegin(), dilations.cend()}; + + coreml_conv->set_isdeconvolution(false); + + // Add Padding + // Usually using autopadding is more efficient than using explicit padding + // Try to see if we can map explicit padding to auto padding + AutoPadType auto_pad_type; + ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3], + onnx_pads, strides, dilations, + StringToAutoPadType(helper.Get("auto_pad", "NOTSET")), + auto_pad_type)); + + if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { + auto* padding_type = coreml_conv->mutable_same(); + if (AutoPadType::SAME_LOWER == auto_pad_type) { // default is SAME_UPPER + padding_type->set_asymmetrymode(COREML_SPEC::SamePadding_SamePaddingMode_TOP_LEFT_HEAVY); + } + } else { + auto* padding_type = coreml_conv->mutable_valid(); + if (AutoPadType::NOTSET == auto_pad_type && onnx_pads != std::vector{0, 0, 0, 0}) { + // NOTSET is adding the explicit padding to the ValidPadding.paddingAmounts + auto* height_border = padding_type->mutable_paddingamounts()->add_borderamounts(); + height_border->set_startedgesize(onnx_pads[0]); + height_border->set_endedgesize(onnx_pads[2]); + auto* width_border = padding_type->mutable_paddingamounts()->add_borderamounts(); + width_border->set_startedgesize(onnx_pads[1]); + width_border->set_endedgesize(onnx_pads[3]); + } } - } - // Add weight - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_conv->mutable_weights(), weight_tensor)); + // Add weight + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_conv->mutable_weights(), weight_tensor)); - // Add bias if present - if (input_defs.size() > 2) { - coreml_conv->set_hasbias(true); - const auto& bias_tensor = *model_builder.GetInitializerTensors().at(input_defs[2]->Name()); - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_conv->mutable_bias(), bias_tensor)); - } + // Add bias if present + if (input_defs.size() > 2) { + coreml_conv->set_hasbias(true); + const auto& bias_tensor = *model_builder.GetConstantInitializer(input_defs[2]->Name()); + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_conv->mutable_bias(), bias_tensor)); + } - if (is_1d_conv) { - std::string conv_output_name = model_builder.GetUniqueName(node.Name() + "_conv_output"); - *layer->mutable_input()->Add() = expand_output_name; - *layer->mutable_output()->Add() = conv_output_name; - model_builder.AddLayer(std::move(layer)); - - // Add a squeeze layer here. Since CoreML only supports 2d conv and we expanded the dimension by 1 before, - // we need to squeeze it back from NxCxHx1->NxCxH. - const auto squeeze_layer_name = model_builder.GetUniqueName(MakeString(node.Name(), "_Conv_squeeze")); - std::unique_ptr squeeze_layer = CreateNNLayer(squeeze_layer_name); - squeeze_layer->mutable_squeeze()->add_axes(-1); - *squeeze_layer->mutable_input()->Add() = conv_output_name; - *squeeze_layer->mutable_output()->Add() = output_name; - model_builder.AddLayer(std::move(squeeze_layer)); - } else { - *layer->mutable_input()->Add() = input_name; - *layer->mutable_output()->Add() = output_name; - model_builder.AddLayer(std::move(layer)); + if (is_1d_conv) { + std::string conv_output_name = model_builder.GetUniqueName(node.Name() + "_conv_output"); + *layer->mutable_input()->Add() = expand_output_name; + *layer->mutable_output()->Add() = conv_output_name; + model_builder.AddLayer(std::move(layer)); + + // Add a squeeze layer here. Since CoreML only supports 2d conv and we expanded the dimension by 1 before, + // we need to squeeze it back from NxCxHx1->NxCxH. + auto squeeze_layer = model_builder.CreateNNLayer(node, "_Conv_squeeze"); + squeeze_layer->mutable_squeeze()->add_axes(-1); + *squeeze_layer->mutable_input()->Add() = conv_output_name; + *squeeze_layer->mutable_output()->Add() = output_name; + model_builder.AddLayer(std::move(squeeze_layer)); + } else { + *layer->mutable_input()->Add() = input_name; + *layer->mutable_output()->Add() = output_name; + model_builder.AddLayer(std::move(layer)); + } } return Status::OK(); } -#endif - -// Operator support related bool ConvOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { @@ -186,23 +223,73 @@ bool ConvOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara const auto& input_defs = node.InputDefs(); const auto& weight_name = input_defs[1]->Name(); - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); - if (Contains(initializers, weight_name)) { - const auto& tensor = *initializers.at(weight_name); - if (tensor.dims().size() != 4 && tensor.dims().size() != 3) { - LOGS(logger, VERBOSE) << "Conv [" << name << "] dimension: " << tensor.dims().size() - << " Only conv 2d and conv 1d are supported."; + const auto* weight = input_params.graph_viewer.GetConstantInitializer(weight_name); + +#if defined(COREML_ENABLE_MLPROGRAM) + if (input_params.create_mlprogram) { + // ML Program supports non-const weight, 1D, 2D and 3D. + // keep to 1D and 2D for consistency with the NeuralNetwork implementation for now. + // add 3D support as/when needed. + } else +#endif // defined (COREML_ENABLE_MLPROGRAM) + { + if (!weight) { + LOGS(logger, VERBOSE) << "The weight of Conv [" << name << "] must be a constant initializer"; return false; } - } else { - LOGS(logger, VERBOSE) << "The weight of Conv [" << name << "] must be known"; + } + + // use the weight for the shape as it should always be known + const auto* weight_shape = input_defs[1]->Shape(); + int64_t num_dims = weight_shape ? weight_shape->dim_size() : -1; + + // ONNX spec requires N and C as first 2 dims + if (num_dims != 3 && num_dims != 4) { + LOGS(logger, VERBOSE) << "Conv [" << name << "] is " << num_dims - 2 << "D. " + << "Only 1D and 2D Conv are supported currently."; return false; } - if (input_defs.size() > 2) { - const auto& bias_name = input_defs[2]->Name(); - if (!Contains(initializers, bias_name)) { - LOGS(logger, VERBOSE) << "The bias of Conv [" << name << "] must be a constant initializer"; + if (input_defs.size() > 2 && !input_params.graph_viewer.GetConstantInitializer(input_defs[2]->Name())) { + LOGS(logger, VERBOSE) << "The bias of Conv [" << name << "] must be a constant initializer"; + return false; + } + + NodeAttrHelper helper(node); + +#if defined(COREML_ENABLE_MLPROGRAM) + // spec says same_lower is supported in CoreML 5. it lies. CoreML 6 is required otherwise you get + // `Unexpected value for parameter pad_type[0] "same_lower" not in ("custom", "same", "valid").` + // We _could_ manually calculate the pads, but not implementing that until we have a real use case to justify + // the effort as it's not clear how common usage of same_lower is. + if (input_params.create_mlprogram && input_params.coreml_version < 6) { + if (StringToAutoPadType(helper.Get("auto_pad", "NOTSET")) == AutoPadType::SAME_LOWER) { + LOGS(logger, VERBOSE) << "Pad type of SAME_LOWER [" << name << "] is not supported until CoreML 6." + << "Available version is CoreML " << input_params.coreml_version; + return false; + } + } +#endif + + // there's no equivalent to allow a manual kernel shape in CoreML. + // it's OK if a specified kernel_shape matches kH and kW dims of the weight input. + auto kernel_shape = helper.GetInt64s("kernel_shape"); + if (kernel_shape) { + bool valid = true; + if (static_cast(kernel_shape->size()) == num_dims - 2) { + for (int i = 0; i < num_dims - 2; ++i) { + // check the specified kernel shape matches the weight shape. skip the initial N and C dims in the latter. + if ((*kernel_shape)[i] != weight_shape->dim()[i + 2].dim_value()) { + valid = false; + break; + } + } + } else { + valid = false; + } + + if (!valid) { + LOGS(logger, VERBOSE) << "Conv [" << name << "] kernel_shape attribute does not match the weight shape"; return false; } } diff --git a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc index a4ad1c31b5027..1eba312b2577b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc @@ -4,37 +4,26 @@ #include "core/common/safeint.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class DepthToSpaceOpBuilder : public BaseOpBuilder { - // Add operator related - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ Status DepthToSpaceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& input_defs = node.InputDefs(); const auto& output_defs = node.OutputDefs(); @@ -54,9 +43,6 @@ Status DepthToSpaceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool DepthToSpaceOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc index b303fe7884cb1..f0adb70587bcf 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc @@ -3,39 +3,26 @@ #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class FlattenOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ - Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + const logging::Logger& /*logger*/) const { + std::unique_ptr layer = model_builder.CreateNNLayer(node); // Note: ONNX Flatten corresponds to CoreML FlattenTo2DLayerParams auto* coreml_flatten = layer->mutable_flattento2d(); @@ -51,9 +38,6 @@ Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -#endif - -// Operator support related bool FlattenOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc index 9c7ec306ca093..7d32675e3e510 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc @@ -2,34 +2,24 @@ // Licensed under the MIT License. #include "core/providers/coreml/builders/impl/base_op_builder.h" - #include "core/providers/coreml/builders/op_builder_factory.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#if defined(__APPLE__) -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime::coreml { class GatherOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: - bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related -#if defined(__APPLE__) namespace { int64_t GetAxisAttribute(const Node& node) { NodeAttrHelper node_attr_helper{node}; @@ -38,8 +28,8 @@ int64_t GetAxisAttribute(const Node& node) { } // namespace Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& logger) const { - auto layer = CreateNNLayer(model_builder, node); + const logging::Logger& /*logger*/) const { + auto layer = model_builder.CreateNNLayer(node); layer->mutable_gather()->set_axis(GetAxisAttribute(node)); *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); // data *layer->mutable_input()->Add() = node.InputDefs()[1]->Name(); // indices @@ -47,10 +37,9 @@ Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif // defined(__APPLE__) -// Operator support related -bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const { +bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& logger) const { int32_t input_type; if (!GetType(*node.InputDefs()[0], input_type, logger)) return false; diff --git a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc index 71b08db6d44d8..8daf64dc4a457 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc @@ -7,46 +7,66 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/impl/builder_utils.h" -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class GemmOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: - bool IsOpSupportedImpl(const Node& /* node */, const OpBuilderInputParams& /* input_params */, - const logging::Logger& /* logger */) const override; -}; + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; -// Add operator related + bool SupportsMLProgram() const override { return true; } +}; -#ifdef __APPLE__ void GemmOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { const auto& op = node.OpType(); const auto& input_defs(node.InputDefs()); - // We have already embedded the weights (matrix B and C(if any)) into the coreml layer - // No need to copy them later to reduce memory consumption - model_builder.AddInitializerToSkip(input_defs[1]->Name()); - if (op == "Gemm" && input_defs.size() > 2) { - model_builder.AddInitializerToSkip(input_defs[2]->Name()); + const bool is_gemm = op == "Gemm"; + +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + // we have to transpose the weight input of Gemm if transB is false, and potentially override the bias shape + if (is_gemm) { + NodeAttrHelper helper(node); + const auto transB = helper.Get("transB", 0); + if (transB == 0) { + model_builder.AddInitializerToSkip(input_defs[1]->Name()); + } + + if (input_defs.size() > 2) { + // ONNX spec requires B to be 2D and we required it to be a constant initializer so reading N this way is safe + // B is {K, N] by default. or {N, K} if transB is true + int N_dim = transB ? 0 : 1; + int64_t N = input_defs[1]->Shape()->dim().at(N_dim).dim_value(); + + const auto& bias_name = input_defs[2]->Name(); + const auto& bias = *model_builder.GetConstantInitializer(bias_name); + if (bias.dims_size() != 1 || bias.dims(0) != N) { + // we have to override the shape/duplicate data to convert {}, {1} or {1, N} to 1D {N} + // when adding the Gemm operation so skip adding the original initializer + model_builder.AddInitializerToSkip(bias_name); + } + } + } + } else +#endif // defined(COREML_ENABLE_MLPROGRAM) + { + // We have already embedded the weights (matrix B and C(if any)) into the coreml layer + // No need to copy them later to reduce memory consumption + model_builder.AddInitializerToSkip(input_defs[1]->Name()); + if (is_gemm && input_defs.size() > 2) { + model_builder.AddInitializerToSkip(input_defs[2]->Name()); + } } } @@ -70,156 +90,258 @@ static Status GetTensorFloatDataTransposed(const ONNX_NAMESPACE::TensorProto& te } Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& /* logger */) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + const logging::Logger& logger) const { + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); - const auto& b_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); - const auto& b_shape = b_tensor.dims(); - - auto* coreml_inner_product = layer->mutable_innerproduct(); - - // The coreml innerproduct weight (matrix B) is stored transposed - // - for MatMul and Gemm (transB = 0), the coreml weight is B' - // - for Gemm (transB = 1), the coreml weight is B - if (op_type == "MatMul") { - coreml_inner_product->set_inputchannels(b_shape[0]); - coreml_inner_product->set_outputchannels(b_shape[1]); - // Add weight (b of MatMul) - std::vector b_transposed; - ORT_RETURN_IF_ERROR(GetTensorFloatDataTransposed(b_tensor, b_transposed)); - CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), b_transposed); - } else { // Gemm - NodeAttrHelper helper(node); - const auto transB = helper.Get("transB", 0); - if (transB == 0) { - coreml_inner_product->set_inputchannels(b_shape[0]); - coreml_inner_product->set_outputchannels(b_shape[1]); + const auto& a = *input_defs[0]; + const auto& b = *input_defs[1]; + const auto* b_initializer = model_builder.GetConstantInitializer(b.Name()); // MLProgram MatMul may not be constant + + const bool is_matmul = op_type == "MatMul"; + const bool is_gemm = op_type == "Gemm"; + + NodeAttrHelper helper(node); + const auto transB = is_gemm ? helper.Get("transB", 0) : 0; + + std::vector b_shape; + ORT_IGNORE_RETURN_VALUE(GetShape(b, b_shape, logger)); + int64_t b0 = -1, b1 = -1; + + // ML Program MatMul supports N-D input + if (model_builder.CreateMLProgram() && is_matmul) { + if (b_shape.size() == 1) { + // B is treated as {b_shape[0], 1} according to the numpy rules. + b0 = b_shape[0]; + b1 = 1; + } else { + // last 2 dims are used + b0 = b_shape[b_shape.size() - 2]; + b1 = b_shape[b_shape.size() - 1]; + } + } else { + // we only support 2D input + b0 = b_shape[0]; + b1 = b_shape[1]; + } + + // B is {K, N} in ONNX spec by default, or {N, K} in Gemm if transB is true + const auto K = transB ? b1 : b0; + const auto N = transB ? b0 : b1; + +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + + if (is_gemm) { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.linear.linear + auto gemm_op = model_builder.CreateOperation(node, "linear"); + AddOperationInput(*gemm_op, "x", a.Name()); + + // CoreML takes weight input as {N, K} which is the reverse of ONNX. + // if transB is true the input weight is {N, K} so can be added directly. + if (transB) { + AddOperationInput(*gemm_op, "weight", b.Name()); + } else { + // transpose from {K, N} to {N, K} + std::vector weight_nk; + std::vector weight_nk_shape = {N, K}; + ORT_RETURN_IF_ERROR(GetTensorFloatDataTransposed(*b_initializer, weight_nk)); + + AddOperationInput(*gemm_op, "weight", + model_builder.AddConstant(gemm_op->type(), b.Name() + "_t", weight_nk, weight_nk_shape)); + } + + if (input_defs.size() == 3) { + const auto& bias_arg = *input_defs[2]; + const auto& bias = *model_builder.GetConstantInitializer(bias_arg.Name()); + + // CoreML linear op requires bias to be 1D tensor of size N + if (bias.dims_size() == 1 && bias.dims().at(0) == N) { + // can use existing initializer + AddOperationInput(*gemm_op, "bias", bias_arg.Name()); + } else { + Initializer unpacked_tensor(bias); + auto bias_data = unpacked_tensor.DataAsSpan(); + std::string_view bias_data_name; + if (bias_data.size() == 1) { + // expand scalar to N + std::vector expanded_bias_data(N, bias_data[0]); + bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", expanded_bias_data); + } else { + // can use data as-is but need to adjust shape (inferred by AddConstant as {bias_data.size()}) + bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", bias_data); + } + + AddOperationInput(*gemm_op, "bias", bias_data_name); + } + } + + AddOperationOutput(*gemm_op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(gemm_op)); + } else { + // CoreML implementation is the same as ONNX MatMul. + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.linear.matmul + auto matmul_op = model_builder.CreateOperation(node, "matmul"); + AddOperationInput(*matmul_op, "x", a.Name()); + AddOperationInput(*matmul_op, "y", b.Name()); + + // once again the spec lies and says transpose_y and transpose_x are optional... + auto false_value_name = model_builder.AddScalarConstant(matmul_op->type(), "false", false); + AddOperationInput(*matmul_op, "transpose_x", false_value_name); + AddOperationInput(*matmul_op, "transpose_y", false_value_name); + + AddOperationOutput(*matmul_op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(matmul_op)); + } + } else +#endif // defined(COREML_ENABLE_MLPROGRAM) + { + auto* coreml_inner_product = layer->mutable_innerproduct(); + + *layer->mutable_input()->Add() = a.Name(); + + coreml_inner_product->set_inputchannels(K); + coreml_inner_product->set_outputchannels(N); + + // CoreML takes weight input as {N, K} which is the reverse of ONNX. + // if Gemm's transB is true the input weight is {N, K} and can be added directly. + if (transB) { + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), *b_initializer)); + } else { std::vector b_transposed; - ORT_RETURN_IF_ERROR(GetTensorFloatDataTransposed(b_tensor, b_transposed)); + ORT_RETURN_IF_ERROR(GetTensorFloatDataTransposed(*b_initializer, b_transposed)); CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), b_transposed); - } else { - coreml_inner_product->set_inputchannels(b_shape[1]); - coreml_inner_product->set_outputchannels(b_shape[0]); - // Add weight (b of MatMul) - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), b_tensor)); } - // Add bias if present - if (input_defs.size() > 2) { + if (is_gemm && input_defs.size() > 2) { + // Add bias coreml_inner_product->set_hasbias(true); - const auto& bias_tensor = *model_builder.GetInitializerTensors().at(input_defs[2]->Name()); - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_inner_product->mutable_bias(), bias_tensor)); + const auto& bias_tensor = *model_builder.GetConstantInitializer(input_defs[2]->Name()); + + // if scalar, or single value expand to 1D tensor of size N + // IsOpSupportedImpl enforces it's scalar, {1}, {N}, or {1, N}. + Initializer unpacked_tensor(bias_tensor); + auto bias_data = unpacked_tensor.DataAsSpan(); + if (bias_data.size() == 1 && N > 1) { + std::vector expanded_bias_data(N, bias_data[0]); + CreateCoreMLWeight(*coreml_inner_product->mutable_bias(), expanded_bias_data); + } else { + CreateCoreMLWeight(*coreml_inner_product->mutable_bias(), bias_data); + } } - } - *layer->mutable_input()->Add() = input_defs[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + model_builder.AddLayer(std::move(layer)); + } - model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool GemmOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { const auto& op_type = node.OpType(); const auto& input_defs(node.InputDefs()); + const bool is_matmul = op_type == "MatMul"; + const bool is_gemm = op_type == "Gemm"; + size_t a_idx = 0, b_idx = 1, c_idx = 2; // A*B+C - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); - if (!Contains(initializers, input_defs[b_idx]->Name())) { - LOGS(logger, VERBOSE) << "B of Gemm/Matmul must be an initializer tensor"; + std::vector a_shape; + if (!GetShape(*input_defs[a_idx], a_shape, logger)) { return false; } - std::vector a_shape; - { - if (!GetShape(*input_defs[a_idx], a_shape, logger)) - return false; - - if (a_shape.size() != 2) { - LOGS(logger, VERBOSE) << "A must be 2D"; - return false; - } + std::vector b_shape; + if (!GetShape(*input_defs[b_idx], b_shape, logger)) { + return false; + } - // TODO is it ok if the shape is dynamic and empty? - if (Product(a_shape) == 0) { - LOGS(logger, VERBOSE) << "A must be non-empty"; + if (!input_params.graph_viewer.GetConstantInitializer(input_defs[b_idx]->Name())) { + if (input_params.create_mlprogram && is_matmul) { + // ML Program MatMul allows non-constant B input + } else { + LOGS(logger, VERBOSE) << op_type << " B input must be a constant initializer"; return false; } } - std::vector b_shape; - { - if (!GetShape(*input_defs[b_idx], b_shape, logger)) - return false; - - if (b_shape.size() != 2) { - LOGS(logger, VERBOSE) << "B must be 2D"; - return false; - } + if (is_matmul) { + if (input_params.create_mlprogram) { + // ML Program matmul op has numpy semantics the same as the ONNX spec so we can use directly + } else { + // we could potentially support 1D and 3D if required. beyond 3D the dims that merge diverge. + // https://github.com/apple/coremltools/blob/1931758aae383c83daddfc56f11a24a9d2bf4b87/coremltools/converters/onnx/_operators.py#L1607 + // https://github.com/apple/coremltools/blob/1931758aae383c83daddfc56f11a24a9d2bf4b87/coremltools/converters/mil/backend/nn/op_mapping.py#L1374 + // https://apple.github.io/coremltools/mlmodel/Format/NeuralNetwork.html#innerproductlayerparams + if (a_shape.size() != 2 || b_shape.size() != 2) { + LOGS(logger, VERBOSE) << "a and b inputs must be 2D. "; + return false; + } - if (Product(b_shape) == 0) { - LOGS(logger, VERBOSE) << "B must be non-empty"; - return false; + if (input_defs.size() > 2) { + LOGS(logger, VERBOSE) << "MatMul with C input is not supported"; + return false; + } } } - if (op_type == "Gemm") { + if (is_gemm) { + // A and B are 2D due to the ONNX spec NodeAttrHelper helper(node); const auto transA = helper.Get("transA", 0); const auto transB = helper.Get("transB", 0); const auto alpha = helper.Get("alpha", 1.0f); const auto beta = helper.Get("beta", 1.0f); + + // TODO: We can support transA, alpha and beta by using multiple layers/operations if needed. if (!(transA == 0 && alpha == 1.f && beta == 1.f)) { - LOGS(logger, VERBOSE) << "Only transA == 0, alpha == 1.0 " - << "and beta == 1.0 is supported." + LOGS(logger, VERBOSE) << "Only support for transA == 0, alpha == 1.0 " + << "and beta == 1.0 is currently implemented." << " transA " << transA << " alpha " << alpha << " beta " << beta; return false; } - // C of Gemm - // For now we only support {n} or {1,n} tensor if (input_defs.size() == 3) { - if (!Contains(initializers, input_defs[c_idx]->Name())) { - LOGS(logger, VERBOSE) << "C of Gemm must be an initializer tensor"; + if (!input_params.graph_viewer.GetConstantInitializer(input_defs[c_idx]->Name())) { + LOGS(logger, VERBOSE) << "C of Gemm must be a constant initializer"; return false; } std::vector c_shape; - if (!GetShape(*input_defs[c_idx], c_shape, logger)) + if (!GetShape(*input_defs[c_idx], c_shape, logger)) { return false; + } - size_t c_dim = c_shape.size(); + // B is {K, N} in ONNX spec by default, or {N, K} in Gemm if transB is true + const auto N = transB ? b_shape[0] : b_shape[1]; - if (c_dim == 0) { - LOGS(logger, VERBOSE) << "C of Gemm cannot be a scalar"; - return false; - } + size_t c_rank = c_shape.size(); - if (c_dim != 1) { - // If C is a (2+)d tensor, it must have the format {1, 1, ..., 1, n} - // where every except the last dimension should be 1 - for (size_t i = 0; i < c_dim - 1; ++i) { - if (c_shape[i] != 1) { - LOGS(logger, VERBOSE) << "C of Gemm must be a vector or a tensor with only last dimension != 1"; - return false; + // allowed: scalar, or 1D where the value is 1 or N, 2D with shape {1, N} + bool c_valid = false; + switch (c_rank) { + case 0: + c_valid = true; + break; + case 1: + if (c_shape[0] == 1 || c_shape[0] == N) { + c_valid = true; } - } + break; + case 2: + if (c_shape[0] == 1 && c_shape[1] == N) { + c_valid = true; + } + break; } - auto c_size = c_shape[c_dim - 1]; - if (c_size != (transB == 0 ? b_shape[1] : b_shape[0])) { - LOGS(logger, VERBOSE) << "C of Gemm must be a vector of b_shape[" - << (transB == 0 ? "1" : "0") << "]" - << " b_shape: [" << b_shape[0] << ", " << b_shape[1] << "]" - << " c_size: " << c_size; + if (!c_valid) { + LOGS(logger, VERBOSE) << "Shape of C Gemm input must be {}, {1}, {N}, or {1, N}. N:" << N << " C shape:" + << Shape2String(c_shape); return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc index ba12600e8bc40..99d6f01cb8c5b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc @@ -7,30 +7,20 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class PadOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; @@ -64,9 +54,6 @@ static InlinedVector GetPaddingAxesData(const InitializedTensorSet& ini return axes_tensor_data; } -// Add operator related - -#ifdef __APPLE__ void PadOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // pads model_builder.AddInitializerToSkip(node.InputDefs()[2]->Name()); // constant_value @@ -78,7 +65,7 @@ void PadOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_pad = layer->mutable_padding(); auto* constant_padding_type = coreml_pad->mutable_constant(); // CoreML::Specification::PaddingLayerParams_PaddingConstant @@ -122,9 +109,6 @@ Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -#endif - -// Operator support related bool PadOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc index fd1c77c851e6f..17910ba6fd486 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc @@ -4,132 +4,191 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/impl/builder_utils.h" -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class PoolOpBuilder : public BaseOpBuilder { - // Add operator related - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; -}; -// Add operator related + bool SupportsMLProgram() const override { return true; } +}; -#ifdef __APPLE__ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); - - auto* coreml_pool = layer->mutable_pooling(); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); - bool is_global_pooling = false; - if (op_type == "GlobalAveragePool") { - is_global_pooling = true; - coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_AVERAGE); - } else if (op_type == "GlobalMaxPool") { - is_global_pooling = true; - coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_MAX); - } else if (op_type == "AveragePool") { - coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_AVERAGE); - } else if (op_type == "MaxPool") { - coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_MAX); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "PoolOpBuilder, unknown op: ", op_type); - } +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + + std::string_view coreml_op_type; + bool is_global = false; + bool is_avg_pool = false; + if (op_type == "GlobalAveragePool") { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.reduction.reduce_mean + coreml_op_type = "reduce_mean"; + is_global = true; + } else if (op_type == "GlobalMaxPool") { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.reduction.reduce_max + coreml_op_type = "reduce_max"; + is_global = true; + } else if (op_type == "AveragePool") { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.pool.avg_pool + coreml_op_type = "avg_pool"; + is_avg_pool = true; + } else if (op_type == "MaxPool") { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.pool.max_pool + coreml_op_type = "max_pool"; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "PoolOpBuilder, unexpected op: ", op_type); + } - if (is_global_pooling) { - coreml_pool->set_globalpooling(true); - coreml_pool->mutable_valid(); - } else { // AveragePool or MaxPool - NodeAttrHelper helper(node); - const auto kernel_shape = helper.Get("kernel_shape", std::vector{0, 0}); - const auto strides = helper.Get("strides", std::vector{1, 1}); - const auto onnx_pads = helper.Get("pads", std::vector{0, 0, 0, 0}); - - coreml_pool->add_kernelsize(kernel_shape[0]); - coreml_pool->add_kernelsize(kernel_shape[1]); - coreml_pool->add_stride(strides[0]); - coreml_pool->add_stride(strides[1]); - coreml_pool->set_avgpoolexcludepadding(helper.Get("count_include_pad", 0) == 0); - coreml_pool->set_globalpooling(false); - - // Add Padding - // Usually using autopadding is more efficient than using explicit padding - // Try to see if we can map explicit padding to auto padding - std::vector input_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); - AutoPadType auto_pad_type; - ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, kernel_shape[0], kernel_shape[1], - onnx_pads, strides, {1, 1} /* dilations */, - StringToAutoPadType(helper.Get("auto_pad", "NOTSET")), - auto_pad_type)); - - if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { - auto* padding_type = coreml_pool->mutable_same(); - if (AutoPadType::SAME_LOWER == auto_pad_type) { // default is SAME_UPPER - padding_type->set_asymmetrymode(COREML_SPEC::SamePadding_SamePaddingMode_TOP_LEFT_HEAVY); + std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type); + + AddOperationInput(*op, "x", input_defs[0]->Name()); + + if (is_global) { + // keep N and C dims, reduce the rest with keepdims=True. equivalent to the ONNX Global*Pool ops. + std::vector axes{2, 3}; // we only support 4D input currently. + AddOperationInput(*op, "axes", model_builder.AddConstant(op->type(), "axes", axes)); + AddOperationInput(*op, "keep_dims", model_builder.AddScalarConstant(op->type(), "keep_dims", true)); + } else { + NodeAttrHelper helper(node); + constexpr int num_spatial_dims = 2; // we only support 4D. -2 for N and C dims. + + AddPadTypeAndPads(*op, model_builder, op->type(), helper, num_spatial_dims); + + const auto kernel_shape = helper.GetInt64s("kernel_shape"); // required + AddOperationInput(*op, "kernel_sizes", model_builder.AddConstant(op->type(), "kernel_sizes", *kernel_shape)); + + // in theory all these values are optional according to the CoreML spec but simpler to just provide default + // values as the actual model compilation tends to require them. + const auto strides = helper.Get("strides", std::vector(num_spatial_dims, 1)); + const bool ceil_mode = helper.Get("ceil_mode", int64_t(0)); // convert int64_t to bool + + AddOperationInput(*op, "strides", model_builder.AddConstant(op->type(), "strides", strides)); + AddOperationInput(*op, "ceil_mode", model_builder.AddScalarConstant(op->type(), "ceil_mode", ceil_mode)); + + if (is_avg_pool) { + const bool count_exclude_pad = helper.Get("count_include_pad", int64_t(0)) == 0; + AddOperationInput(*op, "exclude_padding_from_average", + model_builder.AddScalarConstant(op->type(), "count_exclude_pad", count_exclude_pad)); } + } + + AddOperationOutput(*op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(op)); + + } else +#endif // defined(COREML_ENABLE_MLPROGRAM) + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); + + auto* coreml_pool = layer->mutable_pooling(); + + bool is_global_pooling = false; + if (op_type == "GlobalAveragePool") { + is_global_pooling = true; + coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_AVERAGE); + } else if (op_type == "GlobalMaxPool") { + is_global_pooling = true; + coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_MAX); + } else if (op_type == "AveragePool") { + coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_AVERAGE); + } else if (op_type == "MaxPool") { + coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_MAX); } else { - auto* padding_type = coreml_pool->mutable_valid(); - if (AutoPadType::NOTSET == auto_pad_type && onnx_pads != std::vector{0, 0, 0, 0}) { - // NOTSET is adding the explicit padding to the ValidPadding.paddingAmounts - auto* height_border = padding_type->mutable_paddingamounts()->add_borderamounts(); - height_border->set_startedgesize(onnx_pads[0]); - height_border->set_endedgesize(onnx_pads[2]); - auto* width_border = padding_type->mutable_paddingamounts()->add_borderamounts(); - width_border->set_startedgesize(onnx_pads[1]); - width_border->set_endedgesize(onnx_pads[3]); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "PoolOpBuilder, unexpected op: ", op_type); + } + + if (is_global_pooling) { + coreml_pool->set_globalpooling(true); + coreml_pool->mutable_valid(); + } else { // AveragePool or MaxPool + NodeAttrHelper helper(node); + const auto kernel_shape = helper.Get("kernel_shape", std::vector{0, 0}); + const auto strides = helper.Get("strides", std::vector{1, 1}); + const auto onnx_pads = helper.Get("pads", std::vector{0, 0, 0, 0}); + + coreml_pool->add_kernelsize(kernel_shape[0]); + coreml_pool->add_kernelsize(kernel_shape[1]); + coreml_pool->add_stride(strides[0]); + coreml_pool->add_stride(strides[1]); + coreml_pool->set_avgpoolexcludepadding(helper.Get("count_include_pad", 0) == 0); + coreml_pool->set_globalpooling(false); + + // Add Padding + // Usually using autopadding is more efficient than using explicit padding + // Try to see if we can map explicit padding to auto padding + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + AutoPadType auto_pad_type; + ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, kernel_shape[0], kernel_shape[1], + onnx_pads, strides, {1, 1} /* dilations */, + StringToAutoPadType(helper.Get("auto_pad", "NOTSET")), + auto_pad_type)); + + if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { + auto* padding_type = coreml_pool->mutable_same(); + if (AutoPadType::SAME_LOWER == auto_pad_type) { // default is SAME_UPPER + padding_type->set_asymmetrymode(COREML_SPEC::SamePadding_SamePaddingMode_TOP_LEFT_HEAVY); + } + } else { + auto* padding_type = coreml_pool->mutable_valid(); + if (AutoPadType::NOTSET == auto_pad_type && onnx_pads != std::vector{0, 0, 0, 0}) { + // NOTSET is adding the explicit padding to the ValidPadding.paddingAmounts + auto* height_border = padding_type->mutable_paddingamounts()->add_borderamounts(); + height_border->set_startedgesize(onnx_pads[0]); + height_border->set_endedgesize(onnx_pads[2]); + auto* width_border = padding_type->mutable_paddingamounts()->add_borderamounts(); + width_border->set_startedgesize(onnx_pads[1]); + width_border->set_endedgesize(onnx_pads[3]); + } } } - } - *layer->mutable_input()->Add() = input_defs[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + *layer->mutable_input()->Add() = input_defs[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + + model_builder.AddLayer(std::move(layer)); + } - model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif -// Operator support related -bool PoolOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /* input_params */, +bool PoolOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) + if (!GetShape(*input_defs[0], input_shape, logger)) { return false; + } + // TODO: ML Program supports 3D and 5D. Add if we have a use case for that. const auto input_size = input_shape.size(); if (input_size != 4) { - LOGS(logger, VERBOSE) - << op_type << " only supports rank-4 tensor, input [" - << input_defs[0]->Name() << "] has actual dim count " << input_size; + LOGS(logger, VERBOSE) << op_type << " only supports rank-4 tensor, input [" + << input_defs[0]->Name() << "] has actual dim count " << input_size; return false; } if (op_type == "AveragePool" || op_type == "MaxPool") { NodeAttrHelper helper(node); + const auto storage_order = helper.Get("storage_order", 0); if (storage_order == 1) { LOGS(logger, VERBOSE) << "storage_order == 1 is not supported"; @@ -141,12 +200,14 @@ bool PoolOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara return false; } - // TODO, add support of the ceil_mode by adjusting the padding - // See https://stackoverflow.com/questions/59906456/in-pytorchs-maxpool2d-is-padding-added-depending-on-ceil-mode - // and https://github.com/apple/coremltools/blob/1931758aae383c83daddfc56f11a24a9d2bf4b87/coremltools/converters/mil/frontend/torch/ops.py#L621-L644 - if (helper.Get("ceil_mode", 0) == 1) { - LOGS(logger, VERBOSE) << "ceil_mode == 1 is not supported for pooling"; - return false; + if (!input_params.create_mlprogram) { + // TODO, add support of the ceil_mode by adjusting the padding + // See https://stackoverflow.com/questions/59906456/in-pytorchs-maxpool2d-is-padding-added-depending-on-ceil-mode + // and https://github.com/apple/coremltools/blob/1931758aae383c83daddfc56f11a24a9d2bf4b87/coremltools/converters/mil/frontend/torch/ops.py#L621-L644 + if (helper.Get("ceil_mode", 0) == 1) { + LOGS(logger, VERBOSE) << "ceil_mode == 1 is not supported for pooling"; + return false; + } } if (helper.Get("dilations", std::vector{1, 1}) != diff --git a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc index 6a2014e7952a2..32378b1f654d8 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc @@ -1,36 +1,27 @@ // Copyright (c) Shukant Pal. // Licensed under the MIT License. +#include "core/optimizer/initializer.h" #include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" - -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" -#include "core/optimizer/initializer.h" - -#include "base_op_builder.h" +#include "core/providers/shared/utils/utils.h" namespace onnxruntime { namespace coreml { class ReductionOpBuilder : public BaseOpBuilder { -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - private: + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -#ifdef __APPLE__ namespace { template void AddReductionParams(T* params, const std::vector& axes, bool keepdims, bool noop_with_empty_axes) { @@ -76,7 +67,7 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co const bool keepdims = helper.Get("keepdims", 1) != 0; const bool noop_with_empty_axes = helper.Get("noop_with_empty_axes", 0) != 0; - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); if (op_type == "ReduceSum") { AddReductionParams(layer->mutable_reducesum(), axes, keepdims, noop_with_empty_axes); @@ -93,7 +84,6 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif bool ReductionOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { @@ -124,4 +114,4 @@ void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations } } // namespace coreml -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc index 67aee73630cdb..27d24d9c21893 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc @@ -1,90 +1,96 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/framework/tensorprotoutils.h" #include "core/optimizer/initializer.h" -#include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/cpu/tensor/reshape_helper.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class ReshapeOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; // Reshape opset 4- uses attributes for new shape which we do not support for now int GetMinSupportedOpSet(const Node& /* node */) const override { return 5; } -}; -// Add operator related + bool SupportsMLProgram() const override { return true; } +}; -#ifdef __APPLE__ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + // Skip the second input which is the new shape as we always have to create a new version as the CoreML rules + // are different from ONNX. model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); } Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); - const auto& input_defs = node.InputDefs(); - const auto& initializers(model_builder.GetInitializerTensors()); - const auto& target_shape_tensor = *initializers.at(input_defs[1]->Name()); - const int64_t* raw_target_shape = target_shape_tensor.int64_data().empty() - ? reinterpret_cast(target_shape_tensor.raw_data().data()) - : target_shape_tensor.int64_data().data(); - - const auto size = target_shape_tensor.dims()[0]; - TensorShapeVector target_shape{raw_target_shape, raw_target_shape + size}; std::vector input_shape; - ORT_RETURN_IF_NOT(GetStaticShape(*input_defs[0], input_shape, logger), "Cannot get shape"); - ReshapeHelper helper(TensorShape(input_shape), target_shape); - *layer->mutable_reshapestatic()->mutable_targetshape() = {target_shape.cbegin(), target_shape.cend()}; - *layer->mutable_input()->Add() = input_defs[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + ORT_RETURN_IF_NOT(GetStaticShape(*input_defs[0], input_shape, logger), "Cannot get shape of data"); + + const auto& data_name = input_defs[0]->Name(); + const auto& new_shape_name = input_defs[1]->Name(); + Initializer unpacked_tensor(*model_builder.GetConstantInitializer(new_shape_name)); + TensorShapeVector new_shape = ToShapeVector(unpacked_tensor.DataAsSpan()); + + // ReshapeHelper applies the ONNX rules to create the concrete output shape + ReshapeHelper helper(TensorShape(input_shape), new_shape); - model_builder.AddLayer(std::move(layer)); +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.tensor_transformation.reshape + std::unique_ptr reshape_op = model_builder.CreateOperation(node, "reshape"); + + AddOperationInput(*reshape_op, "x", data_name); + AddOperationInput(*reshape_op, "shape", + model_builder.AddConstant(reshape_op->type(), "shape", ToConstSpan(new_shape))); + + AddOperationOutput(*reshape_op, *node.OutputDefs()[0]); + + model_builder.AddOperation(std::move(reshape_op)); + } else +#endif // defined(COREML_ENABLE_MLPROGRAM) + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); + + *layer->mutable_reshapestatic()->mutable_targetshape() = {new_shape.cbegin(), new_shape.cend()}; + *layer->mutable_input()->Add() = data_name; + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + + model_builder.AddLayer(std::move(layer)); + } return Status::OK(); } -#endif - -// Operator support related bool ReshapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& new_shape_name = input_defs[1]->Name(); - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); - if (!Contains(initializers, new_shape_name)) { + const auto* new_shape_tensor = input_params.graph_viewer.GetConstantInitializer(new_shape_name); + if (!new_shape_tensor) { + // ONNX has different rules around how -1 and 0 values are used/combined, and + // we can't check if those can be translated to CoreML if the shape is unknown. LOGS(logger, VERBOSE) << "New shape of reshape must be a constant initializer"; return false; } - const auto& new_shape_tensor = *initializers.at(new_shape_name); - Initializer unpacked_tensor(new_shape_tensor); + Initializer unpacked_tensor(*new_shape_tensor); auto new_shape = unpacked_tensor.DataAsSpan(); if (new_shape.empty()) { LOGS(logger, VERBOSE) << "New shape of reshape cannot be empty"; @@ -100,7 +106,7 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputP return false; } - // CoreML reshape doesn't support new shape with more than 5 dimensions + // CoreML reshape doesn't support new shape with more than 5 dimensions. if (new_shape.size() > 5) { LOGS(logger, VERBOSE) << "Reshape does not support new shape with rank greater than 5. Input shape: " << Shape2String(input_shape) << ", new shape: " << Shape2String(new_shape); @@ -109,7 +115,7 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputP // CoreML reshape does not support 0 as dimension NodeAttrHelper helper(node); - const bool allow_zero = helper.Get("allowzero ", 0) == 1; + const bool allow_zero = helper.Get("allowzero", 0) == 1; if (allow_zero) { if (std::find(new_shape.begin(), new_shape.end(), int64_t{0}) != new_shape.end()) { LOGS(logger, VERBOSE) << "Reshape does not support new shape with 0 as dimension when allowzero is enabled. " diff --git a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc index 5f963dc30dd8f..6c2fcc2ace856 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc @@ -8,31 +8,21 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/cpu/tensor/reshape_helper.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class ResizeOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; @@ -41,7 +31,7 @@ class ResizeOpBuilder : public BaseOpBuilder { int GetMinSupportedOpSet(const Node& /* node */) const override { return 11; } }; -// Helper functions +namespace { bool GetResizeScales(const InitializedTensorSet& initializers, const Node& node, std::vector& scales, const logging::Logger&) { @@ -73,10 +63,8 @@ bool GetResizeOutputSizes(const InitializedTensorSet& initializers, sizes = std::vector(sizes_data.begin(), sizes_data.end()); return true; } +} // namespace -// Add operator related - -#ifdef __APPLE__ void ResizeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { // We don't really use ROI here, so add it to skipped list if it's an initializer tensor model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // ROI @@ -96,7 +84,7 @@ void ResizeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const N Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_upsample = layer->mutable_upsample(); NodeAttrHelper helper(node); @@ -110,7 +98,7 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& input_defs = node.InputDefs(); const auto& initializers(model_builder.GetInitializerTensors()); - if (input_defs.size() == 3) { // use scales + if (input_defs.size() >= 3 && input_defs[2]->Exists()) { // use scales std::vector scales; ORT_RETURN_IF_NOT(GetResizeScales(initializers, node, scales, logger), "Error getting resize scales"); coreml_upsample->add_scalingfactor(static_cast(scales[2])); @@ -131,9 +119,6 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool ResizeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { @@ -197,20 +182,24 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPa return false; } + bool using_scales = input_defs.size() >= 3 && input_defs[2]->Exists(); // scales - if (input_defs.size() == 3 && !Contains(initializers, input_defs[2]->Name())) { - LOGS(logger, VERBOSE) << "Input scales of Resize must be known"; + if (using_scales && !input_params.graph_viewer.GetConstantInitializer(input_defs[2]->Name())) { + LOGS(logger, VERBOSE) << "scales input of Resize must be a constant initializer"; return false; } // sizes - if (input_defs.size() > 3 && !Contains(initializers, input_defs[3]->Name())) { - LOGS(logger, VERBOSE) << "Input sizes of Resize must be known"; + if (!using_scales && + (input_defs.size() < 4 || + !input_defs[3]->Exists() || + !input_params.graph_viewer.GetConstantInitializer(input_defs[3]->Name()))) { + LOGS(logger, VERBOSE) << "sizes input of Resize must be a constant initializer"; return false; } // We want to check if the scales or sizes are not trying to resize on N/C channels here - if (input_defs.size() == 3) { // we are using scales + if (using_scales) { std::vector scales; if (!GetResizeScales(initializers, node, scales, logger)) return false; diff --git a/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc index fd64153ffd283..a86e3d9538d87 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc @@ -2,44 +2,30 @@ // Licensed under the MIT License. #include "core/providers/coreml/builders/impl/base_op_builder.h" - +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" // for NodeAttrHelper -#if defined(__APPLE__) -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime::coreml { class ShapeOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related -#if defined(__APPLE__) Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& logger) const { - auto layer = CreateNNLayer(model_builder, node); + const logging::Logger& /*logger*/) const { + auto layer = model_builder.CreateNNLayer(node); layer->mutable_getshape(); *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif // defined(__APPLE__) -// Operator support related bool ShapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, const logging::Logger& logger) const { NodeAttrHelper node_attr_helper{node}; diff --git a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc index 2c250b3cc9f5a..39bfbfe5bba1f 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc @@ -1,39 +1,31 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/coreml/builders/impl/base_op_builder.h" - #include "core/optimizer/initializer.h" #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/cpu/tensor/slice_helper.h" #include "core/providers/shared/utils/utils.h" -#if defined(__APPLE__) -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime::coreml { class SliceOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: int GetMinSupportedOpSet(const Node& /* node */) const override { // Before Slice-10, some inputs were attributes instead. We don't support that for now. return 10; } - bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& builder_params, const logging::Logger& logger) const override; }; @@ -62,7 +54,7 @@ Status PrepareSliceComputeMetadataFromConstantInitializers(const Node& slice_nod return Status::OK(); } - const auto* tensor_proto = graph_viewer.GetConstantInitializer(input_defs[input_idx]->Name(), true); + const auto* tensor_proto = graph_viewer.GetConstantInitializer(input_defs[input_idx]->Name()); ORT_RETURN_IF_NOT(tensor_proto, "Failed to get constant initializer."); Initializer unpacked_tensor(*tensor_proto, graph_viewer.ModelPath()); const auto data_type = unpacked_tensor.data_type(); @@ -107,9 +99,6 @@ bool ValidateSliceComputeMetadataForCoreML(const SliceOp::PrepareForComputeMetad } } // namespace -// Add operator related -#if defined(__APPLE__) - void SliceOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { const auto& input_defs = node.InputDefs(); @@ -132,7 +121,7 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const ORT_RETURN_IF_ERROR(PrepareSliceComputeMetadataFromConstantInitializers(node, model_builder.GetGraphViewer(), compute_metadata)); - auto layer = CreateNNLayer(model_builder, node); + auto layer = model_builder.CreateNNLayer(node); *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); auto* slice_static = layer->mutable_slicestatic(); @@ -163,10 +152,8 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const return Status::OK(); } -#endif // defined(__APPLE__) - -// Operator support related -bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const { +bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& logger) const { int32_t input_type; if (!GetType(*node.InputDefs()[0], input_type, logger)) return false; diff --git a/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc index c454a2a779f6e..d6584124c6aba 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc @@ -1,43 +1,29 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/coreml/builders/impl/base_op_builder.h" - #include "core/framework/tensorprotoutils.h" #include "core/providers/common.h" -#include "core/providers/coreml/shape_utils.h" -#include "core/providers/shared/utils/utils.h" - -#ifdef __APPLE__ +#include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/op_builder_factory.h" +#include "core/providers/coreml/shape_utils.h" +#include "core/providers/shared/utils/utils.h" namespace onnxruntime { namespace coreml { class SoftmaxOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ - Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& input_name = node.InputDefs()[0]->Name(); const auto& output_name = node.OutputDefs()[0]->Name(); @@ -66,17 +52,15 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, target_shape.push_back(size_to_dimension); target_shape.push_back(size_from_dimension); - const auto reshape1_output_name = model_builder.GetUniqueName(MakeString(node.Name(), "reshape1_output")); + const auto reshape1_output_name = model_builder.GetUniqueName(node, "reshape1_output"); { // Add reshape layer - const auto softmax_reshape1_layer_name = - model_builder.GetUniqueName(MakeString(node.Name(), "_Softmax_reshape1")); - auto reshape_layer = CreateNNLayer(softmax_reshape1_layer_name); + auto reshape_layer = model_builder.CreateNNLayer(node, "_Softmax_reshape1"); *reshape_layer->mutable_reshapestatic()->mutable_targetshape() = {target_shape.cbegin(), target_shape.cend()}; *reshape_layer->mutable_input()->Add() = input_name; *reshape_layer->mutable_output()->Add() = reshape1_output_name; model_builder.AddLayer(std::move(reshape_layer)); } - const auto softmax_output_name = model_builder.GetUniqueName(MakeString(node.Name(), "softmax_output")); + const auto softmax_output_name = model_builder.GetUniqueName(node, "softmax_output"); { auto* coreml_softmaxnd = layer->mutable_softmaxnd(); coreml_softmaxnd->set_axis(-1); @@ -86,9 +70,7 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } { // Add reshape back layer - const auto softmax_reshape2_layer_name = - model_builder.GetUniqueName(MakeString(node.Name(), "_Softmax_reshape2")); - auto reshape_layer = CreateNNLayer(softmax_reshape2_layer_name); + auto reshape_layer = model_builder.CreateNNLayer(node, "_Softmax_reshape2"); *reshape_layer->mutable_reshapestatic()->mutable_targetshape() = {data_shape.cbegin(), data_shape.cend()}; *reshape_layer->mutable_input()->Add() = softmax_output_name; *reshape_layer->mutable_output()->Add() = output_name; @@ -99,10 +81,6 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -#endif - -// Operator support related - bool SoftmaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /* input_params */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc index 56c87c883156b..0497357c45c54 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc @@ -1,35 +1,24 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/coreml/builders/impl/base_op_builder.h" - #include "core/optimizer/initializer.h" #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#if defined(__APPLE__) -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class SplitOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; @@ -37,10 +26,6 @@ class SplitOpBuilder : public BaseOpBuilder { int GetMinSupportedOpSet(const Node& /* node */) const override { return 13; } }; -// Add operator related - -#ifdef __APPLE__ - void SplitOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { const auto& input_defs = node.InputDefs(); @@ -63,7 +48,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // attribute introduced since opset 18 uint64_t num_outputs; - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_splitnd = layer->mutable_splitnd(); coreml_splitnd->set_axis(axis); @@ -82,7 +67,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, coreml_splitnd->set_numsplits(num_outputs); } else { // note: for opset 18+ 'num_outputs' is a required attribute - num_outputs = narrow(helper.GetInt("num_outputs").value()); + num_outputs = narrow(helper.GetInt64("num_outputs").value()); // note: checked in IsOpSupportedImpl that ensures the dim value at splitting axis exists auto split_dim_size = data_shape[HandleNegativeAxis(axis, data_shape.size())]; uint64_t chunk_size = narrow((split_dim_size + num_outputs - 1) / num_outputs); @@ -111,10 +96,6 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -#endif - -// Operator support related - bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); @@ -159,7 +140,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar } } else { if (node.SinceVersion() >= 18) { - const auto num_outputs = helper.GetInt("num_outputs"); + const auto num_outputs = helper.GetInt64("num_outputs"); if (!num_outputs.has_value()) { LOGS(logger, VERBOSE) << "No 'num_outputs' provided. For split 18+, num_outputs is a required attribute."; return false; @@ -169,9 +150,10 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar << "CoreML SplitND requires at least 2 outputs. num_outputs: " << num_outputs.value(); return false; } - if (num_outputs.value() != static_cast(node.OutputDefs().size()) || num_outputs.value() > split_dims_at_axis) { - LOGS(logger, VERBOSE) << "Invalid num_outputs provided.\n." - << "The value should be smaller or equal to the size of dimension being split. num_outputs: " + if (num_outputs.value() != static_cast(node.OutputDefs().size()) || + num_outputs.value() > split_dims_at_axis) { + LOGS(logger, VERBOSE) << "Invalid num_outputs provided.\n. The value should be smaller or equal to the size " + "of dimension being split. num_outputs: " << num_outputs.value(); return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc index 2e14c85ce69c1..e9cc1c2dbf638 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc @@ -1,48 +1,30 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include + +#include "core/common/safeint.h" #include "core/framework/tensorprotoutils.h" #include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" -#include "core/optimizer/initializer.h" - -#ifdef __APPLE__ +#include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/op_builder_factory.h" - -#include "base_op_builder.h" +#include "core/providers/shared/utils/utils.h" +#include "core/optimizer/initializer.h" namespace onnxruntime { namespace coreml { class SqueezeOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ -void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { - if (node.SinceVersion() > 12 && node.InputDefs().size() > 1) { - model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); - } -} - -/* static */ Status GetAxes(ModelBuilder& model_builder, const Node& node, std::vector& axes) { +namespace { +Status GetAxes(ModelBuilder& model_builder, const Node& node, std::vector& axes) { // Squeeze opset 13 use input as axes if (node.SinceVersion() > 12) { // If axes is not provided, return an empty axes as default to squeeze all @@ -62,11 +44,18 @@ void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const return Status::OK(); } +} // namespace + +void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + if (node.SinceVersion() > 12 && node.InputDefs().size() > 1) { + model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); + } +} Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_squeeze = layer->mutable_squeeze(); std::vector axes; @@ -84,9 +73,6 @@ Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool SqueezeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& /*logger*/) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc index 7d5018a19f74c..f6a61d55a3d63 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc @@ -3,33 +3,23 @@ #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class TransposeOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif }; -// Add operator related - -#ifdef __APPLE__ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); NodeAttrHelper helper(node); std::vector perm = helper.Get("perm", std::vector()); @@ -51,7 +41,6 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); diff --git a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc index 660755b43c043..3403378d59114 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc @@ -3,32 +3,25 @@ #include "core/providers/common.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" -#include "base_op_builder.h" - namespace onnxruntime { namespace coreml { class UnaryOpBuilder : public BaseOpBuilder { - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif }; -#ifdef __APPLE__ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { const auto& op_type(node.OpType()); const auto& input_defs(node.InputDefs()); - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); if (op_type == "Sqrt") { layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::SQRT); @@ -45,9 +38,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); @@ -55,4 +45,4 @@ void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op } } // namespace coreml -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.cc b/onnxruntime/core/providers/coreml/builders/model_builder.cc index 9c8b7bce507e4..eb4723a3b9746 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/model_builder.cc @@ -2,56 +2,675 @@ // Licensed under the MIT License. #include -#include - -#include "model_builder.h" -#include "helper.h" -#include "op_builder_factory.h" +#include "core/common/safeint.h" +#include "core/framework/tensorprotoutils.h" +#include "core/platform/env.h" #include "core/providers/common.h" +#include "core/providers/coreml/builders/model_builder.h" +#include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/coreml_provider_factory.h" #include "core/providers/coreml/model/host_utils.h" -#include "core/providers/coreml/model/model.h" #include "core/providers/coreml/shape_utils.h" +#if defined(COREML_ENABLE_MLPROGRAM) +// includes from coremltools-src in _deps +#include "modelpackage/src/ModelPackage.hpp" +#include "mlmodel/src/MILBlob/Blob/StorageWriter.hpp" +using MILBlob::Blob::StorageWriter; +#endif + +using namespace CoreML::Specification; + namespace onnxruntime { namespace coreml { -ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, uint32_t coreml_flags) +namespace { +#if defined(COREML_ENABLE_MLPROGRAM) +// Should the initializer be written to file or kept as an immediate value +bool ShouldWriteInitializerToWeightsFile(const ONNX_NAMESPACE::TensorProto& tensor_proto) { + // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/load.py#L51-L57 + + bool use_weight_file = false; + + switch (tensor_proto.data_type()) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + case ONNX_NAMESPACE::TensorProto_DataType_INT8: { + auto num_elements = TensorShape(utils::GetTensorShapeFromTensorProto(tensor_proto)).Size(); + use_weight_file = num_elements >= 10; + break; + } + default: + break; + } + + return use_weight_file; +} + +// copy from the ONNX TensorProto to a CoreML field. +// T1 is the source type. T2 is the target type. If the types differ, T1 must be smaller than T2. +// e.g. uint32_t data can be written to RepeatedField +template +void CopyRawDataToRepeatedField(const ONNX_NAMESPACE::TensorProto& tensor_proto, + google::protobuf::RepeatedField& repeated_field) { + const auto& raw_data = tensor_proto.raw_data(); + const T1* data = reinterpret_cast(raw_data.data()); + const T1* data_end = data + (raw_data.size() / sizeof(T1)); + if constexpr (sizeof(T1) == sizeof(T2)) { + repeated_field.Add(data, data_end); + } else { + static_assert(sizeof(T1) < sizeof(T2)); + // we need to iterate over the data and copy to the repeated field, converting to T2 as we go. + repeated_field.Resize(data_end - data, T2(0)); + for (int i = 0; data != data_end; ++data, ++i) { + repeated_field[i] = static_cast(*data); + } + } +} + +// copy T data from the TensorProto.int32_t field to TensorValue.bytes +template +void CopyInt32DataToBytes(const ONNX_NAMESPACE::TensorProto& tensor_proto, MILSpec::TensorValue tensor_value) { + const int num_entries = tensor_proto.int32_data_size(); + std::string& bytes = *tensor_value.mutable_bytes()->mutable_values(); + bytes.resize(num_entries * sizeof(T)); + T* out = reinterpret_cast(bytes.data()); + + const int32_t* in = tensor_proto.int32_data().data(); + for (int i = 0; i < num_entries; ++i) { + out[i] = static_cast(in[i]); + } +} + +// copy T data from the TensorProto.uint64_data field to TensorValue.bytes +template +void CopyUInt64DataToBytes(const ONNX_NAMESPACE::TensorProto& tensor_proto, MILSpec::TensorValue tensor_value) { + const int num_entries = tensor_proto.uint64_data_size(); + std::string& bytes = *tensor_value.mutable_bytes()->mutable_values(); + bytes.resize(num_entries * sizeof(T)); + T* out = reinterpret_cast(bytes.data()); + + const uint64_t* in = tensor_proto.uint64_data().data(); + for (int i = 0; i < num_entries; ++i) { + out[i] = static_cast(in[i]); + } +} + +// NOTE: This supports all the ONNX data types. Weights in CoreML may not need all these +void CopyOnnxTensorToCoreMLTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto, + MILSpec::TensorValue& tensor_value) { + bool has_raw_data = tensor_proto.has_raw_data(); + auto data_type = tensor_proto.data_type(); + + // handling based on + // ONNX TensorProto field usage + // https://github.com/onnx/onnx/blob/b86cc54efce19530fb953e4b21f57e6b3888534c/onnx/onnx.proto#L544-L572 + // CoreMLTools conversion implementation that maps data types to fields + // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/helper.py#L98 + // along with some special cased types that are stored in bytes + // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/helper.py#L23 + // IMMEDIATE_VALUE_TYPES_IN_BYTES = (types.fp16, types.int8, types.uint8, types.uint32) + + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + // from: float_data/raw, to: floats + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_floats()->mutable_values()); + } else { + tensor_value.mutable_floats()->mutable_values()->CopyFrom(tensor_proto.float_data()); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: { + // from: double_data/raw, to: doubles + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_doubles()->mutable_values()); + } else { + tensor_value.mutable_doubles()->mutable_values()->CopyFrom(tensor_proto.double_data()); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT32: { + // from: int32_data/raw, to: ints + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_ints()->mutable_values()); + } else { + tensor_value.mutable_ints()->mutable_values()->CopyFrom(tensor_proto.int32_data()); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT64: { + // enable when this is proven to not be the case + ORT_THROW( + "INT64 is unexpected as CoreML uses 32-bit int for indices. " + "Most likely an initializer that should have been skipped was not."); + //// from: int64_data/raw, to: longints + // if (has_raw_data) { + // CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_longints()->mutable_values()); + + //} else { + // tensor_value.mutable_longints()->mutable_values()->CopyFrom(tensor_proto.int64_data()); + //} + // break; + } + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { + // from: int32_data/raw, to: bytes + if (has_raw_data) { + *tensor_value.mutable_bytes()->mutable_values() = tensor_proto.raw_data(); + } else { + // iterate the int32_data, taking the 16-bits from each entry, and copying to the bytes. + // we use uint16_t as only the size of the data type matters + CopyInt32DataToBytes(tensor_proto, tensor_value); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { + // from: int32_data/raw, to: bytes + if (has_raw_data) { + *tensor_value.mutable_bytes()->mutable_values() = tensor_proto.raw_data(); + } else { + // copy from int32_data to bytes. uint8_t for both as only the size of the data type matters when copying + CopyInt32DataToBytes(tensor_proto, tensor_value); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: { + // from: uint64_data/raw, to: bytes + if (has_raw_data) { + *tensor_value.mutable_bytes()->mutable_values() = tensor_proto.raw_data(); + } else { + // copy uint32_t values from TensorProto.uint64_data + CopyUInt64DataToBytes(tensor_proto, tensor_value); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: { + // enable when this is proven to not be the case + ORT_THROW( + "UINT64 is unexpected as CoreML uses 32-bit int for indices. " + "Most likely an initializer that should have been skipped was not."); + //// from: uint64_data/raw, to: longints + // if (has_raw_data) { + // CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_longints()->mutable_values()); + // } else { + // // TODO: Is this safe? Need to check the CopyFrom implementation. As it's a straight copy of bytes this + // // hopefully can do it as one block instead of iterating and potentially doing a static_cast of each + // // individual value. + // tensor_value.mutable_longints()->mutable_values()->CopyFrom( + // reinterpret_cast&>(tensor_proto.uint64_data())); + // } + + // break; + } + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: { + // from: int32_data/raw, to: bools + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_bools()->mutable_values()); + } else { + const auto& int32s = tensor_proto.int32_data(); + auto& bools = *tensor_value.mutable_bools()->mutable_values(); + const int num_entries = int32s.size(); + bools.Reserve(num_entries); + const int32_t* in = int32s.data(); + for (int i = 0; i < num_entries; ++i) { + *bools.AddAlreadyReserved() = *in++; + } + } + + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_STRING: { + // from: string_data (which is protobuf type bytes), to: strings (protobuf type string) + // due to the protobuf type mismatch we need to iterate and copy + auto& in = tensor_proto.string_data(); + auto& out = *tensor_value.mutable_strings()->mutable_values(); + out.Reserve(in.size()); + for (const auto& iter : in) { + *out.Add() = iter; + } + + break; + } + /* Not clear if there's an actual use-case for 16-bit int data currently, so leaving commented out + case ONNX_NAMESPACE::TensorProto_DataType_INT16: + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { + // from: int32_data/raw, to: ints + // WARNING: This may change to write to mutable_bytes + // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/helper.py#L113-L115 + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_ints()->mutable_values()); + } else { + tensor_value.mutable_ints()->mutable_values()->CopyFrom(tensor_proto.int32_data()); + } + break; + } */ + default: + ORT_THROW("AddTensorProtoDataToMILSpecTensorValue: Unsupported data type: ", data_type); + } +} + +template +uint64_t WriteRawDataUsingStorageWriter(const onnx::TensorProto& tensor_proto, + MILBlob::Blob::StorageWriter& writer) { + MILBlob::Util::Span data(reinterpret_cast(tensor_proto.raw_data().data()), + tensor_proto.raw_data().size() / sizeof(T)); + return writer.WriteData(data); +} + +// Write T1 data from the TensorProto.int32_data field using StorageWriter. +// Currently int32_data can have any of these data types: +// INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, +// FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ +// T1 provides the size of the ONNX data type. T2 is the CoreML type. +// The sizes and layout of T1 and T2 must match as we simply cast the bytes to T2. +template +uint64_t WriteFromInt32DataUsingStorageWriter(const onnx::TensorProto& tensor_proto, + MILBlob::Blob::StorageWriter& writer) { + static_assert(sizeof(T1) == sizeof(T2), "Data sizes must match"); + + // need to copy to temporary data as we have to extract a subset of bytes from each int32_t entry. + // works better to extract the ONNX type first with static_cast, and reinterpret_cast to the CoreML type at the end. + std::vector values; + const int num_values = tensor_proto.int32_data_size(); + values.resize(num_values); // resize so we're not updating the length inside the copy loop + + const int32_t* in = tensor_proto.int32_data().data(); + for (int i = 0; i < num_values; ++i) { + values[i] = static_cast(in[i]); + } + + MILBlob::Util::Span data(reinterpret_cast(values.data()), + num_values); + return writer.WriteData(data); +} + +// write the initializer to weight.bin and return the offset +// StorageWriter is currently limited to fp32, fp16, bfloat16, uint8/int8, uint16/int16. +// AFAIK we don't use bfloat16/int16/uint16 for weights in ONNX, so limit handling to fp32, fp16, uint8/int8 +uint64_t CopyOnnxTensorToCoreMLWeightsFile(const onnx::TensorProto& tensor_proto, + MILBlob::Blob::StorageWriter& writer) { + bool has_raw_data = tensor_proto.has_raw_data(); + auto data_type = tensor_proto.data_type(); + + uint64_t offset = 0; + + // See AddTensorProtoDataToMILSpecTensorValue for links to sources for info on where the different typed data is + // stored for ONNX and CoreML + + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + // from: float_data/raw, to: floats + if (has_raw_data) { + offset = WriteRawDataUsingStorageWriter(tensor_proto, writer); + } else { + MILBlob::Util::Span data(tensor_proto.float_data().data(), tensor_proto.float_data().size()); + offset = writer.WriteData(data); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { + // from: int32_data/raw, to: bytes + if (has_raw_data) { + offset = WriteRawDataUsingStorageWriter(tensor_proto, writer); + } else { + offset = WriteFromInt32DataUsingStorageWriter(tensor_proto, writer); + } + + break; + } + + case ONNX_NAMESPACE::TensorProto_DataType_INT8: { + // from: int32_data/raw, to: bytes + if (has_raw_data) { + offset = WriteRawDataUsingStorageWriter(tensor_proto, writer); + } else { + offset = WriteFromInt32DataUsingStorageWriter(tensor_proto, writer); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { + // from: int32_data/raw, to: bytes + if (has_raw_data) { + offset = WriteRawDataUsingStorageWriter(tensor_proto, writer); + + } else { + offset = WriteFromInt32DataUsingStorageWriter(tensor_proto, writer); + } + break; + } + default: + ORT_THROW("AddWeightToFile: Unsupported data type: ", data_type); + } + + return offset; +} + +MILSpec::Value OnnxTensorToCoreMLTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto, + MILBlob::Blob::StorageWriter& weights_file_writer) { + MILSpec::Value value; + + // populate ValueType with tensor data type, dims and rank + MILSpec::ValueType& value_type = *value.mutable_type(); + MILSpec::TensorType& tensor_type = *value_type.mutable_tensortype(); + tensor_type.set_datatype(OnnxDataTypeToMILSpec(tensor_proto.data_type())); + + tensor_type.set_rank(tensor_proto.dims().size()); + for (const auto& dim : tensor_proto.dims()) { + tensor_type.add_dimensions()->mutable_constant()->set_size(dim); + } + + // add data to either weights.bin or as an immediate value + if (ShouldWriteInitializerToWeightsFile(tensor_proto)) { + uint64_t offset = CopyOnnxTensorToCoreMLWeightsFile(tensor_proto, weights_file_writer); + + auto* file_value = value.mutable_blobfilevalue(); + // Filename copied from + // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/helper.py#L329 + file_value->set_filename("@model_path/weights/weight.bin"); + file_value->set_offset(offset); + } else { + MILSpec::TensorValue& tensor_value = *value.mutable_immediatevalue()->mutable_tensor(); + CopyOnnxTensorToCoreMLTensor(tensor_proto, tensor_value); + } + + return value; +} + +void CreateEmptyFile(const std::string& filename) { + std::ofstream file(filename, std::ofstream::out | std::ofstream::binary); + ORT_ENFORCE(file.is_open(), "Failed to open file ", filename); +} + +#endif // defined(COREML_ENABLE_MLPROGRAM) + +std::string GetModelOutputPath(bool create_ml_program) { + // path is used to create the ML Package directory for ML Program, and for the model directly otherwise. + auto path = util::GetTemporaryFilePath(); + if (!create_ml_program) { + path += ".model.mlmodel"; + } + + return path; +} +} // namespace + +ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, + int32_t coreml_version, uint32_t coreml_flags, + std::vector&& onnx_input_names, + std::vector&& onnx_output_names) : graph_viewer_(graph_viewer), logger_(logger), - coreml_flags_(coreml_flags) { -} + coreml_version_(coreml_version), + coreml_flags_(coreml_flags), + create_ml_program_((coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0), + model_output_path_(GetModelOutputPath(create_ml_program_)), + onnx_input_names_(std::move(onnx_input_names)), + onnx_output_names_(std::move(onnx_output_names)), + coreml_model_(std::make_unique()) { + if (create_ml_program_) { +#if defined(COREML_ENABLE_MLPROGRAM) + coreml_model_->set_specificationversion(CoreMLSpecVersion()); + MILSpec::Program& mlprogram = *coreml_model_->mutable_mlprogram(); + mlprogram.set_version(1); + mlprogram_main_fn_ = &(*mlprogram.mutable_functions())["main"]; -Status ModelBuilder::Initialize() { - coreml_model_ = std::make_unique(); - { // initialize CoreML model + const std::string coreml_opset = "CoreML" + std::to_string(CoreMLVersion()); + *mlprogram_main_fn_->mutable_opset() = coreml_opset; + mlprogram_main_block_ = &(*mlprogram_main_fn_->mutable_block_specializations())[coreml_opset]; + + // create the ModelPackage. this creates the output directory. + mlpackage_ = std::make_unique(model_output_path_, /* create */ true); + + // ModelPackage::addItem does a copy of the file. Due to this we 'add' an empty file first, + // and do the actual writes to the file created in the package. + // We can't use ModelPackage::createFile as we have to add a directory for the weights. + std::string tmp_dir = model_output_path_ + "/tmp"; + ORT_THROW_IF_ERROR(Env::Default().CreateFolder(ToPathString(tmp_dir))); + CreateEmptyFile(tmp_dir + "/weight.bin"); + + std::string weights_id = mlpackage_->addItem(tmp_dir, "weights", "com.microsoft.OnnxRuntime", + "CoreML Model Weights"); + auto weights_info = mlpackage_->findItem(weights_id); + weights_file_writer_ = std::make_unique(weights_info->path() + "/weight.bin"); +#else + // should never happen due to handling in coreml_execution_provider.cc + // throw here so all other code in this class can assume create_ml_program_ is only ever true in a build + // where ML Program support is enabled. + ORT_THROW("ML Program is not enabled in this build"); +#endif + } else { // We support CorelML Specification Version 4 (Core ML 3) coreml_model_->set_specificationversion(4); auto* neural_network = coreml_model_->mutable_neuralnetwork(); - neural_network->set_arrayinputshapemapping(::CoreML::Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); + neural_network->set_arrayinputshapemapping( + CoreML::Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); } - PreprocessInitializers(); - ORT_RETURN_IF_ERROR(RegisterInitializers()); - ORT_RETURN_IF_ERROR(RegisterModelInputs()); - ORT_RETURN_IF_ERROR(AddOperations()); - ORT_RETURN_IF_ERROR(RegisterModelOutputs()); + // populate names. + const auto& initializers = graph_viewer_.GetAllInitializedTensors(); + const auto& inputs = graph_viewer_.GetInputs(); + // rough guess to try and avoid reallocs. most nodes produce one output but some have more so allow for that. + // also need to convert attributes to constants so allow for that + unique_names_.reserve(initializers.size() + inputs.size() + size_t(graph_viewer_.NumberOfNodes() * 1.5)); + for (const auto& pair : initializers) { + unique_names_.insert(pair.first); + } - return Status::OK(); + for (const auto* input : inputs) { + unique_names_.insert(input->Name()); + } + + for (const auto& node : graph_viewer_.Nodes()) { + for (const auto& def : node.OutputDefs()) { + if (def->Exists()) { + unique_names_.insert(def->Name()); + } + } + } } -/* static */ const IOpBuilder* ModelBuilder::GetOpBuilder(const Node& node) { - const auto& op_builders = GetOpBuilders(); - const auto it = op_builders.find(node.OpType()); - if (it != op_builders.cend()) - return it->second; +ModelBuilder::~ModelBuilder() = default; - return nullptr; +/* + * NeuralNetwork related helpers + */ +std::unique_ptr ModelBuilder::CreateNNLayer(const Node& node, std::string_view suffix) { + auto layer_name = GetUniqueName(node, suffix); + + std::unique_ptr layer = std::make_unique(); + layer->set_name(layer_name); + return layer; +} + +void ModelBuilder::AddLayer(std::unique_ptr layer) { + auto* neural_network = coreml_model_->mutable_neuralnetwork(); + neural_network->mutable_layers()->AddAllocated(layer.release()); } +/* + * ML Program related helpers + */ +#if defined(COREML_ENABLE_MLPROGRAM) +const std::string& ModelBuilder::GetSafeName(const std::string& name) { + // Check the name is valid according to the MILSpec rules + // `Identifiers, generally used for names and keys, must match the regular expression [A-Za-z\_][A-Za-z0-9\_@]*.` + // + // There is a secondary list of reserved words that the coremltools python uses, but it's not clear if those are + // required here, or if we will ever hit a model that uses one of them. Due to that, skip checking them for now as + // it adds cost and code complexity + // https://github.com/apple/coremltools/blob/8b37641f243b1a3e81452feea311c6e30dcc9287/coremltools/converters/mil/mil/passes/defs/preprocess.py#L151C1-L175C10 + // static InlinedHashSet reserved_names = + // {"any", "bool", "program", "func", "tensor", "list", "dict", "tuple", "true", "false", + // "string", "bf16", "fp16", "fp32", "fp64", "int8", "int16", "int32", "int64", + // "uint8", "uint16", "uint32", "uint64"}; + + // handle empty name. shouldn't happen but code below assumes name is not empty + if (name.empty()) { + return name; + } + + // We don't need '@' or '\' even though they're allowed. Optimize for a good name that does not need to be changed. + + // has been sanitized and changed already + const auto entry = values_to_rename_.find(name); + if (entry != values_to_rename_.end()) { + return entry->second; + } + + // Replace anything but a good char with '_'. If first char is 0-9 we prefix with '_'; + bool changed = false; + std::string result = name; + + if (std::isdigit(result[0])) { + changed = true; + result = '_' + name; + } + + for (char& c : result) { + if (!std::isalnum(c) && c != '_') { + changed = true; + c = '_'; + } + } + + if (!changed) { + return name; // return original as the return value is a reference that must remain valid + } + + return (values_to_rename_[name] = GetUniqueName(result)); +} + +void ModelBuilder::SanitizeNames() { + // ML Model level inputs/outputs + auto* desc = coreml_model_->mutable_description(); + for (auto& input : *desc->mutable_input()) { + input.set_name(GetSafeName(input.name())); + } + + for (auto& output : *desc->mutable_output()) { + output.set_name(GetSafeName(output.name())); + } + + // main function inputs/outputs. + for (auto& input : *mlprogram_main_fn_->mutable_inputs()) { + input.set_name(GetSafeName(input.name())); + } + + // outputs from block with operations for current coreml version + for (auto& output : *mlprogram_main_block_->mutable_outputs()) { + output = GetSafeName(output); + } + + // iterate operations changing input/output/node names + for (auto& op : *mlprogram_main_block_->mutable_operations()) { + for (auto& input : *op.mutable_inputs()) { + for (auto& arg : *input.second.mutable_arguments()) { + arg.set_name(GetSafeName(arg.name())); + } + } + + for (auto& output : *op.mutable_outputs()) { + output.set_name(GetSafeName(output.name())); + } + } +} + +std::unique_ptr ModelBuilder::CreateOperation(const Node& node, + std::string_view op_type, + std::string_view suffix) { + std::string operation_name = GetUniqueName(node, suffix); + + std::unique_ptr op = std::make_unique(); + op->set_type(std::string(op_type)); + (*op->mutable_attributes())["name"] = CreateScalarTensorValue(operation_name); + + return op; +} + +const std::string& ModelBuilder::AddConstantOperation(std::string_view name, MILSpec::Value&& coreml_tensor) { + // Replicates coremltools/converters/mil/backend/mil/load.py translate_const logic + MILSpec::Operation& const_op = *mlprogram_main_block_->mutable_operations()->Add(); + const_op.set_type("const"); + + MILSpec::NamedValueType& output = *const_op.mutable_outputs()->Add(); + output.set_name(std::string(name)); + *output.mutable_type() = coreml_tensor.type(); + + auto& attr_map = *const_op.mutable_attributes(); + // the operation name doesn't really matter as it isn't used elsewhere, so sanitize name now + attr_map["name"] = CreateScalarTensorValue(GetSafeName(output.name())); + attr_map["val"] = std::move(coreml_tensor); + + return output.name(); +} + +// Add operation to the Block for the main function in the ML Program +void ModelBuilder::AddOperation(std::unique_ptr operation) { + mlprogram_main_block_->mutable_operations()->AddAllocated(operation.release()); +} + +const std::string& ModelBuilder::AddTensorValueAsConstantOperation(std::string_view op_type, + std::string_view value_type, + MILSpec::Value&& input_value) { + auto unique_value_name = GetUniqueName(MakeString(op_type, "_", value_type)); + return AddConstantOperation(unique_value_name, std::move(input_value)); +} + +template +std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { + // add specialization below + static_assert(false_for_T, "Missing specialization for value type"); + + return "ModelBuilder::AddConstant error"; // unreachable +} + +template <> +std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { + auto input_value = CreateTensorValue(value, shape); + return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); +} + +template <> +std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { + auto input_value = CreateTensorValue(value, shape); // CoreML uses int32 + return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); +} + +template <> +std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { + auto input_value = CreateTensorValue(value, shape); + return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); +} + +template <> +std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { + auto input_value = CreateTensorValue(value, shape); + return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); +} + +#endif // defined(COREML_ENABLE_MLPROGRAM) + +/* + * General implementation + */ void ModelBuilder::PreprocessInitializers() { - // TODO: We should be using GetConstantInitializer not GetAllInitializedTensors in all places + // TODO: We should be using GetConstantInitializer not GetAllInitializedTensors in all places. + // non-constant initializers need to be passed in as model inputs in case they're overridden at runtime. const auto& initializers = graph_viewer_.GetAllInitializedTensors(); const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); @@ -64,6 +683,7 @@ void ModelBuilder::PreprocessInitializers() { initializer_usage_[input->Name()]++; } } + if (const auto* op_builder = GetOpBuilder(node)) { op_builder->AddInitializersToSkip(*this, node); } @@ -77,27 +697,36 @@ Status ModelBuilder::RegisterInitializers() { // skip initializer if there is no remaining usage auto usage_count = initializer_usage_[name]; - if (usage_count == 0) + if (usage_count == 0) { continue; + } - std::unique_ptr layer = std::make_unique(); - layer->set_name(GetUniqueName("initializer_" + name)); +#if defined(COREML_ENABLE_MLPROGRAM) + if (create_ml_program_) { + MILSpec::Value coreml_tensor = OnnxTensorToCoreMLTensor(tensor, *weights_file_writer_); + ORT_IGNORE_RETURN_VALUE(AddConstantOperation(name, std::move(coreml_tensor))); + } else +#endif + { + std::unique_ptr layer = std::make_unique(); + layer->set_name(GetUniqueName("initializer_" + name)); - // TODO,look at using LoadConstantLayer instead of LoadConstantNDLayer - auto* constant_tensor = layer->mutable_loadconstantnd(); - const auto& shape = tensor.dims(); - if (shape.empty()) { - // This is a scalar initializer, CoreML constant layer requires a shape, make this a {1} tensor - constant_tensor->mutable_shape()->Add(1); - } else { - std::transform(shape.cbegin(), shape.cend(), - google::protobuf::RepeatedFieldBackInserter(constant_tensor->mutable_shape()), - [](int64_t dim) -> uint64_t { return SafeInt(dim); }); - } + // TODO,look at using LoadConstantLayer instead of LoadConstantNDLayer + auto* constant_tensor = layer->mutable_loadconstantnd(); + const auto& shape = tensor.dims(); + if (shape.empty()) { + // This is a scalar initializer, CoreML constant layer requires a shape, make this a {1} tensor + constant_tensor->mutable_shape()->Add(1); + } else { + std::transform(shape.cbegin(), shape.cend(), + google::protobuf::RepeatedFieldBackInserter(constant_tensor->mutable_shape()), + [](int64_t dim) -> uint64_t { return SafeInt(dim); }); + } - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*constant_tensor->mutable_data(), tensor)); - *layer->mutable_output()->Add() = name; - AddLayer(std::move(layer)); + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*constant_tensor->mutable_data(), tensor)); + *layer->mutable_output()->Add() = name; + AddLayer(std::move(layer)); + } } return Status::OK(); @@ -109,32 +738,33 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i if (is_input) { // input should not be an initializer - if (Contains(GetInitializerTensors(), name)) + if (Contains(GetInitializerTensors(), name)) { return Status::OK(); + } // This input will not be used - if (Contains(skipped_inputs_, name)) + if (Contains(skipped_inputs_, name)) { return Status::OK(); + } } auto* model_description = coreml_model_->mutable_description(); - auto& input_output = is_input - ? *model_description->mutable_input()->Add() - : *model_description->mutable_output()->Add(); + auto& input_output = is_input ? *model_description->mutable_input()->Add() + : *model_description->mutable_output()->Add(); input_output.set_name(name); + auto* multi_array = input_output.mutable_type()->mutable_multiarraytype(); std::vector shape; - ORT_RETURN_IF_NOT(GetShape(node_arg, shape, logger_), - "Unable to get shape for ", input_output_type, ": ", name); + ORT_RETURN_IF_NOT(GetShape(node_arg, shape, logger_), "Unable to get shape for ", input_output_type, ": ", name); if (shape.empty()) { - // If we have an empty shape, this is a scalar input, - // Since all the input output of CoreML EP is MultiArray, we will make the scalar input output as a {1} MultiArray + // If we have an empty shape, this is a scalar + // Since all the input/output of CoreML EP is MultiArray, we will make the scalar input/output a {1} MultiArray shape.push_back(1); - // we need to change the shapes of these scalar outputs back to {} when CoreML EP returns these values to ORT + // we need to change the shapes of scalar outputs back to {} when CoreML EP returns values to ORT if (!is_input) { AddScalarOutput(name); } @@ -179,15 +809,15 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i data_type = type_proto->tensor_type().elem_type(); switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - multi_array->set_datatype(COREML_SPEC::ArrayFeatureType::FLOAT32); + multi_array->set_datatype(ArrayFeatureType::FLOAT32); break; case ONNX_NAMESPACE::TensorProto_DataType_INT32: - multi_array->set_datatype(COREML_SPEC::ArrayFeatureType::INT32); + multi_array->set_datatype(ArrayFeatureType::INT32); break; case ONNX_NAMESPACE::TensorProto_DataType_INT64: // If we have an int64 input/output type, since COREML_SPEC:ArrayFeatureType does not support INT64 // we assign it to be INT32 here - multi_array->set_datatype(COREML_SPEC::ArrayFeatureType::INT32); + multi_array->set_datatype(ArrayFeatureType::INT32); if (!is_input) { // Record the output names and we need to change them back to Int64 when CoreML EP returns these values to ORT AddInt64Output(name); @@ -204,6 +834,26 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i input_output_info_.emplace(name, OnnxTensorInfo{data_type, shape}); +#if defined(COREML_ENABLE_MLPROGRAM) + if (create_ml_program_) { + if (is_input) { + // the model inputs need to be wired up as args to the 'main' function. + auto tensor_value_type = CreateNamedTensorValueType(node_arg); + tensor_value_type.set_name(name); + if (node_arg.Shape()->dim_size() == 0) { + // update shape from {} to {1} (same change we made at the model input level above). + tensor_value_type.mutable_type()->mutable_tensortype()->set_rank(1); + tensor_value_type.mutable_type()->mutable_tensortype()->add_dimensions()->mutable_constant()->set_size(1); + } + + mlprogram_main_fn_->mutable_inputs()->Add(std::move(tensor_value_type)); + } else { + // the model outputs need to be set as outputs of the Block for the 'main' function + *mlprogram_main_block_->mutable_outputs()->Add() = name; + } + } +#endif // defined(COREML_ENABLE_MLPROGRAM) + return Status::OK(); } @@ -215,16 +865,16 @@ Status ModelBuilder::RegisterModelInputs() { return Status::OK(); } -Status ModelBuilder::AddOperations() { - const auto builder_params = MakeOpBuilderParams(graph_viewer_, coreml_flags_); - const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); - for (size_t i = 0; i < node_indices.size(); i++) { - const auto* node(graph_viewer_.GetNode(node_indices[i])); - if (const auto* op_builder = GetOpBuilder(*node)) { - ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(*this, *node, builder_params, logger_)); +Status ModelBuilder::ProcessNodes() { + for (const auto node_idx : graph_viewer_.GetNodesInTopologicalOrder()) { + const auto& node = *graph_viewer_.GetNode(node_idx); + if (const auto* op_builder = GetOpBuilder(node)) { + ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(*this, node, logger_)); } else { + // This shouldn't happen as this is called from CoreMLExecutionProvider::Compile and should only be processing + // nodes that we said were supported and were returned from CoreMLExecutionProvider::GetCapability. return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Node [", node->Name(), "], type [", node->OpType(), "] is not supported"); + "Node [", node.Name(), "], type [", node.OpType(), "] was not able to be processed"); } } @@ -239,29 +889,121 @@ Status ModelBuilder::RegisterModelOutputs() { return Status::OK(); } -Status ModelBuilder::Compile(std::unique_ptr& model, const std::string& path) { - ORT_RETURN_IF_ERROR(SaveCoreMLModel(path)); - model.reset(new Model(path, logger_, coreml_flags_)); - model->SetScalarOutputs(std::move(scalar_outputs_)); - model->SetInt64Outputs(std::move(int64_outputs_)); - model->SetInputOutputInfo(std::move(input_output_info_)); - return model->LoadModel(); +Status ModelBuilder::CreateModel() { + PreprocessInitializers(); + + ORT_RETURN_IF_ERROR(RegisterInitializers()); + ORT_RETURN_IF_ERROR(RegisterModelInputs()); + ORT_RETURN_IF_ERROR(ProcessNodes()); + ORT_RETURN_IF_ERROR(RegisterModelOutputs()); + +#if defined(COREML_ENABLE_MLPROGRAM) + if (create_ml_program_) { + SanitizeNames(); + } +#endif + + return Status::OK(); } -Status ModelBuilder::SaveCoreMLModel(const std::string& path) { - ORT_RETURN_IF_ERROR(Initialize()); - std::ofstream stream(path, std::ofstream::out | std::ofstream::binary); - ORT_RETURN_IF_NOT(coreml_model_->SerializeToOstream(&stream), "Save the CoreML model failed"); +Status ModelBuilder::SaveModel() { + std::string output_path = model_output_path_; + +#if defined(COREML_ENABLE_MLPROGRAM) + if (create_ml_program_) { + std::string tmp_model_path = model_output_path_ + "/tmp/model.mlmodel"; + CreateEmptyFile(tmp_model_path); + + std::string model_id = mlpackage_->setRootModel(tmp_model_path, "model.mlmodel", "com.microsoft.OnnxRuntime", + "CoreML Model Specification"); + auto model_info = mlpackage_->findItem(model_id); + output_path = model_info->path(); + } +#endif - // TODO, Delete, debug only - if (const char* path = std::getenv("ORT_COREML_EP_CONVERTED_MODEL_PATH")) { - std::ofstream temp_stream(path, std::ofstream::out | std::ofstream::binary); - ORT_RETURN_IF_NOT(coreml_model_->SerializeToOstream(&temp_stream), "Save the CoreML model failed"); + // scope this so the stream is closed and flushed by the ofstream dtor + { + LOGS(logger_, INFO) << "Writing CoreML Model to " << output_path; + std::ofstream stream(output_path, std::ofstream::out | std::ofstream::binary); + ORT_RETURN_IF_NOT(coreml_model_->SerializeToOstream(&stream), "Saving the CoreML model failed. Path=", output_path); } +#if defined(COREML_ENABLE_MLPROGRAM) + // need to delete the ModelPackage instance for it to write out the manifest. clear out the other ML Program + // related types as well. + mlprogram_main_block_ = nullptr; + mlpackage_.reset(); + weights_file_writer_.reset(); +#endif + return Status::OK(); } +Status ModelBuilder::LoadModel(std::unique_ptr& model) { +#if defined(COREML_ENABLE_MLPROGRAM) + if (create_ml_program_) { + // we need to provide the sanitized names for model inputs/outputs so that info is captured. + // the input/output matching when we execute the model from the CoreML EP is based on order, so the change + // to the names doesn't matter for that. + auto get_sanitized_names = [this](std::vector&& names) -> std::vector { + std::vector output(std::move(names)); + + for (std::string& name : output) { + name = GetSafeName(name); + } + + return output; + }; + + // also need to update the keys in input_output_info_ + auto get_sanitized_io_info = [this](std::unordered_map&& info) { + std::unordered_map output; + output.reserve(info.size()); + + for (auto entry = info.begin(), end = info.end(); entry != end; ++entry) { + output.emplace(GetSafeName(entry->first), std::move(entry->second)); + } + + return output; + }; + + model = std::make_unique(model_output_path_, + get_sanitized_names(std::move(onnx_input_names_)), + get_sanitized_names(std::move(onnx_output_names_)), + get_sanitized_io_info(std::move(input_output_info_)), + std::move(scalar_outputs_), + std::move(int64_outputs_), + logger_, coreml_flags_); + } else +#endif + { + model = std::make_unique(model_output_path_, + std::move(onnx_input_names_), + std::move(onnx_output_names_), + std::move(input_output_info_), + std::move(scalar_outputs_), + std::move(int64_outputs_), + logger_, coreml_flags_); + } + + return model->LoadModel(); // load using CoreML API, including compilation +} + +// static +Status ModelBuilder::Build(const GraphViewer& graph_viewer, const logging::Logger& logger, + int32_t coreml_version, uint32_t coreml_flags, + std::vector&& onnx_input_names, + std::vector&& onnx_output_names, + std::unique_ptr& model) { + ModelBuilder builder(graph_viewer, logger, coreml_version, coreml_flags, + std::move(onnx_input_names), std::move(onnx_output_names)); + + ORT_RETURN_IF_ERROR(builder.CreateModel()); + ORT_RETURN_IF_ERROR(builder.SaveModel()); + + return builder.LoadModel(model); +} + void ModelBuilder::AddScalarOutput(const std::string& output_name) { scalar_outputs_.insert(output_name); } @@ -270,11 +1012,6 @@ void ModelBuilder::AddInt64Output(const std::string& output_name) { int64_outputs_.insert(output_name); } -void ModelBuilder::AddLayer(std::unique_ptr layer) { - auto* neural_network = coreml_model_->mutable_neuralnetwork(); - neural_network->mutable_layers()->AddAllocated(layer.release()); -} - void ModelBuilder::AddInitializerToSkip(const std::string& tensor_name) { // decrement usage count if this is a known initializer. // For simplicity the OpBuilder::AddInitializersToSkip implementations may call this for arbitrary input names @@ -289,16 +1026,34 @@ void ModelBuilder::AddInputToSkip(const std::string& input_name) { skipped_inputs_.insert(input_name); } -std::string ModelBuilder::GetUniqueName(const std::string& base_name) { +const std::string& ModelBuilder::GetUniqueName(const std::string& base_name) { + if (unique_names_.find(base_name) == unique_names_.end()) { + return *unique_names_.insert(base_name).first; + } + std::string unique_name; - do { - std::ostringstream os; - os << base_name << "_token_" << name_token_++; - unique_name = os.str(); - } while (Contains(unique_names_, unique_name)); + std::string suffix; + + // supports up to 1000 unique names without having to grow in the loop + unique_name.reserve(base_name.size() + 5); + unique_name = base_name; + + while (Contains(unique_names_, unique_name)) { + // assign followed by += to avoid creating temporary strings. + unique_name = base_name; + unique_name += "__"; + unique_name += std::to_string(name_token_++); + } - return unique_name; + return *unique_names_.insert(unique_name).first; } +const std::string& ModelBuilder::GetUniqueName(const Node& node, std::string_view suffix) { + if (node.Name().empty()) { + return GetUniqueName(MakeString(node.OpType(), "_", node.Index(), suffix)); + } else { + return GetUniqueName(node.Name() + std::string(suffix)); + } +} } // namespace coreml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.h b/onnxruntime/core/providers/coreml/builders/model_builder.h index af2d5437be8d1..8f85ab2c09e7c 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.h +++ b/onnxruntime/core/providers/coreml/builders/model_builder.h @@ -3,57 +3,175 @@ #pragma once +#include "core/common/span_utils.h" #include "core/graph/graph_viewer.h" #include "core/providers/coreml/builders/coreml_spec.h" +#include "core/providers/coreml/model/model.h" + +#if defined(COREML_ENABLE_MLPROGRAM) +// coremltools classes +namespace MPL { +class ModelPackage; +} + +namespace MILBlob { +namespace Blob { +class StorageWriter; +} +} // namespace MILBlob +#endif namespace onnxruntime { namespace coreml { class IOpBuilder; -class Model; -struct OnnxTensorInfo; class ModelBuilder { + private: + ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, + int32_t coreml_version, uint32_t coreml_flags, + std::vector&& onnx_input_names, + std::vector&& onnx_output_names); + public: - ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, uint32_t coreml_flags); - ~ModelBuilder() = default; + // Create the CoreML model, serialize to disk, load and compile using the CoreML API and return in `model` + static Status Build(const GraphViewer& graph_viewer, const logging::Logger& logger, + int32_t coreml_version, uint32_t coreml_flags, + std::vector&& onnx_input_names, + std::vector&& onnx_output_names, + std::unique_ptr& model); - Status Compile(std::unique_ptr& model, const std::string& path); - Status SaveCoreMLModel(const std::string& path); + ~ModelBuilder(); - // Accessors for members const GraphViewer& GetGraphViewer() const { return graph_viewer_; } const InitializedTensorSet& GetInitializerTensors() const { return graph_viewer_.GetAllInitializedTensors(); } - + const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name) const { + return graph_viewer_.GetConstantInitializer(name, true); + } + + // Since CoreML 2 the spec version is +1 as CoreML 1.1 was spec version 2. + // We only support CoreML 3 and later so the spec version is always version + 1. + int32_t CoreMLVersion() const { return coreml_version_; } + int32_t CoreMLSpecVersion() const { return coreml_version_ + 1; } + + // Returns true if we are creating an ML Program + bool CreateMLProgram() const { +#if defined(COREML_ENABLE_MLPROGRAM) + return create_ml_program_; +#else + return false; +#endif + } + + /* + * NeuralNetworkLayer helpers + */ + + // Create a NeuralNetwork layer using the node name and optional suffix for the name. + // If Node has no name a unique name will be generated from the node index and operator. + std::unique_ptr CreateNNLayer(const Node& node, std::string_view suffix = ""); + + // Add layer to the Core ML NeuralNetwork model void AddLayer(std::unique_ptr layer); - // The initializer will be processed separately, skip it as an initializer +#if defined(COREML_ENABLE_MLPROGRAM) + /* + * MLProgram helpers + */ + + // Create Operation, set type and the unique name attribute. + std::unique_ptr CreateOperation(const Node& node, std::string_view op_type, + std::string_view suffix = ""); + + // + // Helpers for adding attributes from ONNX nodes as inputs to an ML Program Operation + // + + /// + /// Add a value as a 'const' operation, generating a unique name for the value from op_type and value_type. + /// Use for values that were not initializers in the original ONNX model. e.g. attributes from ONNX nodes. + /// Add existing initializers using AddConstant with the TensorProto. + /// + /// e.g. adding the bias input of Gemm would have op_type='gemm' and value_type='bias'. + /// + /// Value type. + /// Typically MILSpec::Operation.type(). + /// Typically the input name of the operation that will consume the value. + /// Value to add. + /// Optional shape for the value. + /// If T is a primitive type `shape` is ignored and the value is treated as a scalar. + /// For a container type, if `shape` is not provided the shape is inferred to be 1-D of {value.size()}. + /// + /// Unique name generated for value. + template + std::string_view AddConstant(std::string_view op_type, std::string_view value_type, gsl::span value, + std::optional> shape = std::nullopt) { + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v, + // add specialization in AddConstantImpl for new types if needed + "AddConstant currently supports float, int64_t, std::string and bool."); + return AddConstantImpl(op_type, value_type, value, shape); + } + + template + std::string_view AddConstant(std::string_view op_type, std::string_view value_type, const std::vector& value, + std::optional> shape = std::nullopt) { + return AddConstant(op_type, value_type, AsSpan(value), shape); + } + + /// + /// Add a scalar value as a 'const' operation. See AddConstant for details. + /// + template + std::string_view AddScalarConstant(std::string_view op_type, std::string_view value_type, const T& value) { + return AddConstant(op_type, value_type, AsSpan({value}), AsSpan({})); + } + + // add the operation to the main function + void AddOperation(std::unique_ptr operation); +#endif + + /* + * General helpers + */ + + // The initializer is processed separately (e.g. layout is transformed) by the operator builder, + // so we don't do a copy of the original initializer into the model. void AddInitializerToSkip(const std::string& tensor_name); // There are some input which will not be used, add it to a list which will not // be added to CoreML model, since CoreML does not like input unused void AddInputToSkip(const std::string& input_name); - std::string GetUniqueName(const std::string& base_name); - - private: - const GraphViewer& graph_viewer_; - const logging::Logger& logger_; - uint32_t coreml_flags_; - - std::unique_ptr coreml_model_; - std::unordered_set scalar_outputs_; - std::unordered_set int64_outputs_; - std::unordered_map input_output_info_; + const std::string& GetUniqueName(const std::string& base_name); + const std::string& GetUniqueName(const Node& node, std::string_view suffix); - std::unordered_map initializer_usage_; - std::unordered_set skipped_inputs_; + const logging::Logger& Logger() const { return logger_; } - uint32_t name_token_{0}; - std::unordered_set unique_names_; - - // Convert the onnx model to CoreML::Specification::Model - Status Initialize(); + private: +#if defined(COREML_ENABLE_MLPROGRAM) + template + std::string_view AddConstantImpl(std::string_view op_type, std::string_view value_type, gsl::span value, + std::optional> shape = std::nullopt); + + // apply the CoreML naming rules and fix any invalid names. + const std::string& GetSafeName(const std::string& name); + // sanitize all the names in the ML Model + void SanitizeNames(); + + // add Value as a const operation. return value name in case sanitization changed it + const std::string& AddConstantOperation(std::string_view name, COREML_SPEC::MILSpec::Value&& initializer); + const std::string& AddTensorValueAsConstantOperation(std::string_view op_type, std::string_view value_type, + COREML_SPEC::MILSpec::Value&& input_value); +#endif + + // Convert the ONNX model in graph_viewer_ to a CoreML::Specification::Model and serialize to disk. + // We then load it using CoreML in order compile it. + Status CreateModel(); + Status SaveModel(); + Status LoadModel(std::unique_ptr& model); // If a CoreML operation will use initializers directly, we will add the initializers to the skip list void PreprocessInitializers(); @@ -61,7 +179,7 @@ class ModelBuilder { // Copy and process all the initializers to CoreML model Status RegisterInitializers(); - Status AddOperations(); + Status ProcessNodes(); Status RegisterModelInputs(); Status RegisterModelOutputs(); Status RegisterModelInputOutput(const NodeArg& node_arg, bool is_input); @@ -72,7 +190,45 @@ class ModelBuilder { // Record the onnx int64 type output names void AddInt64Output(const std::string& output_name); - static const IOpBuilder* GetOpBuilder(const Node& node); + const GraphViewer& graph_viewer_; + const logging::Logger& logger_; + const int32_t coreml_version_; + const uint32_t coreml_flags_; + const bool create_ml_program_; // ML Program (CoreML5, iOS 15+, macOS 12+) or NeuralNetwork (old) + const std::string model_output_path_; // create_ml_program_ ? dir for mlpackage : filename for mlmodel + + std::vector onnx_input_names_; + std::vector onnx_output_names_; + + std::unique_ptr coreml_model_; + std::unordered_set scalar_outputs_; + std::unordered_set int64_outputs_; + std::unordered_map input_output_info_; + + std::unordered_map initializer_usage_; + std::unordered_set skipped_inputs_; + + uint32_t name_token_{0}; + std::unordered_set unique_names_; + +#if defined(COREML_ENABLE_MLPROGRAM) + // mlprogram_main_ is the main block of the CoreML ML Program. + // It is set in CreateModel to the CoreML Model.mlprogram.functions['main'].block_specializations['CoreML'] + // entry we create. + COREML_SPEC::MILSpec::Function* mlprogram_main_fn_{nullptr}; // Function that contains a Block with the operations + COREML_SPEC::MILSpec::Block* mlprogram_main_block_{nullptr}; // Block that all the operations are added to + std::unique_ptr mlpackage_; + std::unique_ptr weights_file_writer_; + + // Values must start with [a-zA-A_] + // Additionally they can't be in a list of reserved words. + // If we need to sanitize an initializer name we do so during PreprocessInitializers and apply the change during + // RegisterInitializers. + // We also check inputs in AddOperation and apply the change there. + // This means an op builder author doesn't need to be aware of the renaming. + // https://github.com/apple/coremltools/blob/8b37641f243b1a3e81452feea311c6e30dcc9287/coremltools/converters/mil/mil/passes/defs/preprocess.py#L146-L149 + std::unordered_map values_to_rename_; +#endif }; } // namespace coreml diff --git a/onnxruntime/core/providers/coreml/builders/op_builder.h b/onnxruntime/core/providers/coreml/builders/op_builder.h index 79de6438c9700..0bb7f280c33e6 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder.h +++ b/onnxruntime/core/providers/coreml/builders/op_builder.h @@ -11,36 +11,39 @@ namespace coreml { class ModelBuilder; struct OpBuilderInputParams { - OpBuilderInputParams(const GraphViewer& graph_viewer, bool only_allow_static_input_shapes) + OpBuilderInputParams(const GraphViewer& graph_viewer, + int32_t coreml_version, + bool only_allow_static_input_shapes, + bool create_mlprogram) : graph_viewer(graph_viewer), - only_allow_static_input_shapes(only_allow_static_input_shapes) {} + coreml_version(coreml_version), + only_allow_static_input_shapes(only_allow_static_input_shapes), + create_mlprogram(create_mlprogram) {} const GraphViewer& graph_viewer; + const int32_t coreml_version; // required to determine which version of an operation can be used. const bool only_allow_static_input_shapes; + const bool create_mlprogram; // whether to create ML Program (Core ML 5+) or NeuralNetwork (Core ML 3+) }; class IOpBuilder { public: virtual ~IOpBuilder() = default; - // Add operator related -#ifdef __APPLE__ - public: // Check if the initializers of this operator need preprocess // which will not be copied virtual void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const = 0; // Add the operator to CoreML model virtual Status AddToModelBuilder(ModelBuilder& model_builder, const Node& node, - const OpBuilderInputParams& input_params, const logging::Logger& logger) const = 0; -#endif - // Operator support related - public: // Check if an operator is supported virtual bool IsOpSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const = 0; + + // Does the builder implementation support creating an ML Program? + virtual bool SupportsMLProgram() const = 0; }; } // namespace coreml diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h index d72420bcfff88..6469b4cefa5ea 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h @@ -3,7 +3,7 @@ #pragma once -#include "op_builder.h" +#include "core/providers/coreml/builders/op_builder.h" namespace onnxruntime { namespace coreml { diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc index c133f7b82aba4..0ba715cc7c6d9 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc @@ -2,9 +2,11 @@ // Licensed under the MIT License. #include "core/providers/coreml/coreml_execution_provider.h" +#include "core/providers/coreml/coreml_provider_factory.h" // defines flags #include +#include "core/common/logging/logging.h" #include "core/framework/compute_capability.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/graph_viewer.h" @@ -12,12 +14,10 @@ #include "core/providers/partitioning_utils.h" #include "core/session/onnxruntime_cxx_api.h" -#ifdef __APPLE__ #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/model/host_utils.h" #include "core/providers/coreml/model/model.h" #include "core/providers/coreml/shape_utils.h" -#endif namespace onnxruntime { @@ -25,7 +25,24 @@ constexpr const char* COREML = "CoreML"; CoreMLExecutionProvider::CoreMLExecutionProvider(uint32_t coreml_flags) : IExecutionProvider{onnxruntime::kCoreMLExecutionProvider}, - coreml_flags_(coreml_flags) { + coreml_flags_(coreml_flags), + coreml_version_(coreml::util::CoreMLVersion()) { + if (coreml_version_ < MINIMUM_COREML_VERSION) { + LOGS_DEFAULT(ERROR) << "CoreML EP is not supported on this platform."; + } + +#if defined(COREML_ENABLE_MLPROGRAM) + if (coreml_version_ < MINIMUM_COREML_MLPROGRAM_VERSION && + (coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0) { + LOGS_DEFAULT(WARNING) << "ML Program is not supported on this OS version. Falling back to NeuralNetwork."; + coreml_flags_ ^= COREML_FLAG_CREATE_MLPROGRAM; + } +#else + if ((coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0) { + LOGS_DEFAULT(WARNING) << "ML Program is not supported in this build. Falling back to NeuralNetwork."; + coreml_flags_ ^= COREML_FLAG_CREATE_MLPROGRAM; + } +#endif } CoreMLExecutionProvider::~CoreMLExecutionProvider() {} @@ -35,28 +52,34 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie const IKernelLookup& /*kernel_lookup*/) const { std::vector> result; - // We do not run CoreML EP on subgraph, instead we cover this in the control flow nodes - // TODO investigate whether we want to support subgraph using CoreML EP - if (graph_viewer.IsSubgraph() && !(coreml_flags_ & COREML_FLAG_ENABLE_ON_SUBGRAPH)) { + if (coreml_version_ < MINIMUM_COREML_VERSION) { return result; } const auto& logger = *GetLogger(); + // We do not run CoreML EP on subgraph, instead we cover this in the control flow nodes + // TODO investigate whether we want to support subgraph using CoreML EP. May simply require processing the + // implicit inputs of the control flow node that contains the subgraph as inputs to the CoreML model we generate. + if (graph_viewer.IsSubgraph() && !(coreml_flags_ & COREML_FLAG_ENABLE_ON_SUBGRAPH)) { + return result; + } + const bool has_neural_engine = coreml::HasNeuralEngine(logger); if ((coreml_flags_ & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) && !has_neural_engine) { - LOGS(logger, VERBOSE) << "The current system does not have Apple Neural Engine"; + LOGS(logger, WARNING) << "The current system does not have Apple Neural Engine. CoreML EP will not be used."; return result; } - const auto builder_params = coreml::MakeOpBuilderParams(graph_viewer, coreml_flags_); + const auto builder_params = coreml::MakeOpBuilderParams(graph_viewer, coreml_version_, coreml_flags_); const auto supported_nodes = coreml::GetSupportedNodes(graph_viewer, builder_params, logger); - const auto gen_metadef_name = [&]() { - HashValue model_hash; - int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); - return MakeString(COREML, "_", model_hash, "_", metadef_id); - }; + const auto gen_metadef_name = + [&]() { + HashValue model_hash; + int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); + return MakeString(COREML, "_", model_hash, "_", metadef_id); + }; result = utils::CreateSupportedPartitions(graph_viewer, supported_nodes, {}, gen_metadef_name, COREML, kCoreMLExecutionProvider); @@ -86,34 +109,32 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie return result; } -#ifdef __APPLE__ +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) common::Status CoreMLExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { for (const auto& fused_node_and_graph : fused_nodes_and_graphs) { Node& fused_node = fused_node_and_graph.fused_node; - const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); - coreml::ModelBuilder builder(graph_viewer, *GetLogger(), coreml_flags_); std::unique_ptr coreml_model; - const std::string coreml_model_file_path = coreml::util::GetTemporaryFilePath(); - ORT_RETURN_IF_ERROR(builder.Compile(coreml_model, coreml_model_file_path)); - { - const auto& input_defs = fused_node.InputDefs(); - std::vector onnx_input_names(input_defs.size()); - for (size_t i = 0, end = input_defs.size(); i < end; ++i) { - onnx_input_names[i] = input_defs[i]->Name(); - } - coreml_model->SetOnnxInputs(std::move(onnx_input_names)); - } + auto get_names = [](const ConstPointerContainer>& args) -> std::vector { + std::vector names; + names.reserve(args.size()); - { - const auto& output_defs = fused_node.OutputDefs(); - std::vector onnx_output_names(output_defs.size()); - for (size_t i = 0, end = output_defs.size(); i < end; ++i) { - onnx_output_names[i] = output_defs[i]->Name(); - } - coreml_model->SetOnnxOutputs(std::move(onnx_output_names)); + for (const NodeArg* def : args) { + names.push_back(def->Name()); + } + + return names; + }; + + std::vector onnx_input_names = get_names(fused_node.InputDefs()); + std::vector onnx_output_names = get_names(fused_node.OutputDefs()); + + const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); + ORT_RETURN_IF_ERROR(coreml::ModelBuilder::Build(graph_viewer, *GetLogger(), coreml_version_, coreml_flags_, + std::move(onnx_input_names), std::move(onnx_output_names), + coreml_model)); } coreml_models_.emplace(fused_node.Name(), std::move(coreml_model)); @@ -131,13 +152,14 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector(state); - const auto& model_inputs = model->GetOnnxInputs(); - const auto& model_outputs = model->GetOnnxOutputs(); + + // input/output names used by the CoreML model in the order that matches the fused_node InputDefs/OutputDefs + const auto& model_inputs = model->GetOrderedInputs(); + const auto& model_outputs = model->GetOrderedOutputs(); ORT_RETURN_IF_NOT(model_inputs.size() <= num_inputs, "Inconsistent input sizes"); ORT_RETURN_IF_NOT(model_outputs.size() == num_outputs, "Inconsistent output sizes"); @@ -160,28 +182,25 @@ common::Status CoreMLExecutionProvider::Compile(const std::vectorshape; - ORT_RETURN_IF(!coreml::IsStaticShape(inferred_shape) && coreml::DoesShapeSpecifyZeroElements(shape), - "Input (", input_name, ") has a dynamic shape (", coreml::Shape2String(inferred_shape), - ") but the runtime shape (", coreml::Shape2String(shape), - ") has zero elements. This is not supported by the CoreML EP."); - } + const auto& inferred_shape = input_info->shape; + ORT_RETURN_IF(!coreml::IsStaticShape(inferred_shape) && coreml::DoesShapeSpecifyZeroElements(shape), + "Input (", input_name, ") has a dynamic shape (", coreml::Shape2String(inferred_shape), + ") but the runtime shape (", coreml::Shape2String(shape), + ") has zero elements. This is not supported by the CoreML EP."); // If we have an empty shape, this is a scalar input, // Since all the input output of CoreML EP is MultiArray, we will make the scalar input as a {1} MultiArray - if (shape.empty()) + if (shape.empty()) { shape.push_back(1); + } // CoreML MLMultiArray API expect input to be non-const // https://developer.apple.com/documentation/coreml/mlmultiarray/2881219-initwithdatapointer?language=objc void* inputBuffer = const_cast(input_tensor.GetTensorRawData()); - inputs.emplace( - input_name, - coreml::OnnxTensorData{ - coreml::OnnxTensorInfo{tensor_info.GetElementType(), shape}, - inputBuffer, - }); + inputs.emplace(input_name, coreml::OnnxTensorData{ + coreml::OnnxTensorInfo{tensor_info.GetElementType(), shape}, + inputBuffer, + }); } // From this point we will need to take the exclusive lock on the model until the Predict is @@ -193,14 +212,13 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector static_shape) -> void* { + [&ctx, &model_outputs](const std::string& name, + int32_t requested_onnx_tensor_element_type, + gsl::span static_shape) -> void* { const auto model_output_it = std::find(model_outputs.begin(), model_outputs.end(), name); ORT_ENFORCE(model_output_it != model_outputs.end(), "Failed to find CoreML model output name: ", name); - const auto output_idx = gsl::narrow_cast(std::distance(model_outputs.begin(), model_output_it)); + const auto output_idx = gsl::narrow_cast(std::distance(model_outputs.begin(), model_output_it)); auto output_tensor = ctx.GetOutput(output_idx, static_shape.data(), static_shape.size()); const auto type_and_shape_info = output_tensor.GetTensorTypeAndShapeInfo(); @@ -221,13 +239,15 @@ common::Status CoreMLExecutionProvider::Compile(const std::vectorIsScalarOutput(output_name)) + if (model->IsScalarOutput(output_name)) { output_shape.clear(); + } // Since CoreML EP only accepts int32 output type and onnx requires int64 output, // We are going to set the model output (from int32) ->int64 - if (model->IsInt64Output(output_name)) + if (model->IsInt64Output(output_name)) { output_type = ONNX_NAMESPACE::TensorProto_DataType_INT64; + } outputs.emplace(output_name, coreml::OnnxTensorInfo{output_type, output_shape}); } @@ -241,22 +261,6 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, - std::vector& node_compute_funcs) { - for (const auto& fused_node_and_graph : fused_nodes_and_graphs) { - ORT_UNUSED_PARAMETER(fused_node_and_graph); - NodeComputeInfo compute_info; - compute_info.create_state_func = [](ComputeContext* /*context*/, FunctionState* /*state*/) { return 0; }; - compute_info.release_state_func = [](FunctionState /*state*/) {}; - compute_info.compute_func = [](FunctionState /* state */, const OrtApi* /* api */, - OrtKernelContext* /* context */) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Compute is not supported in this build."); - }; - node_compute_funcs.push_back(compute_info); - } - return Status::OK(); -} -#endif //__APPLE__ +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.h b/onnxruntime/core/providers/coreml/coreml_execution_provider.h index 0201739547dd1..24a001280eef5 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.h +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.h @@ -3,9 +3,9 @@ #pragma once +#include "core/common/inlined_containers.h" #include "core/framework/execution_provider.h" #include "core/framework/model_metadef_id_generator.h" -#include "core/providers/coreml/coreml_provider_factory.h" namespace onnxruntime { namespace coreml { @@ -26,15 +26,14 @@ class CoreMLExecutionProvider : public IExecutionProvider { std::vector& node_compute_funcs) override; #endif + private: // The bit flags which define bool options for COREML EP, bits are defined as // COREMLFlags in include/onnxruntime/core/providers/coreml/coreml_provider_factory.h - const uint32_t coreml_flags_; - - private: -// > -#ifdef __APPLE__ - std::unordered_map> coreml_models_; -#endif + uint32_t coreml_flags_; + const int32_t coreml_version_; ModelMetadefIdGenerator metadef_id_generator_; + + // map of fused_node_name to compiled_coreml_model + InlinedHashMap> coreml_models_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/dump_mlprogram_model.py b/onnxruntime/core/providers/coreml/dump_mlprogram_model.py new file mode 100644 index 0000000000000..a3ceee70684dc --- /dev/null +++ b/onnxruntime/core/providers/coreml/dump_mlprogram_model.py @@ -0,0 +1,27 @@ +import sys + +import coremltools as ct + +if len(sys.argv) < 2: + print(f"Usage: {sys.argv[0]} ") + print("If generated by onnxruntime this will be /Data/com.microsoft.onnxruntime/model.mlmodel") + sys.exit(-1) + +model_path = sys.argv[1] +m = ct.models.MLModel(model_path) + +spec = m.get_spec() +print(spec) + +# Example code if you want to filter output or do more advanced things +# main = spec.mlProgram.functions["main"] +# block = main.block_specializations[main.opset] +# print(f"{len(block.operations)} operators") +# for op in block.operations: +# if op.type == 'const': +# if op.attributes["name"].immediateValue.tensor.strings.values[0] == "conv_0_pad_type_0": +# print(f"Conv pad_type={op.attributes['val'].immediateValue.tensor.strings.values}") +# +# if op.type == 'conv': +# #print(op) +# pass diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/ArrayFeatureExtractor.proto b/onnxruntime/core/providers/coreml/mlmodel_format/ArrayFeatureExtractor.proto deleted file mode 100644 index 2b83ccbe3574f..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/ArrayFeatureExtractor.proto +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -package CoreML.Specification; - -/** - * An array feature extractor. - * - * Given an index, extracts the value at that index from its array input. - * Indexes are zero-based. - */ -message ArrayFeatureExtractor { - repeated uint64 extractIndex = 1; -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/BayesianProbitRegressor.proto b/onnxruntime/core/providers/coreml/mlmodel_format/BayesianProbitRegressor.proto deleted file mode 100644 index 9688d87ce48ba..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/BayesianProbitRegressor.proto +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -package CoreML.Specification; - -/** -* A Bayesian probit regressor. -* -* The probit regression model is superficially similar to the more commonly known -* logistic regression, with sampling distribution of the model given by -* -* P(y=+1|x,w) = Φ(/β) -* -* where w are the set of weights, -* x are the set of features for the given event, -* β is a model hyper-parameter, and -* Φ is the link function, defined to be the CDF of the normal distribution. -* The weights w[i,j] are Gaussian distributed, with mean μ[i,j] and precision 1/(σ[i,j])^2 -* (where i indexes over features and j indexes over the values for the feature). -* The parameter β scales the steepness of the inverse link function. -* -* (see https://en.wikipedia.org/wiki/Probit_model and https://en.wikipedia.org/wiki/Logistic_regression -* for more details on probit model and logistic regression, respectively) -* -* Input: X -* x represents a set of features, each taking on a discrete value (note that continuous values -* would first need to be discretized). x can be represented as a vector where the index i is -* the feature id and x[i] is the feature value. Alternatively, x can be represented as a matrix -* with 2 columns where the first column indicates the feature id and the second column contains -* the feature values, i.e. x[i,0] is the feature id and x[i,1] is the feature value. -* -* additional input features: -* - "optimism": apply a mean shift to the probability, i.e. shift regression mean by o*stdev, -* where o is the "optimism" parameter (see additional output features) -* - "samplingScale": for sampling from posterior, multiply standard deviation by this factor -* - "samplingTruncation": for sampling from posterior, truncate sampling distribution at given multiple of std from mean -* -* Output: Y -* probability P(y|x,w) -* -* additional output features: -* - mean (regression output before applying link function) -* - variance (regression output variance before applying link function) -* - pessimistic probability: P(y|x,w) with a mean shift parameterized by "optimism" feature -* - sampled probability: p ~ P(y|x,w) with standard deviation scaling parametrized by "samplingScale" feature -* and distribution truncated at multiple of standard deviation, -* where multiple parameterized by "samplingTruncation" feature. -* -*/ - -message BayesianProbitRegressor { - - /* - * Parameterization of a Gaussian distribution - */ - message Gaussian { - double mean = 1; - double precision = 2; // inverse of the variance - } - - /* - * Weight for a specific feature value - * The weight is represented as a Gaussian distribution - * with a mean and precision (1/variance) to capture - * uncertainty in the weight - */ - message FeatureValueWeight { - uint32 featureValue = 1; - Gaussian featureWeight = 2; - } - - /* - * Feature with associated weights (for different values) - * Each feature has a set of weights for the (discrete) values - * it can take - */ - message FeatureWeight { - uint32 featureId = 1; - repeated FeatureValueWeight weights = 2; - } - - uint32 numberOfFeatures = 1; - - Gaussian bias = 2; // bias term - - /* - * Set of features with associated weights - */ - repeated FeatureWeight features = 3; // feature weights - - /* - * Set this name to be the same as input feature of type multi-array (1D) - * in the model description you want to use as the regression input - */ - string regressionInputFeatureName = 10; - - /* - * Set this name to be the same as optional input feature of type double - * in the model description you want to use as the optimism input - */ - string optimismInputFeatureName = 11; - - /* - * Set this name to be the same as optional input feature of type double - * in the model description you want to use as the samplingScale input - */ - string samplingScaleInputFeatureName = 12; - - /* - * Set this name to be the same as optional input feature of type double - * in the model description you want to use as the samplingBounds input - */ - string samplingTruncationInputFeatureName = 13; - - /* - * name of 'mean' output feature - */ - string meanOutputFeatureName = 20; - - /* - * name of 'variance' output feature - */ - string varianceOutputFeatureName = 21; - - /* - * name of 'pessimistic' output feature - */ - string pessimisticProbabilityOutputFeatureName = 22; - - /* - * name of 'sampled' output feature: samples from the scaled posterior probability distribuiton - */ - string sampledProbabilityOutputFeatureName = 23; -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/CategoricalMapping.proto b/onnxruntime/core/providers/coreml/mlmodel_format/CategoricalMapping.proto deleted file mode 100644 index 23112d074213a..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/CategoricalMapping.proto +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification; - -/** - * A categorical mapping. - * - * This allows conversion from integers to strings, or from strings to integers. - */ -message CategoricalMapping { - oneof MappingType { - // Conversion from strings to integers - StringToInt64Map stringToInt64Map = 1; - - // Conversion from integer to string - Int64ToStringMap int64ToStringMap = 2; - } - - /** - * The value returned if an input is not contained in the map above. - * If one of these is not set, then an error is raised on an unknown input. - */ - oneof ValueOnUnknown { - // Default output when converting from an integer to a string. - string strValue = 101; - - // Default output when converting from a string to an integer. - int64 int64Value = 102; - } -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/CustomModel.proto b/onnxruntime/core/providers/coreml/mlmodel_format/CustomModel.proto deleted file mode 100644 index 9a6d36e009ada..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/CustomModel.proto +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -package CoreML.Specification; - -/** -* A parameterized model whose function is defined in code -*/ -message CustomModel { - - message CustomModelParamValue { - oneof value { - double doubleValue = 10; - string stringValue = 20; - int32 intValue = 30; - int64 longValue = 40; - bool boolValue = 50; - bytes bytesValue = 60; - } - } - - string className = 10; // The name of the class (conforming to MLCustomModel) corresponding to this model - map parameters = 30; - string description = 40; // An (optional) description provided by the model creator. This information is displayed when viewing the model, but does not affect the model's execution on device. -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/DataStructures.proto b/onnxruntime/core/providers/coreml/mlmodel_format/DataStructures.proto deleted file mode 100644 index 8b120c2d7d102..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/DataStructures.proto +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "FeatureTypes.proto"; - -package CoreML.Specification; - -/** - * A mapping from a string - * to a 64-bit integer. - */ -message StringToInt64Map { - map map = 1; -} - -/** - * A mapping from a 64-bit integer - * to a string. - */ -message Int64ToStringMap { - map map = 1; -} - -/** - * A mapping from a string - * to a double-precision floating point number. - */ -message StringToDoubleMap { - map map = 1; -} - -/** - * A mapping from a 64-bit integer - * to a double-precision floating point number. - */ -message Int64ToDoubleMap { - map map = 1; -} - -/** - * A vector of strings. - */ -message StringVector { - repeated string vector = 1; -} - -/** - * A vector of 64-bit integers. - */ -message Int64Vector { - repeated int64 vector = 1; -} - -/** - * A vector of floating point numbers. - */ -message FloatVector { - repeated float vector = 1; -} - -/** - * A vector of double-precision floating point numbers. - */ -message DoubleVector { - repeated double vector = 1; -} - -/** - * A range of int64 values - */ -message Int64Range { - int64 minValue = 1; - int64 maxValue = 2; -} - -/** - * A set of int64 values - */ -message Int64Set { - repeated int64 values = 1; -} - -/** - * A range of double values - */ -message DoubleRange { - double minValue = 1; - double maxValue = 2; -} - diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/DictVectorizer.proto b/onnxruntime/core/providers/coreml/mlmodel_format/DictVectorizer.proto deleted file mode 100644 index 3f94eeec1745c..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/DictVectorizer.proto +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification; - -/** - * Uses an index mapping to convert a dictionary to an array. - * - * The output array will be equal in length to the index mapping vector parameter. - * All keys in the input dictionary must be present in the index mapping vector. - * - * For each item in the input dictionary, insert its value in the output array. - * The position of the insertion is determined by the position of the item's key - * in the index mapping. Any keys not present in the input dictionary, will be - * zero in the output array. - * - * For example: if the ``stringToIndex`` parameter is set to ``["a", "c", "b", "z"]``, - * then an input of ``{"a": 4, "c": 8}`` will produce an output of ``[4, 8, 0, 0]``. - * - */ -message DictVectorizer { - oneof Map { - /// String keys to indexes - StringVector stringToIndex = 1; - - /// Int keys to indexes - Int64Vector int64ToIndex = 2; - } -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/FeatureTypes.proto b/onnxruntime/core/providers/coreml/mlmodel_format/FeatureTypes.proto deleted file mode 100644 index 8711ac7de3026..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/FeatureTypes.proto +++ /dev/null @@ -1,224 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -package CoreML.Specification; - -/** - * The 64-bit integer feature type. - */ -message Int64FeatureType {} - -/** - * The double-precision floating point number feature type. - */ -message DoubleFeatureType {} - -/** - * The string feature type. - */ -message StringFeatureType {} - - -message SizeRange { - uint64 lowerBound = 1; - int64 upperBound = 2; // negative value means unbound otherwise upperbound is included in range -} - -/** - * The image feature type. - */ -message ImageFeatureType { - // Assumes raw (decompressed) format - enum ColorSpace { - INVALID_COLOR_SPACE = 0; - GRAYSCALE = 10; // 8 bits per pixel - RGB = 20; // 32 bits per pixel: RGBA with A channel ignored - BGR = 30; // 32 bits per pixel: BGRA with A channel ignored - } - - message ImageSize { - uint64 width = 1; - uint64 height = 2; - } - - message EnumeratedImageSizes { - repeated ImageSize sizes = 1; - } - - message ImageSizeRange { - SizeRange widthRange = 1; - SizeRange heightRange = 2; - } - - // The required or default image size is width x height - // - // If specificationVersion <= 2 or SizeFlexibility is empty, - // width x height is the required fixed image size - // - // If SizeFlexibility is present, width x height indicate a "default" - // image size which must be consistent with the flexibilty specified - - int64 width = 1; - int64 height = 2; - - // For specification version >= 3 you can specify image size flexibility. - - oneof SizeFlexibility { - - // Use enumeratedSizes for a set of distinct fixed sizes - // e.g. portrait or landscape: [80 x 100, 100 x 8] - // - // If the width x height fields above are specified then they must be - // one of the sizes listed. - // - // If width and height are not specified above then the default width - // and height will be enumeratedSizes[0] - // - // Must be non-empty - - EnumeratedImageSizes enumeratedSizes = 21; - - // Use imageSizeRange to allow for ranges of values - // e.g. any image greater than 10 x 20: [10..= 3 you can specify image size flexibility. - - oneof ShapeFlexibility { - - // Use enumeratedShapes for a set of distinct fixed shapes - // - // If the shape field is specified then it must be - // one of the enumerated shapes. - /// - // If shape is not specifed, the "default" shape will be considered - // enumeratedShapes[0] - // - // Must be non-empty - - EnumeratedShapes enumeratedShapes = 21; - - // Use shapeRange to allow the size of each dimension vary within - // indpendently specified ranges - // - // If you specify shape above it must fall in the range - // specified in shapeRanges. It will be treated as the default shape. - // - // If you don't specify shape above then the default shape will - // have shape[d] = shapeRange.sizeRanges[d].lowerBound - - ShapeRange shapeRange = 31; - - } - - oneof defaultOptionalValue { - int32 intDefaultValue = 41; - float floatDefaultValue = 51; - double doubleDefaultValue = 61; - } - -} - -/** - * The dictionary feature type. - */ -message DictionaryFeatureType { - /** - * Key/value type tags, with the following restrictions: - * - ``keyType`` must be a hashable type - * - ``valueType`` is assumed to be a ``double`` - */ - oneof KeyType { - Int64FeatureType int64KeyType = 1; - StringFeatureType stringKeyType = 2; - } -} - -/** - * The Sequence feature type. - */ -message SequenceFeatureType { - - /** - * Currently only categorical int64 and String sequences are supported - */ - oneof Type { - Int64FeatureType int64Type = 1; - StringFeatureType stringType = 3; - } - - // Range of allowed size/length/count of sequence - SizeRange sizeRange = 101; -} - -/** - * A feature, which may be optional. - */ -message FeatureType { - oneof Type { - Int64FeatureType int64Type = 1; - DoubleFeatureType doubleType = 2; - StringFeatureType stringType = 3; - ImageFeatureType imageType = 4; - ArrayFeatureType multiArrayType = 5; - DictionaryFeatureType dictionaryType = 6; - SequenceFeatureType sequenceType = 7; - } - - bool isOptional = 1000; -} - diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/FeatureVectorizer.proto b/onnxruntime/core/providers/coreml/mlmodel_format/FeatureVectorizer.proto deleted file mode 100644 index 75eaf14b53669..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/FeatureVectorizer.proto +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -package CoreML.Specification; - -/** - * A FeatureVectorizer puts one or more features into a single array. - * - * The ordering of features in the output array is determined by - * ``inputList``. - * - * ``inputDimensions`` is a zero based index. - */ -message FeatureVectorizer { - message InputColumn { - string inputColumn = 1; - uint64 inputDimensions = 2; - } - - repeated InputColumn inputList = 1; -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/GLMClassifier.proto b/onnxruntime/core/providers/coreml/mlmodel_format/GLMClassifier.proto deleted file mode 100644 index 47f6f4a3c7b8c..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/GLMClassifier.proto +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification; - -/** - * A generalized linear model classifier. - */ -message GLMClassifier { - message DoubleArray { - repeated double value = 1; - } - - enum PostEvaluationTransform { - Logit = 0; - Probit = 1; /// Only binary classification is supported for probit - } - - enum ClassEncoding { - ReferenceClass = 0; /// First class is the reference class - OneVsRest = 1; /// Also called One vs All - } - - repeated DoubleArray weights = 1; - repeated double offset = 2; - PostEvaluationTransform postEvaluationTransform = 3; - ClassEncoding classEncoding = 4; - - /** - * Required class label mapping. - */ - oneof ClassLabels { - StringVector stringClassLabels = 100; - Int64Vector int64ClassLabels = 101; - } -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/GLMRegressor.proto b/onnxruntime/core/providers/coreml/mlmodel_format/GLMRegressor.proto deleted file mode 100644 index 64093c4f156a8..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/GLMRegressor.proto +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -package CoreML.Specification; - -/** - * A generalized linear model regressor. - */ -message GLMRegressor { - message DoubleArray { - repeated double value = 1; - } - - enum PostEvaluationTransform { - NoTransform = 0; - Logit = 1; - Probit = 2; - } - - repeated DoubleArray weights = 1; - repeated double offset = 2; - PostEvaluationTransform postEvaluationTransform = 3; -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/Gazetteer.proto b/onnxruntime/core/providers/coreml/mlmodel_format/Gazetteer.proto deleted file mode 100644 index 6abbffaf623b9..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/Gazetteer.proto +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2019, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification.CoreMLModels; - -/** -* A model which uses an efficient probabilistic representation -* for assigning labels to a set of strings. -*/ -message Gazetteer { - - /* - * Stores the revision number for the model, revision 2 is available on - * iOS, tvOS 13.0+, macOS 10.15+ - */ - uint32 revision = 1; - - /* - * Stores the language of the model, as specified in BCP-47 format, - * e.g. "en-US". See https://tools.ietf.org/html/bcp47 - */ - string language = 10; - - /* - * Natural Lanaguge framework's efficient representation of a gazetter. - */ - bytes modelParameterData = 100; - - /* - * Stores the set of output class labels - */ - oneof ClassLabels { - StringVector stringClassLabels = 200; - } - -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/Identity.proto b/onnxruntime/core/providers/coreml/mlmodel_format/Identity.proto deleted file mode 100644 index 123a15e59156d..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/Identity.proto +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -package CoreML.Specification; - -/** - * An identity model. - * - * This model returns given inputs as outputs, unchanged. - * Intended to be used for testing purposes. - */ -message Identity { -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/Imputer.proto b/onnxruntime/core/providers/coreml/mlmodel_format/Imputer.proto deleted file mode 100644 index 3de280b2f162d..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/Imputer.proto +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification; - -/** - * A transformer that replaces missing values with a default value, - * such as a statistically-derived value. - * - * If ``ReplaceValue`` is set, then missing values of that type are - * replaced with the corresponding value. - * - * For example: if ``replaceDoubleValue`` is set to ``NaN`` - * and a single ``NaN`` double value is provided as input, - * then it is replaced by ``imputedDoubleValue``. However - * if the input is an array of doubles, then any instances - * of ``NaN`` in the array is replaced with the corresponding - * value in ``imputedDoubleArray``. - */ -message Imputer { - oneof ImputedValue { - double imputedDoubleValue = 1; - int64 imputedInt64Value = 2; - string imputedStringValue = 3; - DoubleVector imputedDoubleArray = 4; - Int64Vector imputedInt64Array = 5; - StringToDoubleMap imputedStringDictionary = 6; - Int64ToDoubleMap imputedInt64Dictionary = 7; - } - - oneof ReplaceValue { - double replaceDoubleValue = 11; - int64 replaceInt64Value = 12; - string replaceStringValue = 13; - } -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/ItemSimilarityRecommender.proto b/onnxruntime/core/providers/coreml/mlmodel_format/ItemSimilarityRecommender.proto deleted file mode 100644 index a5a8c11092d36..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/ItemSimilarityRecommender.proto +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -/** - * Each tree is a collection of nodes, - * each of which is identified by a unique identifier. - * - * Each node is either a branch or a leaf node. - * A branch node evaluates a value according to a behavior; - * if true, the node identified by ``true_child_node_id`` is evaluated next, - * if false, the node identified by ``false_child_node_id`` is evaluated next. - * A leaf node adds the evaluation value to the base prediction value - * to get the final prediction. - * - * A tree must have exactly one root node, - * which has no parent node. - * A tree must not terminate on a branch node. - * All leaf nodes must be accessible - * by evaluating one or more branch nodes in sequence, - * starting from the root node. - */ - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification; - - -/** - * Item Similarity Recommender - * - * The Item Similarity recommender takes as input a list of items and scores, - * then uses that information and a table of item similarities to predict similarity - * scores for all items. By default, the items predicted are most similar to the given - * items but not part of that item set. - * - * The predicted score for a given item k is - * sum_(i in observed items) sim_(k,i) * (score_i - shift_k) - * - * Because only the most similar scores for each item i are stored, - * sim_(k,i) is often zero. - * - * For many models, the score adjustment parameter shift_j is zero -- it's occasionally used - * to counteract global biases for popular items. - * - * - * References: - */ -message ItemSimilarityRecommender { - - /** The items similar to a given base item. - */ - message ConnectedItem { - uint64 itemId = 1; - double similarityScore = 2; - } - - /** The formula for the score of a given model as given above, with shift_k - * parameter given by itemScoreAdjustment, and the similar item list filling in - * all the known sim(k,i) scores for i given by itemID and k given by the itemID parameter in - * the similarItemList. - */ - message SimilarItems { - uint64 itemId = 1; - repeated ConnectedItem similarItemList = 2; - double itemScoreAdjustment = 3; - } - - repeated SimilarItems itemItemSimilarities = 1; - - /** One or none of these are given. If none are given, then the items must number 0, 1, ..., num_items - 1. - * If either is given, the length must be exactly num_items. - */ - StringVector itemStringIds = 2; - Int64Vector itemInt64Ids = 3; - - /** Input parameter names specifying different possible inputs to the recommender. - */ - string itemInputFeatureName = 10; /* Required */ - string numRecommendationsInputFeatureName = 11; /* Optional; defaults to all items if not given.*/ - string itemRestrictionInputFeatureName = 12; /* Optional. */ - string itemExclusionInputFeatureName = 13; /* Optional; defaults to input item list if not given. */ - - /** The predicted outputs. At least one of these must be specified. - */ - string recommendedItemListOutputFeatureName = 20; - string recommendedItemScoreOutputFeatureName = 21; - -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/LinkedModel.proto b/onnxruntime/core/providers/coreml/mlmodel_format/LinkedModel.proto deleted file mode 100644 index b113000e80a8d..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/LinkedModel.proto +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) 2019, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; -import public "Parameters.proto"; - -package CoreML.Specification; - -/** - * A model which wraps another (compiled) model external to this one - */ -message LinkedModel { - - oneof LinkType { - // A model located via a file system path - LinkedModelFile linkedModelFile = 1; - } -} - -// Model is referenced by a model file name and search path -message LinkedModelFile { - - // Model file name: e.g. "MyFetureExtractor.mlmodelc" - StringParameter linkedModelFileName = 1; - - // Search path to find the linked model file - // Multiple paths can be searched using the unix-style path separator ":" - // Each path can be relative (to this model) or absolute - // - // An empty string is the same as teh relative search path "." - // which searches in the same location as this model file - // - // There are some special paths which start with $ - // - $BUNDLE_MAIN - Indicates to look in the main bundle - // - $BUNDLE_IDENTIFIER(identifier) - Looks in Bunde with given identifer - StringParameter linkedModelSearchPath = 2; -} - - diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/Model.proto b/onnxruntime/core/providers/coreml/mlmodel_format/Model.proto deleted file mode 100644 index 737233f2e3fe7..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/Model.proto +++ /dev/null @@ -1,322 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -/** - * A Core ML model consists of a specification version - * and a model description, - * and can be any one of the following types: - * - * Neural Networks - * - `NeuralNetwork` - * - * Regressors - * - ``GLMRegressor`` - * - ``SupportVectorRegressor`` - * - ``TreeEnsembleRegressor`` - * - ``NeuralNetworkRegressor`` - * - ``BayesianProbitRegressor`` - * - * Classifiers - * - `NeuralNetworkClassifier` - * - `TreeEnsembleClassifier` - * - `GLMClassifier` - * - `SupportVectorClassifier` - * - `KNearestNeighborsClassifier` - * - * Other models - * - `CustomModel` - * - `TextClassifier` - * - `WordTagger` - * - `Gazetteer` - * - `WordEmbedding` - * - `VisionFeaturePrint` - * - `LinkedModel` - * - `SoundAnalysisPreprocessing` - * - `ItemSimilarityRecommender` - * - * Feature Engineering - * - `Imputer` - * - `Scaler` - * - `Normalizer` - * - `OneHotEncoder` - * - `CategoricalMapping` - * - `FeatureVectorizer` - * - `DictVectorizer` - * - `ArrayFeatureExtractor` - * - `NonMaximumSuppression` - * - * Pipelines - * - `PipelineClassifier` - * - `PipelineRegressor` - * - `Pipeline` - * - * Simple Mathematical Functions - * - `Identity` - */ - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "VisionFeaturePrint.proto"; -import public "TextClassifier.proto"; -import public "WordTagger.proto"; -import public "Gazetteer.proto"; -import public "WordEmbedding.proto"; -import public "ArrayFeatureExtractor.proto"; -import public "BayesianProbitRegressor.proto"; -import public "CategoricalMapping.proto"; -import public "CustomModel.proto"; -import public "DictVectorizer.proto"; -import public "FeatureTypes.proto"; -import public "FeatureVectorizer.proto"; -import public "GLMRegressor.proto"; -import public "GLMClassifier.proto"; -import public "NearestNeighbors.proto"; -import public "Identity.proto"; -import public "Imputer.proto"; -import public "NeuralNetwork.proto"; -import public "Normalizer.proto"; -import public "OneHotEncoder.proto"; -import public "Scaler.proto"; -import public "NonMaximumSuppression.proto"; -import public "SVM.proto"; -import public "TreeEnsemble.proto"; -import public "Parameters.proto"; -import public "ItemSimilarityRecommender.proto"; -import public "SoundAnalysisPreprocessing.proto"; -import public "LinkedModel.proto"; - -package CoreML.Specification; - -/** - * A pipeline consisting of one or more models. - */ -message Pipeline { - repeated Model models = 1; - - // Optional names given for each model - // If not supplied it defaults to ["model0",..., "model"(models.size()-1)] - // These names can be used to disambiguate the scope / domain of a parameter - repeated string names = 2; -} - -/** - * A classifier pipeline. - */ -message PipelineClassifier { - Pipeline pipeline = 1; -} - -/** - * A regressor pipeline. - */ -message PipelineRegressor { - Pipeline pipeline = 1; -} - -/** - * A feature description, - * consisting of a name, short description, and type. - */ -message FeatureDescription { - string name = 1; - string shortDescription = 2; - FeatureType type = 3; -} - -/** - * Model metadata, - * consisting of a short description, a version string, - * an author, a license, and any other user defined - * key/value meta data. - */ -message Metadata { - string shortDescription = 1; - string versionString = 2; - string author = 3; - string license = 4; - map userDefined = 100; -} - -/** - * A description of a model, - * consisting of descriptions of its input and output features. - * Both regressor and classifier models require the name of the - * primary predicted output feature (``predictedFeatureName``). - * Classifier models can specify the output feature containing - * probabilities for the predicted classes - * (``predictedProbabilitiesName``). - */ -message ModelDescription { - repeated FeatureDescription input = 1; - repeated FeatureDescription output = 10; - - // [Required for regressor and classifier models]: the name - // to give to an output feature containing the prediction. - string predictedFeatureName = 11; - - // [Optional for classifier models]: the name to give to an - // output feature containing a dictionary mapping class - // labels to their predicted probabilities. If not specified, - // the dictionary will not be returned by the model. - string predictedProbabilitiesName = 12; - - repeated FeatureDescription trainingInput = 50; - - Metadata metadata = 100; -} - -message SerializedModel { - // Identifier whose content describes the model type of the serialized protocol buffer message. - string identifier = 1; - - // Must be a valid serialized protocol buffer of the above specified type. - bytes model = 2; -} - -/** - * A Core ML model, - * consisting of a specification version, - * a model description, and a model type. - * - * Core ML model compatibility is indicated by - * a monotonically increasing specification version number, - * which is incremented anytime a backward-incompatible change is made - * (this is functionally equivalent to the MAJOR version number - * described by `Semantic Versioning 2.0.0 `_). - * - * Specification Versions : OS Availability (Core ML Version) - * - * 1 : iOS 11, macOS 10.13, tvOS 11, watchOS 4 (Core ML 1) - * - Feedforward & Recurrent Neural Networks - * - General Linear Models - * - Tree Ensembles - * - Support Vector Machines - * - Pipelines - * - Feature Engineering - * - * 2 : iOS 11.2, macOS 10.13.2, tvOS 11.2, watchOS 4.2 (Core ML 1.2) - * - Custom Layers for Neural Networks - * - Float 16 support for Neural Network layers - * - * 3 : iOS 12, macOS 10.14, tvOS 12, watchOS 5 (Core ML 2) - * - Flexible shapes and image sizes - * - Categorical sequences - * - Core ML Vision Feature Print, Text Classifier, Word Tagger - * - Non Max Suppression - * - Crop and Resize Bilinear NN layers - * - Custom Models - * - * 4 : iOS 13, macOS 10.15, tvOS 13, watchOS 6 (Core ML 3) - * - Updatable models - * - Exact shape / general rank mapping for neural networks - * - Large expansion of supported neural network layers - * - Generalized operations - * - Control flow - * - Dynamic layers - * - See NeuralNetwork.proto - * - Nearest Neighbor Classifier - * - Sound Analysis Prepreocessing - * - Recommender - * - Linked Model - * - NLP Gazeteer - * - NLP WordEmbedding - * - * 5 : iOS 14, macOS 11, tvOS 14, watchOS 7 (Core ML 4) - * - Model Deployment - * - Model Encryption - * - Unified converter API with PyTorch and Tensorflow 2 Support in coremltools 4 - * - MIL builder for neural networks and composite ops in coremltools 4 - * - New layers in neural network: - * - CumSum - * - OneHot - * - ClampedReLu - * - ArgSort - * - SliceBySize - * - Convolution3D - * - Pool3D - * - Bilinear Upsample with align corners and fractional factors - * - PixelShuffle - * - MatMul with int8 weights and int8 activations - * - Concat interleave - * - See NeuralNetwork.proto - * - Enhanced Xcode model view with interactive previews - * - Enhanced Xcode Playground support for Core ML models - * - */ -message Model { - int32 specificationVersion = 1; - ModelDescription description = 2; - - /* - * Following model types support on-device update: - * - * - NeuralNetworkClassifier - * - NeuralNetworkRegressor - * - NeuralNetwork - * - KNearestNeighborsClassifier - */ - bool isUpdatable = 10; - - // start at 200 here - // model specific parameters: - oneof Type { - // pipeline starts at 200 - PipelineClassifier pipelineClassifier = 200; - PipelineRegressor pipelineRegressor = 201; - Pipeline pipeline = 202; - - // regressors start at 300 - GLMRegressor glmRegressor = 300; - SupportVectorRegressor supportVectorRegressor = 301; - TreeEnsembleRegressor treeEnsembleRegressor = 302; - NeuralNetworkRegressor neuralNetworkRegressor = 303; - BayesianProbitRegressor bayesianProbitRegressor = 304; - - // classifiers start at 400 - GLMClassifier glmClassifier = 400; - SupportVectorClassifier supportVectorClassifier = 401; - TreeEnsembleClassifier treeEnsembleClassifier = 402; - NeuralNetworkClassifier neuralNetworkClassifier = 403; - KNearestNeighborsClassifier kNearestNeighborsClassifier = 404; - - // generic models start at 500 - NeuralNetwork neuralNetwork = 500; - ItemSimilarityRecommender itemSimilarityRecommender = 501; - - // Custom and linked models - CustomModel customModel = 555; - LinkedModel linkedModel = 556; - - // feature engineering starts at 600 - OneHotEncoder oneHotEncoder = 600; - Imputer imputer = 601; - FeatureVectorizer featureVectorizer = 602; - DictVectorizer dictVectorizer = 603; - Scaler scaler = 604; - CategoricalMapping categoricalMapping = 606; - Normalizer normalizer = 607; - ArrayFeatureExtractor arrayFeatureExtractor = 609; - NonMaximumSuppression nonMaximumSuppression = 610; - - - // simple mathematical functions used for testing start at 900 - Identity identity = 900; - - // reserved until 1000 - - // CoreML provided models - CoreMLModels.TextClassifier textClassifier = 2000; - CoreMLModels.WordTagger wordTagger = 2001; - CoreMLModels.VisionFeaturePrint visionFeaturePrint = 2002; - CoreMLModels.SoundAnalysisPreprocessing soundAnalysisPreprocessing = 2003; - CoreMLModels.Gazetteer gazetteer = 2004; - CoreMLModels.WordEmbedding wordEmbedding = 2005; - - // Reserved private messages start at 3000 - // These messages are subject to change with no notice or support. - SerializedModel serializedModel = 3000; - } -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/NearestNeighbors.proto b/onnxruntime/core/providers/coreml/mlmodel_format/NearestNeighbors.proto deleted file mode 100644 index 82acd8490374d..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/NearestNeighbors.proto +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -package CoreML.Specification; - -import public "DataStructures.proto"; -import public "Parameters.proto"; - -/** - * A k-Nearest-Neighbor classifier - */ -message KNearestNeighborsClassifier { - - /** - * The "core" nearest neighbor model attributes. - */ - NearestNeighborsIndex nearestNeighborsIndex = 1; - - /** - * Number of neighbors to use for classification. - */ - Int64Parameter numberOfNeighbors = 3; - - /** - * Type of labels supported by the model. Currently supports String or Int64 - * labels. - */ - oneof ClassLabels { - StringVector stringClassLabels = 100; - Int64Vector int64ClassLabels = 101; - } - - /** - * Default value of class label (useful when prediction is called on an empty kNN classifier) - */ - oneof DefaultClassLabel { - string defaultStringLabel = 110; - int64 defaultInt64Label = 111; - } - - /** - * Weighting scheme to be used when computing the majority label of a - * new data point. - */ - oneof WeightingScheme { - UniformWeighting uniformWeighting = 200; - InverseDistanceWeighting inverseDistanceWeighting = 210; - } -} - -/** - * The "core" attributes of a Nearest Neighbors model. - */ -message NearestNeighborsIndex { - - /** - * Number of dimensions of the input data. - */ - int32 numberOfDimensions = 1; - - /** - * Vector of floating point data that makes up the model. Each data point must have 'numberOfDimensions' - * dimensions. - */ - repeated FloatVector floatSamples = 2; - - /** - * Backing data structure for the Nearest Neighbors Index. Currently supports - * a linear index or a kd-tree index. - */ - oneof IndexType { - LinearIndex linearIndex = 100; - SingleKdTreeIndex singleKdTreeIndex = 110; - } - - /** - * Distance function to be used to find neighbors. Currently only Squared Euclidean - * Distance is supported. - */ - oneof DistanceFunction { - SquaredEuclideanDistance squaredEuclideanDistance = 200; - } - -} - -/** - * Specifies a uniform weighting scheme (i.e. each neighbor receives equal - * voting power). - */ -message UniformWeighting { -} - - -/** - * Specifies a inverse-distance weighting scheme (i.e. closest neighbors receives higher - * voting power). A nearest neighbor with highest sum of (1 / distance) is picked. - */ -message InverseDistanceWeighting { -} - - -/** - * Specifies a flat index of data points to be searched by brute force. - */ -message LinearIndex { -} - - -/** - * Specifies a kd-tree backend for the nearest neighbors model. - */ -message SingleKdTreeIndex { - - /** - * Number of data points contained within a leaf node of the kd-tree. - */ - int32 leafSize = 1; - -} - - -/** - * Specifies the Squared Euclidean Distance function. - */ -message SquaredEuclideanDistance { -} - diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/NeuralNetwork.proto b/onnxruntime/core/providers/coreml/mlmodel_format/NeuralNetwork.proto deleted file mode 100644 index 44a77c6e7f5f1..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/NeuralNetwork.proto +++ /dev/null @@ -1,6531 +0,0 @@ -// Copyright (c) 2017-2019, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -/** - * A neural network is defined through a collection of layers - * and represents a directed acyclic graph (DAG). - * Each layer has a name, a layer type, - * a list of input names, a list of output names, - * and a collection of parameters specific to the layer type. - * - * The graph structure and connectivity of the neural network - * is inferred from the input and output names. - * A neural network starts with the layer - * whose input name is equal to the value specified in - * ``Model.description.input.name``, - * and ends with the layer - * whose output name is equal to the value specified in - * ``Model.description.output.name``. - * Layers must have unique input and output names, - * and a layer may not have input or output names that - * refer to layers that are not yet defined. - * - * For Core ML specification version <=3, - * all inputs are mapped to static rank 5 tensors, with axis notations - * [Sequence, Batch, Channel, Height, Width]. - * - * From specification version 4 onwards (iOS >= 13, macOS >= 10.15), more options are available - * (see enums ``NeuralNetworkMultiArrayShapeMapping``, ``NeuralNetworkImageShapeMapping``) - * to map inputs to generic N-Dimensional (or N rank) tensors, where N >= 1. - * - * Each layer type may have specific constraints on the ranks of its inputs and outputs. - * - * Some of the layers (such as softmax, reduce, etc) have parameters that have been described in - * terms of notational axis "Channel", "Height", "Width" or "Sequence". They can be re-interpreted easily in - * the general ND setting by using the following rule: - * "width" is same as axis = -1 (i.e. the last axis from the end) - * "height" is same as axis = -2 (i.e. the second last axis from the end) - * "channel" is same as axis = -3 (i.e. the third last axis from the end) - * "sequence" is same as axis = -5 (i.e. the fifth last axis from the end) - * - * Several layers are available in 3 different variations, with the names ending - * in identifiers: ``like``, ``static`` and ``dynamic``. For instance, ``FillLike``, - * ``FillStatic`` and ``FillDynamic``. The ``static`` variation generally will have - * a property corresponding to the shape of the output. For instance, if the - * output of the ``FillStatic`` layer is desired to be of shape (10, 4), the - * property ``targetShape`` will have to be set to [10, 4]. In the ``dynamic`` case, - * the shape is an input, hence it can be changed at runtime. For instance, for - * a ``FillDynamic`` layer, the input would have to be an array containing the - * values 10 and 4, if the desired output is of shape (10, 4). Whereas in the - * ``like`` case, the additional input's shape is used as the output shape, ignoring - * its values. For instance, for a ``FillLike`` layer, for an input with shape - * (10, 4), the output generated will also be of shape (10, 4), values of the - * input will be ignored. - */ - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; -import public "Parameters.proto"; - -package CoreML.Specification; - - -enum NeuralNetworkMultiArrayShapeMapping { - - /* - * Describes how the MultiArray shape for the inputs, - * provided in Features Types proto via model description, - * is mapped to construct tensors that are fed into the Neural Network layers. - */ - - /* - * Default legacy value. Only supported for Core ML Specification version <= 3. - * - * The default legacy shape mapping resolves all input shapes to a rank 5 equivalent - * with axis notation of [Seq, Batch, Channel, Height, Width]. - * - * When this enum value is selected, - * the repeated shape field in the message "ArrayFeatureType" in feature types proto, - * must be either length 1 or length 3. - * - * The following rule is used to map the values in the shape field to the actual tensor shape: - * rank 1 shape is mapped to shape [1,1,C,1,1] - * rank 3 shape is mapped to shape [1,1,C,H,W] - * At runtime, the first two dimensions (Seq or Batch) can be presented as well, with non-1 values. - * - * It is invalid to use this enum value if any of the layers added - * Specification version 4 (iOS >= 13, macOS >= 10.15) onwards are used in the network. - * Validator will raise an error in that case. - */ - RANK5_ARRAY_MAPPING = 0; - - /* - * The exact shape and rank (i.e. number of dimensions in the shape) of the input, - * as specified in the message "ArrayFeatureType", is passed through to the layers. - * Supported only for Specification version >= 4 (iOS >= 13, macOS >= 10.15). - */ - EXACT_ARRAY_MAPPING = 1; - -} - -enum NeuralNetworkImageShapeMapping { - - /* - * Describes how the shape of the input tensors is constructed from image inputs. - */ - - /* - * In this case, image input is mapped to a rank 5 tensor. - * For Color images, input tensor is shaped as [1,1,3,H,W]. - * For Gray images, input tensor is shaped as [1,1,1,H,W]. - */ - RANK5_IMAGE_MAPPING = 0; - - /* - * For Color images, input tensor is shaped as [1,3,H,W]. - * For Gray images, input tensor is shaped as [1,1,H,W]. - * Supported only for Specification version >= 4 (iOS >= 13, macOS >= 10.15). - */ - RANK4_IMAGE_MAPPING = 1; - -} - -/** - A neural network. - */ -message NeuralNetwork { - - repeated NeuralNetworkLayer layers = 1; - repeated NeuralNetworkPreprocessing preprocessing = 2; - - // use this enum value to determine the input tensor shapes to the neural network, for multiarray inputs - NeuralNetworkMultiArrayShapeMapping arrayInputShapeMapping = 5; - - // use this enum value to determine the input tensor shapes to the neural network, for image inputs - NeuralNetworkImageShapeMapping imageInputShapeMapping = 6; - - - NetworkUpdateParameters updateParams = 10; - -} - -/// Preprocessing -/// ------------- - -/** - * A neural network preprocessor that - * performs a scalar multiplication of an image - * followed by addition of scalar biases to the channels. - * - * Input: X - * An image in BGR or RGB format with shape ``[3, H, W]`` - * or in grayscale format with shape ``[1, H, W]``. - * Output: Y - * An image with format and shape corresponding to the input. - * - * If the input image is in BGR format: - * - * .. code:: - * - * Y[0, :, :] = channelScale * X[0, :, :] + blueBias - * Y[1, :, :] = channelScale * X[1, :, :] + greenBias - * Y[2, :, :] = channelScale * X[2, :, :] + redBias - * - * If the input image is in RGB format: - * - * .. code:: - * - * Y[0, :, :] = channelScale * X[0, :, :] + redBias - * Y[1, :, :] = channelScale * X[1, :, :] + greenBias - * Y[2, :, :] = channelScale * X[2, :, :] + blueBias - * - * If the input image is in grayscale format: - * - * .. code:: - * - * Y[0, :, :] = channelScale * X[0, :, :] + grayBias - */ -message NeuralNetworkImageScaler { - - float channelScale = 10; ///Scalar to be multiplied. - float blueBias = 20; ///Scalar blue bias to be added. - float greenBias = 21; ///Scalar green bias to be added. - float redBias = 22; ///Scalar red bias to be added. - float grayBias = 30; ///Scalar bias to be added for grayscale images. - -} - -/** - * A neural network preprocessor that - * subtracts the provided mean image from the input image. - * The mean image is subtracted from the input named - * ``NeuralNetworkPreprocessing.featureName``. - */ -message NeuralNetworkMeanImage { - - /** - * Mean image stored as a flattened array of floats, - * representing shape [Channel,Height,Width]. - */ - repeated float meanImage = 1; - -} - -/// Preprocessing parameters for image inputs. -message NeuralNetworkPreprocessing { - - string featureName = 1; /// must be equal to the input name to which the preprocessing is applied - oneof preprocessor { - NeuralNetworkImageScaler scaler = 10; - NeuralNetworkMeanImage meanImage = 11; - } - -} - -/// Activation Functions -/// -------------------- - -/** - * A rectified linear unit (ReLU) activation function. - * - * This function has the following formula: - * - * .. math:: - * f(x) = \text{max}(0, x) - */ -message ActivationReLU { - -} - -/** - * A leaky rectified linear unit (ReLU) activation function. - * - * This function has the following formula: - * - * .. math:: - * f(x) = \begin{cases} - * x & \text{if } x \geq 0 \\ - * \alpha x & \text{if } x < 0 - * \end{cases} - */ -message ActivationLeakyReLU { - - float alpha = 1; //negative slope value for leakyReLU - -} - -/** - * A hyperbolic tangent activation function. - * - * This function has the following formula: - * - * .. math:: - * f(x) = \dfrac{1 - e^{-2x}}{1 + e^{-2x}} - */ -message ActivationTanh { - -} - -/** - * A scaled hyperbolic tangent activation function. - * - * This function has the following formula: - * - * .. math:: - * f(x) = \alpha \tanh(\beta x) - */ -message ActivationScaledTanh { - - float alpha = 1; - float beta = 2; - -} - -/** - * A sigmoid activation function. - * - * This function has the following formula: - * - * .. math:: - * f(x) = \dfrac{1}{1 + e^{-x}} - */ -message ActivationSigmoid { - -} - -/** - * A linear activation function. - * - * This function has the following formula: - * - * .. math:: - * f(x) = \alpha x + \beta - */ -message ActivationLinear { - - float alpha = 1; - float beta = 2; - -} - -/** - * A hard sigmoid activation function. - * - * This function has the following formula: - * - * .. math:: - * f(x) = \text{min}(\text{max}(\alpha x + \beta, 0), 1) - */ -message ActivationSigmoidHard { - - float alpha = 1; - float beta = 2; - -} - -/** - * A parameterized rectified linear unit (PReLU) activation function. - * Input must be at least rank 3. Axis = -3 is denoted by "C", or channels. - * "alpha" parameter can be a vector of length C. - * - * This function has the following formula: - * - * .. math:: - * f(x_i) = \begin{cases} - * x_i & \text{if } x_i \geq 0 \\ - * \alpha_i x_i & \text{if } x_i < 0 - * \end{cases} \;,\;i=1,...,C - */ -message ActivationPReLU { - - // parameter of length C or 1. - // If length is 1, same value is used for all channels - WeightParams alpha = 1; - -} - -/** - * An exponential linear unit (ELU) activation function. - * - * This function has the following formula: - * - * .. math:: - * f(x) = \begin{cases} - * x & \text{if } x \geq 0 \\ - * \alpha (e^x - 1) & \text{if } x < 0 - * \end{cases} - */ -message ActivationELU { - - float alpha = 1; - -} - -/** - * A thresholded rectified linear unit (ReLU) activation function. - * - * This function has the following formula: - * - * .. math:: - * f(x) = \begin{cases} - * x & \text{if } x \geq \alpha \\ - * 0 & \text{if } x < \alpha - * \end{cases} - */ -message ActivationThresholdedReLU { - - float alpha = 1; - -} - -/** - * A softsign activation function. - * - * This function has the following formula: - * - * .. math:: - * f(x) = \dfrac{x}{1 + |x|} - */ -message ActivationSoftsign { - -} - -/** - * A softplus activation function. - * - * This function has the following formula: - * - * .. math:: - * f(x) = \text{log}(1 + e^x) - */ -message ActivationSoftplus { - -} - -/** - * A parametric softplus activation function. - * Input must be at least rank 3. axis = -3 is denoted by "C", or channels. - * "alpha"/"beta" parameter can be a vector of length C. - * - * This function has the following formula: - * - * .. math:: - * f(x_i) = \alpha_i \text{log}(1 + e^{\beta_i x_i}) \;,\;i=1,...,C - */ -message ActivationParametricSoftplus { - - // If length is 1, same value is used for all channels - WeightParams alpha = 1; //parameter of length C or 1 - WeightParams beta = 2; //parameter of length C or 1 - -} - -message ActivationParams { - - oneof NonlinearityType { - ActivationLinear linear = 5; - - ActivationReLU ReLU = 10; - ActivationLeakyReLU leakyReLU = 15; - ActivationThresholdedReLU thresholdedReLU = 20; - ActivationPReLU PReLU = 25; - - ActivationTanh tanh = 30; - ActivationScaledTanh scaledTanh = 31; - - ActivationSigmoid sigmoid = 40; - ActivationSigmoidHard sigmoidHard = 41; - - ActivationELU ELU = 50; - - ActivationSoftsign softsign = 60; - ActivationSoftplus softplus = 70; - ActivationParametricSoftplus parametricSoftplus = 71; - } - -} - -/** - * Representation of the intermediate tensors - */ -message Tensor { - - // Number of dimensions in the tensor shape - uint32 rank = 1; - // actual value of the tensor shape. - // must be of length "rank". Can contain -1s for unknown dimensions. - repeated int64 dimValue = 2; - -} - -/** - * A single neural network layer. - */ -message NeuralNetworkLayer { - - string name = 1; //descriptive name of the layer - repeated string input = 2; - repeated string output = 3; - - repeated Tensor inputTensor = 4; // must be the same length as the "input" field - repeated Tensor outputTensor = 5; // must be the same length as the "output" field - - // Must be set to true to mark the layer as updatable. - // If true, the weightParams in the layer's properties must also be set to updatable - // If false, the value of the isUpdatable parameter within the layer's weights are ignored - bool isUpdatable = 10; - - oneof layer { - - // Start at 100 here - ConvolutionLayerParams convolution = 100; - - PoolingLayerParams pooling = 120; - - ActivationParams activation = 130; - - InnerProductLayerParams innerProduct = 140; - EmbeddingLayerParams embedding = 150; - - // Normalization-related Layers - BatchnormLayerParams batchnorm = 160; - MeanVarianceNormalizeLayerParams mvn = 165; - L2NormalizeLayerParams l2normalize = 170; - SoftmaxLayerParams softmax = 175; - LRNLayerParams lrn = 180; - - CropLayerParams crop = 190; - PaddingLayerParams padding = 200; - UpsampleLayerParams upsample = 210; - - ResizeBilinearLayerParams resizeBilinear = 211; - CropResizeLayerParams cropResize = 212; - - UnaryFunctionLayerParams unary = 220; - - // Element-wise Operations - AddLayerParams add = 230; - MultiplyLayerParams multiply = 231; - - AverageLayerParams average = 240; - ScaleLayerParams scale = 245; - - BiasLayerParams bias = 250; - MaxLayerParams max = 260; - MinLayerParams min = 261; - - DotProductLayerParams dot = 270; - ReduceLayerParams reduce = 280; - LoadConstantLayerParams loadConstant = 290; - - // Data Reorganization - ReshapeLayerParams reshape = 300; - FlattenLayerParams flatten = 301; - PermuteLayerParams permute = 310; - ConcatLayerParams concat = 320; - SplitLayerParams split = 330; - SequenceRepeatLayerParams sequenceRepeat = 340; - - ReorganizeDataLayerParams reorganizeData = 345; - SliceLayerParams slice = 350; - - // Recurrent Layers - SimpleRecurrentLayerParams simpleRecurrent = 400; - GRULayerParams gru = 410; - UniDirectionalLSTMLayerParams uniDirectionalLSTM = 420; - BiDirectionalLSTMLayerParams biDirectionalLSTM = 430; - - // Custom (user-implemented) Layer - CustomLayerParams custom = 500; - - // Following layers are available only after Core ML Specification - // version >= 4 (iOS >= 13, macOS >= 10.15) - - // Control Flow related Layers - CopyLayerParams copy = 600; - BranchLayerParams branch = 605; - - LoopLayerParams loop = 615; - LoopBreakLayerParams loopBreak = 620; - LoopContinueLayerParams loopContinue = 625; - - RangeStaticLayerParams rangeStatic = 635; - RangeDynamicLayerParams rangeDynamic = 640; - - // Element-wise Unary Layers - ClipLayerParams clip = 660; - CeilLayerParams ceil = 665; - FloorLayerParams floor = 670; - - SignLayerParams sign = 680; - RoundLayerParams round = 685; - - Exp2LayerParams exp2 = 700; - - SinLayerParams sin = 710; - CosLayerParams cos = 715; - TanLayerParams tan = 720; - - AsinLayerParams asin = 730; - AcosLayerParams acos = 735; - AtanLayerParams atan = 740; - - SinhLayerParams sinh = 750; - CoshLayerParams cosh = 755; - TanhLayerParams tanh = 760; - - AsinhLayerParams asinh = 770; - AcoshLayerParams acosh = 775; - AtanhLayerParams atanh = 780; - - ErfLayerParams erf = 790; - GeluLayerParams gelu = 795; - - // Element-wise Binary with Broadcasting Support - EqualLayerParams equal = 815; - NotEqualLayerParams notEqual = 820; - LessThanLayerParams lessThan = 825; - LessEqualLayerParams lessEqual = 827; - GreaterThanLayerParams greaterThan = 830; - GreaterEqualLayerParams greaterEqual = 832; - - LogicalOrLayerParams logicalOr = 840; - LogicalXorLayerParams logicalXor = 845; - LogicalNotLayerParams logicalNot = 850; - LogicalAndLayerParams logicalAnd = 855; - - ModBroadcastableLayerParams modBroadcastable = 865; - MinBroadcastableLayerParams minBroadcastable = 870; - MaxBroadcastableLayerParams maxBroadcastable = 875; - AddBroadcastableLayerParams addBroadcastable = 880; - PowBroadcastableLayerParams powBroadcastable = 885; - DivideBroadcastableLayerParams divideBroadcastable = 890; - FloorDivBroadcastableLayerParams floorDivBroadcastable = 895; - MultiplyBroadcastableLayerParams multiplyBroadcastable = 900; - SubtractBroadcastableLayerParams subtractBroadcastable = 905; - - // Tensor Manipulations - TileLayerParams tile = 920; - StackLayerParams stack = 925; - GatherLayerParams gather = 930; - ScatterLayerParams scatter = 935; - GatherNDLayerParams gatherND = 940; - ScatterNDLayerParams scatterND = 945; - SoftmaxNDLayerParams softmaxND = 950; - GatherAlongAxisLayerParams gatherAlongAxis = 952; - ScatterAlongAxisLayerParams scatterAlongAxis = 954; - - ReverseLayerParams reverse = 960; - ReverseSeqLayerParams reverseSeq = 965; - - SplitNDLayerParams splitND = 975; - ConcatNDLayerParams concatND = 980; - TransposeLayerParams transpose = 985; - - SliceStaticLayerParams sliceStatic = 995; - SliceDynamicLayerParams sliceDynamic = 1000; - SlidingWindowsLayerParams slidingWindows = 1005; - - TopKLayerParams topK = 1015; - ArgMinLayerParams argMin = 1020; - ArgMaxLayerParams argMax = 1025; - - EmbeddingNDLayerParams embeddingND = 1040; - BatchedMatMulLayerParams batchedMatmul = 1045; - - // Tensor Allocation / Reshape-related Operations - GetShapeLayerParams getShape = 1065; - LoadConstantNDLayerParams loadConstantND = 1070; - - FillLikeLayerParams fillLike = 1080; - FillStaticLayerParams fillStatic = 1085; - FillDynamicLayerParams fillDynamic = 1090; - - BroadcastToLikeLayerParams broadcastToLike = 1100; - BroadcastToStaticLayerParams broadcastToStatic = 1105; - BroadcastToDynamicLayerParams broadcastToDynamic = 1110; - - SqueezeLayerParams squeeze = 1120; - ExpandDimsLayerParams expandDims = 1125; - FlattenTo2DLayerParams flattenTo2D = 1130; - ReshapeLikeLayerParams reshapeLike = 1135; - ReshapeStaticLayerParams reshapeStatic = 1140; - ReshapeDynamicLayerParams reshapeDynamic = 1145; - RankPreservingReshapeLayerParams rankPreservingReshape = 1150; - - ConstantPaddingLayerParams constantPad = 1155; - - // Random Distributions - RandomNormalLikeLayerParams randomNormalLike = 1170; - RandomNormalStaticLayerParams randomNormalStatic = 1175; - RandomNormalDynamicLayerParams randomNormalDynamic = 1180; - - RandomUniformLikeLayerParams randomUniformLike = 1190; - RandomUniformStaticLayerParams randomUniformStatic = 1195; - RandomUniformDynamicLayerParams randomUniformDynamic = 1200; - - RandomBernoulliLikeLayerParams randomBernoulliLike = 1210; - RandomBernoulliStaticLayerParams randomBernoulliStatic = 1215; - RandomBernoulliDynamicLayerParams randomBernoulliDynamic = 1220; - - CategoricalDistributionLayerParams categoricalDistribution = 1230; - - // Reduction-related Layers: - ReduceL1LayerParams reduceL1 = 1250; - ReduceL2LayerParams reduceL2 = 1255; - ReduceMaxLayerParams reduceMax = 1260; - ReduceMinLayerParams reduceMin = 1265; - ReduceSumLayerParams reduceSum = 1270; - ReduceProdLayerParams reduceProd = 1275; - ReduceMeanLayerParams reduceMean = 1280; - ReduceLogSumLayerParams reduceLogSum = 1285; - ReduceSumSquareLayerParams reduceSumSquare = 1290; - ReduceLogSumExpLayerParams reduceLogSumExp = 1295; - - // Masking / Selection Layers - WhereNonZeroLayerParams whereNonZero = 1313; - MatrixBandPartLayerParams matrixBandPart = 1315; - LowerTriangularLayerParams lowerTriangular = 1320; - UpperTriangularLayerParams upperTriangular = 1325; - WhereBroadcastableLayerParams whereBroadcastable = 1330; - - // Normalization Layers - LayerNormalizationLayerParams layerNormalization = 1350; - - NonMaximumSuppressionLayerParams NonMaximumSuppression = 1400; - - // Following layers are available only after Core ML Specification - // version >= 5 (iOS >= 14, macOS >= 11.0) - OneHotLayerParams oneHot = 1450; - CumSumLayerParams cumSum = 1455; - ClampedReLULayerParams clampedReLU = 1460; - ArgSortLayerParams argSort = 1461; - Pooling3DLayerParams pooling3d = 1465; - GlobalPooling3DLayerParams globalPooling3d = 1466; - SliceBySizeLayerParams sliceBySize = 1470; - Convolution3DLayerParams convolution3d = 1471; - - } - -} - -/** - * Branching Layer - * - * A layer that provides the functionality of branching or an If-Else block. - * - * Must have 1 input. There are no outputs as the execution is transferred to either the - * if or the else branch based on the value of the input. - * - * Input is the condition predicate. Must be a scalar (length 1 tensor). - * - */ -message BranchLayerParams { - - /** - * execute this graph if the absolute value of the input Tensor is greater than 1e-6 - * This must be present. - */ - NeuralNetwork ifBranch = 1; - /** - * execute this graph if the absolute value of the input Tensor is less than 1e-6 - * This is optional. - */ - NeuralNetwork elseBranch = 2; - -} - -/** - * Loop Layer - * - * A layer that provides the functionality of a "for" loop or a "while" loop. - * - * There are either no inputs or 1 input. When an input is present, it corresponds to the maximum loop count, - * in that case the value of the "maxLoopIterations" field is ignored. Input must be a scalar. - * (For description below, maxLoopIterations is assumed to be the value of the input, when its present) - * - * No outputs are produced. Blobs produced by the condition or the body network are visible in the scope of the overall network. - * - * "conditionNetwork" must produce a tensor with the name specified in the "conditionVar" field. - * - * There are 3 possible cases for determining the termination condition: - * - * Case 1: - * - * If there is no "conditionNetwork", in this case the layer corresponds to a pure for loop, which is run "maxLoopIterations" number of times. - * Equivalent pseudo-code: - * - * for loopIterator = 0 : maxLoopIterations - * bodyNetwork() - * - * - * Case 2: - * - * "conditionNetwork" is present, and "maxLoopIterations" is 0 and there is no input, - * in this case the layer corresponds to a while loop. Equivalent pseudo-code: - * - * conditionVar = conditionNetwork() - * while conditionVar: - * bodyNetwork() - * conditionVar = conditionNetwork() - * - * - * Case 3: - * - * "conditionNetwork" is provided, and "maxLoopIterations" is positive or there is an input, - * in this case the layer corresponds to a while loop with a joint condition. Equivalent pseudo-code: - * - * loopIterator = 0 - * conditionVar = conditionNetwork() - * while (conditionVar and loopIterator < maxLoopIterations): - * bodyNetwork() - * loopIterator = loopIterator + 1 - * conditionVar = conditionNetwork() - * - */ -message LoopLayerParams { - - /** - * maximum number of iterations. Ignored if input is present. - */ - uint64 maxLoopIterations = 1; - /** - * This field provides the name of the tensor which is produced by the conditionNetwork - * and whose value is checked to start/continue/terminate the loop. Value close to 0.0f is treated as False. - * This field is optional. - * Must be a non empty string if and only if "conditionNetwork" is present. - */ - string conditionVar = 2; - /** - * Must generate a tensor with the name provided in the "conditionVar" field. - * This field is optional. - * Must be present if and only if "conditionVar" field is a non empty string. - */ - NeuralNetwork conditionNetwork = 3; - /** - * Body of the loop. - * This field must be present. - */ - NeuralNetwork bodyNetwork = 4; - -} - -/** - * Loop break Layer - * - * Terminate the loop that has this layer. - * If present, it should always reside in the "bodyNetwork" of the loop layer - * - * No inputs/outputs - * - */ -message LoopBreakLayerParams { - -} - -/** - * Loop Continue Layer - * - * Stop the current loop iteration and continue on the next iteration. - * If present, it should always reside in the "bodyNetwork" of the loop layer - * - * No inputs/outputs - * - */ -message LoopContinueLayerParams { - -} - -/** - * Copy Layer - * - * A layer that copies its input tensor to the output tensor. - * Must have 1 input and 1 output, with distinct names. - * This is the only layer that is allowed to re-generate an output that is already present in the neural network prior to this layer, - * in which case it will overwrite the output tensor. - * - */ -message CopyLayerParams { - -} - -/** - * GreaterThan Layer - * - * Either 1 or 2 inputs. - * Produces 1 output. - * Perform elementwise greater than operation. - * - * Output is 1.0f if the condition is true otherwise 0.0f. - * - * .. code:: - * - * y = x1 > x2 - * or - * y = x1 > alpha, if only one input is provided - * - * Broadcasting is supported. - * - */ -message GreaterThanLayerParams { - - /** - * Compare to the scalar value provided here if there is 1 input - */ - float alpha = 2; - -} - -/** - * GreaterEqual Layer - * - * Either 1 or 2 inputs. - * Produces 1 output. - * Perform elementwise greater equal operation. - * - * Output is 1.0f if the condition is true otherwise 0.0f. - * - * .. code:: - * - * y = x1 >= x2 - * or - * y = x1 >= alpha, if only one input is provided - * - * Broadcasting is supported. - * - */ -message GreaterEqualLayerParams { - - /** - * Compare to the scalar value provided here if there is 1 input - */ - float alpha = 2; - -} - -/** - * LessThan Layer - * - * Either 1 or 2 inputs. - * Produces 1 output. - * Perform elementwise less than operation. - * - * Output is 1.0f if the condition is true otherwise 0.0f. - * - * .. code:: - * - * y = x1 < x2 - * or - * y = x1 < alpha, if only one input is provided - * - * Broadcasting is supported. - * - */ -message LessThanLayerParams { - - /** - * Compare to the scalar value provided here if there is 1 input - */ - float alpha = 2; - -} - -/** - * LessEqual Layer - * - * Either 1 or 2 inputs. - * Produces 1 output. - * Perform elementwise less equal operation. - * - * Output is 1.0f if the condition is true otherwise 0.0f. - * - * .. code:: - * - * y = x1 <= x2 - * or - * y = x1 <= alpha, if only one input is provided - * - * Broadcasting is supported. - * - */ -message LessEqualLayerParams { - - /** - * Compare to the scalar value provided here if there is 1 input - */ - float alpha = 2; - -} - -/** - * Equal Layer - * - * Either 1 or 2 inputs. - * Produces 1 output. - * Perform elementwise equal operation. - * - * Output is 1.0f if the condition is true otherwise 0.0f. - * - * .. code:: - * - * y = x1 == x2 - * or - * y = x1 == alpha, if only one input is provided - * - * Broadcasting is supported. - * - */ -message EqualLayerParams { - - /** - * Compare to the scalar value provided here if there is 1 input - */ - float alpha = 1; - -} - -/** - * NotEqual Layer - * - * Either 1 or 2 inputs. - * Produces 1 output. - * Perform elementwise not equal operation. - * - * Output is 1.0f if the condition is true otherwise 0.0f. - * - * .. code:: - * - * y = x1 != x2 - * or - * y = x1 != alpha, if only one input is provided - * - * Broadcasting is supported. - * - */ -message NotEqualLayerParams { - - /** - * Compare to the scalar value provided here if there is 1 input - */ - float alpha = 1; - -} - -/** - * LogicalAnd Layer - * - * Must have 2 inputs, produces 1 output. - * Perform elementwise logical AND operation. - * - * Input is considered False if equal to 0.0f otherwise True. - * Output is 1.0f if the condition is true otherwise 0.0f. - * - * .. code:: - * - * y = AND(x1, x2) - * - * Broadcasting is supported. - * - */ -message LogicalAndLayerParams { - -} - -/** - * LogicalOr Layer - * - * Must have 2 inputs, produces 1 output. - * Perform elementwise logical OR operation. - * - * Input is considered False if equal to 0.0f otherwise True. - * Output is 1.0f if the condition is true otherwise 0.0f. - * - * .. code:: - * - * y = OR(x1, x2) - * - * Broadcasting is supported. - * - */ -message LogicalOrLayerParams { - -} - -/** - * LogicalXor Layer - * - * Must have 2 inputs, produces 1 output. - * Perform elementwise logical XOR operation. - * - * Input is considered False if equal to 0.0f otherwise True. - * Output is 1.0f if the condition is true otherwise 0.0f. - * - * .. code:: - * - * y = XOR(x1, x2) - * - * Broadcasting is supported. - * - */ -message LogicalXorLayerParams { - -} - -/** - * LogicalNot Layer - * - * Must have 1 input, produces 1 output. - * Perform elementwise logical NOT operation. - * - * Input is considered False if equal to 0.0f otherwise True. - * Output is 1.0f if the condition is true otherwise 0.0f. - * - * .. code:: - * - * y = NOT(x) - * - * - */ -message LogicalNotLayerParams { - -} - -/// Border Amounts -/// -------------- - -/** - * Specifies the amount of spatial border to be either padded or cropped. - * - * For padding: - * - * .. code:: - * - * H_out = borderAmounts[0].startEdgeSize + H_in + borderAmounts[0].endEdgeSize - * W_out = borderAmounts[1].startEdgeSize + W_in + borderAmounts[1].endEdgeSize - * - * topPaddingAmount == Height startEdgeSize - * bottomPaddingAmount == Height endEdgeSize - * leftPaddingAmount == Width startEdgeSize - * rightPaddingAmount == Width endEdgeSize - * - * For cropping: - * - * .. code:: - * - * H_out = (-borderAmounts[0].startEdgeSize) + H_in + (-borderAmounts[0].endEdgeSize) - * W_out = (-borderAmounts[1].startEdgeSize) + W_in + (-borderAmounts[1].endEdgeSize) - * - * topCropAmount == Height startEdgeSize - * bottomCropAmount == Height endEdgeSize - * leftCropAmount == Width startEdgeSize - * rightCropAmount == Width endEdgeSize - */ -message BorderAmounts { - - message EdgeSizes { - /** - * The amount to be padded or cropped from the beginning. - */ - uint64 startEdgeSize = 1; - - /** - * The amount to be padded or cropped from the end. - */ - uint64 endEdgeSize = 2; - } - - /** - * The border amounts. - * This must be length 2 in the order ``[H, W]``. - */ - repeated EdgeSizes borderAmounts = 10; - -} - -/** - * Specifies the type of padding to be used with Convolution/Deconvolution and Pooling layers. - * After padding, input spatial shape: ``[H_in, W_in]``, gets modified to the - * output spatial shape ``[H_out, W_out]``. - * - * .. code:: - * - * topPaddingAmount == Height startEdgeSize == borderAmounts[0].startEdgeSize - * bottomPaddingAmount == Height endEdgeSize == borderAmounts[0].endEdgeSize - * leftPaddingAmount == Width startEdgeSize == borderAmounts[1].startEdgeSize - * rightPaddingAmount == Width endEdgeSize == borderAmounts[1].endEdgeSize - * - * With Convolution or Pooling: - * - * .. code:: - * - * H_out = int_division_round_down((H_in + topPaddingAmount + bottomPaddingAmount - KernelSize[0]),stride[0]) + 1 - * - * which is same as: - * - * .. code:: - * - * H_out = int_division_round_up((H_in + topPaddingAmount + bottomPaddingAmount - KernelSize[0] + 1),stride[0]) - * - * With Deconvolution: - * - * .. code:: - * - * H_out = (H_in-1) * stride[0] + kernelSize[0] - (topPaddingAmount + bottomPaddingAmount) - * - * - * The equivalent expressions hold true for ``W_out`` as well. - * - * - * By default, the values of ``paddingAmounts`` are set to ``0``, - * which results in a "true" valid padding. - * If non-zero values are provided for ``paddingAmounts``, - * "valid" convolution/pooling is performed within the spatially expanded input. - * - */ -message ValidPadding { - - BorderAmounts paddingAmounts = 1; - -} - -/** - * Specifies the type of padding to be used with Convolution/Deconvolution and pooling layers. - * After padding, input spatial shape: ``[H_in, W_in]``, gets modified to the - * output spatial shape ``[H_out, W_out]``. - * With Convolution or pooling: - * - * .. code:: - * - * H_out = int_division_round_up(H_in,stride[0]) - * W_out = int_division_round_up(W_in,stride[1]) - * - * This is achieved by using the following padding amounts: - * - * .. code:: - * - * totalPaddingHeight = max(0,(H_out-1) * stride[0] + KernelSize[0] - Hin) - * totalPaddingWidth = max(0,(W_out-1) * stride[1] + KernelSize[1] - Win) - * - * There are two modes of asymmetry: - * ``BOTTOM_RIGHT_HEAVY``, and ``TOP_LEFT_HEAVY``. - * - * If the mode is ``BOTTOM_RIGHT_HEAVY``: - * - * .. code:: - * - * topPaddingAmount = floor(totalPaddingHeight / 2) - * bottomPaddingAmount = totalPaddingHeight - topPaddingAmount - * leftPaddingAmount = floor(totalPaddingWidth / 2) - * rightPaddingAmount = totalPaddingWidth - leftPaddingAmount - * - * If the mode is ``TOP_LEFT_HEAVY``: - * - * .. code:: - * - * bottomPaddingAmount = floor(totalPaddingHeight / 2) - * topPaddingAmount = totalPaddingHeight - bottomPaddingAmount - * rightPaddingAmount = floor(totalPaddingWidth / 2) - * leftPaddingAmount = totalPaddingWidth - rightPaddingAmount - * - * - * With Deconvolution: - * - * .. code:: - * - * H_out = H_in * stride[0] - * W_out = W_in * stride[1] - */ -message SamePadding { - - enum SamePaddingMode { - - BOTTOM_RIGHT_HEAVY = 0; - TOP_LEFT_HEAVY = 1; - - } - SamePaddingMode asymmetryMode = 1; - -} - -/** - * Specifies how grid points are sampled from an interval. - * Without the loss of generality, assume the interval to be [0, X-1] from which N points are to be sampled. - * Here X may correspond to an input image's height or width. - * All the methods can be expressed in terms of numpy's linspace function, along with the constraint that grid points have to lie in the interval [0, X-1]. - * Note: numpy.linspace(start = start, end = end, num = N, endpoint = True) corresponds to sampling - * N points uniformly from the interval [start, end], endpoints included. - * The methods vary in how the ``start`` and ``end`` values are computed. - */ -message SamplingMode { - - enum Method { - - /** - * start = 0, end = X-1 - * grid points = numpy.linspace(start, end) - */ - STRICT_ALIGN_ENDPOINTS_MODE = 0; - - /** - * if N == 1: start = end = (X-1)/2 - * otherwise, start = 0, end = X-1 - * grid points = numpy.linspace(start, end) - */ - ALIGN_ENDPOINTS_MODE = 1; - - /** - * start = 0, end = X - X/N - * grid points = min(X-1, numpy.linspace(start, end)) - * This is same as the mode used in the upsample layer in this specification, when used with bilinear interpolation. In that case N/X = upsample ratio. - */ - UPSAMPLE_MODE = 2; - - /** - * spacing = max(1, X-1)/N - * start = 0.5 * spacing - * end = start + (N-1) * spacing - * grid points = min(X-1, numpy.linspace(start, end)) - */ - ROI_ALIGN_MODE = 3; - - } - - Method samplingMethod = 1; - -} - -/** - * Specifies the convention used to specify four bounding box coordinates for an image of size (Height, Width). - * The (0,0) coordinate corresponds to the top-left corner of the image. - */ -message BoxCoordinatesMode { - - enum Coordinates { - - /** - * [h_start, w_start, h_end, w_end] - */ - CORNERS_HEIGHT_FIRST = 0; - - /** - * [w_start, h_start, w_end, h_end] - */ - CORNERS_WIDTH_FIRST = 1; - - /** - * [h_center, w_center, box_height, box_width] - */ - CENTER_SIZE_HEIGHT_FIRST = 2; - - /** - * [w_center, h_center, box_width, box_height] - */ - CENTER_SIZE_WIDTH_FIRST = 3; - - } - - Coordinates boxMode = 1; - -} - -/** - * Weights for layer parameters. - * Weights are stored as repeated floating point numbers - * using row-major ordering - * and can represent 1-, 2-, 3-, or 4-dimensional data. - */ -message WeightParams { - - /** - * Values specified in single / float / FP32 precision. - */ - repeated float floatValue = 1; - - /** - * Values in 16-bit half precision floating point. - */ - bytes float16Value = 2; - - /** - * Raw value specification for quantized lower precisions. - * - * This field is interpreted as uintN, where N is the number of bits in quantization. - * E.g. if n=8, the field is interpreted as an array of UINT8. - * Use this field for quantized parameters unless specifically noted to use - * int8RawValue. - */ - bytes rawValue = 30; - - /** - * Field to be used if int8DynamicQuantize is set in the parent layer. - * Cannot be set if rawValue is also set. - * The values in this field are interpreted as INT8. - * - * If this field is set, following conditions must hold true: - * * QuantizationType == LinearQuantizationParams, such that - * * size of the "scale" field is 1 and "bias" field is empty in "LinearQuantizationParams" - */ - bytes int8RawValue = 31; - - /** - * Quantization related parameters. - */ - QuantizationParams quantization = 40; - - bool isUpdatable = 50; - -} - -/** - * Quantization parameters. - */ -message QuantizationParams { - - uint64 numberOfBits = 1; - oneof QuantizationType { - LinearQuantizationParams linearQuantization = 101; - LookUpTableQuantizationParams lookupTableQuantization = 102; - } - -} - -message LinearQuantizationParams { - - /** - * Stores scale and bias values corresponding to the quantized weights. - * Must be an array of 1 element, or an array of C elements, where C - * is number of output channels. For recurrent layers it is equal to - * the output vector size. - * - * Relationship between quantized weights, unquantized weights, scale and bias: - * - * W_unquantized = W_quantized * scale + bias - * - */ - repeated float scale = 1; - repeated float bias = 2; - -} - -message LookUpTableQuantizationParams { - - /* Stores look-up table quantization values. Must be an array of - (2^numberOfBits) Elements. - */ - repeated float floatValue = 1; - -} - -/// Layers -/// ------ - -/** - * A layer that performs spatial convolution or deconvolution. - * - * .. code:: - * - * y = ConvolutionLayer(x) - * - * Requires 1 or 2 inputs and produces 1 output. - * - * Input - * First Input: - * A blob with rank greater than or equal to 4. - * Rank 4 blob represents [Batch, channels, height, width]. - * For ranks greater than 4, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * - * From Core ML specification version 4 onwards (iOS >= 13, macOS >= 10.15). - * convolution layer can have 2 inputs, in which case the second input is - * the blob representing the weights. This is allowed when "isDeconvolution" = False. - * The weight blob should have shape - * ``[outputChannels, kernelChannels, kernelHeight, kernelWidth]``, - * where kernelChannels == inputChannels / nGroups. - * - * Output - * Rank is same as the input. e.g.: for rank 4 input, output shape is [B, C_out, H_out, W_out] - * - * - * If ``dilationFactor`` is not 1, effective kernel size is - * modified as follows: - * - * .. code:: - * - * KernelSize[0] <-- (kernelSize[0]-1) * dilationFactor[0] + 1 - * KernelSize[1] <-- (kernelSize[1]-1) * dilationFactor[1] + 1 - * - * Type of padding can be ``valid`` or ``same``. Output spatial dimensions depend on the - * the type of padding. For details, refer to the descriptions of the messages "ValidPadding" - * and "SamePadding". Padded values are all zeros. - * - * For Deconvolution, ``ConvolutionPaddingType`` (``valid`` or ``same``) is ignored when ``outputShape`` is set. - * - * - */ -message ConvolutionLayerParams { - - /** - * The number of kernels. - * Same as ``C_out`` used in the layer description. - */ - uint64 outputChannels = 1; - - /** - * Channel dimension of the kernels. - * Must be equal to ``inputChannels / nGroups``, if isDeconvolution == False - * Must be equal to ``inputChannels``, if isDeconvolution == True - */ - uint64 kernelChannels = 2; - - /** - * Group convolution, i.e. weight reuse along channel axis. - * Input and kernels are divided into g groups - * and convolution / deconvolution is applied within the groups independently. - * If not set or 0, it is set to the default value 1. - */ - uint64 nGroups = 10; - - /** - * Must be length 2 in the order ``[H, W]``. - * If not set, default value ``[3, 3]`` is used. - */ - repeated uint64 kernelSize = 20; - - /** - * Must be length 2 in the order ``[H, W]``. - * If not set, default value ``[1, 1]`` is used. - */ - repeated uint64 stride = 30; - - /** - * Must be length 2 in order ``[H, W]``. - * If not set, default value ``[1, 1]`` is used. - * It is ignored if ``isDeconvolution == true``. - */ - repeated uint64 dilationFactor = 40; - - /** - * The type of padding. - */ - oneof ConvolutionPaddingType { - ValidPadding valid = 50; - SamePadding same = 51; - } - - /** - * Flag to specify whether it is a deconvolution layer. - */ - bool isDeconvolution = 60; - - /** - * Flag to specify whether a bias is to be added or not. - */ - bool hasBias = 70; - - /** - * Weights associated with this layer. - * If convolution (``isDeconvolution == false``), weights have the shape - * ``[outputChannels, kernelChannels, kernelHeight, kernelWidth]``, where kernelChannels == inputChannels / nGroups - * If deconvolution (``isDeconvolution == true``) weights have the shape - * ``[kernelChannels, outputChannels / nGroups, kernelHeight, kernelWidth]``, where kernelChannels == inputChannels - */ - WeightParams weights = 90; - WeightParams bias = 91; /// Must be of size [outputChannels]. - - /** - * The output shape, which has length 2 ``[H_out, W_out]``. - * This is used only for deconvolution (``isDeconvolution == true``). - * If not set, the deconvolution output shape is calculated - * based on ``ConvolutionPaddingType``. - */ - repeated uint64 outputShape = 100; - -} - -/** - * A layer that performs a 3-dimensional convolution. - * - * .. code:: - * - * y = Convolution3DLayer(x) - * - * Input - * A blob of rank 5. - * The input blob's shape should be ``[batch, channels, depth, height, width]``. - * - * Fields - * The bias field, if set, should have shape of ``[channelsOut]``. - * - * Output - * A blob of rank 5. - * The output blob's shape is ``[batch, channelsOut, depthOut, heightOut, widthOut]``. - * - * Type of padding can be ``custom``, ``valid``, or ``same``. Padded values are all zeros. - * Output spatial dimensions depend on the the type of padding. For details, refer to the - * descriptions of the ``PaddingType`` field of this ``Convolution3DLayerParams`` message. - * - * Example - * For example, given an input of size ``[1, 3, 3, 8, 8]``, a stride of 2 in each dimension, - * a kernel of 3 in each dimension, 2 output channels, and ``same`` padding, this layer will - * compute the total padding applied in the depth, height, and width dimensions to be 2, 1, and 1, - * respectively. The depth padding is even and will be applied equally to both sides of the depth - * dimension. Since the height and width padding values are odd, they'll be applied to the - * bottom/right of the height/width dimensions. Thus, the padding applied to the input will be - * ``[1, 1, 0, 1, 0, 1]`` (front, back, top, bottom, left, right). Finally, the output produced - * will have size ``[1, 2, 2, 4, 4]``. - * - */ -message Convolution3DLayerParams { - - /** - * The number of channels in the output (channelsOut). Must be a positive integer. - */ - int32 outputChannels = 1; - - /** - * The number of channels in the input (channels). Must be a positive integer. - */ - int32 inputChannels = 2; - - /** - * Group convolution, i.e., weight reuse along the channel axis. - * It must evenly divide both the number of input and output channels and be at most the number - * of input channels (a depthwise convolution). - * Input and kernels are divided into g groups and convolution is applied within the groups - * independently. - */ - int32 nGroups = 10; - - /* Depth of the convolution kernel. Must be a positive integer. - */ - int32 kernelDepth = 20; - - /* Height of the convolution kernel. Must be a positive integer. - */ - int32 kernelHeight = 21; - - /* Width of the convolution kernel. Must be a positive integer. - */ - int32 kernelWidth = 22; - - /* Stride along the depth direction. Must be a positive integer. - */ - int32 strideDepth = 31; - - /* Stride along the height direction. Must be a positive integer. - */ - int32 strideHeight = 32; - - /* Stride along the width direction. Must be a positive integer. - */ - int32 strideWidth = 33; - - /* Dilation along the depth direction. Must be a positive integer. - */ - int32 dilationDepth = 40; - - /* Dilation along the height direction. Must be a positive integer. - */ - int32 dilationHeight = 41; - - /* Dilation along the width direction. Must be a positive integer. - */ - int32 dilationWidth = 42; - - /** - * Flag to specify whether a bias is to be added or not. - * If false, then no bias is added. - */ - bool hasBias = 50; - - /** - * Weights associated with this layer. - * Weights have the shape - * if deconvolution == False - * ``[outputChannels, kernelChannels, kernelDepth, kernelHeight, kernelWidth]``, where - * kernelChannels == inputChannels / nGroups - * else if deconvolution == True - * ``[outputChannels / nGroups, kernelChannels, kernelDepth, kernelHeight, kernelWidth]``, where - */ - WeightParams weights = 60; - - /** - * Must be of size ``[outputChannels]``. - */ - WeightParams bias = 61; - - - /** - * The type of padding. - * All padding types pad the input shape with zeros. - * CUSTOM padding will add the custom padding values specified below to their respective - * dimensions, e.g., `customPaddingFront` number of zeros will be added to one side of the - * input's depth dimension and `customPaddingBack` number of zeros will be added to the other - * side of the input's depth dimension. - * VALID padding adds no padding to any dimension. In this case, the last convolution along - * each dimension will be dropped if the input dimension and the kernel size, stride, and - * dilation do not match. - * SAME padding adds enough padding to each dimension such that the output of the convolution - * has size ``Ceiling(inputShape / stride)``. Padding is added evenly to both sides of each - * dimension unless the total padding to add is odd, in which case it is added to the - * back/bottom/right side of the respective dimension. For example, if the total padding needed - * in the depth dimension is 3, 1 zero will be added to the front side of the depth dimension - * and 2 zeros will be added to the back side. - */ - enum PaddingType { - CUSTOM = 0; - VALID = 1; - SAME = 2; - } - PaddingType paddingType = 70; - - /* Padding before the input in the depth direction. Must be zero or a positive integer. - * Used when the `PaddingType` is `CustomPadding`, otherwise ignored by other padding types. - */ - int32 customPaddingFront = 80; - - /* Padding after the input in the depth direction. Must be zero or a positive integer. - * Used when the `PaddingType` is `CustomPadding`, otherwise ignored by other padding types. - */ - int32 customPaddingBack = 81; - - /* Padding before the input in the height direction. Must be zero or a positive integer. - * Used when the `PaddingType` is `CustomPadding`, otherwise ignored by other padding types. - */ - int32 customPaddingTop = 82; - - /* Padding after the input in the height direction. Must be zero or a positive integer. - * Used when the `PaddingType` is `CustomPadding`, otherwise ignored by other padding types. - */ - int32 customPaddingBottom = 83; - - /* Padding before the input in the width direction. Must be zero or a positive integer. - * Used when the `PaddingType` is `CustomPadding`, otherwise ignored by other padding types. - */ - int32 customPaddingLeft = 84; - - /* Padding after the input in the width direction. Must be zero or a positive integer. - * Used when the `PaddingType` is `CustomPadding`, otherwise ignored by other padding types. - */ - int32 customPaddingRight = 85; - - /* Flag to specify if this is Convolution Transpose or not. - */ - bool isDeconvolution = 86; - - /* - * The output shape, which has length 3 ``[D_out, H_out, W_out]``. - * This is used only for deconvolution (``isDeconvolution == true``). - * If not set, the deconvolution output shape is calculated - * based on ``PaddingType``. - */ - repeated uint64 outputShape = 87; - -} - -/** - * A layer that performs a matrix-vector or matrix-matrix product. - * This is equivalent to a fully-connected, or dense layer. - * The weight parameters correspond to a matrix of dimensions (inputChannels, outputChannels) i.e. (C_in, C_out) - * - * .. code:: - * - * y = InnerProductLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * Input can have rank 1 to rank 5. This is how it is reshaped in to the matrix (for rank > 1): - * rank 1 (x1) : in this case, the layer corresponds to a matrix-vector product. x1 must be equal to C_in - * rank 2 (x1, x2): x2 must be equal to C_in - * rank 3 (x1, x2, x3) --> (x1 * x2, x3). x3 must be equal to C_in - * rank 4 (x1, x2, x3, x4) ---> (x1, x2 * x3 * x4). x2 * x3 * x4 must be equal to C_in - * rank 5 (x1, x2, x3, x4, x5) ---> (x1 * x2, x3 * x4 * x5). x3 * x4 * x5 must be equal to C_in - * - * Output - * Output rank is same as the input rank - * rank 1: (C_out) - * rank 2: (x1, C_out) - * rank 3: (x1, x2, C_out) - * rank 4: (x1, C_out, 1, 1) - * rank 5: (x1, x2, C_out, 1, 1) - * - */ -message InnerProductLayerParams { - - uint64 inputChannels = 1; /// Input size: C_in. - uint64 outputChannels = 2; /// Output size: C_out. - - bool hasBias = 10; /// Whether a bias is added or not. - - WeightParams weights = 20; /// Weight matrix [C_out, C_in]. - WeightParams bias = 21; /// Bias vector [C_out]. - - /** - * If set, this layer, at runtime, quantizes the floating point input blob to int8 before applying an - * inner product using INT8 weight matrix parameters, as provided in weights->int8RawValue. The - * result is then dequantized. - * Requires: - * * hasBias == false - * * QuantizationType == LinearQuantizationParams, such that - * * size of the "scale" field is 1 and "bias" field is empty in "LinearQuantizationParams" - * * numberOfBits == 8 - * * weights->rawValue_size to be empty - */ - bool int8DynamicQuantize = 22; - -} - -/** - * A layer that performs a matrix lookup and optionally adds a bias. - * The weights matrix is stored with dimensions [outputChannels, inputDim]. - * - * .. code:: - * - * y = EmbeddingLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * Input values must be in the range ``[0, inputDim - 1]``. - * - * Input must have rank equal to 4 or 5, such that the last 3 dimensions are all 1. - * rank 4: shape (x1, 1, 1, 1). x1 is effectively the batch/sequence length. - * rank 5: shape (x1, x2 , 1, 1, 1). x1 * x2 is effectively the combined batch/sequence length. - * - * Output - * Output rank is same as the input rank. Please see input description above. - * rank 4: shape (x1, outputChannels, 1, 1) - * rank 5: shape (x1, x2, outputChannels, 1, 1) - * - */ -message EmbeddingLayerParams { - - uint64 inputDim = 1; /// Size of the input dictionary. - uint64 outputChannels = 2; /// Size of the output vectors. - - bool hasBias = 10; /// Whether a bias is added or not. - - WeightParams weights = 20; /// 2-D weights of dimensions [outputChannels, inputDim]. - WeightParams bias = 21; /// Bias of size [outputChannels]. - -} - -/** - * A layer that performs a matrix lookup and optionally adds a bias. - * The weights matrix is stored with dimensions [embeddingSize, vocabSize]. - * - * .. code:: - * - * y = EmbeddingNDLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * Input values must be in the range ``[0, vocabSize - 1]``. - * Input must have rank at least 2. The last dimension must always be 1. - * rank 2: shape (x1, 1). x1 is the batch/sequence length. - * rank 3: shape (x1, x2, 1). x1 * x2 is effectively the combined batch/sequence length. - * rank 4: shape (x1, x2, x3, 1). x1 * x2 * x2 is effectively the combined batch/sequence length. - * rank 5: shape (x1, x2 , x3, x4, 1). x1 * x2 * x3 * x4 is effectively the combined batch/sequence length. - * - * Output - * Output rank is same as the input rank. Please see input description above. - * rank 2: shape (x1, embeddingSize) - * rank 3: shape (x1, x2, embeddingSize) - * rank 4: shape (x1, x2, x3, embeddingSize) - * rank 5: shape (x1, x2, x3, x4, embeddingSize) - * - */ -message EmbeddingNDLayerParams { - - uint64 vocabSize = 1; /// Size of the input dictionary. - uint64 embeddingSize = 2; /// Size of the output vectors. - bool hasBias = 3; /// Whether a bias is added or not. - WeightParams weights = 20; /// 2-D weights of dimensions [embeddingSize, vocabSize]. - WeightParams bias = 21; /// Bias of size [embeddingSize]. - -} - -/** - * A layer that performs batch normalization, - * which is performed along axis = -3, - * and repeated along the other axes, if present. - * - * .. code:: - * - * y = BatchnormLayer(x) - * - * Requires 1 input and produces 1 output. - * - * This operation is described by the following formula: - * - * .. math:: - * y_i = \gamma_i \dfrac{ (x_i - \mu_i)}{\sqrt{\sigma_i^2 + \epsilon}} + \beta_i \;,\;i=1,....,C - * - * Input - * A blob with rank greater than equal to 3. - * Example: Rank 4 blob represents [Batch, channels, height, width] - * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * - * Output - * A blob with the same shape as the input. - */ -message BatchnormLayerParams { - - uint64 channels = 1; /// Size of the channel dimension in the input. - - /** - * If ``computeMeanVar == true``, - * the mean and variance are calculated from either - * the single input instance, if ``instanceNormalization == true``, - * or the whole batch, if ``instanceNormalization = false``. - * and the values provided in parameters "mean" and "variance" are ignored. - */ - bool computeMeanVar = 5; - bool instanceNormalization = 6; - - /** - * A small constant to avoid division by 0 while normalizing by variance. - * Defaults to ``1e-5`` if not set or set to ``0``. - */ - float epsilon = 10; - - WeightParams gamma = 15; /// Parameter of length [channels] - WeightParams beta = 16; /// Parameter of length [channels] - WeightParams mean = 17; /// Parameter of length [channels] - WeightParams variance = 18; /// Parameter of length [channels] - -} - -/** - * A spatial pooling layer. - * - * .. code:: - * - * y = PoolingLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with rank greater than equal to 4. - * Rank 4 blob represents [Batch, channels, height, width] - * For ranks greater than 4, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * - * Output - * Rank is same as the input. e.g.: for rank 4 input, output shape is [B, C, H_out, W_out] - * - * Padding options are similar to ``ConvolutionLayerParams`` - * with the additional option of ``ValidCompletePadding`` (``includeLastPixel``), - * which ensures that the last application of the kernel - * always includes the last pixel of the input image, if there is padding. - * - * .. code:: - * - * H_out = ceil(float(H_in + 2 * paddingAmounts[0] - kernelSize[0])/float(Stride[0])) + 1 - * if (paddingAmounts[0] > 0 or paddingAmounts[1] > 0) - * if ((H_out - 1) * Stride >= H_in + paddingAmounts[0]) { - * H_out = H_out - 1 - * } - * } - * - * The equivalent expressions hold true for ``W_out`` as well. - * Only symmetric padding is supported with this option. - */ -message PoolingLayerParams { - - enum PoolingType { - - MAX = 0; - AVERAGE = 1; - L2 = 2; - - } - PoolingType type = 1; /// Type of pooling operation. - - /** - * Must be length 2 in the order ``[H, W]``. - * If not set, default value ``[3, 3]`` is used. - */ - repeated uint64 kernelSize = 10; - - /** - * Must be length 2 in the order ``[H, W]``. - * If not set, default value ``[1, 1]`` is used. - */ - repeated uint64 stride = 20; - - message ValidCompletePadding { - - /** - * Must be length 2 in order ``[H, W]``. - * If not set, value ``[0, 0]`` is used. - */ - repeated uint64 paddingAmounts = 10; - - } - - oneof PoolingPaddingType { - ValidPadding valid = 30; - SamePadding same = 31; - ValidCompletePadding includeLastPixel = 32; - } - - /** - * If true, padded values are excluded from the count (denominator) - * when computing average pooling. - */ - bool avgPoolExcludePadding = 50; - - /** - * If true, global pooling is performed. - * Kernel size is inferred from the input data spatial dimensions. - */ - bool globalPooling = 60; - -} - -/* - * A layer to pool three spatial dimensions - * - * Input - * A blob with rank equal to 5, representing [Batch, channels, depth, height, width]. - * - * Output - * Rank is same as the input: A blob with rank equal to 5, representing [Batch, channels, depth, height, width]. - * - * Requires 1 input and produces 1 output. - * - * For example, given an input of shape (1,1,2,3,3): - * +----+----+----+ - * / | 10 | 11 | 12 | - * / +----+----+----+ - * / | 13 | 14 | 15 | - * / +----+----+----+ - * / | 16 | 17 | 18 | - * / +----+----+----+ - * +----+----+----+ / - * | 1 | 2 | 3 | / - * +----+----+----+ / - * | 4 | 5 | 6 | / - * +----+----+----+ / - * | 7 | 8 | 9 | / - * +----+----+----+ - * - * And applying MAX pooling using: - * Kernel: 2x2x2 - * Stride: 1x1x1 - * Valid Padding - * We expect to get an output with shape: (1,1,1,2,2) and value: - * +----+----+ - * | 14 | 15 | - * +----+----+ - * | 17 | 18 | - * +----+----+ - */ -message Pooling3DLayerParams { - - enum PoolingType3D { - MAX = 0; - AVERAGE = 1; - } - - // Whether to use Max or Average - PoolingType3D type = 1; - - // Depth of the pooling region. - int32 kernelDepth = 2; - - // Height of the pooling region. - int32 kernelHeight = 3; - - // Width of the pooling region. - int32 kernelWidth = 4; - - // Stride along the depth direction - int32 strideDepth = 5; - - // Stride along the height direction - int32 strideHeight = 6; - - // Stride along the width direction - int32 strideWidth = 7; - - /** - * The type of padding. - * All padding types pad the input shape with zeros. - * CUSTOM padding will add the custom padding values specified below to their respective - * dimensions, e.g., `customPaddingFront` number of zeros will be added to one side of the - * input's depth dimension and `customPaddingBack` number of zeros will be added to the other - * side of the input's depth dimension. - * VALID padding adds no padding to any dimension. In this case, the last pool along - * each dimension will be dropped if the input dimension and the kernel size, and stride do not match. - * SAME padding adds enough padding to each dimension such that the output - * has the same spatial dimensions as the input. Padding is added evenly to both - * sides of each dimension unless the total padding to add is odd, in which case the extra padding - * is added to the back/bottom/right side of the respective dimension. For example, if the the - * total horizontal padding is 3, then there will be 1 padding on the left, and 2 padding on the right. - */ - enum Pooling3DPaddingType { - CUSTOM = 0; - VALID = 1; - SAME = 2; - } - Pooling3DPaddingType paddingType = 15; - - // Padding before the input in the depth direction. - int32 customPaddingFront = 8; - - // Padding after the input in the depth direction. - int32 customPaddingBack = 9; - - // Padding before the input in the height direction. - int32 customPaddingTop = 10; - - // Padding after the input in the height direction. - int32 customPaddingBottom = 11; - - // Padding before the input in the width direction. - int32 customPaddingLeft = 12; - - // Padding after the input in the width direction. - int32 customPaddingRight = 13; - - // If true, exclude zeros from padding in Average pooling. Meaningless in Max Pooling. - bool countExcludePadding = 14; -} - -/* - * A layer to pool three spatial dimensions down to one value. - * This behaves like a special case of Pooling3DLayerParams in which - * the Kernel is the size of the input and there is no padding. - * - * Input - * A blob with rank equal to 5, representing [Batch, channels, depth, height, width]. - * - * Output - * Rank is same as the input: A blob with rank equal to 5, representing [Batch, channels, depth, height, width]. - * Depth, height, and width of the output will always be 1. - * - * Requires 1 input and produces 1 output. - * - * For example, given an input of shape (1,1,2,3,3): - * +----+----+----+ - * / | 10 | 11 | 12 | - * / +----+----+----+ - * / | 13 | 14 | 15 | - * / +----+----+----+ - * / | 16 | 17 | 18 | - * / +----+----+----+ - * +----+----+----+ / - * | 1 | 2 | 3 | / - * +----+----+----+ / - * | 4 | 5 | 6 | / - * +----+----+----+ / - * | 7 | 8 | 9 | / - * +----+----+----+ - * - * And applying MAX global 3d pooling, we expect to get an output with shape: (1,1,1,1,1) and value: - * +----+ - * | 18 | - * +----+ - */ -message GlobalPooling3DLayerParams { - - enum GlobalPoolingType3D { - MAX = 0; - AVERAGE = 1; - } - - // Whether to use Max or Average - GlobalPoolingType3D type = 1; -} - -/** - * A layer that performs padding along spatial dimensions. - * - * .. code:: - * - * y = PaddingLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with rank at least 2. - * e.g.: blob with shape ``[H_in, W_in]``. - * For ranks greater than 2, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch - * i.e. Padding is applied on last two dimensions. - * - * Output - * Same rank as the input. - * e.g.: blob with shape ``[H_out, W_out]``. - * - * Output dimensions are calculated as follows: - * - * .. code:: - * - * H_out = H_in + topPaddingAmount + bottomPaddingAmount - * W_out = W_in + leftPaddingAmount + rightPaddingAmount - * - * topPaddingAmount == Height startEdgeSize == borderAmounts[0].startEdgeSize - * bottomPaddingAmount == Height endEdgeSize == borderAmounts[0].endEdgeSize - * leftPaddingAmount == Width startEdgeSize == borderAmounts[1].startEdgeSize - * rightPaddingAmount == Width endEdgeSize == borderAmounts[1].endEdgeSize - * - * There are three types of padding: - * - * - ``PaddingConstant``, which fills a constant value at the border. - * - ``PaddingReflection``, which reflects the values at the border. - * - ``PaddingReplication``, which replicates the values at the border. - * - * Given the following input: - * - * .. code:: - * - * [1, 3, 4] : 1 2 3 4 - * 5 6 7 8 - * 9 10 11 12 - * - * Here is the output of applying the padding - * ``(top=2, left=2, bottom=0, right=0)`` - * with each of the supported types: - * - * - ``PaddingConstant`` (``value = 0``): - * .. code:: - * - * [1, 5, 6] : 0 0 0 0 0 0 - * 0 0 0 0 0 0 - * 0 0 1 2 3 4 - * 0 0 5 6 7 8 - * 0 0 9 10 11 12 - * - * - ``PaddingReflection``: - * .. code:: - * - * [1, 5, 6] : 11 10 9 10 11 12 - * 7 6 5 6 7 8 - * 3 2 1 2 3 4 - * 7 6 5 6 7 8 - * 11 10 9 10 11 12 - * - * - ``PaddingReplication``: - * .. code:: - * - * [1, 5, 6] : 1 1 1 2 3 4 - * 1 1 1 2 3 4 - * 1 1 1 2 3 4 - * 5 5 5 6 7 8 - * 9 9 9 10 11 12 - */ -message PaddingLayerParams { - - /** - * Fill a constant value in the padded region. - */ - message PaddingConstant { - float value = 1; - } - - /** - * Reflect the values at the border for padding. - */ - message PaddingReflection { - } - - /** - * Replicate the values at the border for padding. - */ - message PaddingReplication { - } - - oneof PaddingType { - PaddingConstant constant = 1; - PaddingReflection reflection = 2; - PaddingReplication replication = 3; - } - - BorderAmounts paddingAmounts = 10; /// Amounts to be padded to the input. - -} - -/** - * A layer that concatenates along the axis = -3 or -5. - * For general concatenation along any axis, see ConcatNDLayer. - * - * .. code:: - * - * y = ConcatLayer(x1,x2,....) - * - * Requires more than 1 input and produces 1 output. - * - * Input - * All input blobs must have same rank. - * If "sequenceConcat" = False, rank must be greater than equal to 3. In this case concatenation is along axis = -3 - * If "sequenceConcat" = True, rank must be greater than equal to 5. In this case concatenation is along axis = -5 - * - * Output - * Same rank as the input. - * - */ -message ConcatLayerParams { - - /** - * If true, concatenate along the axis = -5 instead of axis = -3. - */ - bool sequenceConcat = 100; - -} - -/** - * A layer that performs local response normalization (LRN). - * - * .. code:: - * - * y = LRNLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with rank greater than equal to 3. - * Example: Rank 4 blob represents [Batch, channels, height, width] - * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * Output - * A blob with the same shape as the input. - * - * This layer is described by the following formula: - * - * .. math:: - * x_i \leftarrow \dfrac{x_i}{\left ( k + \dfrac{\alpha}{C} \sum_j x_j^2 \right )^\beta} - * - * where the summation is done over a ``(localSize, 1, 1)`` neighborhood --- - * that is, over a window "across" channels in 1x1 spatial neighborhoods. - */ -message LRNLayerParams { - - float alpha = 1; - float beta = 2; - uint64 localSize = 3; /// Number of channels in the normalization window. - float k = 4; /// Defaults to 1 if not set or 0. Must be strictly positive. - -} - -/** - * Softmax Normalization Layer - * - * A layer that performs softmax normalization. - * Normalization is applied along axis = -3 or N-3 (where N is the rank of the input) - * For softmax layer that can operate on any axis, see SoftmaxNDLayer. - * - * - * .. code:: - * - * y = SoftmaxLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * Must be a blob with rank >= 3. - * Output - * A blob with the same shape as the input. - * - * This layer is described by the following formula: - * - * .. math:: - * x_i \leftarrow \dfrac{e^{x_i}}{\sum_i{e^{x_i}}} - */ -message SoftmaxLayerParams { - -} - -/** - * A layer that uniformly splits across axis = -3 to produce a specified number of outputs. - * For general split operation along any axis, see SplitNDLayer. - * - * .. code:: - * - * (y1,y2,...yN) = SplitLayer(x), where N = nOutputs - * - * Requires 1 input and produces multiple outputs. - * - * Input - * A blob with rank at least 3. - * e.g.: blob with shape ``[C, H, W]`` - * Output - * ``nOutputs`` blobs each with same rank as the input. - * e.g.: For input that is of shape ``[C, H, W]``, output shapes will be ``[C/nOutputs, H, W]`` - */ -message SplitLayerParams { - - uint64 nOutputs = 1; /// The number of outputs. - -} - -/** - * A layer that performs elementwise addition. - * This layer has limited broadcasting support. For general broadcasting see AddBroadcastableLayer. - * - * .. code:: - * - * y = AddLayer(x1,x2,...) - * - * Requires 1 or more than 1 input and produces 1 output. - * - * Input - * In general, there are no rank constraints. - * However, only certain set of shapes are broadcastable. For example: - * [B, 1, 1, 1], [B, C, 1, 1], [B, 1, H, W], [B, C, H, W] - * Output - * A blob with shape equal to the input blob. - * - * If only one input is provided, scalar addition is performed: - * - * .. math:: - * y = x + \alpha - * - */ -message AddLayerParams { - - /** - * Scalar to be added to the input. - * Only used if there is a single input. - */ - float alpha = 1; - -} - -/** - * A layer that performs elementwise multiplication. - * This layer has limited broadcasting support. For general broadcasting see MultiplyBroadcastableLayer. - * - * .. code:: - * - * y = MultiplyLayer(x1,x2,...) - * - * Requires 1 or more than 1 input and produces 1 output. - * - * Input - * In general, there are no rank constraints. - * However, only certain set of shapes are broadcastable. For example: - * [B, 1, 1, 1], [B, C, 1, 1], [B, 1, H, W], [B, C, H, W] - * Output - * A blob with shape equal to the first input blob. - * - * If only one input is provided, scalar multiplication is performed: - * - * .. math:: - * y = \alpha x - * - */ -message MultiplyLayerParams { - - /** - * Scalar to be multiplied with the input. - * Only used if there is a single input. - */ - float alpha = 1; - -} - -/** - * A layer that applies a unary function. - * - * .. code:: - * - * y = UnaryFunctionLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with no rank constraints. - * Output - * A blob with the same shape as the input. - * - * The input is first modified by shifting and scaling: - * - * .. math:: - * x \leftarrow \text{scale} \cdot x + \text{shift} - */ -message UnaryFunctionLayerParams { - - /** - * A unary operator. - * - * The following functions are supported: - * - * ``SQRT`` - * .. math:: f(x) = \sqrt{x} - * - * ``RSQRT`` - * .. math:: f(x) = \dfrac{1}{\sqrt{x + \epsilon}} - * - * ``INVERSE`` - * .. math:: f(x) = \dfrac{1}{x + \epsilon} - * - * ``POWER`` - * .. math:: f(x) = x^\alpha - * - * ``EXP`` - * .. math:: f(x) = e^x - * - * ``LOG`` - * .. math:: f(x) = \log x - * - * ``ABS`` - * .. math:: f(x) = |x| - * - * ``THRESHOLD`` - * .. math:: f(x) = \text{max}(\alpha, x) - */ - enum Operation { - SQRT = 0; - RSQRT = 1; - INVERSE = 2; - POWER = 3; - EXP = 4; - LOG = 5; - ABS = 6; - THRESHOLD = 7; - } - Operation type = 1; /// The type of unary function. - - /** - * A constant used in ``POWER`` and ``THRESHOLD`` functions. - */ - float alpha = 2; - - /** - * A small constant to avoid division by 0 while normalizing variance. - * Defaults to ``1e-6`` if not set or set to ``0``. - */ - float epsilon = 3; - - /** - * Input is shifted by this amount - * before the unary function is applied. - * Defaults to ``0.0`` if not set. - */ - float shift = 4; - - /** - * Input is scaled by this amount - * before the unary function is applied. - * Defaults to ``1.0`` if not set or set to ``0``. - */ - float scale = 5; - -} - -/** - * A layer that scales up spatial dimensions. - * It supports two modes: nearest neighbour (default) and bilinear. - * - * .. code:: - * - * y = UpsampleLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with rank at least 3. - * e.g.: blob with shape ``[C, H, W]``. - * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * - * Output - * Same rank as the input. - * e.g.: blob with shape ``[C, scalingFactor[0] * H, scalingFactor[1] * W]`` - */ -message UpsampleLayerParams { - - /** - * Scaling Factor. Mutually exclusive with fractionalScalingFactor. - * Must be length 2 in order ``[H, W]``. - * If not set, default value ``[1, 1]`` is used. - */ - repeated uint64 scalingFactor = 1; - - /** - * Fractional scaling factor. Mutually exclusive with scalingFactor. - * Must be length 2 in order ``[H, W]``. - * If not set, default value ``[1.0, 1.0]`` is used. - */ - repeated float fractionalScalingFactor = 7; - - /* - * Overall mode for interpolating new elements when upsampling. - * NN - Nearest Neighbors - simply pick the nearest true value for interpolated values. - * BILINEAR - Use bilinear interpolation. See LinearUpsamplingMode for behavior. - */ - enum InterpolationMode { - - NN = 0; /// Nearest Neighbour - BILINEAR = 1; /// Bilinear - - } - - InterpolationMode mode = 5; - - /** - * LinearUpsampleMode specifies the behavior for linear upsampling. Only valid when Interpolation Mode is BILINEAR. - * If input grid is [0, Xin-1] (corresponding to an input size of Xin), and if the output size is Xout, - * then the grid points are sampled in the following manner: - * DEFAULT: - * spacing = (Xin-Xin/Xout) / (Xout-1) - * grid_point[i] = min(Xin-1, max(0, i * spacing)), for i = 0,1,2,….,Xout-1 - * ALIGN_CORNERS_TRUE: - * spacing = (Xin-1) / (Xout-1) - * grid_point[i] = min(Xin-1, max(0, i * spacing)), for i = 0,1,2,….,Xout-1 - * ALIGN_CORNERS_FALSE: - * spacing = Xin / Xout - * grid_point[i] = min(Xin-1, max(0, i * spacing + 0.5 * spacing - 0.5)), for i = 0,1,2,….,Xout-1 - */ - enum LinearUpsampleMode { - - DEFAULT = 0; - ALIGN_CORNERS_TRUE = 1; - ALIGN_CORNERS_FALSE = 2; - - } - - LinearUpsampleMode linearUpsampleMode = 6; - -} - -/** -* A layer that resizes the input to a pre-specified spatial size using bilinear interpolation. -* -* .. code:: -* -* y = ResizeBilinearLayer(x) -* -* Requires 1 input and produces 1 output. -* -* Input -* A blob with rank at least 3. -* e.g.: blob with shape ``[C, H_in, W_in]``. -* For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. -* -* Output -* Same rank as the input. -* e.g.: blob with shape ``[C, H_out, W_out]``. -* -*/ -message ResizeBilinearLayerParams { - - /** - * Target Spatial Size. - * Must be length 2 in order ``[Height, Width]``, i.e. ``[H_out, W_out]``. - * If not set, default value ``[1, 1]`` is used. - */ - repeated uint64 targetSize = 1; - - /** - * Mode used to compute the grid on which the spatial output values are evaluated. - * Same mode is applied to both the height and width axes. - */ - SamplingMode mode = 2; - -} - -/** -* A layer that extracts cropped spatial patches or RoIs (regions of interest) from the input and resizes them to a pre-specified size using -* bilinear interpolation. -* Note that RoI Align layer can be implemented with this layer followed by a pooling layer. -* -* .. code:: -* -* y = CropResizeLayer(x) -* -* Requires 2 inputs and produces 1 output. -* -* Input -* There are two inputs. -* First input represents an image feature map. -* Second input represents the bounding box coordinates for N patches or RoIs (region of interest). -* -* First input is rank 5: [1, Batch, C, H_in, W_in]. -* Second input is rank 5. Its shape can be either [N, 1, 4, 1, 1] or [N, 1, 5, 1, 1]. -* -* N: number of patches/RoIs to be extracted -* -* If RoI shape = ``[N, 1, 4, 1, 1]`` -* The axis=-3 corresponds to the four coordinates specifying the bounding box. -* All the N RoIs are extracted from all the batches of the input. -* -* If RoI shape = ``[N, 1, 5, 1, 1]`` -* The first element of the axis=-3 specifies the input batch id from which to extract the RoI and -* must be in the interval ``[0, Batch - 1]``. That is, n-th RoI is extracted from the RoI[n,0,0,0,0]-th -* input batch id. The last four elements of the axis=-3 specify the bounding box coordinates. -* -* Output -* A blob with rank 5. -* - Shape is [N, Batch, C, H_out, W_out] if input RoI shape is [N, 1, 4, 1, 1] -* - Shape is [N, 1, C, H_out, W_out] if input RoI shape is [N, 1, 5, 1, 1] -* -*/ -message CropResizeLayerParams { - - /** - * Target Spatial Size. - * Must be length 2 in order ``[Height, Width]``, i.e. ``[H_out, W_out]``. - * If not set, default value ``[1, 1]`` is used. - */ - repeated uint64 targetSize = 1; - - /** - * If true the bounding box coordinates must be in the interval [0, 1]. - * They are scaled by (H_in - 1), (W_in - 1), i.e. based on the input spatial dimensions. - * If false the bounding box coordinates must be in the interval - * [0, H_in -1] and [0, W_in - 1], respectively for height and width dimensions. - */ - bool normalizedCoordinates = 2; - - /** - * Mode used to compute the grid on which the spatial output values are evaluated. - * Same mode is applied to both the height and width axes. - */ - SamplingMode mode = 3; - - /** - * Representation used to express the bounding box coordinates. - * It determines how the values of the second input are interpreted. - */ - BoxCoordinatesMode boxIndicesMode = 4; - - /** - * Additional spatial scale that multiplies the bounding box coordinates. - * Generally used while implementing the RoI Align layer, - * which uses unnormalized RoI coordinates along with a spatial scale less than or equal to 1. - */ - float spatialScale = 5; - -} - -/** - * A layer that performs elementwise addition of a bias, - * which is broadcasted to match the input shape. - * - * .. code:: - * - * y = BiasLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with rank at least 3. - * e.g.: blob with shape ``[C, H, W]``. - * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * Output - * A blob with the same shape as the input. - */ -message BiasLayerParams { - - /** - * The shape of the bias. - * Must be one of the following: - * ``[1]``, ``[C]``, ``[1, H, W]`` or ``[C, H, W]``. - */ - repeated uint64 shape = 1; - - /** - * The bias values. - * The size must be equal to the product of the ``shape`` dimensions. - */ - WeightParams bias = 2; - -} - -/** - * A layer that performs elmentwise multiplication by a scale factor - * and optionally adds a bias; - * both the scale and bias are broadcasted to match the input shape. - * - * .. code:: - * - * y = ScaleLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with rank at least 3. - * e.g.: blob with shape ``[C, H, W]``. - * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * Output - * A blob with the same shape as the input. - */ -message ScaleLayerParams { - - /** - * The shape of the scale. - * Must be one of the following: - * ``[1]``, ``[C]``, ``[1, H, W]`` or ``[C, H, W]``. - */ - repeated uint64 shapeScale = 1; - - /** - * The scale values. - * The size must be equal to the product of the ``shape`` dimensions. - */ - WeightParams scale = 2; /// Scale values. Size must be equal to the product of dimensions specified in shapeScale. - - bool hasBias = 3; /// If true, a bias is added after scaling. - - /** - * The shape of the bias. - * Must be one of the following: - * ``[1]``, ``[C]``, ``[1, H, W]`` or ``[C, H, W]``. - */ - repeated uint64 shapeBias = 4; - - /** - * The bias values. - * The size must be equal to the product of the ``shape`` dimensions. - */ - WeightParams bias = 5; - -} - -/** - * A layer that loads data as a parameter and provides it as an output. - * The output is rank 5. For general rank, see LoadConstantNDLayer. - * - * .. code:: - * - * y = LoadConstantLayer() - * - * Requires no input and produces 1 output. - * - * Output: - * A blob with rank 5 and shape ``[1, 1, C, H, W]`` - */ -message LoadConstantLayerParams { - - /** - * The shape of the constant to be loaded, - * which must be``[C, H, W]``, that is length 3. - */ - repeated uint64 shape = 1; - - /** - * The data values, - * of size ``C * H * W``. - */ - WeightParams data = 2; - -} - -/** - * A layer that performs L2 normalization, i.e. divides by the - * the square root of the sum of squares of all elements of input. - * - * .. code:: - * - * y = L2NormalizeLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with rank greater than equal to 3. - * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * Output - * A blob with the same shape as the input. - * - * This layer is described by the following formula: - * - * .. math:: - * x_i \leftarrow \dfrac{x_i}{\sqrt{\sum{x_i^2} + \epsilon}} - */ -message L2NormalizeLayerParams { - - /** - * A small constant to avoid division by 0 while normalizing variance. - * Defaults to ``1e-6`` if not set or set to ``0``. - */ - float epsilon = 1; - -} - -/// Data Reorganization Layers -/// -------------------------- - -/** - * A layer that flattens the input. - * - * .. code:: - * - * y = FlattenLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with rank greater than equal to 3. - * e.g.: Rank 4 blob represents [Batch, C, H, W] - * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * Output - * Same rank as the input, such that last two dimensions are both 1. - * e.g.: For rank 4 input, output shape is ``[Batch, C * H * W, 1, 1]`` - * - * There are two X orders: ``CHANNEL_FIRST`` and ``CHANNEL_LAST``. - * ``CHANNEL_FIRST`` does not require data to be rearranged, - * because row major ordering is used by internal storage. - * ``CHANNEL_LAST`` requires data to be rearranged. - */ -message FlattenLayerParams { - - enum FlattenOrder { - - CHANNEL_FIRST = 0; - CHANNEL_LAST = 1; - - } - FlattenOrder mode = 1; - -} - -/** - * A layer that recasts the input into a new shape. - * - * .. code:: - * - * y = ReshapeLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with rank 5. - * e.g.: ``[1, 1, C, H, W]`` or ``[Seq, 1, C, H, W]``. - * Output - * A blob with rank 5. - * e.g.: ``[1, 1, C_out, H_out, W_out]`` or ``[Seq_out, 1, C_out, H_out, W_out]``. - * - * There are two reshape orders: ``CHANNEL_FIRST`` and ``CHANNEL_LAST``. - * ``CHANNEL_FIRST`` is equivalent to - * flattening the input to ``[Seq, 1, C * H * W, 1, 1]`` in channel first order - * and then reshaping it to the target shape; - * no data rearrangement is required. - * ``CHANNEL_LAST`` is equivalent to - * flattening the input to ``[Seq, 1, H * W * C, 1, 1]`` in channel last order, - * reshaping it to ``[Seq_out, 1, H_out, W_out, C_out]`` (it is now in "H_out-major"" order), - * and then permuting it to ``[C_out, H_out, W_out]``; - * both the flattening and permuting requires the data to be rearranged. - */ -message ReshapeLayerParams { - - /** - * The shape of the output. - * Must be of length 3 or 4. - * If set to 3, ``targetShape`` is interpreted as - * ``[1, 1, C_out, H_out, W_out]``, and sequence length of the input is preserved. - * If set to 4, ``targetShape`` is interpreted as - * ``[Seq_out, 1, C_out, H_out, W_out]``, - * where ``Seq_out`` is the new sequence length. - */ - repeated int64 targetShape = 1; - - enum ReshapeOrder { - - CHANNEL_FIRST = 0; - CHANNEL_LAST = 1; - - } - ReshapeOrder mode = 2; - -} - -/** - * A layer that rearranges the dimensions and data of an input. - * For generic transpose/permute operation see TransposeLayer. - * - * .. code:: - * - * y = PermuteLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * Must be a rank 5 blob. - * e.g.: shape ``[Seq, B, C, H, W]``. - * Output - * Rank 5 blob. Transposed version of the input, such that dimensions at axis=1 or axis=-4 is unchanged. - * - * - * Examples: - * - * Assume input shape is [Seq, B, C, H, W] - * - * - If ``axis`` is set to ``[0, 3, 1, 2]``, - * then the output has shape ``[Seq, B, W, C, H]`` - * - * - If ``axis`` is set to ``[3, 1, 2, 0]``, - * then the output has shape ``[W, B, C, H, Seq]`` - * - * - If ``axis`` is set to ``[0, 3, 2, 1]``, - * then the output has shape ``[Seq, B, W, H, C]`` - * - * - If ``axis`` is not set, or is set to ``[0, 1, 2, 3]``, - * the output is the same as the input. - */ -message PermuteLayerParams { - - /** - * The order in which to permute the dimensions. - * Must have length 4 and a permutation of ``[0, 1, 2, 3]``. - */ - repeated uint64 axis = 1; - -} - -/** - * A layer that reorganizes data in the input in specific ways. - * - * .. code:: - * - * y = ReorganizeDataLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with rank at least 3. - * e.g.: blob with shape ``[C, H, W]``. - * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * Output - * Same rank as the input. - * e.g.: blob with shape ``[C_out, H_out, W_out]``. - * - * mode == SPACE_TO_DEPTH - * ``[C_out, H_out, W_out]`` : ``[C * blockSize * blockSize, H/blockSize, W/blockSize]``. - * blockSize must divide H and W. - * Data is moved from the spatial dimensions to the channel dimension. Input is spatially divided into - * non-overlapping blocks of size blockSize X blockSize and data from each block is moved into the - * channel dimension. - * - * mode == DEPTH_TO_SPACE - * ``[C_out, H_out, W_out]`` : ``[C/(blockSize * blockSize), H * blockSize, W * blockSize]``. - * Square of blockSize must divide C. - * Reverse of SPACE_TO_DEPTH. Data is moved from the channel dimension to the spatial dimensions. - * - * mode == PIXEL_SHUFFLE - * ``[C_out, H_out, W_out]`` : ``[C/(blockSize * blockSize), H * blockSize, W * blockSize]``. - * Square of blockSize must divide C. - * Similar to DEPTH_TO_SPACE, but using the pixel-shuffle semantics for channel order in the output space. - * In both modes, elements along the channel dimension are collapsed into - * blocks in the spatial dimensions. The difference is in the arrangement of - * the input-channels' data in the output space. See below example for more - * detail. - * (Only available in Core ML Specification >= 5 (iOS >= 14, macOS >= 11.0) - * - * - * Examples: - * - * Assume input is the following [C = 8, H = 1, W = 2] tensor: - * - * .. code:: - * - * [[[1 2]] [[3 4]] [[5 6]] [[7 8]] [[9 10]] [[11 12]] [[13 14]] [[15 16]]] - * - * If block_size == 2 and mode == DEPTH_TO_SPACE, output will be the following - * [C = 2, H = 2, W = 4] tensor: - * - * .. code:: - * - * [[[ 1 5 2 6] - * [ 9 13 10 14]] - * - * [[ 3 7 4 8] - * [11 15 12 16]]] - * - * For mode == SPACE_TO_DEPTH, the behavior is the same as mode == - * DEPTH_TO_SPACE, but with the input and output swapped. - * - * If block_size == 2 and mode == PIXEL_SHUFFLE, output will be the following - * [C = 2, H = 2, W = 4] tensor: - * - * .. code:: - * - * [[[ 1 3 2 4] - * [ 5 7 6 8]] - * - * [[ 9 11 10 12] - * [13 15 14 16]]] - * - */ -message ReorganizeDataLayerParams { - - enum ReorganizationType { - - SPACE_TO_DEPTH = 0; - DEPTH_TO_SPACE = 1; - PIXEL_SHUFFLE = 2; - - } - ReorganizationType mode = 1; - uint64 blockSize = 2; /// must be greater than 1 - -} - -/** - * A layer that slices the input data along axis = -1 or -2 or -3. - * For general slice along any axis, please see SliceStaticLayer/SliceDynamicLayer. - * - * .. code:: - * - * y = SliceLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob that can, in general, have any rank. However, depending on the value of "axis" , - * there may be additional rank constraints. - * Output - * A blob with the same rank as the input. - * - * Sliced section is taken from the interval ``[startIndex, endIndex)``, i.e. - * startIndex is inclusive while endIndex is exclusive. - * stride must be positive and represents the step size for slicing. - * Negative indexing is supported for startIndex and endIndex. - * -1 denotes N-1, -2 denotes N-2 and so on, where N is the length of the dimension to be sliced. - * - */ -message SliceLayerParams { - - int64 startIndex = 1; /// start of the sliced section. Inclusive. - int64 endIndex = 2; /// end of sliced section. Exclusive. - uint64 stride = 3; /// The step size. Must be positive. - - enum SliceAxis { - - CHANNEL_AXIS = 0; - HEIGHT_AXIS = 1; - WIDTH_AXIS = 2; - - } - // The following mapping is used for interpreting this parameter: - // CHANNEL_AXIS => axis = -3, input must have rank at least 3. - // HEIGHT_AXIS => axis = -2, input must have rank at least 2. - // WIDTH_AXIS => axis = -1 - SliceAxis axis = 4; - -} - -/** - * A layer that reduces the input using a specified operation. - * - * .. code:: - * - * y = ReduceLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob that can, in general, have any rank. However, depending on the value of "axis" , - * there may be additional rank constraints. - * Output - * A blob with the same rank as the input, which has 1s on the dimensions specified in the parameter "axis" - * - * Values supported for axis are [-1], [-2], [-3], [-2,-1], [-3,-2,-1] - * and the equivalent positive values (depending on the rank of the input) - * For mode == 'ArgMax', axis must be [-1] or [-2] or [-3]. - */ -message ReduceLayerParams { - - /* - * The following reduction operations are supported - * and are applied on the specified axis of the input array: - * - * ``SUM`` - * Sum of all elements - * - * .. math:: \sum{x_i} - * - * ``AVG`` - * Sum of all elements divided by the number of elements - * - * .. math:: \dfrac{\sum^n{x_i}}{n} - * - * ``PROD`` - * Product of all elements - * - * .. math:: \prod{x_i} - * - * ``LOGSUM`` - * Sum of the natural logarithm of all elements - * - * .. math:: \sum{\ln{(x_i + \epsilon)}} - * - * ``SUMSQUARE`` - * Sum of squares of all elements - * - * .. math:: \sum{x^2} - * - * ``L1`` - * L1 normalization of all elements - * - * .. math:: ||x||_1 = \sum{|x_i|} - * - * ``L2`` - * L2 normalization of all elements - * - * .. math:: ||x||_2 = \sqrt{\sum{x_i^2}} - * - * ``MAX`` - * Maximum of all elements - * - * .. math:: \text{max}(x_i) - * - * ``MIN`` - * Minumum of all elements - * - * .. math:: \text{min}(x_i) - * - * ``ARGMAX`` - * Argument of the maximum of all elements - * - * .. math:: \text{argmax}(x_i) - * - */ - enum ReduceOperation { - - SUM = 0; - AVG = 1; - PROD = 2; - LOGSUM = 3; - SUMSQUARE = 4; - L1 = 5; - L2 = 6; - MAX = 7; - MIN = 8; - ARGMAX = 9; /// only supported with axis = C, H or W. - - } - ReduceOperation mode = 1; /// Specifies function used to reduce. - - /** - * Used if mode is ``LOGSUM``. - * Defaults to ``1e-6`` if not set or is set to ``0``. - */ - float epsilon = 2; - - enum ReduceAxis { - - CHW = 0; - HW = 1; - C = 2; - H = 3; - W = 4; - - } - - // The following mapping is used for interpreting this parameter: - // CHW = axis [-3, -2, -1], input must have rank at least 3. - // HW = axis [-2, -1], input must have rank at least 2. - // C = axis [-3] - // H = axis [-2] - // W = axis [-1] - ReduceAxis axis = 3; - -} - -/** - * A layer that crops the spatial dimensions of an input. - * If two inputs are provided, the shape of the second input is used as the reference shape. - * - * .. code:: - * - * y = CropLayer(x1) or y = CropLayer(x1,x2) - * - * Requires 1 or 2 inputs and produces 1 output. - * - * Input - * 1 or 2 tensors, each with rank at least 3, both inputs must have equal rank. - * Example: - * - 1 input case: A blob with shape ``[C, H_in, W_in]``. - * - 2 input case: 1st blob with shape ``[C, H_in, W_in]``, 2nd blob with shape ``[C, H_out, W_out]``. - * - * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * - * Output - * Same rank as the inputs. - * e.g.: A blob with shape ``[C, H_out, W_out]``. - * - * If one input is used, output is computed as follows: - * - * .. code:: - * - * y = x1[:, topCropAmount:H_in - bottomCropAmount, leftCropAmount:W_in - rightCropAmount] - * - * topCropAmount == Height startEdgeSize == borderAmounts[0].startEdgeSize - * bottomCropAmount == Height endEdgeSize == borderAmounts[0].endEdgeSize - * leftCropAmount == Width startEdgeSize == borderAmounts[1].startEdgeSize - * rightCropAmount == Width endEdgeSize == borderAmounts[1].endEdgeSize - * - * H_out = H_in - topCropAmount - bottomCropAmount - * W_out = W_in - leftCropAmount - rightCropAmount - * - * If two inputs are used, output is computed as follows: - * - * .. code:: - * - * y = x1[:, offset[0]:offset[0] + H_out, offset[1]:offset[1] + W_out] - */ -message CropLayerParams { - - /** - * The amounts to be cropped from the input. - * Used only if a single input is provided. - */ - BorderAmounts cropAmounts = 1; - - /** - * The offset amounts. - * Used only if two inputs are provided. - * Must be of length 2, in order ``[H, W]``. - */ - repeated uint64 offset = 5; - -} - -/** - * A layer that computes the elementwise average of the inputs. - * This layer has limited broadcasting support. For general broadcasting see AddBroadcastableLayer. - * - * .. code:: - * - * y = AverageLayer(x1,x2,...) - * - * Requires multiple inputs and produces 1 output. - * - * Input - * In general, there are no rank constraints. - * However, only certain set of shapes are broadcastable. For example: - * [B, 1, 1, 1], [B, C, 1, 1], [B, 1, H, W], [B, C, H, W] - * Output - * A blob with the same shape as each input. - */ -message AverageLayerParams { - -} - -/** - * A layer that computes the elementwise maximum over the inputs. - * - * .. code:: - * - * y = MaxLayer(x1,x2,...) - * - * Requires multiple inputs and produces 1 output. - * - * Input - * In general, there are no rank constraints. - * However, only certain set of shapes are broadcastable. For example: - * [B, C, 1, 1], [B, C, H, W] - * Output - * A blob with the same shape as each input. - */ -message MaxLayerParams { - -} - -/** - * A layer that computes the elementwise minimum over the inputs. - * - * .. code:: - * - * y = MinLayer(x1,x2,...) - * - * Requires multiple inputs and produces 1 output. - * - * Input - * In general, there are no rank constraints. - * However, only certain set of shapes are broadcastable. For example: - * [B, C, 1, 1], [B, C, H, W] - * Output - * A blob with the same shape as each input. - */ -message MinLayerParams { - -} - -/** - * A layer that computes the dot product of two vectors. - * - * .. code:: - * - * y = DotProductLayer(x1,x2) - * - * Requires 2 inputs and produces 1 output. - * - * Input - * Two blobs with rank at least 3, such that the last two dimensions must be 1. - * e.g.: blobs with shape ``[B, C, 1, 1]``. - * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * - * Output - * Same rank as the input. - * e.g. for rank 4 inputs, output shape: [B, 1, 1, 1] - */ -message DotProductLayerParams { - - /** - * If true, inputs are normalized first, - * thereby computing the cosine similarity. - */ - bool cosineSimilarity = 1; - -} - -/** - * A layer that performs mean variance normalization, along axis = -3. - * - * .. code:: - * - * y = MeanVarianceNormalizeLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with rank greater than equal to 3. - * Example: Rank 4 blob represents [Batch, channels, height, width] - * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * - * Output - * A blob with the same shape as the input. - * - * If ``acrossChannels == true`` - * normalization is performed on flattened input, i.e. the input is reshaped to (Batch,C), where "Batch" contains - * all dimensions from 0 to -4 (inclusive), and C contains dimensions -1, -2, -3. - * - * If ``acrossChannels == false`` - * normalization is performed within a channel, - * across spatial dimensions (i.e. last two dimensions). - */ -message MeanVarianceNormalizeLayerParams { - - /** - * If true, mean and variance are computed across channels. - */ - bool acrossChannels = 1; - - /** - * If false, only mean is subtracted. - */ - bool normalizeVariance = 2; - - /** - * A small constant to avoid division by 0 while normalizing variance. - * Defaults to ``1e-6`` if not set or set to ``0``. - */ - float epsilon = 3; - -} - -/** - * A layer that repeats a sequence or the dimension sitting at axis = -5 - * - * .. code:: - * - * y = SequenceRepeatLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with rank at least 5. - * e.g: shape ``[Seq, B, C, H, W]`` - * Output - * A blob with the same rank as the input. - * e.g.: for input shape ``[Seq, B, C, H, W]``, output shape is ``[nRepetitions * Seq, B, C, H, W]``. - */ -message SequenceRepeatLayerParams { - - /** - * Number of repetitions. - * Defaults to ``1`` if not set or set to ``0``. - */ - uint64 nRepetitions = 1; - -} - -/// Recurrent Layers -/// ---------------- - -/* - * The following activations are supported with recurrent layers: - * - Linear - * - Sigmoid - * - Tanh - * - ReLU - * - Scaled Hyperbolic Tangent: alpha * tanh(beta * x), currently only supported for alpha = 1.7159, beta = 2/3 - * - Hard Sigmoid: min(max(alpha * x + beta, 0), 1), currently only supported for alpha = 0.2, beta = 0.5 - */ - -/** - * A simple recurrent layer. - * - * .. code:: - * - * y_t = SimpleRecurrentLayer(x_t, y_{t-1}) - * - * Input - * A blob of rank 5, with shape `[Seq, Batch, inputVectorSize, 1, 1]``. - * This represents a sequence of vectors of size ``inputVectorSize``. - * Output - * Same rank as the input. - * Represents a vector of size ``outputVectorSize``. It is either the final output or a sequence of outputs at all time steps. - * - * - Output Shape: ``[1, Batch, outputVectorSize, 1, 1]`` , if ``sequenceOutput == false`` - * - Output Shape: ``[Seq, Batch, outputVectorSize, 1, 1]`` , if ``sequenceOutput == true`` - * - * This layer is described by the following equation: - * - * .. math:: - * \boldsymbol{y_t} = f(\mathrm{clip}(W \boldsymbol{x_t} + \ - * R \boldsymbol{y_{t-1}} + b)) - * - * - ``W`` is a 2-dimensional weight matrix - * (``[outputVectorSize, inputVectorSize]``, row-major) - * - ``R`` is a 2-dimensional recursion matrix - * (``[outputVectorSize, outputVectorSize]``, row-major) - * - ``b`` is a 1-dimensional bias vector (``[outputVectorSize]``) - * - ``f()`` is an activation - * - ``clip()`` is a function that constrains values between ``[-50.0, 50.0]`` - */ -message SimpleRecurrentLayerParams { - - uint64 inputVectorSize = 1; /// The size of the input vectors. - uint64 outputVectorSize = 2; /// The size of the output vectors. - - /** - * Activations supported are Linear, Sigmoid, Tanh, ReLU, Scaled Tanh (alpha = 1.71, beta = 2/3), Hard sigmoid (alpha = 0.2, beta = 0.5) - */ - ActivationParams activation = 10; /// The activation function. - - /** - If false output is just the result after final state update. - If true, output is a sequence, containing outputs at all time steps. - */ - bool sequenceOutput = 15; - - bool hasBiasVector = 20; /// If false, no bias is added. - - WeightParams weightMatrix = 30; /// Weight matrix W. - WeightParams recursionMatrix = 31; /// Recursion Weight matrix R. - WeightParams biasVector = 32; /// Bias vector b. - - bool reverseInput = 100; - // If true, then the node processes the input sequence from right to left - -} - -/** - * Gated-Recurrent Unit (GRU) Layer - * - * .. code:: - * - * y_t = GRULayer(x_t, y_{t-1}) - * - * Input - * A blob of rank 5, with shape `[Seq, Batch, inputVectorSize, 1, 1]``. - * This represents a sequence of vectors of size ``inputVectorSize``. - * Output - * Same rank as the input. - * Represents a vector of size ``outputVectorSize``. It is either the final output or a sequence of outputs at all time steps. - * - * - Output Shape: ``[1, Batch, outputVectorSize, 1, 1]`` , if ``sequenceOutput == false`` - * - Output Shape: ``[Seq, Batch, outputVectorSize, 1, 1]`` , if ``sequenceOutput == true`` - * - * This layer is described by the following equations: - * - * Update Gate - * .. math:: - * \boldsymbol{z_t} = \ - * f(\mathrm{clip}(W_z \boldsymbol{x_t} + \ - * R_z \boldsymbol{y_{t-1}} + b_z) - * - * Reset Gate - * .. math:: - * \boldsymbol{r_t} = \ - * f(\mathrm{clip}(W_r \boldsymbol{x_t} + \ - * R_r \boldsymbol{y_{t-1}} + b_r)) - * - * Cell Memory State - * .. math:: - * \boldsymbol{c_t} = \ - * \boldsymbol{y_{t-1}} \odot \boldsymbol{r_t} - * - * Output Gate - * .. math:: - * \boldsymbol{o_t} = \ - * g(\mathrm{clip}(W_o \boldsymbol{x_t} + \ - * R_o \boldsymbol{c_t} + b_o)) - * - * Output - * .. math:: - * \boldsymbol{y_t} = \ - * (1 - \boldsymbol{z_t}) \odot \boldsymbol{o_t} + \ - * \boldsymbol{z_t} \odot \boldsymbol{y_{t-1}} - * - * - ``W_z``, ``W_r``, ``W_o`` are 2-dimensional input weight matrices - * (``[outputVectorSize, inputVectorSize]``, row-major) - * - ``R_z``, ``R_r``, ``R_o`` are 2-dimensional recursion matrices - * (``[outputVectorSize, outputVectorSize]``, row-major) - * - ``b_z``, ``b_r``, ``b_o`` are 1-dimensional bias vectors - * (``[outputVectorSize]``) - * - ``f()``, ``g()`` are activations - * - ``clip()`` is a function that constrains values between ``[-50.0, 50.0]`` - * - ``⊙`` denotes the elementwise product of matrices - */ -message GRULayerParams { - - uint64 inputVectorSize = 1; /// Size of the input vectors. - uint64 outputVectorSize = 2; /// Size of the output vectors. - - /** - * 2 element array representing activations [f(), g()] in that order. - * Typical values used = [sigmoid, tanh]. - * Activations supported are Linear, Sigmoid, Tanh, ReLU, Scaled Tanh (alpha = 1.71, beta = 2/3), Hard sigmoid (alpha = 0.2, beta = 0.5) - */ - repeated ActivationParams activations = 10; - - /** - * If false output is just the result after final state update. - * If true, output is a sequence, containing outputs at all time steps. - */ - bool sequenceOutput = 15; - - /** - * If false, no biases (``b_z``, ``b_r``, ``b_o``) are added. - */ - bool hasBiasVectors = 20; - - WeightParams updateGateWeightMatrix = 30; /// Weight Matrix W_z. - WeightParams resetGateWeightMatrix = 31; /// Weight Matrix W_r. - WeightParams outputGateWeightMatrix = 32; /// Weight Matrix W_o. - - WeightParams updateGateRecursionMatrix = 50; /// Recursion Weight Matrix R_z. - WeightParams resetGateRecursionMatrix = 51; /// Recursion Weight Matrix R_r. - WeightParams outputGateRecursionMatrix = 52; /// Recursion Weight Matrix R_o. - - WeightParams updateGateBiasVector = 70; /// Bias vector b_z. - WeightParams resetGateBiasVector = 71; /// Bias vector b_r. - WeightParams outputGateBiasVector = 72; /// Bias vector b_o. - - /// If true, then the node processes the input sequence from right to left - bool reverseInput = 100; - -} - -/** - * Long short-term memory (LSTM) parameters. - * - * This is described by the following equations: - * - * Input Gate - * .. math:: - * \boldsymbol{i_t} = \ - * f(\mathrm{clip}(W_i \boldsymbol{x_t} + \ - * R_i \boldsymbol{y_{t-1}} + \ - * p_i \odot c_{t-1} + b_i)) - * - * Forget Gate - * .. math:: - * \boldsymbol{f_t} = \ - * f(\mathrm{clip}(W_f \boldsymbol{x_t} + \ - * R_f \boldsymbol{y_{t-1}} + \ - * p_f \odot c_{t-1} + b_f)) - * - * Block Input - * .. math:: - * \boldsymbol{z_t} = \ - * g(\mathrm{clip}(W_z \boldsymbol{x_t} + \ - * R_z \boldsymbol{y_{t-1}} + b_z)) - * - * Cell Memory State - * .. math:: - * \boldsymbol{c_t} = \ - * \boldsymbol{c_{t-1}} \odot \boldsymbol{f_t} + \ - * \boldsymbol{i_t} \odot \boldsymbol{z_t} - * - * Output Gate - * .. math:: - * \boldsymbol{o_t} = \ - * f(\mathrm{clip}(W_o \boldsymbol{x_t} + \ - * R_o \boldsymbol{y_{t-1}} + \ - * p_o \odot c_t + b_o)) - * - * Output - * .. math:: - * \boldsymbol{y_t} = \ - * h(\boldsymbol{c_t}) \odot \boldsymbol{o_t} - * - * - ``W_i``, ``W_f``, ``W_z``, ``W_o`` are 2-dimensional input weight matrices - * (``[outputVectorSize, inputVectorSize]``, row-major) - * - ``R_i``, ``R_f``, ``R_z``, ``R_o`` are 2-dimensional recursion matrices - * (``[outputVectorSize, outputVectorSize]``, row-major) - * - ``b_i``, ``b_f``, ``b_z``, ``b_o`` are 1-dimensional bias vectors - * (``[outputVectorSize]``) - * - ``p_``, ``p_f``, ``p_o`` are 1-dimensional peephole vectors - * (``[outputVectorSize]``) - * - ``f()``, ``g()``, ``h()`` are activations - * - ``clip()`` is a function that constrains values between ``[-50.0, 50.0]`` - * - ``⊙`` denotes the elementwise product of matrices - */ -message LSTMParams { - - /** - * If true, output is a sequence, containing outputs at all time steps. - * If false, output is just the result after final state update. - */ - bool sequenceOutput = 10; - - /** - * If false, no biases (``b_i``, ``b_f``, ``b_z``, ``b_o``) are added. - */ - bool hasBiasVectors = 20; - - /** - * If true, a vector of ``1`` values is added to ``b_f``. - */ - bool forgetBias = 30; - - /** - * If true, peephole vectors are included. - */ - bool hasPeepholeVectors = 40; - - /** - * If the coupled Input and Forget flag is on, the behaviour of - * ``c_t`` is changed to the following (i.e. forget gate is not used): - * - * .. math:: - * \boldsymbol{c_t} = \ - * \boldsymbol{c_{t-1}} \odot (1 - \boldsymbol{i_t}) + \ - * \boldsymbol{i_t} \odot \boldsymbol{z_t} - * - */ - bool coupledInputAndForgetGate = 50; - - /** - * Places a limit on the maximum and minimum values of ``c_t``. - * c_t = min(c_t, cellClipThreshold) - * c_t = max(c_t, -cellClipThreshold) - * If 0, it is set to its default value = 50.0. - */ - float cellClipThreshold = 60; - -} - -/** - * Weights for long short-term memory (LSTM) layers - */ -message LSTMWeightParams { - - WeightParams inputGateWeightMatrix = 1; /// Weight Matrix W_i. - WeightParams forgetGateWeightMatrix = 2; /// Weight Matrix W_f. - WeightParams blockInputWeightMatrix = 3; /// Weight Matrix W_z. - WeightParams outputGateWeightMatrix = 4; /// Weight Matrix W_o. - - WeightParams inputGateRecursionMatrix = 20; /// Recursion Weight Matrix R_i. - WeightParams forgetGateRecursionMatrix = 21; /// Recursion Weight Matrix R_f. - WeightParams blockInputRecursionMatrix = 22; /// Recursion Weight Matrix R_z. - WeightParams outputGateRecursionMatrix = 23; /// Recursion Weight Matrix R_o. - - //biases: - WeightParams inputGateBiasVector = 40; /// Bias vector b_i. - WeightParams forgetGateBiasVector = 41; /// Bias vector b_f. - WeightParams blockInputBiasVector = 42; /// Bias vector b_z. - WeightParams outputGateBiasVector = 43; /// Bias vector b_o. - - //peepholes: - WeightParams inputGatePeepholeVector = 60; /// Peephole vector p_i. - WeightParams forgetGatePeepholeVector = 61; /// Peephole vector p_f. - WeightParams outputGatePeepholeVector = 62; /// Peephole vector p_o. - -} - -/** - * A unidirectional long short-term memory (LSTM) layer. - * - * .. code:: - * - * (y_t, c_t) = UniDirectionalLSTMLayer(x_t, y_{t-1}, c_{t-1}) - * - * Input - * A blob of rank 5, with shape `[Seq, Batch, inputVectorSize, 1, 1]``. - * This represents a sequence of vectors of size ``inputVectorSize``. - * Output - * Same rank as the input. - * Represents a vector of size ``outputVectorSize``. It is either the final output or a sequence of outputs at all time steps. - * - * - Output Shape: ``[1, Batch, outputVectorSize, 1, 1]`` , if ``sequenceOutput == false`` - * - Output Shape: ``[Seq, Batch, outputVectorSize, 1, 1]`` , if ``sequenceOutput == true`` - * - */ -message UniDirectionalLSTMLayerParams { - - uint64 inputVectorSize = 1; /// Size of the input vectors. - uint64 outputVectorSize = 2; /// Size of the output vectors. - - /** - * 3 element array representing activations [f(),g(),h()] in that order. - * Typical values used = [sigmoid, tanh, tanh]. - * Activations supported are Linear, Sigmoid, Tanh, ReLU, Scaled Tanh (alpha = 1.71, beta = 2/3), Hard sigmoid (alpha = 0.2, beta = 0.5) - */ - repeated ActivationParams activations = 10; - - LSTMParams params = 15; - - LSTMWeightParams weightParams = 20; /// Weights, biases and peepholes. - - /// If true, then the node processes the input sequence from right to left - bool reverseInput = 100; - -} - -/** - * Bidirectional long short-term memory (LSTM) layer - * - * .. code:: - * - * (y_t, c_t, y_t_reverse, c_t_reverse) = BiDirectionalLSTMLayer(x_t, y_{t-1}, c_{t-1}, y_{t-1}_reverse, c_{t-1}_reverse) - * - * Input - * A blob of rank 5, with shape `[Seq, Batch, inputVectorSize, 1, 1]``. - * This represents a sequence of vectors of size ``inputVectorSize``. - * Output - * Same rank as the input. - * Represents a vector of size ``2 * outputVectorSize``. It is either the final output or a sequence of outputs at all time steps. - * - * - Output Shape: ``[1, Batch, 2 * outputVectorSize, 1, 1]`` , if ``sequenceOutput == false`` - * - Output Shape: ``[Seq, Batch, 2 * outputVectorSize, 1, 1]`` , if ``sequenceOutput == true`` - * - * - * The first LSTM operates on the input sequence in the forward direction. - * The second LSTM operates on the input sequence in the reverse direction. - * - * Example: given the input sequence ``[x_1, x_2, x_3]``, - * where ``x_i`` are vectors at time index ``i``: - * - * The forward LSTM output is ``[yf_1, yf_2, yf_3]``, - * - * where ``yf_i`` are vectors of size ``outputVectorSize``: - * - * - ``yf_1`` is the output at the end of sequence {``x_1``} - * - ``yf_2`` is the output at the end of sequence {``x_1``, ``x_2``} - * - ``yf_3`` is the output at the end of sequence {``x_1``, ``x_2``, ``x_3``} - * - * The backward LSTM output: ``[yb_1, yb_2, yb_3]``, - * - * where ``yb_i`` are vectors of size ``outputVectorSize``: - * - * - ``yb_1`` is the output at the end of sequence {``x_3``} - * - ``yb_2`` is the output at the end of sequence {``x_3``, ``x_2``} - * - ``yb_3`` is the output at the end of sequence {``x_3``, ``x_2``, ``x_1``} - * - * Output of the bi-dir layer: - * - * - if ``sequenceOutput = True`` : { ``[yf_1, yb_3]``, ``[yf_2, yb_2]``, ``[yf_3, yb_1]`` } - * - if ``sequenceOutput = False`` : { ``[yf_3, yb_3]`` } - */ -message BiDirectionalLSTMLayerParams { - - /** - * Size of the input vectors. - */ - uint64 inputVectorSize = 1; - /** - * Size of the outputs vectors. - * It is same for both forward and backward LSTMs. - */ - uint64 outputVectorSize = 2; - - /** - * 3 element array representing activations [f(),g(),h()] in that order. - * Typical values used = [sigmoid, tanh, tanh]. - * Activations supported are Linear, Sigmoid, Tanh, ReLU, Scaled Tanh (alpha = 1.71, beta = 2/3), Hard sigmoid (alpha = 0.2, beta = 0.5) - */ - repeated ActivationParams activationsForwardLSTM = 10; - /** - * Currently, backward LSTM activations - * must be same as the ones for the forward LSTM. - */ - repeated ActivationParams activationsBackwardLSTM = 11; - - /** - * Common parameters shared by the forward and backward LSTMs. - */ - LSTMParams params = 15; - - /** - * Weights and biases. - * Must be a length 2 message, - * for the forward and backward LSTM respectively. - */ - repeated LSTMWeightParams weightParams = 20; - -} - -message CustomLayerParams { - - message CustomLayerParamValue { - oneof value { - double doubleValue = 10; - string stringValue = 20; - int32 intValue = 30; - int64 longValue = 40; - bool boolValue = 50; - } - } - - string className = 10; // The name of the class (conforming to MLCustomLayer) corresponding to this layer - repeated WeightParams weights = 20; // Any weights -- these are serialized in binary format and memmapped at runtime - map parameters = 30; // these may be handled as strings, so this should not be large - string description = 40; // An (optional) description of the layer provided by the model creator. This information is displayed when viewing the model, but does not affect the model's execution on device. - -} - -/** - * A layer that rearranges the dimensions and data of an input. - * - * .. code:: - * - * y = TransposeLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A N-Dimensional tensor. - * Output - * A N-Dimensional tensor of the same rank but with dimensions and data permuted according to axes. - * Shape: ``[InputShape[axis[0]], InputShape[axis[1]], ... , InputShape[axis[N-1]]]`` - * - * Examples: - * - * - If ``axes`` is set to ``[3, 1, 2, 0]`` and the input shape is ``[6,7,8,9]``, - * then the output has shape ``[9,7,8,6]`` - */ - -message TransposeLayerParams { - - /** - * Length of "axes" should match the rank of input & output tensor - * "axes" should be a permutation of "[0,1,2,...,N-1]" where N is the rank. - */ - repeated uint64 axes = 1; // - -} - -/** - * A layer that computes the matrix multiplication of two tensors with numpy-like broadcasting - * where the matrices reside in the last two indices of the tensor. - * - * .. code:: - * - * y = BatchedMatMul(a,b) - * - * Requires 1 or 2 inputs and produces 1 output. - * - * The first tensor, "a", must be provided as an input. The second tensor can either be an input or provided as a weight matrix parameter. - * - * Input - * - a: First N-Dimensional tensor - * - b: Second N-Dimensional tensor (either a rank-N input or a matrix, i.e. N=2, provided as a layer parameter) - * - * Output - * A tensor containing the matrix product of two tensors. - * When there are two inputs: rank is max(2, rank(a), rank(b)) - * When there is one input: rank is same as that of the input. - * - * This operation behaves as following: - * - * When there are two inputs: - * - If N >= 2 for both tensors, it is treated as a batch of matrices residing in the last two indices. - * All the indices, except for the last two, are broadcasted using conventional rules. - * - If the first tensor is 1-D, it is converted to a 2-D tensor by prepending a 1 to its shape. Eg. (D) -> (1,D) - * - If the second tensor is 1-D, it is converted to a 2-D tensor by appending a 1 to its shape. Eg. (D) -> (D,1) - * - * When there is one input: - * - The weight matrix corresponds to a matrix, of shape (X1, X2). Values of X1, X2 must be provided as layer parameters. - * - The input, "a", is reshaped into a matrix by combining all the leading dimensions, except the last, into a batch dimension. eg: - * - if "a" is rank 1 (X1,) --> (1, X1). Output shape will be (X2,) - * - if "a" is rank 2 (B1, X1) --> no need to reshape. Output shape will be (B1, X2) - * - if "a" is rank 3 (B1, B2, X1) --> (B1 * B2, X1). Output shape will be (B1, B2, X2) - * - etc - */ -message BatchedMatMulLayerParams { - - /** - * If transposeA is true, it transposes the left matrix on the fly before matrix multiplication. - * (is ignored when there is one input) - */ - bool transposeA = 1; - /** - * If transposeB is true, it transposes the right matrix on the fly before matrix multiplication. - * (is ignored when there is one input) - */ - bool transposeB = 2; - - /* - * Following parameters are ignored when there are two inputs. - */ - - uint64 weightMatrixFirstDimension = 5; /// X1: same as the last dimension of the input tensor - uint64 weightMatrixSecondDimension = 6; /// X2: same as the last dimension of the output tensor - - bool hasBias = 7; /// Whether a bias is added or not. Supported only when there is one input. - - /* - * Weight matrix representing shape [X1, X2]. - * Values are however stored in column major order, - * in the "repeated float" or "bytes" fields of the message "WeightParams" - */ - WeightParams weights = 8; - WeightParams bias = 9; /// Bias vector [X2]. Supported only when there is one input. - - /** - * If set, this layer, at runtime, quantizes the floating point input blob to int8 before applying the - * matrix multiplication using the INT8 weight parameters provided in weights->int8RawValue. The - * result is then dequantized. - * Requires: - * * number of inputs to be 1 - * * hasBias == false - * * QuantizationType == LinearQuantizationParams, such that - * * size of the "scale" field is 1 and "bias" field is empty in "LinearQuantizationParams" - * * numberOfBits == 8 - * * weights->rawValue_size to be empty - */ - bool int8DynamicQuantize = 10; - -} - -/** - * A layer that concatenates a list of tensors along a specified axis. - * - * .. code:: - * - * y = ConcatNDLayer(x1,x2,....) - * - * Requires at least 2 input and produces 1 output. - * - * Input - * The rank of the input tensors must match and all dimensions also must match, except for the dimension 'axis'. - * - * - * Output - * Same rank as the input. The dimension along "axis", is the sum of the dimensions of the inputs. - * - * example: - * - * in1 : shape (3, 2), value = [[1, 2], [3, 4], [5, 6]] - * in2 : shape (3, 2), value = [[7, 8], [9, 10], [11, 12]] - * axis = 0 - * - * if interleave = False (default) - * output : shape (6, 2) - * output[0:3, :] = in1 - * output[3:6, :] = in2 - * value = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]] - * - * if interleave = True - * output : shape (6, 2) - * output[0::2, :] = in1 - * output[1::2, :] = in2 - * value = [[1, 2], [7, 8], [3, 4], [9, 10], [5, 6], [11, 12]] - * - */ -message ConcatNDLayerParams { - - /** - * Dimension along which to concatenate. Supports negative values of the parameter 'axis'. - */ - int64 axis = 1; - - /** - * (Only available in Core ML Specification >= 5 (iOS >= 14, macOS >= 11.0) - * Interleave option. If True, concatenation is done via interleaving the inputs. - * This requires all inputs to have the exact same shape. - */ - bool interleave = 2; - - -} - -/** - * A layer that performs softmax normalization along a specified axis. - * - * .. code:: - * - * y = SoftmaxNDLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Output shape is same as the input. - */ -message SoftmaxNDLayerParams { - - /** - * Dimension on which the softmax would be performed. Supports negative values of the parameter 'axis'. - */ - int64 axis = 1; - -} - -/** - * A layer that reverses specific dimensions of the input tensor. - * It is similar in functionality to the numpy.flip method. - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - */ -message ReverseLayerParams { - - /** - * Reverses each dimension of the input tensor for which corresponding reverseDim is set to True. - * Requires len(reverseDim) == rank(inputTensor) - */ - repeated bool reverseDim = 1; - -} - -/** - * A layer that reverses variable length slices. - * - * Requires 2 inputs and produces 1 output. - * - * 2 inputs, in order are denoted by "data", "seq_lengths". - * "seq_lenghts" must be a rank 1 tensor, i.e. seq_lengths.shape = (B,) - * which contains the lengths of the amount of sequence to be reversed, for each element of the batch. - * Dimension "batchAxis" in "data" must be equal to B, i.e, - * data.shape[batchAxis] = B. - * - * According to the batch axis, input "data" is first divided into a batch of B inputs, - * each of which is flipped along the dimension "sequenceAxis", by the amount specified in - * "seq_lengths", the second input. - * - * e.g.: - * - * data [shape = (2,4)]: - * [0 1 2 3] - * [4 5 6 7] - * seq_lengths [shape = (2,)]: - * [3, 0] - * batchAxis = 0 - * sequenceAxis = 1 - * - * output [shape = (2,4)]: - * [2 1 0 3] - * [4 5 6 7] - * - * - * data [shape = (2,3,2)]: - * [0 1] - * [2 3] - * [4 5] (slice = 0) - * [6 7] - * [8 9] - * [10 11] (slice = 1) - * seq_lengths [shape = (2,)]: - * [2, 3] - * batchAxis = 0 - * sequenceAxis = 1 - * - * output [shape = (2,3,2)]: - * [2 3] - * [0 1] - * [4 5] (slice = 0) - * [10 11] - * [8 9] - * [6 7] (slice = 1) - * - * Output shape is same as the input. - */ -message ReverseSeqLayerParams { - - int64 batchAxis = 1; // batch axis has to be strictly less than seq_axis - int64 sequenceAxis = 2; - -} - -/** - * A layer that loads data as a parameter and provides it as an output. - * - * .. code:: - * - * y = LoadConstantNDLayer() - * - * Requires no input and produces 1 output. - * - * Output: A tensor with shape as provided in the parameter "shape" - */ -message LoadConstantNDLayerParams { - - /** - * The shape of the constant to be loaded. - */ - repeated uint64 shape = 1; - WeightParams data = 2; - -} - -/** - * A layer that generates an output tensor with a constant value. - * Input is only used to determine the shape of the output. - * This layer is used to allocate a tensor with a dynamic shape (that of the input) and constant value. - * - * Requires 1 input and produces 1 output. - * - * .. code:: - * - * y = FillLikeLayer(x) - * - * Input - * A N-Dimensional tensor, whose values are ignored. Only the shape is used to - * infer the shape of the output. - * - * Output - * A N-Dimensional tensor with the same shape as the input tensor. - * - */ -message FillLikeLayerParams { - - float value = 1; - -} - -/** - * A layer that generates an output tensor with a constant value. - * This layer is used to allocate a tensor with a static shape and constant value. - * - * Requires no input and produces 1 output. - * - * .. code:: - * - * y = FillStaticLayer(x) - * - * Output - * A N-Dimensional tensor of shape "targetShape". - * - */ -message FillStaticLayerParams { - - float value = 1; - repeated uint64 targetShape = 2; - -} - -/** - * A layer that generates an output tensor with a constant value. - * This layer is used to allocate a tensor with a dynamic shape (as specified by the input) and constant value. - * - * Requires 1 input and produces 1 output. - * - * .. code:: - * - * y = FillDynamicLayer(x) - * - * Input - * A rank 1 tensor specifying the shape of the output - * - * Output - * An N-Dimensional tensor with the shape specified by the values in the input tensor. - * - */ -message FillDynamicLayerParams { - - float value = 1; - -} - -/** - * A layer that returns the elements either from tensor x or tensor y, - * depending on the value in the condition tensor. - * It is similar in functionality to the numpy.where method with 3 inputs. - * - * Requires 3 inputs and produces 1 output. - * Inputs, in order, are the condition tensor, x and y. - * - * for each vector index (i,...,j): - * output[i,...,j] = x[i,...,j] if condition[i,...,j] = True - * y[i,...,j] if condition[i,...,j] = False - * - * All the 3 inputs are first broadcasted to a common shape. - * (the shapes must be broadcastable) - * - * output.rank = max(input[0].rank, input[1].rank, input[2].rank) - * - */ -message WhereBroadcastableLayerParams { - -} - -/** - * A layer that computes elementwise trigonometric sine function. - * - * - * .. code:: - * - * y = SinLayer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message SinLayerParams { - -} - -/** - * A layer that computes elementwise trigonometric cosine function. - * - * - * .. code:: - * - * y = CosLayer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message CosLayerParams { - -} - -/** - * A layer that computes elementwise trigonometric tangent function. - * - * - * .. code:: - * - * y = TanLayer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message TanLayerParams { - -} - -/** - * A layer that computes elementwise trigonometric arcsine function. - * - * - * .. code:: - * - * y = AsinLayer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message AsinLayerParams { - -} - -/** - * A layer that computes elementwise trigonometric arccosine function. - * - * - * .. code:: - * - * y = AcosLayer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message AcosLayerParams { - -} - -/** - * A layer that computes elementwise trigonometric arctangent function. - * - * - * .. code:: - * - * y = AtanLayer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message AtanLayerParams { - -} - -/** - * A layer that computes elementwise trigonometric hyperbolic sine function. - * - * - * .. code:: - * - * y = SinhLayer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message SinhLayerParams { - -} - -/** - * A layer that computes elementwise trigonometric hyperbolic cosine function. - * - * - * .. code:: - * - * y = CoshLayer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message CoshLayerParams { - -} - -/** - * A layer that computes elementwise trigonometric hyperbolic tangent function. - * - * - * .. code:: - * - * y = TanhLayer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message TanhLayerParams { - -} - -/** - * A layer that computes elementwise trigonometric hyperbolic arcsine function. - * - * - * .. code:: - * - * y = AsinhLayer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message AsinhLayerParams { - -} - -/** - * A layer that computes elementwise trigonometric hyperbolic arccosine function. - * - * - * .. code:: - * - * y = AcoshLayer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message AcoshLayerParams { - -} - -/** - * A layer that computes elementwise trigonometric hyperbolic arctangent function. - * - * - * .. code:: - * - * y = AtanhLayer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message AtanhLayerParams { - -} -/** - * A layer that raises each element in first tensor to the power of - * corresponding element in the second tensor. - * Supports conventional numpy-like broadcasting. - * - * .. code:: - * - * y = PowBroadcastableLayer(x) - * - * Requires 2 inputs and produces 1 output. - * - * Input - * - First N-Dimensional tensor - * - Second N-Dimensional tensor - * - * Output - * An N-Dimensional tensor with the broadcast shape. - * - */ -message PowBroadcastableLayerParams { - -} - -/** - * A layer that computes the exponential of all elements in the input tensor, with the base 2. - * - * - * .. code:: - * - * y = Exp2Layer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message Exp2LayerParams { - -} - -/** - * A layer that returns a tensor containing the indices of all non-zero - * elements of input tensor. - * It is similar in functionality to the numpy.where method with 1 input. - * - * Requires 1 input and produces 1 output. - * Output is of rank 2, of shape (N,R), - * where N is the number of non-zero elements in the input and R is the rank of the input. - * - * Output contains indices represented in the multi-index form - * - * e.g.: - * input {shape = (4,)}: - * [0 1 0 2] - * output {shape = (2,1)}: - * [1] - * [3] - * - * - * input {shape = (3, 3)}: - * [1 2 1] - * [0 2 2] - * [2 1 0] - * output {shape = (7,1)}: - * [0. 0.] - * [0. 1.] - * [0. 2.] - * [1. 1.] - * [1. 2.] - * [2. 0.] - * [2. 1.] - * - */ -message WhereNonZeroLayerParams { - -} - -/** - * A layer that copies a tensor setting everything outside a central band in - * each inner-most matrix to zero. - * - * Requires 1 input and produces 1 output. - * - * Parameters for matrix_band_part layer - * band(m, n) = (num_lower < 0 || (m-n) <= num_lower) && (num_upper < 0 || (n-m) <= num_upper). - * output[i, j, k, ..., m, n] = band(m, n) * input[i, j, k, ..., m, n] - * - * - * Output shape is same as the input shape. - * Rank of the input must be at least 2. - * For rank higher than 2, the last 2 dimensions are treated as the matrix, while the rest are treated as batch. - */ -message MatrixBandPartLayerParams { - - int64 numLower = 1; - int64 numUpper = 2; - -} - -/** - * A layer that copies a tensor setting everything outside upper triangular to zero. - * - * Requires 1 input and produces 1 output. - * - * Output shape is same as the input shape. - * Rank of the input must be at least 2. - * For rank higher than 2, the last 2 dimensions are treated as the matrix, while the rest are treated as batch. - */ -message UpperTriangularLayerParams { - - int64 k = 1; // Diagonal below which to zero elements. k = 0 (the default) is the main diagonal, k < 0 is below it and k > 0 is above - -} - -/** - * A layer that copies a tensor setting everything outside lower triangular to zero. - * - * Requires 1 input and produces 1 output. - * - * Output shape is same as the input shape. - * Rank of the input must be at least 2. - * For rank higher than 2, the last 2 dimensions are treated as the matrix, while the rest are treated as batch. - */ -message LowerTriangularLayerParams { - - int64 k = 1; // Diagonal above which to zero elements. k = 0 (the default) is the main diagonal, k < 0 is below it and k > 0 is above - -} - -/** - * - * A layer that broadcasts a tensor to a new shape. - * - * Requires 2 inputs and produces 1 output. - * - * First input is broadcast to produce the output, while the second input is only - * used to determine the shape of the output. Values of second input are not used. - * - * Output is a tensor with the same shape as the second input. - * - */ -message BroadcastToLikeLayerParams { - -} - -/** - * - * A layer that broadcasts a tensor to a new shape. - * - * Requires 1 input and produces 1 output. - * - * Output tensor is the broadcasted version of the input and has shape as specified in the - * parameter "targetShape". - */ -message BroadcastToStaticLayerParams { - - repeated uint64 targetShape = 1; - -} - -/** - * - * A layer that broadcasts a tensor to a new shape. - * - * Requires 2 inputs and produces 1 output. - * - * First input is the one that is broadcasted to produce the output. - * Second input is a rank 1 tensor specifying the shape of the output. - * Output tensor has shape as specified by the values in the 2nd input tensor. - */ -message BroadcastToDynamicLayerParams { - -} - -/** - * A layer that performs element-wise addition operation with broadcast support. - * - * Requires 2 inputs and produces 1 output. - */ -message AddBroadcastableLayerParams { - -} - -/** - * A layer that performs element-wise maximum operation with broadcast support. - * - * Requires 2 inputs and produces 1 output. - */ -message MaxBroadcastableLayerParams { - -} - -/** - * A layer that performs element-wise minimum operation with broadcast support. - * - * Requires 2 inputs and produces 1 output. - */ -message MinBroadcastableLayerParams { - -} - -/** - * A layer that performs element-wise modular operation with broadcast support. - * - * Requires 2 inputs and produces 1 output. - */ -message ModBroadcastableLayerParams { - -} - -/** - * A layer that performs element-wise floor division operation with broadcast support. - * - * Requires 2 inputs and produces 1 output. - */ -message FloorDivBroadcastableLayerParams { - -} - -/** - * A layer that performs element-wise subtract operation with broadcast support. - * - * Requires 2 inputs and produces 1 output. - */ -message SubtractBroadcastableLayerParams { - -} - -/** - * A layer that performs element-wise multiply operation with broadcast support. - * - * Requires 2 inputs and produces 1 output. - */ -message MultiplyBroadcastableLayerParams { - -} - -/** - * A layer that performs element-wise division operation with broadcast support. - * - * Requires 2 inputs and produces 1 output. - */ -message DivideBroadcastableLayerParams { - -} - -/** - * Gather layer that gathers elements from the first input, along a specified axis, - * at indices specified in the second input. - * It is similar in functionality to the numpy.take method. - * - * Requires 2 inputs and produces 1 output. - * - * Given two inputs, 'data' and 'indices', gather the slices of 'data' - * and store into output. - * e.g. - * for i in [0, length(indices) - 1] - * output[i] = data[indices[i]] (1-D case, axis=0) - * - * if axis = 0: - * for each vector index (i,...,j) - * output[i,...,j,:,..,:] = data[indices[i,...,j],:,..,:] - * - * output.rank = (data.rank - 1) + indices.rank - * - * Negative indices and negative axis are supported. - * - * e.g: - * - * data shape = (2, 3) - * indices shape = (6, 8) - * axis = 0 - * output shape = (6, 8) + (3,) = (6, 8, 3) - * - * data shape = (2, 3, 5) - * indices shape = (6, 8) - * axis = 1 - * output shape = (2,) + (6, 8) + (5,) = (2, 6, 8, 5) - * - */ -message GatherLayerParams { - - int64 axis = 1; - -} - -/* - * Scatter accumulation mode. - */ -enum ScatterMode { - - SCATTER_UPDATE = 0; - SCATTER_ADD = 1; /// add - SCATTER_SUB = 2; /// subtract - SCATTER_MUL = 3; /// multiply - SCATTER_DIV = 4; /// divide - SCATTER_MAX = 5; /// maximum - SCATTER_MIN = 6; /// minimum - -} - -/* - * A layer that scatters data into a new tensor according to indices from the input. - * This is the inverse operation of Gather. - * - * Requires 3 inputs and produces 1 output. - * - * Output is initialized with the first input. - * Then updated with the values in the third input, at indices specified by the second input. - * - * An example when axis=0: - * Given three inputs, in order, "container", "indices", "updates", where - * - * - "container" is a rank R+1 tensor of shape [D_0, D_1, ..., D_R], which - * contains D_0 number of tensors, each with shape [D_1, ..., D_R]. - * - * - "indices" is a rank 1 tensor with shape [N], where N is the number of updates. - * The values in this tensor must be in the range [0, D_0 - 1]. (negative indexing is supported) - * - * - "updates" is a rank R+1 tensor with shape [N, D_1, ..., D_R], which represents - * a total number of N tensors, each of shape [D_1, ..., D_R]. - * - * The effect of this operation is as follows: - * - * output = container; - * For each i in 0, ..., N - 1 - * output[indices[i], :, ..., :] = updates[i, :, ..., :] // if mode == "SCATTER_UPDATE" - * - * or - * For each i in 0, ..., N - 1 - * output[indices[i], :, ..., :] += updates[i, :, ..., :] // if mode == "SCATTER_ADD" - * - * etc - * - * When "indices" is a tensor of rank greater than 1, the equation becomes (for axis=0): - * For each vector index (i,...,j) - * output[indices[i,...,j],...] -= updates[i,...,j,...] // if mode == "SCATTER_SUB" - * - * - * The output has the same shape as the first input. - * "indices" input must have rank less than or equal to the "updates" input and its shape - * must be a subset of the the shape of the "updates" input. - * - * e.g: - * - * container shape = (4, 3) - * indices shape = (5, 2, 3) - * updates shape = (4, 5, 2, 3) - * axis = 1 - * output shape = (4, 3) - * - * container shape = (4, 4, 3) - * indices shape = (6,) - * updates shape = (4, 6, 3) - * axis = -2 - * output shape = (4, 4, 3) - * - * container shape = (5,) - * indices shape = (5, 7, 5, 6) - * updates shape = (5, 7, 5, 6) - * axis = -1 - * output shape = (5,) - */ - -message ScatterLayerParams { - - int64 axis = 1; - ScatterMode mode = 2; /// mode of accumulation. - -} - -/** - * A layer that gathers elements from the first input, 'params', at the multi-indices specified - * by the second input, 'indices'. - * - * Requires 2 inputs and produces 1 output. - * - * 'params' = input[0], 'indices' = input[1] - * - * 'indices' is a rank K+1 tensor of shape [I_0, I_1, .., I_(K-1), I_K] which is viewed as a collection of - * indices of (I_0 * I_1 * ... * I_(K-1)) points in the I_K dimensional space. For instance, the multi-index of the first point - * is indices[0,0,...,0,:]. - * - * Here is how the output is constructed: - * - * for i = 0,1,...,(I_0-1) - * ... - * for j = 0,1,....,(I_(K-1)-1) - * output[i,....,j,:,:,..,:] = params[indices[i,...,j,:], :,:,..,:] - * - * Hence, output shape is [I_0, I_1,...,I(K-1)] + params.shape[I_K:] - * - * output.rank = indices.rank - 1 + params.rank - indices.shape[-1] - * - * e.g: - * - * input[0] shape = (4, 2, 3, 4) - * input[1] shape = (6, 2) - * output shape = (6,) + (3, 4) = (6, 3, 4) - * - * input[0] shape = (3, 3, 3, 4, 7) - * input[1] shape = (3, 5) - * output shape = (3,) + () = (3,) - * - * input[0] shape = (5, 3, 2, 5) - * input[1] shape = (2, 7, 3, 2) - * output shape = (2, 7, 3) + (2, 5) = (2, 7, 3, 2, 5) - * - */ -message GatherNDLayerParams { - -} - -/* - * A layer that scatters data into a new tensor according to multi-indices from the input. - * This is the inverse operation of GatherND. - * - * Requires 3 inputs and produces 1 output. - * 3 inputs, in order are denoted as "container", "indices", "updates". - * - * 'indices' is a rank K+1 tensor of shape [I_0, I_1, .., I_(K-1), I_K] which is viewed as a collection of - * indices of (I_0 * I_1 * ... * I_(K-1)) points in the I_K dimensional space. For instance, the multi-index of the first point - * is indices[0,0,...,0,:]. - * - * container.rank >= I_K - * updates.rank = K + (container.rank - I_K) - * shape of 'updates' = [I_0, I_1,...,I(K-1)] + container.shape[I_K:] - * - * output = container - * For each vector index (i,...,j) s.t. 0<=i shape: (3,) - * reps = N/A [Ignored] - * output shape = (2, 8, 12) - * - */ -message TileLayerParams { - - repeated uint64 reps = 1; - -} - -/** - * A layer that returns the shape of an input tensor. - * - * Requires 1 input and produces 1 output. - * - * Input: a tensor. - * Output: a vector of length R, where R is the rank of the input tensor - * Output is always a rank 1 tensor. - */ -message GetShapeLayerParams { - -} - -/** - * A layer that computes the Gauss error function, - * which is defined as: - * - * .. math:: - * f(x) = \dfrac{1}{\sqrt{\pi}}\int_{-x}^{x}{e^{-t^2}dt} - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - */ -message ErfLayerParams { - -} - -/** - * A layer that evaluates the Gaussian Error Linear Unit (GELU) activation. - * Following equations are used to compute the activation based on the value of the "mode" parameter: - * - * mode == 'EXACT': - * .. math:: - * f(x) = 0.5x\left ( 1+\rm{erf}\left ( \frac{x}{\sqrt{2}} \right ) \right ) - * - * mode == 'TANH_APPROXIMATION': - * .. math:: - * f(x) = 0.5x\left ( 1+\rm{tanh}\left ( \sqrt{2/\pi}\left ( x + 0.044715x^3 \right ) \right ) \right ) - * - * mode == 'SIGMOID_APPROXIMATION': - * .. math:: - * f(x) = x*\rm{sigmoid}(1.702x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message GeluLayerParams { - - enum GeluMode { - - EXACT = 0; - TANH_APPROXIMATION = 1; - SIGMOID_APPROXIMATION = 2; - - } - - GeluMode mode = 1; /// mode of GELU operation. - -} - -/** - * RangeStatic layer that returns a tensor that contains evenly spaced values. - * It is similar in functionality to the numpy.arange method. - * - * Requires no input and produces 1 output. - * Output is a rank 1 tensor. - */ -message RangeStaticLayerParams { - - float endValue = 1; - float startValue = 2; - float stepSizeValue = 3; - -} - -/** - * A layer that returns a tensor that contains evenly spaced values. - * Its functionality is similar to the numpy.arange method. - * - * Requires at least 1 input, up to a maximum of 3 inputs. - * Produces 1 output, which is a rank 1 tensor. - * - * Each input must be a scalar, or rank 1 and shape (1,). - * - * The first input represents the "endValue". - * The second input, if present, corresponds to "startValue". In this case the value of the "startValue" parameter is ignored. - * The third input, if present, corresponds to "stepSizeValue". In this case the value of the "stepSizeValue" parameter is ignored. - * - */ -message RangeDynamicLayerParams { - - float startValue = 2; - float stepSizeValue = 3; - -} - -/** - * A layer that returns a tensor containing all windows of size ``windowSize`` - * separated by ``step`` along the dimension ``axis``. - * - * .. code:: - * - * y = SlidingWindows(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * An N-Dimensional tensor. - * - * Output - * An (N+1)-Dimensional tensor. - * - * This operation behaves as following: - * - if axis = 0 & input is rank 1 (L,). Output shape will be (M, W). - * - if axis = 1 & input is rank 3 (B1, L, C1). Output shape will be (B1, M, W, C1) - * - if axis = 2 & input is rank 5 (B1, B2, L, C1, C2) --> (B1 * B2, L, C1 * C2) --> (B1 * B2, M, W, C1 * C2). Output shape will be (B1, B2, M, W, C1, C2) - * - etc. - * where - * - L, C, B refer to input length, feature dimension length & batch size respectively - * - W is the window size. - * - M is the number of windows/slices calculated as M = (L - W) / step + 1 - */ -message SlidingWindowsLayerParams { - - int64 axis = 1; - uint64 windowSize = 2; - uint64 step = 3; - -} - -/** - * A layer that applies layer normalization over the input tensor. - * - * Requires 1 input and produces 1 output. - * - * output = gamma * (input - computed_mean) / (sqrt(computed_variance + eps)) + beta - * - * Parameters - * normalizedShape: subset of the input shape, along with layer norm is performed, rest of the input shape is treated as the batch dimension. The mean and variance are computed for the input, over the last few dimensions as specified by the normalizedShape parameter. - * gamma: must have shape = "normalizedShape" - * beta: must have shape = "normalizedShape" - * eps: small constant to avoid division by 0 - * - * Output shape is same as the input. - * - * e.g.: - * input shape = (10,5) - * normalized shape = (5,) or (10,5) - * - * input shape = (10,5,6,7) - * normalized shape = (7,) or (6,7) or (5,6,7) or (10,5,6,7) - */ -message LayerNormalizationLayerParams { - - repeated int64 normalizedShape = 1; - float eps = 2; - WeightParams gamma = 3; - WeightParams beta = 4; - -} - -/** - * Non maximum suppression (NMS) layer. - * Applies the non maximum suppression algorithm to input bounding box coordinates. - * The effect of this layer is similar to the functionality of the "NonMaximumSuppression" - * model type (for details please see NonMaximumSuppression.proto) with a couple of differences. - * One, this is a layer in a neural network model, whereas that is a different model type. Second, - * this layer supports a batch of bounding boxes. - * - * The NMS layer requires at least 2 inputs, and up to a maximum of 5 inputs. It produces 4 outputs. - * Following is the description of inputs and outputs: - * - * input 1, shape (B,N,4): coordinates of N boxes, for a batch size B. - * input 2, shape (B,N,C): class scores for each box. C can be 1 when there is only 1 score per box, i.e., no class specific score. - * - * input 3, optional, shape (1,): IoU threshold. When present, it overwrites the value provided in layer parameter "iouThreshold". - * input 4, optional, shape (1,): Score threshold. When present, it overwrites the value provided in layer parameter "scoreThreshold". - * input 5, optional, shape (1,): Maximum number of boxes. When present, it overwrites the value provided in layer parameter "maxBoxes". - * - * output 1, shape (B,maxBoxes,4): box coordinates, corresponding to the surviving boxes. - * output 2, shape (B,maxBoxes,C): box scores, corresponding to the surviving boxes. - * output 3, shape (B,maxBoxes): indices of the surviving boxes. Hence it will have values in the range [0,N-1], except for padding. - * output 4, shape (B,): number of boxes selected after the NMS algorithm, for each batch. - * - * When surviving boxes are less than "maxBoxes", the first 3 outputs are padded. - * For the first two outputs, the padding is done using values 0, whereas for the third output the - * padding value used is -1, since the output values represent indices. - * - * If no box survives, that is, all the scores are below the "scoreThreshold", - * then for that batch, number of boxes (value of the fourth output) will be 1. The first 3 outputs will - * correspond to the box with the highest score. This is to avoid generating an "empty" output. - * - * The four values that describe the box dimensions are (in order): - * - * - x (center location of the box along the horizontal axis) - * - y (center location of the box along the vertical axis) - * - width (size of box along the horizontal axis) - * - height (size of box on along the vertical axis) - * - * In each batch, - * the N scores for N boxes, used for suppression, are generated by taking the max of the matrix (N,C) - * along the columns. - * If "perClassSuppression" flag is false, suppression happens across all classes. - * If "perClassSuppression" flag is true, each box is assigned to the class with the highest - * score and then the suppression happens separately for boxes within the same class. - * - * Note that the 4th output can be used to dynamically slice the first 3 outputs, in case - * the padded outputs are not required. - * - */ -message NonMaximumSuppressionLayerParams { - /** - * The intersection over union (IoU) threshold over which boxes are suppressed. - */ - float iouThreshold = 1; - - /** - * Before IoU suppression is performed, boxes with class scores below this threshold are rejected. - */ - float scoreThreshold = 2; - - /** - * The maximum number of boxes to be given out as output. - * If the number of surviving boxes are less, output is padded up to this number. - */ - uint64 maxBoxes = 3; - - /** - * If true, suppression is performed independently within boxes of each class. - */ - bool perClassSuppression = 4; -} - -/** - * A layer that performs element-wise clamped ReLU operation. - * - * Requires 1 input and produces 1 output. - * - * This function has the following formula: - * - * .. math:: - * f(x) = \begin{cases} - * \text{min}(\text{beta},x) \;\; \text{if} \;\; x \geq 0\\ - * \text{min}(\text{beta} ,\text{alpha}\cdot x) \;\; \text{if} \;\; x<0 - * \end{cases} - * - * Output shape is same as the input. - * - * Available (iOS >= 14, macOS >= 11.0, watchOS >= 7) - */ -message ClampedReLULayerParams { - - float alpha = 1; - float beta = 2; - -} - -/** -* A layer that returns the indices that would sort the input tensor, along a specified axis. -* -* Requires 1 input and produces 1 output. -* -* Output has the same rank and shape as the input. -* -* Value of "axis" must be positive and less than the rank of the input. -* -* e.g.: -* -* input shape = (5,) -* axis = 0 -* input values = [3.1, 5.4, 32.9, 3.2, 77.0] -* output shape = (5,) -* output values = [0, 3, 1, 2, 4], descending = False -* output values = [4, 2, 1, 3, 0], descending = True -* -* input shape = (2,3) -* axis = 1 -* input values = [[3, 5, 32], [3, 77, 6]] -* output shape = (2,3) -* output values = [[0, 1, 2], [0, 2, 1]], descending = False -* output values = [[2, 1, 0], [1, 2, 0]], descending = True -* -*/ -message ArgSortLayerParams { - - int64 axis = 1; /// must be between [0, input_rank - 1] - bool descending = 2; - -} - -/** - * A layer that does slice operation by providing size to be extracted - * from the given input tensor. - * - * Requires 2 inputs and produces 1 output. - * Rank of the output is same as the rank of the first input. - * - * The 1st input represents the tensor to be sliced. - * The 2nd input represents the beginning index to be sliced from. - * - * Example: - * Input 1: x (x.shape = (2, 3, 4)) - * Input 2: begin - * size: 2 - * axis: 1 - * - * Output: x[:, begin:begin+2, :] - * - */ -message SliceBySizeLayerParams { - - int64 size = 2; - int64 axis = 3; - -} - - -/// Neural Network Specializations -/// ------------------------------ - -/** - * A neural network specialized as a classifier. - */ -message NeuralNetworkClassifier { - - repeated NeuralNetworkLayer layers = 1; - repeated NeuralNetworkPreprocessing preprocessing = 2; - - // use this enum value to determine the input tensor shapes to the neural network, for multiarray inputs - NeuralNetworkMultiArrayShapeMapping arrayInputShapeMapping = 5; - - // use this enum value to determine the input tensor shapes to the neural network, for image inputs - NeuralNetworkImageShapeMapping imageInputShapeMapping = 6; - - NetworkUpdateParameters updateParams = 10; - - // The set of labels for every possible class. - oneof ClassLabels { - StringVector stringClassLabels = 100; - Int64Vector int64ClassLabels = 101; - } - - // The name of the output blob containing the probability of each class. - // In other words, the score vector. Must be a 1-D tensor with the same - // number and order of elements as ClassLabels. - string labelProbabilityLayerName = 200; -} - - -/** - * A layer that computes the one hot representation of the input. - * - * Requires 1 or 2 inputs and produces 1 output. - * Rank of the output is one more than the first input. - * If the second input is present, it is used to determine the value of "oneHotVectorSize" and the parameter "oneHotVectorSize" is ignored. - * - * Input values correspond to indices and should typically be in the range [0,"oneHotVectorSize" -1]. If it is outside this range, a vector of all "offValue" will be chosen. - * - * Typically one hot vectors contain 0s everywhere, except 1 at the index that the input corresponds to. - * However, instead of 0, any float value could be generated by using the "offValue" parameter. - * Similarly, instead of 1, any other value can be used by employing the "onValue" parameter. - * - * e.g.: - * input shape: (10,), "oneHotVectorSize" : 32, axis=-1, then output shape will be (10,32) - * input shape: (10,23), "oneHotVectorSize" : 32, axis=1, then output shape will be (10,32,23) - * input shape: (10,), "oneHotVectorSize" : 32, axis=0, then output shape will be (32,10) - * - * input shape: (2,), "oneHotVectorSize" : 4, axis=-1, then output shape will be (2,4) - * say input values = [2, 0], and "onValue" = 5, and "offValue" = -1, then output will be: - * [-1, -1, 5, -1 - * 5, -1, -1, -1] - * - * say input values = [2, -1], and "onValue" = 5, and "offValue" = -1, then output will be: - * [-1, -1, 5, -1 - * -1, -1, -1, -1] - * - * Available (iOS >= 14, macOS >= 11.0, watchOS >= 7) - */ - -message OneHotLayerParams { - - uint64 oneHotVectorSize = 1; /// size of the one hot vector - int64 axis = 2; /// negative indexing is supported. It refers to the axis in the output tensor. - float onValue = 3; - float offValue = 4; -} - - -/** - * A layer that computes the cumsum values of the input along a given axis. - * - * Requires 1 or 2 inputs and produces 1 output. - * - * Output shape and rank is same as the first input. - * If the second input is present, it is used to determine the value of "axis" and the parameter "axis" is ignored. - * - * e.g.: - * Input shape = (3,), values it has: [4, 6, 7] - * - * Then output values will be: - * - * if "excludeFinalSum" = False and "reverse" = False: - * output values : [4, 10, 17] - * - * if "excludeFinalSum" = True and "reverse" = False: - * output values : [0, 4, 10] - * - * if "excludeFinalSum" = False and "reverse" = True: - * output values : [17, 13, 7] - * - * if "excludeFinalSum" = True and "reverse" = True: - * output values : [13, 7, 0] - * - * - * Available (iOS >= 14, macOS >= 11.0, watchOS >= 7) - */ - - -message CumSumLayerParams { - - int64 axis = 1; /// negative indexing is supported - - /// if true, the first element of the output is 0, and the last element contains the sum of the input up to the penultimate value - /// if false, the first element of the output is same as the input and the last element is the sum of all the input values - /// (this behavior is reversed when "reverse" flag is True) - bool excludeFinalSum = 2; - - bool reverse = 3; /// if true, cumsum is performed in the opposite direction -} - - -/** - * A neural network specialized as a regressor. - */ -message NeuralNetworkRegressor { - - repeated NeuralNetworkLayer layers = 1; - repeated NeuralNetworkPreprocessing preprocessing = 2; - - // use this enum value to determine the input tensor shapes to the neural network, for multiarray inputs - NeuralNetworkMultiArrayShapeMapping arrayInputShapeMapping = 5; - - // use this enum value to determine the input tensor shapes to the neural network, for image inputs - NeuralNetworkImageShapeMapping imageInputShapeMapping = 6; - - NetworkUpdateParameters updateParams = 10; - -} - -/// --------------------------------------------------------- -/// On-device Training related messages -/// --------------------------------------------------------- - -/** - * Details on how the network will be updated - */ -message NetworkUpdateParameters { - - repeated LossLayer lossLayers = 1; - Optimizer optimizer = 2; - Int64Parameter epochs = 3; - - /** - * Describes whether to shuffle the batch of data between epochs. - */ - BoolParameter shuffle = 10; - - /** - * The seed to be used in an associated random number generator. - */ - Int64Parameter seed = 20; -} - -/** - * Loss layer - categorical cross entropy and mean squared error are the only supported loss functions currently - */ -message LossLayer { - - string name = 1; - oneof LossLayerType { - - CategoricalCrossEntropyLossLayer categoricalCrossEntropyLossLayer = 10; - MeanSquaredErrorLossLayer meanSquaredErrorLossLayer = 11; - - } - -} - -/** - * Categorical cross entropy loss layer - * Categorical cross entropy is used for single label categorization (only one category is applicable for each data point). - * - * The input is a vector of length N representing the distribution over N categories. It must be the output of a softmax. - * - * The target is a single value representing the true category or class label. If the target is the predictedFeatureName of a neural network classifier it will be inverse mapped to the corresponding categorical index for you. - * - * math: - * Loss_{CCE}(input, target) = -\sum_{i=1}^{N} (target == i) log( input[i] ) = - log (input[target]) - */ -message CategoricalCrossEntropyLossLayer { - - string input = 1; - string target = 2; - -} - -/** - * Mean squared error loss layer, - * specifying input and target - */ -message MeanSquaredErrorLossLayer { - - string input = 1; - string target = 2; - -} - -/** - * Optimizer - stochastic gradient descent and adam are the only supported optimizers currently - */ -message Optimizer { - - oneof OptimizerType { - - SGDOptimizer sgdOptimizer = 10; - AdamOptimizer adamOptimizer = 11; - - } - -} - -/** - * Stochastic gradient descent optimizer, - * specifying configurable learning rate, mini batch size, and momentum - */ -message SGDOptimizer { - - DoubleParameter learningRate = 1; - Int64Parameter miniBatchSize = 2; - DoubleParameter momentum = 3; - -} - -/** - * Adam optimizer, - * specifying configurable learning rate, mini batch size, betas, and eps - */ -message AdamOptimizer { - - DoubleParameter learningRate = 1; - Int64Parameter miniBatchSize = 2; - DoubleParameter beta1 = 3; - DoubleParameter beta2 = 4; - DoubleParameter eps = 5; - -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/NonMaximumSuppression.proto b/onnxruntime/core/providers/coreml/mlmodel_format/NonMaximumSuppression.proto deleted file mode 100644 index c98949a0c2e21..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/NonMaximumSuppression.proto +++ /dev/null @@ -1,187 +0,0 @@ -// Copyright (c) 2018, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification; - -/* -* Non-maximum suppression of axis-aligned bounding boxes. -* -* This is used primarily for object detectors that tend to produce multiple -* boxes around a single object. This is a byproduct of the detector's -* robustness to spatial translation. If there are two or more bounding boxes -* that are very similar to one another, the algorithm should return only a -* single representative. -* -* Similarity between two bounding boxes is measured by intersection-over-union -* (IOU), the fraction between the area of intersection and area of the union. -* Here is an example where the areas can be calculated by hand by counting glyphs:: -* -* +-------+ +-------+ -* | | | | -* | +------+ +--+ | +---+ -* | | | | | | | | -* +-------+ | +--+ +----+ | -* | | | | -* +------+ +------+ -* Intersection Union -* IOU: 0.16 = 12 / 73 -* -* All IOU scores are fractions betwen 0.0 (fully disjoint) and 1.0 (perfect -* overlap). The standard algorithm (PickTop) is defined as follows: -* -* 1. Sort boxes by descending order of confidence -* 2. Take the top one and mark it as keep -* 3. Suppress (mark it as discard) all boxes within a fixed IOU radius of the -* keep box -* 4. Go to 2 and repeat on the subset of boxes not already kept or discarded -* 5. When all boxes are processed, output only the ones marked as keep -* -* Before the algorithm, boxes that fall below the confidence threshold are -* discarded. -*/ -message NonMaximumSuppression { - // Suppression methods: - /* - * Pick the bounding box of the top confidence, suppress all within a radius. - */ - message PickTop { - /* - * Suppression is only done among predictions with the same label - * (argmax of the confidence). - */ - bool perClass = 1; - } - - /* - * Choose which underlying suppression method to use - */ - oneof SuppressionMethod { - PickTop pickTop = 1; - } - - /* - * Optional class label mapping. - */ - oneof ClassLabels { - StringVector stringClassLabels = 100; - Int64Vector int64ClassLabels = 101; - } - - /* - * This defines the radius of suppression. A box is considered to be within - * the radius of another box if their IOU score is less than this value. - */ - double iouThreshold = 110; - - /* - * Remove bounding boxes below this threshold. The algorithm run-time is - * proportional to the square of the number of incoming bounding boxes - * (O(N^2)). This threshold is a way to reduce N to make the algorithm - * faster. The confidence threshold can be any non-negative value. Negative - * confidences are not allowed, since if the output shape is specified to be - * larger than boxes after suppression, the unused boxes are filled with - * zero confidence. If the prediction is handled by Core Vision, it is also - * important that confidences are defined with the following semantics: - * - * 1. Confidences should be between 0 and 1 - * 2. The sum of the confidences for a prediction should not exceed 1, but is - * allowed to be less than 1 - * 3. The sum of the confidences will be interpreted as the confidence of - * any object (e.g. if the confidences for two classes are 0.2 and 0.4, - it means there is a 60% (0.2 + 0.4) confidence that an object is - present) - */ - double confidenceThreshold = 111; - - /* - * Set the name of the confidence input. - * - * The input should be a multi-array of type double and shape N x C. N is - * the number of boxes and C the number of classes. Each row describes the - * confidences of each object category being present at that particular - * location. Confidences should be nonnegative, where 0.0 means the highest - * certainty the object is not present. - * - * Specifying shape is optional. - */ - string confidenceInputFeatureName = 200; - - /* - * Set the name of the coordinates input. - * - * The input should be a multi-array of type double and shape N x 4. The - * rows correspond to the rows of the confidence matrix. The four values - * describe (in order): - * - * - x (center location of the box along the horizontal axis) - * - y (center location of the box along the vertical axis) - * - width (size of box along the horizontal axis) - * - height (size of box on along the vertical axis) - * - * Specifying shape is optional. - */ - string coordinatesInputFeatureName = 201; - - /* - * The iouThreshold can be optionally overridden by specifying this string - * and providing a corresponding input of type double. This allows changing - * the value of the parameter during run-time. - * - * The input should be a scalar double between 0.0 and 1.0. Setting it to 1.0 - * means there will be no suppression based on IOU. - */ - string iouThresholdInputFeatureName = 202; - - /* - * The confidenceThreshold can be optionally overridden by specifying this - * string and providing a corresponding input. This allows changing the - * value of the parameter during run-time, which can aid setting it just - * right for a particular use case. - * - * The input should be a scalar double with nonnegative value. - */ - string confidenceThresholdInputFeatureName = 203; - - /* - * Set the name of the confidence output. The output will be the same type - * and shape as the corresponding input. The only difference is that the - * number of rows may have been reduced. - * - * Specifying shape is optional. One reason to specify shape is to limit - * the number of output boxes. This can be done is several ways: - * - * Fixed shape: - * The output can be pinned to a fixed set of boxes. If this number is larger - * than the number of boxes that would have been returned, the output is padded - * with zeros for both confidence and coordinates. Specifying a fixed shape - * can be done by setting either shape (deprecated) or allowedShapes set to - * fixedsize. - * - * Min/max: - * It is also possible to set both a minimum and a maximum. The same zero-padding - * as for fixed shape is applied when necessary. Setting min/max is done by defining - * two allowedShapes, where the first dimension uses a rangeofsizes defining lowerbound - * and upperbound. - */ - string confidenceOutputFeatureName = 210; - - /* - * Set the name of the coordinates output. The output will be the same type - * and shape as the corresponding input. The only difference is that the - * number of rows may have been reduced. - * - * Specifying shape is optional. See confidence output for a more detailed - * description. Note that to achieve either fixed shape output or a - * constraint range of boxes, only one of confidence or coordinates need to - * set a shape. Both shapes are allowed to be defined, but in such case they - * have to be consistent along dimension 0. - */ - string coordinatesOutputFeatureName = 211; -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/Normalizer.proto b/onnxruntime/core/providers/coreml/mlmodel_format/Normalizer.proto deleted file mode 100644 index 627f7e2e3afd7..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/Normalizer.proto +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -package CoreML.Specification; - -/** - * A normalization preprocessor. - */ -message Normalizer { - /** - * There are three normalization modes, - * which have the corresponding formulas: - * - * Max - * .. math:: - * max(x_i) - * - * L1 - * .. math:: - * z = ||x||_1 = \sum_{i=1}^{n} |x_i| - * - * L2 - * .. math:: - * z = ||x||_2 = \sqrt{\sum_{i=1}^{n} x_i^2} - */ - enum NormType { - LMax = 0; - L1 = 1; - L2 = 2; - } - - NormType normType = 1; -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/OneHotEncoder.proto b/onnxruntime/core/providers/coreml/mlmodel_format/OneHotEncoder.proto deleted file mode 100644 index f47cf28166222..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/OneHotEncoder.proto +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification; - -/** - * Transforms a categorical feature into an array. The array will be all - * zeros expect a single entry of one. - * - * Each categorical value will map to an index, this mapping is given by - * either the ``stringCategories`` parameter or the ``int64Categories`` - * parameter. - */ -message OneHotEncoder { - enum HandleUnknown { - ErrorOnUnknown = 0; - IgnoreUnknown = 1; // Output will be all zeros for unknown values. - } - - /** - * Mapping to be used for the encoding. The position of the category in - * the below vector determines where the single one entry will be in the - * output. - */ - oneof CategoryType { - StringVector stringCategories = 1; - Int64Vector int64Categories = 2; - } - - // Output can be a dictionary with only one entry, instead of an array. - bool outputSparse = 10; - - HandleUnknown handleUnknown = 11; -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/Parameters.proto b/onnxruntime/core/providers/coreml/mlmodel_format/Parameters.proto deleted file mode 100644 index ed1ebe525181f..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/Parameters.proto +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification; - -/** - * Int64 parameter, - * consisting of a default int64 value, and allowed range or set of values - * value is unbounded if AllowedValues is not set. - */ -message Int64Parameter { - int64 defaultValue = 1; - oneof AllowedValues { - Int64Range range = 10; - Int64Set set = 11; - } -} - -/** - * Double parameter, - * consisting of a default double value, and allowed range of values - * value is unbounded if AllowedValues is not set. - */ -message DoubleParameter { - double defaultValue = 1; - oneof AllowedValues { - DoubleRange range = 10; - } -} - -/** - * String parameter, - * A default string value must be provided - */ -message StringParameter { - string defaultValue = 1; -} - -/** - * String parameter, - * A default bool value must be provided - */ -message BoolParameter { - bool defaultValue = 1; -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/README.md b/onnxruntime/core/providers/coreml/mlmodel_format/README.md deleted file mode 100644 index e5eba65f982ad..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# Core ML Model Format Specification -This directory contains the protobuf message definitions that comprise the Core ML model document (``.mlmodel``) format. - -The top-level message is ``Model``, which is defined in ``Model.proto``. -Other message types describe data structures, feature types, feature engineering model types, and predictive model types. - -# Update the Core ML Model Format Specification -Please do not modify protobuf message definitions, they are copied directly from [Core ML Tools](https://github.com/apple/coremltools) repository. - -To update the Core ML Model Format Schema schema files to a more recent version: -1. Delete all the protobuf message definitions (`.proto`) from this directory. -2. Copy the new version of protobuf message definitions (`.proto`) from the `mlmodel/format/` directory of preferred coremltools release branch. - -# Core ML Model Format Schema version history -## [coremltools 4.0](https://github.com/apple/coremltools/releases/tag/4.0) -[Core ML Model Format Specification](https://github.com/apple/coremltools/tree/4.0/mlmodel/format) diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/SVM.proto b/onnxruntime/core/providers/coreml/mlmodel_format/SVM.proto deleted file mode 100644 index 932a4ec216682..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/SVM.proto +++ /dev/null @@ -1,195 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification; - -/// Kernel Definitions -/// ------------------ - -/** - * A linear kernel. - * - * This function has the following formula: - * - * .. math:: - * K(\boldsymbol{x}, \boldsymbol{x'}) = \boldsymbol{x}^T \boldsymbol{x'} - */ -message LinearKernel { -} - -/** - * A Gaussian radial basis function (RBF) kernel. - * - * This function has the following formula: - * - * .. math:: - * K(\boldsymbol{x}, \boldsymbol{x'}) = \ - * \exp(-\gamma || \boldsymbol{x} - \boldsymbol{x'} ||^2 ) - * - */ -message RBFKernel { - double gamma = 1; -} - -/** - * A polynomial kernel. - * - * This function has the following formula: - * - * .. math:: - * K(\boldsymbol{x}, \boldsymbol{x'}) = \ - * (\gamma \boldsymbol{x}^T \boldsymbol{x'} + c)^{degree} - */ -message PolyKernel { - int32 degree = 1; - double c = 2; - double gamma = 3; -} - -/** - * A sigmoid kernel. - * - * This function has the following formula: - * - * .. math:: - * K(\boldsymbol{x}, \boldsymbol{x'}) = \ - * \tanh(\gamma \boldsymbol{x}^T \boldsymbol{x'} + c) - */ -message SigmoidKernel { - double gamma = 1; - double c = 2; -} - -/** - * A kernel. - */ -message Kernel { - oneof kernel { - LinearKernel linearKernel = 1; - RBFKernel rbfKernel = 2; - PolyKernel polyKernel = 3; - SigmoidKernel sigmoidKernel = 4; - } -} - - -/// Support Vector Definitions -/// -------------------------- - -/** - * A sparse node. - */ -message SparseNode { - int32 index = 1; // 1-based indexes, like libsvm - double value = 2; -} - -/** - * A sparse vector. - */ -message SparseVector { - repeated SparseNode nodes = 1; -} - -/** - * One or more sparse support vectors. - */ -message SparseSupportVectors { - repeated SparseVector vectors = 1; -} - -/** - * A dense vector. - */ -message DenseVector { - repeated double values = 1; -} - -/** - * One or more dense support vectors. - */ -message DenseSupportVectors { - repeated DenseVector vectors = 1; -} - -/** - * One or more coefficients. - */ -message Coefficients { - repeated double alpha = 1; -} - -/** - * A support vector regressor. - */ -message SupportVectorRegressor { - Kernel kernel = 1; - - // Support vectors, either sparse or dense format - oneof supportVectors { - SparseSupportVectors sparseSupportVectors = 2; - DenseSupportVectors denseSupportVectors = 3; - } - - // Coefficients, one for each support vector - Coefficients coefficients = 4; - - double rho = 5; -} - -/** - * A support vector classifier - */ -message SupportVectorClassifier { - Kernel kernel = 1; - - /** - * The number of support vectors for each class. - */ - repeated int32 numberOfSupportVectorsPerClass = 2; - - /** - * The support vectors, in either sparse or dense format. - */ - oneof supportVectors { - SparseSupportVectors sparseSupportVectors = 3; - DenseSupportVectors denseSupportVectors = 4; - } - - /** - * The coefficients, essentially a two dimensional array of - * size: (numberOfClasses-1) by (total number of support vectors) - */ - repeated Coefficients coefficients = 5; - - /** - * Constants for decision function, - * with K*(K-1) / 2 elements, - * where K is the number of classes. - */ - repeated double rho = 6; - - /** - * Pairwise probability information for A vs B classifier. - * Total of K*(K-1)/2 elements where K is the number of classes. - * These fields are optional, - * and only required if you want probabilities or multi class predictions. - */ - repeated double probA = 7; - repeated double probB = 8; - - /** - * Class label mapping. - */ - oneof ClassLabels { - StringVector stringClassLabels = 100; - Int64Vector int64ClassLabels = 101; - } -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/Scaler.proto b/onnxruntime/core/providers/coreml/mlmodel_format/Scaler.proto deleted file mode 100644 index f0e13d54be2e8..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/Scaler.proto +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -package CoreML.Specification; - -/** - * A scaling operation. - * - * This function has the following formula: - * - * .. math:: - * f(x) = scaleValue \cdot (x + shiftValue) - * - * If the ``scaleValue`` is not given, the default value 1 is used. - * If the ``shiftValue`` is not given, the default value 0 is used. - * - * If ``scaleValue`` and ``shiftValue`` are each a single value - * and the input is an array, then the scale and shift are applied - * to each element of the array. - * - * If the input is an integer, then it is converted to a double to - * perform the scaling operation. If the output type is an integer, - * then it is cast to an integer. If that cast is lossy, then an - * error is generated. - */ -message Scaler { - repeated double shiftValue = 1; - repeated double scaleValue = 2; -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/SoundAnalysisPreprocessing.proto b/onnxruntime/core/providers/coreml/mlmodel_format/SoundAnalysisPreprocessing.proto deleted file mode 100644 index 05bb744a9af94..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/SoundAnalysisPreprocessing.proto +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) 2019, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -package CoreML.Specification.CoreMLModels; - -/** -* A model which takes audio signal samples as input and outputs an array of -* preprocessed samples according to the specified preprocessing types -*/ -message SoundAnalysisPreprocessing { - - // Specific preprocessing types for sound analysis - - /* Vggish preprocesses input audio samples and makes them ready to - be fed to Vggish feature extractor. - c.f. https://arxiv.org/pdf/1609.09430.pdf - - The preprocessing takes input a single channel (monophonic) audio samples - 975 miliseconds long, sampled at 16KHz, i.e., 15600 samples 1D multiarray - and produces preprocessed samples in multiarray of shape [1, 96, 64] - - (1) Splits the input audio samples into overlapping frames, where each - frame is 25 milliseconds long and hops forward by 10 milliseconds. - Any partial frames at the end are dropped. - - (2) Hann window: apply a periodic Hann with a window_length of - 25 milliseconds, which translates to 400 samples in 16KHz sampling rate - - w(n) = 0.5 - 0.5 * cos(2*pi*n/window_length_sample), - where 0 <= n <= window_lenth_samples - 1 and window_lenth_samples = 400 - - Then, the Hann window is applied to each frame as below - - windowed_frame(n) = frame(n) * w(n) - where 0 <= n <= window_lenth_samples - 1 and window_lenth_samples = 400 - - (3) Power spectrum: calculate short-time Fourier transfor magnitude, with - an FFT length of 512 - - (4) Log Mel filter bank: calculates a log magnitude mel-frequency - spectrogram minimum frequency of 125Hz and maximum frequency of 7500Hz, - number of mel bins is 64, log_offset is 0.01, number of spectrum bins - is 64. - */ - - message Vggish { - // no specific parameter - } - - // Vision feature print type - oneof SoundAnalysisPreprocessingType { - Vggish vggish = 20; - } - -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/TextClassifier.proto b/onnxruntime/core/providers/coreml/mlmodel_format/TextClassifier.proto deleted file mode 100644 index bf6d3c7f7f3e5..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/TextClassifier.proto +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2018, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification.CoreMLModels; - -/** -* A model which takes a single input string and outputs a -* label for the input. -*/ -message TextClassifier { - - /* - * Stores the resivion number for the model, revision 1 is available on - * iOS, tvOS 12.0+, macoOS 10.14+ - */ - uint32 revision = 1; - - /* - * Stores the language of the model, as specified in BCP-47 format, - * e.g. "en-US". See https://tools.ietf.org/html/bcp47 - */ - string language = 10; - - /* - * Stores the byte representation of learned model parameters - */ - bytes modelParameterData = 100; - - /* - * Stores the set of output class labels - */ - oneof ClassLabels { - StringVector stringClassLabels = 200; - } - -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/TreeEnsemble.proto b/onnxruntime/core/providers/coreml/mlmodel_format/TreeEnsemble.proto deleted file mode 100644 index defebee98852c..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/TreeEnsemble.proto +++ /dev/null @@ -1,161 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -/** - * Each tree is a collection of nodes, - * each of which is identified by a unique identifier. - * - * Each node is either a branch or a leaf node. - * A branch node evaluates a value according to a behavior; - * if true, the node identified by ``true_child_node_id`` is evaluated next, - * if false, the node identified by ``false_child_node_id`` is evaluated next. - * A leaf node adds the evaluation value to the base prediction value - * to get the final prediction. - * - * A tree must have exactly one root node, - * which has no parent node. - * A tree must not terminate on a branch node. - * All leaf nodes must be accessible - * by evaluating one or more branch nodes in sequence, - * starting from the root node. - */ - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification; - -/** - * A tree ensemble post-evaluation transform. - */ -enum TreeEnsemblePostEvaluationTransform { - NoTransform = 0; - Classification_SoftMax = 1; - Regression_Logistic = 2; - Classification_SoftMaxWithZeroClassReference = 3; -} - -/** - * Tree ensemble parameters. - */ -message TreeEnsembleParameters { - message TreeNode { - uint64 treeId = 1; - uint64 nodeId = 2; - - enum TreeNodeBehavior { - BranchOnValueLessThanEqual = 0; - BranchOnValueLessThan = 1; - BranchOnValueGreaterThanEqual = 2; - BranchOnValueGreaterThan = 3; - BranchOnValueEqual = 4; - BranchOnValueNotEqual = 5; - LeafNode = 6; - } - - /** - * The branch mode parameters. - * - * If branch is false, - * then the parameters in this section must be filled in - * to determine how the branching functions. - */ - TreeNodeBehavior nodeBehavior = 3; - - /** - * If the node behavior mode is a branch mode, - * then these values must be filled in. - */ - uint64 branchFeatureIndex = 10; - double branchFeatureValue = 11; - uint64 trueChildNodeId = 12; - uint64 falseChildNodeId = 13; - bool missingValueTracksTrueChild = 14; - - /** - * The leaf mode. - * - * If ``nodeBahavior`` == ``LeafNode``, - * then the evaluationValue is added to the base prediction value - * in order to get the final prediction. - * To support multiclass classification - * as well as regression and binary classification, - * the evaluation value is encoded here as a sparse vector, - * with evaluationIndex being the index of the base vector - * that evaluation value is added to. - * In the single class case, - * it is expected that evaluationIndex is exactly 0. - */ - message EvaluationInfo { - uint64 evaluationIndex = 1; - double evaluationValue = 2; - } - - repeated EvaluationInfo evaluationInfo = 20; - - /** - * The relative hit rate of a node for optimization purposes. - * - * This value has no effect on the accuracy of the result; - * it allows the tree to optimize for frequent branches. - * The value is relative, - * compared to the hit rates of other branch nodes. - * - * You typically use a proportion of training samples - * that reached this node - * or some similar metric to derive this value. - */ - double relativeHitRate = 30; - } - - repeated TreeNode nodes = 1; - - /** - * The number of prediction dimensions or classes in the model. - * - * All instances of ``evaluationIndex`` in a leaf node - * must be less than this value, - * and the number of values in the ``basePredictionValue`` field - * must be equal to this value. - * - * For regression, - * this is the dimension of the prediction. - * For classification, - * this is the number of classes. - */ - uint64 numPredictionDimensions = 2; - - /** - * The base prediction value. - * - * The number of values in this must match - * the default values of the tree model. - */ - repeated double basePredictionValue = 3; -} - -/** - * A tree ensemble classifier. - */ -message TreeEnsembleClassifier { - TreeEnsembleParameters treeEnsemble = 1; - TreeEnsemblePostEvaluationTransform postEvaluationTransform = 2; - - // Required class label mapping - oneof ClassLabels { - StringVector stringClassLabels = 100; - Int64Vector int64ClassLabels = 101; - } -} - -/** - * A tree ensemble regressor. - */ -message TreeEnsembleRegressor { - TreeEnsembleParameters treeEnsemble = 1; - TreeEnsemblePostEvaluationTransform postEvaluationTransform = 2; -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/VisionFeaturePrint.proto b/onnxruntime/core/providers/coreml/mlmodel_format/VisionFeaturePrint.proto deleted file mode 100644 index cd13d290e421e..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/VisionFeaturePrint.proto +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) 2018, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -package CoreML.Specification.CoreMLModels; - -/** -* A model which takes an input image and outputs array(s) of features -* according to the specified feature types -*/ -message VisionFeaturePrint { - - // Specific vision feature print types - - // Scene extracts features useful for identifying contents of natural images - // in both indoor and outdoor environments - message Scene { - enum SceneVersion { - SCENE_VERSION_INVALID = 0; - // VERSION_1 is available on iOS,tvOS 12.0+, macOS 10.14+ - // It uses a 299x299 input image and yields a 2048 float feature vector - SCENE_VERSION_1 = 1; - } - - SceneVersion version = 1; - } - - // Objects extracts features useful for identifying and localizing - // objects in natural images - message Objects { - enum ObjectsVersion { - OBJECTS_VERSION_INVALID = 0; - // VERSION_1 is available on iOS,tvOS 14.0+, macOS 11.0+ - // It uses a 299x299 input image and yields two multiarray - // features: one at high resolution of shape (288, 35, 35) - // the other at low resolution of shape (768, 17, 17) - OBJECTS_VERSION_1 = 1; - } - - ObjectsVersion version = 1; - - /* - * Stores the names of the output features according to the - * order of them being computed from the neural network, i.e., - * the first element in the output is the earliest being - * computed, while the last is the latest being computed. In - * general, the order reflects the resolution of the feature. - * The earlier it is computed, the higher the feature resolution. - */ - repeated string output = 100; - } - - // Vision feature print type - oneof VisionFeaturePrintType { - Scene scene = 20; - Objects objects = 21; - } - -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/WordEmbedding.proto b/onnxruntime/core/providers/coreml/mlmodel_format/WordEmbedding.proto deleted file mode 100644 index ec11a67ca5294..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/WordEmbedding.proto +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) 2019, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification.CoreMLModels; - -/** -* A model which maps a set of strings into a finite-dimensional real vector space. -*/ -message WordEmbedding { - - /* - * Stores the revision number for the model, revision 2 is available on - * iOS, tvOS 13.0+, macOS 10.15+ - */ - uint32 revision = 1; - - /* - * Stores the language of the model, as specified in BCP-47 format, - * e.g. "en-US". See https://tools.ietf.org/html/bcp47 - */ - string language = 10; - - /* - * Stores efficient representation of emebedding as encoded by the Natural Language Framework - */ - bytes modelParameterData = 100; - -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/WordTagger.proto b/onnxruntime/core/providers/coreml/mlmodel_format/WordTagger.proto deleted file mode 100644 index 8523e05df2c0b..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/WordTagger.proto +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) 2018, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification.CoreMLModels; - -/** -* A model which takes a single input string and outputs a -* sequence of tokens, tags for tokens, along with their -* locations and lengths, in the original string. -*/ -message WordTagger { - - /* - * Stores the resivion number for the model, revision 1 is available on - * iOS, tvOS 12.0+, macoOS 10.14+ - */ - uint32 revision = 1; - - /* - * Stores the language of the model, as specified in BCP-47 format, - * e.g. "en-US". See https://tools.ietf.org/html/bcp47 - */ - string language = 10; - - /* - * Stores the name of tokens output. The output will be - * a sequence of strings that contains the tokens in the - * input string - */ - string tokensOutputFeatureName = 20; - - /* - * Stores the name of token tags output. The output will be - * a sequence of strings that contains the tags for each - * token in the input string - */ - string tokenTagsOutputFeatureName = 21; - - /* - * Stores the name of token locations output. The output will be - * a sequence of integers that contains the locations (indices) - * for each token in the input string, location starts from 0 - */ - string tokenLocationsOutputFeatureName = 22; - - /* - * Stores the name of token lengths output. The output will be - * a sequence of integers that contains the lengths for each - * token in the input string - */ - string tokenLengthsOutputFeatureName = 23; - - /* - * Stores the byte representation of learned model parameters - */ - bytes modelParameterData = 100; - - /* - * Stores the set of output tags - */ - oneof Tags { - StringVector stringTags = 200; - } - - - -} - diff --git a/onnxruntime/core/providers/coreml/model/host_utils.h b/onnxruntime/core/providers/coreml/model/host_utils.h index f7f45bce087bc..a9991ccb945ce 100644 --- a/onnxruntime/core/providers/coreml/model/host_utils.h +++ b/onnxruntime/core/providers/coreml/model/host_utils.h @@ -8,10 +8,50 @@ #include -#define API_AVAILABLE_OS_VERSIONS API_AVAILABLE(macos(10.15), ios(13)) +#if defined(__APPLE__) +// See https://apple.github.io/coremltools/mlmodel/Format/Model.html for the info on each CoreML specification version. +// See https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html for the list of ops +// in each CoreML specification version. -// Base requireed OS to run CoreML Specification Version 4 (Core ML 3) -#define HAS_VALID_BASE_OS_VERSION @available(macOS 10.15, iOS 13, *) +// Specification Versions : OS Availability(Core ML Version) +// +// 4 : iOS 13, macOS 10.15, tvOS 13, watchOS 6 (Core ML 3) +// - initial version of CoreML EP +// 5 : iOS 14, macOS 11, tvOS 14, watchOS 7 (Core ML 4) +// - additional layers in NeuralNetwork but currently none are implemented by the CoreML EP +// 6 : iOS 15, macOS 12, tvOS 15, watchOS 8 (Core ML 5) +// - adds MLProgram (MILSpec.Program) +// - iOS 15 ops +// 7 : iOS 16, macOS 13, tvOS 16, watchOS 9 (Core ML 6) +// - iOS 16 ops +// 8 : iOS 17, macOS 14, tvOS 17, watchOS 10 (Core ML 7) +// - iOS 17 ops +// +// **NOTE** We use the Core ML version not the spec version. +// +// e.g. iOS 13 has Core ML 3 (which is Core ML Specification version 4), and the related macros are +// API_AVAILABLE_COREML3, HAS_COREML3_OR_LATER and onnxruntime::coreml::util::CoreMLVersion() will return 3. + +// https://developer.apple.com/documentation/swift/marking-api-availability-in-objective-c +// API_AVAILABLE is used to decorate Objective-C APIs +#define API_AVAILABLE_COREML3 API_AVAILABLE(macos(10.15), ios(13)) +#define API_AVAILABLE_COREML4 API_AVAILABLE(macos(11), ios(14)) +#define API_AVAILABLE_COREML5 API_AVAILABLE(macos(12), ios(15)) +#define API_AVAILABLE_COREML6 API_AVAILABLE(macos(13), ios(16)) +#define API_AVAILABLE_COREML7 API_AVAILABLE(macos(14), ios(17)) + +// @available is used in implementation code +// Base required OS to run CoreML Specification Version 4 (Core ML 3) +#define HAS_COREML3_OR_LATER @available(macOS 10.15, iOS 13, *) +#define HAS_COREML4_OR_LATER @available(macOS 11, iOS 14, *) +#define HAS_COREML5_OR_LATER @available(macOS 12, iOS 15, *) +#define HAS_COREML6_OR_LATER @available(macOS 13, iOS 16, *) +#define HAS_COREML7_OR_LATER @available(macOS 14, iOS 17, *) + +#endif + +#define MINIMUM_COREML_VERSION 3 // first version we support +#define MINIMUM_COREML_MLPROGRAM_VERSION 5 // first version where ML Program was available namespace onnxruntime { namespace coreml { @@ -21,9 +61,18 @@ namespace util { // This corresponds to [CoreML Specification Version 4 (Core ML 3)] bool HasRequiredBaseOS(); +// Return the CoreML version if 3 or higher. Otherwise returns -1. +int CoreMLVersion(); + // Get a temporary macOS/iOS temp file path std::string GetTemporaryFilePath(); +#if !defined(NDEBUG) && defined(__APPLE__) +// Override location the model is written to so that a) it's easily found and b) it is not automatically deleted +// when the EP exits. Use to debug the model that is generated. +// See onnxruntime/core/providers/coreml/dump_mlprogram_model.py for a script to dump the ML Program. +constexpr const char* kOverrideModelOutputDirectoryEnvVar = "ORT_COREML_EP_MODEL_DIR"; +#endif } // namespace util } // namespace coreml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/model/host_utils.mm b/onnxruntime/core/providers/coreml/model/host_utils.mm index 4c394386cd37a..5487ea35388f5 100644 --- a/onnxruntime/core/providers/coreml/model/host_utils.mm +++ b/onnxruntime/core/providers/coreml/model/host_utils.mm @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/platform/env.h" #include "core/providers/coreml/model/host_utils.h" #import @@ -10,19 +11,42 @@ namespace util { bool HasRequiredBaseOS() { - // This may look strange, but it is required "@available(macOS ....)" to safe-guard some code - // otherwise the compiler will spit -Wunsupported-availability-guard - if (HAS_VALID_BASE_OS_VERSION) - return true; - else - return false; + return CoreMLVersion() >= 3; +} + +int32_t CoreMLVersion() { + if (HAS_COREML7_OR_LATER) + return 7; + if (HAS_COREML6_OR_LATER) + return 6; + if (HAS_COREML5_OR_LATER) + return 5; + if (HAS_COREML4_OR_LATER) + return 4; + if (HAS_COREML3_OR_LATER) + return 3; + + return -1; } std::string GetTemporaryFilePath() { - // Get temporary directory. + // Get temporary directory for user. NSURL* temporary_directory_url = [NSURL fileURLWithPath:NSTemporaryDirectory() isDirectory:YES]; + +#if !defined(NDEBUG) + std::string path_override = Env::Default().GetEnvironmentVar(kOverrideModelOutputDirectoryEnvVar); + if (!path_override.empty()) { + NSString* ns_path_override = [NSString stringWithUTF8String:path_override.c_str()]; + temporary_directory_url = [NSURL fileURLWithPath:ns_path_override isDirectory:YES]; + } +#endif + // Generate a Unique file name to use. NSString* temporary_filename = [[NSProcessInfo processInfo] globallyUniqueString]; + + // make it easy to see who generated it + temporary_filename = [@"onnxruntime-" stringByAppendingString:temporary_filename]; + // Create URL to that file. NSURL* temporary_file_url = [temporary_directory_url URLByAppendingPathComponent:temporary_filename]; diff --git a/onnxruntime/core/providers/coreml/model/host_utils_stub.cc b/onnxruntime/core/providers/coreml/model/host_utils_stub.cc new file mode 100644 index 0000000000000..5c383b0274e8c --- /dev/null +++ b/onnxruntime/core/providers/coreml/model/host_utils_stub.cc @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/platform/env.h" +#include "core/providers/coreml/model/host_utils.h" + +namespace onnxruntime { +namespace coreml { +namespace util { + +bool HasRequiredBaseOS() { + return true; +} + +int CoreMLVersion() { + return 7; // CoreML 7 is the latest we support. +} + +std::string GetTemporaryFilePath() { + static std::atomic counter = 0; + + // we want to avoid creating endless directories/names whilst avoiding clashes if tests run in parallel so cycle + // through 20 potential output names. + auto dir_name = "coreml_ep_test_run." + std::to_string(counter++ % 20); + + // to replicate the iOS/macOS host_utils.mm behavior where the output is / + // we want to return the name of something that does not exist. this is required for ML Package creation. + auto& env = Env::Default(); + if (env.FolderExists(dir_name)) { + ORT_THROW_IF_ERROR(env.DeleteFolder(ToPathString(dir_name))); + } + + return dir_name; +} + +} // namespace util +} // namespace coreml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/model/model.h b/onnxruntime/core/providers/coreml/model/model.h index 105b6a0333b15..e3cd43d786fc3 100644 --- a/onnxruntime/core/providers/coreml/model/model.h +++ b/onnxruntime/core/providers/coreml/model/model.h @@ -33,59 +33,62 @@ using GetOutputTensorMutableRawDataFn = std::function static_shape)>; class Model { - friend class ModelBuilder; - public: + Model(const std::string& path, + std::vector&& model_input_names, + std::vector&& model_output_names, + std::unordered_map&& input_output_info, + std::unordered_set&& scalar_outputs, + std::unordered_set&& int64_outputs, + const logging::Logger& logger, uint32_t coreml_flags); + ~Model(); ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Model); + Status LoadModel(); + Status Predict(const std::unordered_map& inputs, const std::unordered_map& outputs, const GetOutputTensorMutableRawDataFn& get_output_tensor_mutable_raw_data_fn); - bool IsScalarOutput(const std::string& output_name) const; + bool IsScalarOutput(const std::string& output_name) const { + return Contains(scalar_outputs_, output_name); + } - bool IsInt64Output(const std::string& output_name) const; + bool IsInt64Output(const std::string& output_name) const { + return Contains(int64_outputs_, output_name); + } // Mutex for exclusive lock to this model object OrtMutex& GetMutex() { return mutex_; } - // Input and output names in the onnx model's order - const std::vector& GetOnnxInputs() const { return onnx_inputs_; } - void SetOnnxInputs(std::vector&& inputs) { onnx_inputs_ = std::move(inputs); } + // Input and output names in the ORT fused node's order. + // Names may have been adjusted from the originals due to CoreML naming rules. + // We do inputs/outputs based on order at the ONNX level so this doesn't matter. + const std::vector& GetOrderedInputs() const { return model_input_names_; } + const std::vector& GetOrderedOutputs() const { return model_output_names_; } - const std::vector& GetOnnxOutputs() const { return onnx_outputs_; } - void SetOnnxOutputs(std::vector&& outputs) { onnx_outputs_ = std::move(outputs); } + const OnnxTensorInfo* TryGetInputOutputInfo(const std::string& name) const { + const auto info_it = input_output_info_.find(name); + return info_it != input_output_info_.end() ? &info_it->second : nullptr; + } - const OnnxTensorInfo* TryGetInputOutputInfo(const std::string& name) const; - const OnnxTensorInfo& GetInputOutputInfo(const std::string& name) const; + const OnnxTensorInfo& GetInputOutputInfo(const std::string& name) const { + const auto* info = TryGetInputOutputInfo(name); + ORT_ENFORCE(info != nullptr, "Failed to get info for input/output: ", name); + return *info; + } private: std::unique_ptr execution_; - std::unordered_set scalar_outputs_; - std::unordered_set int64_outputs_; - - std::vector onnx_inputs_; - std::vector onnx_outputs_; + std::vector model_input_names_; // input names in the order of the ORT fused node's inputs + std::vector model_output_names_; // output names in the order of the ORT fused node's outputs std::unordered_map input_output_info_; + std::unordered_set scalar_outputs_; + std::unordered_set int64_outputs_; OrtMutex mutex_; - - Model(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags); - Status LoadModel(); - - void SetInputOutputInfo(std::unordered_map&& input_output_info) { - input_output_info_ = std::move(input_output_info); - } - - void SetScalarOutputs(std::unordered_set&& scalar_outputs) { - scalar_outputs_ = std::move(scalar_outputs); - } - - void SetInt64Outputs(std::unordered_set&& int64_outputs) { - int64_outputs_ = std::move(int64_outputs); - } }; } // namespace coreml diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 155201ad4c39c..1434043e064f4 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -19,6 +19,7 @@ #include "core/common/narrow.h" #include "core/common/span_utils.h" #include "core/graph/onnx_protobuf.h" +#include "core/platform/env.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/coreml_provider_factory.h" #include "core/providers/coreml/model/host_utils.h" @@ -252,14 +253,14 @@ - (instancetype)initWithPath:(const std::string&)path coreml_flags:(uint32_t)coreml_flags; - (void)cleanup; - (void)dealloc; -- (Status)loadModel API_AVAILABLE_OS_VERSIONS; +- (Status)loadModel API_AVAILABLE_COREML3; - (Status)predict:(const std::unordered_map&)inputs outputs:(const std::unordered_map&)outputs getOutputTensorDataFn:(const GetOutputTensorMutableRawDataFn&) get_output_tensor_mutable_raw_data_fn - API_AVAILABLE_OS_VERSIONS; + API_AVAILABLE_COREML3; -@property(nullable) MLModel* model API_AVAILABLE_OS_VERSIONS; +@property(nullable) MLModel* model API_AVAILABLE_COREML3; @end @@ -287,6 +288,14 @@ - (void)cleanup { compiled_model_path_ = nil; } +#if !defined(NDEBUG) + std::string path_override = Env::Default().GetEnvironmentVar(util::kOverrideModelOutputDirectoryEnvVar); + if (!path_override.empty()) { + // don't cleanup + coreml_model_path_ = nil; + } +#endif + if (coreml_model_path_ != nil) { error = nil; [[NSFileManager defaultManager] removeItemAtPath:coreml_model_path_ error:&error]; @@ -308,6 +317,10 @@ - (Status)loadModel { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create model URL from path"); } + // TODO: Update this to version with callback handler as the API used here is deprecated. + // https://developer.apple.com/documentation/coreml/mlmodel/3929553-compilemodelaturl + // As we call loadModel during EP Compile there shouldn't be an issue letting the actual compile run in the + // background. We will have to check for completion in `predict` and block until it is done. NSError* error = nil; NSURL* compileUrl = [MLModel compileModelAtURL:modelUrl error:&error]; @@ -454,7 +467,7 @@ Status Predict(const std::unordered_map& inputs, return Status::OK(); } - if (HAS_VALID_BASE_OS_VERSION) { + if (HAS_COREML3_OR_LATER) { Status status{}; @autoreleasepool { status = [execution_ loadModel]; @@ -471,7 +484,7 @@ Status Predict(const std::unordered_map& inputs, const GetOutputTensorMutableRawDataFn& get_output_tensor_mutable_raw_data_fn) { ORT_RETURN_IF_NOT(model_loaded, "Execution::Predict requires Execution::LoadModel"); - if (HAS_VALID_BASE_OS_VERSION) { + if (HAS_COREML3_OR_LATER) { @autoreleasepool { return [execution_ predict:inputs outputs:outputs @@ -482,8 +495,20 @@ Status Predict(const std::unordered_map& inputs, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Execution::Predict requires macos 10.15+ or ios 13+"); } -Model::Model(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags) - : execution_(std::make_unique(path, logger, coreml_flags)) { +Model::Model(const std::string& path, + std::vector&& model_input_names, + std::vector&& model_output_names, + std::unordered_map&& input_output_info, + std::unordered_set&& scalar_outputs, + std::unordered_set&& int64_outputs, + const logging::Logger& logger, + uint32_t coreml_flags) + : execution_(std::make_unique(path, logger, coreml_flags)), + model_input_names_(std::move(model_input_names)), + model_output_names_(std::move(model_output_names)), + input_output_info_(std::move(input_output_info)), + scalar_outputs_(std::move(scalar_outputs)), + int64_outputs_(std::move(int64_outputs)) { } Model::~Model() {} @@ -497,25 +522,5 @@ Status Predict(const std::unordered_map& inputs, const GetOutputTensorMutableRawDataFn& get_output_tensor_mutable_raw_data_fn) { return execution_->Predict(inputs, outputs, get_output_tensor_mutable_raw_data_fn); } - -bool Model::IsScalarOutput(const std::string& output_name) const { - return Contains(scalar_outputs_, output_name); -} - -bool Model::IsInt64Output(const std::string& output_name) const { - return Contains(int64_outputs_, output_name); -} - -const OnnxTensorInfo* Model::TryGetInputOutputInfo(const std::string& name) const { - const auto info_it = input_output_info_.find(name); - return info_it != input_output_info_.end() ? &info_it->second : nullptr; -} - -const OnnxTensorInfo& Model::GetInputOutputInfo(const std::string& name) const { - const auto* info = TryGetInputOutputInfo(name); - ORT_ENFORCE(info != nullptr, "Failed to get info for input/output: ", name); - return *info; -} - } // namespace coreml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/model/model_stub.cc b/onnxruntime/core/providers/coreml/model/model_stub.cc new file mode 100644 index 0000000000000..c6f2e7401ea1e --- /dev/null +++ b/onnxruntime/core/providers/coreml/model/model_stub.cc @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/coreml/model/model.h" + +namespace onnxruntime { +namespace coreml { + +class Execution {}; + +Model::Model(const std::string& /*path*/, + std::vector&& model_input_names, + std::vector&& model_output_names, + std::unordered_map&& input_output_info, + std::unordered_set&& scalar_outputs, + std::unordered_set&& int64_outputs, + const logging::Logger& /*logger*/, + uint32_t /*coreml_flags*/) + : execution_(std::make_unique()), + model_input_names_(std::move(model_input_names)), + model_output_names_(std::move(model_output_names)), + input_output_info_(std::move(input_output_info)), + scalar_outputs_(std::move(scalar_outputs)), + int64_outputs_(std::move(int64_outputs)) { +} + +Model::~Model() { +} + +Status Model::LoadModel() { + // return OK so we hit more CoreML EP code. + return Status::OK(); +} + +Status Model::Predict(const std::unordered_map& /*inputs*/, + const std::unordered_map& /*outputs*/, + const GetOutputTensorMutableRawDataFn& /*get_output_tensor_mutable_raw_data_fn*/) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Executing a CoreML model is not supported on this platform."); +} + +} // namespace coreml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index cbdf79caf3afd..c3d5a51b636ef 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/providers/cpu/cpu_execution_provider.h" +#include #include "core/framework/op_kernel.h" #include "core/framework/kernel_registry.h" #include "core/mlas/inc/mlas.h" @@ -29,7 +30,7 @@ CPUExecutionProvider::CPUExecutionProvider(const CPUExecutionProviderInfo& info) std::vector CPUExecutionProvider::CreatePreferredAllocators() { bool create_arena = info_.create_arena; -#if defined(USE_JEMALLOC) || defined(USE_MIMALLOC) +#if defined(USE_JEMALLOC) || defined(USE_MIMALLOC) || defined(ABSL_HAVE_ADDRESS_SANITIZER) // JEMalloc/mimalloc already have memory pool, so just use device allocator. create_arena = false; #elif !(defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64)) @@ -142,9 +143,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Aco class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Atan); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, double, Gemm); -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Gemm); -#endif class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Hardmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, LogSoftmax); @@ -334,9 +332,6 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Flatten); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, double, Gemm); -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, MLFloat16, Gemm); -#endif class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int32_t, MatMul); @@ -496,9 +491,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Sp class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, ScatterND); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, Gemm); -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Gemm); -#endif class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint8_t, BitShift); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint32_t, BitShift); @@ -605,9 +597,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, string, Expand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Gemm); -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, Gemm); -#endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, MatMul); @@ -725,6 +714,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, float, IsNaN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, double, IsNaN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, MLFloat16, IsNaN); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, BFloat16, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool, NonZero); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, NonZero); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, NonZero); @@ -1034,6 +1024,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, BFloat16, IsNaN); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Gelu); #if !defined(DISABLE_FLOAT8_TYPES) class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FNUZ, IsNaN); @@ -2006,8 +1998,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { Greater)>, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #if !defined(DISABLE_FLOAT8_TYPES) BuildKernelCreateInfo, @@ -2612,15 +2609,6 @@ Status RegisterFp16Kernels(KernelRegistry& kernel_registry) { MLFloat16, LeakyRelu)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index bf73c59fb78ca..c4a83efa01a91 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -25,6 +25,7 @@ #include "core/providers/cpu/tensor/tile.h" #include "core/providers/cpu/tensor/gather_elements.h" #include "core/providers/cpu/tensor/unsqueeze.h" +#include "core/providers/cpu/tensor/upsamplebase.h" #ifndef DISABLE_CONTRIB_OPS #include "contrib_ops/cpu/bert/attention_base.h" @@ -62,6 +63,7 @@ #endif #include "cpu_provider_shared.h" +#include namespace onnxruntime { // The suppressed warning is: "The type with a virtual function needs either public virtual or protected nonvirtual destructor." @@ -292,6 +294,12 @@ struct ProviderHostCPUImpl : ProviderHostCPU { Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) override { return p->contrib::transformers::Sampling::Compute(ctx); } Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) override { return p->contrib::transformers::Sampling::SetupSubgraphExecutionInfo(session_state, attribute_name, subgraph_session_state); } + void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims, + gsl::span input_dims, + InlinedVector& scales) const override { + p->AdjustOutputSizeAsPolicy(output_dims, input_dims, scales); + } + #ifdef ENABLE_ATEN Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) override { return p->ATen::Compute(p_ctx); } #endif diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index f33eec4b93e98..c0e674827e4d1 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -24,6 +24,7 @@ class SliceOp__PrepareForComputeMetadata; // Directly maps to SliceOp::PrepareF class UnsqueezeBase__Prepare; // Directly maps to UnsqueezeBase::Prepare class contrib__AdamWOptimizerBase__Prepare; class contrib__SGDOptimizerV2Base__Prepare; +class UpsampleBase; using PadsVector = InlinedVector; @@ -202,6 +203,10 @@ struct ProviderHostCPU { virtual Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) = 0; virtual Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0; + virtual void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims, + gsl::span input_dims, + InlinedVector& scales) const = 0; + #ifdef ENABLE_ATEN virtual Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) = 0; #endif diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.cc b/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.cc index 180b3153fbb34..e2981da3a6f25 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.cc +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#if !defined(ORT_MINIMAL_BUILD) + #include "core/providers/cpu/ml/tree_ensemble_helper.h" #include "core/common/common.h" #include "onnx/defs/tensor_proto_util.h" @@ -64,3 +66,5 @@ Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name } // namespace ml } // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.h index 3c8a5a840bc76..33172c343a88e 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.h @@ -2,6 +2,9 @@ // Licensed under the MIT License. #pragma once + +#if !defined(ORT_MINIMAL_BUILD) + #include "core/common/common.h" #include "core/framework/op_kernel.h" @@ -13,3 +16,5 @@ Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name } // namespace ml } // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h b/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h index a5d46aff83b50..ccecbabfa3db3 100644 --- a/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h +++ b/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h @@ -25,6 +25,8 @@ class BatchNormHelper { const Tensor* var, bool is_spatial = true, bool is_nhwc = false) { + // NHWC dependent shape: X + // All other shapes are assumed to be in NCHW layout? const auto& x_dims = X->Shape().GetDims(); // If x_dims size < 2, num_channels defaults to 1. @@ -48,16 +50,22 @@ class BatchNormHelper { // validate 'scales' shape const auto& scale_dims = scale->Shape().GetDims(); if (static_cast(scale_dims.size()) != kNumInputScaleDimensions) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: NumDimensions() != ", kNumInputScaleDimensions); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input scale: NumDimensions() != ", kNumInputScaleDimensions); } if (scale_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: 0th dimension != ", num_channels); } + // N & C do not belong to features + // skip the first element for NHWC and the first two elements for NCHW. + int feature_offset = is_nhwc ? 1 : 2; + // in non-spatial cases - the other dims of 'scale' must be validated if (!is_spatial) { for (int feature = 0; feature < num_feature_dims; ++feature) { - if (scale_dims[1 + feature] != x_dims[2 + feature]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: ", (1 + feature), " dimension != ", x_dims[2 + feature]); + if (scale_dims[1 + feature] != x_dims[feature_offset + feature]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: ", (1 + feature), + " dimension != ", x_dims[feature_offset + feature]); } } } @@ -65,7 +73,8 @@ class BatchNormHelper { // validate 'B' shape const auto& B_dims = B->Shape().GetDims(); if (static_cast(B_dims.size()) != kNumInputBiasDimensions) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: NumDimensions() != ", kNumInputBiasDimensions); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input B: NumDimensions() != ", kNumInputBiasDimensions); } if (B_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: 0th dimension != ", num_channels); @@ -73,8 +82,9 @@ class BatchNormHelper { // in non-spatial cases - the other dims of 'B' must be validated if (!is_spatial) { for (int feature = 0; feature < num_feature_dims; ++feature) { - if (B_dims[1 + feature] != x_dims[2 + feature]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: ", (1 + feature), " dimension != ", x_dims[2 + feature]); + if (B_dims[1 + feature] != x_dims[feature_offset + feature]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: ", (1 + feature), + " dimension != ", x_dims[feature_offset + feature]); } } } @@ -82,16 +92,19 @@ class BatchNormHelper { // validate 'mean' shape const auto& mean_dims = mean->Shape().GetDims(); if (static_cast(mean_dims.size()) != kNumInputMeanDimensions) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: NumDimensions() != ", kNumInputMeanDimensions); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input mean: NumDimensions() != ", kNumInputMeanDimensions); } if (mean_dims[0] != num_channels) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: 0th dimension != ", num_channels); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input mean: 0th dimension != ", num_channels); } // in non-spatial cases - the other dims of 'mean' must be validated if (!is_spatial) { for (int feature = 0; feature < num_feature_dims; ++feature) { - if (mean_dims[1 + feature] != x_dims[2 + feature]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: ", (1 + feature), " dimension != ", x_dims[2 + feature]); + if (mean_dims[1 + feature] != x_dims[feature_offset + feature]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: ", (1 + feature), + " dimension != ", x_dims[feature_offset + feature]); } } } @@ -99,7 +112,8 @@ class BatchNormHelper { // validate 'var' shape const auto& var_dims = var->Shape().GetDims(); if (static_cast(var_dims.size()) != kNumInputVarianceDimensions) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: NumDimensions() != ", kNumInputVarianceDimensions); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input var: NumDimensions() != ", kNumInputVarianceDimensions); } if (var_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: 0th dimension != ", num_channels); @@ -107,8 +121,9 @@ class BatchNormHelper { // in non-spatial cases - the other dims of 'var' must be validated if (!is_spatial) { for (int feature = 0; feature < num_feature_dims; ++feature) { - if (var_dims[1 + feature] != x_dims[2 + feature]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: ", (1 + feature), " dimension != ", x_dims[2 + feature]); + if (var_dims[1 + feature] != x_dims[feature_offset + feature]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: ", (1 + feature), + " dimension != ", x_dims[feature_offset + feature]); } } } diff --git a/onnxruntime/core/providers/cpu/signal/dft.cc b/onnxruntime/core/providers/cpu/signal/dft.cc index 15bf633579e5f..50fe7d1344eaf 100644 --- a/onnxruntime/core/providers/cpu/signal/dft.cc +++ b/onnxruntime/core/providers/cpu/signal/dft.cc @@ -506,7 +506,7 @@ static Status short_time_fourier_transform(OpKernelContext* ctx, bool is_oneside // Calculate the window size with preference to the window input. const auto window_size = window ? window->Shape()[0] : frame_length; - ORT_ENFORCE(window_size < signal_size, "Ensure that the dft size is smaller than the signal."); + ORT_ENFORCE(window_size <= signal_size, "Ensure that the dft size is smaller than the signal."); // Calculate the number of dfts to run const auto n_dfts = diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.cc b/onnxruntime/core/providers/cpu/tensor/gelu.cc new file mode 100644 index 0000000000000..d55973eda180f --- /dev/null +++ b/onnxruntime/core/providers/cpu/tensor/gelu.cc @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/common.h" +#include "core/common/narrow.h" +#include "core/framework/op_kernel.h" +#include "core/util/math_cpuonly.h" +#include "core/mlas/inc/mlas.h" + +#include "core/platform/threadpool.h" +#include +#include "core/providers/cpu/element_wise_ranged_transform.h" +#include "core/providers/cpu/tensor/gelu.h" + +using onnxruntime::narrow; +using namespace onnxruntime::common; + +namespace onnxruntime { + +// May revisit the implementations to support inplace computation, if needed. + +ONNX_CPU_OPERATOR_KERNEL( + Gelu, + 20, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Gelu); + +#ifndef DISABLE_CONTRIB_OPS +namespace contrib { +ONNX_OPERATOR_KERNEL_EX( + Gelu, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Gelu); +} +#endif + +template +Status Gelu::Compute(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const T* input_data = input->Data(); + + Tensor* output = context->Output(0, input->Shape()); + T* output_data = output->MutableData(); + + concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); + int64_t elem_count = input->Shape().Size(); + constexpr int64_t length_per_task = 4096; // this number comes from FastGelu. + int64_t task_count = (elem_count + length_per_task - 1) / length_per_task; + + if (approximation_algorithm_ == "tanh") { + // FastGelu allows optional bias. Here we split input data into chunks. Each chunk + // has N elements (except the last chunk), and use thread pool to parallel chunks. + // N = 4096 is selected based on performance test results on input shape 1x128x768. + // FastGelu uses approximation for Gelu. The formula is 0.5 * (1 + Tanh(x * (C * x * x + B))) * x. + static constexpr float B = 0.7978845608028654f; // sqrt(2.0 / M_PI) + static constexpr float C = 0.035677408136300125f; // 0.044715 * sqrt(2.0 / M_PI) + + concurrency::ThreadPool::TryBatchParallelFor( + tp, static_cast(task_count), + [&](ptrdiff_t task_idx) { + const auto start = task_idx * length_per_task; + const T* p_input = input_data + start; + T* p_output = output_data + start; + int64_t count = std::min(length_per_task, elem_count - start); + + for (int64_t i = 0; i < count; i++) { + T value = p_input[i]; + p_output[i] = value * (static_cast(C) * value * value + static_cast(B)); + } + + MlasComputeTanh(p_output, p_output, narrow(count)); + + for (int64_t i = 0; i < count; i++) { + p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f); + } + }, + 0); + return Status::OK(); + } else if (approximation_algorithm_ == "none") { + concurrency::ThreadPool::TryBatchParallelFor( + tp, static_cast(task_count), + [&](ptrdiff_t task_idx) { + const auto start = task_idx * length_per_task; + const T* p_input = input_data + start; + T* p_output = output_data + start; + int64_t count = std::min(length_per_task, elem_count - start); + + for (int64_t i = 0; i < count; i++) { + T value = p_input[i]; + p_output[i] = value * static_cast(M_SQRT1_2); + } + + MlasComputeErf(p_output, p_output, narrow(count)); + + for (int64_t i = 0; i < count; i++) { + p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f); + } + }, + 0); + return Status::OK(); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported approximation_algorithm: ", approximation_algorithm_); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.h b/onnxruntime/core/providers/cpu/tensor/gelu.h new file mode 100644 index 0000000000000..13238028d878a --- /dev/null +++ b/onnxruntime/core/providers/cpu/tensor/gelu.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace onnxruntime { + +template +class Gelu final : public OpKernel { + public: + explicit Gelu(const OpKernelInfo& info) : OpKernel(info) { + approximation_algorithm_ = info.GetAttrOrDefault("approximate", "none"); + } + Status Compute(OpKernelContext* ctx) const override; + + private: + std::string approximation_algorithm_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/isinf.cc b/onnxruntime/core/providers/cpu/tensor/isinf.cc index 1b449f46927a2..9d18d1fa62288 100644 --- a/onnxruntime/core/providers/cpu/tensor/isinf.cc +++ b/onnxruntime/core/providers/cpu/tensor/isinf.cc @@ -23,7 +23,9 @@ ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST( using IsInfTypesOpset20 = TypeList< float, - double + double, + MLFloat16, + BFloat16 #if !defined(DISABLE_FLOAT8_TYPES) , Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ @@ -76,10 +78,8 @@ ONNX_CPU_OPERATOR_KERNEL( IsInf); IsInf::IsInf(const OpKernelInfo& info) : OpKernel(info) { - Status status = info.GetAttr("detect_positive", &detect_positive_); - ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_positive"); - status = info.GetAttr("detect_negative", &detect_negative_); - ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_negative"); + detect_positive_ = info.GetAttrOrDefault("detect_positive", 1); + detect_negative_ = info.GetAttrOrDefault("detect_negative", 1); opset_ = info.node().SinceVersion(); } @@ -87,29 +87,67 @@ namespace isinf_internal { template struct ComputeDispatchTarget { void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const { - const auto total_items = X.Shape().Size(); + auto input_data = X.DataAsSpan(); auto output_data = Y.MutableData(); if (detect_positive && detect_negative) { EigenMap(Y) = EigenMap(X).array().isInf(); } else if (detect_positive) { - auto input_data = X.Data(); - auto end_data = input_data + total_items; std::transform( - input_data, end_data, output_data, [](T v) { + input_data.begin(), input_data.end(), output_data, [](T v) { return (v == std::numeric_limits::infinity()); }); } else if (detect_negative) { - auto input_data = X.Data(); - auto end_data = input_data + total_items; std::transform( - input_data, end_data, output_data, [](T v) { + input_data.begin(), input_data.end(), output_data, [](T v) { return (v == -std::numeric_limits::infinity()); }); } else { // all false - memset(output_data, false, onnxruntime::narrow(total_items)); + memset(output_data, false, input_data.size()); + } + } +}; + +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const { + auto output_data = Y.MutableData(); + auto input_data = X.DataAsSpan(); + if (detect_positive && detect_negative) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](MLFloat16 v) { return v.IsInfinity(); }); + } else if (detect_positive) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](MLFloat16 v) { return v.IsPositiveInfinity(); }); + } else if (detect_negative) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](MLFloat16 v) { return v.IsNegativeInfinity(); }); + } else { + // all false + memset(output_data, false, input_data.size()); + } + } +}; + +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const { + auto output_data = Y.MutableData(); + auto input_data = X.DataAsSpan(); + if (detect_positive && detect_negative) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](BFloat16 v) { return v.IsInfinity(); }); + } else if (detect_positive) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](BFloat16 v) { return v.IsPositiveInfinity(); }); + } else if (detect_negative) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](BFloat16 v) { return v.IsNegativeInfinity(); }); + } else { + // all false + memset(output_data, false, input_data.size()); } } }; diff --git a/onnxruntime/core/providers/cpu/tensor/isnan.cc b/onnxruntime/core/providers/cpu/tensor/isnan.cc index 34495e382278a..0e15c64b126f3 100644 --- a/onnxruntime/core/providers/cpu/tensor/isnan.cc +++ b/onnxruntime/core/providers/cpu/tensor/isnan.cc @@ -46,9 +46,11 @@ ADD_TYPED_ISNAN_OP_9(MLFloat16); ADD_TYPED_ISNAN_OP_13(float); ADD_TYPED_ISNAN_OP_13(double); ADD_TYPED_ISNAN_OP_13(MLFloat16); +ADD_TYPED_ISNAN_OP_13(BFloat16); ADD_TYPED_ISNAN_OP(float); ADD_TYPED_ISNAN_OP(double); ADD_TYPED_ISNAN_OP(MLFloat16); +ADD_TYPED_ISNAN_OP(BFloat16); #if !defined(DISABLE_FLOAT8_TYPES) ADD_TYPED_ISNAN_OP(Float8E4M3FN); @@ -75,9 +77,7 @@ Status IsNaN::Compute(OpKernelContext* context) const { template <> Status IsNaN::Compute(OpKernelContext* context) const { const auto* X_ptr = context->Input(0); - if (!X_ptr) { - return Status(common::ONNXRUNTIME, common::FAIL, "Null input ptr"); - } + auto X_data = X_ptr->Data(); auto& dims = X_ptr->Shape(); auto shape_size = dims.Size(); @@ -91,6 +91,19 @@ Status IsNaN::Compute(OpKernelContext* context) const { return Status::OK(); } +template <> +Status IsNaN::Compute(OpKernelContext* context) const { + const auto* X_ptr = context->Input(0); + + auto X_data = X_ptr->DataAsSpan(); + auto& Y = *context->Output(0, X_ptr->Shape()); + + std::transform(X_data.begin(), X_data.end(), Y.MutableData(), + [](BFloat16 x) { return x.IsNaN(); }); + + return Status::OK(); +} + #if !defined(DISABLE_FLOAT8_TYPES) template <> Status IsNaN::Compute(OpKernelContext* context) const { diff --git a/onnxruntime/core/providers/cpu/tensor/reshape_helper.h b/onnxruntime/core/providers/cpu/tensor/reshape_helper.h index 5961686674424..d7ceda16e61ea 100644 --- a/onnxruntime/core/providers/cpu/tensor/reshape_helper.h +++ b/onnxruntime/core/providers/cpu/tensor/reshape_helper.h @@ -37,12 +37,14 @@ class ReshapeHelper { if (unknown_dim != -1) { // calculate unknown dimension ORT_ENFORCE(size != 0 && (input_shape_size % size) == 0, - "The input tensor cannot be reshaped to the requested shape. Input shape:", input_shape, ", requested shape:", TensorShape(requested_shape)); + "The input tensor cannot be reshaped to the requested shape. Input shape:", input_shape, + ", requested shape:", TensorShape(requested_shape)); requested_shape[unknown_dim] = input_shape_size / size; } else { // check if the output shape is valid. ORT_ENFORCE(input_shape_size == size, - "The input tensor cannot be reshaped to the requested shape. Input shape:", input_shape, ", requested shape:", TensorShape(requested_shape)); + "The input tensor cannot be reshaped to the requested shape. Input shape:", input_shape, + ", requested shape:", TensorShape(requested_shape)); } } }; diff --git a/onnxruntime/core/providers/cpu/tensor/scatter.cc b/onnxruntime/core/providers/cpu/tensor/scatter.cc index 8844b7e7a26c4..c7a2005924836 100644 --- a/onnxruntime/core/providers/cpu/tensor/scatter.cc +++ b/onnxruntime/core/providers/cpu/tensor/scatter.cc @@ -198,13 +198,6 @@ struct Func_Min { } }; -template <> -struct Func_Min { - void operator()(MLFloat16*, const MLFloat16*) const { - ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'."); - } -}; - template <> struct Func_Min { void operator()(BFloat16*, const BFloat16*) const { @@ -233,13 +226,6 @@ struct Func_Max { } }; -template <> -struct Func_Max { - void operator()(MLFloat16*, const MLFloat16*) const { - ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'."); - } -}; - template <> struct Func_Max { void operator()(BFloat16*, const BFloat16*) const { diff --git a/onnxruntime/core/providers/cpu/tensor/space_depth_ops.h b/onnxruntime/core/providers/cpu/tensor/space_depth_ops.h index 7d117317ba172..3218c8952d6ec 100644 --- a/onnxruntime/core/providers/cpu/tensor/space_depth_ops.h +++ b/onnxruntime/core/providers/cpu/tensor/space_depth_ops.h @@ -14,6 +14,7 @@ class SpaceDepthBase { "Attribute blocksize is not set."); } + template Status InputValidationsAndOutputDimsCalc(const Tensor& input, int64_t& batch, int64_t& input_depth, int64_t& input_height, int64_t& input_width, @@ -27,9 +28,15 @@ class SpaceDepthBase { } batch = input_shape[0]; - input_depth = input_shape[1]; - input_height = input_shape[2]; - input_width = input_shape[3]; + if constexpr (IsNHWC) { + input_depth = input_shape[3]; + input_height = input_shape[1]; + input_width = input_shape[2]; + } else { + input_depth = input_shape[1]; + input_height = input_shape[2]; + input_width = input_shape[3]; + } if (is_space_to_depth) { // SpaceToDepth op if ((input_height % this->blocksize_) != 0) { @@ -46,7 +53,8 @@ class SpaceDepthBase { } else { // DepthToSpace op if ((input_depth % (blocksize_ * blocksize_) != 0)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "DepthToSpace requires input depth to be a multiple of (block_size * blok_size)"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DepthToSpace requires input depth to be a multiple of (block_size * block_size)"); } output_depth = input_depth / blocksize_ / blocksize_; diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.cc b/onnxruntime/core/providers/cpu/tensor/upsample.cc index fa69e144be554..babbac0b7be17 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample.cc +++ b/onnxruntime/core/providers/cpu/tensor/upsample.cc @@ -1,10 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/providers/cpu/tensor/upsample.h" + +#include + +#include "core/common/inlined_containers.h" #include "core/common/safeint.h" #include "core/platform/threadpool.h" -#include "core/providers/cpu/tensor/upsample.h" #include "core/providers/cpu/tensor/upsample_antialias.h" + using namespace onnxruntime::common; using namespace std; using onnxruntime::narrow; @@ -30,6 +35,46 @@ REGISTER_VERSIONED_TYPED_KERNEL(int32_t, 9, 9); REGISTER_VERSIONED_TYPED_KERNEL(int8_t, 9, 9); REGISTER_VERSIONED_TYPED_KERNEL(uint8_t, 9, 9); +void UpsampleBase::AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims, + InlinedVector& scales) const { + // AspectRatioPolicy::STRETCH is default policy when opset < 18 + if (keep_aspect_ratio_policy_ == AspectRatioPolicy::STRETCH) { + return; + } + + InlinedHashSet axes_set(axes_.begin(), axes_.end()); + + float scale_in_policy = 0.0f; + if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_LARGER) { + scale_in_policy = std::numeric_limits::max(); + + for (size_t i = 0; i < scales.size(); i++) { + if (axes_set.empty() || axes_set.count(i) > 0) { + scale_in_policy = std::min(scale_in_policy, scales[i]); + } + } + } else if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_SMALLER) { + scale_in_policy = std::numeric_limits::min(); + + for (size_t i = 0; i < scales.size(); i++) { + if (axes_set.empty() || axes_set.count(i) > 0) { + scale_in_policy = std::max(scale_in_policy, scales[i]); + } + } + } + + for (size_t i = 0; i < scales.size(); i++) { + // if axes is not specified (AKA axes_set.empty()), we apply the policy to all axes + if (axes_set.empty() || axes_set.count(i) > 0) { + scales[i] = scale_in_policy; + output_dims[i] = static_cast(std::round(scales[i] * input_dims[i])); + } else { + scales[i] = 1.0f; + output_dims[i] = input_dims[i]; + } + } +} + template void UpsampleNearest2x(int64_t batch_size, int64_t num_channels, @@ -94,8 +139,8 @@ UpsampleNearestSetupInputMappings(int64_t n_dim, const TensorShape& input_shape, const TensorShape& output_shape, const std::vector& input_dim_factor, - const vector& scales, - const vector& roi, + gsl::span scales, + gsl::span roi, bool extrapolation_enabled, const GetOriginalCoordinateFunc& get_original_coordinate, const GetNearestPixelFunc& get_nearest_pixel) { @@ -141,8 +186,8 @@ static Status UpsampleNearestImpl(const T* input, T* output, const TensorShape& input_shape, const TensorShape& output_shape, - const vector& scales, - const vector& roi, + gsl::span scales, + gsl::span roi, bool extrapolation_enabled, const T extrapolation_value, const GetOriginalCoordinateFunc& get_original_coordinate, @@ -285,8 +330,8 @@ static Status UpsampleNearest(const T* input, T* output, const TensorShape& input_shape, const TensorShape& output_shape, - const vector& scales, - const vector& roi, + gsl::span scales, + gsl::span roi, bool is_resize, bool extrapolation_enabled, T extrapolation_value, @@ -412,7 +457,7 @@ BilinearParams SetupUpsampleBilinear(const int32_t input_height, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, AllocatorPtr& alloc, const GetOriginalCoordinateFunc& get_original_coordinate, const bool is_nchw) { @@ -518,7 +563,7 @@ BilinearParamsInteger SetupUpsampleBilinearInteger(const int32_t input_height, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, AllocatorPtr& alloc, const GetOriginalCoordinateFunc& get_original_coordinate, const bool is_nchw) { @@ -650,7 +695,7 @@ static TrilinearParams SetupUpsampleTrilinear(int64_t input_depth, float depth_scale, float height_scale, float width_scale, - const std::vector& roi, + gsl::span roi, AllocatorPtr& alloc, const GetOriginalCoordinateFunc& get_original_coordinate) { TrilinearParams p; @@ -796,7 +841,7 @@ void UpsampleTrilinear(int64_t batch_size, float depth_scale, float height_scale, float width_scale, - const std::vector& roi, + gsl::span roi, bool use_extrapolation, float extrapolation_value, const T* XdataBase, @@ -929,7 +974,7 @@ void ResizeBiCubic(int64_t batch_size, bool use_extrapolation, float extrapolation_value, bool exclude_outside, - const std::vector& roi, + gsl::span roi, const T* Xdata, T* Ydata, const GetOriginalCoordinateFunc& get_original_coordinate) { @@ -1067,9 +1112,9 @@ void ResizeBiCubic(int64_t batch_size, template Status Upsample::BaseCompute(OpKernelContext* context, - const std::vector& roi, - const std::vector& scales, - const gsl::span& output_dims) const { + gsl::span roi, + gsl::span scales, + gsl::span output_dims) const { const auto* X = context->Input(0); auto dims = X->Shape().GetDims(); ORT_RETURN_IF_NOT(output_dims.size() == dims.size(), "Rank of input and output tensor should be same."); @@ -1327,7 +1372,7 @@ Status Upsample::Compute(OpKernelContext* context) const { // Initialize the roi array to all zeros as this will be the most common case // Roi data is needed only when coordinate transformation mode is set to tf_crop_and_resize // for all other cases we need a 0 initialized roi array - std::vector roi_array(roi_); + InlinedVector roi_array(roi_); if (!roi_cached_) { bool use_default_roi = true; @@ -1353,7 +1398,7 @@ Status Upsample::Compute(OpKernelContext* context) const { ComputeROIWithAxes(roi_array, input_dims.size()); // Get scales data - std::vector scales_array(input_dims.size()); + InlinedVector scales_array(input_dims.size()); if (OpKernel::Node().InputDefs().size() == 1) { // Compute output shape from scales and input dims diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.h b/onnxruntime/core/providers/cpu/tensor/upsample.h index 3046ee4b8260d..8ff04781f6ad0 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample.h +++ b/onnxruntime/core/providers/cpu/tensor/upsample.h @@ -66,8 +66,8 @@ class Upsample : public UpsampleBase, public OpKernel { Status Compute(OpKernelContext* context) const override; - Status BaseCompute(OpKernelContext* context, const std::vector& roi, const std::vector& scales, - const gsl::span& output_dims) const; + Status BaseCompute(OpKernelContext* context, gsl::span roi, gsl::span scales, + gsl::span output_dims) const; }; BilinearParams SetupUpsampleBilinear(const int32_t input_height, @@ -76,7 +76,7 @@ BilinearParams SetupUpsampleBilinear(const int32_t input_height, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, AllocatorPtr& alloc, const GetOriginalCoordinateFunc& get_original_coordinate, const bool is_nchw); @@ -90,7 +90,7 @@ void UpsampleBilinear(const int32_t batch_size, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, const bool use_extrapolation, const float extrapolation_value, const T* const XdataBase, @@ -144,7 +144,7 @@ void NhwcUpsampleBilinear(const int32_t batch_size, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, const float extrapolation_value, const T* const XdataBase, T* const YdataBase, @@ -227,7 +227,7 @@ BilinearParamsInteger SetupUpsampleBilinearInteger(const int32_t input_height, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, AllocatorPtr& alloc, const GetOriginalCoordinateFunc& get_original_coordinate, const bool is_nchw); @@ -241,7 +241,7 @@ void NhwcUpsampleBilinearInteger(const int32_t batch_size, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, const float extrapolation_value, const T* const XdataBase, T* const YdataBase, diff --git a/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h b/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h index 59b512def619d..1e32b7e874b1a 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h +++ b/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h @@ -21,32 +21,6 @@ namespace onnxruntime { -namespace ConstValue { -constexpr int32_t mag_factor = 1 << (22 - 1); -} - -namespace { -const uint8_t* GetLookupTableShared() { - // initialized once - static const auto* lookup_table = []() { - // if we have already initialized the lookup table, just return - // ideally we could have a global lookup table, but that account for too much space. - /* Handles values form -640 to 639. */ - static uint8_t table[1280] = {0}; - - // taken from https://github.com/python-pillow/Pillow/blob/66add095a50d76c35c7f58643461f2edf78a3f05/src/libImaging/Resample.c#L94 - // we need to handle negative values - // it's equivalent to :x = np.clip(x, 0, 255) where x \in [-640, 639] - // we will accept a negative x for (&table[640])[x] means table +640 -x - for (int i = 0; i < 1280; ++i) { - table[i] = static_cast(std::min(std::max(i - 640, 0), 255)); - } - return table; - }(); - return lookup_table; -} -} // namespace - template struct FilterParamsBaseAntiAlias { std::vector bound; @@ -57,15 +31,15 @@ struct FilterParamsBaseAntiAlias { template struct FilterParamsAntiAlias { - float support_size = 2.0f; - float cubic_coeff_a = -0.75f; + float support_size = antialias_constants::kSupportSize; + float cubic_coeff_a = antialias_constants::kCubicCoeffA; FilterParamsBaseAntiAlias dim_x; FilterParamsBaseAntiAlias dim_y; FilterParamsBaseAntiAlias dim_z; const uint8_t* GetClip8LookupTable() const { - return GetLookupTableShared(); + return UpsampleBase::GetLookupTableShared(); } virtual ~FilterParamsAntiAlias() = default; virtual float Filter(float x) const = 0; @@ -89,7 +63,7 @@ struct BilinearParamsAntiAlias : FilterParamsAntiAlias { template struct BiCubicParamsAntiAlias : FilterParamsAntiAlias { BiCubicParamsAntiAlias() { - this->support_size = 4.0f; + this->support_size = antialias_constants::kBiCubicSupportSize; } // taken from @@ -124,27 +98,6 @@ struct TriLinearParamsAntiAlias : FilterParamsAntiAlias { } }; -template -struct AccumulateType { - using type = int32_t; - using Dtype = T; -}; - -template <> -struct AccumulateType { - using type = float; -}; - -template <> -struct AccumulateType { - using type = float; -}; - -template <> -struct AccumulateType { - using type = double; -}; - // The following method supports a 3/4/5-D input in 'Linear mode, cubic mode' // that amounts to 'Bilinear,TriLinear, Bicubic/Tricubic' Upsampling/Resizing in the sense that it assumes // A N-D tensor has @@ -156,19 +109,20 @@ struct AccumulateType { // - [N, H, W, C] and the scales are [1.0, height_scale, width_scale, 1.0] template void SetupUpsampleFilterAntiAlias(FilterParamsAntiAlias& p, - const gsl::span input_h_w_c, - const gsl::span output_h_w_c, - const gsl::span scale_h_w_c, - const std::vector& roi, + gsl::span input_h_w_c, + gsl::span output_h_w_c, + gsl::span scale_h_w_c, + gsl::span roi, AllocatorPtr& alloc, const GetOriginalCoordinateFunc& get_original_coordinate, bool exclude_outside, const bool is_nchw) { - auto compute_weight_coefficients = [&alloc, &roi, &get_original_coordinate, exclude_outside](const FilterParamsAntiAlias& p, - const int64_t input_size, - const int64_t output_size, - size_t rindex, - FilterParamsBaseAntiAlias& param_base, - const float rscale) -> int64_t { + auto compute_weight_coefficients = [&alloc, roi, &get_original_coordinate, exclude_outside]( + const FilterParamsAntiAlias& p, + const int64_t input_size, + const int64_t output_size, + size_t rindex, + FilterParamsBaseAntiAlias& param_base, + const float rscale) -> int64_t { param_base.bound.reserve(static_cast(output_size) * 2); param_base.out_of_bound_idx.reserve(static_cast(output_size)); @@ -245,13 +199,14 @@ void SetupUpsampleFilterAntiAlias(FilterParamsAntiAlias& p, // normalize the scale to 1 << 22 for int8/uint8 if constexpr (std::is_same::value) { - scale_buffer_int[x] = static_cast(std::round(scale_buffer[x] * ConstValue::mag_factor * 2.f)); + scale_buffer_int[x] = static_cast(std::round(scale_buffer[x] * ConstValue::mag_factor_x_2)); } } /*for (; x < window_size; x++) { scale_buffer[x] = 0; }*/ } + return window_size; }; @@ -269,9 +224,6 @@ void SetupUpsampleFilterAntiAlias(FilterParamsAntiAlias& p, } } -template -inline constexpr bool is_8bit_v = std::is_same::value || std::is_same::value; - /** * @brief To compute interpolation along with the last axis. * For brief,we assume the input tensor has 3 dimensions and we all it CHW for each character represent a dim. @@ -398,6 +350,7 @@ void ComputeInterpolationAtLevel2(int64_t num_channels, int64_t input_height, in output += *Xdata_offset * (*weight_coeff_start++); Xdata_offset += output_width; } + if constexpr (is_8bit_v) { *Ydata_offset++ = static_cast(clip8_lookups[output >> 22]); } else if constexpr (std::is_same::value) { @@ -444,6 +397,7 @@ void ComputeInterpolationAtLevel2(int64_t num_channels, int64_t input_height, in output += *Xdata_offset * (*weight_coeff_start++); Xdata_offset += output_width; } + if constexpr (is_8bit_v) { *Ydata_offset++ = static_cast(clip8_lookups[output >> 22]); } else if constexpr (std::is_same::value) { @@ -515,6 +469,7 @@ void UpsampleBaseAntiAlias(FilterParamsAntiAlias& p, narrow(input_height * num_channels * input_width)); auto ydata_span = gsl::make_span(image_temp_buffer.get(), narrow(input_height * num_channels * output_width)); + // This computes only the width direction.Thus height keeps unchanged. ComputeInterpolationAtLevel1(num_channels, input_height, input_width, input_height, output_width, xdata_span, ydata_span, p, p.dim_x, tp); } @@ -546,7 +501,7 @@ void UpsampleBilinearAntiAlias(const int64_t batch_size, const int64_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, const bool use_extrapolation, const float extrapolation_value, bool exclude_outside, @@ -575,7 +530,7 @@ void NhwcUpsampleBilinearAntiAlias(const int64_t batch_size, const int64_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, const bool use_extrapolation, const float extrapolation_value, bool exclude_outside, @@ -608,7 +563,7 @@ void NhwcResizeBiCubicAntiAlias(const int64_t batch_size, bool use_extrapolation, float extrapolation_value, bool exclude_outside, - const std::vector& roi, + gsl::span roi, const Tensor* X, T* Ydata_base, AllocatorPtr& alloc, @@ -688,7 +643,7 @@ void ResizeBiCubicAntiAlias(int64_t batch_size, bool use_extrapolation, float extrapolation_value, bool exclude_outside, - const std::vector& roi, + gsl::span roi, const Tensor* X, T* Ydata_base, AllocatorPtr& alloc, @@ -700,7 +655,7 @@ void ResizeBiCubicAntiAlias(int64_t batch_size, BiCubicParamsAntiAlias::type> p; p.cubic_coeff_a = cubic_coeff_a; SetupUpsampleFilterAntiAlias(p, input_paras, output_paras, scale_paras, roi, - alloc, get_original_coordinate, exclude_outside, false); + alloc, get_original_coordinate, exclude_outside, true); return UpsampleBaseAntiAlias(p, batch_size, num_channels, input_height, input_width, output_height, output_width, use_extrapolation, extrapolation_value, @@ -719,7 +674,7 @@ void UpsampleTrilinearAntiAlias(int64_t batch_size, float depth_scale, float height_scale, float width_scale, - const std::vector& roi, + gsl::span roi, bool use_extrapolation, float extrapolation_value, bool exclude_outside, diff --git a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h index a0e7ca1084fef..b768fedd8513a 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h +++ b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h @@ -3,11 +3,13 @@ #pragma once +#include #include #include #include #include -#include + +#include #include "core/common/status.h" #include #include @@ -58,7 +60,73 @@ enum class AspectRatioPolicy { NOT_SMALLER, }; +// Antialias types +template +struct AccumulateType { + using type = int32_t; + using Dtype = T; +}; + +template <> +struct AccumulateType { + using type = float; +}; + +template <> +struct AccumulateType { + using type = float; +}; + +template <> +struct AccumulateType { + using type = float; +}; + +template <> +struct AccumulateType { + using type = double; +}; + +namespace antialias_constants { +constexpr float kCubicCoeffA = -0.75f; +constexpr float kSupportSize = 2.0f; +constexpr float kBiCubicSupportSize = 4.0f; +} // namespace antialias_constants + +namespace ConstValue { +constexpr int32_t mag_factor = 1 << (22 - 1); +// We use to multiply by 2, let's make a constant which is twice as big +constexpr int32_t mag_factor_x_2 = 1 << 22; +} // namespace ConstValue + +template +inline constexpr bool is_8bit_v = std::is_same::value || std::is_same::value; + +template +void PrintAntiAliasBuffers(std::ostream& os, gsl::span bounds, gsl::span out_of_bounds, + gsl::span weight_coefficients) { + os << "#### Bounds: "; + std::copy(bounds.begin(), bounds.end(), std::ostream_iterator(os, " ")); + os << std::endl; + + os << "#### Out of Bounds: "; + std::copy(out_of_bounds.begin(), out_of_bounds.end(), + std::ostream_iterator(os, " ")); + os << std::endl; + + os << "#### Scale Buffer: "; + std::copy(weight_coefficients.begin(), weight_coefficients.end(), + std::ostream_iterator(os, " ")); + os << std::endl; +} + class UpsampleBase { + public: + // Make this available in other EP via provider bridge + // it works iff output_shape is specified + void AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims, + InlinedVector& scales) const; + protected: explicit UpsampleBase(const OpKernelInfo& info) : scales_cached_(false), roi_cached_(false), use_extrapolation_(false) { @@ -69,23 +137,32 @@ class UpsampleBase { std::string mode; ORT_ENFORCE(info.GetAttr("mode", &mode).IsOK()); mode_ = StringToUpsampleMode(mode); - antialias_ = info.GetAttrOrDefault("antialias", 0) == 0 ? false : true; - if (antialias_) { - ORT_ENFORCE((UpsampleMode::LINEAR == mode_ || UpsampleMode::CUBIC == mode_), - "when anti-aliasing is set, Resize only supports mode `LINEAR` and `CUBIC`."); - } auto input_count = info.GetInputCount(); if (input_count == 1) { // opset < 10 - ORT_THROW_IF_ERROR(info.GetAttrs("scales", scales_)); - ORT_THROW_IF_ERROR(ScalesValidation(scales_, mode_)); + std::vector scales; + ORT_THROW_IF_ERROR(info.GetAttrs("scales", scales)); + ORT_THROW_IF_ERROR(ScalesValidation(scales, mode_)); + scales_.assign(scales.cbegin(), scales.cend()); scales_cached_ = true; } - std::string keep_aspect_ratio_policy = info.GetAttrOrDefault("keep_aspect_ratio_policy", "stretch"); - keep_aspect_ratio_policy_ = StringToKeepAspectRatioPolicy(keep_aspect_ratio_policy); + if (opset >= 18) { + antialias_ = info.GetAttrOrDefault("antialias", 0) == 0 ? false : true; + + if (antialias_) { + ORT_ENFORCE((UpsampleMode::LINEAR == mode_ || UpsampleMode::CUBIC == mode_), + "when anti-aliasing is set, Resize only supports mode `LINEAR` and `CUBIC`."); + } - axes_ = info.GetAttrsOrDefault("axes"); + // The attribute is absent in opset < 18, but the default value as if stretch. + std::string keep_aspect_ratio_policy = info.GetAttrOrDefault("keep_aspect_ratio_policy", "stretch"); + keep_aspect_ratio_policy_ = StringToKeepAspectRatioPolicy(keep_aspect_ratio_policy); + + // guard against unit tests that can add an attribute + auto axes = info.GetAttrsOrDefault("axes"); + axes_.assign(axes.cbegin(), axes.cend()); + } extrapolation_value_ = info.GetAttrOrDefault("extrapolation_value", 0.0f); @@ -112,7 +189,7 @@ class UpsampleBase { nearest_mode_ = StringToNearestMode(nearest_mode_name); get_nearest_pixel_ = GetNearestPixelFromOriginal(nearest_mode_); - cubic_coeff_a_ = info.GetAttrOrDefault("cubic_coeff_a", -0.75f); + cubic_coeff_a_ = info.GetAttrOrDefault("cubic_coeff_a", antialias_constants::kCubicCoeffA); exclude_outside_ = info.GetAttrOrDefault("exclude_outside", 0) == 0 ? false : true; if ((exclude_outside_ == 1 && mode_ != CUBIC) && (antialias_ == false || mode_ != LINEAR)) { @@ -166,7 +243,7 @@ class UpsampleBase { ResizeCoordinateTransformationMode coordinate_transform_mode_; GetOriginalCoordinateFunc get_original_coordinate_; ResizeNearestMode nearest_mode_; - AspectRatioPolicy keep_aspect_ratio_policy_; + AspectRatioPolicy keep_aspect_ratio_policy_{AspectRatioPolicy::STRETCH}; GetNearestPixelFunc get_nearest_pixel_; float cubic_coeff_a_; bool exclude_outside_; @@ -174,9 +251,9 @@ class UpsampleBase { float extrapolation_value_; bool use_nearest2x_optimization_ = false; - std::vector scales_; - std::vector roi_; - std::vector axes_; + InlinedVector scales_; + InlinedVector roi_; + TensorShapeVector axes_; bool scales_cached_; bool roi_cached_; @@ -335,7 +412,7 @@ class UpsampleBase { } } - [[nodiscard]] Status ScalesValidation(const std::vector& scales, const UpsampleMode mode) const { + [[nodiscard]] Status ScalesValidation(gsl::span scales, const UpsampleMode mode) const { if (!is_resize_) { for (auto& scale : scales) { ORT_RETURN_IF_NOT(scale >= 1, "Scale value should be greater than or equal to 1."); @@ -372,7 +449,7 @@ class UpsampleBase { } [[nodiscard]] Status - ParseScalesData(const Tensor* scale, std::vector& scales, int64_t rank) const { + ParseScalesData(const Tensor* scale, InlinedVector& scales, int64_t rank) const { const auto* scale_data = scale->Data(); int64_t scales_size = scale->Shape().Size(); ORT_RETURN_IF_NOT(scales_size > 0, "scales size should be greater than 0."); @@ -387,19 +464,19 @@ class UpsampleBase { // in which case the other axes is ignored and use default scale of 1 // scales_size == axes_.size() should be guaranteed if axes is not empty if (rank > 0 && (scales_size != rank || axes_.size())) { - std::vector new_scales(size_t(rank), 1.0f); + InlinedVector new_scales(size_t(rank), 1.0f); ORT_RETURN_IF_NOT(*std::max_element(axes_.begin(), axes_.end()) < rank && (int64_t(axes_.size()) == scales_size), "all values in axes should be less than rank of the data"); for (size_t i = 0; i < axes_.size(); i++) { new_scales[static_cast(axes_[i])] = scales[i]; } - scales = new_scales; + scales.swap(new_scales); } return ScalesValidation(scales, mode_); } - void ParseRoiData(const Tensor* roi, std::vector& roi_array) const { + void ParseRoiData(const Tensor* roi, InlinedVector& roi_array) const { int64_t roi_size = roi->Shape().Size(); if (roi_size > 0) { roi_array.resize(onnxruntime::narrow(roi_size)); @@ -429,52 +506,11 @@ class UpsampleBase { return Status::OK(); } - // it works iff output_shape is specified - void AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims, - std::vector& scales) const { - std::unordered_set axes_set(axes_.begin(), axes_.end()); - - // AspectRatioPolicy::STRETCH is default policy when opset < 18 - if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::STRETCH) { - return; - } - - float scale_in_policy = 0.0f; - if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_LARGER) { - scale_in_policy = std::numeric_limits::max(); - - for (size_t i = 0; i < scales.size(); i++) { - if (axes_set.empty() || axes_set.count(i) > 0) { - scale_in_policy = std::min(scale_in_policy, scales[i]); - } - } - } else if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_SMALLER) { - scale_in_policy = std::numeric_limits::min(); - - for (size_t i = 0; i < scales.size(); i++) { - if (axes_set.empty() || axes_set.count(i) > 0) { - scale_in_policy = std::max(scale_in_policy, scales[i]); - } - } - } - - for (size_t i = 0; i < scales.size(); i++) { - // if axes is not specified (AKA axes_set.empty()), we apply the policy to all axes - if (axes_set.empty() || axes_set.count(i) > 0) { - scales[i] = scale_in_policy; - output_dims[i] = static_cast(std::round(scales[i] * input_dims[i])); - } else { - scales[i] = 1.0f; - output_dims[i] = input_dims[i]; - } - } - } - // It's different in Opset 18 and before. // we will modify output_shape by sorts of policy even if it's specified [[nodiscard]] Status ParseScalesDataAndAdjustOutputSize(TensorShapeVector& output_dims, gsl::span input_dims, - std::vector& scales) const { + InlinedVector& scales) const { for (size_t i = 0, end = input_dims.size(); i < end; ++i) { // Handle corner case to avoid dividing by zero in the next step if (input_dims[i] == 0) { @@ -507,9 +543,9 @@ class UpsampleBase { // Roi is redefined in Opset-18, we have a concept of axes. // So we need to update it accordingly. - void ComputeROIWithAxes(std::vector& roi_array, size_t rank) const { + void ComputeROIWithAxes(InlinedVector& roi_array, size_t rank) const { if (axes_.size()) { - std::vector roi_tmp(rank * 2, 0); + InlinedVector roi_tmp(rank * 2, 0); for (size_t i = rank; i < rank * 2; ++i) { roi_tmp[i] = 1; } @@ -518,9 +554,32 @@ class UpsampleBase { roi_tmp[v_in_axes] = (roi_array[i]); roi_tmp[rank + v_in_axes] = (roi_array[axes_.size() + i]); } - roi_array = roi_tmp; + roi_array.swap(roi_tmp); } } + + public: + static constexpr size_t kLookupTableSize = 1280; + + static const uint8_t* GetLookupTableShared() { + // initialized once + static const auto* lookup_table = []() { + // if we have already initialized the lookup table, just return + // ideally we could have a global lookup table, but that account for too much space. + /* Handles values form -640 to 639. */ + static uint8_t table[kLookupTableSize] = {0}; + + // taken from https://github.com/python-pillow/Pillow/blob/66add095a50d76c35c7f58643461f2edf78a3f05/src/libImaging/Resample.c#L94 + // we need to handle negative values + // it's equivalent to :x = np.clip(x, 0, 255) where x \in [-640, 639] + // we will accept a negative x for (&table[640])[x] means table +640 -x + for (int i = 0; i < static_cast(kLookupTableSize); ++i) { + table[i] = static_cast(std::min(std::max(i - 640, 0), 255)); + } + return table; + }(); + return lookup_table; + } }; // UpsampleBase } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/atomic/common.cuh b/onnxruntime/core/providers/cuda/atomic/common.cuh index 14fa2d0706f73..170aa3a2d8d0c 100644 --- a/onnxruntime/core/providers/cuda/atomic/common.cuh +++ b/onnxruntime/core/providers/cuda/atomic/common.cuh @@ -122,5 +122,316 @@ __device__ __forceinline__ void AtomicAdd(half* start_addr, size_t index, #endif } +// Disable default template instantiation. +// For every type T, we need to define a specialization +// to select the right type for calling atomicCAS. +template +class AtomicCasType; + +template<> +class AtomicCasType { + public: + using type = unsigned short int; + static const unsigned int mask = 0xffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned short int; + static const unsigned int mask = 0xffffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned int; + static const unsigned int mask = 0xffffffffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned long long int; + static const unsigned int mask = 0xffffffffu; +}; + +template<> +class AtomicCasType { + public: + using type = int; + static const unsigned int mask = 0xffffffffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned long long int; + static const unsigned int mask = 0xffffffffu; +}; + +// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh. +// +// This function compute 8-bit atomic binary operation using 32-bit atomicCAS. +// It accumulate `val` into the `address` using the `func`. +// The accumulation is atomic (i.e., thread-safe). +// +// E.g., Assume ValueType is +// int8_t +// and BinaryFunc is +// struct AddFunc { +// __device__ __forceinline__ int8_t operator()(int8_t a, int8_t b) const { +// return a + b; +// } +// This function becomes atomic_add for int8_t. +template +__device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* address, ValueType val, BinaryFunc func) { + // Assert to ensure the following bit-wise manipulation is correct. + static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4, + "ValueType must be 1-byte, 2-byte or 4-byte large."); + // Number of bytes to the lower 4-byte aligned address. + // If the current address is b1010"10", then offset = b10 = 2, + // which means the current address is 2 bytes away from + // the lower 4-byte aligned address b1010"00". + size_t offset = (size_t)address & 3; + // Find an new 4-byte aligned address `address_as_ui` lower than + // or equal to `address`. Lower than `address` so that the actual + // int8_t byte is in the 4-byte word that we load. + // + // This address has the following properties: + // 1. It is 4-byte aligned. + // 2. It is lower than or equal to `address`. + // 3. De-referencing this address may return + // a uint32_t value that contains the same int8_t + // value indicated by `address`. + // + // E.g., + // address = b101010 + // offset = b101010 & b000011 = b10 = 2 + // (char*)address - offset => (char*)b101010 - b000010 => b1010"00", + // which is (32-bit aligned). + uint32_t * address_as_ui = (uint32_t*)((char*)address - offset); + uint32_t old = *address_as_ui; + // E.g., offset = 2. + // address_as_ui is an address 2 bytes lower than `address`. + // + // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... + // ^ ^ ^ + // | | | + // | address <--- offset * 8 (bit)-----> address_as_ui + // | ^ + // | | + // ------------------------- *address_as_ui ----------------------- + // + // This visualization shows + // 1. the 32-bit word at address_as_ui. + // 2. the gap between address_as_ui and address. + // 3. *address_as_ui contains the int8_t value at `address`. + uint32_t shift = offset * 8; + uint32_t old_byte; + uint32_t newval; + uint32_t assumed; + do { + assumed = old; + // Select 8-bit value from 32-bit word. Assume offset = 2 (byte), so + // we want to select the 3rd byte (byte 2 below) from the word. + // + // Journey of a 32-bit value: + // + // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... + // + // | + // | old >> offset * 8, where offset = 2. + // | Effectively, push lower two bytes + // | out of the word. + // V + // + // 00000000 | 00000000 | ..... byte 3 ..... | ..... byte 2 ..... + // + // | apply bit-wise AND, + // | & 0xff (i.e., & b11111111), + // | so that we only keep + // | the byte of interest. + // | Otherwise, overflow may + // | happen when casting this + // | 32-bit value to int8_t. + // V + // + // 00000000 | 00000000 | 00000000 | ..... byte 2 ..... + old_byte = (old >> shift) & AtomicCasType::mask; + // Compute new int8_t value and store it to newrawvalue. + // Journey of a 32-bit value (cont'd): + // + // newrawvalue + // ... new byte 2 ... + auto newrawvalue = func(val, reinterpret_cast(old_byte)); + // Put the new int8_t value back to 32-bit word. + // Also ensure that bits not occupied by the int8_t value are 0s. + // + // Journey of a 32-bit value (cont'd): + // + // reinterpret_cast(newrawvalue) + // random values | random values | random values | ... new byte 2 ... + // + // reinterpret_cast(newrawvalue) & AtomicCasType::mask + // 00000000 | 00000000 | 00000000 | ... new byte 2 ... + newval = reinterpret_cast(newrawvalue) & AtomicCasType::mask; + // Journey of a 32-bit value (cont'd): + // + // old + // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... + // + // 0x000000ff + // 00000000 | 00000000 | 00000000 | 11111111 + // + // 0x000000ff << shift + // 00000000 | 11111111 | 00000000 | 00000000 + // + // ~(0x000000ff << shift) + // 11111111 | 00000000 | 11111111 | 11111111 + // + // old & ~(0x000000ff << shift) + // ..... byte 3 ..... | 00000000 | ..... byte 1 ..... | ..... byte 0 ..... + // + // newval << shift + // 00000000 | ... new byte 2 ... | 00000000 | 00000000 + // + // (old & ~(0x000000ff << shift)) | (newval << shift) + // ..... byte 3 ..... | ... new byte 2 ... | ..... byte 1 ..... | ..... byte 0 ..... + newval = (old & ~(AtomicCasType::mask << shift)) | (newval << shift); + old = atomicCAS(address_as_ui, assumed, newval); + } while (assumed != old); +} + +// It accumulates `val` into the `address` using the `func`. +// This function is thread-safe (i.e., atomic). +template +__device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) { + ValueType observed = *address, assumed, new_value; + using CasType = typename AtomicCasType::type; + static_assert(sizeof(ValueType) == sizeof(CasType), + "ValueType and CasType must have the same size for calling atomicCAS."); + auto address_as_cas_type = reinterpret_cast(address); + do { + // Record the value used to compute new value. + assumed = observed; + + // Compute expected new value. + new_value = func(observed, val); + + // Cast to aribitrary 2-byte type to desired integer type supported by atomicCAS. + // 4 + // 8 + auto observed_as_cas_type = *reinterpret_cast(&observed); + auto new_value_as_cas_type = *reinterpret_cast(&new_value); + + // Call atomicCAS as if the 2-byte type variables are all unsigned short int. + // 4 unsigned int (or int) + // 8 unsigned long long int + auto cas_observed_as_cas_type = atomicCAS(address_as_cas_type, observed_as_cas_type, new_value_as_cas_type); + + // Cast the freshly observed value in memory back to the TwoByteType. + observed = *reinterpret_cast(&cas_observed_as_cas_type); + + // Two cases: + // 1. compare-and-swap success + // a. `address` holds `new_value` + // b. `observed` becomes the new value after the assignment. + // Thus, the following `observed != new_value` is false, + // and the loop terminates. + // 2. compare-and-swap fails + // a. `address` holds a value different from `observed`, thus, + // the `new_value` is stale. + // b. `observed` becomes the fresh value observed in `address`. + // Thus, the following (observed != new_value) is true, + // and the loop continues. In the next iteration, the + // `new_value` is computed again using the fresh `observed`. + } while (observed != assumed); +} + +struct AddFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +struct MulFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return a * b; + } +}; + +struct MaxFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return b > a ? b : a; + } +}; + +struct MinFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return b < a ? b : a; + } +}; + +__device__ __forceinline__ void atomic_add(int8_t* address, int8_t value) { + atomic_byte_func_with_unit32_cas(address, value, AddFunc()); +} +__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t value) { + atomic_byte_func_with_unit32_cas(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(int8_t* address, int8_t value) { + atomic_byte_func_with_unit32_cas(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) { + atomic_byte_func_with_unit32_cas(address, value, MinFunc()); +} + +__device__ __forceinline__ void atomic_mul(half* address, half value) { +#if __CUDA_ARCH__ >= 700 + atomic_binary_func(address, value, MulFunc()); +#else + atomic_byte_func_with_unit32_cas(address, value, MulFunc()); +#endif +} +__device__ __forceinline__ void atomic_max(half* address, half value) { +#if __CUDA_ARCH__ >= 700 + atomic_binary_func(address, value, MaxFunc()); +#else + atomic_byte_func_with_unit32_cas(address, value, MaxFunc()); +#endif +} +__device__ __forceinline__ void atomic_min(half* address, half value) { +#if __CUDA_ARCH__ >= 700 + atomic_binary_func(address, value, MinFunc()); +#else + atomic_byte_func_with_unit32_cas(address, value, MinFunc()); +#endif +} + +__device__ __forceinline__ void atomic_mul(float* address, float value) { + atomic_binary_func(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(float* address, float value) { + atomic_binary_func(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(float* address, float value) { + atomic_binary_func(address, value, MinFunc()); +} + +__device__ __forceinline__ void atomic_mul(double* address, double value) { + atomic_binary_func(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(double* address, double value) { + atomic_binary_func(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(double* address, double value) { + atomic_binary_func(address, value, MinFunc()); +} + + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index 0d9928baa86e0..bed2f677166d6 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -194,13 +194,13 @@ template <> __device__ __inline__ half _Ceil(half a) { return half(ceilf((float)a)); } template -__device__ __inline__ T _Floor(T a); +__device__ __host__ __inline__ T _Floor(T a); template <> -__device__ __inline__ float _Floor(float a) { return floorf(a); } +__device__ __host__ __inline__ float _Floor(float a) { return floorf(a); } template <> -__device__ __inline__ double _Floor(double a) { return floor(a); } +__device__ __host__ __inline__ double _Floor(double a) { return floor(a); } template <> __device__ __inline__ half _Floor(half a) { return half(floorf((float)a)); } @@ -230,13 +230,13 @@ template <> __device__ __inline__ half _Erf(half a) { return half(erff((float)a)); } template -__device__ __inline__ T _Round(T a); +__device__ __host__ __inline__ T _Round(T a); template <> -__device__ __inline__ float _Round(float a) { return rintf(a); } +__device__ __host__ __inline__ float _Round(float a) { return rintf(a); } template <> -__device__ __inline__ double _Round(double a) { return rint(a); } +__device__ __host__ __inline__ double _Round(double a) { return rint(a); } template <> __device__ __inline__ half _Round(half a) { @@ -438,6 +438,157 @@ __device__ __inline__ BFloat16 _Fmod(BFloat16 a, BFloat16 b) { return fmodf((float)a, (float)b); } +namespace isinf_details { +template +struct IsInfTyped { + static __device__ __inline__ bool IsInf(T a) { + // cast is needed because on non MS compilers, + // because there isinf() returns int + // and we want to avoid stupid warnings + return static_cast(isinf(a)); + } + static __device__ __inline__ bool IsInfPos(T a) { + return a == std::numeric_limits::infinity(); + } + static __device__ __inline__ bool IsInfNeg(T a) { + return a == -std::numeric_limits::infinity(); + } +}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(half a) { + return MLFloat16::kPositiveInfinityBits == + static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask); + } + static __device__ __inline__ bool IsInfPos(half a) { + return MLFloat16::kPositiveInfinityBits == *reinterpret_cast(&a); + } + static __device__ __inline__ bool IsInfNeg(half a) { + return MLFloat16::kNegativeInfinityBits == *reinterpret_cast(&a); + } +}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(BFloat16 a) { + return BFloat16::kPositiveInfinityBits == + static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask); + } + static __device__ __inline__ bool IsInfPos(BFloat16 a) { + return BFloat16::kPositiveInfinityBits == *reinterpret_cast(&a); + } + static __device__ __inline__ bool IsInfNeg(BFloat16 a) { + return BFloat16::kNegativeInfinityBits == *reinterpret_cast(&a); + } +}; + +#if !defined(DISABLE_FLOAT8_TYPES) + +template +struct ReturnFalse { + constexpr static bool __device__ __inline__ IsInf(T) { return false; } + constexpr static bool __device__ __inline__ IsInfPos(T) { return false; } + constexpr static bool __device__ __inline__ IsInfNeg(T) { return false; } +}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(Float8E5M2 a) { + return a.val == 0b01111100 || a.val == 0b11111100; + } + static __device__ __inline__ bool IsInfPos(Float8E5M2 a) { + return a.val == 0b01111100; + } + static __device__ __inline__ bool IsInfNeg(Float8E5M2 a) { + return a.val == 0b11111100; + } +}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +#endif +} // namespace isinf_details + +template +struct _IsInf { + __device__ __inline__ bool operator()(T a) const { + if constexpr (detect_positive && detect_negative) { + return isinf_details::IsInfTyped::IsInf(a); + } else if constexpr (detect_positive) { + return isinf_details::IsInfTyped::IsInfPos(a); + } else if constexpr (detect_negative) { + return isinf_details::IsInfTyped::IsInfNeg(a); + } else { + return false; + } + } +}; + +// float and double +template +struct _IsNan { + __device__ __inline__ bool operator()(T a) const { + return isnan(a); + } +}; + +template <> +struct _IsNan { + __device__ __inline__ bool operator()(half a) const { + return static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask) + > MLFloat16::kPositiveInfinityBits; + } +}; + +template <> +struct _IsNan { + __device__ __inline__ bool operator()(BFloat16 a) const { + return static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask) + > BFloat16::kPositiveInfinityBits; + } +}; + +#if !defined(DISABLE_FLOAT8_TYPES) + +template<> +struct _IsNan { + __device__ __inline__ bool operator()(Float8E4M3FN a) const { + return (*reinterpret_cast(&a) & 0x7f) == 0x7f; + } +}; + +template<> +struct _IsNan { + __device__ __inline__ bool operator()(Float8E4M3FNUZ a) const { + return *reinterpret_cast(&a) == 0x80; + } +}; + +template<> +struct _IsNan { + __device__ __inline__ bool operator()(Float8E5M2 a) const { + uint8_t c = *reinterpret_cast(&a); + return ((c & 0x7c) == 0x7c) && ((c & 0x03) != 0x00); + } +}; + +template<> +struct _IsNan { + __device__ __inline__ bool operator()(Float8E5M2FNUZ a) const { + return *reinterpret_cast(&a) == 0x80; + } +}; + +#endif + // We would like to use 64-bit integer to support large matrices. However, CUDA seems to support only 32-bit integer // For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type. #ifndef CUDA_LONG diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index e9941ce743bc3..61da125b40953 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -70,6 +70,15 @@ class ToCudaType { } }; +template <> +class ToCudaType { + public: + typedef Float8E4M3FNUZ MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + template <> class ToCudaType { public: @@ -79,6 +88,15 @@ class ToCudaType { } }; +template <> +class ToCudaType { + public: + typedef Float8E5M2FNUZ MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + #endif inline bool CalculateFdmStrides(gsl::span p, const std::vector& dims) { @@ -141,8 +159,7 @@ class HalfGemmOptions { } #else cublasMath_t GetMathMode() const { - // CublasMathModeSetter will check whether device has tensor cores later. - return CUBLAS_TENSOR_OP_MATH; + return CUBLAS_DEFAULT_MATH; } cudaDataType GetComputeType() const { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 3fc4ed355a12b..05d9f3b5a1e8f 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -3,6 +3,7 @@ // Licensed under the MIT License. #include "core/common/inlined_containers.h" +#include "core/common/parse_string.h" #include "core/providers/shared_library/provider_api.h" #include "core/platform/env_var_utils.h" #include "core/providers/cuda/cuda_execution_provider.h" @@ -11,6 +12,7 @@ #include "core/providers/cuda/cuda_fwd.h" #include "core/providers/cuda/gpu_data_transfer.h" #include "core/providers/cuda/cuda_profiler.h" +#include "core/session/onnxruntime_run_options_config_keys.h" #ifndef USE_CUDA_MINIMAL #ifndef DISABLE_CONTRIB_OPS @@ -190,31 +192,60 @@ CUDAExecutionProvider::PerThreadContext::~PerThreadContext() { #endif } -bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const { - return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; +bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptureAllowed( + CudaGraphAnnotation_t cuda_graph_annotation_id) const { + if (!IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id)) { + return false; + } + if (graph_id_to_run_count_.find(cuda_graph_annotation_id) == graph_id_to_run_count_.end()) { + return false; + } + return graph_id_to_run_count_.at(cuda_graph_annotation_id) >= min_num_runs_before_cuda_graph_capture_; +} + +bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptureAllowedOnRun( + CudaGraphAnnotation_t cuda_graph_annotation_id) const { + return cuda_graph_.IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id); +} + +CudaGraphAnnotation_t CUDAExecutionProvider::PerThreadContext::GetCudaGraphAnnotationId( + const onnxruntime::RunOptions& run_options) const { + auto graph_annotation_str = + run_options.GetConfigOptions().GetConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation); + // If graph annotation is not provided, fall back to the one cuda graph per session behavior + CudaGraphAnnotation_t cuda_graph_annotation_id = 0; + if (graph_annotation_str.has_value()) { + ORT_ENFORCE(TryParseStringWithClassicLocale(*graph_annotation_str, cuda_graph_annotation_id), + "Failed to parse the cuda graph annotation id: ", + *graph_annotation_str); + } + + return cuda_graph_annotation_id; } -void CUDAExecutionProvider::PerThreadContext::CaptureBegin() { - cuda_graph_.Reset(); - cuda_graph_.CaptureBegin(); +void CUDAExecutionProvider::PerThreadContext::CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id) { + cuda_graph_.CaptureBegin(cuda_graph_annotation_id); } -void CUDAExecutionProvider::PerThreadContext::CaptureEnd() { - cuda_graph_.CaptureEnd(); - is_graph_captured_ = true; +void CUDAExecutionProvider::PerThreadContext::CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id) { + cuda_graph_.CaptureEnd(cuda_graph_annotation_id); } -bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptured() const { - return is_graph_captured_; +bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptured(CudaGraphAnnotation_t graph_annotation_id) const { + return cuda_graph_.IsGraphCaptured(graph_annotation_id); } -Status CUDAExecutionProvider::PerThreadContext::ReplayGraph() { - ORT_ENFORCE(IsGraphCaptured()); - return cuda_graph_.Replay(); +Status CUDAExecutionProvider::PerThreadContext::ReplayGraph(CudaGraphAnnotation_t graph_annotation_id) { + return cuda_graph_.Replay(graph_annotation_id); } -void CUDAExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() { - ++regular_run_count_before_graph_capture_; +void CUDAExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture( + CudaGraphAnnotation_t cuda_graph_annotation_id) { + if (graph_id_to_run_count_.find(cuda_graph_annotation_id) == graph_id_to_run_count_.end()) { + graph_id_to_run_count_[cuda_graph_annotation_id] = 1; + return; + } + graph_id_to_run_count_[cuda_graph_annotation_id]++; } void OverrideTunableOpInfoByEnv(CUDAExecutionProviderInfo& info) { @@ -386,25 +417,28 @@ Status CUDAExecutionProvider::Sync() const { return Status::OK(); } -Status CUDAExecutionProvider::OnRunStart() { +Status CUDAExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) { // always set CUDA device when session::Run() in case it runs in a worker thread CUDA_RETURN_IF_ERROR(cudaSetDevice(GetDeviceId())); - if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { + CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCudaGraphAnnotationId(run_options); + if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id) && + GetPerThreadContext().IsGraphCaptureAllowed(cuda_graph_annotation_id)) { LOGS(*GetLogger(), INFO) << "Capturing the cuda graph for this model"; - GetPerThreadContext().CaptureBegin(); + GetPerThreadContext().CaptureBegin(cuda_graph_annotation_id); } return Status::OK(); } -Status CUDAExecutionProvider::OnRunEnd(bool sync_stream) { - if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) { - if (GetPerThreadContext().IsGraphCaptureAllowed()) { - GetPerThreadContext().CaptureEnd(); +Status CUDAExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) { + CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCudaGraphAnnotationId(run_options); + if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id)) { + if (GetPerThreadContext().IsGraphCaptureAllowed(cuda_graph_annotation_id)) { + GetPerThreadContext().CaptureEnd(cuda_graph_annotation_id); // CUDA work issued to a capturing stream doesn’t actually run on the GPU, // so run the captured graph here to actually execute the work. - ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph()); + ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id)); } else { - GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture(); + GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture(cuda_graph_annotation_id); } } @@ -433,12 +467,12 @@ bool CUDAExecutionProvider::IsGraphCaptureEnabled() const { return info_.enable_cuda_graph; } -bool CUDAExecutionProvider::IsGraphCaptured() const { - return GetPerThreadContext().IsGraphCaptured(); +bool CUDAExecutionProvider::IsGraphCaptured(int graph_annotation_id) const { + return GetPerThreadContext().IsGraphCaptured(graph_annotation_id); } -Status CUDAExecutionProvider::ReplayGraph() { - return GetPerThreadContext().ReplayGraph(); +Status CUDAExecutionProvider::ReplayGraph(int graph_annotation_id) { + return GetPerThreadContext().ReplayGraph(graph_annotation_id); } namespace cuda { @@ -722,6 +756,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, uint32_t, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, uint64_t, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, bool, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, IsNaN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, 10, float, Pad); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, 10, double, Pad); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, 10, MLFloat16, Pad); @@ -830,6 +865,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, MLFloat16, ThresholdedRelu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, TopK); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, Mod); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 19, IsInf); // opset 11 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Compress); @@ -913,7 +949,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom // OpSet 12 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Clip); - class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, float, MaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, double, MaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, MLFloat16, MaxPool); @@ -989,6 +1024,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Sqrt); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Sqrt); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Sqrt); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Sqrt); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Log); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Log); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Log); @@ -1046,7 +1082,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Size); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Transpose); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterElements); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 15, ScatterElements); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Softmax); @@ -1061,6 +1097,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, U class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Concat); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Gather); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, GatherElements); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 19, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, MatMul); @@ -1108,11 +1145,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceSumSquare); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Dropout); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, float, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, double, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, int32_t, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, uint8_t, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, If); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, Loop); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Flatten); @@ -1201,9 +1238,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, LSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, LSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( + kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( + kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( + kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMin); @@ -1254,6 +1294,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample); // Opset 17 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); @@ -1269,10 +1311,16 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, bool, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, uint8_t, Resize); // Opset 19 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Cast); @@ -1326,6 +1374,13 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, S class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Shape); #endif +// Opset 20 +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, float, Gelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, Gelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsNaN); + template <> KernelCreateInfo BuildKernelCreateInfo() { return {}; @@ -1510,6 +1565,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1722,6 +1778,8 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // opset 11 BuildKernelCreateInfo, @@ -1880,6 +1938,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1933,11 +1992,12 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1999,11 +2059,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2092,9 +2152,12 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2138,6 +2201,8 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 17 BuildKernelCreateInfo, @@ -2159,10 +2224,16 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 19 BuildKernelCreateInfo, @@ -2216,6 +2287,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + // Opset 20 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index d0bb2321edf0a..f53779058a8af 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -29,9 +29,9 @@ class CUDAExecutionProvider : public IExecutionProvider { Status Sync() const override; - Status OnRunStart() override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; - Status OnRunEnd(bool sync_stream) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; DataLayout GetPreferredLayout() const override; @@ -78,6 +78,7 @@ class CUDAExecutionProvider : public IExecutionProvider { bool GetCudnnConv1dPadToNc1d() const { return info_.cudnn_conv1d_pad_to_nc1d; } bool IsSkipLayerNormInStrictMode() const { return info_.enable_skip_layer_norm_strict_mode; } bool IsNHWCPreferred() const { return info_.prefer_nhwc; } + bool UseTF32() const { return info_.use_tf32; } ProviderOptions GetProviderOptions() const override { return CUDAExecutionProviderInfo::ToProviderOptions(info_); @@ -91,8 +92,8 @@ class CUDAExecutionProvider : public IExecutionProvider { std::unique_ptr GetProfiler() override; bool IsGraphCaptureEnabled() const override; - bool IsGraphCaptured() const override; - Status ReplayGraph() override; + bool IsGraphCaptured(CudaGraphAnnotation_t graph_annotation_id) const override; + Status ReplayGraph(CudaGraphAnnotation_t graph_annotation_id) override; void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; std::vector CreatePreferredAllocators() override; @@ -114,6 +115,7 @@ class CUDAExecutionProvider : public IExecutionProvider { PerThreadContext(OrtDevice::DeviceId device_id, cudaStream_t stream, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy, CUDAExecutionProviderExternalAllocatorInfo external_alloc_info, OrtArenaCfg* arena_cfg); ~PerThreadContext(); + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerThreadContext); cublasHandle_t CublasHandle() const { return cublas_handle_; @@ -129,41 +131,33 @@ class CUDAExecutionProvider : public IExecutionProvider { template const T* GetConstOnes(size_t count, cudaStream_t stream) { - constexpr bool is_float = std::is_same::value; - constexpr bool is_double = std::is_same::value; - constexpr bool is_half = std::is_same::value; - constexpr bool is_BFloat16 = std::is_same::value; -#if !defined(DISABLE_FLOAT8_TYPES) - constexpr bool is_Float8E4M3FN = std::is_same::value; - constexpr bool is_Float8E5M2 = std::is_same::value; -#endif - if (is_float) { + if constexpr (std::is_same::value) { if (!constant_ones_float_) { constant_ones_float_ = cuda::CreateConstantOnes(); } return reinterpret_cast(constant_ones_float_->GetBuffer(stream, count)); - } else if (is_double) { + } else if constexpr (std::is_same::value) { if (!constant_ones_double_) { constant_ones_double_ = cuda::CreateConstantOnes(); } return reinterpret_cast(constant_ones_double_->GetBuffer(stream, count)); - } else if (is_half) { + } else if constexpr (std::is_same::value) { if (!constant_ones_half_) { constant_ones_half_ = cuda::CreateConstantOnes(); } return reinterpret_cast(constant_ones_half_->GetBuffer(stream, count)); - } else if (is_BFloat16) { + } else if constexpr (std::is_same::value) { if (!constant_ones_bfloat16_) { constant_ones_bfloat16_ = cuda::CreateConstantOnes(); } return reinterpret_cast(constant_ones_bfloat16_->GetBuffer(stream, count)); #if !defined(DISABLE_FLOAT8_TYPES) - } else if (is_Float8E4M3FN) { + } else if constexpr (std::is_same::value) { if (!constant_ones_float8e4m3fn_) { constant_ones_float8e4m3fn_ = cuda::CreateConstantOnes(); } return reinterpret_cast(constant_ones_float8e4m3fn_->GetBuffer(stream, count)); - } else if (is_Float8E5M2) { + } else if constexpr (std::is_same::value) { if (!constant_ones_float8e5m2_) { constant_ones_float8e5m2_ = cuda::CreateConstantOnes(); } @@ -174,12 +168,14 @@ class CUDAExecutionProvider : public IExecutionProvider { } } - bool IsGraphCaptureAllowed() const; - void CaptureBegin(); - void CaptureEnd(); - bool IsGraphCaptured() const; - Status ReplayGraph(); - void IncrementRegularRunCountBeforeGraphCapture(); + bool IsGraphCaptureAllowed(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + bool IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + void CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id); + void CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id); + bool IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + CudaGraphAnnotation_t GetCudaGraphAnnotationId(const onnxruntime::RunOptions& run_options) const; + Status ReplayGraph(CudaGraphAnnotation_t cuda_graph_annotation_id); + void IncrementRegularRunCountBeforeGraphCapture(CudaGraphAnnotation_t cuda_graph_annotation_id); private: cublasHandle_t cublas_handle_ = nullptr; @@ -198,8 +194,8 @@ class CUDAExecutionProvider : public IExecutionProvider { // Cuda graph with multi threads will be supported in the future, so cuda_graph_ // is put under PerThreadContext. CUDAGraph cuda_graph_; - bool is_graph_captured_ = false; - int regular_run_count_before_graph_capture_ = 0; + // Map of graph id to regular_run_count_before_graph_capture + std::unordered_map graph_id_to_run_count_; // There is chance that the second regular run allocates GPU memory for causes like: // (1) memory pattern is enabled. (2) arena allocation for stream. diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc index 7b507296d5982..c96381e3e68b1 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc @@ -31,8 +31,10 @@ constexpr const char* kTunableOpEnable = "tunable_op_enable"; constexpr const char* kTunableOpTuningEnable = "tunable_op_tuning_enable"; constexpr const char* kTunableOpMaxTuningDurationMs = "tunable_op_max_tuning_duration_ms"; constexpr const char* kEnableSkipLayerNormStrictMode = "enable_skip_layer_norm_strict_mode"; -constexpr const char* kPreferNCHWMode = "prefer_nhwc"; -constexpr const char* KUseEPLevelUnifiedStream = "use_ep_level_unified_stream"; +constexpr const char* kPreferNHWCMode = "prefer_nhwc"; +constexpr const char* kUseEPLevelUnifiedStream = "use_ep_level_unified_stream"; +constexpr const char* kUseTF32 = "use_tf32"; + } // namespace provider_option_names } // namespace cuda @@ -112,8 +114,9 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P .AddAssignmentToReference(cuda::provider_option_names::kEnableCudaGraph, info.enable_cuda_graph) .AddAssignmentToReference(cuda::provider_option_names::kCudnnConv1dPadToNc1d, info.cudnn_conv1d_pad_to_nc1d) .AddAssignmentToReference(cuda::provider_option_names::kEnableSkipLayerNormStrictMode, info.enable_skip_layer_norm_strict_mode) - .AddAssignmentToReference(cuda::provider_option_names::kPreferNCHWMode, info.prefer_nhwc) - .AddAssignmentToReference(cuda::provider_option_names::KUseEPLevelUnifiedStream, info.use_ep_level_unified_stream) + .AddAssignmentToReference(cuda::provider_option_names::kPreferNHWCMode, info.prefer_nhwc) + .AddAssignmentToReference(cuda::provider_option_names::kUseEPLevelUnifiedStream, info.use_ep_level_unified_stream) + .AddAssignmentToReference(cuda::provider_option_names::kUseTF32, info.use_tf32) .AddValueParser( cuda::provider_option_names::kTunableOpEnable, [&info](const std::string& value_str) -> Status { @@ -164,8 +167,9 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution {cuda::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op.tuning_enable)}, {cuda::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op.max_tuning_duration_ms)}, {cuda::provider_option_names::kEnableSkipLayerNormStrictMode, MakeStringWithClassicLocale(info.enable_skip_layer_norm_strict_mode)}, - {cuda::provider_option_names::kPreferNCHWMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, - {cuda::provider_option_names::KUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, + {cuda::provider_option_names::kPreferNHWCMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, + {cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, + {cuda::provider_option_names::kUseTF32, MakeStringWithClassicLocale(info.use_tf32)}, }; return options; @@ -185,8 +189,9 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid {cuda::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op_enable)}, {cuda::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op_tuning_enable)}, {cuda::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op_max_tuning_duration_ms)}, - {cuda::provider_option_names::kPreferNCHWMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, - {cuda::provider_option_names::KUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, + {cuda::provider_option_names::kPreferNHWCMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, + {cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, + {cuda::provider_option_names::kUseTF32, MakeStringWithClassicLocale(info.use_tf32)}, }; return options; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h index b286f5a9161b0..1cac3d1513698 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h @@ -76,6 +76,9 @@ struct CUDAExecutionProviderInfo { bool use_ep_level_unified_stream{false}; + // By default, enable TF32 to speed up float GEMM/MatMul or cuDNN convolution of float matrices. + bool use_tf32{true}; + static CUDAExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const CUDAExecutionProviderInfo& info); static ProviderOptions ToProviderOptions(const OrtCUDAProviderOptionsV2& info); @@ -83,12 +86,37 @@ struct CUDAExecutionProviderInfo { } // namespace onnxruntime template <> -struct std::hash<::onnxruntime::cuda::TunableOpInfo> { - size_t operator()(const ::onnxruntime::cuda::TunableOpInfo& info) const { - size_t seed_and_value{0xbc9f1d34}; - onnxruntime::HashCombine(info.enable, seed_and_value); - onnxruntime::HashCombine(info.tuning_enable, seed_and_value); - onnxruntime::HashCombine(info.max_tuning_duration_ms, seed_and_value); - return seed_and_value; +struct std::hash<::onnxruntime::CUDAExecutionProviderInfo> { + size_t operator()(const ::onnxruntime::CUDAExecutionProviderInfo& info) const { + size_t value{0xbc9f1d34}; // seed + + // Bits: device_id (16), arena_extend_strategy/cudnn_conv_algo_search (reserved 2), boolean options (1 each) + size_t data = static_cast(info.device_id) ^ + (static_cast(info.arena_extend_strategy) << 16) ^ + (static_cast(info.cudnn_conv_algo_search) << 18) ^ + (static_cast(info.do_copy_in_default_stream) << 20) ^ + (static_cast(info.has_user_compute_stream) << 21) ^ + (static_cast(info.cudnn_conv_use_max_workspace) << 22) ^ + (static_cast(info.enable_cuda_graph) << 23) ^ + (static_cast(info.tunable_op.enable) << 24) ^ + (static_cast(info.tunable_op.tuning_enable) << 25) ^ + (static_cast(info.cudnn_conv1d_pad_to_nc1d) << 26) ^ + (static_cast(info.enable_skip_layer_norm_strict_mode) << 27) ^ + (static_cast(info.prefer_nhwc) << 28) ^ + (static_cast(info.use_ep_level_unified_stream) << 29) ^ + (static_cast(info.use_tf32) << 30); + onnxruntime::HashCombine(data, value); + + onnxruntime::HashCombine(info.gpu_mem_limit, value); + onnxruntime::HashCombine(info.tunable_op.max_tuning_duration_ms, value); + + // Memory pointers + onnxruntime::HashCombine(reinterpret_cast(info.user_compute_stream), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.alloc), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.free), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.empty_cache), value); + + // The default memory arena cfg is not used in hashing right now. + return value; } }; diff --git a/onnxruntime/core/providers/cuda/cuda_graph.cc b/onnxruntime/core/providers/cuda/cuda_graph.cc index 230d664391611..8353c654681fc 100644 --- a/onnxruntime/core/providers/cuda/cuda_graph.cc +++ b/onnxruntime/core/providers/cuda/cuda_graph.cc @@ -9,17 +9,44 @@ namespace onnxruntime { -CUDAGraph::CUDAGraph(cudaStream_t stream) : stream_(stream) { +CudaGraphSet::~CudaGraphSet() { + Clear(); } -void CUDAGraph::SetStream(cudaStream_t stream) { +void CudaGraphSet::Clear() { + for (auto& it : cuda_graphs_) { + CUDA_CALL_THROW(cudaGraphExecDestroy(it.second)); + } + cuda_graphs_.clear(); +} + +bool CudaGraphSet::Contains(CudaGraphAnnotation_t cuda_graph_annotation_id) const { + return cuda_graphs_.find(cuda_graph_annotation_id) != cuda_graphs_.end(); +} + +void CudaGraphSet::Put(CudaGraphAnnotation_t cuda_graph_annotation_id, cudaGraphExec_t graph_exec) { + ORT_ENFORCE(!Contains(cuda_graph_annotation_id)); + cuda_graphs_.emplace(cuda_graph_annotation_id, graph_exec); +} + +cudaGraphExec_t CudaGraphSet::Get(CudaGraphAnnotation_t cuda_graph_annotation_id) const { + ORT_ENFORCE(Contains(cuda_graph_annotation_id)); + return cuda_graphs_.at(cuda_graph_annotation_id); +} + +CUDAGraphManager::CUDAGraphManager(cudaStream_t stream) : stream_(stream) { +} + +void CUDAGraphManager::SetStream(cudaStream_t stream) { stream_ = stream; } -void CUDAGraph::CaptureBegin() { - ORT_ENFORCE(!has_graph_exec_, - "This cuda graph has already captured a graph. " - "Create a new instance to capture a new graph."); +void CUDAGraphManager::CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id) { + ORT_ENFORCE(IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id)); + + ORT_ENFORCE(!cuda_graph_set_.Contains(cuda_graph_annotation_id), + "Trying to capture a graph with annotation id ", cuda_graph_annotation_id, + " that already used. Please use a different annotation id."); CUDA_CALL_THROW(cudaStreamSynchronize(stream_)); // For now cuda graph can only work with a single thread. In the future, we @@ -29,40 +56,48 @@ void CUDAGraph::CaptureBegin() { CUDA_CALL_THROW(cudaStreamBeginCapture(stream_, cudaStreamCaptureModeGlobal)); } -void CUDAGraph::CaptureEnd() { - CUDA_CALL_THROW(cudaStreamEndCapture(stream_, &graph_)); - if (graph_ == NULL) { +void CUDAGraphManager::CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id) { + cudaGraph_t graph = NULL; + CUDA_CALL_THROW(cudaStreamEndCapture(stream_, &graph)); + if (graph == NULL) { ORT_THROW("CUDAGraph::CaptureEnd: graph_ is NULL"); } - has_graph_ = true; - CUDA_CALL_THROW(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0)); - has_graph_exec_ = true; - CUDA_CALL_THROW(cudaGraphDestroy(graph_)); - has_graph_ = false; + cudaGraphExec_t graph_exec = NULL; + CUDA_CALL_THROW(cudaGraphInstantiate(&graph_exec, graph, NULL, NULL, 0)); + CUDA_CALL_THROW(cudaGraphDestroy(graph)); + + // Currently all the captured graphs will be tied to the session's lifecycle + // TODO(wy): Addd an interface to free captured graphs + cuda_graph_set_.Put(cuda_graph_annotation_id, graph_exec); } -Status CUDAGraph::Replay() { +Status CUDAGraphManager::Replay(CudaGraphAnnotation_t cuda_graph_annotation_id) { // Although this function is not thread safe, the lock is not needed here because // CUDA EP maintains a separate cuda graph per thread - LOGS_DEFAULT(INFO) << "Replaying CUDA graph on stream " << stream_; - CUDA_RETURN_IF_ERROR(cudaGraphLaunch(graph_exec_, stream_)); + LOGS_DEFAULT(INFO) << "Replaying CUDA graph on stream " << stream_ << " with cuda_graph_annotation_id " + << cuda_graph_annotation_id; + + cudaGraphExec_t graph_exec = cuda_graph_set_.Get(cuda_graph_annotation_id); + CUDA_RETURN_IF_ERROR(cudaGraphLaunch(graph_exec, stream_)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); return Status::OK(); } -void CUDAGraph::Reset() { - if (has_graph_) { - CUDA_CALL_THROW(cudaGraphDestroy(graph_)); - has_graph_ = false; - } - if (has_graph_exec_) { - CUDA_CALL_THROW(cudaGraphExecDestroy(graph_exec_)); - has_graph_exec_ = false; - } +bool CUDAGraphManager::IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t cuda_graph_annotation_id) const { + return cuda_graph_annotation_id != kCudaGraphAnnotationSkip; +} + +bool CUDAGraphManager::IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const { + return cuda_graph_set_.Contains(cuda_graph_annotation_id); +} + +void CUDAGraphManager::Reset() { + cuda_graph_set_.Clear(); } -CUDAGraph::~CUDAGraph() { +CUDAGraphManager::~CUDAGraphManager() { Reset(); } diff --git a/onnxruntime/core/providers/cuda/cuda_graph.h b/onnxruntime/core/providers/cuda/cuda_graph.h index 9bcefcc64ea77..064994c1f14ae 100644 --- a/onnxruntime/core/providers/cuda/cuda_graph.h +++ b/onnxruntime/core/providers/cuda/cuda_graph.h @@ -3,33 +3,55 @@ #pragma once +#include + #include "core/common/common.h" #include "core/platform/ort_mutex.h" #include "core/providers/cuda/cuda_pch.h" namespace onnxruntime { -using CaptureId_t = unsigned long long; +using CudaGraphAnnotation_t = int; +using CudaGraphSet_t = std::unordered_map; + +constexpr CudaGraphAnnotation_t kCudaGraphAnnotationSkip = -1; +constexpr CudaGraphAnnotation_t kCudaGraphAnnotationDefault = 0; + +struct CudaGraphSet { + CudaGraphSet(){}; + ~CudaGraphSet(); -struct CUDAGraph { - CUDAGraph(){}; - CUDAGraph(cudaStream_t stream); - ~CUDAGraph(); + void Clear(); + bool Contains(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + void Put(CudaGraphAnnotation_t cuda_graph_annotation_id, cudaGraphExec_t graph_exec); + cudaGraphExec_t Get(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + + private: + CudaGraphSet_t cuda_graphs_; +}; + +struct CUDAGraphManager { + CUDAGraphManager(){}; + CUDAGraphManager(cudaStream_t stream); + ~CUDAGraphManager(); void SetStream(cudaStream_t stream); - void CaptureBegin(); - void CaptureEnd(); - Status Replay(); + void CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id); + void CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id); + Status Replay(CudaGraphAnnotation_t cuda_graph_annotation_id); + void Reset(); - private: - cudaGraph_t graph_ = NULL; - cudaGraphExec_t graph_exec_ = NULL; + bool IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + bool IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const; - bool has_graph_ = false; - bool has_graph_exec_ = false; + private: + CudaGraphSet cuda_graph_set_; + CudaGraphAnnotation_t cuda_graph_annotation_id_ = kCudaGraphAnnotationDefault; cudaStream_t stream_ = nullptr; // Does not own the stream }; +using CUDAGraph = CUDAGraphManager; + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_kernel.h b/onnxruntime/core/providers/cuda/cuda_kernel.h index e3106e41e77c8..288da23f35ec8 100644 --- a/onnxruntime/core/providers/cuda/cuda_kernel.h +++ b/onnxruntime/core/providers/cuda/cuda_kernel.h @@ -90,6 +90,10 @@ class CudaKernel : public OpKernel { return stream->cublas_handle_; } + bool UseTF32() const { + return provider_->UseTF32(); + } + tunable::CudaTuningContext* GetTuningContext() const { return static_cast(provider_->GetTuningContext()); } diff --git a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc index f416caecd115f..8fdcaacdb0f29 100644 --- a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc +++ b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc @@ -18,10 +18,14 @@ namespace onnxruntime::cuda { class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, double, + BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, MLFloat16, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, double, + BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, MLFloat16, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, float, @@ -72,12 +76,30 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalN class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, MLFloat16, MaxPool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, double, + BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, MLFloat16, BatchNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, float, BatchNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, double, + BatchNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, MLFloat16, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, DepthToSpace); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, 12, DepthToSpace); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 13, DepthToSpace); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, SpaceToDepth); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 13, SpaceToDepth); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( + kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, float, LRN); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( + kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, double, LRN); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( + kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, MLFloat16, LRN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 13, float, LRN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 13, double, LRN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 13, MLFloat16, LRN); Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn nhwc_function_table[] = { @@ -86,18 +108,26 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) { kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, MLFloat16, BatchNormalization)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : nhwc_function_table) { diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 892e8d5329eba..103c79c93b2ca 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -225,6 +225,7 @@ struct CUDA_Provider : Provider { info.tunable_op.max_tuning_duration_ms = params->tunable_op_max_tuning_duration_ms; info.enable_skip_layer_norm_strict_mode = params->enable_skip_layer_norm_strict_mode != 0; info.use_ep_level_unified_stream = params->use_ep_level_unified_stream != 0; + info.use_tf32 = params->use_tf32 != 0; return std::make_shared(info); } @@ -258,6 +259,7 @@ struct CUDA_Provider : Provider { cuda_options.enable_skip_layer_norm_strict_mode = internal_options.enable_skip_layer_norm_strict_mode; cuda_options.prefer_nhwc = internal_options.prefer_nhwc; cuda_options.use_ep_level_unified_stream = internal_options.use_ep_level_unified_stream; + cuda_options.use_tf32 = internal_options.use_tf32; } ProviderOptions GetProviderOptions(const void* provider_options) override { diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc index 0a256394b7d99..3c0bf183362dd 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc @@ -212,6 +212,9 @@ void* CudaStream::GetResource(int version, int id) const { case CudaResource::prefer_nhwc_t: return reinterpret_cast(ep_info_.prefer_nhwc); break; + case CudaResource::use_tf32_t: + return reinterpret_cast(ep_info_.use_tf32); + break; default: break; } diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.h b/onnxruntime/core/providers/cuda/cuda_stream_handle.h index b02c167e9e9ec..15e7a0553c84e 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.h +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.h @@ -11,6 +11,7 @@ namespace onnxruntime { struct CudaStream; +void WaitCudaNotificationOnDevice(Stream& stream, synchronize::Notification& notification); struct DeferredCpuAllocator : public OrtAllocator { DeferredCpuAllocator(CudaStream&); @@ -47,6 +48,8 @@ struct CudaStream : Stream { onnxruntime::IAllocator* GetCpuAllocator() const { return cpu_allocator_.get(); } + WaitNotificationFn GetWaitNotificationFn() const override { return WaitCudaNotificationOnDevice; } + private: std::vector deferred_cpu_buffers_; AllocatorPtr cpu_allocator_; @@ -64,5 +67,4 @@ void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis cudnnHandle_t external_cudnn_handle, cublasHandle_t external_cublass_handle, const CUDAExecutionProviderInfo& ep_info); -void WaitCudaNotificationOnDevice(Stream& stream, synchronize::Notification& notification); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc index c850f7b583bfc..39b73163794f0 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.cc +++ b/onnxruntime/core/providers/cuda/cudnn_common.cc @@ -160,7 +160,6 @@ cudnnDataType_t CudnnTensor::GetDataType() { template <> cudnnDataType_t CudnnTensor::GetDataType() { ORT_THROW("cuDNN doesn't support BFloat16."); - return CUDNN_DATA_FLOAT; } template <> diff --git a/onnxruntime/core/providers/cuda/cudnn_common.h b/onnxruntime/core/providers/cuda/cudnn_common.h index fdd14dedad47e..2cbeb13696270 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.h +++ b/onnxruntime/core/providers/cuda/cudnn_common.h @@ -24,12 +24,12 @@ class CudnnTensor final { operator cudnnTensorDescriptor_t() const { return tensor_; } + Status CreateTensorIfNeeded(); + template static cudnnDataType_t GetDataType(); private: - Status CreateTensorIfNeeded(); - cudnnTensorDescriptor_t tensor_; }; diff --git a/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc b/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc index 3e50116eafd17..ee0334e552022 100644 --- a/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc +++ b/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc @@ -51,25 +51,27 @@ Status MatMul(const T* input_1_data, const T* input_2_data, T* output_data, CudaT one = cuda::ToCudaType::FromFloat(1.0f); CudaT zero = cuda::ToCudaType::FromFloat(0.0f); - CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(static_cast(einsum_cuda_assets)->cublas_handle_, - CUBLAS_OP_N, - CUBLAS_OP_N, - static_cast(N), - static_cast(M), - static_cast(K), - &one, - reinterpret_cast(input_2_data), - static_cast(N), - static_cast(right_stride), - reinterpret_cast(input_1_data), - static_cast(K), - static_cast(left_stride), - &zero, - reinterpret_cast(output_data), - static_cast(N), - static_cast(output_stride), - static_cast(num_batches), - static_cast(einsum_cuda_assets)->cuda_ep_->GetDeviceProp())); + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + static_cast(einsum_cuda_assets)->cublas_handle_, + CUBLAS_OP_N, + CUBLAS_OP_N, + static_cast(N), + static_cast(M), + static_cast(K), + &one, + reinterpret_cast(input_2_data), + static_cast(N), + static_cast(right_stride), + reinterpret_cast(input_1_data), + static_cast(K), + static_cast(left_stride), + &zero, + reinterpret_cast(output_data), + static_cast(N), + static_cast(output_stride), + static_cast(num_batches), + static_cast(einsum_cuda_assets)->cuda_ep_->GetDeviceProp(), + static_cast(einsum_cuda_assets)->cuda_ep_->UseTF32())); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/math/gemm.cc b/onnxruntime/core/providers/cuda/math/gemm.cc index 8fe23c9a036cc..4e61e0c8c69c6 100644 --- a/onnxruntime/core/providers/cuda/math/gemm.cc +++ b/onnxruntime/core/providers/cuda/math/gemm.cc @@ -118,7 +118,7 @@ Status Gemm::ComputeDefault(OpKernelContext* ctx, int M, int N, int K) const b_data, N, GetConstOnes(M, Stream(ctx)), 1, /*beta*/ &zero, - out_data, N, device_prop)); + out_data, N, device_prop, UseTF32())); } else if (b_shape.NumDimensions() == 2 && b_shape[1] == 1) { // B is (M, 1), broadcast using Y(N,M) = 1 * ones(N,1) x B(1,M) + 0 * Y CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( @@ -130,7 +130,7 @@ Status Gemm::ComputeDefault(OpKernelContext* ctx, int M, int N, int K) const GetConstOnes(N, Stream(ctx)), N, b_data, 1, /*beta*/ &zero, - out_data, N, device_prop)); + out_data, N, device_prop, UseTF32())); } else { // B is (M, N), no broadcast needed. CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(out_data, b_data, static_cast(M) * N * sizeof(T), cudaMemcpyDeviceToDevice, Stream(ctx))); @@ -153,7 +153,7 @@ Status Gemm::ComputeDefault(OpKernelContext* ctx, int M, int N, int K) const // ideally we need to set the output buffer contents to 0 if bias is missing, // but passing 0 for beta is cheaper and it will ignore any junk in the output buffer B != nullptr ? &beta : &zero, - out_data, N, device_prop)); + out_data, N, device_prop, UseTF32())); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/math/matmul.cc b/onnxruntime/core/providers/cuda/math/matmul.cc index e4c37c52a1780..6e126fbeadce8 100644 --- a/onnxruntime/core/providers/cuda/math/matmul.cc +++ b/onnxruntime/core/providers/cuda/math/matmul.cc @@ -173,7 +173,8 @@ Status FuncMatMul( &cuda_zero, reinterpret_cast(Y->MutableData()), ldc, - device_prop)); + device_prop, + cuda_kernel->UseTF32())); return Status::OK(); } else if (CanUseStridedBatchedGemm(A->Shape(), B->Shape(), trans_A, trans_B, trans_batch_B, trans_batch_B, stride_A, stride_B, stride_C, batch_count)) { @@ -195,7 +196,8 @@ Status FuncMatMul( ldc, stride_C, static_cast(batch_count), - device_prop)); + device_prop, + cuda_kernel->UseTF32())); return Status::OK(); } @@ -213,12 +215,12 @@ Status FuncMatMul( ORT_RETURN_IF_ERROR(Y_arrays.CopyToGpu(ctx->GetComputeStream())); // TF32 provides a huge performance gain for training and inference while preserving FP32 levels of accuracy. - // It requires Ampere or newer GPU, and pointers of matrics shall be aligned (ideal alignment is 16-byte). + // It requires Ampere or newer GPU, and pointers of matrices shall be aligned (ideal alignment is 16-byte). // Assume that start memory of input/output tensor is aligned, we only check offsets of sub-matrix per batch here. - cublasMath_t mode = (std::is_same::value && device_prop.major >= 8 && helper.IsBatchedGemmAligned()) - ? CUBLAS_TF32_TENSOR_OP_MATH - : CUBLAS_DEFAULT_MATH; - CublasMathModeSetter math_mode_setter(device_prop, cuda_kernel->GetCublasHandle(ctx), mode); + bool use_tf32 = std::is_same::value && + cuda_kernel->UseTF32() && + device_prop.major >= 8 && + helper.IsBatchedGemmAligned(); // note that onnxruntime OrtValue is row major, while cublas is column major, // so swap left/right operands @@ -238,7 +240,8 @@ Status FuncMatMul( Y_arrays.GpuPtr(), ldc, static_cast(helper.OutputOffsets().size()), - device_prop)); + device_prop, + use_tf32)); return Status::OK(); } @@ -321,7 +324,8 @@ Status MatMul::ComputeDefault(OpKernelContext* ctx, MatMulComputeHelper& help &zero, reinterpret_cast(Y->MutableData()), ldc, - device_prop)); + device_prop, + UseTF32())); return Status::OK(); } else if (CanUseStridedBatchedGemm(left_X->Shape(), right_X->Shape(), transa, transb, trans_batch_a_, trans_batch_b_, stride_A, stride_B, stride_C, batch_count)) { @@ -343,7 +347,8 @@ Status MatMul::ComputeDefault(OpKernelContext* ctx, MatMulComputeHelper& help ldc, stride_C, static_cast(batch_count), - device_prop)); + device_prop, + UseTF32())); return Status::OK(); } @@ -361,12 +366,12 @@ Status MatMul::ComputeDefault(OpKernelContext* ctx, MatMulComputeHelper& help ORT_RETURN_IF_ERROR(output_arrays.CopyToGpu(ctx->GetComputeStream())); // TF32 provides a huge performance gain for training and inference while preserving FP32 levels of accuracy. - // It requires Ampere or newer GPU, and pointers of matrics shall be aligned (ideal alignment is 16-byte). + // It requires Ampere or newer GPU, and pointers of matrices shall be aligned (ideal alignment is 16-byte). // Assume that start memory of input/output tensor is aligned, we only check offsets of sub-matrix per batch here. - cublasMath_t mode = (std::is_same::value && device_prop.major >= 8 && helper.IsBatchedGemmAligned()) - ? CUBLAS_TF32_TENSOR_OP_MATH - : CUBLAS_DEFAULT_MATH; - CublasMathModeSetter math_mode_setter(device_prop, GetCublasHandle(ctx), mode); + bool use_tf32 = std::is_same::value && + this->UseTF32() && + device_prop.major >= 8 && + helper.IsBatchedGemmAligned(); // note that onnxruntime OrtValue is row major, while cublas is column major, // so swap left/right operands @@ -386,7 +391,8 @@ Status MatMul::ComputeDefault(OpKernelContext* ctx, MatMulComputeHelper& help output_arrays.GpuPtr(), ldc, static_cast(helper.OutputOffsets().size()), - device_prop)); + device_prop, + use_tf32)); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/math/topk.cc b/onnxruntime/core/providers/cuda/math/topk.cc index d516537e25949..cf26e0acfa557 100644 --- a/onnxruntime/core/providers/cuda/math/topk.cc +++ b/onnxruntime/core/providers/cuda/math/topk.cc @@ -56,7 +56,7 @@ TopK::TopK(const OpKernelInfo& info) : CudaKernel(info) { info.GetAttrOrDefault("largest", &largest_, 1); info.GetAttrOrDefault("sorted", &sorted_, 1); if (!inputk) { - info.GetAttrOrDefault("k", &K_, 0); + info.GetAttrOrDefault("k", &attr_k_, 0); } } @@ -67,7 +67,7 @@ TopK::TopK(const OpKernelInfo& info) : CudaKernel(info) { static_cast(tensor_I->MutableDataRaw()), \ elem_nums_cuda, \ elem_nums.size(), \ - axis, K_, largest_, sorted_, N, dimension) + axis, k_value, largest_, sorted_, N, dimension) template Status TopK::ComputeInternal(OpKernelContext* ctx) const { @@ -77,19 +77,29 @@ Status TopK::ComputeInternal(OpKernelContext* ctx) const { int32_t axis = static_cast(axis_ < 0 ? rank + axis_ : axis_); ORT_ENFORCE(axis > -1 && axis < rank); + int64_t k_value = 0; if (inputk) { auto tensor_K = ctx->Input(1); ORT_ENFORCE(nullptr != tensor_K); - K_ = *tensor_K->Data(); - ORT_ENFORCE(K_ >= 0 && K_ <= tensor_X->Shape().GetDims()[axis]); + k_value = *tensor_K->Data(); + } else { // from attribute + k_value = attr_k_; } - auto output_shape = tensor_X->Shape(); - output_shape[axis] = K_; + // Now that we know the value of 'K' and the input shape, + // make a final validation before going to the implementation + const auto& input_shape = tensor_X->Shape(); + if ((k_value < 0) || (k_value > input_shape.GetDims()[axis])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Value of K outside range. K value: ", k_value, + ". Input shape: ", input_shape, " . Axis: ", axis); + } + + auto output_shape = input_shape; + output_shape[axis] = k_value; auto tensor_V = ctx->Output(0, output_shape); auto tensor_I = ctx->Output(1, output_shape); - if (0 == K_) { + if (output_shape.Size() == 0) { // Bail out early if the output is going to be empty return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/math/topk.h b/onnxruntime/core/providers/cuda/math/topk.h index 9dec13ad2a930..5731df3130c5a 100644 --- a/onnxruntime/core/providers/cuda/math/topk.h +++ b/onnxruntime/core/providers/cuda/math/topk.h @@ -17,7 +17,7 @@ class TopK final : public CudaKernel { int64_t axis_; int64_t largest_; int64_t sorted_; - mutable int64_t K_; + int64_t attr_k_; }; } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index 655877f425054..24593b255371c 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -71,6 +71,88 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa return Status::OK(); \ } +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + IsInf, + kOnnxDomain, + 10, + 19, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + IsInf); + +ONNX_OPERATOR_KERNEL_EX( + IsInf, + kOnnxDomain, + 20, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + IsInf); + +IsInf::IsInf(const OpKernelInfo& info) : UnaryElementwise(info) { + detect_positive_ = static_cast(info.GetAttrOrDefault("detect_positive", 1)); + detect_negative_ = static_cast(info.GetAttrOrDefault("detect_negative", 1)); + opset_ = info.node().SinceVersion(); +} + +Status IsInf::ComputeInternal(OpKernelContext* context) const { + UnaryElementwisePreparation p; + ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p)); + + Explicit_Impl_IsInf(Stream(context), opset_, detect_positive_, detect_negative_, + p.input_tensor->GetElementType(), p.input_tensor->DataRaw(), + p.output_tensor->MutableData(), + p.input_tensor->Shape().Size()); + return Status::OK(); +} + +// IsNan +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + IsNaN, + kOnnxDomain, + 9, + 12, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + IsNaN); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + IsNaN, + kOnnxDomain, + 13, + 19, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + IsNaN); + +ONNX_OPERATOR_KERNEL_EX( + IsNaN, + kOnnxDomain, + 20, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + IsNaN); + +Status IsNaN::ComputeInternal(OpKernelContext* context) const { + UnaryElementwisePreparation p; + ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p)); + + Explicit_Impl_IsNan(Stream(context), p.input_tensor->GetElementType(), p.input_tensor->DataRaw(), + p.output_tensor->MutableData(), + p.input_tensor->Shape().Size()); + + return Status::OK(); +} + #define UNARY_OP_VERSIONED_TYPED(name, startver, endver, T) \ UNARY_ELEMENTWISE_REGISTER_VERSIONED_KERNEL(name, startver, endver, T) @@ -160,7 +242,7 @@ UNARY_OP_CSILHFD(Neg, 13) UNARY_OP_HFD(Floor, 13) UNARY_OP_HFD(Ceil, 13) UNARY_OP_HFD(Reciprocal, 13) -UNARY_OP_HFD(Sqrt, 13) +UNARY_OP_HFDX(Sqrt, 13) UNARY_OP_HFD(Log, 13) UNARY_OP_HFD(Exp, 13) UNARY_OP_HFD(Erf, 13) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h index 775b78c43a736..95d68b5e1d534 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h @@ -2,6 +2,7 @@ // Licensed under the MIT License. #pragma once + #include "core/providers/cuda/cuda_kernel.h" namespace onnxruntime { @@ -119,5 +120,22 @@ class Sign final : public UnaryElementwise { Status ComputeInternal(OpKernelContext* context) const override; }; +class IsInf final : public UnaryElementwise { + public: + explicit IsInf(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + private: + bool detect_positive_{true}; + bool detect_negative_{true}; + int opset_; +}; + +class IsNaN : public UnaryElementwise { + public: + explicit IsNaN(const OpKernelInfo& info) : UnaryElementwise(info) {} + Status ComputeInternal(OpKernelContext* context) const override; +}; + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index 5c3db4a499972..2cdfcda5be26a 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -11,6 +11,7 @@ #endif namespace onnxruntime { + namespace cuda { #define OP(name, expr) \ @@ -83,7 +84,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_CSILHFD(Neg) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Floor) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Ceil) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Reciprocal) -SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sqrt) +SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Sqrt) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Log) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Exp) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Erf) @@ -126,9 +127,10 @@ struct OP_Cast { UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast(), count); \ } -#define IMPL_CAST_IMPL_THROW(InT, OutT) \ - void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \ - ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \ +#define IMPL_CAST_IMPL_THROW(InT, OutT) \ + void Explicit_Impl_Cast(cudaStream_t /*stream*/, const InT* /*input_data*/, OutT* /*output_data*/, \ + size_t /*count*/) { \ + ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \ } #if !defined(DISABLE_FLOAT8_TYPES) @@ -284,5 +286,62 @@ EXPLICIT_IMPL_CASTSAT(__nv_bfloat16, Float8E5M2) #endif +namespace isinf_details { +template +struct IsInf_DispFunc { + void operator()(cudaStream_t stream, const void* input_raw, bool* output_data, + bool detect_positive, bool detect_negative, size_t count) const { + using CudaType = typename ToCudaType::MappedType; + const auto* input_data = reinterpret_cast(input_raw); + if (detect_positive && detect_negative) { + UnaryElementWiseImpl(stream, input_data, output_data, _IsInf{}, count); + } else if (detect_positive) { + UnaryElementWiseImpl(stream, input_data, output_data, _IsInf{}, count); + } else if (detect_negative) { + UnaryElementWiseImpl(stream, input_data, output_data, _IsInf{}, count); + } else { + UnaryElementWiseImpl(stream, input_data, output_data, _IsInf{}, count); + } + } +}; + +} // namespace isinf_details + +void Explicit_Impl_IsInf(cudaStream_t stream, int op_set, + bool detect_positive, bool detect_negative, + int32_t input_data_type, + const void* input_raw, bool* output_data, + size_t count) { + if (op_set < 20) { + utils::MLTypeCallDispatcher dispatcher{input_data_type}; + dispatcher.Invoke(stream, input_raw, output_data, + detect_positive, detect_negative, count); + } else { + utils::MLTypeCallDispatcher dispatcher{input_data_type}; + dispatcher.Invoke(stream, input_raw, output_data, + detect_positive, detect_negative, count); + } +} + +// IsNan + +namespace isnan_details { +template +struct IsNan_Disp { + void operator()(cudaStream_t stream, const void* input_raw, bool* output_data, size_t count) const { + using CudaType = typename ToCudaType::MappedType; + const auto* input_data = reinterpret_cast(input_raw); + UnaryElementWiseImpl(stream, input_data, output_data, _IsNan{}, count); + } +}; +} // namespace isnan_details + +void Explicit_Impl_IsNan(cudaStream_t stream, int32_t input_data_type, + const void* input_raw, bool* output_data, size_t count) { + // KernelDef constraints would ensure only subset of datatypes is used. + utils::MLTypeCallDispatcher dispatcher{input_data_type}; + dispatcher.Invoke(stream, input_raw, output_data, count); +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h index 608a81a24cf4f..2588f56e32c12 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h @@ -137,5 +137,34 @@ void Impl_CastSat( #endif +// IsInf + +#if !defined(DISABLE_FLOAT8_TYPES) +#define ISINF_OPSET20_ALL_FLOATS float, double, MLFloat16, BFloat16, Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, \ + Float8E5M2FNUZ +#else +#define ISINF_OPSET20_ALL_FLOATS float, double, MLFloat16, BFloat16 +#endif + +void Explicit_Impl_IsInf(cudaStream_t stream, int op_set, + bool detect_positive, bool detect_negative, + int32_t input_data_type, + const void* input_raw, bool* output_data, + size_t count); + +// IsNan +#define ISNAN_OPSET9_FLOATS float, double, MLFloat16 +#define ISNAN_OPSET13_FLOATS float, double, MLFloat16, BFloat16 +#if !defined(DISABLE_FLOAT8_TYPES) +#define ISNAN_OPSET20_FLOATS float, double, MLFloat16, BFloat16, Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, \ + Float8E5M2FNUZ +#else +#define ISNAN_OPSET20_FLOATS ISNAN_OPSET13_FLOATS +#endif + +void Explicit_Impl_IsNan(cudaStream_t stream, int32_t input_data_type, + const void* input_raw, bool* output_data, size_t count); + } // namespace cuda + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/batch_norm.cc b/onnxruntime/core/providers/cuda/nn/batch_norm.cc index c468971e1e426..02da1a2c99dfd 100644 --- a/onnxruntime/core/providers/cuda/nn/batch_norm.cc +++ b/onnxruntime/core/providers/cuda/nn/batch_norm.cc @@ -87,7 +87,7 @@ Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) CudnnTensor data_desc; vector new_dims; - BatchNormHelper::NormalizeDims(x_shape, new_dims); + BatchNormHelper::NormalizeDims(x_shape, new_dims, NHWC); ORT_RETURN_IF_ERROR(data_desc.Set(new_dims, CudnnTensor::GetDataType(), NHWC)); // For half data type, the alpha, beta, scale, B, mean, var need to be float type @@ -137,6 +137,12 @@ Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) auto saved_mean_data = reinterpret_cast(saved_mean->MutableData()); auto saved_inv_var_data = reinterpret_cast(saved_var->MutableData()); + auto stream = static_cast(p_op_kernel_context->GetComputeStream()->GetHandle()); + CUDA_RETURN_IF_ERROR( + cudaMemcpyAsync(running_mean_data, mean_data, mean->SizeInBytes(), cudaMemcpyDeviceToDevice, stream)); + CUDA_RETURN_IF_ERROR( + cudaMemcpyAsync(running_var_data, var_data, var->SizeInBytes(), cudaMemcpyDeviceToDevice, stream)); + CUDNN_RETURN_IF_ERROR(BatchNormalizationForwardTrainingHelper( GetCudnnHandle(p_op_kernel_context), cudnn_batch_norm_mode_, @@ -149,7 +155,7 @@ Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) bn_tensor_desc, scale_data, b_data, - momentum_, + 1.0 - momentum_, running_mean_data, running_var_data, epsilon_, @@ -186,6 +192,7 @@ SPECIALIZED_COMPUTE(MLFloat16, kOnnxDomain, false) #ifdef ENABLE_CUDA_NHWC_OPS SPECIALIZED_COMPUTE(float, kMSInternalNHWCDomain, true) +SPECIALIZED_COMPUTE(double, kMSInternalNHWCDomain, true) SPECIALIZED_COMPUTE(MLFloat16, kMSInternalNHWCDomain, true) #endif } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 82f3503919237..e05786248cbcf 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -97,11 +97,11 @@ Status SliceOutUnwantedOutputSection(cudaStream_t stream, template Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) { + bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { is_packed = false; // only layout of weight input is adjusted via PrePack - if (NHWC && is_nhwc_domain_) { // InputTensors::IN_W - if (input_idx == 1) { + if constexpr (NHWC) { + if (is_nhwc_domain_ && input_idx == 1) { // InputTensors::IN_W // Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} auto orig_shape = tensor.Shape(); @@ -123,6 +123,10 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr CUDA_CALL_THROW(cudaStreamSynchronize(DefaultCudaStream())); is_packed = true; } + } else { + ORT_UNUSED_PARAMETER(tensor); + ORT_UNUSED_PARAMETER(input_idx); + ORT_UNUSED_PARAMETER(alloc); } return Status::OK(); @@ -149,8 +153,11 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) // Make sure input and weight are 4D for NHWC since we set 4D descriptor for NHWC. constexpr bool channels_last = NHWC; - if (channels_last && (x_shape.NumDimensions() != 4 || w_shape.NumDimensions() != 4)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of dimensions of X and W should be 4 for channels_last format (NHWC)"); + if constexpr (channels_last) { + if (x_shape.NumDimensions() != 4 || w_shape.NumDimensions() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of dimensions of X and W should be 4 for channels_last format (NHWC)"); + } } // set B @@ -326,7 +333,8 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, gsl::narrow_cast(conv_attrs_.group), - CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType())); + CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType(), + UseTF32())); if (context->InputCount() >= 3) { const Tensor* B = context->Input(2); @@ -351,8 +359,13 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) if (!s_.cached_benchmark_results.contains(x_dims_cudnn)) { // set math type to tensor core before algorithm search - if constexpr (std::is_same::value) + if constexpr (std::is_same::value) { CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); + } else if constexpr (std::is_same::value) { + if (!UseTF32()) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); + } + } cudnnConvolutionFwdAlgoPerf_t perf; int algo_count = 1; @@ -397,8 +410,11 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) default: perf.algo = kDefaultConvAlgo; CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(GetCudnnHandle(context), s_, perf.algo, &perf.memory)); - if (std::is_same::value) { + + if constexpr (std::is_same::value) { perf.mathType = CUDNN_TENSOR_OP_MATH; + } else if (std::is_same::value && !UseTF32()) { + perf.mathType = CUDNN_FMA_MATH; } else { perf.mathType = CUDNN_DEFAULT_MATH; } @@ -480,7 +496,8 @@ Status CudnnConvolutionDescriptor::Set( const gsl::span& dilations, int groups, cudnnConvolutionMode_t mode, - cudnnDataType_t data_type) { + cudnnDataType_t data_type, + bool use_tf32) { if (!desc_) CUDNN_RETURN_IF_ERROR(cudnnCreateConvolutionDescriptor(&desc_)); @@ -513,6 +530,8 @@ Status CudnnConvolutionDescriptor::Set( CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_DEFAULT_MATH)); if (data_type == CUDNN_DATA_HALF) { CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_TENSOR_OP_MATH)); + } else if (data_type == CUDNN_DATA_FLOAT && !use_tf32) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_FMA_MATH)); } return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index bcaa4d855b81e..3aec654224e39 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -29,7 +29,8 @@ class CudnnConvolutionDescriptor final { const gsl::span& dilations, int groups, cudnnConvolutionMode_t mode, - cudnnDataType_t data_type); + cudnnDataType_t data_type, + bool use_tf32); operator cudnnConvolutionDescriptor_t() const { return desc_; } @@ -194,7 +195,7 @@ class Conv : public CudaKernel { } Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) override; + bool& is_packed, PrePackedWeights* prepacked_weights) override; Status ComputeInternal(OpKernelContext* context) const override; diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index 55dceaa2698e8..939b9959af818 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -167,7 +167,8 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations, gsl::narrow_cast(conv_transpose_attrs_.group), mode, - CudnnTensor::GetDataType())); + CudnnTensor::GetDataType(), + UseTF32())); if (has_bias) { const auto& b_shape = p.B->Shape(); @@ -187,8 +188,13 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); // set math type to tensor core before algorithm search - if constexpr (std::is_same::value) + if constexpr (std::is_same::value) { CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); + } else if constexpr (std::is_same::value) { + if (!UseTF32()) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); + } + } cudnnConvolutionBwdDataAlgoPerf_t perf; int algo_count = 1; diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm.h b/onnxruntime/core/providers/cuda/nn/layer_norm.h index ff231f4f1ad5c..c021d3ffe63a2 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm.h +++ b/onnxruntime/core/providers/cuda/nn/layer_norm.h @@ -7,8 +7,6 @@ namespace onnxruntime { namespace cuda { -using namespace onnxruntime::cuda; - // NOTE: This was originally a contrib op with 3 type constraints. The ONNX spec merges 'T' and 'V'. // the kernel is templatized on all three for backwards compatibility, but in ONNX usage T == V. template diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu index 679b8b6b78886..b9e8b45307079 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu @@ -29,8 +29,6 @@ namespace onnxruntime { namespace cuda { -using namespace onnxruntime::cuda; - template __device__ void cuWelfordOnlineSum( const U curr, diff --git a/onnxruntime/core/providers/cuda/nn/lrn.cc b/onnxruntime/core/providers/cuda/nn/lrn.cc index 6fcdec74d84b5..788299b5eb8d6 100644 --- a/onnxruntime/core/providers/cuda/nn/lrn.cc +++ b/onnxruntime/core/providers/cuda/nn/lrn.cc @@ -6,37 +6,47 @@ namespace onnxruntime { namespace cuda { -#define REGISTER_KERNEL_VERSIONED_TYPED(START_VER, END_VER, T) \ +#define REGISTER_KERNEL_VERSIONED_TYPED(START_VER, END_VER, T, DOMAIN, LAYOUT) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ LRN, \ - kOnnxDomain, \ + DOMAIN, \ START_VER, \ END_VER, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - LRN); + LRN); -#define REGISTER_KERNEL_TYPED(VER, T) \ +#define REGISTER_KERNEL_TYPED(VER, T, DOMAIN, LAYOUT) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ LRN, \ - kOnnxDomain, \ + DOMAIN, \ VER, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - LRN); + LRN); -REGISTER_KERNEL_VERSIONED_TYPED(1, 12, float) -REGISTER_KERNEL_VERSIONED_TYPED(1, 12, double) -REGISTER_KERNEL_VERSIONED_TYPED(1, 12, MLFloat16) +REGISTER_KERNEL_VERSIONED_TYPED(1, 12, float, kOnnxDomain, false) +REGISTER_KERNEL_VERSIONED_TYPED(1, 12, double, kOnnxDomain, false) +REGISTER_KERNEL_VERSIONED_TYPED(1, 12, MLFloat16, kOnnxDomain, false) -REGISTER_KERNEL_TYPED(13, float) -REGISTER_KERNEL_TYPED(13, double) -REGISTER_KERNEL_TYPED(13, MLFloat16) +REGISTER_KERNEL_TYPED(13, float, kOnnxDomain, false) +REGISTER_KERNEL_TYPED(13, double, kOnnxDomain, false) +REGISTER_KERNEL_TYPED(13, MLFloat16, kOnnxDomain, false) -template -LRN::LRN(const OpKernelInfo& info) : CudaKernel(info) { +#ifdef ENABLE_CUDA_NHWC_OPS +REGISTER_KERNEL_VERSIONED_TYPED(1, 12, float, kMSInternalNHWCDomain, true) +REGISTER_KERNEL_VERSIONED_TYPED(1, 12, double, kMSInternalNHWCDomain, true) +REGISTER_KERNEL_VERSIONED_TYPED(1, 12, MLFloat16, kMSInternalNHWCDomain, true) + +REGISTER_KERNEL_TYPED(13, float, kMSInternalNHWCDomain, true) +REGISTER_KERNEL_TYPED(13, double, kMSInternalNHWCDomain, true) +REGISTER_KERNEL_TYPED(13, MLFloat16, kMSInternalNHWCDomain, true) +#endif + +template +LRN::LRN(const OpKernelInfo& info) : CudaKernel(info) { int64_t size; ORT_ENFORCE(info.GetAttr("size", &size).IsOK()); ORT_ENFORCE(size > 0); @@ -58,8 +68,8 @@ LRN::LRN(const OpKernelInfo& info) : CudaKernel(info) { .IsOK()); } -template -Status LRN::ComputeInternal(OpKernelContext* context) const { +template +Status LRN::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; const Tensor* X = context->Input(0); @@ -71,7 +81,7 @@ Status LRN::ComputeInternal(OpKernelContext* context) const { Tensor* Y = context->Output(0, X->Shape()); CudnnTensor x_tensor; - ORT_RETURN_IF_ERROR(x_tensor.Set(X->Shape().GetDims(), CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(x_tensor.Set(X->Shape().GetDims(), CudnnTensor::GetDataType(), Layout == NHWC)); const auto one = Consts::One; const auto zero = Consts::Zero; diff --git a/onnxruntime/core/providers/cuda/nn/lrn.h b/onnxruntime/core/providers/cuda/nn/lrn.h index 319e323c72a92..31b2819ccc52a 100644 --- a/onnxruntime/core/providers/cuda/nn/lrn.h +++ b/onnxruntime/core/providers/cuda/nn/lrn.h @@ -20,7 +20,7 @@ class CudnnLRNDescriptor final { cudnnLRNDescriptor_t desc_; }; -template +template class LRN : public CudaKernel { public: LRN(const OpKernelInfo& info); diff --git a/onnxruntime/core/providers/cuda/nvtx_profile.cc b/onnxruntime/core/providers/cuda/nvtx_profile.cc index 6c7c594066b86..867e7c1f24584 100644 --- a/onnxruntime/core/providers/cuda/nvtx_profile.cc +++ b/onnxruntime/core/providers/cuda/nvtx_profile.cc @@ -4,13 +4,8 @@ #ifdef ENABLE_NVTX_PROFILE #include "nvtx_profile.h" #include "core/common/common.h" -#if defined(_WIN32) || defined(WIN32) || defined(__CYGWIN__) || defined(__MINGW32__) || defined(__BORLANDC__) #include #include -#else -#include -#include -#endif namespace onnxruntime { namespace profile { diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc index 99c1f48e21c74..6476364a211fd 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc @@ -9,40 +9,49 @@ namespace onnxruntime { namespace cuda { template -void CudnnRnnBase::SetWeightBias(const cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnn_desc, - const int pseudo_layer, - const cudnnTensorDescriptor_t x_desc, - const cudnnFilterDescriptor_t w_desc, - const cudnnFilterDescriptor_t filter_desc, - const void* reorganized_w_data, - const int lin_layer_id, - const T* pos, - int& offset, - bool is_matrix, - cudaStream_t cuda_stream) const { +Status CudnnRnnBase::SetWeightBias(const cudnnHandle_t handle, + const cudnnRNNDescriptor_t rnn_desc, + const int pseudo_layer, + size_t reorganized_w_data_size, + const void* reorganized_w_data, + const int lin_layer_id, + const T* pos, + int& offset, + bool is_matrix, + cudaStream_t cuda_stream) const { int numDims; - std::vector matDims(3); + std::array matDims; + std::array strideA; cudnnDataType_t dt; - cudnnTensorFormat_t tf; T* mem_offset; - if (is_matrix) { - cudnnGetRNNLinLayerMatrixParams(handle, rnn_desc, pseudo_layer, x_desc, w_desc, reorganized_w_data, lin_layer_id, filter_desc, (void**)&mem_offset); - } else { - cudnnGetRNNLinLayerBiasParams(handle, rnn_desc, pseudo_layer, x_desc, w_desc, reorganized_w_data, lin_layer_id, filter_desc, (void**)&mem_offset); - } + CudnnTensor tensor_desc_matrix, tensor_desc_bias; + ORT_RETURN_IF_ERROR(tensor_desc_bias.CreateTensorIfNeeded()); + ORT_RETURN_IF_ERROR(tensor_desc_matrix.CreateTensorIfNeeded()); - cudnnGetFilterNdDescriptor(filter_desc, 3, &dt, &tf, &numDims, matDims.data()); + T *mem_offset_matrix, *mem_offset_bias; + CUDNN_RETURN_IF_ERROR(cudnnGetRNNWeightParams( + handle, rnn_desc, pseudo_layer, reorganized_w_data_size, reorganized_w_data, + lin_layer_id, tensor_desc_matrix, (void**)&mem_offset_matrix, tensor_desc_bias, (void**)&mem_offset_bias)); + CUDNN_RETURN_IF_ERROR(cudnnGetTensorNdDescriptor( + is_matrix ? tensor_desc_matrix : tensor_desc_bias, 3, &dt, &numDims, matDims.data(), strideA.data())); + + mem_offset = is_matrix ? mem_offset_matrix : mem_offset_bias; int count = matDims[0] * matDims[1] * matDims[2]; + + if (strideA[0] != count) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::INVALID_ARGUMENT, "Stride is not packed"); + } CUDA_CALL_THROW(cudaMemcpyAsync(mem_offset, pos + offset, count * sizeof(T), cudaMemcpyDeviceToDevice, cuda_stream)); + offset += count; + + return Status::OK(); } template Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, const cudnnRNNDescriptor_t rnn_desc, - const cudnnTensorDescriptor_t x_desc, - const cudnnFilterDescriptor_t w_desc, + size_t reorganized_w_data_size, void* reorganized_w_data, const T* W_data, const T* R_data, @@ -51,18 +60,22 @@ Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, int w_offset = 0; int r_offset = 0; int bias_offset = 0; - CudnnFilterDescriptor filter_desc; for (int layer = 0; layer < RNN_NUM_LAYERS * num_directions_; ++layer) { for (size_t idx = 0; idx < W_lin_layer_id_.size(); ++idx) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, W_lin_layer_id_[idx], W_data, w_offset, true, cuda_stream); + ORT_RETURN_IF_ERROR(SetWeightBias( + cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data, + W_lin_layer_id_[idx], W_data, w_offset, true, cuda_stream)); if (B_data != nullptr) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, W_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream); + ORT_RETURN_IF_ERROR(SetWeightBias(cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data, + W_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream)); } } for (size_t idx = 0; idx < R_lin_layer_id_.size(); ++idx) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, R_lin_layer_id_[idx], R_data, r_offset, true, cuda_stream); + ORT_RETURN_IF_ERROR(SetWeightBias(cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data, + R_lin_layer_id_[idx], R_data, r_offset, true, cuda_stream)); if (B_data != nullptr) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, R_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream); + ORT_RETURN_IF_ERROR(SetWeightBias(cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data, + R_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream)); } } } @@ -72,6 +85,7 @@ Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, template Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, const Tensor* B, + size_t& reorganized_w_data_size_in_bytes, IAllocatorUniquePtr& reorganized_w_data, CudnnFilterDescriptor& target_w_desc, CudnnRNN& rnn_desc, onnxruntime::Stream* ort_stream) const { @@ -91,19 +105,16 @@ Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, cons TensorShapeVector dims_w({w_size, 1, 1}); ORT_RETURN_IF_ERROR(target_w_desc.Set(dims_w, CudnnTensor::GetDataType())); - TensorShapeVector fake_dims_x({1, input_size, 1}); - CudnnTensor fake_x_desc; - ORT_RETURN_IF_ERROR(fake_x_desc.Set(fake_dims_x, CudnnTensor::GetDataType())); - // Prepare the weight data - reorganized_w_data = GetScratchBuffer(w_size * sizeof(T), ort_stream); + reorganized_w_data_size_in_bytes = w_size * sizeof(T); + reorganized_w_data = GetScratchBuffer(reorganized_w_data_size_in_bytes, ort_stream); // In many cases, this allocation is bigger than needed, leaving part of - // the buffer unintialized. non-zero garbage data leads to wrong result + // the buffer uninitialized. non-zero garbage data leads to wrong result // in call to cudnnRNNForwardInference() // TODO! refine allocation size for each case. cudaStream_t cuda_stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; - cudaMemsetAsync(reorganized_w_data.get(), 0, w_size * sizeof(T), cuda_stream); + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(reorganized_w_data.get(), 0, reorganized_w_data_size_in_bytes, cuda_stream)); const T* W_data = W->Data(); const T* R_data = R->Data(); @@ -111,8 +122,9 @@ Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, cons auto* ort_cuda_stream = dynamic_cast(ort_stream); cudnnHandle_t cudnn_handle = ort_cuda_stream ? ort_cuda_stream->cudnn_handle_ : DefaultCudnnHandle(); - ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(cudnn_handle, rnn_desc, fake_x_desc, target_w_desc, - reorganized_w_data.get(), W_data, R_data, B_data, cuda_stream)); + ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(cudnn_handle, rnn_desc, + reorganized_w_data_size_in_bytes, reorganized_w_data.get(), + W_data, R_data, B_data, cuda_stream)); return Status::OK(); } @@ -128,22 +140,31 @@ Status CudnnRnnBase::CacheCudnnRnnWeights(const OpKernelInfo& info) { bool get_R = info.TryGetConstantInput(RNN_Input_Index::R, &R); bool get_B = info.TryGetConstantInput(RNN_Input_Index::B, &B); + bool has_bias = B != nullptr; + if (get_W && get_R) { CudnnRNN tmp_rnn_desc; - ORT_RETURN_IF_ERROR(tmp_rnn_desc.Set(DefaultCudnnHandle(), + auto proj_size = hidden_size_; + ORT_RETURN_IF_ERROR(tmp_rnn_desc.Set(W->Shape()[2], // input_size hidden_size_, + proj_size, RNN_NUM_LAYERS, cudnn_dropout_desc_, cudnn_direction_mode_, rnn_mode_, - CudnnTensor::GetDataType(), - GetDeviceProp())); + has_bias, + CudnnTensor::GetDataType())); if (get_B) { - ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B, w_data_cache_, w_desc_cache_, tmp_rnn_desc, nullptr)); + ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B, + w_data_cache_size_in_bytes_, w_data_cache_, w_desc_cache_, + tmp_rnn_desc, nullptr)); } else { - ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, nullptr, w_data_cache_, w_desc_cache_, tmp_rnn_desc, nullptr)); + ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, nullptr, + w_data_cache_size_in_bytes_, w_data_cache_, w_desc_cache_, + tmp_rnn_desc, nullptr)); } cudaStreamSynchronize(nullptr); + weight_cached_ = true; } @@ -158,17 +179,72 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { ORT_ENFORCE(nullptr != X); // optional inputs - const Tensor* sequence_lens = ctx->Input(RNN_Input_Index::sequence_lens); // [batch_size] - const Tensor* initial_h = ctx->Input(RNN_Input_Index::initial_h); // initial hidden. [num_directions_, batch_size, hidden_size_] + // [batch_size] + const Tensor* sequence_lens = ctx->Input(RNN_Input_Index::sequence_lens); + // initial hidden. [num_directions_, batch_size, hidden_size_] + const Tensor* initial_h = ctx->Input(RNN_Input_Index::initial_h); const Tensor* initial_c(nullptr); if (rnn_mode_ == CUDNN_LSTM) { - initial_c = ctx->Input(RNN_Input_Index::initial_c); // initial cell. [num_directions_, batch_size, hidden_size_] + // initial cell. [num_directions_, batch_size, hidden_size_] + initial_c = ctx->Input(RNN_Input_Index::initial_c); } + size_t proj_size = hidden_size_; int64_t seq_length = X->Shape()[0]; int64_t batch_size = X->Shape()[1]; int64_t input_size = X->Shape()[2]; + // we thread a single input as sequence_lens of length 1, require to expand to [batch_size]? + std::vector sequence_lengths_temp; + if (!sequence_lens) { + sequence_lengths_temp.resize(batch_size, gsl::narrow_cast(seq_length)); + } + + const int32_t* sequence_lens_data = (sequence_lens == nullptr) + ? sequence_lengths_temp.data() + : sequence_lens->Data(); + + // cuDNN doesn't support 0 sequence inside the batch, find the 0 sequence and set it to 1 + // there's a ZeroMask kernel to reset the result to 0 for the 0 sequence + int64_t zero_seq_count = 0; + std::vector zero_seq_index_cache(batch_size, 0); + + CudaAsyncBuffer sequence_lens_buffer(this, batch_size); + int32_t* seq_len_array = sequence_lens_buffer.CpuPtr(); + + // 0-len sequences are not supported by cuDNN. + // Replace them by sequences of len 1 and mask them out with SetZeroSequences + for (int i = 0; i < batch_size; ++i) { + if (0 == sequence_lens_data[i]) { + seq_len_array[i] = 1; + zero_seq_index_cache[zero_seq_count] = i; + ++zero_seq_count; + } else { + seq_len_array[i] = sequence_lens_data[i]; + } + } + + // Calculate the zero position cache for reverse direction if it's bidirectional + // The cache is for Y_h or Y_c, and the 1st sequence for Y, no need to do it for other sequence in Y since + // we hacked the 0 sequence to 1 + if (zero_seq_count && num_directions_ > 1) { + zero_seq_index_cache.resize(zero_seq_count * num_directions_); + for (int64_t i = 0; i < zero_seq_count; ++i) { + zero_seq_index_cache[static_cast(zero_seq_count) + i] = + static_cast(batch_size + zero_seq_index_cache[i]); + } + zero_seq_count *= num_directions_; + } + + // Prior to cuDNN 8.9.1 the sequence lens buffer must be passed to cudnnRNNForward and thus is must + // be copied to the GPU always. + ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream())); + // Starting with cuDNN 8.9.1 the sequence lens buffer is ignored by cudnnRNNForward and thus it must + // be copied to the GPU only for the ReverseBySequence kernels. + // if (reverse_) { + // ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream())); + // } + // optional outputs TensorShapeVector dims_Y({seq_length, num_directions_, batch_size, hidden_size_}); TensorShapeVector dims_hxy({RNN_NUM_LAYERS * num_directions_, batch_size, hidden_size_}); @@ -177,25 +253,6 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { Tensor* Y_h = ctx->Output(Output_Index::Y_h, dims_hxy); Tensor* Y_c = ctx->Output(Output_Index::Y_c, dims_yc); - std::vector dims_x({batch_size, input_size, 1}); - std::vector dims_y({batch_size, hidden_size_ * num_directions_, 1}); - - CudnnTensor x_desc_temp; - ORT_RETURN_IF_ERROR(x_desc_temp.Set(dims_x, CudnnTensor::GetDataType())); - CudnnTensor y_desc_temp; - ORT_RETURN_IF_ERROR(y_desc_temp.Set(dims_y, CudnnTensor::GetDataType())); - std::vector x_desc(seq_length, x_desc_temp); - std::vector y_desc(seq_length, y_desc_temp); - - CudnnTensor hx_desc; - CudnnTensor cx_desc; - CudnnTensor y_h_desc; - CudnnTensor y_c_desc; - ORT_RETURN_IF_ERROR(hx_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(cx_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(y_h_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(y_c_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - IAllocatorUniquePtr x_reversed_data; const T* x_data = X->Data(); if (reverse_) { @@ -203,6 +260,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { x_reversed_data = GetScratchBuffer(seq_length * batch_size * input_size, ctx->GetComputeStream()); ReverseBySequence(Stream(ctx), gsl::narrow_cast(seq_length), + sequence_lens_buffer.GpuPtr(), gsl::narrow_cast(batch_size), gsl::narrow_cast(input_size), reinterpret_cast(x_data), @@ -226,115 +284,81 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { y_data = y_alloc_data.get(); } - const int32_t* sequence_lens_data = (sequence_lens == nullptr) ? nullptr : sequence_lens->Data(); + const Tensor* B = ctx->Input(RNN_Input_Index::B); + bool has_bias = B != nullptr; CudnnRNN rnn_desc; - ORT_RETURN_IF_ERROR(rnn_desc.Set(GetCudnnHandle(ctx), + ORT_RETURN_IF_ERROR(rnn_desc.Set(input_size, hidden_size_, + proj_size, RNN_NUM_LAYERS, cudnn_dropout_desc_, cudnn_direction_mode_, rnn_mode_, - CudnnTensor::GetDataType(), - GetDeviceProp())); + has_bias, + CudnnTensor::GetDataType())); // Prepare the weight data + size_t w_data_size_in_bytes = 0; IAllocatorUniquePtr w_data; CudnnFilterDescriptor w_desc; if (!weight_cached_) { const Tensor& W = *ctx->Input(RNN_Input_Index::W); const Tensor& R = *ctx->Input(RNN_Input_Index::R); - const Tensor* B = ctx->Input(RNN_Input_Index::B); - ORT_RETURN_IF_ERROR(ReorganizeWeights(&W, &R, B, w_data, w_desc, rnn_desc, ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(ReorganizeWeights(&W, &R, B, w_data_size_in_bytes, w_data, w_desc, + rnn_desc, ctx->GetComputeStream())); } - // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences - CUDNN_RETURN_IF_ERROR(cudnnSetRNNPaddingMode(rnn_desc, CUDNN_RNN_PADDED_IO_ENABLED)); + CudnnDataTensor x_desc1; + ORT_RETURN_IF_ERROR(x_desc1.Set(CudnnTensor::GetDataType(), seq_length, batch_size, + input_size, seq_len_array)); + CudnnDataTensor y_desc1; + ORT_RETURN_IF_ERROR(y_desc1.Set(CudnnTensor::GetDataType(), seq_length, batch_size, + ((rnn_mode_ == CUDNN_LSTM) ? proj_size : hidden_size_) * num_directions_, + seq_len_array)); - size_t workspace_bytes; - CUDNN_RETURN_IF_ERROR(cudnnGetRNNWorkspaceSize(GetCudnnHandle(ctx), rnn_desc, gsl::narrow_cast(seq_length), x_desc.data(), &workspace_bytes)); - auto workspace_cuda = GetScratchBuffer(workspace_bytes, ctx->GetComputeStream()); - int64_t zero_seq_count = 0; - std::vector zero_seq_index_cache(batch_size, 0); - int64_t zero_seq_index_cache_size = 0; - - if (CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_ || nullptr == sequence_lens_data) { - CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInference(GetCudnnHandle(ctx), - rnn_desc, - gsl::narrow_cast(seq_length), - x_desc.data(), - x_data_input, - hx_desc, - hx_data, - cx_desc, - cx_data, - weight_cached_ ? w_desc_cache_ : w_desc, - weight_cached_ ? w_data_cache_.get() : w_data.get(), - y_desc.data(), - y_data, - y_h_desc, - y_h_data, - y_c_desc, - y_c_data, - workspace_cuda.get(), - workspace_bytes)); - } else { - // cudnn doesn't support 0 sequence inside the batch, find the 0 sequence and set it to 1 - // there's a ZeroMask kernel to reset the result to 0 for the 0 sequence - std::vector seq_len_array(sequence_lens_data, sequence_lens_data + batch_size); - for (int i = 0; i < batch_size; ++i) { - if (0 == seq_len_array[i]) { - seq_len_array[i] = 1; - zero_seq_index_cache[zero_seq_count] = i; - ++zero_seq_count; - } - } + CudnnTensor cx_desc; + ORT_RETURN_IF_ERROR(cx_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - // Calculate the zero position cache for reverse direction if it's bidirectional - // The cache is for Y_h or Y_c, and the 1st sequence for Y, no need to do it for other sequence in Y since - // we hacked the 0 sequence to 1 - if (zero_seq_count && num_directions_ > 1) { - zero_seq_index_cache_size = zero_seq_count * num_directions_; - zero_seq_index_cache.resize(zero_seq_index_cache_size); - for (int64_t i = 0; i < zero_seq_count; ++i) { - zero_seq_index_cache[static_cast(zero_seq_count) + i] = static_cast(batch_size + zero_seq_index_cache[i]); - } - } + CudnnTensor hx_desc; + ORT_RETURN_IF_ERROR(hx_desc.Set(dims_hxy, CudnnTensor::GetDataType())); + + // reserveSpaceSize is not required cudnnRNNForward, but returned by cudnnGetRNNTempSpaceSizes + size_t workspace_bytes, reservespace_bytes; - CudnnDataTensor x_desc1; - ORT_RETURN_IF_ERROR(x_desc1.Set(CudnnTensor::GetDataType(), seq_length, batch_size, input_size, seq_len_array.data())); - CudnnDataTensor y_desc1; - ORT_RETURN_IF_ERROR(y_desc1.Set(CudnnTensor::GetDataType(), seq_length, batch_size, hidden_size_ * num_directions_, seq_len_array.data())); - - CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInferenceEx(GetCudnnHandle(ctx), - rnn_desc, - x_desc1, - x_data_input, - hx_desc, - hx_data, - cx_desc, - cx_data, - weight_cached_ ? w_desc_cache_ : w_desc, - weight_cached_ ? w_data_cache_.get() : w_data.get(), - y_desc1, - y_data, - y_h_desc, - y_h_data, - y_c_desc, - y_c_data, - nullptr, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr, nullptr, - workspace_cuda.get(), - workspace_bytes)); - - // Early terminate for this case since Y data is not required, and Y_h is obtained correctly, no need the following code to retrive Y_h from Y data. - if (nullptr == Y) { + CUDNN_RETURN_IF_ERROR(cudnnGetRNNTempSpaceSizes(GetCudnnHandle(ctx), rnn_desc, CUDNN_FWD_MODE_INFERENCE, + x_desc1, &workspace_bytes, &reservespace_bytes)); + auto workspace_cuda = GetScratchBuffer(workspace_bytes, ctx->GetComputeStream()); + auto reservespace_cuda = GetScratchBuffer(reservespace_bytes, ctx->GetComputeStream()); + + CUDNN_RETURN_IF_ERROR(cudnnRNNForward(GetCudnnHandle(ctx), + rnn_desc, + CUDNN_FWD_MODE_INFERENCE, + sequence_lens_buffer.GpuPtr(), // should be zero starting with cudnn 8.9.1 + x_desc1, + x_data_input, + y_desc1, + y_data, // output + hx_desc, + hx_data, // input + y_h_data, // output + cx_desc, cx_data, y_c_data, + weight_cached_ ? w_data_cache_size_in_bytes_ : w_data_size_in_bytes, + weight_cached_ ? w_data_cache_.get() : w_data.get(), + workspace_bytes, + workspace_cuda.get(), + reservespace_bytes, + reservespace_cuda.get())); + + // Early terminate for this case since Y data is not required, and Y_h is obtained correctly, + // no need the following code to retrieve Y_h from Y data. + if (nullptr == Y) { + // Mask on output for 0 sequence batches + if (zero_seq_count > 0) { // Mask on output for 0 sequence batches - if (zero_seq_count > 0) { - SetZeroSequences(zero_seq_index_cache_size, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); - } - return Status::OK(); + SetZeroSequences(zero_seq_count, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); } + return Status::OK(); } IAllocatorUniquePtr y_reorganized_data; @@ -345,6 +369,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { // reverse output data ReverseBySequence(Stream(ctx), gsl::narrow_cast(seq_length), + sequence_lens_buffer.GpuPtr(), gsl::narrow_cast(batch_size), gsl::narrow_cast(hidden_size_), reinterpret_cast(y_data), @@ -361,8 +386,9 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { } if (Y != nullptr) { - // User specified this optional output, so need to copy the reversed data to orignial place - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(y_data, y_reorganized_data.get(), output_size * sizeof(T), cudaMemcpyDeviceToDevice, Stream(ctx))); + // User specified this optional output, so need to copy the reversed data to original place + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(y_data, y_reorganized_data.get(), output_size * sizeof(T), + cudaMemcpyDeviceToDevice, Stream(ctx))); } else { y_data = y_reorganized_data.get(); } @@ -370,23 +396,9 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { // Mask on output for 0 sequence batches if (zero_seq_count > 0) { - SetZeroSequences(zero_seq_index_cache_size, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); + SetZeroSequences(zero_seq_count, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); } - if ((CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_) && sequence_lens_data != nullptr && y_h_data != nullptr && y_data != nullptr) { - CudaAsyncBuffer sequence_lens_buffer(this, batch_size); - memcpy(sequence_lens_buffer.CpuPtr(), sequence_lens_data, batch_size * sizeof(int32_t)); - ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream())); - RnnMaskImpl(Stream(ctx), - gsl::narrow_cast(num_directions_), - gsl::narrow_cast(seq_length), - gsl::narrow_cast(batch_size), - gsl::narrow_cast(hidden_size_), - sequence_lens_buffer.GpuPtr(), - reinterpret_cast(y_data), - reinterpret_cast(y_h_data), - output_size); - } return Status::OK(); } @@ -399,7 +411,8 @@ void CudnnRnnBase::SetZeroSequences(const int64_t zero_seq_index_cache_size, onnxruntime::Stream* ort_stream) const { typedef typename ToCudaType::MappedType CudaT; CudaAsyncBuffer zero_seq_index_cache_async_buffer(this, zero_seq_index_cache_size); - memcpy(zero_seq_index_cache_async_buffer.CpuPtr(), zero_seq_index_cache.data(), zero_seq_index_cache_size * sizeof(int32_t)); + memcpy(zero_seq_index_cache_async_buffer.CpuPtr(), zero_seq_index_cache.data(), + zero_seq_index_cache_size * sizeof(int32_t)); ORT_THROW_IF_ERROR(zero_seq_index_cache_async_buffer.CopyToGpu(ort_stream)); cudaStream_t cuda_stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; MaskZeroSequences(cuda_stream, diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h index 1c9483b2afd38..0fa01d3486e99 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h @@ -38,26 +38,28 @@ class CudnnRNN { } } - Status Set(const cudnnHandle_t& cudnnHandle, int64_t hidden_size, int num_layers, + Status Set(int64_t input_size, int64_t hidden_size, int64_t proj_size, int num_layers, cudnnDropoutDescriptor_t cudnn_dropout_desc, cudnnDirectionMode_t cudnn_direction_model, - cudnnRNNMode_t rnn_mode, cudnnDataType_t dataType, const cudaDeviceProp& prop) { + cudnnRNNMode_t rnn_mode, bool has_bias, cudnnDataType_t dataType) { if (!cudnn_rnn_desc_) CUDNN_RETURN_IF_ERROR(cudnnCreateRNNDescriptor(&cudnn_rnn_desc_)); - CUDNN_RETURN_IF_ERROR(cudnnSetRNNDescriptor_v6(cudnnHandle, - cudnn_rnn_desc_, + CUDNN_RETURN_IF_ERROR(cudnnSetRNNDescriptor_v8(cudnn_rnn_desc_, + CUDNN_RNN_ALGO_STANDARD, // CUDNN_RNN_ALGO_PERSIST_STATIC, CUDNN_RNN_ALGO_PERSIST_DYNAMIC + rnn_mode, + has_bias ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS, + cudnn_direction_model, + CUDNN_LINEAR_INPUT, + dataType, + dataType, + dataType == CUDNN_DATA_HALF ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH, + gsl::narrow_cast(input_size), gsl::narrow_cast(hidden_size), + gsl::narrow_cast(proj_size), // projected size num_layers, cudnn_dropout_desc, - CUDNN_LINEAR_INPUT, // We can also skip the input matrix transformation - cudnn_direction_model, - rnn_mode, - CUDNN_RNN_ALGO_STANDARD, // CUDNN_RNN_ALGO_PERSIST_STATIC, CUDNN_RNN_ALGO_PERSIST_DYNAMIC - dataType)); - - if (prop.major >= 7 && dataType == CUDNN_DATA_HALF) { - cudnnSetRNNMatrixMathType(cudnn_rnn_desc_, CUDNN_TENSOR_OP_MATH); - } + // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences + CUDNN_RNN_PADDED_IO_ENABLED)); return Status::OK(); } @@ -119,8 +121,7 @@ class CudnnRnnBase : public CudaKernel { private: Status SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, const cudnnRNNDescriptor_t rnn_desc, - const cudnnTensorDescriptor_t x_desc, - const cudnnFilterDescriptor_t w_desc, + size_t w_data_size, void* w_data, const T* W_data, const T* R_data, @@ -128,23 +129,22 @@ class CudnnRnnBase : public CudaKernel { cudaStream_t cuda_stream) const; Status ReorganizeWeights(const Tensor* W, const Tensor* R, const Tensor* B, + size_t& target_w_data_size_in_bytes, IAllocatorUniquePtr& target_w_data, CudnnFilterDescriptor& target_w_desc, CudnnRNN& rnn_desc, onnxruntime::Stream* ort_stream) const; - void SetWeightBias(const cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnn_desc, - const int pseudo_layer, - const cudnnTensorDescriptor_t x_desc, - const cudnnFilterDescriptor_t w_desc, - const cudnnFilterDescriptor_t filter_desc, - const void* w_data, - const int lin_layer_id, - const T* pos, - int& offset, - bool is_matrix, - cudaStream_t cuda_stream) const; + Status SetWeightBias(const cudnnHandle_t handle, + const cudnnRNNDescriptor_t rnn_desc, + const int pseudo_layer, + size_t w_data_size, + const void* w_data, + const int lin_layer_id, + const T* pos, + int& offset, + bool is_matrix, + cudaStream_t cuda_stream) const; void SetZeroSequences(const int64_t zero_seq_index_cache_size, const std::vector zero_seq_index_cache, @@ -167,6 +167,7 @@ class CudnnRnnBase : public CudaKernel { cudnnRNNMode_t rnn_mode_; // w_desc_cache_ & w_data_cache_ are changed in Constructor if we can get the weights as constant input CudnnFilterDescriptor w_desc_cache_; + size_t w_data_cache_size_in_bytes_; IAllocatorUniquePtr w_data_cache_; bool weight_cached_; int64_t layout_; diff --git a/onnxruntime/core/providers/cuda/rnn/rnn.cc b/onnxruntime/core/providers/cuda/rnn/rnn.cc index 4bd22340ef2bb..ed8be63679707 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn.cc +++ b/onnxruntime/core/providers/cuda/rnn/rnn.cc @@ -1,8 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/shared_library/provider_api.h" #include "rnn.h" + +#include "core/providers/shared_library/provider_api.h" #include "rnn_impl.h" #include "core/providers/cuda/cudnn_common.h" diff --git a/onnxruntime/core/providers/cuda/rnn/rnn.h b/onnxruntime/core/providers/cuda/rnn/rnn.h index e4e50046b3725..6221afb003b22 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn.h +++ b/onnxruntime/core/providers/cuda/rnn/rnn.h @@ -4,6 +4,7 @@ #pragma once #include "cudnn_rnn_base.h" + #include "core/providers/cuda/cuda_common.h" #include diff --git a/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu b/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu index d485855ddb417..94c8036be6cdf 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu +++ b/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu @@ -8,22 +8,32 @@ namespace onnxruntime { namespace cuda { template -__global__ void _ReverseBySequenceKernel(const int32_t seq_length, +__global__ void _ReverseBySequenceKernel(const int32_t max_seq_length, + const int32_t* seq_lengths, const int32_t block_size, const fast_divmod div_batch_block, + const fast_divmod div_input_or_hidden_size, const T* data, T* reversed_data, const CUDA_LONG N) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); int seq_id, offset; div_batch_block.divmod(id, seq_id, offset); - int org_id = (seq_length - seq_id - 1) * block_size + offset; - reversed_data[id] = data[org_id]; + int batch, batch_offset; + div_input_or_hidden_size.divmod(offset, batch, batch_offset); + int seq_id_org = seq_lengths[batch] - seq_id - 1; + if (seq_id_org >= 0) { + int org_id = seq_id_org * block_size + offset; + reversed_data[id] = data[org_id]; + } else { + reversed_data[id] = T{}; + } } template void ReverseBySequence(cudaStream_t stream, - const int32_t seq_length, + const int32_t max_seq_length, + const int32_t *seq_lengths, const int32_t batch_size, const int32_t input_or_hidden_size, const T* data, @@ -32,9 +42,10 @@ void ReverseBySequence(cudaStream_t stream, // kerneral int32_t block_size = batch_size * input_or_hidden_size; fast_divmod div_batch_block(block_size); + fast_divmod div_input_or_hidden_size(input_or_hidden_size); int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); _ReverseBySequenceKernel<<>>( - seq_length, block_size, div_batch_block, data, reversed_data, (CUDA_LONG)N); + max_seq_length, seq_lengths, block_size, div_batch_block, div_input_or_hidden_size, data, reversed_data, (CUDA_LONG)N); } template @@ -82,60 +93,6 @@ void ReorderBidirectionalDataInSequence(cudaStream_t stream, data, reordered_data, (CUDA_LONG)N); } -template -__global__ void _RnnMaskKernel(const int32_t seq_length, - const int32_t batch_size, - const int32_t hidden_size, - const int32_t* sequence_lens, - const fast_divmod div_seq_block, - const fast_divmod div_dir_block, - const fast_divmod div_batch_block, - T* y_output_data, - T* y_h_output_data, - const CUDA_LONG N) { - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); - - int seq_id, direction_id, batch_id, offset; - div_seq_block.divmod(id, seq_id, offset); - div_dir_block.divmod(offset, direction_id, offset); - div_batch_block.divmod(offset, batch_id, offset); - int32_t batch_seq_length = sequence_lens[batch_id]; - - if (batch_id >= batch_size || batch_seq_length == seq_length) { - return; - } - - if (seq_id >= batch_seq_length) { - y_output_data[id] = 0; - return; - } - - if ((y_h_output_data != nullptr) && - ((direction_id == 0 && (seq_id + 1) == batch_seq_length) || (direction_id == 1 && seq_id == 0))) { - int hy_idx = direction_id * batch_size * hidden_size + batch_id * hidden_size + offset; - y_h_output_data[hy_idx] = y_output_data[id]; - } -} - -template -void RnnMaskImpl(cudaStream_t stream, - const int32_t num_directions, - const int32_t seq_length, - const int32_t batch_size, - const int32_t hidden_size, - const int32_t* sequence_lens, - T* y_output_data, - T* y_h_output_data, - const size_t N) { - fast_divmod div_seq_block(batch_size * hidden_size * num_directions); - fast_divmod div_dir_block(batch_size * hidden_size); - fast_divmod div_batch_block(hidden_size); - int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); - _RnnMaskKernel<<>>( - seq_length, batch_size, hidden_size, sequence_lens, div_seq_block, - div_dir_block, div_batch_block, y_output_data, y_h_output_data, (CUDA_LONG)N); -} - template __global__ void _MaskZeroSequences(const int32_t hidden_size, T* y_output_data, @@ -180,17 +137,9 @@ void MaskZeroSequences(cudaStream_t stream, } #define SPECIALIZED_RNN_IMPL(T) \ - template void RnnMaskImpl(cudaStream_t stream, \ - const int32_t num_directions, \ - const int32_t seq_length, \ - const int32_t batch_size, \ - const int32_t hidden_size, \ - const int32_t* sequence_lens, \ - T* y_output_data, \ - T* y_h_output_data, \ - const size_t N); \ - template void ReverseBySequence(cudaStream_t stream, \ - const int32_t seq_length, \ + template void ReverseBySequence(cudaStream_t stream, \ + const int32_t max_seq_length, \ + const int32_t* seq_lengths, \ const int32_t batch_size, \ const int32_t hidden_size, \ const T* data, \ @@ -203,7 +152,7 @@ void MaskZeroSequences(cudaStream_t stream, const T* data, \ T* reordered_data, \ const size_t N); \ -template void MaskZeroSequences(cudaStream_t stream, \ +template void MaskZeroSequences(cudaStream_t stream, \ const int32_t hidden_size, \ T* y_output_data, \ T* y_h_output_data, \ diff --git a/onnxruntime/core/providers/cuda/rnn/rnn_impl.h b/onnxruntime/core/providers/cuda/rnn/rnn_impl.h index 9844e04ff6ec5..ba876011f6b67 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn_impl.h +++ b/onnxruntime/core/providers/cuda/rnn/rnn_impl.h @@ -10,7 +10,8 @@ namespace cuda { template void ReverseBySequence(cudaStream_t stream, - const int32_t seq_length, + const int32_t max_seq_length, + const int32_t* seq_lengths, const int32_t batch_size, const int32_t input_or_hidden_size, const T* data, @@ -26,17 +27,6 @@ void ReorderBidirectionalDataInSequence(cudaStream_t stream, T* reordered_data, const size_t N); -template -void RnnMaskImpl(cudaStream_t stream, - const int32_t num_directions, - const int32_t seq_length, - const int32_t batch_size, - const int32_t hidden_size, - const int32_t* sequence_lens, - T* y_output_data, - T* y_h_output_data, - const size_t N); - template void MaskZeroSequences(cudaStream_t stream, const int32_t hidden_size, diff --git a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h index fa987866c002f..54c024793ff0b 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h +++ b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h @@ -168,5 +168,31 @@ struct NumericLimits { } }; +// TODO Where to put this? good places might be +// core/framework/tensor_shape.h +// core/util/matrix_layout.h + +constexpr bool LAYOUT_NCHW = false; +constexpr bool LAYOUT_NHWC = true; + +template +struct Channels; + +template <> +struct Channels { + static constexpr size_t N = 0; + static constexpr size_t H = 1; + static constexpr size_t W = 2; + static constexpr size_t C = 3; +}; + +template <> +struct Channels { + static constexpr size_t N = 0; + static constexpr size_t C = 1; + static constexpr size_t H = 2; + static constexpr size_t W = 3; +}; + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h index 510cc5cfbb7dd..053c66ddcb34a 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h +++ b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h @@ -29,13 +29,15 @@ cublasGemmHelper(cublasHandle_t handle, const float* B, int ldb, const float* beta, float* C, int ldc, - const cudaDeviceProp& prop) { + const cudaDeviceProp& prop, + bool use_tf32) { #if defined(USE_CUDA) - // TF32 uses 10 bit mantissa which has sufficient margin of precision for most use cases. It gets 8x throughput than FP32 in A100. - // It can be overrided by setting environment variable NVIDIA_TF32_OVERRIDE = 0 to disable TF32 - onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, CUBLAS_TF32_TENSOR_OP_MATH); + // To disable TF32, set environment variable NVIDIA_TF32_OVERRIDE = 0 or set provider option use_tf32 = 0 + cublasMath_t mode = use_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; + onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, mode); #else ORT_UNUSED_PARAMETER(prop); + ORT_UNUSED_PARAMETER(use_tf32); #endif return cublasSgemm(handle, @@ -58,7 +60,8 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, const double* B, int ldb, const double* beta, double* C, int ldc, - const cudaDeviceProp& /*prop*/) { + const cudaDeviceProp& /*prop*/, + bool /*use_tf32*/) { return cublasDgemm(handle, transa, transb, @@ -79,7 +82,8 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, const half* B, int ldb, const half* beta, half* C, int ldc, - const cudaDeviceProp& prop) { + const cudaDeviceProp& prop, + bool /*use_tf32*/) { const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { @@ -121,7 +125,8 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, const half* B, int ldb, const float* beta, half* C, int ldc, - const cudaDeviceProp& prop) { + const cudaDeviceProp& prop, + bool /*use_tf32*/) { const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { @@ -155,10 +160,11 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, } #if defined(USE_CUDA) -inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, - int n, int k, const BFloat16* alpha, const BFloat16* A, int lda, - const BFloat16* B, int ldb, const BFloat16* beta, BFloat16* C, int ldc, - const cudaDeviceProp& /*prop*/) { +inline cublasStatus_t cublasGemmHelper( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, + int n, int k, const BFloat16* alpha, const BFloat16* A, int lda, + const BFloat16* B, int ldb, const BFloat16* beta, BFloat16* C, int ldc, + const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) { float h_a = alpha->ToFloat(); float h_b = beta->ToFloat(); @@ -169,7 +175,7 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, cublasOperation_t #else inline cublasStatus_t cublasGemmHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const BFloat16*, const BFloat16*, int, const BFloat16*, int, const BFloat16*, - BFloat16*, int, const cudaDeviceProp&) { + BFloat16*, int, const cudaDeviceProp&, bool /*use_tf32*/) { return CUBLAS_STATUS_NOT_SUPPORTED; } #endif @@ -185,7 +191,17 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, const float* beta, float* Carray[], int ldc, int batch_count, - const cudaDeviceProp&) { + const cudaDeviceProp& prop, + bool use_tf32) { +// The caller shall check memory alignments of the matrices when use_tf32 is true. +#if defined(USE_CUDA) + cublasMath_t mode = use_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; + onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, mode); +#else + ORT_UNUSED_PARAMETER(prop); + ORT_UNUSED_PARAMETER(use_tf32); +#endif + return cublasSgemmBatched(handle, transa, transb, @@ -208,7 +224,8 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, const double* beta, double* Carray[], int ldc, int batch_count, - const cudaDeviceProp& /*prop*/) { + const cudaDeviceProp& /*prop*/, + bool /*use_tf32*/) { return cublasDgemmBatched(handle, transa, transb, @@ -231,7 +248,8 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, const half* beta, half* Carray[], int ldc, int batch_count, - const cudaDeviceProp& prop) { + const cudaDeviceProp& prop, + bool /*use_tf32*/) { const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { @@ -266,11 +284,12 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, } #if defined(USE_CUDA) -inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, - int m, int n, int k, const BFloat16* alpha, const BFloat16* Aarray[], - int lda, const BFloat16* Barray[], int ldb, const BFloat16* beta, - BFloat16* Carray[], int ldc, int batch_count, - const cudaDeviceProp& /*prop*/) { +inline cublasStatus_t cublasGemmBatchedHelper( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const BFloat16* alpha, const BFloat16* Aarray[], + int lda, const BFloat16* Barray[], int ldb, const BFloat16* beta, + BFloat16* Carray[], int ldc, int batch_count, + const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) { float h_a = alpha->ToFloat(); float h_b = beta->ToFloat(); @@ -282,7 +301,8 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, cublasOpera #else inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const BFloat16*, const BFloat16*[], int, const BFloat16*[], int, - const BFloat16*, BFloat16*[], int, int, const cudaDeviceProp&) { + const BFloat16*, BFloat16*[], int, int, const cudaDeviceProp&, + bool /*use_tf32*/) { return CUBLAS_STATUS_NOT_SUPPORTED; } #endif @@ -301,15 +321,14 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, float* C, int ldc, long long int strideC, int batch_count, - const cudaDeviceProp& prop) { -#ifdef ENABLE_TRAINING_OPS + const cudaDeviceProp& prop, + bool use_tf32) { #if defined(USE_CUDA) - onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, CUBLAS_TF32_TENSOR_OP_MATH); -#else - ORT_UNUSED_PARAMETER(prop); -#endif + cublasMath_t mode = use_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; + onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, mode); #else ORT_UNUSED_PARAMETER(prop); + ORT_UNUSED_PARAMETER(use_tf32); #endif return cublasSgemmStridedBatched(handle, @@ -337,7 +356,8 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, double* C, int ldc, long long int strideC, int batch_count, - const cudaDeviceProp& /*prop*/) { + const cudaDeviceProp& /*prop*/, + bool /*use_tf32*/) { return cublasDgemmStridedBatched(handle, transa, transb, @@ -363,7 +383,8 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, __half* C, int ldc, long long int strideC, int batch_count, - const cudaDeviceProp& prop) { + const cudaDeviceProp& prop, + bool /*use_tf32*/) { const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { @@ -411,7 +432,8 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, __half* C, int ldc, long long int strideC, int batch_count, - const cudaDeviceProp& prop) { + const cudaDeviceProp& prop, + bool /*use_tf32*/) { const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { @@ -447,49 +469,66 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, } #if defined(USE_CUDA) -inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const BFloat16* alpha, const BFloat16* A, int lda, - long long int strideA, const BFloat16* B, int ldb, - long long int strideB, const BFloat16* beta, BFloat16* C, int ldc, - long long int strideC, int batch_count, - const cudaDeviceProp& /*prop*/) { +inline cublasStatus_t cublasGemmStridedBatchedHelper( + cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const BFloat16* alpha, const BFloat16* A, int lda, + long long int strideA, const BFloat16* B, int ldb, + long long int strideB, const BFloat16* beta, BFloat16* C, int ldc, + long long int strideC, int batch_count, + const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) { float h_a = alpha->ToFloat(); float h_b = beta->ToFloat(); // accumulating in FP32 - return cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, &h_a, A, CUDA_R_16BF, lda, strideA, B, CUDA_R_16BF, - ldb, strideB, &h_b, C, CUDA_R_16BF, ldc, strideC, batch_count, CUDA_R_32F, - CUBLAS_GEMM_DEFAULT); + return cublasGemmStridedBatchedEx( + handle, transa, transb, m, n, k, &h_a, A, CUDA_R_16BF, lda, strideA, B, CUDA_R_16BF, + ldb, strideB, &h_b, C, CUDA_R_16BF, ldc, strideC, batch_count, CUDA_R_32F, + CUBLAS_GEMM_DEFAULT); } #else -inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, - int, const BFloat16*, const BFloat16*, int, long long int, - const BFloat16*, int, long long int, const BFloat16*, BFloat16*, - int, long long int, int, const cudaDeviceProp&) { +inline cublasStatus_t cublasGemmStridedBatchedHelper( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, + int, const BFloat16*, const BFloat16*, int, long long int, + const BFloat16*, int, long long int, const BFloat16*, BFloat16*, + int, long long int, int, const cudaDeviceProp&, bool /*use_tf32*/) { return CUBLAS_STATUS_NOT_SUPPORTED; } #endif // transpose using geam -inline cublasStatus_t cublasTransposeHelper(cudaStream_t, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, const float* alpha, const float* A, int lda, const float* beta, const float* B, int ldb, float* C, int ldc) { +inline cublasStatus_t cublasTransposeHelper( + cudaStream_t, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, const float* alpha, const float* A, int lda, const float* beta, const float* B, int ldb, + float* C, int ldc) { return cublasSgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); } -inline cublasStatus_t cublasTransposeHelper(cudaStream_t, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, const double* alpha, const double* A, int lda, const double* beta, const double* B, int ldb, double* C, int ldc) { +inline cublasStatus_t cublasTransposeHelper( + cudaStream_t, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, const double* alpha, const double* A, int lda, const double* beta, const double* B, int ldb, + double* C, int ldc) { return cublasDgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); } bool CanUse_cublasTransposeHelper_MLFloat16(int m, int n); -cublasStatus_t cublasTransposeHelper(cudaStream_t, cublasHandle_t, cublasOperation_t, cublasOperation_t, int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int); + +cublasStatus_t cublasTransposeHelper( + cudaStream_t, cublasHandle_t, cublasOperation_t, cublasOperation_t, + int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int); // copy -inline cublasStatus_t cublasCopyHelper(cudaStream_t, cublasHandle_t handle, int n, const float* x, int incx, float* y, int incy) { +inline cublasStatus_t cublasCopyHelper( + cudaStream_t, cublasHandle_t handle, int n, const float* x, int incx, float* y, int incy) { return cublasScopy(handle, n, x, incx, y, incy); } -inline cublasStatus_t cublasCopyHelper(cudaStream_t, cublasHandle_t handle, int n, const double* x, int incx, double* y, int incy) { +inline cublasStatus_t cublasCopyHelper( + cudaStream_t, cublasHandle_t handle, int n, const double* x, int incx, double* y, int incy) { return cublasDcopy(handle, n, x, incx, y, incy); } -cublasStatus_t cublasCopyHelper(cudaStream_t stream, cublasHandle_t handle, int n, const half* x, int incx, half* y, int incy); -cublasStatus_t cublasCopyHelper(cudaStream_t stream, cublasHandle_t handle, int n, const BFloat16* x, int incx, BFloat16* y, int incy); +cublasStatus_t cublasCopyHelper( + cudaStream_t stream, cublasHandle_t handle, int n, const half* x, int incx, half* y, int incy); + +cublasStatus_t cublasCopyHelper( + cudaStream_t stream, cublasHandle_t handle, int n, const BFloat16* x, int incx, BFloat16* y, int incy); diff --git a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu index 10c8625b39ef8..b710e8a1b48c2 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu @@ -95,7 +95,37 @@ struct OffsetCalculatorFor2D { template struct FuncAssignment { - __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { start_addr[index] = value; } + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { + start_addr[index] = value; + } +}; + +template +struct FuncAdd { + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { + atomic_add(start_addr + index, value); + } +}; + +template +struct FuncMul { + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { + atomic_mul(start_addr + index, value); + } +}; + +template +struct FuncMax { + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { + atomic_max(start_addr + index, value); + } +}; + +template +struct FuncMin { + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { + atomic_min(start_addr + index, value); + } }; template @@ -238,8 +268,24 @@ Status ScatterElementsImplInternal(cudaStream_t stream, const T* input_data, con template Status ScatterElementsImpl(cudaStream_t stream, const T* input_data, const TIndex* indices_data, const T* updates_data, T* output_data, const GatherScatterElementsArgs& args) { - return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, - FuncAssignment()); + if (args.operation == GatherScatterElementsArgs::Operation::NONE) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncAssignment()); + } else if (args.operation == GatherScatterElementsArgs::Operation::ADD) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncAdd()); + } else if (args.operation == GatherScatterElementsArgs::Operation::MUL) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncMul()); + } else if (args.operation == GatherScatterElementsArgs::Operation::MAX) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncMax()); + } else if (args.operation == GatherScatterElementsArgs::Operation::MIN) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncMin()); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported reduction operator."); + } } #define GATHER_SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL(T, TIndex) \ diff --git a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h index 631d0bf049c6f..7b1c88f1fc1cb 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h @@ -10,6 +10,14 @@ namespace onnxruntime { namespace cuda { struct GatherScatterElementsArgs { + enum class Operation { + NONE, + ADD, + MUL, + MAX, + MIN + }; + int64_t rank; int64_t axis; int64_t input_size; @@ -19,6 +27,9 @@ struct GatherScatterElementsArgs { TArray indices_fdms; TArray indices_strides; int64_t indices_size; + // operation used to combine values associated the same + // memory location in the output tensor. + Operation operation; }; template diff --git a/onnxruntime/core/providers/cuda/tensor/gelu.cc b/onnxruntime/core/providers/cuda/tensor/gelu.cc new file mode 100644 index 0000000000000..67b2fad373a7f --- /dev/null +++ b/onnxruntime/core/providers/cuda/tensor/gelu.cc @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/tensor/gelu.h" +#include "core/providers/cuda/tensor/gelu_impl.h" + +namespace onnxruntime { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Gelu, \ + kOnnxDomain, \ + 20, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .MayInplace(0, 0), \ + Gelu); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(double) + +template +Status Gelu::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const auto& input_dims = input->Shape().GetDims(); + if (input_dims.size() < 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 0 is expected to have 1 or more dimensions, got ", input_dims.size()); + } + + Tensor* output = context->Output(0, input->Shape()); + + int64_t input_length = input->Shape().Size(); + if (input_length == 0) { + return Status::OK(); + } + + typedef typename ToCudaType::MappedType CudaT; + + if (approximation_algorithm_ == "tanh") { + return LaunchFastGeluKernel(GetDeviceProp(), + Stream(context), + static_cast(input_length), + 0 /* no bias */, + reinterpret_cast(input->Data()), + nullptr /* no bias */, + reinterpret_cast(output->MutableData()), + use_half2_); + } else if (approximation_algorithm_ == "none") { + return LaunchGeluKernel(Stream(context), + reinterpret_cast(input->Data()), + reinterpret_cast(output->MutableData()), + static_cast(input_length)); + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported approximation_algorithm: ", approximation_algorithm_); +} + +} // namespace cuda + +#ifndef DISABLE_CONTRIB_OPS +namespace contrib::cuda { +#define REGISTER_CONTRIB_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Gelu, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .MayInplace(0, 0), \ + onnxruntime::cuda::Gelu); + +REGISTER_CONTRIB_KERNEL_TYPED(float) +REGISTER_CONTRIB_KERNEL_TYPED(MLFloat16) +REGISTER_CONTRIB_KERNEL_TYPED(double) + +#undef REGISTER_CONTRIB_KERNEL_TYPED +} // namespace contrib::cuda +#endif + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/gelu.h b/onnxruntime/core/providers/cuda/tensor/gelu.h new file mode 100644 index 0000000000000..1c8189ab24121 --- /dev/null +++ b/onnxruntime/core/providers/cuda/tensor/gelu.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/math/unary_elementwise_ops.h" + +namespace onnxruntime { +namespace cuda { + +template +class Gelu final : public UnaryElementwise { + public: + Gelu(const OpKernelInfo& info) : UnaryElementwise(info) { + approximation_algorithm_ = info.GetAttrOrDefault("approximate", "none"); + } + + Status ComputeInternal(OpKernelContext* ctx) const override; + + private: + const bool use_half2_{true}; + + std::string approximation_algorithm_; +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu b/onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu similarity index 84% rename from onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu rename to onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu index c9498eb1bcd7b..7a27b7af33137 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu @@ -24,12 +24,9 @@ limitations under the License. #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/shared_inc/cuda_call.h" -#include "contrib_ops/cuda/bert/fast_gelu_impl.h" - -using namespace onnxruntime::cuda; +#include "core/providers/cuda/tensor/gelu_impl.h" namespace onnxruntime { -namespace contrib { namespace cuda { // constants for approximating the normal cdf @@ -65,7 +62,7 @@ __global__ void FastGeluKernel2(const half2 a, const half2 b, const half2 c, int } template <> -Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, +Status LaunchFastGeluKernel(const cudaDeviceProp& /*prop*/, cudaStream_t stream, int input_length, int bias_length, const float* input, const float* bias, float* output, bool /*use_half2*/) { constexpr int blockSize = 256; const int gridSize = (input_length + blockSize - 1) / blockSize; @@ -75,6 +72,17 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int return CUDA_CALL(cudaGetLastError()); } +template <> +Status LaunchFastGeluKernel(const cudaDeviceProp& /*prop*/, cudaStream_t stream, int input_length, int bias_length, + const double* input, const double* bias, double* output, bool /*use_half2*/) { + constexpr int blockSize = 256; + const int gridSize = (input_length + blockSize - 1) / blockSize; + FastGeluKernel<<>>(A, B, C, input_length, bias_length, + input, bias, output); + + return CUDA_CALL(cudaGetLastError()); +} + template <> Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, const half* input, const half* bias, half* output, bool use_half2) { @@ -100,7 +108,7 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int } template <> -Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, +Status LaunchFastGeluKernel(const cudaDeviceProp& /*prop*/, cudaStream_t stream, int input_length, int bias_length, const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) { constexpr int blockSize = 256; @@ -114,5 +122,4 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int } } // namespace cuda -} // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/gelu_impl.cu b/onnxruntime/core/providers/cuda/tensor/gelu_impl.cu new file mode 100644 index 0000000000000..3f96da38b37bb --- /dev/null +++ b/onnxruntime/core/providers/cuda/tensor/gelu_impl.cu @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "core/providers/cuda/tensor/gelu_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cu_inc/unary_elementwise_impl.cuh" + +namespace onnxruntime { +namespace cuda { + +template +struct OP_Gelu { + __device__ __inline__ T operator()(const T& a) const { + return _Gelu(a); + } +}; + +template <> +struct OP_Gelu { + __device__ __inline__ half operator()(const half& a) const { + return static_cast(_Gelu(static_cast(a))); + } +}; + +template +Status LaunchGeluKernel( + cudaStream_t stream, + const T* input_data, + T* output_data, + size_t count) { + UnaryElementWiseImpl(stream, input_data, output_data, OP_Gelu(), count); + + return CUDA_CALL(cudaGetLastError()); +} + +#define SPECIALIZED_GELU_IMPL(T) \ + template Status LaunchGeluKernel(cudaStream_t stream, const T* input_data, T* output_data, \ + size_t count); + +SPECIALIZED_GELU_IMPL(float); +SPECIALIZED_GELU_IMPL(half); +SPECIALIZED_GELU_IMPL(double); + +#undef SPECIALIZED_GELU_IMPL + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.h b/onnxruntime/core/providers/cuda/tensor/gelu_impl.h similarity index 80% rename from onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.h rename to onnxruntime/core/providers/cuda/tensor/gelu_impl.h index ba78310f5dfc2..2ea0d3441fda3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/gelu_impl.h @@ -1,17 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - #pragma once + #include "core/common/common.h" namespace onnxruntime { -namespace contrib { namespace cuda { +template +Status LaunchGeluKernel(cudaStream_t stream, const T* input, T* output, size_t count); + template Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, const T* input, const T* bias, T* output, bool use_half2); } // namespace cuda -} // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/resize.cc b/onnxruntime/core/providers/cuda/tensor/resize.cc index 764172a8d1fac..97d4eb71e970a 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize.cc +++ b/onnxruntime/core/providers/cuda/tensor/resize.cc @@ -28,10 +28,22 @@ namespace cuda { .InputMemoryType(OrtMemTypeCPUInput, 3) \ .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \ Resize); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Resize, \ + kOnnxDomain, \ + 13, 17, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \ + Resize); \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ Resize, \ kOnnxDomain, \ - 13, \ + 18, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ diff --git a/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu new file mode 100644 index 0000000000000..d56e4bc53874d --- /dev/null +++ b/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu @@ -0,0 +1,1179 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/tensor/resize_impl.h" + +#define FUNC_DEF __device__ + +namespace onnxruntime { +namespace cuda { + +using onnxruntime::ResizeCoordinateTransformationMode; +using onnxruntime::UpsampleMode; + +/// +/// Compute a buffer for bilinear data for CUDA antialias resizing. +/// +static std::tuple ComputeBilinearScaleBufferSize( + int64_t output_height, int64_t output_width, + float height_rscale, float width_rscale, + float support_value, + float& scaled_support_height, float& scaled_support_width, + int32_t& window_size_height, int32_t& window_size_width) { + scaled_support_height = ComputeScaledSupportValue(support_value, height_rscale); + scaled_support_width = ComputeScaledSupportValue(support_value, width_rscale); + window_size_height = ComputeWindowSize(scaled_support_height); + window_size_width = ComputeWindowSize(scaled_support_width); + + auto height_buffer_size = ComputeWeightedCoeffBufferSize(output_height, window_size_height); + auto width_buffer_size = ComputeWeightedCoeffBufferSize(output_width, window_size_width); + + return std::make_tuple(height_buffer_size, width_buffer_size); +} + +/// +/// Compute a buffer for btrilinear data for CUDA antialias resizing. +/// +static std::tuple ComputeTrilinearScaleBufferSize( + int64_t output_depth, int64_t output_height, int64_t output_width, + float depth_rscale, float height_rscale, float width_rscale, + float support_value, + float& scaled_support_depth, float& scaled_support_height, + float& scaled_support_width, int32_t& window_size_depth, + int32_t& window_size_height, int32_t& window_size_width) { + scaled_support_depth = ComputeScaledSupportValue(support_value, depth_rscale); + window_size_depth = ComputeWindowSize(scaled_support_depth); + auto depth_buffer_size = ComputeWeightedCoeffBufferSize(output_depth, window_size_depth); + + const auto [y_buffer_size, w_buffer_size] = ComputeBilinearScaleBufferSize(output_height, + output_width, height_rscale, + width_rscale, support_value, + scaled_support_height, + scaled_support_width, + window_size_height, window_size_width); + return std::make_tuple(depth_buffer_size, y_buffer_size, w_buffer_size); +} + +// Antialiasing filters +struct BilinearFilter { + __device__ __host__ float operator()(float x, float /* cubic_coeff_a */) const { + if (x < 0.0f) { + x = -x; + } + if (x < 1.0f) { + return 1.0f - x; + } + return 0.0f; + } +}; + +struct BiCubicFilter { + __device__ __host__ float operator()(float x, float cubic_coeff_a) const { + /* https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm + */ + if (x < 0.0f) { + x = -x; + } + if (x < 1.0f) { + return ((cubic_coeff_a + 2.0f) * x - (cubic_coeff_a + 3.0f)) * x * x + 1; + } + if (x < 2.0f) { + return (((x - 5.0f) * x + 8.f) * x - 4.f) * cubic_coeff_a; + } + return 0.0f; + } +}; + +struct TriLinearFilter { + __device__ __host__ float operator()(float x, float /* cubic_coeff_a */) const { + if (x < 0.0f) { + x = -x; + } + if (x < 1.0f) { + return 1.0f - x; + } + return 0.0f; + } +}; + +template +struct AccumTypeCaster { + static __device__ __host__ AccumType* cast(AccumType* p) { + return p; + } +}; + +template <> +struct AccumTypeCaster { + static __device__ __host__ float* cast(int32_t* p) { + return reinterpret_cast(p); + } +}; + +template +__global__ void _ComputeInterpolationAtLevel1( + int64_t num_channels, + int64_t input_height, int64_t input_width, + int64_t output_height, int64_t output_width, + const fast_divmod div_output_width, + const fast_divmod div_output_image, + int32_t window_size, + const uint8_t* clip8_table, + const int64_t* bound_data, + std::tuple outof_bounds_buffers, + const AccumType* weight_coefficients, + const T* Xdata, T* Ydata, + const int N) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + + // No need to do scale + if (output_width == input_width) { + Ydata[id] = Xdata[id]; + return; + } + + int bxc, output_image_index; + div_output_image.divmod(id, bxc, output_image_index); + + int output_y, output_x; + div_output_width.divmod(output_image_index, output_y, output_x); + + CUDA_LONG input_index = static_cast(bxc * num_channels * input_height * input_width); + CUDA_LONG output_index = static_cast(bxc * num_channels * output_height * output_width); + + auto* Ydata_offset = Ydata + output_index + output_width * output_y + output_x; + const auto* bound = bound_data; + + AccumType output = onnxruntime::is_8bit_v ? ConstValue::mag_factor : 0; + + const auto* weight_coeff = weight_coefficients + window_size * output_x; + int64_t xmin = bound[static_cast(output_x) * 2]; + int64_t xmax = bound[static_cast(output_x) * 2 + 1]; + + // Input window + const auto* Xdata_offset = Xdata + input_index + input_width * output_y + xmin; + + for (; xmin < xmax; ++xmin) { + if constexpr (std::is_same::value) { + // This cast is needed when we deal with half + output += static_cast((*Xdata_offset++)) * (*weight_coeff++); + } else { + output += (*Xdata_offset++) * (*weight_coeff++); + } + } + + if constexpr (onnxruntime::is_8bit_v) { + const uint8_t* clip8_lookups = &clip8_table[640]; + *Ydata_offset = static_cast(clip8_lookups[output >> 22]); + } else if constexpr (std::is_same::value) { + *Ydata_offset = static_cast(std::round(output)); + } else { + *Ydata_offset = static_cast(output); + } +} + +template +__global__ void _ComputeInterpolationAtLevel2( + int64_t num_channels, + int64_t input_height, int64_t input_width, + int64_t output_height, int64_t output_width, + const fast_divmod div_output_height, + const fast_divmod div_output_width, + const fast_divmod div_output_image, + int32_t window_size, + bool use_extrapolation, float extrapolation_value, + const uint8_t* clip8_table, + const int64_t* bound_data, + std::tuple outof_bounds_buffers, + const AccumType* weight_coefficients, + const T* Xdata, T* Ydata, int N) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + + // No need to do scale + if (output_height == input_height) { + Ydata[id] = Xdata[id]; + return; + } + + int bxc, output_image_index; + div_output_image.divmod(id, bxc, output_image_index); + + int output_z, output_y, output_x, temp; + div_output_height.divmod(output_image_index, output_z, temp); + div_output_width.divmod(temp, output_y, output_x); + + CUDA_LONG input_index = static_cast(bxc * num_channels * input_height * input_width + + output_z * input_height * input_width); + CUDA_LONG output_index = static_cast(bxc * num_channels * output_height * output_width + + output_z * output_height * output_width); + + auto* Ydata_offset = Ydata + output_index + output_width * output_y + output_x; + + if (use_extrapolation) { + const auto* w_outof_bounds = std::get<1>(outof_bounds_buffers); + // Extrapolate along the w dimension + if (w_outof_bounds[static_cast(output_x)] != -1) { + *Ydata_offset = static_cast(extrapolation_value); + return; + } + + // Extrapolate along the y dimension + const auto* y_outof_bounds = std::get<0>(outof_bounds_buffers); + if (y_outof_bounds[static_cast(output_y)] != -1) { + *Ydata_offset = static_cast(extrapolation_value); + return; + } + } + + const auto* bound = bound_data; + + AccumType output = onnxruntime::is_8bit_v ? ConstValue::mag_factor : 0; + + const auto* weight_coeff = weight_coefficients + window_size * output_y; + int64_t ymin = bound[static_cast(output_y) * 2]; + int64_t ymax = bound[static_cast(output_y) * 2 + 1]; + + const auto* Xdata_offset = Xdata + input_index + ymin * output_width + output_x; + + for (; ymin < ymax; ++ymin) { + if constexpr (std::is_same::value) { + // We cast to AccumType to resolve ambiguous call to operator* for half in CUDA + output += static_cast((*Xdata_offset)) * (*weight_coeff++); + } else { + output += (*Xdata_offset) * (*weight_coeff++); + } + Xdata_offset += input_width; + } + + if constexpr (onnxruntime::is_8bit_v) { + const uint8_t* clip8_lookups = &clip8_table[640]; + *Ydata_offset = static_cast(clip8_lookups[output >> 22]); + } else if constexpr (std::is_same::value) { + *Ydata_offset = static_cast(std::round(output)); + } else { + *Ydata_offset = output; + } +} + +template +__global__ void _ComputeInterpolationAtLevel3( + int64_t input_depth, + int64_t input_height, int64_t input_width, + int64_t output_depth, + int64_t output_height, int64_t output_width, + const fast_divmod div_output_height, + const fast_divmod div_output_width, + const fast_divmod div_output_image, + int32_t window_size, + bool use_extrapolation, float extrapolation_value, + const uint8_t* clip8_table, + const int64_t* bound_data, + std::tuple outof_bounds_buffers, + const AccumType* weight_coefficients, + const T* Xdata, T* Ydata, int N) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + + // No need to do scale + if (input_depth == output_depth) { + Ydata[id] = Xdata[id]; + return; + } + + int bxc, output_image_index; + div_output_image.divmod(id, bxc, output_image_index); + + int output_z, output_y, output_x, temp; + div_output_height.divmod(output_image_index, output_z, temp); + div_output_width.divmod(temp, output_y, output_x); + + CUDA_LONG input_index = static_cast(bxc * input_depth * input_height * input_width); + + auto* Ydata_offset = Ydata + id; + + if (use_extrapolation) { + const auto* w_outof_bounds = std::get<2>(outof_bounds_buffers); + // Extrapolate along the w dimension + if (w_outof_bounds[static_cast(output_x)] != -1) { + *Ydata_offset = static_cast(extrapolation_value); + return; + } + + // Extrapolate along the y dimension + const auto* y_outof_bounds = std::get<1>(outof_bounds_buffers); + if (y_outof_bounds[static_cast(output_y)] != -1) { + *Ydata_offset = static_cast(extrapolation_value); + return; + } + + // Extrapolate along the z dimension + const int64_t* z_outof_bounds = std::get<0>(outof_bounds_buffers); + if (z_outof_bounds != nullptr && z_outof_bounds[static_cast(output_z)] != -1) { + *Ydata_offset = static_cast(extrapolation_value); + return; + } + } + + const auto* bound = bound_data; + + AccumType output = onnxruntime::is_8bit_v ? ConstValue::mag_factor : 0; + + const auto* weight_coeff = weight_coefficients + window_size * output_z; + int64_t zmin = bound[static_cast(output_z) * 2]; + int64_t zmax = bound[static_cast(output_z) * 2 + 1]; + + const auto z_step = input_height * input_width; + const auto* Xdata_offset = Xdata + input_index + zmin * z_step + output_y * output_width + output_x; + + for (; zmin < zmax; ++zmin) { + if constexpr (std::is_same::value) { + // We cast to AccumType to resolve ambiguous call to operator* for half in CUDA + output += static_cast((*Xdata_offset)) * (*weight_coeff++); + } else { + output += (*Xdata_offset) * (*weight_coeff++); + } + Xdata_offset += z_step; + } + + if constexpr (onnxruntime::is_8bit_v) { + const uint8_t* clip8_lookups = &clip8_table[640]; + *Ydata_offset = static_cast(clip8_lookups[output >> 22]); + } else if constexpr (std::is_same::value) { + *Ydata_offset = static_cast(std::round(output)); + } else { + *Ydata_offset = output; + } +} + +/// +/// This function expects the following buffers to be pre-allocated on device +/// 1. bounds: int64_t[output_size * 2] +/// 2. out_of_bounds: int64_t[output_size] +/// 3. scale_data: T[output_size * window_size] +/// +/// Template parameter AccumType +/// +template +FUNC_DEF void SetupUpsampleFilterAnitAliasImpl( + int64_t i, + int64_t input_size, int64_t output_size, + float rscale, + float roi_start, float roi_end, + float scaled_support, int32_t window_size, bool exclude_outside, + float cubic_coeff_a, + int64_t* bounds, + int64_t* out_of_bounds, + AccumType* scale_data) { + Filter filter{}; + CudaFunctionOriginalCoordinate get_original_coordinate{}; + + const auto scale = 1.f / rscale; + const float inv_scale = (scale >= 1.0f) ? 1.0f / scale : 1.0f; + + const float id = static_cast(i); + float center = 0.5f; + if (scale == 1.0f) { + center += id; + } else { + center += get_original_coordinate(id, rscale, + static_cast(output_size), + static_cast(input_size), + roi_start, roi_end); + } + + if (center - 0.5f < 0 || center - 0.5f > static_cast(input_size - 1)) { + out_of_bounds[i] = i; + } else { + out_of_bounds[i] = -1; + } + + float total_weight{0}; + + auto fmin = _Floor(center - scaled_support + 0.5f); + auto fmax = _Floor(center + scaled_support + 0.5f); + + int64_t min_real = static_cast(fmin); + int64_t max_real = static_cast(fmax); + int64_t min_cut = std::max(min_real, 0); + int64_t max_cut = std::min(max_real, input_size); + + int64_t min_val = exclude_outside ? min_cut : min_real; + int64_t max_val = exclude_outside ? max_cut : max_real; + bounds[i * 2] = min_cut; + bounds[i * 2 + 1] = max_cut; + + // This is done for int32_t case, when the final result is in int32_t, but + // we perform calculations in float. All other types as is. + auto* scale_buffer = AccumTypeCaster::cast(&scale_data[i * window_size]); + + max_val -= min_val; + for (int64_t x = 0; x < max_val; x++) { + const float arg = (x + min_val - center + 0.5f) * inv_scale; + const auto w = filter(arg, cubic_coeff_a); + scale_buffer[x] = w; + total_weight += w; + } + + if (!exclude_outside) { + int64_t neg_xsize = min_val < 0 ? -min_val : 0; + for (int64_t x = 0; x < neg_xsize; x++) { + scale_buffer[neg_xsize] += scale_buffer[x]; + } + + int64_t bound_size = + max_val + min_val > input_size ? max_val + min_val - input_size : 0; + for (int64_t x = max_val - bound_size; x < max_val; x++) { + scale_buffer[max_val - bound_size - 1] += + scale_buffer[x]; + } + + for (int64_t x = 0; (neg_xsize | bound_size) > 0 && x < max_cut - min_cut; x++) { + scale_buffer[x] = scale_buffer[x + neg_xsize]; + } + } + + const float total_weight_inv = (total_weight == 0) ? 1.f : (1.f / total_weight); + if constexpr (std::is_same::value) { + auto* scale_buffer_int = reinterpret_cast(scale_buffer); + for (int64_t x = 0; x < max_cut - min_cut; x++) { + scale_buffer[x] *= total_weight_inv; + // normalize the scale to 1 << 22 for int8/uint8 + scale_buffer_int[x] = static_cast(_Round(scale_buffer[x] * ConstValue::mag_factor_x_2)); + } + } else { + for (int64_t x = 0; x < max_cut - min_cut; x++) { + scale_buffer[x] *= total_weight_inv; + } + } +} + +/// This kernel computes antialias filter for bilinear or bicubic upsampling. +/// The function expects the following buffers to be pre-allocated on device +/// 1. bounds: int64_t[output_size * 2] for each of the two dimensions +/// 2. out_of_bounds: int64_t[output_size] for each of the two dimensions +/// 3. scale_data: AccumType[output_size * window_size] for each of the two dimensions +/// Buffers layout [h_data, w_data] +template +__global__ void _SetupBilinearUpsampleFilterAntiAlias( + std::tuple input_dims, // h, w + std::tuple output_dims, // h, w + std::tuple inv_scale_vals, // h, w + std::tuple roi_start_vals, // h, w + std::tuple roi_end_vals, // h, w + std::tuple dim_scaled_support, // Pre-computed scaled support values h, w + std::tuple dim_window_size, // Pre-computed windows sizes h, w + float cubic_coeff_a, + bool exclude_outside, + int64_t* bounds, + int64_t* out_of_bounds, + std::tuple weighted_coefficients // y, h buffers +) { + const auto N = std::get<0>(output_dims) + std::get<1>(output_dims); + + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + + if (id < std::get<0>(output_dims)) { + // Setup for y + int64_t input_size = std::get<0>(input_dims); + int64_t output_size = std::get<0>(output_dims); + float inv_scale = std::get<0>(inv_scale_vals); + float roi_start = std::get<0>(roi_start_vals); + float roi_end = std::get<0>(roi_end_vals); + float scaled_support = std::get<0>(dim_scaled_support); + int32_t window_size = std::get<0>(dim_window_size); + + SetupUpsampleFilterAnitAliasImpl( + id, + input_size, output_size, + inv_scale, + roi_start, roi_end, + scaled_support, window_size, + exclude_outside, + cubic_coeff_a, + bounds, + out_of_bounds, + std::get<0>(weighted_coefficients)); + + } else { + // Setup for w + // w = id - output_height + + int64_t input_size = std::get<1>(input_dims); + int64_t output_size = std::get<1>(output_dims); + float inv_scale = std::get<1>(inv_scale_vals); + float roi_start = std::get<1>(roi_start_vals); + float roi_end = std::get<1>(roi_end_vals); + + float scaled_support = std::get<1>(dim_scaled_support); + int32_t window_size = std::get<1>(dim_window_size); + + // Adjust buffer positions + const auto y_output_size = std::get<0>(output_dims); + + auto i = id - y_output_size; + bounds += (y_output_size * 2); + out_of_bounds += y_output_size; + + SetupUpsampleFilterAnitAliasImpl( + i, + input_size, output_size, + inv_scale, + roi_start, roi_end, + scaled_support, window_size, + exclude_outside, + cubic_coeff_a, + bounds, + out_of_bounds, + std::get<1>(weighted_coefficients)); + } +} + +/// +/// Compute AntiAlias filter for trilinear upsampling, all in one go +/// The function expects the following buffers to be pre-allocated on device +/// 1. bounds: int64_t[output_size * 2] for each of the three dimensions +/// 2. out_of_bounds: int64_t[output_size] for each of the three dimensions +/// 3. scale_data: AccumType[output_size * window_size] for each of the three dimensions +/// Each kind of buffer contains data for all 3 dims. +/// Buffers layout [d_data, h_data, w_data] +/// +template +__global__ void _SetupTrilinerarUpsampleFilterAntiAlias( + std::tuple input_dims, // d, h, w + std::tuple output_dims, // d, h, w + std::tuple inv_scale_vals, // d, h, w + std::tuple roi_start_vals, // d, h, w + std::tuple roi_end_vals, // d, h, w + std::tuple dim_scaled_support, // Pre-computed scaled support values d, h, w + std::tuple dim_window_size, // Pre-computed windows sizes d, h, w + bool exclude_outisde, + int64_t* bounds, + int64_t* out_of_bounds, + std::tuple weighted_coefficients) { + const auto N = std::get<0>(output_dims) + std::get<1>(output_dims) + std::get<2>(output_dims); + + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + + if (id < std::get<0>(output_dims)) { + // Setup for d by default (id < output_depth) + int64_t input_size = std::get<0>(input_dims); + int64_t output_size = std::get<0>(output_dims); + float inv_scale = std::get<0>(inv_scale_vals); + float roi_start = std::get<0>(roi_start_vals); + float roi_end = std::get<0>(roi_end_vals); + float scaled_support = std::get<0>(dim_scaled_support); + int32_t window_size = std::get<0>(dim_window_size); + + SetupUpsampleFilterAnitAliasImpl( + id, + input_size, output_size, + inv_scale, + roi_start, roi_end, + scaled_support, window_size, + exclude_outisde, + onnxruntime::antialias_constants::kCubicCoeffA, // Default value for trilinear + bounds, + out_of_bounds, + std::get<0>(weighted_coefficients)); + + } else if (id >= std::get<0>(output_dims) && id < (std::get<0>(output_dims) + std::get<1>(output_dims))) { + int64_t input_size = std::get<1>(input_dims); + int64_t output_size = std::get<1>(output_dims); + float inv_scale = std::get<1>(inv_scale_vals); + float roi_start = std::get<1>(roi_start_vals); + float roi_end = std::get<1>(roi_end_vals); + + float scaled_support = std::get<1>(dim_scaled_support); + int32_t window_size = std::get<1>(dim_window_size); + + // Adjust buffer positions + const auto d_output_size = std::get<0>(output_dims); + + auto i = id - d_output_size; + bounds += d_output_size * 2; + out_of_bounds += d_output_size; + + SetupUpsampleFilterAnitAliasImpl( + i, + input_size, output_size, + inv_scale, + roi_start, roi_end, + scaled_support, window_size, + exclude_outisde, + onnxruntime::antialias_constants::kCubicCoeffA, // Default value for trilinear + bounds, + out_of_bounds, + std::get<1>(weighted_coefficients)); + } else { + int64_t input_size = std::get<2>(input_dims); + int64_t output_size = std::get<2>(output_dims); + float inv_scale = std::get<2>(inv_scale_vals); + float roi_start = std::get<2>(roi_start_vals); + float roi_end = std::get<2>(roi_end_vals); + float scaled_support = std::get<2>(dim_scaled_support); + int32_t window_size = std::get<2>(dim_window_size); + + // Adjust buffer positions + const auto d_y_output_size = std::get<0>(output_dims) + std::get<1>(output_dims); + + auto i = id - d_y_output_size; + bounds += (d_y_output_size * 2); + out_of_bounds += d_y_output_size; + + SetupUpsampleFilterAnitAliasImpl( + i, + input_size, output_size, + inv_scale, + roi_start, roi_end, + scaled_support, window_size, + exclude_outisde, + onnxruntime::antialias_constants::kCubicCoeffA, // Default value for trilinear + bounds, + out_of_bounds, + std::get<2>(weighted_coefficients)); + } +} + +#define CASEA_COORD_ANTIALIAS(coordinate_mode, TransformCoordType, ...) \ + case coordinate_mode: { \ + using coord_t = TransformCoordType; \ + return __VA_ARGS__(); \ + break; \ + } + +#define DISPATCH_ANTIALIAS_FILTER_SETUP(coord_enum, ...) \ + [&] { \ + const auto the_type = coord_enum; \ + switch (the_type) { \ + CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::HALF_PIXEL, \ + TransformCoordinate_HALF_PIXEL, __VA_ARGS__) \ + CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::ASYMMETRIC, \ + TransformCoordinate_ASYMMETRIC, __VA_ARGS__) \ + CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::PYTORCH_HALF_PIXEL, \ + TransformCoordinate_PYTORCH_HALF_PIXEL, __VA_ARGS__) \ + CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::ALIGN_CORNERS, \ + TransformCoordinate_ALIGN_CORNERS, __VA_ARGS__) \ + CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::TF_HALF_PIXEL_FOR_NN, \ + TransformCoordinate_TF_HALF_PIXEL_FOR_NN, __VA_ARGS__) \ + CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE, \ + TransformCoordinate_TF_CROP_AND_RESIZE, __VA_ARGS__) \ + default: \ + ORT_THROW("unknown ResizeCoordinateTransformationMode"); \ + } \ + }() + +namespace { +template +IAllocatorUniquePtr AllocateTyped( + const TempSpaceAllocateFunc& alloc, + size_t elements) { + return alloc(elements * sizeof(T)); +} + +template +T* GetTyped(IAllocatorUniquePtr& bytes) { + return reinterpret_cast(bytes.get()); +} +} // namespace + +template +void ResizeTrilinearUpsample( + cudaStream_t stream, + int rank, + const UpsampleMode /*upsample_mode*/, + ResizeCoordinateTransformationMode coordinate_transform_mode, + gsl::span /*input_shape*/, + gsl::span /*output_shape*/, + int64_t batch_size, int64_t num_channels, + std::tuple inferred_input_dims, + std::tuple inferred_output_dims, + std::tuple inferred_dim_rscales, + const TArray& output_div_pitches, + gsl::span roi_vals, + const std::optional& extrapolation, + bool exclude_outside, + const TempSpaceAllocateFunc& allocate_temp_space, + const uint8_t* clip8_lookups, + const T* input_data, + T* output_data, + const size_t N) { + using AccumType = typename onnxruntime::AccumulateType::type; + + const bool use_extrapolation = extrapolation.has_value(); + const float extrapolation_value = use_extrapolation ? *extrapolation : 0.f; + + int64_t input_depth, input_height, input_width; + std::tie(input_depth, input_height, input_width) = inferred_input_dims; + + int64_t output_depth, output_height, output_width; + std::tie(output_depth, output_height, output_width) = inferred_output_dims; + + int blocksPerDimsMappingGrid = + static_cast(ceil((output_depth + output_height + output_width) / 32.0)); + + int blocksPerGrid = static_cast(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); + + constexpr float support_value = antialias_constants::kSupportSize; + float z_scale, h_scale, w_scale; + std::tie(z_scale, h_scale, w_scale) = inferred_dim_rscales; + + const auto& div_output_width = output_div_pitches[rank - 2]; + + SafeInt bounds_buffer_size = (SafeInt(output_depth) + output_height + output_width) * 2; + SafeInt out_of_bounds_buffer_size = (SafeInt(output_depth) + output_height + output_width); + + auto bounds_buffer_ptr = AllocateTyped(allocate_temp_space, bounds_buffer_size); + auto out_of_bounds_buffer_ptr = AllocateTyped(allocate_temp_space, out_of_bounds_buffer_size); + + int64_t* z_bounds_buffer = GetTyped(bounds_buffer_ptr); + int64_t* y_bounds_buffer = z_bounds_buffer + output_depth * 2; + int64_t* w_bounds_buffer = y_bounds_buffer + output_height * 2; + + int64_t* z_outof_bounds_buffer = GetTyped(out_of_bounds_buffer_ptr); + int64_t* y_outof_bounds_buffer = z_outof_bounds_buffer + output_depth; + int64_t* w_outof_bounds_buffer = y_outof_bounds_buffer + output_height; + + float z_scaled_support, h_scaled_support, w_scaled_support; + int32_t z_window_size, h_window_size, w_window_size; + const auto [z_buffer_size, y_buffer_size, w_buffer_size] = ComputeTrilinearScaleBufferSize( + output_depth, output_height, output_width, + z_scale, h_scale, w_scale, support_value, + z_scaled_support, h_scaled_support, w_scaled_support, + z_window_size, h_window_size, w_window_size); + + const int64_t weighted_buffer_size = SafeInt(z_buffer_size) + y_buffer_size + w_buffer_size; + + auto weighted_buffer_ptr = AllocateTyped(allocate_temp_space, weighted_buffer_size); + AccumType* z_weighted_buffer = GetTyped(weighted_buffer_ptr); + AccumType* y_weighted_buffer = z_weighted_buffer + z_buffer_size; + AccumType* w_weighted_buffer = y_weighted_buffer + y_buffer_size; + + const auto h_w_interpolate_temp_buf_size = SafeInt(batch_size) * num_channels * + input_depth * input_height * output_width; + auto h_w_interpolate_temp_buffer_ptr = AllocateTyped(allocate_temp_space, + narrow(h_w_interpolate_temp_buf_size)); + + const auto h_w_interpolate_result_buffer_size = SafeInt(batch_size) * num_channels * + input_depth * output_height * output_width; + auto h_w_interpolate_result_buffer_ptr = AllocateTyped(allocate_temp_space, h_w_interpolate_result_buffer_size); + + // clang-format off + DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() { + _SetupTrilinerarUpsampleFilterAntiAlias<<>>( + inferred_input_dims, + inferred_output_dims, + inferred_dim_rscales, + std::make_tuple(roi_vals[rank - 3], roi_vals[rank - 2], roi_vals[rank - 1]), // roi starts d, h, w + std::make_tuple(roi_vals[rank - 3 + rank], roi_vals[rank - 2 + rank], // roi ends d, h, w + roi_vals[rank - 1 + rank]), + std::make_tuple(z_scaled_support, h_scaled_support, w_scaled_support), + std::make_tuple(z_window_size, h_window_size, w_window_size), + exclude_outside, + GetTyped(bounds_buffer_ptr), + GetTyped(out_of_bounds_buffer_ptr), + std::make_tuple(z_weighted_buffer, y_weighted_buffer, w_weighted_buffer)); + }); + + // clang-format on + const fast_divmod div_w_image(narrow(num_channels * input_depth * input_height * output_width)); + // clang-format off + _ComputeInterpolationAtLevel1<<>>( + num_channels * input_depth, input_height, input_width, input_height, output_width, + div_output_width, + div_w_image, + w_window_size, + clip8_lookups, + w_bounds_buffer, + std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer), + w_weighted_buffer, input_data, + GetTyped(h_w_interpolate_temp_buffer_ptr), + narrow(h_w_interpolate_temp_buf_size)); + + // clang-format on + const fast_divmod div_output_height{narrow(output_height * output_width)}; + const fast_divmod div_h_w_image(narrow(num_channels * input_depth * output_height * output_width)); + // clang-format off + _ComputeInterpolationAtLevel2<<>>( + num_channels * input_depth, input_height, output_width, output_height, output_width, + div_output_height, + div_output_width, + div_h_w_image, + h_window_size, + false, 0.f, // No extrapolation + clip8_lookups, + y_bounds_buffer, + std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer), + y_weighted_buffer, GetTyped(h_w_interpolate_temp_buffer_ptr), + GetTyped(h_w_interpolate_result_buffer_ptr), + narrow(h_w_interpolate_result_buffer_size)); + + // clang-format on + const fast_divmod div_z_h_w_image(narrow(input_depth * output_height * output_width)); + // clang-format off + _ComputeInterpolationAtLevel3<<>>( + input_depth, output_height, output_width, + output_depth, output_height, output_width, + div_output_height, + div_output_width, + div_z_h_w_image, + z_window_size, + use_extrapolation, extrapolation_value, + clip8_lookups, + z_bounds_buffer, + std::make_tuple(z_outof_bounds_buffer, y_outof_bounds_buffer, w_outof_bounds_buffer), + z_weighted_buffer, GetTyped(h_w_interpolate_result_buffer_ptr), + output_data, + narrow(N)); + // clang-format on +} + +template +void ResizeBiLinearUpsample(cudaStream_t stream, + int rank, + const UpsampleMode /*upsample_mode*/, + ResizeCoordinateTransformationMode coordinate_transform_mode, + gsl::span /*input_shape*/, + gsl::span /*output_shape*/, + int64_t /*batch_size*/, int64_t num_channels, + std::tuple inferred_input_dims, + std::tuple inferred_output_dims, + std::tuple inferred_dim_rscales, + const TArray& output_div_pitches, + gsl::span roi_vals, + const std::optional& extrapolation, + bool exclude_outside, + const TempSpaceAllocateFunc& allocate_temp_space, + const uint8_t* clip8_lookups, + const T* input_data, + T* output_data, + const size_t N) { + using AccumType = typename onnxruntime::AccumulateType::type; + + const bool use_extrapolation = extrapolation.has_value(); + const float extrapolation_value = use_extrapolation ? *extrapolation : 0.f; + + int64_t input_depth, input_height, input_width; + std::tie(input_depth, input_height, input_width) = inferred_input_dims; + + int64_t output_depth, output_height, output_width; + std::tie(output_depth, output_height, output_width) = inferred_output_dims; + + int blocksPerDimsMappingGrid = + narrow(CeilDiv((output_depth + output_height + output_width), 32)); + + // rank 2 or 4 + const fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 4] + : fast_divmod(gsl::narrow_cast(N)); + const fast_divmod& div_output_width = output_div_pitches[rank - 2]; + + constexpr float support_value = antialias_constants::kSupportSize; + + float h_scale, w_scale; + std::tie(std::ignore, h_scale, w_scale) = inferred_dim_rscales; + + int blocksPerGrid = narrow(CeilDiv(N, GridDim::maxThreadsPerBlock)); + + SafeInt bounds_buffer_size = (SafeInt(output_height) + output_width) * 2; + SafeInt out_of_bounds_buffer_size = (SafeInt(output_height) + output_width); + + float h_scaled_support, w_scaled_support; + int32_t h_window_size, w_window_size; + const auto [weighted_y_size, weighted_w_size] = + ComputeBilinearScaleBufferSize(output_height, output_width, + h_scale, w_scale, support_value, + h_scaled_support, w_scaled_support, h_window_size, w_window_size); + + auto bounds_buffer_ptr = AllocateTyped(allocate_temp_space, bounds_buffer_size); + auto out_of_bounds_buffer_ptr = AllocateTyped(allocate_temp_space, out_of_bounds_buffer_size); + + int64_t* y_bounds_buffer = GetTyped(bounds_buffer_ptr); + int64_t* w_bounds_buffer = y_bounds_buffer + output_height * 2; + + int64_t* y_outof_bounds_buffer = GetTyped(out_of_bounds_buffer_ptr); + int64_t* w_outof_bounds_buffer = y_outof_bounds_buffer + output_height; + + const int64_t weighted_buffer_size = SafeInt(weighted_y_size) + weighted_w_size; + auto weighted_buffer_ptr = AllocateTyped(allocate_temp_space, narrow(weighted_buffer_size)); + + AccumType* y_weighted_buffer = GetTyped(weighted_buffer_ptr); + AccumType* w_weighted_buffer = y_weighted_buffer + weighted_y_size; + + const auto temp_buf_size = num_channels * input_height * output_width; + auto image_temp_buffer = AllocateTyped(allocate_temp_space, narrow(temp_buf_size)); + + // clang-format off + DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() { + // Data is d, h, w in tuples + + _SetupBilinearUpsampleFilterAntiAlias<<>>( + std::make_tuple(input_height, input_width), + std::make_tuple(output_height, output_width), + std::make_tuple(h_scale, w_scale), + std::make_tuple(roi_vals[rank - 2], roi_vals[rank - 1]), // roi starts h, w + std::make_tuple(roi_vals[rank - 2 + rank], roi_vals[rank - 1 + rank]), // roi ends h, w + std::make_tuple(h_scaled_support, w_scaled_support), + std::make_tuple(h_window_size, w_window_size), + onnxruntime::antialias_constants::kCubicCoeffA, exclude_outside, + GetTyped(bounds_buffer_ptr), + GetTyped(out_of_bounds_buffer_ptr), + std::make_tuple(y_weighted_buffer, w_weighted_buffer)); + }); + + // clang-format on + const fast_divmod div_step_image{narrow(num_channels * input_height * output_width)}; + // clang-format off + _ComputeInterpolationAtLevel1<<>>( + num_channels, input_height, input_width, input_height, output_width, + div_output_width, + div_step_image, + w_window_size, + clip8_lookups, + w_bounds_buffer, + std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer), + w_weighted_buffer, input_data, GetTyped(image_temp_buffer), + narrow(temp_buf_size)); + + // clang-format on + const fast_divmod div_output_height{narrow(output_height * output_width)}; + // clang-format off + _ComputeInterpolationAtLevel2<<>>( + num_channels, input_height, output_width, output_height, output_width, + div_output_height, + div_output_width, + div_output_image, + h_window_size, + use_extrapolation, extrapolation_value, + clip8_lookups, + y_bounds_buffer, + std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer), + y_weighted_buffer, GetTyped(image_temp_buffer), output_data, + narrow(N)); + + // clang-format on +} + +template +void ResizeBicubicUpsample(cudaStream_t stream, + int rank, + const UpsampleMode /*upsample_mode*/, + ResizeCoordinateTransformationMode coordinate_transform_mode, + gsl::span /*input_shape*/, + gsl::span /*output_shape*/, + int64_t batch_size, int64_t num_channels, + std::tuple inferred_input_dims, + std::tuple inferred_output_dims, + std::tuple inferred_dim_rscales, + // const TArray& input_strides, + const TArray& output_div_pitches, + gsl::span roi_vals, + const std::optional& extrapolation, + bool exclude_outside, + const TempSpaceAllocateFunc& allocate_temp_space, + const uint8_t* clip8_lookups, + const T* input_data, + T* output_data, + const size_t N) { + using AccumType = typename onnxruntime::AccumulateType::type; + + const bool use_extrapolation = extrapolation.has_value(); + const float extrapolation_value = use_extrapolation ? *extrapolation : 0.f; + + int blocksPerGrid = narrow(CeilDiv(N, GridDim::maxThreadsPerBlock)); + const fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 4] + : fast_divmod(gsl::narrow_cast(N)); + const fast_divmod& div_output_width = output_div_pitches[rank - 2]; + + constexpr float support_value = antialias_constants::kBiCubicSupportSize; + + int64_t input_depth, input_height, input_width; + std::tie(input_depth, input_height, input_width) = inferred_input_dims; + + int64_t output_depth, output_height, output_width; + std::tie(output_depth, output_height, output_width) = inferred_output_dims; + + int blocksPerDimsMappingGrid = + narrow(CeilDiv((output_depth + output_height + output_width), 32)); + + float h_scale, w_scale; + std::tie(std::ignore, h_scale, w_scale) = inferred_dim_rscales; + + SafeInt bounds_buffer_size = (SafeInt(output_height) + output_width) * 2; + SafeInt out_of_bounds_buffer_size = (SafeInt(output_height) + output_width); + + float h_scaled_support, w_scaled_support; + int32_t h_window_size, w_window_size; + const auto [weighted_y_size, weighted_w_size] = + ComputeBilinearScaleBufferSize(output_height, output_width, + h_scale, w_scale, support_value, + h_scaled_support, w_scaled_support, h_window_size, w_window_size); + + auto bounds_buffer_ptr = AllocateTyped(allocate_temp_space, bounds_buffer_size); + auto out_of_bounds_buffer_ptr = AllocateTyped(allocate_temp_space, out_of_bounds_buffer_size); + + int64_t* y_bounds_buffer = GetTyped(bounds_buffer_ptr); + int64_t* w_bounds_buffer = y_bounds_buffer + output_height * 2; + + int64_t* y_outof_bounds_buffer = GetTyped(out_of_bounds_buffer_ptr); + int64_t* w_outof_bounds_buffer = y_outof_bounds_buffer + output_height; + + const int64_t weighted_buffer_size = SafeInt(weighted_y_size) + + weighted_w_size; + auto weighted_buffer_ptr = AllocateTyped(allocate_temp_space, weighted_buffer_size); + + AccumType* y_weighted_buffer = GetTyped(weighted_buffer_ptr); + AccumType* w_weighted_buffer = y_weighted_buffer + weighted_y_size; + + const auto temp_buf_size = SafeInt(batch_size) * num_channels * input_height * output_width; + auto image_temp_buffer = AllocateTyped(allocate_temp_space, narrow(temp_buf_size)); + + // clang-format off + DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() { + _SetupBilinearUpsampleFilterAntiAlias<<>>( + std::make_tuple(input_height, input_width), + std::make_tuple(output_height, output_width), + std::make_tuple(h_scale, w_scale), + std::make_tuple(roi_vals[rank - 2], roi_vals[rank - 1]), // roi starts h, w + std::make_tuple(roi_vals[rank - 2 + rank], roi_vals[rank - 1 + rank]), // roi ends h, w + std::make_tuple(h_scaled_support, w_scaled_support), + std::make_tuple(h_window_size, w_window_size), + onnxruntime::antialias_constants::kCubicCoeffA, exclude_outside, + GetTyped(bounds_buffer_ptr), + GetTyped(out_of_bounds_buffer_ptr), + std::make_tuple(y_weighted_buffer, w_weighted_buffer)); + }); + // clang-format on + const fast_divmod div_step_image(narrow(num_channels * input_height * output_width)); + // clang-format off + _ComputeInterpolationAtLevel1<<>>( + num_channels, input_height, input_width, input_height, output_width, + div_output_width, + div_step_image, + w_window_size, + clip8_lookups, + w_bounds_buffer, + std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer), + w_weighted_buffer, input_data, GetTyped(image_temp_buffer), + narrow(temp_buf_size)); + // clang-format on + + const fast_divmod div_output_height{narrow(output_height * output_width)}; + // clang-format off + _ComputeInterpolationAtLevel2<<>>( + num_channels, input_height, output_width, output_height, output_width, + div_output_height, + div_output_width, + div_output_image, + h_window_size, + use_extrapolation, extrapolation_value, + clip8_lookups, + y_bounds_buffer, + std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer), + y_weighted_buffer, GetTyped(image_temp_buffer), output_data, + narrow(N)); + // clang-format on +} + +template +void ResizeAntiAliasImpl( + cudaStream_t stream, + int rank, + const UpsampleMode upsample_mode, + ResizeCoordinateTransformationMode coordinate_transform_mode, + gsl::span input_shape, + gsl::span output_shape, + int64_t batch_size, int64_t num_channels, + std::tuple inferred_input_dims, + std::tuple inferred_output_dims, + std::tuple inferred_dim_rscales, + const TArray& output_div_pitches, + gsl::span roi_vals, + const std::optional& extrapolation, + bool exclude_outside, + TempSpaceAllocateFunc allocate_temp_space, + const uint8_t* clip8_lookups, + const T* input_data, + T* output_data, + const size_t N) { + // We support a special case of bilinear or bicubic if the input data is 4D with the outer 2 scales being 1.0 + // We would have validated the outer scale values by the time execution reaches this + const bool is_2D = (rank == 2 || rank == 4); + + // We support a special case of trilinear or tricubic if the input data is 5D with the outer 2 scales being 1.0 + // We would have validated the outer scale values by the time execution reaches this + const bool is_3D = (rank == 3 || rank == 5); + + // Should not hit this as we have already validated input rank/scales and we provide verbose error messages + // to the user. + ORT_ENFORCE(is_2D || is_3D, "Only bilinear/trilinear and bicubic modes are supported in Resize anti-alias mode"); + + switch (upsample_mode) { + case UpsampleMode::LINEAR: { + if (is_2D) { + ResizeBiLinearUpsample(stream, rank, upsample_mode, coordinate_transform_mode, + input_shape, output_shape, batch_size, num_channels, + inferred_input_dims, inferred_output_dims, inferred_dim_rscales, + output_div_pitches, roi_vals, extrapolation, exclude_outside, + allocate_temp_space, clip8_lookups, input_data, output_data, N); + } else if (is_3D) { + ResizeTrilinearUpsample(stream, rank, upsample_mode, coordinate_transform_mode, + input_shape, output_shape, batch_size, num_channels, + inferred_input_dims, inferred_output_dims, inferred_dim_rscales, + output_div_pitches, roi_vals, extrapolation, exclude_outside, + allocate_temp_space, clip8_lookups, input_data, output_data, N); + } else { + ORT_NOT_IMPLEMENTED("Resize supports only 2-D or 3-D in LINEAR mode."); + } + } break; + case CUBIC: { + if (is_2D) { + ResizeBicubicUpsample(stream, rank, upsample_mode, coordinate_transform_mode, + input_shape, output_shape, batch_size, num_channels, + inferred_input_dims, inferred_output_dims, inferred_dim_rscales, + output_div_pitches, roi_vals, extrapolation, exclude_outside, + allocate_temp_space, clip8_lookups, input_data, output_data, N); + } else { + ORT_NOT_IMPLEMENTED("Resize supports only 2-D in CUBIC mode."); + } + } break; + default: + ORT_NOT_IMPLEMENTED("Only bilinear/trilinear and bicubic modes are supported in Resize anti-alias mode"); + break; + } +} + +#define SPECIALIZED_ANTIALIAS_IMPL(T) \ + template void ResizeAntiAliasImpl( \ + cudaStream_t stream, \ + int rank, \ + const UpsampleMode upsample_mode, \ + ResizeCoordinateTransformationMode coordinate_transform_mode, \ + gsl::span input_shape, \ + gsl::span output_shape, \ + int64_t batch_size, int64_t num_channels, \ + std::tuple inferred_input_dims, \ + std::tuple inferred_output_dims, \ + std::tuple inferred_dim_rscales, \ + const TArray& output_div_pitches, \ + gsl::span roi_vals, \ + const std::optional& extrapolation_value, \ + bool exclude_outside, \ + TempSpaceAllocateFunc allocate_temp_space, \ + const uint8_t* clip8_lookups, \ + const T* input_data, \ + T* output_data, \ + const size_t N); + +SPECIALIZED_ANTIALIAS_IMPL(float) +SPECIALIZED_ANTIALIAS_IMPL(double) +SPECIALIZED_ANTIALIAS_IMPL(half) +SPECIALIZED_ANTIALIAS_IMPL(int32_t) +SPECIALIZED_ANTIALIAS_IMPL(uint8_t) + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu index 1a94c7705e913..e788f24052985 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu @@ -12,7 +12,7 @@ using onnxruntime::ResizeNearestMode; using onnxruntime::UpsampleMode; struct NearestPixel_SIMPLE { - __device__ __forceinline__ int operator() (float x_original, bool is_down_sampling) const { + __device__ __forceinline__ int operator()(float x_original, bool is_down_sampling) const { if (is_down_sampling) { return static_cast(_Ceil(x_original)); } @@ -21,7 +21,7 @@ struct NearestPixel_SIMPLE { }; struct NearestPixel_ROUND_PREFER_FLOOR { - __device__ __forceinline__ int operator() (float x_original, bool) const { + __device__ __forceinline__ int operator()(float x_original, bool) const { if (x_original == static_cast(x_original) + 0.5f) { return static_cast(_Floor(x_original)); } @@ -30,62 +30,23 @@ struct NearestPixel_ROUND_PREFER_FLOOR { }; struct NearestPixel_ROUND_PREFER_CEIL { - __device__ __forceinline__ int operator() (float x_original, bool) const { + __device__ __forceinline__ int operator()(float x_original, bool) const { return static_cast(roundf(x_original)); } }; struct NearestPixel_FLOOR { - __device__ __forceinline__ int operator() (float x_original, bool) const { + __device__ __forceinline__ int operator()(float x_original, bool) const { return static_cast(_Floor(x_original)); } }; struct NearestPixel_CEIL { - __device__ __forceinline__ int operator() (float x_original, bool) const { + __device__ __forceinline__ int operator()(float x_original, bool) const { return static_cast(_Ceil(x_original)); } }; -struct TransformCoordinate_ASYMMETRIC { - __device__ __forceinline__ float operator() (float x_resized, float x_scale, float, float, float, float) const { - return x_resized / x_scale; - } -}; - -struct TransformCoordinate_HALF_PIXEL { - __device__ __forceinline__ float operator() (float x_resized, float x_scale, float, float, float, float) const { - return ((x_resized + 0.5f) / x_scale) - 0.5f; - } -}; - -struct TransformCoordinate_PYTORCH_HALF_PIXEL { - __device__ __forceinline__ float operator() (float x_resized, float x_scale, float length_resized, float, float, float) const { - return length_resized > 1 ? (x_resized + 0.5f) / x_scale - 0.5f : 0.0f; - } -}; - -struct TransformCoordinate_TF_HALF_PIXEL_FOR_NN { - __device__ __forceinline__ float operator() (float x_resized, float x_scale, float, float, float, float) const { - return (x_resized + 0.5f) / x_scale; - } -}; - -struct TransformCoordinate_ALIGN_CORNERS { - __device__ __forceinline__ float operator() (float x_resized, float, float length_resized, float length_original, float, float) const { - return length_resized == 1 ? 0 : x_resized * (length_original - 1) / (length_resized - 1); - } -}; - -struct TransformCoordinate_TF_CROP_AND_RESIZE { - __device__ __forceinline__ float operator() (float x_resized, float, float length_resized, float length_original, float roi_start, float roi_end) const { - auto orig = length_resized > 1 - ? roi_start * (length_original - 1) + (x_resized * (roi_end - roi_start) * (length_original - 1)) / (length_resized - 1) - : 0.5 * (roi_start + roi_end) * (length_original - 1); - return static_cast(orig); - } -}; - #define CASE_TYPE_USING_HINT(enum_type, type, HINT, ...) \ case enum_type: { \ using HINT = type; \ @@ -95,20 +56,24 @@ struct TransformCoordinate_TF_CROP_AND_RESIZE { #define CASE_TYPE_COORD(enum_type, type, ...) \ CASE_TYPE_USING_HINT(enum_type, type, coord_t, __VA_ARGS__) -#define DISPATCH_RESIZE_COORDINATE_TRANSFORMATION_MODE(TYPE, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - switch (the_type) { \ - CASE_TYPE_COORD(ResizeCoordinateTransformationMode::HALF_PIXEL, TransformCoordinate_HALF_PIXEL, __VA_ARGS__) \ - CASE_TYPE_COORD(ResizeCoordinateTransformationMode::ASYMMETRIC, TransformCoordinate_ASYMMETRIC, __VA_ARGS__) \ - CASE_TYPE_COORD(ResizeCoordinateTransformationMode::PYTORCH_HALF_PIXEL, TransformCoordinate_PYTORCH_HALF_PIXEL, __VA_ARGS__) \ - CASE_TYPE_COORD(ResizeCoordinateTransformationMode::ALIGN_CORNERS, TransformCoordinate_ALIGN_CORNERS, __VA_ARGS__) \ - CASE_TYPE_COORD(ResizeCoordinateTransformationMode::TF_HALF_PIXEL_FOR_NN, TransformCoordinate_TF_HALF_PIXEL_FOR_NN, __VA_ARGS__) \ - CASE_TYPE_COORD(ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE, TransformCoordinate_TF_CROP_AND_RESIZE, __VA_ARGS__) \ - default: \ - ORT_THROW("unknown ResizeCoordinateTransformationMode"); \ - } \ +#define DISPATCH_RESIZE_COORDINATE_TRANSFORMATION_MODE(TYPE, ...) \ + [&] { \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op */ \ + switch (the_type) { \ + CASE_TYPE_COORD(ResizeCoordinateTransformationMode::HALF_PIXEL, TransformCoordinate_HALF_PIXEL, __VA_ARGS__) \ + CASE_TYPE_COORD(ResizeCoordinateTransformationMode::ASYMMETRIC, TransformCoordinate_ASYMMETRIC, __VA_ARGS__) \ + CASE_TYPE_COORD(ResizeCoordinateTransformationMode::PYTORCH_HALF_PIXEL, \ + TransformCoordinate_PYTORCH_HALF_PIXEL, __VA_ARGS__) \ + CASE_TYPE_COORD(ResizeCoordinateTransformationMode::ALIGN_CORNERS, \ + TransformCoordinate_ALIGN_CORNERS, __VA_ARGS__) \ + CASE_TYPE_COORD(ResizeCoordinateTransformationMode::TF_HALF_PIXEL_FOR_NN, \ + TransformCoordinate_TF_HALF_PIXEL_FOR_NN, __VA_ARGS__) \ + CASE_TYPE_COORD(ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE, \ + TransformCoordinate_TF_CROP_AND_RESIZE, __VA_ARGS__) \ + default: \ + ORT_THROW("unknown ResizeCoordinateTransformationMode"); \ + } \ }() #define CASE_TYPE_NEAREST(enum_type, type, ...) \ @@ -119,11 +84,11 @@ struct TransformCoordinate_TF_CROP_AND_RESIZE { const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ switch (the_type) { \ - CASE_TYPE_NEAREST(ResizeNearestMode::SIMPLE, NearestPixel_SIMPLE, __VA_ARGS__) \ + CASE_TYPE_NEAREST(ResizeNearestMode::SIMPLE, NearestPixel_SIMPLE, __VA_ARGS__) \ CASE_TYPE_NEAREST(ResizeNearestMode::ROUND_PREFER_FLOOR, NearestPixel_ROUND_PREFER_FLOOR, __VA_ARGS__) \ - CASE_TYPE_NEAREST(ResizeNearestMode::ROUND_PREFER_CEIL, NearestPixel_ROUND_PREFER_CEIL, __VA_ARGS__) \ - CASE_TYPE_NEAREST(ResizeNearestMode::FLOOR, NearestPixel_FLOOR, __VA_ARGS__) \ - CASE_TYPE_NEAREST(ResizeNearestMode::CEIL, NearestPixel_CEIL, __VA_ARGS__) \ + CASE_TYPE_NEAREST(ResizeNearestMode::ROUND_PREFER_CEIL, NearestPixel_ROUND_PREFER_CEIL, __VA_ARGS__) \ + CASE_TYPE_NEAREST(ResizeNearestMode::FLOOR, NearestPixel_FLOOR, __VA_ARGS__) \ + CASE_TYPE_NEAREST(ResizeNearestMode::CEIL, NearestPixel_CEIL, __VA_ARGS__) \ default: \ ORT_THROW("unknown ResizeNearestMode"); \ } \ @@ -151,10 +116,12 @@ __global__ void _ResizeNearestMappingKernel2D( // only apply co-ordinate transformation if scale != 1.0 if (scales_height == 1.0f) { - dims_mapping[id].extrapolate_ = 0; + dims_mapping[id].extrapolate_ = 0; } else { - float orig_coord = transform_coordinate(static_cast(dim), scales_height, static_cast(output_height), - static_cast(input_height), roi_start_height, roi_end_height); + float orig_coord = transform_coordinate(static_cast(dim), scales_height, + static_cast(output_height), + static_cast(input_height), + roi_start_height, roi_end_height); dims_mapping[id].extrapolate_ = static_cast( extrapolation_enabled && (orig_coord < 0.f || orig_coord > static_cast(input_height - 1))); dim = calc_nearest_pixel(orig_coord, scales_height < 1); @@ -210,9 +177,12 @@ __global__ void _ResizeNearestMappingKernel( if (scales[axis] == 1.0f) { dims_mapping[id].extrapolate_ = 0; } else { - float orig_coord = transform_coordinate(static_cast(dim), scales[axis], static_cast(output_shape[axis]), + float orig_coord = transform_coordinate(static_cast(dim), scales[axis], + static_cast(output_shape[axis]), static_cast(input_shape[axis]), roi[axis], roi[axis + rank]); - dims_mapping[id].extrapolate_ = static_cast(extrapolation_enabled && (orig_coord < 0.f || orig_coord > static_cast(input_shape[axis] - 1))); + dims_mapping[id].extrapolate_ = static_cast(extrapolation_enabled && + (orig_coord < 0.f || + orig_coord > static_cast(input_shape[axis] - 1))); dim = calc_nearest_pixel(orig_coord, scales[axis] < 1); if (dim >= input_shape[axis]) dim = input_shape[axis] - 1; if (dim < 0) dim = 0; @@ -293,21 +263,27 @@ __global__ void _ResizeBilinearCoordinateMapping( LinearMappingInfo* dims_mapping) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, SumHW); if (id < output_height) { // y = id - float input_y = scale_height == 1 ? static_cast(id) : - transform_coordinate(static_cast(id), scale_height, - static_cast(output_height), static_cast(input_height), - roi_height_start, roi_height_end); - dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_y < 0 || input_y > static_cast(input_height - 1))); + float input_y = scale_height == 1 ? static_cast(id) + : transform_coordinate(static_cast(id), scale_height, + static_cast(output_height), + static_cast(input_height), + roi_height_start, roi_height_end); + dims_mapping[id].extrapolate_ = static_cast((extrapolation_enabled && + (input_y < 0 || + input_y > static_cast(input_height - 1)))); input_y = max(0.0f, min(input_y, static_cast(input_height - 1))); int y_int = static_cast(input_y); dims_mapping[id].origin_ = y_int; dims_mapping[id].weight_ = (y_int >= input_height - 1) ? 0.5f : input_y - y_int; - } else { //x = id - output_height - float input_x = scale_width == 1 ? static_cast(id - output_height) : - transform_coordinate(static_cast(id - output_height), scale_width, - static_cast(output_width), static_cast(input_width), - roi_width_start, roi_width_end); - dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_x < 0 || input_x > static_cast(input_width - 1))); + } else { // x = id - output_height + float input_x = scale_width == 1 ? static_cast(id - output_height) + : transform_coordinate(static_cast(id - output_height), + scale_width, static_cast(output_width), + static_cast(input_width), roi_width_start, + roi_width_end); + dims_mapping[id].extrapolate_ = static_cast((extrapolation_enabled && + (input_x < 0 || + input_x > static_cast(input_width - 1)))); input_x = max(0.0f, min(input_x, static_cast(input_width - 1))); int x_int = static_cast(input_x); dims_mapping[id].origin_ = x_int; @@ -371,32 +347,40 @@ __global__ void _ResizeTrilinearCoordinateMapping( LinearMappingInfo* dims_mapping) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, SumDHW); if (id < output_depth) { // z = id - float input_z = scale_depth == 1 ? static_cast(id) : - transform_coordinate(static_cast(id), scale_depth, - static_cast(output_depth), static_cast(input_depth), - roi_depth_start, roi_depth_end); - dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_z < 0 || input_z > static_cast(input_depth - 1))); + float input_z = scale_depth == 1 ? static_cast(id) + : transform_coordinate(static_cast(id), scale_depth, + static_cast(output_depth), + static_cast(input_depth), + roi_depth_start, roi_depth_end); + dims_mapping[id].extrapolate_ = static_cast((extrapolation_enabled && + (input_z < 0 || + input_z > static_cast(input_depth - 1)))); input_z = max(0.0f, min(input_z, static_cast(input_depth - 1))); int z_int = static_cast(input_z); dims_mapping[id].origin_ = z_int; dims_mapping[id].weight_ = (z_int >= input_depth - 1) ? 0.5f : input_z - z_int; } else if (id >= output_depth && id < (output_depth + output_height)) { // y = id - output_depth - float input_y = scale_height == 1 ? static_cast(id - output_depth) : - transform_coordinate(static_cast(id - output_depth), scale_height, - static_cast(output_height), static_cast(input_height), - roi_height_start, roi_height_end); - - dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_y < 0 || input_y > static_cast(input_height - 1))); + float input_y = scale_height == 1 ? static_cast(id - output_depth) + : transform_coordinate(static_cast(id - output_depth), + scale_height, static_cast(output_height), + static_cast(input_height), + roi_height_start, roi_height_end); + + dims_mapping[id].extrapolate_ = static_cast((extrapolation_enabled && + (input_y < 0 || + input_y > static_cast(input_height - 1)))); input_y = max(0.0f, min(input_y, static_cast(input_height - 1))); int y_int = static_cast(input_y); dims_mapping[id].origin_ = y_int; dims_mapping[id].weight_ = (y_int >= input_height - 1) ? 0.5f : input_y - y_int; - } else { //x = id - output_depth - output_height - float input_x = scale_width == 1 ? static_cast(id - output_depth - output_height) : - transform_coordinate(static_cast(id - output_depth - output_height), scale_width, - static_cast(output_width), static_cast(input_width), - roi_width_start, roi_width_end); - dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_x < 0 || input_x > static_cast(input_width - 1))); + } else { // x = id - output_depth - output_height + float input_x = scale_width == 1 ? static_cast(id - output_depth - output_height) + : transform_coordinate(static_cast(id - output_depth - output_height), + scale_width, static_cast(output_width), + static_cast(input_width), + roi_width_start, roi_width_end); + dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_x < 0 || + input_x > static_cast(input_width - 1))); input_x = max(0.0f, min(input_x, static_cast(input_width - 1))); int x_int = static_cast(input_x); dims_mapping[id].origin_ = x_int; @@ -513,21 +497,33 @@ __global__ void _ResizeCubicCoordinateMapping( int max_input_coord = static_cast(is_y_axis ? input_height : input_width); float scale = is_y_axis ? scale_height : scale_width; - float input_coordinat = scale == 1 ? (is_y_axis ? id : id - output_height) : - transform_coordinate( - static_cast(is_y_axis ? id : id - output_height), - scale, - static_cast(is_y_axis ? output_height : output_width), - static_cast(max_input_coord), - (is_y_axis ? roi_height_start : roi_width_start), - (is_y_axis ? roi_height_end : roi_width_end)); + float input_coordinat = scale == 1 ? (is_y_axis ? id : id - output_height) + : transform_coordinate( + static_cast(is_y_axis ? id : id - output_height), + scale, + static_cast(is_y_axis ? output_height : output_width), + static_cast(max_input_coord), + (is_y_axis ? roi_height_start : roi_width_start), + (is_y_axis ? roi_height_end : roi_width_end)); int coord_int = static_cast(_Floor(input_coordinat)); float s_coord = abs(input_coordinat - coord_int); float coeff_sum = 1.0f; - float coeff_0 = static_cast(((cubic_coeff_a * (s_coord + 1) - 5 * cubic_coeff_a) * (s_coord + 1) + 8 * cubic_coeff_a) * (s_coord + 1) - 4 * cubic_coeff_a); - float coeff_1 = static_cast(((cubic_coeff_a + 2) * s_coord - (cubic_coeff_a + 3)) * s_coord * s_coord + 1); - float coeff_2 = static_cast(((cubic_coeff_a + 2) * (1 - s_coord) - (cubic_coeff_a + 3)) * (1 - s_coord) * (1 - s_coord) + 1); - float coeff_3 = static_cast(((cubic_coeff_a * (2 - s_coord) - 5 * cubic_coeff_a) * (2 - s_coord) + 8 * cubic_coeff_a) * (2 - s_coord) - 4 * cubic_coeff_a); + float coeff_0 = static_cast(((cubic_coeff_a * (s_coord + 1) - 5 * cubic_coeff_a) * + (s_coord + 1) + + 8 * cubic_coeff_a) * + (s_coord + 1) - + 4 * cubic_coeff_a); + float coeff_1 = static_cast(((cubic_coeff_a + 2) * s_coord - (cubic_coeff_a + 3)) * + s_coord * s_coord + + 1); + float coeff_2 = static_cast(((cubic_coeff_a + 2) * (1 - s_coord) - (cubic_coeff_a + 3)) * + (1 - s_coord) * (1 - s_coord) + + 1); + float coeff_3 = static_cast(((cubic_coeff_a * (2 - s_coord) - 5 * cubic_coeff_a) * + (2 - s_coord) + + 8 * cubic_coeff_a) * + (2 - s_coord) - + 4 * cubic_coeff_a); if (exclude_outside) { coeff_0 = (coord_int - 1 < 0 || coord_int - 1 >= max_input_coord) ? 0.0 : coeff_0; coeff_1 = (coord_int + 0 < 0 || coord_int + 0 >= max_input_coord) ? 0.0 : coeff_1; @@ -540,7 +536,8 @@ __global__ void _ResizeCubicCoordinateMapping( dm.coeff1_ = coeff_1 / coeff_sum; dm.coeff2_ = coeff_2 / coeff_sum; dm.coeff3_ = coeff_3 / coeff_sum; - dm.extrapolate_ = (int)(extrapolation_enabled && (input_coordinat < 0 || input_coordinat > static_cast(max_input_coord - 1))); + dm.extrapolate_ = (int)(extrapolation_enabled && (input_coordinat < 0 || + input_coordinat > static_cast(max_input_coord - 1))); } template @@ -569,21 +566,30 @@ __global__ void _ResizeBiCubicKernel( int x_int = x_info.origin_; int y_int = y_info.origin_; const T* image = input_data + input_index; - output_data[id] = y_info.coeff0_ * CubicInterpolationRowwise(image, x_int, y_int - 1, input_height, input_width, w0, w1, w2, w3) + - y_info.coeff1_ * CubicInterpolationRowwise(image, x_int, y_int, input_height, input_width, w0, w1, w2, w3) + - y_info.coeff2_ * CubicInterpolationRowwise(image, x_int, y_int + 1, input_height, input_width, w0, w1, w2, w3) + - y_info.coeff3_ * CubicInterpolationRowwise(image, x_int, y_int + 2, input_height, input_width, w0, w1, w2, w3); + output_data[id] = y_info.coeff0_ * + CubicInterpolationRowwise(image, x_int, y_int - 1, input_height, input_width, w0, w1, w2, w3) + + y_info.coeff1_ * + CubicInterpolationRowwise(image, x_int, y_int, input_height, input_width, w0, w1, w2, w3) + + y_info.coeff2_ * + CubicInterpolationRowwise(image, x_int, y_int + 1, input_height, input_width, w0, w1, w2, w3) + + y_info.coeff3_ * + CubicInterpolationRowwise(image, x_int, y_int + 2, input_height, input_width, w0, w1, w2, w3); } size_t CalcResizeBufferSize(const onnxruntime::UpsampleMode upsample_mode, const gsl::span& output_dims) { switch (upsample_mode) { case UpsampleMode::NN: - return sizeof(int64_t) * output_dims.size() + sizeof(NearestMappingInfo) * static_cast(std::accumulate(output_dims.begin(), output_dims.end(), (int64_t)0)); + return sizeof(int64_t) * output_dims.size() + + sizeof(NearestMappingInfo) * + static_cast(std::accumulate(output_dims.begin(), + output_dims.end(), (int64_t)0)); case UpsampleMode::LINEAR: - return sizeof(LinearMappingInfo) * static_cast(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 2, (int64_t)0)); + return sizeof(LinearMappingInfo) * + static_cast(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 2, (int64_t)0)); case UpsampleMode::CUBIC: - return sizeof(CubicMappingInfo) * static_cast(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 2, (int64_t)0)); + return sizeof(CubicMappingInfo) * + static_cast(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 2, (int64_t)0)); } return 0; } @@ -603,7 +609,7 @@ void ResizeNearestImpl( const size_t N, bool extrapolation_enabled, const T extrapolation_value, - float cubic_coeff_a, + float /*cubic_coeff_a*/, ResizeCoordinateTransformationMode transform_coordinate, ResizeNearestMode calc_nearest_pixel, int64_t* /* prefix_dim_sum */, @@ -616,7 +622,8 @@ void ResizeNearestImpl( if (could2d) { int64_t output_height = output_shape[rank - 2]; int64_t output_width = output_shape[rank - 1]; - fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 3] : fast_divmod(static_cast(output_height * output_width)); + fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 3] + : fast_divmod(static_cast(output_height * output_width)); int blocksPerDimsMappingGrid = static_cast(ceil((output_height + output_width) / 32.0)); DISPATCH_RESIZE_COORDINATE_TRANSFORMATION_MODE(transform_coordinate, [&]() { @@ -694,13 +701,6 @@ void ResizeImpl( ResizeCoordinateTransformationMode coordinate_transform_mode, ResizeNearestMode nearest_mode, void* dims_mapping) { - bool isSame = std::all_of(scales_vals.Data(), scales_vals.Data() + rank, [](float v) { return v == 1.0f; }) && - (coordinate_transform_mode != ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE); - if (isSame) { - CUDA_CALL_THROW(cudaMemcpyAsync(output_data, input_data, N * sizeof(T), cudaMemcpyDeviceToDevice, stream)); - return; - } - if (upsample_mode == UpsampleMode::NN) { ResizeNearestImpl( stream, rank, input_shape, output_shape, input_strides, output_div_pitches, @@ -761,7 +761,7 @@ void ResizeImpl( } else if (is_3D) { DISPATCH_RESIZE_COORDINATE_TRANSFORMATION_MODE(coordinate_transform_mode, [&]() { _ResizeTrilinearCoordinateMapping<<>>( - input_shape[rank - 3] , input_shape[rank - 2], input_shape[rank - 1], + input_shape[rank - 3], input_shape[rank - 2], input_shape[rank - 1], output_depth, output_height, output_width, scales_vals[rank - 3], scales_vals[rank - 2], scales_vals[rank - 1], roi_vals[rank - 3], roi_vals[rank - 3 + rank], @@ -778,7 +778,7 @@ void ResizeImpl( reinterpret_cast(dims_mapping)); return; } - ORT_THROW("Only bilinear/trilinear and bicubic modes are supported in Resize"); + ORT_THROW("Resize support 2-D and 3-D dimensions in LINEAR mode."); break; case UpsampleMode::CUBIC: if (is_2D) { @@ -801,7 +801,7 @@ void ResizeImpl( reinterpret_cast(dims_mapping)); return; } - ORT_THROW("Only bilinear/trilinear and bicubic modes are supported in Resize"); + ORT_THROW("Resize supports only 2-D in CUBIC mode."); case UpsampleMode::NN: ORT_THROW("Only bilinear/trilinear and bicubic modes are supported in Resize"); } @@ -809,7 +809,7 @@ void ResizeImpl( #define SPECIALIZED_IMPL(T) \ template void ResizeImpl( \ - cudaStream_t stream, \ + cudaStream_t stream, \ const UpsampleMode upsample_mode, \ const int rank, \ TArray& input_shape, \ diff --git a/onnxruntime/core/providers/cuda/tensor/resize_impl.h b/onnxruntime/core/providers/cuda/tensor/resize_impl.h index d459dbff18d3e..ad06eebb9efb1 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/resize_impl.h @@ -2,15 +2,69 @@ // Licensed under the MIT License. #pragma once + #include + +#include + #include "core/providers/cuda/shared_inc/cuda_utils.h" #include "core/common/common.h" #include "core/providers/cpu/tensor/upsamplebase.h" #include "core/providers/cuda/cuda_common.h" namespace onnxruntime { +template <> +struct AccumulateType { + using type = float; +}; namespace cuda { +struct TransformCoordinate_ASYMMETRIC { + __device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale, + float, float, float, float) const { + return x_resized / x_scale; + } +}; + +struct TransformCoordinate_HALF_PIXEL { + __device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale, + float, float, float, float) const { + return ((x_resized + 0.5f) / x_scale) - 0.5f; + } +}; + +struct TransformCoordinate_PYTORCH_HALF_PIXEL { + __device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale, float length_resized, float, + float, float) const { + return length_resized > 1 ? (x_resized + 0.5f) / x_scale - 0.5f : 0.0f; + } +}; + +struct TransformCoordinate_TF_HALF_PIXEL_FOR_NN { + __device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale, + float, float, float, float) const { + return (x_resized + 0.5f) / x_scale; + } +}; + +struct TransformCoordinate_ALIGN_CORNERS { + __device__ __host__ __forceinline__ float operator()(float x_resized, float, float length_resized, + float length_original, float, float) const { + return length_resized == 1 ? 0 : x_resized * (length_original - 1) / (length_resized - 1); + } +}; + +struct TransformCoordinate_TF_CROP_AND_RESIZE { + __device__ __host__ __forceinline__ float operator()(float x_resized, float, float length_resized, + float length_original, float roi_start, float roi_end) const { + auto orig = length_resized > 1 + ? roi_start * (length_original - 1) + + (x_resized * (roi_end - roi_start) * (length_original - 1)) / (length_resized - 1) + : 0.5 * (roi_start + roi_end) * (length_original - 1); + return static_cast(orig); + } +}; + size_t CalcResizeBufferSize(const onnxruntime::UpsampleMode upsample_mode, const gsl::span& output_dims); @@ -36,5 +90,62 @@ void ResizeImpl( onnxruntime::ResizeNearestMode nearest_mode, void* dims_mapping); +using TempSpaceAllocateFunc = std::function(size_t buffer_size)>; + +template +void ResizeAntiAliasImpl( + cudaStream_t stream, + int rank, + const UpsampleMode upsample_mode, + ResizeCoordinateTransformationMode coordinate_transform_mode, + gsl::span input_shape, + gsl::span output_shape, + int64_t batch_size, int64_t num_channels, + std::tuple inferred_input_dims, + std::tuple inferred_output_dims, + std::tuple inferred_dim_rscales, + const TArray& output_div_pitches, + gsl::span roi_vals, // CPU + const std::optional& extrapolation_value, + bool exclude_outside, + TempSpaceAllocateFunc allocate_temp_space, + const uint8_t* clip8_lookups, + const T* input_data, + T* output_data, + const size_t N); + +/// +/// Compute scaled support value for a given dimension inverse scale +/// +/// Support value from parameters +/// inverse scale value comes from input/attr for +/// +inline float ComputeScaledSupportValue(float support_value, float rscale) { + const float scale = 1.0f / rscale; + float scaled_support = (scale >= 1.0f) ? (support_value * 0.5f) * scale : support_value * 0.5f; + return scaled_support; +} + +/// +/// Compute window size for a given dimension scaled support value. +/// +/// +/// +inline int32_t ComputeWindowSize(float scaled_support) { + SafeInt window_size(ceilf(scaled_support)); + return window_size * 2 + 1; +} + +/// +/// Computes scale buffer size in number of elements for allocation purposes. +/// +/// +/// +/// Number of elements to fit in the buffer +inline SafeInt ComputeWeightedCoeffBufferSize(int64_t output_size, int32_t window_size) { + SafeInt buffer_size(output_size); + return buffer_size * window_size; +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc b/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc index e4d145154971e..42a9f50001103 100755 --- a/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc +++ b/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc @@ -27,7 +27,23 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 11, 12, kCudaExe DataTypeImpl::GetTensorType()}), ScatterElements); -ONNX_OPERATOR_KERNEL_EX(ScatterElements, kOnnxDomain, 13, kCudaExecutionProvider, +ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 13, 15, kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("Tind", + std::vector{DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + ScatterElements); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 16, 17, kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("Tind", + std::vector{DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + ScatterElements); + +ONNX_OPERATOR_KERNEL_EX(ScatterElements, kOnnxDomain, 18, kCudaExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) .TypeConstraint("Tind", std::vector{DataTypeImpl::GetTensorType(), @@ -106,6 +122,20 @@ Status ScatterElements::ComputeInternal(OpKernelContext* context) const { TensorShapeVector indices_shape_vec = indices_shape.AsShapeVector(); CoalesceDimensions(input_shape_vec, indices_shape_vec, nullptr, axis, args); + if (reduction_ == "none") { + args.operation = GatherScatterElementsArgs::Operation::NONE; + } else if (reduction_ == "add") { + args.operation = GatherScatterElementsArgs::Operation::ADD; + } else if (reduction_ == "mul") { + args.operation = GatherScatterElementsArgs::Operation::MUL; + } else if (reduction_ == "min") { + args.operation = GatherScatterElementsArgs::Operation::MIN; + } else if (reduction_ == "max") { + args.operation = GatherScatterElementsArgs::Operation::MAX; + } else { + ORT_THROW("Unsupported reduction type"); + } + // Use element size instead of concrete types so we can specialize less template functions to reduce binary size. int dtype = GetElementType(input_tensor->DataType()->Size()); if (dtype == ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) { diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_elements.h b/onnxruntime/core/providers/cuda/tensor/scatter_elements.h index 3e9e0ce041845..3884b716da308 100755 --- a/onnxruntime/core/providers/cuda/tensor/scatter_elements.h +++ b/onnxruntime/core/providers/cuda/tensor/scatter_elements.h @@ -14,6 +14,12 @@ class ScatterElements final : public CudaKernel { ScatterElements(const OpKernelInfo& info) : CudaKernel(info) { ORT_ENFORCE(info.GetAttr("axis", &axis_).IsOK(), "Missing/Invalid 'axis' attribute value"); + reduction_ = info.GetAttrOrDefault("reduction", "none"); + + ORT_ENFORCE(reduction_ == "none" || reduction_ == "add" || + reduction_ == "mul" || reduction_ == "max" || + reduction_ == "min", + "Invalid reduction attribute value of ", reduction_); } ~ScatterElements() = default; Status ComputeInternal(OpKernelContext* context) const override; @@ -23,6 +29,10 @@ class ScatterElements final : public CudaKernel { struct ComputeImpl; int64_t axis_; + // "reduction" attribute has been defined since opset 13 but + // we never implemented it. Let's try to support them starting + // with opset 18. + std::string reduction_; }; } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/tensor/space_depth_ops.cc b/onnxruntime/core/providers/cuda/tensor/space_depth_ops.cc index 407a2ef3981f1..aaaf3600b676e 100644 --- a/onnxruntime/core/providers/cuda/tensor/space_depth_ops.cc +++ b/onnxruntime/core/providers/cuda/tensor/space_depth_ops.cc @@ -20,7 +20,22 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), - SpaceToDepth); + SpaceToDepth); + +#ifdef ENABLE_CUDA_NHWC_OPS +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + SpaceToDepth, + kMSInternalNHWCDomain, + 1, + 12, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", + {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + SpaceToDepth); +#endif ONNX_OPERATOR_KERNEL_EX( SpaceToDepth, @@ -32,7 +47,21 @@ ONNX_OPERATOR_KERNEL_EX( {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), - SpaceToDepth); + SpaceToDepth); + +#ifdef ENABLE_CUDA_NHWC_OPS +ONNX_OPERATOR_KERNEL_EX( + SpaceToDepth, + kMSInternalNHWCDomain, + 13, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", + {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + SpaceToDepth); +#endif ONNX_OPERATOR_VERSIONED_KERNEL_EX( DepthToSpace, @@ -45,7 +74,22 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), - DepthToSpace); + DepthToSpace); + +#ifdef ENABLE_CUDA_NHWC_OPS +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + DepthToSpace, + kMSInternalNHWCDomain, + 1, + 10, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", + {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + DepthToSpace); +#endif ONNX_OPERATOR_VERSIONED_KERNEL_EX( DepthToSpace, @@ -58,7 +102,22 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), - DepthToSpace); + DepthToSpace); + +#ifdef ENABLE_CUDA_NHWC_OPS +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + DepthToSpace, + kMSInternalNHWCDomain, + 11, + 12, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", + {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + DepthToSpace); +#endif ONNX_OPERATOR_KERNEL_EX( DepthToSpace, @@ -70,23 +129,35 @@ ONNX_OPERATOR_KERNEL_EX( {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), - DepthToSpace); + DepthToSpace); + +#ifdef ENABLE_CUDA_NHWC_OPS +ONNX_OPERATOR_KERNEL_EX( + DepthToSpace, + kMSInternalNHWCDomain, + 13, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", + {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + DepthToSpace); +#endif static Status SpaceDepthOpCudaImpl(const cudaDeviceProp& prop, cudaStream_t stream, const cublasHandle_t cublas_handle, const Tensor& input, Tensor& output, const std::vector& permutation, - const int64_t batch_size, - const int64_t in_dim1, const int64_t in_dim2, const int64_t in_dim3, - const int64_t in_dim4, const int64_t in_dim5, + const TensorShape& virtual_input_shape, const TensorShape& virtual_output_shape) { - TensorShape virtual_input_shape{batch_size, in_dim1, in_dim2, in_dim3, in_dim4, in_dim5}; return Transpose::DoTranspose(prop, stream, cublas_handle, permutation, input, output, &virtual_input_shape, &virtual_output_shape); } -Status SpaceToDepth::ComputeInternal(OpKernelContext* context) const { +template +Status SpaceToDepth::ComputeInternal(OpKernelContext* context) const { const auto* tensor_pointer = context->Input(0); if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); const Tensor& input = *tensor_pointer; @@ -101,29 +172,44 @@ Status SpaceToDepth::ComputeInternal(OpKernelContext* context) const { int64_t output_height = -1; int64_t output_width = -1; - ORT_RETURN_IF_ERROR(InputValidationsAndOutputDimsCalc(input, - batch, - input_depth, input_height, input_width, - output_depth, output_height, output_width, - true)); + ORT_RETURN_IF_ERROR( + InputValidationsAndOutputDimsCalc(input, + batch, + input_depth, input_height, input_width, + output_depth, output_height, output_width, + true)); // We use the "actual" output shape to construct the output tensor - Tensor& output = *context->Output(0, {batch, output_depth, output_height, output_width}); + Tensor& output = (Layout == LAYOUT_NCHW) + ? *context->Output(0, {batch, output_depth, output_height, output_width}) + : *context->Output(0, {batch, output_height, output_width, output_depth}); + + TensorShape virtual_input_shape = (Layout == LAYOUT_NCHW) + ? TensorShape{batch, input_depth, input_height / blocksize_, + blocksize_, input_width / blocksize_, blocksize_} + : TensorShape{batch, input_height / blocksize_, blocksize_, + input_width / blocksize_, blocksize_, input_depth}; // We will pass in the "virtual" output shape to be used by DoTranspose() in SpaceDepthOpCudaImpl(...) - TensorShape virtual_output_shape{batch, blocksize_, blocksize_, input_depth, - input_height / blocksize_, input_width / blocksize_}; + TensorShape virtual_output_shape = (Layout == LAYOUT_NCHW) + ? TensorShape{batch, blocksize_, blocksize_, input_depth, + input_height / blocksize_, input_width / blocksize_} + : TensorShape{batch, input_height / blocksize_, input_width / blocksize_, + blocksize_, blocksize_, input_depth}; - std::vector permutation = {0, 3, 5, 1, 2, 4}; + std::vector permutation = (Layout == LAYOUT_NCHW) + ? std::vector{0, 3, 5, 1, 2, 4} + : std::vector{0, 1, 3, 2, 4, 5}; - ORT_RETURN_IF_ERROR(SpaceDepthOpCudaImpl(GetDeviceProp(), Stream(context), GetCublasHandle(context), input, output, permutation, batch, - input_depth, input_height / blocksize_, blocksize_, input_width / blocksize_, blocksize_, - virtual_output_shape)); + ORT_RETURN_IF_ERROR( + SpaceDepthOpCudaImpl(GetDeviceProp(), Stream(context), GetCublasHandle(context), input, output, permutation, + virtual_input_shape, virtual_output_shape)); return Status::OK(); } -Status DepthToSpace::ComputeInternal(OpKernelContext* context) const { +template +Status DepthToSpace::ComputeInternal(OpKernelContext* context) const { const auto* tensor_pointer = context->Input(0); if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); const Tensor& input = *tensor_pointer; @@ -138,46 +224,56 @@ Status DepthToSpace::ComputeInternal(OpKernelContext* context) const { int64_t output_height = -1; int64_t output_width = -1; - ORT_RETURN_IF_ERROR(InputValidationsAndOutputDimsCalc(input, - batch, - input_depth, input_height, input_width, - output_depth, output_height, output_width, - false)); + ORT_RETURN_IF_ERROR( + InputValidationsAndOutputDimsCalc(input, + batch, + input_depth, input_height, input_width, + output_depth, output_height, output_width, + false)); // We use the "actual" output shape to construct the output tensor - Tensor& output = *context->Output(0, {batch, output_depth, output_height, output_width}); + Tensor& output = (Layout == LAYOUT_NCHW) + ? *context->Output(0, {batch, output_depth, output_height, output_width}) + : *context->Output(0, {batch, output_height, output_width, output_depth}); + + int64_t virtual_input_depth = input_depth / blocksize_ / blocksize_; + TensorShape virtual_input_shape; + + // cdr only here! + if (is_dcr_) { + virtual_input_shape = (Layout == LAYOUT_NCHW) + ? TensorShape{batch, blocksize_, blocksize_, + virtual_input_depth, input_height, input_width} + : TensorShape{batch, input_height, input_width, + blocksize_, blocksize_, virtual_input_depth}; + } else { + virtual_input_shape = (Layout == LAYOUT_NCHW) + ? TensorShape{batch, virtual_input_depth, blocksize_, + blocksize_, input_height, input_width} + : TensorShape{batch, input_height, input_width, + virtual_input_depth, blocksize_, blocksize_}; + } // We will pass in the "virtual" output shape to be used by DoTranspose() in SpaceDepthOpCudaImpl(...) - TensorShape virtual_output_shape{batch, input_depth / blocksize_ / blocksize_, - input_height, blocksize_, input_width, blocksize_}; + TensorShape virtual_output_shape = (Layout == LAYOUT_NCHW) + ? TensorShape{batch, virtual_input_depth, input_height, + blocksize_, input_width, blocksize_} + : TensorShape{batch, input_height, blocksize_, + input_width, blocksize_, virtual_input_depth}; std::vector permutation; - permutation.reserve(6); - permutation.push_back(0); if (is_dcr_) { - permutation.push_back(3); - permutation.push_back(4); - permutation.push_back(1); - permutation.push_back(5); - permutation.push_back(2); + permutation = (Layout == LAYOUT_NCHW) + ? std::vector({0, 3, 4, 1, 5, 2}) + : std::vector({0, 1, 3, 2, 4, 5}); } else { - permutation.push_back(1); - permutation.push_back(4); - permutation.push_back(2); - permutation.push_back(5); - permutation.push_back(3); + permutation = std::vector({0, 1, 4, 2, 5, 3}); } - int64_t dim1 = is_dcr_ ? blocksize_ : input_depth / blocksize_ / blocksize_; - int64_t dim3 = is_dcr_ ? input_depth / blocksize_ / blocksize_ : blocksize_; - ORT_RETURN_IF_ERROR(SpaceDepthOpCudaImpl(GetDeviceProp(), Stream(context), GetCublasHandle(context), input, output, - permutation, - batch, - dim1, blocksize_, dim3, input_height, input_width, - virtual_output_shape)); + permutation, virtual_input_shape, virtual_output_shape)); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h b/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h index 57b85556f1dbe..8780d9b365005 100644 --- a/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h +++ b/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h @@ -9,6 +9,7 @@ namespace onnxruntime { namespace cuda { +template class SpaceToDepth final : public CudaKernel, SpaceDepthBase { public: explicit SpaceToDepth(const OpKernelInfo& info) : CudaKernel(info), SpaceDepthBase(info) { @@ -17,6 +18,7 @@ class SpaceToDepth final : public CudaKernel, SpaceDepthBase { Status ComputeInternal(OpKernelContext* context) const override; }; +template class DepthToSpace final : public CudaKernel, SpaceDepthBase { public: explicit DepthToSpace(const OpKernelInfo& info) : CudaKernel(info), SpaceDepthBase(info) { diff --git a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu index 9f9c365d2a53d..6344845359b32 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu @@ -80,7 +80,7 @@ bool CanDoTranspose3D(const cudaDeviceProp& prop, size_t rank, const gsl::span& input_shape, - const TArray& input_strides, const void* input_data, void* output_data, int64_t N, + const TArray& input_strides, const void* input_data, void* output_data, int64_t /*N*/, const dim3& grid_size, const dim3& block_size) { switch (element_size) { HANDLE_TRANSPOSE_3D_TILE_DIM(int8_t); @@ -248,10 +248,10 @@ __global__ void Transpose4DKernelParallelizeOneElementPerThread( } bool CanDoTranspose4DParallelizeOneElementPerThread(const cudaDeviceProp& prop, - size_t element_size, + size_t /*element_size*/, int32_t rank, const gsl::span& input_dims, - const gsl::span& permutations, + const gsl::span& /*permutations*/, dim3& grid_size, dim3& block_size) { if (rank == 4) { // dims[3]: block.x diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.cc b/onnxruntime/core/providers/cuda/tensor/upsample.cc index ae12ca328bc7c..17533eb3d9a72 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.cc +++ b/onnxruntime/core/providers/cuda/tensor/upsample.cc @@ -2,6 +2,9 @@ // Licensed under the MIT License. #include "upsample.h" + +#include + #include "upsample_impl.h" #include "core/providers/cuda/tensor/resize_impl.h" #include "core/providers/cpu/tensor/utils.h" @@ -37,11 +40,23 @@ REGISTER_VERSIONED_TYPED_KERNEL(MLFloat16, 9, 9); REGISTER_VERSIONED_TYPED_KERNEL(int32_t, 9, 9); REGISTER_VERSIONED_TYPED_KERNEL(uint8_t, 9, 9); +template +Upsample::Upsample(const OpKernelInfo& info) : UpsampleBase(info), CudaKernel(info) { + if (UpsampleBase::antialias_) { + // Copy the table on DEVICE + const uint8_t* lookup_table = GetLookupTableShared(); + auto alloc = info.GetAllocator(OrtMemTypeDefault); + shared_lookup_table_ondevice_ = IAllocator::MakeUniquePtr(std::move(alloc), kLookupTableSize); + CUDA_CALL_THROW(cudaMemcpyAsync(shared_lookup_table_ondevice_.get(), lookup_table, kLookupTableSize, + cudaMemcpyHostToDevice, nullptr)); + } +} + template Status Upsample::BaseCompute(OpKernelContext* context, - const std::vector& roi, - const std::vector& scales, - const gsl::span& output_dims) const { + gsl::span roi, + gsl::span scales, + gsl::span output_dims) const { const Tensor* X = context->Input(0); auto X_dims = X->Shape().GetDims(); int32_t rank = static_cast(X_dims.size()); @@ -52,7 +67,8 @@ Status Upsample::BaseCompute(OpKernelContext* context, is_resize_ ? "Resize: input tensor cannot be scalar." : "Upsample: input tensor cannot be scalar."); if (rank != static_cast(scales.size())) return Status(ONNXRUNTIME, INVALID_ARGUMENT, - is_resize_ ? "Resize: input tensor's dimension does not match the scales." : "Upsample: input tensor's dimension does not match the scales."); + is_resize_ ? "Resize: input tensor's dimension does not match the scales." + : "Upsample: input tensor's dimension does not match the scales."); if (roi.size() != 2 * X_dims.size()) return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Resize: size of roi array should be 2 * N where N is the rank of input tensor X."); @@ -79,22 +95,194 @@ Status Upsample::BaseCompute(OpKernelContext* context, size_t output_count = Y->Shape().Size(); if (is_resize_) { - TArray input_shape(X_dims); - TArray output_shape(output_dims); - TArray roi_vals(roi); - TArray scales_vals(scales); - - size_t temp_buffer_size = CalcResizeBufferSize(mode_, output_dims); - auto dims_mapping_buffer = GetScratchBuffer(temp_buffer_size, context->GetComputeStream()); - void* dims_mapping = reinterpret_cast(dims_mapping_buffer.get()); - ResizeImpl(Stream(context), mode_, (int)rank, input_shape, output_shape, - input_strides, output_div_pitches, scales_vals, roi_vals, - reinterpret_cast(X->Data()), - reinterpret_cast(Y->MutableData()), - output_count, use_extrapolation_, ToCudaType::FromFloat(extrapolation_value_), - cubic_coeff_a_, exclude_outside_, - coordinate_transform_mode_, nearest_mode_, - dims_mapping); + const bool is_same = std::all_of(scales.begin(), scales.end(), [](float v) { return v == 1.0f; }) && + (coordinate_transform_mode_ != ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE); + if (is_same) { + CUDA_CALL_THROW(cudaMemcpyAsync(Y->MutableData(), X->Data(), + output_count * sizeof(T), cudaMemcpyDeviceToDevice, Stream(context))); + return Status::OK(); + } + + if (antialias_) { + TempSpaceAllocateFunc allocate_temp_space = [&](size_t bytes_size) { + return GetScratchBuffer(bytes_size, context->GetComputeStream()); + }; + + std::optional extrapolation_value; + if (use_extrapolation_) + extrapolation_value.emplace(extrapolation_value_); + + switch (mode_) { + case UpsampleMode::LINEAR: { + if (X_dims.size() == 2 || X_dims.size() == 4) { + const bool is_2D = X_dims.size() == 2; + + int64_t batch_size = 1; + int64_t num_channels = 1; + + int64_t input_height; + int64_t input_width; + + int64_t output_height; + int64_t output_width; + + float height_scale; + float width_scale; + + if (is_2D) { + input_height = X_dims[0]; + input_width = X_dims[1]; + + output_height = output_dims[0]; + output_width = output_dims[1]; + + height_scale = scales[0]; + width_scale = scales[1]; + } else { + if (scales[0] == 1.0f && scales[1] == 1.0f) { + batch_size = X_dims[Channels::N]; + num_channels = X_dims[Channels::C]; + input_height = X_dims[Channels::H]; + input_width = X_dims[Channels::W]; + + output_height = output_dims[Channels::H]; + output_width = output_dims[Channels::W]; + + height_scale = scales[2]; + width_scale = scales[3]; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Resize", ": NHWC is not supported yet"); + } + } + + ResizeAntiAliasImpl(Stream(context), + rank, + mode_, + coordinate_transform_mode_, + X_dims, output_dims, + batch_size, num_channels, + std::make_tuple(0, input_height, input_width), + std::make_tuple(0, output_height, output_width), + std::make_tuple(0.f, height_scale, width_scale), + output_div_pitches, + roi, + extrapolation_value, + exclude_outside_, + allocate_temp_space, + shared_lookup_table_ondevice_.get(), + reinterpret_cast(X->Data()), + reinterpret_cast(Y->MutableData()), + output_count); + + } else if (X_dims.size() == 3 || X_dims.size() == 5) { + const bool is_3D = X_dims.size() == 3; + + if (!is_3D) { + if (!(scales[0] == 1.0f && scales[1] == 1.0f)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Resize", ": NDHWC is not supported yet"); + } + } + + const int64_t batch_size = is_3D ? 1 : X_dims[0]; + const int64_t num_channels = is_3D ? 1 : X_dims[1]; + const int64_t input_depth = is_3D ? X_dims[0] : X_dims[2]; + const int64_t input_height = is_3D ? X_dims[1] : X_dims[3]; + const int64_t input_width = is_3D ? X_dims[2] : X_dims[4]; + + const int64_t output_depth = is_3D ? output_dims[0] : output_dims[2]; + const int64_t output_height = is_3D ? output_dims[1] : output_dims[3]; + const int64_t output_width = is_3D ? output_dims[2] : output_dims[4]; + + const float depth_scale = is_3D ? scales[0] : scales[2]; + const float height_scale = is_3D ? scales[1] : scales[3]; + const float width_scale = is_3D ? scales[2] : scales[4]; + + ResizeAntiAliasImpl(Stream(context), + rank, + mode_, + coordinate_transform_mode_, + X_dims, output_dims, + batch_size, num_channels, + std::make_tuple(input_depth, input_height, input_width), + std::make_tuple(output_depth, output_height, output_width), + std::make_tuple(depth_scale, height_scale, width_scale), + output_div_pitches, + roi, + extrapolation_value, + exclude_outside_, + allocate_temp_space, + shared_lookup_table_ondevice_.get(), + reinterpret_cast(X->Data()), + reinterpret_cast(Y->MutableData()), + output_count); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Resize", + ": 'Linear' mode only support 2-D inputs or 3-D inputs ('Bilinear', 'Trilinear') " + "or 4-D inputs or 5-D inputs with the corresponding outermost 2 scale values " + "being 1."); + } + } break; + case UpsampleMode::CUBIC: { + if (X_dims.size() != 2 && X_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Resize", + ": 'Cubic' mode only support 2-D inputs ('Bicubic') or 4-D inputs " + "with the corresponding outermost 2 scale values being 1."); + } + + const bool is_2D = X_dims.size() == 2; + const bool is_nchw = is_2D ? true : (scales[1] == 1.0f && scales[1] == 1.0f); + + ORT_RETURN_IF_NOT(is_nchw, + "Resize 'Cubic' mode only supports NCWH layout " + " with 2-D or 4-D with leading dims equal to 1"); + + const int64_t batch_size = is_2D ? 1 : X_dims[Channels::N]; + const int64_t num_channels = is_2D ? 1 : X_dims[Channels::C]; + const int64_t input_height = is_2D ? X_dims[0] : X_dims[Channels::H]; + const int64_t input_width = is_2D ? X_dims[1] : X_dims[Channels::W]; + + const int64_t output_height = is_2D ? output_dims[0] : output_dims[Channels::H]; + const int64_t output_width = is_2D ? output_dims[1] : output_dims[Channels::W]; + const float height_scale = is_2D ? scales[0] : scales[2]; + const float width_scale = is_2D ? scales[1] : scales[3]; + + ResizeAntiAliasImpl(Stream(context), rank, mode_, coordinate_transform_mode_, + X_dims, output_dims, + batch_size, num_channels, + std::make_tuple(0, input_height, input_width), + std::make_tuple(0, output_height, output_width), + std::make_tuple(0.f, height_scale, width_scale), + output_div_pitches, + roi, + extrapolation_value, + exclude_outside_, + allocate_temp_space, + shared_lookup_table_ondevice_.get(), + reinterpret_cast(X->Data()), + reinterpret_cast(Y->MutableData()), + output_count); + } break; + default: + return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Resize: unexpected mode"); + } + } else { + TArray input_shape(X_dims); + TArray output_shape(output_dims); + TArray roi_vals(roi); + TArray scales_vals(scales); + + size_t temp_buffer_size = CalcResizeBufferSize(mode_, output_dims); + auto dims_mapping_buffer = GetScratchBuffer(temp_buffer_size, context->GetComputeStream()); + void* dims_mapping = reinterpret_cast(dims_mapping_buffer.get()); + ResizeImpl(Stream(context), mode_, rank, input_shape, output_shape, + input_strides, output_div_pitches, scales_vals, roi_vals, + reinterpret_cast(X->Data()), + reinterpret_cast(Y->MutableData()), + output_count, use_extrapolation_, ToCudaType::FromFloat(extrapolation_value_), + cubic_coeff_a_, exclude_outside_, + coordinate_transform_mode_, nearest_mode_, + dims_mapping); + } } else { TArray scales_div(rank); @@ -124,7 +312,7 @@ Status Upsample::ComputeInternal(OpKernelContext* context) const { auto input_dims = X->Shape().GetDims(); TensorShapeVector output_dims(input_dims.size()); - std::vector roi_array(input_dims.size() * 2, 0.0f); + InlinedVector roi_array(input_dims.size() * 2, 0.0f); if (!roi_cached_) { bool use_default_roi = true; if (need_roi_input_) { @@ -147,29 +335,37 @@ Status Upsample::ComputeInternal(OpKernelContext* context) const { } } - const std::vector& roi = roi_cached_ ? roi_ : roi_array; - std::vector scales_array = scales_; + ComputeROIWithAxes(roi_array, input_dims.size()); + InlinedVector scales_array(input_dims.size()); + // opset < 10 if (OpKernel::Node().InputDefs().size() == 1) { - // Compute output shape from scales and input dims + // Compute output shape from scales attributes and input dims + scales_array = scales_; + ComputeOutputShape(scales_array, input_dims, output_dims); - return BaseCompute(context, roi, scales_, output_dims); + return BaseCompute(context, roi_array, scales_, output_dims); } const Tensor* scales = context->Input(scales_input_idx_); const Tensor* sizes = context->Input(sizes_input_idx_); + // This is when scales are obtained and cached from a constant initializer if (scales_cached_) { - ORT_ENFORCE(sizes == nullptr, "Only one of scales or sizes must be provided as input."); + ORT_RETURN_IF_NOT(sizes == nullptr, "Only one of scales or sizes must be provided as input."); + scales_array = scales_; + // Compute output shape from scales and input dims ComputeOutputShape(scales_array, input_dims, output_dims); - return BaseCompute(context, roi, scales_, output_dims); + return BaseCompute(context, roi_array, scales_array, output_dims); } - scales_array.resize((input_dims.size())); + // Scales and sizes are input to the node if (scales != nullptr && scales->Shape().Size() != 0) { // use scales input data ORT_ENFORCE(sizes == nullptr, "Only one of scales or sizes must be provided as input."); ORT_RETURN_IF_ERROR(ParseScalesData(scales, scales_array, input_dims.size())); + + // Compute output shape from scales and input dims ComputeOutputShape(scales_array, input_dims, output_dims); } else { // When sizes input is available directly populate it into the output_dims array. @@ -179,7 +375,7 @@ Status Upsample::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(ParseScalesDataAndAdjustOutputSize(output_dims, input_dims, scales_array)); } - return BaseCompute(context, roi, scales_array, output_dims); + return BaseCompute(context, roi_array, scales_array, output_dims); } } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.h b/onnxruntime/core/providers/cuda/tensor/upsample.h index 7bf2a23ede399..50597e0fba1b9 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.h +++ b/onnxruntime/core/providers/cuda/tensor/upsample.h @@ -13,12 +13,14 @@ namespace cuda { template class Upsample : public UpsampleBase, public CudaKernel { public: - Upsample(const OpKernelInfo& info) : UpsampleBase(info), CudaKernel(info) { - } + explicit Upsample(const OpKernelInfo& info); Status ComputeInternal(OpKernelContext* context) const override; - Status BaseCompute(OpKernelContext* context, const std::vector& roi, const std::vector& scales, - const gsl::span& output_dims) const; + Status BaseCompute(OpKernelContext* context, gsl::span roi, gsl::span scales, + gsl::span output_dims) const; + + private: + IAllocatorUniquePtr shared_lookup_table_ondevice_; }; } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/triton_kernel.cu b/onnxruntime/core/providers/cuda/triton_kernel.cu index 6ffbf0420a15f..b42dbd0291b7a 100644 --- a/onnxruntime/core/providers/cuda/triton_kernel.cu +++ b/onnxruntime/core/providers/cuda/triton_kernel.cu @@ -130,27 +130,11 @@ void LoadOrtTritonKernel() { std::call_once(load_ort_triton_kernel_flag, TryToLoadKernel); } -Status LaunchTritonKernel(cudaStream_t stream, std::string fname, - int grid0, int grid1, int grid2, void* args, size_t args_size) { -#ifdef USE_TRITON_KERNEL - if (ort_triton_kernel_map.count(fname) == 0) { - // Return unsupported status if function name not found in registry. - // This error status will be used by TunableOp - std::ostringstream message_stream; - message_stream << "Can't find ort triton kernel name: " << fname; - std::string message = message_stream.str(); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(true, message); - } - auto idx = ort_triton_kernel_map[fname]; - return LaunchTritonKernel(stream, idx, grid0, grid1, grid2, args, args_size); -#else - return Status::OK(); -#endif -} -Status LaunchTritonKernel(cudaStream_t stream, size_t idx, - int grid0, int grid1, int grid2, void* args, size_t args_size) { + #ifdef USE_TRITON_KERNEL +Status LaunchTritonKernel(cudaStream_t stream, size_t idx, int grid0, int grid1, int grid2, + void* args, size_t args_size) { if (idx >= ort_triton_kernel_metadata.size()) { // Return unsupported status when idx exceeds the size of ort_triton_kernel_metadata. // This error status will be used by TunableOp @@ -181,11 +165,37 @@ Status LaunchTritonKernel(cudaStream_t stream, size_t idx, nullptr, (void**)&config), "Launching kernel failed."); -#endif return Status::OK(); } +Status LaunchTritonKernel(cudaStream_t stream, std::string fname, int grid0, int grid1, int grid2, + void* args, size_t args_size) { + if (ort_triton_kernel_map.count(fname) == 0) { + // Return unsupported status if function name not found in registry. + // This error status will be used by TunableOp + std::ostringstream message_stream; + message_stream << "Can't find ort triton kernel name: " << fname; + std::string message = message_stream.str(); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(true, message); + } + auto idx = ort_triton_kernel_map[fname]; + return LaunchTritonKernel(stream, idx, grid0, grid1, grid2, args, args_size); +} + +#else +Status LaunchTritonKernel(cudaStream_t /*stream*/, std::string /*fname*/, int /*grid0*/, int /*grid1*/, int /*grid2*/, + void* /*args*/, size_t /*args_size*/) { + return Status::OK(); +} + +Status LaunchTritonKernel(cudaStream_t /*stream*/, size_t /*idx*/, int /*grid0*/, int /*grid1*/, int /*grid2*/, + void* /*args*/, size_t /*args_size*/) { + return Status::OK(); +} +#endif + + const TritonKernelMetaData* GetOrtTritonKernelMetadata(size_t idx) { if (idx >= ort_triton_kernel_metadata.size()) { return nullptr; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index f29cc3afc3cda..88e3dd487d427 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -80,15 +80,10 @@ namespace Windows::AI::MachineLearning::Adapter }; // This is the counterpart to the MLOperatorGraphDesc ABI struct which owns its memory and uses containers. - // Either nodesAsOperatorDesc or nodesAsIDMLOperator can have non-zero size. struct DmlGraphNodeCreateInfo { uint32_t nodeCount = 0; - std::vector> nodesAsOperatorDesc; - - // TODO (jeffbloo): Remove this - std::vector> nodesAsIDMLOperator; - + std::vector> nodes; std::vector inputEdges; std::vector outputEdges; std::vector intermediateEdges; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ApiTraits.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ApiTraits.cpp new file mode 100644 index 0000000000000..bf9800458102b --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ApiTraits.cpp @@ -0,0 +1,570 @@ +//--------------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// This file is automatically generated. Please do not edit it directly. +// To modify this file, edit the schema: dml/Tools/DirectMLSchema.json +// And run this script to regenerate: dml/Tools/GenerateSchema.ps1 +// +// #dml-new-operator-location +//--------------------------------------------------------------------------- + +#pragma once + +#include "precomp.h" + +template +T ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ +#ifndef WAI_BUILD_LINUX + // Clang will instantiate this template even if it isn't used, + // so this static_assert will always fire and break the build. + static_assert(false, "Not implemented for this type"); +#endif +} + +template <> +DML_TENSOR_DATA_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_TENSOR_DATA_TYPE_UNKNOWN", DML_TENSOR_DATA_TYPE_UNKNOWN}, + {"DML_TENSOR_DATA_TYPE_FLOAT32", DML_TENSOR_DATA_TYPE_FLOAT32}, + {"DML_TENSOR_DATA_TYPE_FLOAT16", DML_TENSOR_DATA_TYPE_FLOAT16}, + {"DML_TENSOR_DATA_TYPE_UINT32", DML_TENSOR_DATA_TYPE_UINT32}, + {"DML_TENSOR_DATA_TYPE_UINT16", DML_TENSOR_DATA_TYPE_UINT16}, + {"DML_TENSOR_DATA_TYPE_UINT8", DML_TENSOR_DATA_TYPE_UINT8}, + {"DML_TENSOR_DATA_TYPE_INT32", DML_TENSOR_DATA_TYPE_INT32}, + {"DML_TENSOR_DATA_TYPE_INT16", DML_TENSOR_DATA_TYPE_INT16}, + {"DML_TENSOR_DATA_TYPE_INT8", DML_TENSOR_DATA_TYPE_INT8}, + {"DML_TENSOR_DATA_TYPE_FLOAT64", DML_TENSOR_DATA_TYPE_FLOAT64}, + {"DML_TENSOR_DATA_TYPE_UINT64", DML_TENSOR_DATA_TYPE_UINT64}, + {"DML_TENSOR_DATA_TYPE_INT64", DML_TENSOR_DATA_TYPE_INT64}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_TENSOR_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_TENSOR_TYPE_INVALID", DML_TENSOR_TYPE_INVALID}, + {"DML_TENSOR_TYPE_BUFFER", DML_TENSOR_TYPE_BUFFER}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_OPERATOR_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_OPERATOR_INVALID", DML_OPERATOR_INVALID}, + {"DML_OPERATOR_ELEMENT_WISE_IDENTITY", DML_OPERATOR_ELEMENT_WISE_IDENTITY}, + {"DML_OPERATOR_ELEMENT_WISE_ABS", DML_OPERATOR_ELEMENT_WISE_ABS}, + {"DML_OPERATOR_ELEMENT_WISE_ACOS", DML_OPERATOR_ELEMENT_WISE_ACOS}, + {"DML_OPERATOR_ELEMENT_WISE_ADD", DML_OPERATOR_ELEMENT_WISE_ADD}, + {"DML_OPERATOR_ELEMENT_WISE_ASIN", DML_OPERATOR_ELEMENT_WISE_ASIN}, + {"DML_OPERATOR_ELEMENT_WISE_ATAN", DML_OPERATOR_ELEMENT_WISE_ATAN}, + {"DML_OPERATOR_ELEMENT_WISE_CEIL", DML_OPERATOR_ELEMENT_WISE_CEIL}, + {"DML_OPERATOR_ELEMENT_WISE_CLIP", DML_OPERATOR_ELEMENT_WISE_CLIP}, + {"DML_OPERATOR_ELEMENT_WISE_COS", DML_OPERATOR_ELEMENT_WISE_COS}, + {"DML_OPERATOR_ELEMENT_WISE_DIVIDE", DML_OPERATOR_ELEMENT_WISE_DIVIDE}, + {"DML_OPERATOR_ELEMENT_WISE_EXP", DML_OPERATOR_ELEMENT_WISE_EXP}, + {"DML_OPERATOR_ELEMENT_WISE_FLOOR", DML_OPERATOR_ELEMENT_WISE_FLOOR}, + {"DML_OPERATOR_ELEMENT_WISE_LOG", DML_OPERATOR_ELEMENT_WISE_LOG}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND", DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS", DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN", DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN", DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL", DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL", DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT", DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR", DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR", DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR}, + {"DML_OPERATOR_ELEMENT_WISE_MAX", DML_OPERATOR_ELEMENT_WISE_MAX}, + {"DML_OPERATOR_ELEMENT_WISE_MEAN", DML_OPERATOR_ELEMENT_WISE_MEAN}, + {"DML_OPERATOR_ELEMENT_WISE_MIN", DML_OPERATOR_ELEMENT_WISE_MIN}, + {"DML_OPERATOR_ELEMENT_WISE_MULTIPLY", DML_OPERATOR_ELEMENT_WISE_MULTIPLY}, + {"DML_OPERATOR_ELEMENT_WISE_POW", DML_OPERATOR_ELEMENT_WISE_POW}, + {"DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW", DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW}, + {"DML_OPERATOR_ELEMENT_WISE_RECIP", DML_OPERATOR_ELEMENT_WISE_RECIP}, + {"DML_OPERATOR_ELEMENT_WISE_SIN", DML_OPERATOR_ELEMENT_WISE_SIN}, + {"DML_OPERATOR_ELEMENT_WISE_SQRT", DML_OPERATOR_ELEMENT_WISE_SQRT}, + {"DML_OPERATOR_ELEMENT_WISE_SUBTRACT", DML_OPERATOR_ELEMENT_WISE_SUBTRACT}, + {"DML_OPERATOR_ELEMENT_WISE_TAN", DML_OPERATOR_ELEMENT_WISE_TAN}, + {"DML_OPERATOR_ELEMENT_WISE_THRESHOLD", DML_OPERATOR_ELEMENT_WISE_THRESHOLD}, + {"DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR", DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR}, + {"DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR", DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR}, + {"DML_OPERATOR_ACTIVATION_ELU", DML_OPERATOR_ACTIVATION_ELU}, + {"DML_OPERATOR_ACTIVATION_CELU", DML_OPERATOR_ACTIVATION_CELU}, + {"DML_OPERATOR_ACTIVATION_HARDMAX", DML_OPERATOR_ACTIVATION_HARDMAX}, + {"DML_OPERATOR_ACTIVATION_HARDMAX1", DML_OPERATOR_ACTIVATION_HARDMAX1}, + {"DML_OPERATOR_ACTIVATION_HARD_SIGMOID", DML_OPERATOR_ACTIVATION_HARD_SIGMOID}, + {"DML_OPERATOR_ACTIVATION_IDENTITY", DML_OPERATOR_ACTIVATION_IDENTITY}, + {"DML_OPERATOR_ACTIVATION_LEAKY_RELU", DML_OPERATOR_ACTIVATION_LEAKY_RELU}, + {"DML_OPERATOR_ACTIVATION_LINEAR", DML_OPERATOR_ACTIVATION_LINEAR}, + {"DML_OPERATOR_ACTIVATION_LOG_SOFTMAX", DML_OPERATOR_ACTIVATION_LOG_SOFTMAX}, + {"DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1", DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1}, + {"DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU", DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU}, + {"DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS", DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS}, + {"DML_OPERATOR_ACTIVATION_RELU", DML_OPERATOR_ACTIVATION_RELU}, + {"DML_OPERATOR_ACTIVATION_SCALED_ELU", DML_OPERATOR_ACTIVATION_SCALED_ELU}, + {"DML_OPERATOR_ACTIVATION_SCALED_TANH", DML_OPERATOR_ACTIVATION_SCALED_TANH}, + {"DML_OPERATOR_ACTIVATION_SIGMOID", DML_OPERATOR_ACTIVATION_SIGMOID}, + {"DML_OPERATOR_ACTIVATION_SOFTMAX", DML_OPERATOR_ACTIVATION_SOFTMAX}, + {"DML_OPERATOR_ACTIVATION_SOFTMAX1", DML_OPERATOR_ACTIVATION_SOFTMAX1}, + {"DML_OPERATOR_ACTIVATION_SOFTPLUS", DML_OPERATOR_ACTIVATION_SOFTPLUS}, + {"DML_OPERATOR_ACTIVATION_SOFTSIGN", DML_OPERATOR_ACTIVATION_SOFTSIGN}, + {"DML_OPERATOR_ACTIVATION_TANH", DML_OPERATOR_ACTIVATION_TANH}, + {"DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU", DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU}, + {"DML_OPERATOR_CONVOLUTION", DML_OPERATOR_CONVOLUTION}, + {"DML_OPERATOR_GEMM", DML_OPERATOR_GEMM}, + {"DML_OPERATOR_REDUCE", DML_OPERATOR_REDUCE}, + {"DML_OPERATOR_AVERAGE_POOLING", DML_OPERATOR_AVERAGE_POOLING}, + {"DML_OPERATOR_AVERAGE_POOLING1", DML_OPERATOR_AVERAGE_POOLING1}, + {"DML_OPERATOR_LP_POOLING", DML_OPERATOR_LP_POOLING}, + {"DML_OPERATOR_LP_POOLING1", DML_OPERATOR_LP_POOLING1}, + {"DML_OPERATOR_MAX_POOLING", DML_OPERATOR_MAX_POOLING}, + {"DML_OPERATOR_ROI_POOLING", DML_OPERATOR_ROI_POOLING}, + {"DML_OPERATOR_SLICE", DML_OPERATOR_SLICE}, + {"DML_OPERATOR_CAST", DML_OPERATOR_CAST}, + {"DML_OPERATOR_SPLIT", DML_OPERATOR_SPLIT}, + {"DML_OPERATOR_JOIN", DML_OPERATOR_JOIN}, + {"DML_OPERATOR_PADDING", DML_OPERATOR_PADDING}, + {"DML_OPERATOR_PADDING1", DML_OPERATOR_PADDING1}, + {"DML_OPERATOR_VALUE_SCALE_2D", DML_OPERATOR_VALUE_SCALE_2D}, + {"DML_OPERATOR_UPSAMPLE_2D", DML_OPERATOR_UPSAMPLE_2D}, + {"DML_OPERATOR_GATHER", DML_OPERATOR_GATHER}, + {"DML_OPERATOR_SPACE_TO_DEPTH", DML_OPERATOR_SPACE_TO_DEPTH}, + {"DML_OPERATOR_DEPTH_TO_SPACE", DML_OPERATOR_DEPTH_TO_SPACE}, + {"DML_OPERATOR_TILE", DML_OPERATOR_TILE}, + {"DML_OPERATOR_TOP_K", DML_OPERATOR_TOP_K}, + {"DML_OPERATOR_BATCH_NORMALIZATION", DML_OPERATOR_BATCH_NORMALIZATION}, + {"DML_OPERATOR_BATCH_NORMALIZATION_TRAINING", DML_OPERATOR_BATCH_NORMALIZATION_TRAINING}, + {"DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION", DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION}, + {"DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION", DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION}, + {"DML_OPERATOR_LP_NORMALIZATION", DML_OPERATOR_LP_NORMALIZATION}, + {"DML_OPERATOR_RNN", DML_OPERATOR_RNN}, + {"DML_OPERATOR_LSTM", DML_OPERATOR_LSTM}, + {"DML_OPERATOR_GRU", DML_OPERATOR_GRU}, + {"DML_OPERATOR_ELEMENT_WISE_SIGN", DML_OPERATOR_ELEMENT_WISE_SIGN}, + {"DML_OPERATOR_ELEMENT_WISE_IS_NAN", DML_OPERATOR_ELEMENT_WISE_IS_NAN}, + {"DML_OPERATOR_ELEMENT_WISE_ERF", DML_OPERATOR_ELEMENT_WISE_ERF}, + {"DML_OPERATOR_ELEMENT_WISE_SINH", DML_OPERATOR_ELEMENT_WISE_SINH}, + {"DML_OPERATOR_ELEMENT_WISE_COSH", DML_OPERATOR_ELEMENT_WISE_COSH}, + {"DML_OPERATOR_ELEMENT_WISE_TANH", DML_OPERATOR_ELEMENT_WISE_TANH}, + {"DML_OPERATOR_ELEMENT_WISE_ASINH", DML_OPERATOR_ELEMENT_WISE_ASINH}, + {"DML_OPERATOR_ELEMENT_WISE_ACOSH", DML_OPERATOR_ELEMENT_WISE_ACOSH}, + {"DML_OPERATOR_ELEMENT_WISE_ATANH", DML_OPERATOR_ELEMENT_WISE_ATANH}, + {"DML_OPERATOR_ELEMENT_WISE_IF", DML_OPERATOR_ELEMENT_WISE_IF}, + {"DML_OPERATOR_ELEMENT_WISE_ADD1", DML_OPERATOR_ELEMENT_WISE_ADD1}, + {"DML_OPERATOR_ACTIVATION_SHRINK", DML_OPERATOR_ACTIVATION_SHRINK}, + {"DML_OPERATOR_MAX_POOLING1", DML_OPERATOR_MAX_POOLING1}, + {"DML_OPERATOR_MAX_UNPOOLING", DML_OPERATOR_MAX_UNPOOLING}, + {"DML_OPERATOR_DIAGONAL_MATRIX", DML_OPERATOR_DIAGONAL_MATRIX}, + {"DML_OPERATOR_SCATTER", DML_OPERATOR_SCATTER}, + {"DML_OPERATOR_ONE_HOT", DML_OPERATOR_ONE_HOT}, + {"DML_OPERATOR_RESAMPLE", DML_OPERATOR_RESAMPLE}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT", DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT", DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT}, + {"DML_OPERATOR_ELEMENT_WISE_ROUND", DML_OPERATOR_ELEMENT_WISE_ROUND}, + {"DML_OPERATOR_ELEMENT_WISE_IS_INFINITY", DML_OPERATOR_ELEMENT_WISE_IS_INFINITY}, + {"DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE", DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE}, + {"DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR", DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR}, + {"DML_OPERATOR_FILL_VALUE_SEQUENCE", DML_OPERATOR_FILL_VALUE_SEQUENCE}, + {"DML_OPERATOR_FILL_VALUE_CONSTANT", DML_OPERATOR_FILL_VALUE_CONSTANT}, + {"DML_OPERATOR_CUMULATIVE_SUMMATION", DML_OPERATOR_CUMULATIVE_SUMMATION}, + {"DML_OPERATOR_REVERSE_SUBSEQUENCES", DML_OPERATOR_REVERSE_SUBSEQUENCES}, + {"DML_OPERATOR_GATHER_ELEMENTS", DML_OPERATOR_GATHER_ELEMENTS}, + {"DML_OPERATOR_GATHER_ND", DML_OPERATOR_GATHER_ND}, + {"DML_OPERATOR_SCATTER_ND", DML_OPERATOR_SCATTER_ND}, + {"DML_OPERATOR_MAX_POOLING2", DML_OPERATOR_MAX_POOLING2}, + {"DML_OPERATOR_SLICE1", DML_OPERATOR_SLICE1}, + {"DML_OPERATOR_TOP_K1", DML_OPERATOR_TOP_K1}, + {"DML_OPERATOR_DEPTH_TO_SPACE1", DML_OPERATOR_DEPTH_TO_SPACE1}, + {"DML_OPERATOR_SPACE_TO_DEPTH1", DML_OPERATOR_SPACE_TO_DEPTH1}, + {"DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1", DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1}, + {"DML_OPERATOR_RESAMPLE1", DML_OPERATOR_RESAMPLE1}, + {"DML_OPERATOR_MATRIX_MULTIPLY_INTEGER", DML_OPERATOR_MATRIX_MULTIPLY_INTEGER}, + {"DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY", DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY}, + {"DML_OPERATOR_CONVOLUTION_INTEGER", DML_OPERATOR_CONVOLUTION_INTEGER}, + {"DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION", DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_AND", DML_OPERATOR_ELEMENT_WISE_BIT_AND}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_OR", DML_OPERATOR_ELEMENT_WISE_BIT_OR}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_XOR", DML_OPERATOR_ELEMENT_WISE_BIT_XOR}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_NOT", DML_OPERATOR_ELEMENT_WISE_BIT_NOT}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_COUNT", DML_OPERATOR_ELEMENT_WISE_BIT_COUNT}, + {"DML_OPERATOR_ACTIVATION_RELU_GRAD", DML_OPERATOR_ACTIVATION_RELU_GRAD}, + {"DML_OPERATOR_AVERAGE_POOLING_GRAD", DML_OPERATOR_AVERAGE_POOLING_GRAD}, + {"DML_OPERATOR_MAX_POOLING_GRAD", DML_OPERATOR_MAX_POOLING_GRAD}, + {"DML_OPERATOR_RANDOM_GENERATOR", DML_OPERATOR_RANDOM_GENERATOR}, + {"DML_OPERATOR_NONZERO_COORDINATES", DML_OPERATOR_NONZERO_COORDINATES}, + {"DML_OPERATOR_RESAMPLE_GRAD", DML_OPERATOR_RESAMPLE_GRAD}, + {"DML_OPERATOR_SLICE_GRAD", DML_OPERATOR_SLICE_GRAD}, + {"DML_OPERATOR_ADAM_OPTIMIZER", DML_OPERATOR_ADAM_OPTIMIZER}, + {"DML_OPERATOR_ARGMIN", DML_OPERATOR_ARGMIN}, + {"DML_OPERATOR_ARGMAX", DML_OPERATOR_ARGMAX}, + {"DML_OPERATOR_ROI_ALIGN", DML_OPERATOR_ROI_ALIGN}, + {"DML_OPERATOR_GATHER_ND1", DML_OPERATOR_GATHER_ND1}, + {"DML_OPERATOR_ELEMENT_WISE_ATAN_YX", DML_OPERATOR_ELEMENT_WISE_ATAN_YX}, + {"DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD", DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD}, + {"DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE", DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE}, + {"DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD", DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD}, + {"DML_OPERATOR_CUMULATIVE_PRODUCT", DML_OPERATOR_CUMULATIVE_PRODUCT}, + {"DML_OPERATOR_BATCH_NORMALIZATION_GRAD", DML_OPERATOR_BATCH_NORMALIZATION_GRAD}, + {"DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD", DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD}, + {"DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD", DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD}, + {"DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR", DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR}, + {"DML_OPERATOR_ROI_ALIGN1", DML_OPERATOR_ROI_ALIGN1}, + {"DML_OPERATOR_ELEMENT_WISE_CLIP1", DML_OPERATOR_ELEMENT_WISE_CLIP1}, + {"DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1", DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1}, + {"DML_OPERATOR_ELEMENT_WISE_NEGATE", DML_OPERATOR_ELEMENT_WISE_NEGATE}, + {"DML_OPERATOR_ACTIVATION_GELU", DML_OPERATOR_ACTIVATION_GELU}, + {"DML_OPERATOR_ACTIVATION_SWISH", DML_OPERATOR_ACTIVATION_SWISH}, + {"DML_OPERATOR_ACTIVATION_HARD_SWISH", DML_OPERATOR_ACTIVATION_HARD_SWISH}, + {"DML_OPERATOR_RESAMPLE2", DML_OPERATOR_RESAMPLE2}, + {"DML_OPERATOR_RESAMPLE_GRAD1", DML_OPERATOR_RESAMPLE_GRAD1}, + {"DML_OPERATOR_DIAGONAL_MATRIX1", DML_OPERATOR_DIAGONAL_MATRIX1}, + {"DML_OPERATOR_MULTIHEAD_ATTENTION", DML_OPERATOR_MULTIHEAD_ATTENTION}, + {"DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING", DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING}, + {"DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT", DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_BINDING_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_BINDING_TYPE_NONE", DML_BINDING_TYPE_NONE}, + {"DML_BINDING_TYPE_BUFFER", DML_BINDING_TYPE_BUFFER}, + {"DML_BINDING_TYPE_BUFFER_ARRAY", DML_BINDING_TYPE_BUFFER_ARRAY}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_REDUCE_FUNCTION ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_REDUCE_FUNCTION_ARGMAX", DML_REDUCE_FUNCTION_ARGMAX}, + {"DML_REDUCE_FUNCTION_ARGMIN", DML_REDUCE_FUNCTION_ARGMIN}, + {"DML_REDUCE_FUNCTION_AVERAGE", DML_REDUCE_FUNCTION_AVERAGE}, + {"DML_REDUCE_FUNCTION_L1", DML_REDUCE_FUNCTION_L1}, + {"DML_REDUCE_FUNCTION_L2", DML_REDUCE_FUNCTION_L2}, + {"DML_REDUCE_FUNCTION_LOG_SUM", DML_REDUCE_FUNCTION_LOG_SUM}, + {"DML_REDUCE_FUNCTION_LOG_SUM_EXP", DML_REDUCE_FUNCTION_LOG_SUM_EXP}, + {"DML_REDUCE_FUNCTION_MAX", DML_REDUCE_FUNCTION_MAX}, + {"DML_REDUCE_FUNCTION_MIN", DML_REDUCE_FUNCTION_MIN}, + {"DML_REDUCE_FUNCTION_MULTIPLY", DML_REDUCE_FUNCTION_MULTIPLY}, + {"DML_REDUCE_FUNCTION_SUM", DML_REDUCE_FUNCTION_SUM}, + {"DML_REDUCE_FUNCTION_SUM_SQUARE", DML_REDUCE_FUNCTION_SUM_SQUARE}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + +template <> +DML_MATRIX_TRANSFORM ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_MATRIX_TRANSFORM_NONE", DML_MATRIX_TRANSFORM_NONE}, + {"DML_MATRIX_TRANSFORM_TRANSPOSE", DML_MATRIX_TRANSFORM_TRANSPOSE}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_CONVOLUTION_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_CONVOLUTION_MODE_CONVOLUTION", DML_CONVOLUTION_MODE_CONVOLUTION}, + {"DML_CONVOLUTION_MODE_CROSS_CORRELATION", DML_CONVOLUTION_MODE_CROSS_CORRELATION}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_CONVOLUTION_DIRECTION ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_CONVOLUTION_DIRECTION_FORWARD", DML_CONVOLUTION_DIRECTION_FORWARD}, + {"DML_CONVOLUTION_DIRECTION_BACKWARD", DML_CONVOLUTION_DIRECTION_BACKWARD}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + +template <> +DML_PADDING_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_PADDING_MODE_CONSTANT", DML_PADDING_MODE_CONSTANT}, + {"DML_PADDING_MODE_EDGE", DML_PADDING_MODE_EDGE}, + {"DML_PADDING_MODE_REFLECTION", DML_PADDING_MODE_REFLECTION}, + {"DML_PADDING_MODE_SYMMETRIC", DML_PADDING_MODE_SYMMETRIC}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_INTERPOLATION_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR", DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR}, + {"DML_INTERPOLATION_MODE_LINEAR", DML_INTERPOLATION_MODE_LINEAR}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_RECURRENT_NETWORK_DIRECTION ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_RECURRENT_NETWORK_DIRECTION_FORWARD", DML_RECURRENT_NETWORK_DIRECTION_FORWARD}, + {"DML_RECURRENT_NETWORK_DIRECTION_BACKWARD", DML_RECURRENT_NETWORK_DIRECTION_BACKWARD}, + {"DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL", DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_FEATURE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT", DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT}, + {"DML_FEATURE_FEATURE_LEVELS", DML_FEATURE_FEATURE_LEVELS}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_FEATURE_LEVEL ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_FEATURE_LEVEL_1_0", DML_FEATURE_LEVEL_1_0}, + {"DML_FEATURE_LEVEL_2_0", DML_FEATURE_LEVEL_2_0}, + {"DML_FEATURE_LEVEL_2_1", DML_FEATURE_LEVEL_2_1}, + {"DML_FEATURE_LEVEL_3_0", DML_FEATURE_LEVEL_3_0}, + {"DML_FEATURE_LEVEL_3_1", DML_FEATURE_LEVEL_3_1}, + {"DML_FEATURE_LEVEL_4_0", DML_FEATURE_LEVEL_4_0}, + {"DML_FEATURE_LEVEL_4_1", DML_FEATURE_LEVEL_4_1}, + {"DML_FEATURE_LEVEL_5_0", DML_FEATURE_LEVEL_5_0}, + {"DML_FEATURE_LEVEL_5_1", DML_FEATURE_LEVEL_5_1}, + {"DML_FEATURE_LEVEL_5_2", DML_FEATURE_LEVEL_5_2}, + {"DML_FEATURE_LEVEL_6_0", DML_FEATURE_LEVEL_6_0}, + {"DML_FEATURE_LEVEL_6_1", DML_FEATURE_LEVEL_6_1}, + {"DML_FEATURE_LEVEL_6_2", DML_FEATURE_LEVEL_6_2}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_IS_INFINITY_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_IS_INFINITY_MODE_EITHER", DML_IS_INFINITY_MODE_EITHER}, + {"DML_IS_INFINITY_MODE_POSITIVE", DML_IS_INFINITY_MODE_POSITIVE}, + {"DML_IS_INFINITY_MODE_NEGATIVE", DML_IS_INFINITY_MODE_NEGATIVE}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_DEPTH_SPACE_ORDER ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW", DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW}, + {"DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH", DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_AXIS_DIRECTION ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_AXIS_DIRECTION_INCREASING", DML_AXIS_DIRECTION_INCREASING}, + {"DML_AXIS_DIRECTION_DECREASING", DML_AXIS_DIRECTION_DECREASING}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_ROUNDING_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_ROUNDING_MODE_HALVES_TO_NEAREST_EVEN", DML_ROUNDING_MODE_HALVES_TO_NEAREST_EVEN}, + {"DML_ROUNDING_MODE_TOWARD_ZERO", DML_ROUNDING_MODE_TOWARD_ZERO}, + {"DML_ROUNDING_MODE_TOWARD_INFINITY", DML_ROUNDING_MODE_TOWARD_INFINITY}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_RANDOM_GENERATOR_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_RANDOM_GENERATOR_TYPE_PHILOX_4X32_10", DML_RANDOM_GENERATOR_TYPE_PHILOX_4X32_10}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_MULTIHEAD_ATTENTION_MASK_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE", DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE}, + {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH", DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH}, + {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START", DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START}, + {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END", DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END}, + {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN", DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp new file mode 100644 index 0000000000000..7d8ed17e7d925 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp @@ -0,0 +1,554 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +#pragma once +#include "precomp.h" + +OperatorFieldVariant CreateAttribute( + const DML_SCHEMA_FIELD* schemaField, + const dml::ir::operatorFieldTypes::AttributeDesc* attributeDesc); + +OperatorFieldVariant CreateActivation( + const dml::ir::operatorFieldTypes::Activation* activationDesc) +{ + DML_OPERATOR_TYPE activationOperatorType = ApiTraits::StringifyHelpers::FromString(activationDesc->type()->c_str()); + const DML_OPERATOR_SCHEMA& activationSchema = SchemaHelpers::GetSchema(activationOperatorType); + std::vector activationOperatorFields(activationSchema.FieldCount); + uint32_t attributeIndex = 0; + + for (uint32_t fieldIndex = 0; fieldIndex < activationSchema.FieldCount; fieldIndex++) + { + const DML_SCHEMA_FIELD* schemaField = &activationSchema.Fields[fieldIndex]; + OperatorFieldVariant field; + switch (schemaField->Kind) + { + case DML_SCHEMA_FIELD_KIND_INPUT_TENSOR: + case DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR: + { + if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC) + { + field = OperatorFieldTypes::TensorDesc(); + } + else if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY) + { + field = OperatorFieldTypes::TensorDescArray(); + } + break; + } + case DML_SCHEMA_FIELD_KIND_ATTRIBUTE: + { + const dml::ir::operatorFieldTypes::AttributeDesc* attributeDesc = + attributeIndex >= activationDesc->attributes()->size() ? + nullptr : + activationDesc->attributes()->Get(attributeIndex++); + field = CreateAttribute(schemaField, attributeDesc); + break; + } + } + + activationOperatorFields[fieldIndex] = OperatorField(schemaField, std::move(field)); + } + + return AbstractOperatorDesc(&activationSchema, std::move(activationOperatorFields)); +} + +OperatorFieldVariant CreateActivations( + const dml::ir::operatorFieldTypes::ActivationArray* activationDescs) +{ + std::vector activations; + for (uint32_t index = 0; index < static_cast(activationDescs->data()->size()); index++) + { + OperatorFieldVariant activation = CreateActivation(activationDescs->data()->Get(index)); + activations.push_back(std::get(activation).value()); + } + return activations; +} + +OperatorFieldVariant CreateAttribute( + const DML_SCHEMA_FIELD* schemaField, + const dml::ir::operatorFieldTypes::AttributeDesc* attributeDesc) +{ + switch (schemaField->Type) + { + case DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC: + { + return attributeDesc != nullptr && attributeDesc->val_as_Activation() != nullptr ? + CreateActivation(attributeDesc->val_as_Activation()) : + OperatorFieldTypes::FusedActivationOperatorDesc(); + } + case DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY: + { + return attributeDesc != nullptr && attributeDesc->val_as_ActivationArray() != nullptr ? + CreateActivations(attributeDesc->val_as_ActivationArray()) : + OperatorFieldTypes::FusedActivationOperatorDescArray(); + } + case DML_SCHEMA_FIELD_TYPE_UINT: + { + OperatorFieldTypes::UInt data; + if (attributeDesc != nullptr) + { + data = attributeDesc->val_as_UInt32()->data(); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_UINT64: + { + OperatorFieldTypes::UInt64 data; + if (attributeDesc != nullptr) + { + data = attributeDesc->val_as_UInt64()->data(); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_INT: + { + OperatorFieldTypes::Int data; + if (attributeDesc != nullptr) + { + data = attributeDesc->val_as_Int32()->data(); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_FLOAT: + { + OperatorFieldTypes::Float data; + if (attributeDesc != nullptr) + { + data = attributeDesc->val_as_Float32()->data(); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_UINT_ARRAY: + { + OperatorFieldTypes::UIntArray data; + if (attributeDesc != nullptr) + { + data.assign(attributeDesc->val_as_UIntArray()->data()->begin(), attributeDesc->val_as_UIntArray()->data()->end()); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_INT_ARRAY: + { + OperatorFieldTypes::IntArray data; + if (attributeDesc != nullptr) + { + data.assign(attributeDesc->val_as_IntArray()->data()->begin(), attributeDesc->val_as_IntArray()->data()->end()); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY: + { + OperatorFieldTypes::FloatArray data; + if (attributeDesc != nullptr) + { + data.assign(attributeDesc->val_as_FloatArray()->data()->begin(), attributeDesc->val_as_FloatArray()->data()->end()); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_SCALE_BIAS: + { + OperatorFieldTypes::ScaleBias scaleBias; + const dml::ir::operatorFieldTypes::ScaleBias* scaleBiasAttribute = attributeDesc->val_as_ScaleBias(); + if (scaleBiasAttribute != nullptr) + { + scaleBias = {scaleBiasAttribute->scale(), scaleBiasAttribute->bias()}; + } + return scaleBias; + } + case DML_SCHEMA_FIELD_TYPE_SIZE_2D: + { + OperatorFieldTypes::Size2D size2d = {}; + if (attributeDesc != nullptr) + { + size2d.Height = attributeDesc->val_as_Size2D()->height(); + size2d.Width = attributeDesc->val_as_Size2D()->width(); + } + return size2d; + } + case DML_SCHEMA_FIELD_TYPE_SCALAR_UNION: + { + DML_SCALAR_UNION scalarUnion; + if (attributeDesc != nullptr) + { + const dml::ir::operatorFieldTypes::ByteArray* byteArr = attributeDesc->val_as_ScalarUnionData()->data_as_ByteArray(); + std::copy(byteArr->data()->begin(), byteArr->data()->end(), scalarUnion.Bytes); + } + return scalarUnion; + } + case DML_SCHEMA_FIELD_TYPE_BOOL: + { + OperatorFieldTypes::Bool data; + if (attributeDesc != nullptr) + { + data = attributeDesc->val_as_Bool()->data(); + } + return data; + } + default: + { + throw std::invalid_argument("Invalid attribute type."); + } + } +} + +OperatorFieldTypes::TensorDesc CreateBufferTensorDesc( + const dml::ir::DmlBufferTensorDesc* tensorDesc, + const bool isConstantTensor = false) +{ + DmlBufferTensorDesc bufferTensorDesc = {}; + bufferTensorDesc.dataType = ApiTraits::StringifyHelpers::FromString(tensorDesc->dataType()->c_str()); + if (isConstantTensor) + { + bufferTensorDesc.flags = DML_TENSOR_FLAG_OWNED_BY_DML; + } + bufferTensorDesc.sizes.assign(tensorDesc->sizes()->begin(), tensorDesc->sizes()->end()); + if (flatbuffers::IsFieldPresent(tensorDesc, dml::ir::DmlBufferTensorDesc::VT_STRIDES)) + { + bufferTensorDesc.strides.emplace(tensorDesc->strides()->begin(), tensorDesc->strides()->end()); + } + bufferTensorDesc.totalTensorSizeInBytes = tensorDesc->totalTensorSizeInBytes(); + return bufferTensorDesc; +} + +AbstractOperatorDesc CreateAbstractOperatorDesc( + uint32_t nodeIndex, + const dml::ir::OperatorNodeDesc* flatbufferOperatorNodeDesc, + const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>* nodeInputNames, + const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>* nodeOutputNames, + const std::unordered_set& constantInputs) +{ + DML_OPERATOR_TYPE type = ApiTraits::StringifyHelpers::FromString(flatbufferOperatorNodeDesc->type()->c_str()); + if (type == DML_OPERATOR_INVALID) + { + throw std::invalid_argument("Graph operator node at index:" + std::to_string(nodeIndex) + + " either has empty or invalid operator type."); + } + const DML_OPERATOR_SCHEMA& schema = SchemaHelpers::GetSchema(type); + std::vector operatorFields(schema.FieldCount); + + auto inputNameItr = nodeInputNames->begin(); + uint32_t inputTensorDescIndex = 0; + + uint32_t outputTensorDescIndex = 0; + auto outputNameItr = nodeOutputNames->begin(); + + uint32_t attributeIndex = 0; + + + for (uint32_t fieldIndex = 0; fieldIndex < schema.FieldCount; fieldIndex++) + { + const DML_SCHEMA_FIELD* schemaField = &schema.Fields[fieldIndex]; + + OperatorFieldVariant field; + switch (schemaField->Kind) + { + case DML_SCHEMA_FIELD_KIND_INPUT_TENSOR: + { + if (inputNameItr == nodeInputNames->end()) + { + throw std::invalid_argument("Missing input names for node at index:" + std::to_string(nodeIndex)); + } + + if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC) + { + const flatbuffers::String* inputName = *inputNameItr; + inputNameItr++; + if (inputName->size() == 0) + { + field = OperatorFieldTypes::TensorDesc(); + break; + } + bool isConstantTensor = !constantInputs.empty() && constantInputs.find(inputName->c_str()) != constantInputs.end(); + + if (flatbufferOperatorNodeDesc->inputs()->size() <= inputTensorDescIndex) + { + throw std::invalid_argument("Expecting at least " + std::to_string(inputTensorDescIndex + 1) + + "input tensor desc for graph operator node at index:" + std::to_string(nodeIndex)); + } + const dml::ir::DmlBufferTensorDesc* tensorDesc = flatbufferOperatorNodeDesc->inputs()->Get(inputTensorDescIndex++); + field = CreateBufferTensorDesc(tensorDesc, isConstantTensor); + } + else if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY) + { + std::vector tensors; + while (inputTensorDescIndex < static_cast(flatbufferOperatorNodeDesc->inputs()->size())) + { + const flatbuffers::String* inputName = *inputNameItr; + inputNameItr++; + bool isConstantTensor = !constantInputs.empty() && constantInputs.find(inputName->c_str()) != constantInputs.end(); + + if (flatbufferOperatorNodeDesc->inputs()->size() <= inputTensorDescIndex) + { + throw std::invalid_argument("Expecting at least " + std::to_string(inputTensorDescIndex + 1) + + "input tensor desc for graph operator node at index:" + std::to_string(nodeIndex)); + } + const dml::ir::DmlBufferTensorDesc* tensorDesc = flatbufferOperatorNodeDesc->inputs()->Get(inputTensorDescIndex++); + tensors.push_back(CreateBufferTensorDesc(tensorDesc, isConstantTensor).value()); + } + field = tensors; + } + break; + } + case DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR: + { + if (outputNameItr == nodeOutputNames->end()) + { + throw std::invalid_argument("Missing output names for node at index:" + std::to_string(nodeIndex)); + } + + if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC) + { + const flatbuffers::String* outputName = *outputNameItr; + outputNameItr++; + + if (outputName->size() == 0) + { + field = OperatorFieldTypes::TensorDesc(); + break; + } + + if (flatbufferOperatorNodeDesc->outputs()->size() <= outputTensorDescIndex) + { + throw std::invalid_argument("Expecting at least " + std::to_string(outputTensorDescIndex + 1) + + "output tensor desc for graph operator node at index:" + std::to_string(nodeIndex)); + } + const dml::ir::DmlBufferTensorDesc* tensorDesc = flatbufferOperatorNodeDesc->outputs()->Get(outputTensorDescIndex++); + field = CreateBufferTensorDesc(tensorDesc); + } + else if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY) + { + std::vector tensors; + while (outputTensorDescIndex < static_cast(flatbufferOperatorNodeDesc->outputs()->size())) + { + if (flatbufferOperatorNodeDesc->outputs()->size() <= outputTensorDescIndex) + { + throw std::invalid_argument("Expecting at least " + std::to_string(outputTensorDescIndex + 1) + + "output tensor desc for graph operator node at index:" + std::to_string(nodeIndex)); + } + const dml::ir::DmlBufferTensorDesc* tensorDesc = flatbufferOperatorNodeDesc->outputs()->Get(outputTensorDescIndex++); + tensors.push_back(CreateBufferTensorDesc(tensorDesc).value()); + } + field = tensors; + } + break; + } + case DML_SCHEMA_FIELD_KIND_ATTRIBUTE: + { + if (flatbufferOperatorNodeDesc->attributes()->size() <= attributeIndex) + { + throw std::invalid_argument("Expecting at least " + std::to_string(attributeIndex + 1) + + "attributes for graph operator node at index:" + std::to_string(nodeIndex)); + } + const dml::ir::operatorFieldTypes::AttributeDesc* attributeDesc = + attributeIndex >= flatbufferOperatorNodeDesc->attributes()->size() ? + nullptr : + flatbufferOperatorNodeDesc->attributes()->Get(attributeIndex++); + field = CreateAttribute(schemaField, attributeDesc); + break; + } + } + + operatorFields[fieldIndex] = OperatorField(schemaField, std::move(field)); + } + + return AbstractOperatorDesc(&schema, std::move(operatorFields)); +} + +std::unordered_map ConvertToEdgeNameToIndexMap( + const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>* list) +{ + std::unordered_map nameToIndexMap; + for (uint32_t index = 0; index < list->size(); index++) + { + const flatbuffers::String* name = list->GetAsString(index); + if (name->size() == 0) + { + continue; + } + nameToIndexMap[name->string_view()] = index; + } + return nameToIndexMap; // NRVO will automatically move it. no need to use std::move +} + +template void PopulateEdges( + const uint32_t nodeIndex, + const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>* edgeNames, + const std::unordered_map& edgeNameToIndexMap, + /*out*/ std::vector& edges, + /*out*/ std::vector& intermediateEdges, + /*out*/ std::unordered_map& edgeToOutgoingNodeIndexMap) +{ + for (flatbuffers::uoffset_t edgeIndex = 0; edgeIndex < edgeNames->size(); edgeIndex++) + { + const flatbuffers::String* edgeName = edgeNames->Get(edgeIndex); + if (edgeName->size() == 0) + { + // This must be optional input/output + continue; + } + // edge can be graphInput or graphOutput + if (edgeNameToIndexMap.find(edgeName->string_view()) != edgeNameToIndexMap.end()) + { + EdgeType edge = {}; + edge.Name = edgeName->str(); + + if constexpr (std::is_same_v) + { + edge.GraphInputIndex = edgeNameToIndexMap.at(edgeName->string_view()); + edge.ToNodeIndex = nodeIndex; + edge.ToNodeInputIndex = edgeIndex; + } + else if constexpr (std::is_same_v) + { + edge.GraphOutputIndex = edgeNameToIndexMap.at(edgeName->string_view()); + edge.FromNodeIndex = nodeIndex; + edge.FromNodeOutputIndex = edgeIndex; + edgeToOutgoingNodeIndexMap[edgeName->string_view()] = {nodeIndex, edgeIndex}; + } + + edges.push_back(edge); + } + // edge is intermediate edge + else + { + if constexpr (std::is_same_v) + { + if (edgeToOutgoingNodeIndexMap.find(edgeName->string_view()) == edgeToOutgoingNodeIndexMap.end()) + { + throw std::range_error("Neither there is any graph input with name " + edgeName->str() + + "nor there is any node which has " + edgeName->str() + " as one of the output."); + } + auto& intermediateEdgeNodeIndex = edgeToOutgoingNodeIndexMap[edgeName->string_view()]; + DmlIntermediateSerializedGraphEdge intermediateEdge = {}; + intermediateEdge.Name = edgeName->str(); + intermediateEdge.FromNodeIndex = intermediateEdgeNodeIndex.nodeIndex; + intermediateEdge.FromNodeOutputIndex = intermediateEdgeNodeIndex.nodeOutputIndex; + intermediateEdge.ToNodeIndex = nodeIndex; + intermediateEdge.ToNodeInputIndex = edgeIndex; + intermediateEdges.push_back(std::move(intermediateEdge)); + } + else if constexpr (std::is_same_v) + { + edgeToOutgoingNodeIndexMap[edgeName->string_view()] = {nodeIndex, edgeIndex}; + } + } + } +} + +/* +* - Handling of empty optional input/output/attibute for non-constant node: +* input/output +* - and will have an null entry +* but the actual OperatorNodeDesc variant's +* and will not have any entry. +* attribute +* - will have null entry +*/ +DmlSerializedGraphDesc DeserializeDmlGraph( + const uint8_t* flatbufferGraphDescBlob, + /*out*/ std::vector>& rawData) +{ + if (flatbufferGraphDescBlob == nullptr) + { + throw std::invalid_argument("Given pointer to flatbuffer blob is null"); + } + const dml::ir::DmlGraphDesc* flatbufferGraphDesc = dml::ir::GetDmlGraphDesc(flatbufferGraphDescBlob); + + std::unordered_map graphInputEdgeToIndexMap = ConvertToEdgeNameToIndexMap(flatbufferGraphDesc->graphInputNames()); + std::unordered_map graphOutputEdgeToIndexMap = ConvertToEdgeNameToIndexMap(flatbufferGraphDesc->graphOutputNames()); + + std::unordered_map edgeToOutgoingNodeIndexMap; + std::unordered_set constantInputs; + + std::vector nodes(flatbufferGraphDesc->nodes()->size()); + std::vector inputEdges; + std::vector outputEdges; + std::vector intermediateEdges; + + for (uint32_t nodeIndex = 0; nodeIndex < flatbufferGraphDesc->nodes()->size(); nodeIndex++) + { + const dml::ir::DmlGraphNode* flatbufferNode = flatbufferGraphDesc->nodes()->Get(nodeIndex); + + PopulateEdges( + nodeIndex, + flatbufferNode->inputNames(), + graphInputEdgeToIndexMap, + inputEdges, + intermediateEdges, + edgeToOutgoingNodeIndexMap); + PopulateEdges( + nodeIndex, + flatbufferNode->outputNames(), + graphOutputEdgeToIndexMap, + outputEdges, + intermediateEdges, + edgeToOutgoingNodeIndexMap); + + DmlSerializedGraphNode node = {}; + if (flatbufferNode->name()->size() == 0) + { + throw std::invalid_argument("Graph node at index:" + std::to_string(nodeIndex) + " doesn't have any name"); + } + node.Name = flatbufferNode->name()->c_str(); + + if (flatbufferNode->desc_type() == dml::ir::NodeDesc_ConstantNodeDesc) + { + const dml::ir::ConstantNodeDesc* flatbufferConstantNode = flatbufferNode->desc_as_ConstantNodeDesc(); + if (flatbufferConstantNode->data_type() == dml::ir::ConstantNodeDescDetail_ConstantName) + { + if (flatbufferConstantNode->data_as_ConstantName()->name()->size() == 0) + { + throw std::invalid_argument("Constant node at index:" + std::to_string(nodeIndex) + + " doesn't have constant data name."); + } + + ConstantName constantNode = {flatbufferConstantNode->data_as_ConstantName()->name()->c_str()}; + node.Desc = constantNode; + // output of this node will part of constantInputs list + for (uint32_t outputIndex = 0; outputIndex < flatbufferNode->outputNames()->size(); outputIndex++) + { + constantInputs.insert(flatbufferNode->outputNames()->Get(outputIndex)->c_str()); + } + } + else if (flatbufferConstantNode->data_type() == dml::ir::ConstantNodeDescDetail_ConstantRawData) + { + + uint32_t rawDataSize = flatbufferConstantNode->data_as_ConstantRawData()->data()->size(); + rawData.push_back(std::make_unique(rawDataSize)); + std::transform( + flatbufferConstantNode->data_as_ConstantRawData()->data()->begin(), + flatbufferConstantNode->data_as_ConstantRawData()->data()->end(), + rawData.back().get(), + [](uint8_t b) {return static_cast(b);}); + + ConstantData constantData = {}; + constantData.dataSize = rawDataSize; + constantData.data = rawData.back().get(); + node.Desc = constantData; + } + + + } + else if (flatbufferNode->desc_type() == dml::ir::NodeDesc::NodeDesc_OperatorNodeDesc) + { + // convert dml::ir::OperatorNodeDesc to AbstractOperatorDesc + const dml::ir::OperatorNodeDesc* flatbufferOperatorNodeDesc = flatbufferNode->desc_as_OperatorNodeDesc(); + node.Desc = CreateAbstractOperatorDesc( + nodeIndex, + flatbufferOperatorNodeDesc, + flatbufferNode->inputNames(), + flatbufferNode->outputNames(), + constantInputs); + } + + nodes[nodeIndex] = node; + } + + DmlSerializedGraphDesc graphDesc; + graphDesc.InputCount = flatbufferGraphDesc->graphInputNames()->size(); + graphDesc.OutputCount = flatbufferGraphDesc->graphOutputNames()->size(); + graphDesc.InputEdges = std::move(inputEdges); + graphDesc.IntermediateEdges = std::move(intermediateEdges); + graphDesc.OutputEdges = std::move(outputEdges); + graphDesc.Nodes = std::move(nodes); + return graphDesc; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index 642d9aa03eeef..202b762d99e01 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -135,8 +135,10 @@ namespace DmlGraphFusionHelper void ProcessInputData( const ExecutionProviderImpl* providerImpl, + const bool graphSerializationEnabled, const std::vector& isInputsUploadedByDmlEP, - const std::vector& inputEdges, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex, const gsl::span subGraphInputArgNames, const std::unordered_map>& initializerNameToInitializerMap, onnxruntime::Graph& graph, @@ -162,8 +164,17 @@ namespace DmlGraphFusionHelper // Walk through each graph edge and mark used inputs inputsUsed.assign(fusedNodeInputCount, false); - for (const DML_INPUT_GRAPH_EDGE_DESC& edge : inputEdges) { - inputsUsed[edge.GraphInputIndex] = true; + for (auto it = serializedGraphInputIndexToSubgraphInputIndex->begin(); it != serializedGraphInputIndexToSubgraphInputIndex->end(); it++) { + inputsUsed[it->second] = true; + } + for (auto it = serializedGraphLargeConstantNameToSubgraphInputIndex->begin(); it != serializedGraphLargeConstantNameToSubgraphInputIndex->end(); it++) { + inputsUsed[it->second] = true; + } + + std::wstring modelName; + if (graphSerializationEnabled) + { + modelName = GetModelName(graph.ModelPath()); } for (uint32_t i = 0; i < initInputBindings.size(); i++) @@ -209,6 +220,10 @@ namespace DmlGraphFusionHelper // Tensor sizes in DML must be a multiple of 4 bytes large. tensorByteSize = AlignToPow2(tensorByteSize, 4); + if(graphSerializationEnabled) + { + WriteToFile(modelName, ConvertToWString(iter->first) + L".bin", reinterpret_cast(tensorPtr), tensorByteSize); + } if (inputRawData) { @@ -287,55 +302,158 @@ namespace DmlGraphFusionHelper return initializerPartitionMap; } + inline uint32_t GetConstantNodeGraphInputIndex( + const std::string& constantName, + const std::unordered_map* serializedGraphConstantNameToMainGraphInputIndex, + uint32_t& graphMaxInputIndex, + std::unordered_map& localConstantNameToIndexMap) + { + if (serializedGraphConstantNameToMainGraphInputIndex == nullptr) + { + if (localConstantNameToIndexMap.find(constantName) == localConstantNameToIndexMap.end()) + { + localConstantNameToIndexMap[constantName] = ++graphMaxInputIndex; + } + return localConstantNameToIndexMap[constantName]; + } + else + { + graphMaxInputIndex = std::max(graphMaxInputIndex, serializedGraphConstantNameToMainGraphInputIndex->at(constantName)); + return serializedGraphConstantNameToMainGraphInputIndex->at(constantName); + } + } + + template void ConvertGraphDesc( const Dml::GraphDescBuilder::GraphDesc& graphDesc, - _Out_ DML_GRAPH_DESC& dmlGraphDesc, const uint32_t inputCount, const uint32_t outputCount, - _Inout_ std::vector& dmlOperatorGraphNodes, - _Inout_ std::vector& dmlConstantGraphNodes, + IDMLDevice* device, + StackAllocator& allocator, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex, + _Out_ DML_GRAPH_DESC& dmlGraphDesc, + _Inout_ std::vector>& dmlOperators, _Inout_ std::vector& dmlGraphNodes, _Inout_ std::vector& dmlInputEdges, _Inout_ std::vector& dmlOutputEdges, _Inout_ std::vector& dmlIntermediateEdges) { - for (size_t i = 0; i < graphDesc.nodes.size(); ++i) + std::unordered_map oldNodeIndexToNewNodeIndexMap; + for (uint32_t index = 0; index < static_cast(graphDesc.Nodes.size()); index++) { - auto& nodeInfo = graphDesc.nodes[i]; - - if (std::holds_alternative>(nodeInfo.nodeDef)) + const DmlSerializedGraphNode& node = graphDesc.Nodes[index]; + if (std::holds_alternative(node.Desc)) { - dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{std::get>(nodeInfo.nodeDef).Get(), nodeInfo.name.data()}; - dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]}; + oldNodeIndexToNewNodeIndexMap[index] = static_cast(dmlGraphNodes.size()); + DML_OPERATOR_DESC dmlDesc = SchemaHelpers::ConvertOperatorDesc(std::get(node.Desc), &allocator); + ComPtr op; + ORT_THROW_IF_FAILED(device->CreateOperator(&dmlDesc, IID_PPV_ARGS(&op))); + dmlOperators.push_back(op); + DML_OPERATOR_GRAPH_NODE_DESC* dmlOperatorGraphNode = allocator.template Allocate(); + dmlOperatorGraphNode->Name = node.Name.data(); + dmlOperatorGraphNode->Operator = op.Get(); + dmlGraphNodes.push_back(DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, dmlOperatorGraphNode}); } else { - auto& nodeDefinitionData = std::get>(nodeInfo.nodeDef); - dmlConstantGraphNodes[i] = DML_CONSTANT_DATA_GRAPH_NODE_DESC{ - nodeDefinitionData.data(), - nodeDefinitionData.size(), - nodeInfo.name.data() - }; - - // TODO: Change as new header is ingested - dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{static_cast(2), &dmlConstantGraphNodes[i]}; + auto& constantNodeVariant = std::get(node.Desc); + if (std::holds_alternative(constantNodeVariant)) + { + oldNodeIndexToNewNodeIndexMap[index] = static_cast(dmlGraphNodes.size()); + + auto& constantData = std::get(constantNodeVariant); + + DML_CONSTANT_DATA_GRAPH_NODE_DESC* constantNode = allocator.template Allocate(); + constantNode->Name = node.Name.data(); + constantNode->DataSize = constantData.dataSize; + constantNode->Data = constantData.data; + dmlGraphNodes.push_back(DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_CONSTANT, constantNode}); + } } } - for (size_t i = 0; i < graphDesc.inputEdges.size(); ++i) + uint32_t graphMaxInputIndex = 0; + + for (size_t i = 0; i < graphDesc.InputEdges.size(); ++i) { - dmlInputEdges[i] = DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INPUT, &graphDesc.inputEdges[i]}; + DML_INPUT_GRAPH_EDGE_DESC* edge = allocator.template Allocate(); + // 1. If serializedGraphInputIndexToMainGraphInputIndex is not null: + // then use the corresponding main graph input index, because the caller will use corresponding + // main graph input index for extracting the actual input tensor from the main graph and + // the caller does not own the creation of dml bindings directly. + // Use Case: When the caller is ORT (DML EP) or DmlEngine. + // + // 2. If serializedGraphInputIndexToMainGraphInputIndex is null: + // then assign the sequential graph input index, because it owns the creation of dml bindings + // directly. + edge->GraphInputIndex = serializedGraphInputIndexToSubgraphInputIndex == nullptr ? + graphDesc.InputEdges[i].GraphInputIndex : + serializedGraphInputIndexToSubgraphInputIndex->at(graphDesc.InputEdges[i].GraphInputIndex); + edge->ToNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.InputEdges[i].ToNodeIndex]; + edge->ToNodeInputIndex = graphDesc.InputEdges[i].ToNodeInputIndex; + edge->Name = graphDesc.InputEdges[i].Name.data(); + + graphMaxInputIndex = std::max(graphMaxInputIndex, edge->GraphInputIndex); + dmlInputEdges.push_back(DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INPUT, edge}); } - for (size_t i = 0; i < graphDesc.outputEdges.size(); ++i) + for (size_t i = 0; i < graphDesc.OutputEdges.size(); ++i) { - dmlOutputEdges[i] = DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_OUTPUT, &graphDesc.outputEdges[i]}; + DML_OUTPUT_GRAPH_EDGE_DESC* edge = allocator.template Allocate(); + edge->GraphOutputIndex = graphDesc.OutputEdges[i].GraphOutputIndex; + edge->FromNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.OutputEdges[i].FromNodeIndex]; + edge->FromNodeOutputIndex = graphDesc.OutputEdges[i].FromNodeOutputIndex; + edge->Name = graphDesc.OutputEdges[i].Name.data(); + + dmlOutputEdges.push_back(DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_OUTPUT, edge}); } - for (size_t i = 0; i < graphDesc.intermediateEdges.size(); ++i) + std::unordered_map localConstantNameToIndexMap; + for (uint32_t i = 0; i < static_cast(graphDesc.IntermediateEdges.size()); ++i) { - dmlIntermediateEdges[i] = - DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INTERMEDIATE, &graphDesc.intermediateEdges[i]}; + DmlSerializedGraphNodeDescVariant descVariant = graphDesc.Nodes[graphDesc.IntermediateEdges[i].FromNodeIndex].Desc; + bool isConstantEdge = std::holds_alternative(descVariant); + if (isConstantEdge) + { + auto& constantNodeVariant = std::get(descVariant); + if (std::holds_alternative(constantNodeVariant)) + { + DML_INTERMEDIATE_GRAPH_EDGE_DESC* edge = allocator.template Allocate(); + edge->FromNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].FromNodeIndex]; + edge->FromNodeOutputIndex = graphDesc.IntermediateEdges[i].FromNodeOutputIndex; + edge->ToNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].ToNodeIndex]; + edge->ToNodeInputIndex = graphDesc.IntermediateEdges[i].ToNodeInputIndex; + edge->Name = graphDesc.IntermediateEdges[i].Name.data(); + dmlIntermediateEdges.push_back(DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INTERMEDIATE, edge}); + } + else + { + const std::string& constantName = graphDesc.Nodes[graphDesc.IntermediateEdges[i].FromNodeIndex].Name; + + DML_INPUT_GRAPH_EDGE_DESC* edge = allocator.template Allocate(); + edge->GraphInputIndex = GetConstantNodeGraphInputIndex( + constantName, + serializedGraphLargeConstantNameToSubgraphInputIndex, + graphMaxInputIndex, + localConstantNameToIndexMap); + edge->ToNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].ToNodeIndex]; + edge->ToNodeInputIndex = graphDesc.IntermediateEdges[i].ToNodeInputIndex; + edge->Name = graphDesc.IntermediateEdges[i].Name.data(); + + dmlInputEdges.push_back({DML_GRAPH_EDGE_TYPE_INPUT, edge}); + } + } + else + { + DML_INTERMEDIATE_GRAPH_EDGE_DESC* edge = allocator.template Allocate(); + edge->FromNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].FromNodeIndex]; + edge->FromNodeOutputIndex = graphDesc.IntermediateEdges[i].FromNodeOutputIndex; + edge->ToNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].ToNodeIndex]; + edge->ToNodeInputIndex = graphDesc.IntermediateEdges[i].ToNodeInputIndex; + edge->Name = graphDesc.IntermediateEdges[i].Name.data(); + dmlIntermediateEdges.push_back(DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INTERMEDIATE, edge}); + } } dmlGraphDesc.InputCount = inputCount; @@ -400,27 +518,34 @@ namespace DmlGraphFusionHelper Microsoft::WRL::ComPtr TryCreateCompiledOperator( const GraphDescBuilder::GraphDesc& graphDesc, const onnxruntime::IndexedSubGraph& indexedSubGraph, - const ExecutionProviderImpl* providerImpl) + const ExecutionProviderImpl* providerImpl, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex) { const uint32_t fusedNodeInputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->inputs.size()); const uint32_t fusedNodeOutputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->outputs.size()); // convert DML EP GraphDesc into DML_GRAPH_DESC and create IDMLCompiledOperator - DML_GRAPH_DESC dmlGraphDesc = {}; - std::vector dmlOperatorGraphNodes(graphDesc.nodes.size()); - std::vector dmlConstantGraphNodes(graphDesc.nodes.size()); + ComPtr device; + ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); - std::vector dmlGraphNodes(graphDesc.nodes.size()); - std::vector dmlInputEdges(graphDesc.inputEdges.size()); - std::vector dmlOutputEdges(graphDesc.outputEdges.size()); - std::vector dmlIntermediateEdges(graphDesc.intermediateEdges.size()); + StackAllocator<1024> allocator; + DML_GRAPH_DESC dmlGraphDesc = {}; + std::vector> dmlOperators; + std::vector dmlGraphNodes; + std::vector dmlInputEdges; + std::vector dmlOutputEdges; + std::vector dmlIntermediateEdges; ConvertGraphDesc( graphDesc, - dmlGraphDesc, fusedNodeInputCount, fusedNodeOutputCount, - dmlOperatorGraphNodes, - dmlConstantGraphNodes, + device.Get(), + allocator, + serializedGraphInputIndexToSubgraphInputIndex, + serializedGraphLargeConstantNameToSubgraphInputIndex, + dmlGraphDesc, + dmlOperators, dmlGraphNodes, dmlInputEdges, dmlOutputEdges, @@ -438,8 +563,6 @@ namespace DmlGraphFusionHelper executionFlags |= DML_EXECUTION_FLAG_DISABLE_META_COMMANDS; } - ComPtr device; - ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); ComPtr device1; ORT_THROW_IF_FAILED(device.As(&device1)); @@ -460,6 +583,7 @@ namespace DmlGraphFusionHelper } void FusePartitionAndRegisterKernel( + const uint32_t partitionIndex, onnxruntime::Graph& graph, onnxruntime::KernelRegistry* registryForPartitionKernels, const std::unordered_map>& initializerNameToInitializerMap, @@ -467,8 +591,43 @@ namespace DmlGraphFusionHelper const onnxruntime::IndexedSubGraph& indexedSubGraph, std::vector&& isInputsUploadedByDmlEP, const GraphDescBuilder::GraphDesc& graphDesc, - Microsoft::WRL::ComPtr compiledExecutionPlanOperator) + Microsoft::WRL::ComPtr compiledExecutionPlanOperator, + const bool graphSerializationEnabled, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex) { + if (graphSerializationEnabled) + { + + const std::wstring modelName = GetModelName(graph.ModelPath()); + auto buffer = SerializeDmlGraph(graphDesc); + + const std::wstring partitionName = + L"Partition_" + + std::to_wstring(partitionIndex) + + L".bin"; + WriteToFile(modelName, partitionName, buffer.data(), buffer.size()); + + std::vector> rawData; + DmlSerializedGraphDesc deserializedGraphDesc = DeserializeDmlGraph(buffer.data(), rawData); + GraphDescBuilder::GraphDesc deserializedDmlGraphDesc = {}; + deserializedDmlGraphDesc.InputCount = deserializedGraphDesc.InputCount; + deserializedDmlGraphDesc.InputEdges = std::move(deserializedGraphDesc.InputEdges); + deserializedDmlGraphDesc.IntermediateEdges = std::move(deserializedGraphDesc.IntermediateEdges); + deserializedDmlGraphDesc.Nodes = std::move(deserializedGraphDesc.Nodes); + deserializedDmlGraphDesc.OutputCount = deserializedGraphDesc.OutputCount; + deserializedDmlGraphDesc.OutputEdges = std::move(deserializedGraphDesc.OutputEdges); + deserializedDmlGraphDesc.reuseCommandList = graphDesc.reuseCommandList; + deserializedDmlGraphDesc.outputShapes = graphDesc.outputShapes; + + compiledExecutionPlanOperator = DmlGraphFusionHelper::TryCreateCompiledOperator( + deserializedDmlGraphDesc, + indexedSubGraph, + providerImpl, + serializedGraphInputIndexToSubgraphInputIndex, + serializedGraphLargeConstantNameToSubgraphInputIndex); + } + auto& fusedNode = graph.BeginFuseSubGraph(indexedSubGraph, indexedSubGraph.GetMetaDef()->name); fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider); @@ -482,8 +641,10 @@ namespace DmlGraphFusionHelper std::vector inputsUsed; ProcessInputData( providerImpl, + graphSerializationEnabled, isInputsUploadedByDmlEP, - graphDesc.inputEdges, + serializedGraphInputIndexToSubgraphInputIndex, + serializedGraphLargeConstantNameToSubgraphInputIndex, indexedSubGraph.GetMetaDef()->inputs, initializerNameToInitializerMap, graph, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h index f8f6162aaa1e0..f1e9654021196 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h @@ -45,12 +45,17 @@ namespace DmlGraphFusionHelper gsl::span> partitions ); + template void ConvertGraphDesc( const Dml::GraphDescBuilder::GraphDesc& graphDesc, - _Out_ DML_GRAPH_DESC& dmlGraphDesc, const uint32_t inputCount, const uint32_t outputCount, - _Inout_ std::vector& dmlOperatorGraphNodes, + IDMLDevice* device, + StackAllocator& allocator, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex, + _Out_ DML_GRAPH_DESC& dmlGraphDesc, + _Inout_ std::vector>& dmlOperators, _Inout_ std::vector& dmlGraphNodes, _Inout_ std::vector& dmlInputEdges, _Inout_ std::vector& dmlOutputEdges, @@ -69,9 +74,12 @@ namespace DmlGraphFusionHelper Microsoft::WRL::ComPtr TryCreateCompiledOperator( const GraphDescBuilder::GraphDesc& graphDesc, const onnxruntime::IndexedSubGraph& indexedSubGraph, - const ExecutionProviderImpl* providerImpl); + const ExecutionProviderImpl* providerImpl, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex); void FusePartitionAndRegisterKernel( + const uint32_t partitionIndex, onnxruntime::Graph& graph, onnxruntime::KernelRegistry* registryForPartitionKernels, const std::unordered_map>& initializerNameToInitializerMap, @@ -79,7 +87,10 @@ namespace DmlGraphFusionHelper const onnxruntime::IndexedSubGraph& indexedSubGraph, std::vector&& isInputsUploadedByDmlEP, const GraphDescBuilder::GraphDesc& graphDesc, - Microsoft::WRL::ComPtr compiledExecutionPlanOperator); + Microsoft::WRL::ComPtr compiledExecutionPlanOperator, + const bool graphSerializationEnabled, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex = nullptr, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex = nullptr); void RegisterDynamicKernel( onnxruntime::Graph& graph, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp index 679738b639ec9..35a2c451a49a5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp @@ -24,15 +24,20 @@ namespace Dml std::vector isInputsUploadedByDmlEP; GraphDescBuilder::GraphDesc graphDesc; std::unordered_map> isInitializerTransferable; + std::vector> smallConstantData; // Need to keep it alive for maintaining lifetime + std::unordered_map serializedGraphInputIndexToSubgraphInputIndex; + std::unordered_map serializedGraphLargeConstantNameToSubgraphInputIndex; }; } DmlGraphFusionTransformer::DmlGraphFusionTransformer( const std::string& name, - const onnxruntime::IExecutionProvider* provider + const onnxruntime::IExecutionProvider* provider, + const bool graphSerializationEnabled ) :onnxruntime::GraphTransformer(name), - m_providerImpl(static_cast(provider)->GetImpl()) + m_providerImpl(static_cast(provider)->GetImpl()), + graphSerializationEnabled(graphSerializationEnabled) { } @@ -227,23 +232,39 @@ namespace Dml ComPtr device; ORT_THROW_IF_FAILED(m_providerImpl->GetDmlDevice(device.GetAddressOf())); + // This map will be used to transfer the initializer to D3D12 system heap memory. + // 'serializedDmlGraphDesc' will have constant input as intermediate edges, that's why + // we need a mapping between intermediateEdgeIndex and indexedSubGraph's (a given partition) + // input arg index. + // For ex: Let's say intermediate edge index = idx, then + // indexedSubGraphInputArgIdx = constantEdgeIdxToSubgraphInputArgIdxMap[idx]; + // corresponding constant tensor = initializerNameToInitializerMap[indexedSubGraph.GetMetaDef()->inputs[indexedSubGraphInputArgIdx]] + // We are using intermediate edge index as a key because same constant tensor can be used by + // multiple nodes. + std::unordered_map serializedGraphInputIndexToSubgraphInputIndex; + std::unordered_map serializedGraphLargeConstantNameToSubgraphInputIndex; + std::vector> smallConstantData; GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( isInputsUploadedByDmlEP.data(), isInputsUploadedByDmlEP.size(), isInitializerTransferable, partitionNodePropsMap, - device.Get(), m_providerImpl, modelPath, subgraphNodes, subgraphInputs, - subgraphOutputs); + subgraphOutputs, + serializedGraphInputIndexToSubgraphInputIndex, + serializedGraphLargeConstantNameToSubgraphInputIndex, + smallConstantData); // Compile the operator auto compiledPartition = DmlGraphFusionHelper::TryCreateCompiledOperator( graphDesc, indexedSubGraph, - m_providerImpl); + m_providerImpl, + &serializedGraphInputIndexToSubgraphInputIndex, + &serializedGraphLargeConstantNameToSubgraphInputIndex); if (!compiledPartition) { @@ -264,6 +285,9 @@ namespace Dml compiledPartitionInfo->isInputsUploadedByDmlEP = std::move(isInputsUploadedByDmlEP); compiledPartitionInfo->graphDesc = std::move(graphDesc); compiledPartitionInfo->isInitializerTransferable = std::move(isInitializerTransferable); + compiledPartitionInfo->smallConstantData = std::move(smallConstantData); + compiledPartitionInfo->serializedGraphInputIndexToSubgraphInputIndex = std::move(serializedGraphInputIndexToSubgraphInputIndex); + compiledPartitionInfo->serializedGraphLargeConstantNameToSubgraphInputIndex = std::move(serializedGraphLargeConstantNameToSubgraphInputIndex); compiledPartitionInfos[partitionIndex] = std::move(compiledPartitionInfo); } } @@ -271,12 +295,14 @@ namespace Dml } while (!additionalSplittingNodes.empty()); + uint32_t partitionIndex = 0; for (auto&& compiledPartitionInfo : compiledPartitionInfos) { // Null compiled operators were not DML partitions if (compiledPartitionInfo) { DmlGraphFusionHelper::FusePartitionAndRegisterKernel( + partitionIndex++, graph, m_providerImpl->GetKernelRegistry().get(), compiledPartitionInfo->isInitializerTransferable, @@ -284,7 +310,10 @@ namespace Dml compiledPartitionInfo->indexedSubGraph, std::move(compiledPartitionInfo->isInputsUploadedByDmlEP), compiledPartitionInfo->graphDesc, - compiledPartitionInfo->compiledOperator); + compiledPartitionInfo->compiledOperator, + graphSerializationEnabled, + &compiledPartitionInfo->serializedGraphInputIndexToSubgraphInputIndex, + &compiledPartitionInfo->serializedGraphLargeConstantNameToSubgraphInputIndex); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h index 19dab0c89943c..b370f3ef9043c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h @@ -16,7 +16,8 @@ class DmlGraphFusionTransformer : public onnxruntime::GraphTransformer public: DmlGraphFusionTransformer( const std::string& name, - const onnxruntime::IExecutionProvider* provider + const onnxruntime::IExecutionProvider* provider, + const bool graphSerializationEnabled ); public: @@ -38,5 +39,6 @@ class DmlGraphFusionTransformer : public onnxruntime::GraphTransformer private: const ExecutionProviderImpl* m_providerImpl = nullptr; + const bool graphSerializationEnabled = false; }; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphSerialization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphSerialization.cpp new file mode 100644 index 0000000000000..5355964e8db74 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphSerialization.cpp @@ -0,0 +1,580 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +#pragma once +#include "precomp.h" + +template +T* ReadAs(uint8_t* base, size_t byteOffset) +{ + return reinterpret_cast(base + byteOffset); +} + +void SerializeAttributeDescs( + flatbuffers::FlatBufferBuilder& builder, + const AbstractOperatorDesc& operatorDesc, + /*out*/ std::vector>& attributeDescs); + +flatbuffers::Offset serializeActivation( + flatbuffers::FlatBufferBuilder& builder, + const AbstractOperatorDesc& activationOperatorDesc) +{ + std::vector> attributeDescs; + SerializeAttributeDescs(builder, activationOperatorDesc, attributeDescs); + + flatbuffers::Offset offset = dml::ir::operatorFieldTypes::CreateActivationDirect( + builder, + activationOperatorDesc.schema->OperatorName, + &attributeDescs); + return offset; +} + +void SerializeAttributeDescs( + flatbuffers::FlatBufferBuilder& builder, + const AbstractOperatorDesc& operatorDesc, + /*out*/ std::vector>& attributeDescs) +{ + for (const OperatorField& field : operatorDesc.fields) + { + if (field.GetSchema()->Kind == DML_SCHEMA_FIELD_KIND_INPUT_TENSOR || + field.GetSchema()->Kind == DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR) + { + continue; + } + + flatbuffers::Offset offset; + + if (std::holds_alternative(field.GetData())) + { + const OperatorFieldTypes::FusedActivationOperatorDesc& fusedActivation = field.AsFusedActivationOperatorDesc(); + if (!fusedActivation.has_value()) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + nullptr, + dml::ir::operatorFieldTypes::AttributeFieldVariant_Activation); + } + else + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_Activation, + serializeActivation(builder, fusedActivation.value()).Union()); + } + } + else if (std::holds_alternative(field.GetData())) + { + const OperatorFieldTypes::FusedActivationOperatorDescArray& fusedActivations = + field.AsFusedActivationOperatorDescArray(); + if (!fusedActivations.has_value()) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + nullptr, + dml::ir::operatorFieldTypes::AttributeFieldVariant_ActivationArray); + } + else + { + std::vector> fbActivations; + + for (AbstractOperatorDesc activationOpDesc : fusedActivations.value()) + { + flatbuffers::Offset fbActivation = + serializeActivation(builder, activationOpDesc); + fbActivations.push_back(fbActivation); + } + + flatbuffers::Offset activationOffset = + dml::ir::operatorFieldTypes::CreateActivationArrayDirect(builder, &fbActivations); + + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_ActivationArray, + activationOffset.Union()); + } + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_UInt32, + builder.CreateStruct(dml::ir::operatorFieldTypes::UInt32(field.AsUInt())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_UInt64, + builder.CreateStruct(dml::ir::operatorFieldTypes::UInt64(field.AsUInt64())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_Int32, + builder.CreateStruct(dml::ir::operatorFieldTypes::Int32(field.AsInt())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_Float32, + builder.CreateStruct(dml::ir::operatorFieldTypes::Float32(field.AsFloat())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_UIntArray, + dml::ir::operatorFieldTypes::CreateUIntArray(builder, builder.CreateVector(field.AsUIntArray())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_IntArray, + dml::ir::operatorFieldTypes::CreateIntArray(builder, builder.CreateVector(field.AsIntArray())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_FloatArray, + dml::ir::operatorFieldTypes::CreateFloatArray(builder, builder.CreateVector(field.AsFloatArray())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + const OperatorFieldTypes::ScaleBias& scaleBias = field.AsScaleBias(); + if (!scaleBias.has_value()) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + nullptr, + dml::ir::operatorFieldTypes::AttributeFieldVariant_ScaleBias); + } + else + { + dml::ir::operatorFieldTypes::ScaleBias fbScaleBias(scaleBias.value().Scale, scaleBias.value().Bias); + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_ScaleBias, + builder.CreateStruct(fbScaleBias).Union()); + } + } + else if (std::holds_alternative(field.GetData())) + { + const DML_SIZE_2D size2d = field.AsSize2D(); + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_Size2D, + builder.CreateStruct(dml::ir::operatorFieldTypes::Size2D(size2d.Width, size2d.Height)).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + OperatorFieldTypes::ScalarUnion scalarUnion = field.AsScalarUnion(); + dml::ir::operatorFieldTypes::ByteArray byteArr; + for (uint32_t index = 0; index < static_cast(sizeof(scalarUnion.Bytes)); index++) + { + byteArr.mutable_data()->Mutate(index, scalarUnion.Bytes[index]); + } + + flatbuffers::Offset scalarUnionOffset = + dml::ir::operatorFieldTypes::CreateScalarUnionData( + builder, + dml::ir::operatorFieldTypes::ScalarVariant_ByteArray, + builder.CreateStruct(byteArr).Union()); + + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_ScalarUnionData, + scalarUnionOffset.Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_Bool, + builder.CreateStruct(dml::ir::operatorFieldTypes::Bool(field.AsBool())).Union()); + } + else + { + continue; + } + + attributeDescs.push_back(offset); + } +} + +flatbuffers::Offset SerializeDmlTensorDesc( + flatbuffers::FlatBufferBuilder& builder, + const DmlBufferTensorDesc* tensorDesc) +{ + const std::vector *strides = nullptr; + if (tensorDesc->strides.has_value()) + { + strides = &tensorDesc->strides.value(); + } + + flatbuffers::Offset offset = dml::ir::CreateDmlBufferTensorDescDirect( + builder, + ApiTraits::StringifyHelpers::ToString(tensorDesc->dataType), + &tensorDesc->sizes, + strides, + tensorDesc->totalTensorSizeInBytes); + return offset; +} + +flatbuffers::Offset SerializeOperatorNodeDesc( + flatbuffers::FlatBufferBuilder& builder, + const AbstractOperatorDesc& operatorDesc) +{ + const DML_OPERATOR_SCHEMA* operatorSchema = operatorDesc.schema; + + std::vector> inputTensorDescs; + std::vector> outputTensorDescs; + + for (const DmlBufferTensorDesc* tensorDesc : operatorDesc.GetInputTensors()) + { + if (tensorDesc == nullptr) + { + continue; + } + flatbuffers::Offset serializedDmlTensorDesc = SerializeDmlTensorDesc(builder, tensorDesc); + inputTensorDescs.push_back(serializedDmlTensorDesc); + } + + for (const DmlBufferTensorDesc* tensorDesc : operatorDesc.GetOutputTensors()) + { + if (tensorDesc == nullptr) + { + continue; + } + flatbuffers::Offset serializedDmlTensorDesc = SerializeDmlTensorDesc(builder, tensorDesc); + outputTensorDescs.push_back(serializedDmlTensorDesc); + } + + std::vector> attributeDescs; + SerializeAttributeDescs(builder, operatorDesc, attributeDescs); + + flatbuffers::Offset offset = dml::ir::CreateOperatorNodeDesc( + builder, + builder.CreateString(operatorSchema->OperatorName), + builder.CreateVector(inputTensorDescs), + builder.CreateVector(outputTensorDescs), + builder.CreateVector(attributeDescs)); + return offset.Union(); +} + +flatbuffers::Offset SerializeConstantNodeDesc( + flatbuffers::FlatBufferBuilder& builder, + uint32_t nodeIndex, + const DmlSerializedGraphNodeConstantVariant& constantNodeDesc) +{ + flatbuffers::Offset offset; + + if (std::holds_alternative(constantNodeDesc)) + { + auto& constantName = std::get(constantNodeDesc); + if (constantName.name.empty()) + { + throw std::invalid_argument("Graph constant node at index:" + std::to_string(nodeIndex) + + " doesn't have the constant data name."); + } + + flatbuffers::Offset constantNameOffset = dml::ir::CreateConstantName( + builder, + builder.CreateString(constantName.name)); + + offset = dml::ir::CreateConstantNodeDesc( + builder, + dml::ir::ConstantNodeDescDetail_ConstantName, + constantNameOffset.Union()); + } + else + { + auto& constantData = std::get(constantNodeDesc); + std::vector rawBytes; + std::transform(constantData.data, constantData.data + constantData.dataSize, + std::back_inserter(rawBytes), [](std::byte b) {return static_cast(b); }); + flatbuffers::Offset constantDataOffset = dml::ir::CreateConstantRawDataDirect( + builder, + &rawBytes); + + offset = dml::ir::CreateConstantNodeDesc( + builder, + dml::ir::ConstantNodeDescDetail_ConstantRawData, + constantDataOffset.Union()); + } + + return offset.Union(); +} + +flatbuffers::Offset SerializeNode( + flatbuffers::FlatBufferBuilder& builder, + const uint32_t nodeIndex, + const DmlSerializedGraphNode& graphNode, + const std::vector>& nodeInputNames, + const std::vector>& nodeOutputNames) +{ + if (graphNode.Name.empty()) + { + throw std::invalid_argument("Graph node at index:" + std::to_string(nodeIndex) + + " does not have any name."); + } + + flatbuffers::Offset offset; + if (std::holds_alternative(graphNode.Desc)) + { + auto& operatorNode = std::get(graphNode.Desc); + offset = dml::ir::CreateDmlGraphNode( + builder, + dml::ir::NodeDesc_OperatorNodeDesc, + SerializeOperatorNodeDesc(builder, operatorNode), + builder.CreateString(graphNode.Name), + builder.CreateVector(nodeInputNames), + builder.CreateVector(nodeOutputNames)); + } + else + { + auto& constantNodeVariant = std::get(graphNode.Desc); + offset = dml::ir::CreateDmlGraphNode( + builder, + dml::ir::NodeDesc_ConstantNodeDesc, + SerializeConstantNodeDesc(builder, nodeIndex, constantNodeVariant), + builder.CreateString(graphNode.Name), + builder.CreateVector(nodeInputNames), + builder.CreateVector(nodeOutputNames)); + } + return offset; +} + +/* +* validates input/output edges and throws exception if an edge +* does not have a name or if an edge has more than 1 names. +*/ +template +std::unordered_map> ConvertToEdgeIndexToNameMap( + const std::vector& edges, + flatbuffers::FlatBufferBuilder& builder) +{ + std::unordered_map> edgeIndexToNameMap; + for (auto& edge : edges) + { + uint32_t index; + if constexpr (std::is_same_v) + { + index = edge.GraphInputIndex; + } + else if constexpr (std::is_same_v) + { + index = edge.GraphOutputIndex; + } + + if (edge.Name.empty()) + { + throw std::invalid_argument("Graph input or output edge at index " + std::to_string(index) + " does not have name."); + } + + if (edgeIndexToNameMap.find(index) != edgeIndexToNameMap.end()) + { + flatbuffers::String* edgeName = ReadAs( + builder.GetCurrentBufferPointer(), + builder.GetSize() - edgeIndexToNameMap[index].o); + if (edge.Name != edgeName->str()) + { + throw std::invalid_argument("Graph input or output edge at index " + std::to_string(index) + " has more than 1 names."); + } + } + + edgeIndexToNameMap[index] = builder.CreateString(edge.Name); + } + return edgeIndexToNameMap; // NRVO will automatically move it. no need to use std::move +} + +void PopulateNonConstantNodeInputOutputCount( + const std::vector& nodes, + /*out*/ std::vector& nodeInputCounts, + /*out*/ std::vector& nodeOutputCounts) +{ + for (uint32_t nodeIndex = 0; nodeIndex < static_cast(nodes.size()); nodeIndex++) + { + auto& node = nodes[nodeIndex]; + if (std::holds_alternative(node.Desc)) + { + auto& operatorNode = std::get(node.Desc); + nodeInputCounts[nodeIndex] = std::max( + nodeInputCounts[nodeIndex], + static_cast(operatorNode.GetInputTensors().size())); + + nodeOutputCounts[nodeIndex] = std::max( + nodeOutputCounts[nodeIndex], + static_cast(operatorNode.GetOutputTensors().size())); + } + } +} + +void PopulateConstantNodeInputOutputCount( + const std::vector& edges, + /*out*/std::vector& maxInputIndexForNodes, + /*out*/std::vector& maxOutputIndexForNodes) +{ + for (auto& edge : edges) + { + maxInputIndexForNodes[edge.ToNodeIndex] = std::max(maxInputIndexForNodes[edge.ToNodeIndex], edge.ToNodeInputIndex + 1); + maxOutputIndexForNodes[edge.FromNodeIndex] = std::max(maxOutputIndexForNodes[edge.FromNodeIndex], edge.FromNodeOutputIndex + 1); + } +} + +/* +* validates intermediate edge and throws exception if an edge +* does not have a name or if an edge has more than 1 names. +*/ +void PopulateNodeInputOutputNames( + flatbuffers::FlatBufferBuilder& builder, + const DmlSerializedGraphDesc& graphDesc, + const std::unordered_map>& graphInputIndexToNameMap, + const std::unordered_map>& graphOutputIndexToNameMap, + /*out*/std::vector>>& nodeToInputNames, + /*out*/std::vector>>& nodeToOutputNames) +{ + for (auto& edge : graphDesc.InputEdges) + { + nodeToInputNames[edge.ToNodeIndex][edge.ToNodeInputIndex] = graphInputIndexToNameMap.at(edge.GraphInputIndex); + } + + for (auto& edge : graphDesc.OutputEdges) + { + nodeToOutputNames[edge.FromNodeIndex][edge.FromNodeOutputIndex] = graphOutputIndexToNameMap.at(edge.GraphOutputIndex); + } + + std::unordered_map>> intermediateEdgeNames; + for (uint32_t edgeIndex = 0; edgeIndex < static_cast(graphDesc.IntermediateEdges.size()); edgeIndex++) + { + auto& edge = graphDesc.IntermediateEdges[edgeIndex]; + if (edge.Name.empty()) + { + throw std::invalid_argument( + "Graph intermediate edge from nodeIndex:" + std::to_string(edge.FromNodeIndex) + + " & nodeOutputIndex:" + std::to_string(edge.FromNodeOutputIndex) + " doesn't have name."); + } + + if (intermediateEdgeNames.find(edge.FromNodeIndex) != intermediateEdgeNames.end() && + intermediateEdgeNames[edge.FromNodeIndex].find(edge.FromNodeOutputIndex) != intermediateEdgeNames[edge.FromNodeIndex].end()) + { + flatbuffers::Offset edgeNameOffset = intermediateEdgeNames[edge.FromNodeIndex][edge.FromNodeOutputIndex]; + flatbuffers::String* edgeName = ReadAs( + builder.GetCurrentBufferPointer(), + builder.GetSize() - edgeNameOffset.o); + + if (edgeName->str() != edge.Name) + { + throw std::invalid_argument( + "Graph intermediate edge from nodeIndex:" + std::to_string(edge.FromNodeIndex) + + " & nodeOutputIndex:" + std::to_string(edge.FromNodeOutputIndex) + " has more than 1 names."); + } + } + else + { + intermediateEdgeNames[edge.FromNodeIndex][edge.FromNodeOutputIndex] = builder.CreateString(edge.Name.c_str()); + } + nodeToInputNames[edge.ToNodeIndex][edge.ToNodeInputIndex] = intermediateEdgeNames[edge.FromNodeIndex][edge.FromNodeOutputIndex]; + nodeToOutputNames[edge.FromNodeIndex][edge.FromNodeOutputIndex] = intermediateEdgeNames[edge.FromNodeIndex][edge.FromNodeOutputIndex]; + } +} + + +/* +* - If an edge is connected to multiple nodes, then there will be multiple instances +* of input or intermediate edges, all with the same name. +* - The input will be validated incrementally throughout the execution +* of the method. +* - Handling of empty optional input/output/attibute for non-constant node: +* input/output +* - and will have an null entry +* but the actual OperatorNodeDesc variant's +* and will not have any entry. +* attribute +* - will have null entry +*/ +flatbuffers::DetachedBuffer SerializeDmlGraph(const DmlSerializedGraphDesc& graphDesc) +{ + + flatbuffers::FlatBufferBuilder builder(1024); + if (graphDesc.Nodes.empty()) + { + return builder.Release(); + } + + // create input/output edge index to name map + std::unordered_map> graphInputIndexToNameMap = + ConvertToEdgeIndexToNameMap(graphDesc.InputEdges, builder); + std::unordered_map> graphOutputIndexToNameMap = + ConvertToEdgeIndexToNameMap(graphDesc.OutputEdges, builder); + + /* + * - Calculate number of input/output for each operator to allocate + * appropriate amount of memory for each node to store input/output names. + * - Non-constant node's input/output count can be determined by the + * AbstractOperatorDesc. + * - Constant node will only have outgoing edges and those outgoing edges + * will be intermediate edges. + */ + std::vector nodeInputCounts(graphDesc.Nodes.size(), 0); + std::vector nodeOutputCounts(graphDesc.Nodes.size(), 0); + PopulateNonConstantNodeInputOutputCount(graphDesc.Nodes, nodeInputCounts, nodeOutputCounts); + PopulateConstantNodeInputOutputCount(graphDesc.IntermediateEdges, nodeInputCounts, nodeOutputCounts); + + // populate node input/output names. + std::vector>> nodeToInputNames(graphDesc.Nodes.size()); + std::vector>> nodeToOutputNames(graphDesc.Nodes.size()); + for (uint32_t nodeIndex = 0; nodeIndex < static_cast(graphDesc.Nodes.size()); nodeIndex++) + { + nodeToInputNames[nodeIndex].assign(nodeInputCounts[nodeIndex], builder.CreateString(nullptr, 0)); + nodeToOutputNames[nodeIndex].assign(nodeOutputCounts[nodeIndex], builder.CreateString(nullptr, 0)); + } + PopulateNodeInputOutputNames(builder, graphDesc, graphInputIndexToNameMap, graphOutputIndexToNameMap, nodeToInputNames, nodeToOutputNames); + + // Create flatbuffer node objects + std::vector> nodes(graphDesc.Nodes.size()); + for (uint32_t nodeIndex = 0; nodeIndex < static_cast(graphDesc.Nodes.size()); nodeIndex++) + { + nodes[nodeIndex] = SerializeNode( + builder, + nodeIndex, + graphDesc.Nodes[nodeIndex], + nodeToInputNames[nodeIndex], + nodeToOutputNames[nodeIndex]); + } + + // Convert to std::vector to create the object. + std::vector> graphInputNames(graphDesc.InputCount, builder.CreateString(nullptr, 0)); + std::vector> graphOutputNames(graphDesc.OutputCount, builder.CreateString(nullptr, 0)); + for (const auto& [key, value] : graphInputIndexToNameMap) + { + graphInputNames[key] = value; + } + for (const auto& [key, value] : graphOutputIndexToNameMap) + { + graphOutputNames[key] = value; + } + + flatbuffers::Offset dmlGraphDescOffset = dml::ir::CreateDmlGraphDescDirect( + builder, + &nodes, + &graphInputNames, + &graphOutputNames); + builder.Finish(dmlGraphDescOffset); + return builder.Release(); +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp index 5c7b7bff1e370..0f0d445a95bae 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -180,32 +180,50 @@ namespace Dml // Convert partitionONNXGraph into DML EP GraphDesc ComPtr device; ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); + // This map will be used to transfer the initializer to D3D12 system heap memory. + // 'serializedDmlGraphDesc' will have constant input as intermediate edges, that's why + // we need a mapping between intermediateEdgeIndex and indexedSubGraph's (a given partition) + // input arg index. + // For ex: Let's say intermediate edge index = idx, then + // indexedSubGraphInputArgIdx = constantEdgeIdxToSubgraphInputArgIdxMap[idx]; + // corresponding constant tensor = initializerNameToInitializerMap[indexedSubGraph.GetMetaDef()->inputs[indexedSubGraphInputArgIdx]] + // We are using intermediate edge index as a key because same constant tensor can be used by + // multiple nodes. + std::unordered_map serializedGraphInputIndexToSubgraphInputIndex; + std::unordered_map serializedGraphLargeConstantNameToSubgraphInputIndex; + std::vector> smallConstantData; GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( isInputsUploadedByDmlEP.data(), isInputsUploadedByDmlEP.size(), m_isInitializerTransferable, m_partitionNodePropsMap, - device.Get(), providerImpl, m_modelPath, m_subgraphNodePointers, m_subgraphInputs, - m_subgraphOutputs); + m_subgraphOutputs, + serializedGraphInputIndexToSubgraphInputIndex, + serializedGraphLargeConstantNameToSubgraphInputIndex, + smallConstantData); m_outputShapes = graphDesc.outputShapes; // Walk through each graph edge and mark used inputs m_inputsUsed.resize(fusedNodeInputCount, false); - for (const DML_INPUT_GRAPH_EDGE_DESC& edge : graphDesc.inputEdges) - { - m_inputsUsed[edge.GraphInputIndex] = true; + for (auto it = serializedGraphInputIndexToSubgraphInputIndex.begin(); it != serializedGraphInputIndexToSubgraphInputIndex.end(); it++) { + m_inputsUsed[it->second] = true; + } + for (auto it = serializedGraphLargeConstantNameToSubgraphInputIndex.begin(); it != serializedGraphLargeConstantNameToSubgraphInputIndex.end(); it++) { + m_inputsUsed[it->second] = true; } // Compile the operator m_compiledExecutionPlanOperator = DmlGraphFusionHelper::TryCreateCompiledOperator( graphDesc, *m_indexedSubGraph, - providerImpl); + providerImpl, + &serializedGraphInputIndexToSubgraphInputIndex, + &serializedGraphLargeConstantNameToSubgraphInputIndex); // Queue references to objects which must be kept alive until resulting GPU work completes m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get()); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 8a32d06534dda..6c347ebdca7c1 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -771,8 +771,14 @@ namespace Dml !native16BitShaderOpsSupported && IsCustomOpShader(node)) { - nodeContainsSupportedDataTypes = false; - return; + // STFT is a special case since it has a dml ep registered + // graph transformation that will decompose fp16 STFT into convolution + // and so it is OK to register for fp16. + if (strcmp("STFT", node.OpType().c_str()) != 0) + { + nodeContainsSupportedDataTypes = false; + return; + } } // Allow nodeArgs that are SequenceTensor when they are actually implemented by CPU Kernels. diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 5617bc7bdcac6..841d6244a983e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -270,7 +270,7 @@ namespace Dml return m_impl->OnSessionInitializationEnd(); } - virtual onnxruntime::Status Sync() const final override + onnxruntime::Status Sync() const final override { // Completely wait until the device has completed all preceding tasks. // The application could have called SynchronizeBoundOutputs(). @@ -278,7 +278,7 @@ namespace Dml return Status::OK(); } - virtual onnxruntime::Status OnRunEnd(bool /*sync_stream*/) final override + onnxruntime::Status OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) final override { // Flush any pending work to the GPU, but don't block for completion, permitting it // to overlap other work. diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h index a5415ba85f3d3..7c25755a7d09e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h @@ -24,8 +24,8 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 161; - static constexpr size_t ActivationFunctionCount = 24; + static constexpr auto ValueCount = 168; + static constexpr size_t ActivationFunctionCount = 26; }; template <> @@ -62,7 +62,7 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 4; + static constexpr auto ValueCount = 5; }; template <> @@ -86,7 +86,7 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 8; + static constexpr auto ValueCount = 13; }; template <> @@ -119,6 +119,12 @@ struct EnumTraits static constexpr auto ValueCount = 1; }; +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 5; +}; + template constexpr auto EnumValueCount = EnumTraits::ValueCount; @@ -495,12 +501,6 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_POOLING; }; -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING; -}; - template <> struct OperatorDescTraits { @@ -879,6 +879,12 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT; +}; + template <> struct OperatorDescTraits { @@ -1029,6 +1035,18 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_DIAGONAL_MATRIX1; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MULTIHEAD_ATTENTION; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING; +}; + template <> struct OperatorDescTraits { @@ -1174,9 +1192,15 @@ struct OperatorDescTraits }; template <> -struct OperatorDescTraits +struct OperatorDescTraits { - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MULTIHEAD_ATTENTION; + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SWISH; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_HARD_SWISH; }; template @@ -1502,12 +1526,6 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_POOLING> using DescType = DML_ROI_POOLING_OPERATOR_DESC; }; -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING> -{ - using DescType = DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC; -}; - template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE> { @@ -2036,6 +2054,24 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_DIAGONAL_MATRIX1> using DescType = DML_DIAGONAL_MATRIX1_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MULTIHEAD_ATTENTION> +{ + using DescType = DML_MULTIHEAD_ATTENTION_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING> +{ + using DescType = DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT> +{ + using DescType = DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_ELU> { @@ -2181,14 +2217,20 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_GELU> }; template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MULTIHEAD_ATTENTION> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SWISH> { - using DescType = DML_MULTIHEAD_ATTENTION_OPERATOR_DESC; + using DescType = DML_ACTIVATION_SWISH_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_HARD_SWISH> +{ + using DescType = DML_ACTIVATION_HARD_SWISH_OPERATOR_DESC; }; // Calls a visitor functor, supplying an empty operator desc corresponding to the given DML_OPERATOR_TYPE as // the first argument. -// +// // For example: // Visit(DML_OPERATOR_ELEMENT_WISE_IDENTITY, [](auto tag) { // using T = decltype(tag); // T is one of the DML_*_OPERATOR_DESC structs @@ -2485,6 +2527,10 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_DIAGONAL_MATRIX1_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_MULTIHEAD_ATTENTION: return std::invoke(std::forward(visitor), DML_MULTIHEAD_ATTENTION_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: + return std::invoke(std::forward(visitor), DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: + return std::invoke(std::forward(visitor), DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_ELU: return std::invoke(std::forward(visitor), DML_ACTIVATION_ELU_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_CELU: @@ -2533,13 +2579,10 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_ACTIVATION_SHRINK_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_GELU: return std::invoke(std::forward(visitor), DML_ACTIVATION_GELU_OPERATOR_DESC{}, std::forward(args)...); - -#pragma warning(push) -#pragma warning(disable: 4063) - case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: - return std::invoke(std::forward(visitor), DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward(args)...); -#pragma warning(pop) - + case DML_OPERATOR_ACTIVATION_SWISH: + return std::invoke(std::forward(visitor), DML_ACTIVATION_SWISH_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_HARD_SWISH: + return std::invoke(std::forward(visitor), DML_ACTIVATION_HARD_SWISH_OPERATOR_DESC{}, std::forward(args)...); default: ORT_THROW_HR(E_INVALIDARG); return std::invoke(std::forward(visitor), DML_ACTIVATION_RELU_OPERATOR_DESC{}, std::forward(args)...); @@ -2547,7 +2590,55 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args } #pragma warning(pop) +namespace StringifyHelpers +{ +template +inline gsl::czstring ToString(T value) +{ +#ifndef WAI_BUILD_LINUX + // Clang will instantiate this template even if it isn't used, + // so this static_assert will always fire and break the build. + static_assert(false, "Not implemented for this type"); +#endif +} + +template <> +inline gsl::czstring ToString(DML_TENSOR_DATA_TYPE value) +{ + switch (value) + { + case DML_TENSOR_DATA_TYPE_UNKNOWN: return "DML_TENSOR_DATA_TYPE_UNKNOWN"; + case DML_TENSOR_DATA_TYPE_FLOAT32: return "DML_TENSOR_DATA_TYPE_FLOAT32"; + case DML_TENSOR_DATA_TYPE_FLOAT16: return "DML_TENSOR_DATA_TYPE_FLOAT16"; + case DML_TENSOR_DATA_TYPE_UINT32: return "DML_TENSOR_DATA_TYPE_UINT32"; + case DML_TENSOR_DATA_TYPE_UINT16: return "DML_TENSOR_DATA_TYPE_UINT16"; + case DML_TENSOR_DATA_TYPE_UINT8: return "DML_TENSOR_DATA_TYPE_UINT8"; + case DML_TENSOR_DATA_TYPE_INT32: return "DML_TENSOR_DATA_TYPE_INT32"; + case DML_TENSOR_DATA_TYPE_INT16: return "DML_TENSOR_DATA_TYPE_INT16"; + case DML_TENSOR_DATA_TYPE_INT8: return "DML_TENSOR_DATA_TYPE_INT8"; + case DML_TENSOR_DATA_TYPE_FLOAT64: return "DML_TENSOR_DATA_TYPE_FLOAT64"; + case DML_TENSOR_DATA_TYPE_UINT64: return "DML_TENSOR_DATA_TYPE_UINT64"; + case DML_TENSOR_DATA_TYPE_INT64: return "DML_TENSOR_DATA_TYPE_INT64"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_TENSOR_TYPE value) +{ + switch (value) + { + case DML_TENSOR_TYPE_INVALID: return "DML_TENSOR_TYPE_INVALID"; + case DML_TENSOR_TYPE_BUFFER: return "DML_TENSOR_TYPE_BUFFER"; + default: + assert(false); + return ""; + } +} +template <> inline gsl::czstring ToString(DML_OPERATOR_TYPE value) { switch (value) @@ -2561,9 +2652,6 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_ELEMENT_WISE_ATAN: return "DML_OPERATOR_ELEMENT_WISE_ATAN"; case DML_OPERATOR_ELEMENT_WISE_CEIL: return "DML_OPERATOR_ELEMENT_WISE_CEIL"; case DML_OPERATOR_ELEMENT_WISE_CLIP: return "DML_OPERATOR_ELEMENT_WISE_CLIP"; - case DML_OPERATOR_ELEMENT_WISE_CLIP1: return "DML_OPERATOR_ELEMENT_WISE_CLIP1"; - case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD"; - case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1"; case DML_OPERATOR_ELEMENT_WISE_COS: return "DML_OPERATOR_ELEMENT_WISE_COS"; case DML_OPERATOR_ELEMENT_WISE_DIVIDE: return "DML_OPERATOR_ELEMENT_WISE_DIVIDE"; case DML_OPERATOR_ELEMENT_WISE_EXP: return "DML_OPERATOR_ELEMENT_WISE_EXP"; @@ -2587,24 +2675,41 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_ELEMENT_WISE_RECIP: return "DML_OPERATOR_ELEMENT_WISE_RECIP"; case DML_OPERATOR_ELEMENT_WISE_SIN: return "DML_OPERATOR_ELEMENT_WISE_SIN"; case DML_OPERATOR_ELEMENT_WISE_SQRT: return "DML_OPERATOR_ELEMENT_WISE_SQRT"; - case DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE: return "DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE"; - case DML_OPERATOR_ELEMENT_WISE_ATAN_YX: return "DML_OPERATOR_ELEMENT_WISE_ATAN_YX"; case DML_OPERATOR_ELEMENT_WISE_SUBTRACT: return "DML_OPERATOR_ELEMENT_WISE_SUBTRACT"; case DML_OPERATOR_ELEMENT_WISE_TAN: return "DML_OPERATOR_ELEMENT_WISE_TAN"; case DML_OPERATOR_ELEMENT_WISE_THRESHOLD: return "DML_OPERATOR_ELEMENT_WISE_THRESHOLD"; case DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR: return "DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR"; case DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR: return "DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR"; + case DML_OPERATOR_ACTIVATION_ELU: return "DML_OPERATOR_ACTIVATION_ELU"; + case DML_OPERATOR_ACTIVATION_CELU: return "DML_OPERATOR_ACTIVATION_CELU"; + case DML_OPERATOR_ACTIVATION_HARDMAX: return "DML_OPERATOR_ACTIVATION_HARDMAX"; + case DML_OPERATOR_ACTIVATION_HARDMAX1: return "DML_OPERATOR_ACTIVATION_HARDMAX1"; + case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return "DML_OPERATOR_ACTIVATION_HARD_SIGMOID"; + case DML_OPERATOR_ACTIVATION_IDENTITY: return "DML_OPERATOR_ACTIVATION_IDENTITY"; + case DML_OPERATOR_ACTIVATION_LEAKY_RELU: return "DML_OPERATOR_ACTIVATION_LEAKY_RELU"; + case DML_OPERATOR_ACTIVATION_LINEAR: return "DML_OPERATOR_ACTIVATION_LINEAR"; + case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX: return "DML_OPERATOR_ACTIVATION_LOG_SOFTMAX"; + case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1: return "DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1"; + case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU: return "DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU"; + case DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS: return "DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS"; + case DML_OPERATOR_ACTIVATION_RELU: return "DML_OPERATOR_ACTIVATION_RELU"; + case DML_OPERATOR_ACTIVATION_SCALED_ELU: return "DML_OPERATOR_ACTIVATION_SCALED_ELU"; + case DML_OPERATOR_ACTIVATION_SCALED_TANH: return "DML_OPERATOR_ACTIVATION_SCALED_TANH"; + case DML_OPERATOR_ACTIVATION_SIGMOID: return "DML_OPERATOR_ACTIVATION_SIGMOID"; + case DML_OPERATOR_ACTIVATION_SOFTMAX: return "DML_OPERATOR_ACTIVATION_SOFTMAX"; + case DML_OPERATOR_ACTIVATION_SOFTMAX1: return "DML_OPERATOR_ACTIVATION_SOFTMAX1"; + case DML_OPERATOR_ACTIVATION_SOFTPLUS: return "DML_OPERATOR_ACTIVATION_SOFTPLUS"; + case DML_OPERATOR_ACTIVATION_SOFTSIGN: return "DML_OPERATOR_ACTIVATION_SOFTSIGN"; + case DML_OPERATOR_ACTIVATION_TANH: return "DML_OPERATOR_ACTIVATION_TANH"; + case DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU: return "DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU"; case DML_OPERATOR_CONVOLUTION: return "DML_OPERATOR_CONVOLUTION"; case DML_OPERATOR_GEMM: return "DML_OPERATOR_GEMM"; case DML_OPERATOR_REDUCE: return "DML_OPERATOR_REDUCE"; - case DML_OPERATOR_ARGMIN: return "DML_OPERATOR_ARGMIN"; - case DML_OPERATOR_ARGMAX: return "DML_OPERATOR_ARGMAX"; case DML_OPERATOR_AVERAGE_POOLING: return "DML_OPERATOR_AVERAGE_POOLING"; case DML_OPERATOR_AVERAGE_POOLING1: return "DML_OPERATOR_AVERAGE_POOLING1"; case DML_OPERATOR_LP_POOLING: return "DML_OPERATOR_LP_POOLING"; case DML_OPERATOR_LP_POOLING1: return "DML_OPERATOR_LP_POOLING1"; case DML_OPERATOR_MAX_POOLING: return "DML_OPERATOR_MAX_POOLING"; - case DML_OPERATOR_MAX_POOLING1: return "DML_OPERATOR_MAX_POOLING1"; case DML_OPERATOR_ROI_POOLING: return "DML_OPERATOR_ROI_POOLING"; case DML_OPERATOR_SLICE: return "DML_OPERATOR_SLICE"; case DML_OPERATOR_CAST: return "DML_OPERATOR_CAST"; @@ -2620,18 +2725,15 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_TILE: return "DML_OPERATOR_TILE"; case DML_OPERATOR_TOP_K: return "DML_OPERATOR_TOP_K"; case DML_OPERATOR_BATCH_NORMALIZATION: return "DML_OPERATOR_BATCH_NORMALIZATION"; - case DML_OPERATOR_BATCH_NORMALIZATION_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_GRAD"; - case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD"; + case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING"; case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION: return "DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION"; case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION: return "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION"; - case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD: return "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD"; case DML_OPERATOR_LP_NORMALIZATION: return "DML_OPERATOR_LP_NORMALIZATION"; case DML_OPERATOR_RNN: return "DML_OPERATOR_RNN"; case DML_OPERATOR_LSTM: return "DML_OPERATOR_LSTM"; case DML_OPERATOR_GRU: return "DML_OPERATOR_GRU"; case DML_OPERATOR_ELEMENT_WISE_SIGN: return "DML_OPERATOR_ELEMENT_WISE_SIGN"; case DML_OPERATOR_ELEMENT_WISE_IS_NAN: return "DML_OPERATOR_ELEMENT_WISE_IS_NAN"; - case DML_OPERATOR_ELEMENT_WISE_NEGATE: return "DML_OPERATOR_ELEMENT_WISE_NEGATE"; case DML_OPERATOR_ELEMENT_WISE_ERF: return "DML_OPERATOR_ELEMENT_WISE_ERF"; case DML_OPERATOR_ELEMENT_WISE_SINH: return "DML_OPERATOR_ELEMENT_WISE_SINH"; case DML_OPERATOR_ELEMENT_WISE_COSH: return "DML_OPERATOR_ELEMENT_WISE_COSH"; @@ -2641,6 +2743,8 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_ELEMENT_WISE_ATANH: return "DML_OPERATOR_ELEMENT_WISE_ATANH"; case DML_OPERATOR_ELEMENT_WISE_IF: return "DML_OPERATOR_ELEMENT_WISE_IF"; case DML_OPERATOR_ELEMENT_WISE_ADD1: return "DML_OPERATOR_ELEMENT_WISE_ADD1"; + case DML_OPERATOR_ACTIVATION_SHRINK: return "DML_OPERATOR_ACTIVATION_SHRINK"; + case DML_OPERATOR_MAX_POOLING1: return "DML_OPERATOR_MAX_POOLING1"; case DML_OPERATOR_MAX_UNPOOLING: return "DML_OPERATOR_MAX_UNPOOLING"; case DML_OPERATOR_DIAGONAL_MATRIX: return "DML_OPERATOR_DIAGONAL_MATRIX"; case DML_OPERATOR_SCATTER: return "DML_OPERATOR_SCATTER"; @@ -2652,10 +2756,9 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_ELEMENT_WISE_IS_INFINITY: return "DML_OPERATOR_ELEMENT_WISE_IS_INFINITY"; case DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE: return "DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE"; case DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR: return "DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR"; - case DML_OPERATOR_FILL_VALUE_CONSTANT: return "DML_OPERATOR_FILL_VALUE_CONSTANT"; case DML_OPERATOR_FILL_VALUE_SEQUENCE: return "DML_OPERATOR_FILL_VALUE_SEQUENCE"; + case DML_OPERATOR_FILL_VALUE_CONSTANT: return "DML_OPERATOR_FILL_VALUE_CONSTANT"; case DML_OPERATOR_CUMULATIVE_SUMMATION: return "DML_OPERATOR_CUMULATIVE_SUMMATION"; - case DML_OPERATOR_CUMULATIVE_PRODUCT: return "DML_OPERATOR_CUMULATIVE_PRODUCT"; case DML_OPERATOR_REVERSE_SUBSEQUENCES: return "DML_OPERATOR_REVERSE_SUBSEQUENCES"; case DML_OPERATOR_GATHER_ELEMENTS: return "DML_OPERATOR_GATHER_ELEMENTS"; case DML_OPERATOR_GATHER_ND: return "DML_OPERATOR_GATHER_ND"; @@ -2684,20 +2787,278 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_RESAMPLE_GRAD: return "DML_OPERATOR_RESAMPLE_GRAD"; case DML_OPERATOR_SLICE_GRAD: return "DML_OPERATOR_SLICE_GRAD"; case DML_OPERATOR_ADAM_OPTIMIZER: return "DML_OPERATOR_ADAM_OPTIMIZER"; + case DML_OPERATOR_ARGMIN: return "DML_OPERATOR_ARGMIN"; + case DML_OPERATOR_ARGMAX: return "DML_OPERATOR_ARGMAX"; case DML_OPERATOR_ROI_ALIGN: return "DML_OPERATOR_ROI_ALIGN"; - case DML_OPERATOR_ROI_ALIGN1: return "DML_OPERATOR_ROI_ALIGN1"; case DML_OPERATOR_GATHER_ND1: return "DML_OPERATOR_GATHER_ND1"; - case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: return "DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR"; + case DML_OPERATOR_ELEMENT_WISE_ATAN_YX: return "DML_OPERATOR_ELEMENT_WISE_ATAN_YX"; + case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD"; + case DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE: return "DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE"; + case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD: return "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD"; + case DML_OPERATOR_CUMULATIVE_PRODUCT: return "DML_OPERATOR_CUMULATIVE_PRODUCT"; + case DML_OPERATOR_BATCH_NORMALIZATION_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_GRAD"; + case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD"; case DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD: return "DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD"; - case DML_OPERATOR_ROI_ALIGN_GRAD: return "DML_OPERATOR_ROI_ALIGN_GRAD"; - case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING"; + case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: return "DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR"; + case DML_OPERATOR_ROI_ALIGN1: return "DML_OPERATOR_ROI_ALIGN1"; + case DML_OPERATOR_ELEMENT_WISE_CLIP1: return "DML_OPERATOR_ELEMENT_WISE_CLIP1"; + case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1"; + case DML_OPERATOR_ELEMENT_WISE_NEGATE: return "DML_OPERATOR_ELEMENT_WISE_NEGATE"; + case DML_OPERATOR_ACTIVATION_GELU: return "DML_OPERATOR_ACTIVATION_GELU"; + case DML_OPERATOR_ACTIVATION_SWISH: return "DML_OPERATOR_ACTIVATION_SWISH"; + case DML_OPERATOR_ACTIVATION_HARD_SWISH: return "DML_OPERATOR_ACTIVATION_HARD_SWISH"; case DML_OPERATOR_RESAMPLE2: return "DML_OPERATOR_RESAMPLE2"; case DML_OPERATOR_RESAMPLE_GRAD1: return "DML_OPERATOR_RESAMPLE_GRAD1"; case DML_OPERATOR_DIAGONAL_MATRIX1: return "DML_OPERATOR_DIAGONAL_MATRIX1"; case DML_OPERATOR_MULTIHEAD_ATTENTION: return "DML_OPERATOR_MULTIHEAD_ATTENTION"; + case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: return "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING"; + case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_BINDING_TYPE value) +{ + switch (value) + { + case DML_BINDING_TYPE_NONE: return "DML_BINDING_TYPE_NONE"; + case DML_BINDING_TYPE_BUFFER: return "DML_BINDING_TYPE_BUFFER"; + case DML_BINDING_TYPE_BUFFER_ARRAY: return "DML_BINDING_TYPE_BUFFER_ARRAY"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_REDUCE_FUNCTION value) +{ + switch (value) + { + case DML_REDUCE_FUNCTION_ARGMAX: return "DML_REDUCE_FUNCTION_ARGMAX"; + case DML_REDUCE_FUNCTION_ARGMIN: return "DML_REDUCE_FUNCTION_ARGMIN"; + case DML_REDUCE_FUNCTION_AVERAGE: return "DML_REDUCE_FUNCTION_AVERAGE"; + case DML_REDUCE_FUNCTION_L1: return "DML_REDUCE_FUNCTION_L1"; + case DML_REDUCE_FUNCTION_L2: return "DML_REDUCE_FUNCTION_L2"; + case DML_REDUCE_FUNCTION_LOG_SUM: return "DML_REDUCE_FUNCTION_LOG_SUM"; + case DML_REDUCE_FUNCTION_LOG_SUM_EXP: return "DML_REDUCE_FUNCTION_LOG_SUM_EXP"; + case DML_REDUCE_FUNCTION_MAX: return "DML_REDUCE_FUNCTION_MAX"; + case DML_REDUCE_FUNCTION_MIN: return "DML_REDUCE_FUNCTION_MIN"; + case DML_REDUCE_FUNCTION_MULTIPLY: return "DML_REDUCE_FUNCTION_MULTIPLY"; + case DML_REDUCE_FUNCTION_SUM: return "DML_REDUCE_FUNCTION_SUM"; + case DML_REDUCE_FUNCTION_SUM_SQUARE: return "DML_REDUCE_FUNCTION_SUM_SQUARE"; default: assert(false); return ""; } } + +template <> +inline gsl::czstring ToString(DML_MATRIX_TRANSFORM value) +{ + switch (value) + { + case DML_MATRIX_TRANSFORM_NONE: return "DML_MATRIX_TRANSFORM_NONE"; + case DML_MATRIX_TRANSFORM_TRANSPOSE: return "DML_MATRIX_TRANSFORM_TRANSPOSE"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_CONVOLUTION_MODE value) +{ + switch (value) + { + case DML_CONVOLUTION_MODE_CONVOLUTION: return "DML_CONVOLUTION_MODE_CONVOLUTION"; + case DML_CONVOLUTION_MODE_CROSS_CORRELATION: return "DML_CONVOLUTION_MODE_CROSS_CORRELATION"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_CONVOLUTION_DIRECTION value) +{ + switch (value) + { + case DML_CONVOLUTION_DIRECTION_FORWARD: return "DML_CONVOLUTION_DIRECTION_FORWARD"; + case DML_CONVOLUTION_DIRECTION_BACKWARD: return "DML_CONVOLUTION_DIRECTION_BACKWARD"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_PADDING_MODE value) +{ + switch (value) + { + case DML_PADDING_MODE_CONSTANT: return "DML_PADDING_MODE_CONSTANT"; + case DML_PADDING_MODE_EDGE: return "DML_PADDING_MODE_EDGE"; + case DML_PADDING_MODE_REFLECTION: return "DML_PADDING_MODE_REFLECTION"; + case DML_PADDING_MODE_SYMMETRIC: return "DML_PADDING_MODE_SYMMETRIC"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_INTERPOLATION_MODE value) +{ + switch (value) + { + case DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR: return "DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR"; + case DML_INTERPOLATION_MODE_LINEAR: return "DML_INTERPOLATION_MODE_LINEAR"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_RECURRENT_NETWORK_DIRECTION value) +{ + switch (value) + { + case DML_RECURRENT_NETWORK_DIRECTION_FORWARD: return "DML_RECURRENT_NETWORK_DIRECTION_FORWARD"; + case DML_RECURRENT_NETWORK_DIRECTION_BACKWARD: return "DML_RECURRENT_NETWORK_DIRECTION_BACKWARD"; + case DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL: return "DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_FEATURE value) +{ + switch (value) + { + case DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT: return "DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT"; + case DML_FEATURE_FEATURE_LEVELS: return "DML_FEATURE_FEATURE_LEVELS"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_FEATURE_LEVEL value) +{ + switch (value) + { + case DML_FEATURE_LEVEL_1_0: return "DML_FEATURE_LEVEL_1_0"; + case DML_FEATURE_LEVEL_2_0: return "DML_FEATURE_LEVEL_2_0"; + case DML_FEATURE_LEVEL_2_1: return "DML_FEATURE_LEVEL_2_1"; + case DML_FEATURE_LEVEL_3_0: return "DML_FEATURE_LEVEL_3_0"; + case DML_FEATURE_LEVEL_3_1: return "DML_FEATURE_LEVEL_3_1"; + case DML_FEATURE_LEVEL_4_0: return "DML_FEATURE_LEVEL_4_0"; + case DML_FEATURE_LEVEL_4_1: return "DML_FEATURE_LEVEL_4_1"; + case DML_FEATURE_LEVEL_5_0: return "DML_FEATURE_LEVEL_5_0"; + case DML_FEATURE_LEVEL_5_1: return "DML_FEATURE_LEVEL_5_1"; + case DML_FEATURE_LEVEL_5_2: return "DML_FEATURE_LEVEL_5_2"; + case DML_FEATURE_LEVEL_6_0: return "DML_FEATURE_LEVEL_6_0"; + case DML_FEATURE_LEVEL_6_1: return "DML_FEATURE_LEVEL_6_1"; + case DML_FEATURE_LEVEL_6_2: return "DML_FEATURE_LEVEL_6_2"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_IS_INFINITY_MODE value) +{ + switch (value) + { + case DML_IS_INFINITY_MODE_EITHER: return "DML_IS_INFINITY_MODE_EITHER"; + case DML_IS_INFINITY_MODE_POSITIVE: return "DML_IS_INFINITY_MODE_POSITIVE"; + case DML_IS_INFINITY_MODE_NEGATIVE: return "DML_IS_INFINITY_MODE_NEGATIVE"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_DEPTH_SPACE_ORDER value) +{ + switch (value) + { + case DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW: return "DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW"; + case DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH: return "DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_AXIS_DIRECTION value) +{ + switch (value) + { + case DML_AXIS_DIRECTION_INCREASING: return "DML_AXIS_DIRECTION_INCREASING"; + case DML_AXIS_DIRECTION_DECREASING: return "DML_AXIS_DIRECTION_DECREASING"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_ROUNDING_MODE value) +{ + switch (value) + { + case DML_ROUNDING_MODE_HALVES_TO_NEAREST_EVEN: return "DML_ROUNDING_MODE_HALVES_TO_NEAREST_EVEN"; + case DML_ROUNDING_MODE_TOWARD_ZERO: return "DML_ROUNDING_MODE_TOWARD_ZERO"; + case DML_ROUNDING_MODE_TOWARD_INFINITY: return "DML_ROUNDING_MODE_TOWARD_INFINITY"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_RANDOM_GENERATOR_TYPE value) +{ + switch (value) + { + case DML_RANDOM_GENERATOR_TYPE_PHILOX_4X32_10: return "DML_RANDOM_GENERATOR_TYPE_PHILOX_4X32_10"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_MULTIHEAD_ATTENTION_MASK_TYPE value) +{ + switch (value) + { + case DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE"; + case DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH"; + case DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START"; + case DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END"; + case DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN"; + default: + assert(false); + return ""; + } +} + + +template +T FromString(std::string_view value); + +} } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index 2a82c12872a72..64ea5b7801a84 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -618,7 +618,7 @@ constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA { constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS[4] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ZeroPointTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ZeroPointTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, }; @@ -633,7 +633,7 @@ constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA { constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS[4] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ZeroPointTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ZeroPointTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, }; @@ -869,31 +869,6 @@ constexpr DML_OPERATOR_SCHEMA DML_ROI_POOLING_OPERATOR_SCHEMA { DML_ROI_POOLING_OPERATOR_SCHEMA_FIELDS, }; - -constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[13] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA { - "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING", - static_cast(DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING), - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 13, - DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS, -}; - constexpr DML_SCHEMA_FIELD DML_SLICE_OPERATOR_SCHEMA_FIELDS[6] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, @@ -1146,7 +1121,7 @@ constexpr DML_SCHEMA_FIELD DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleGradientTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputBiasGradientTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", false }, }; constexpr DML_OPERATOR_SCHEMA DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA { @@ -1890,6 +1865,25 @@ constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_OPERATOR_SCHEMA { DML_MATRIX_MULTIPLY_INTEGER_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS[8] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA { + "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT", + static_cast(DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT), + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 8, + DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA_FIELDS[9] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false }, @@ -2312,7 +2306,7 @@ constexpr DML_OPERATOR_SCHEMA DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA { DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA_FIELDS, }; -constexpr DML_SCHEMA_FIELD DML_RESAMPLE2_OPERATOR_SCHEMA_FIELDS[8]{ +constexpr DML_SCHEMA_FIELD DML_RESAMPLE2_OPERATOR_SCHEMA_FIELDS[8] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InterpolationMode", false }, @@ -2323,7 +2317,7 @@ constexpr DML_SCHEMA_FIELD DML_RESAMPLE2_OPERATOR_SCHEMA_FIELDS[8]{ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, "OutputPixelOffsets", false }, }; -constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE2_OPERATOR_SCHEMA{ +constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE2_OPERATOR_SCHEMA { "DML_OPERATOR_RESAMPLE2", DML_OPERATOR_RESAMPLE2, DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, @@ -2342,7 +2336,7 @@ constexpr DML_SCHEMA_FIELD DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA_FIELDS[8]{ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, "OutputPixelOffsets", false }, }; -constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA{ +constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA { "DML_OPERATOR_RESAMPLE_GRAD1", DML_OPERATOR_RESAMPLE_GRAD1, DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, @@ -2350,7 +2344,7 @@ constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA{ DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA_FIELDS, }; -constexpr DML_SCHEMA_FIELD DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA_FIELDS[6]{ +constexpr DML_SCHEMA_FIELD DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA_FIELDS[6] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ValueDataType", false }, @@ -2359,7 +2353,7 @@ constexpr DML_SCHEMA_FIELD DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA_FIELDS[6]{ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_INT, "DiagonalFillEnd", false }, }; -constexpr DML_OPERATOR_SCHEMA DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA{ +constexpr DML_OPERATOR_SCHEMA DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA { "DML_OPERATOR_DIAGONAL_MATRIX1", DML_OPERATOR_DIAGONAL_MATRIX1, DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, @@ -2396,6 +2390,30 @@ constexpr DML_OPERATOR_SCHEMA DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA { DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[13] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA { + "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING", + DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 13, + DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, @@ -2732,6 +2750,35 @@ constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_GELU_OPERATOR_SCHEMA { DML_ACTIVATION_GELU_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SWISH_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SigmoidInputScale", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SWISH_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_SWISH", + DML_OPERATOR_ACTIVATION_SWISH, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 3, + DML_ACTIVATION_SWISH_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_HARD_SWISH", + DML_OPERATOR_ACTIVATION_HARD_SWISH, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_RNN_ZERO_OPERATOR_SCHEMA_FIELDS[3] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "SequenceLengthsTensor", false }, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDesc_generated.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDesc_generated.h new file mode 100644 index 0000000000000..df485396f1e47 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDesc_generated.h @@ -0,0 +1,850 @@ +// automatically generated by the FlatBuffers compiler, do not modify + + +#ifndef FLATBUFFERS_GENERATED_DMLGRAPHDESC_DML_IR_H_ +#define FLATBUFFERS_GENERATED_DMLGRAPHDESC_DML_IR_H_ + +#include "core/common/flatbuffers.h" + +// Ensure the included flatbuffers.h is the same version as when this file was +// generated, otherwise it may not be compatible. +static_assert(FLATBUFFERS_VERSION_MAJOR == 23 && + FLATBUFFERS_VERSION_MINOR == 5 && + FLATBUFFERS_VERSION_REVISION == 26, + "Non-compatible flatbuffers version included"); + +#include "OperatorFieldTypes_generated.h" + +namespace dml { +namespace ir { + +struct ConstantRawData; +struct ConstantRawDataBuilder; + +struct ConstantName; +struct ConstantNameBuilder; + +struct ConstantNodeDesc; +struct ConstantNodeDescBuilder; + +struct DmlBufferTensorDesc; +struct DmlBufferTensorDescBuilder; + +struct OperatorNodeDesc; +struct OperatorNodeDescBuilder; + +struct DmlGraphNode; +struct DmlGraphNodeBuilder; + +struct DmlGraphDesc; +struct DmlGraphDescBuilder; + +enum ConstantNodeDescDetail : uint8_t { + ConstantNodeDescDetail_NONE = 0, + ConstantNodeDescDetail_ConstantName = 1, + ConstantNodeDescDetail_ConstantRawData = 2, + ConstantNodeDescDetail_MIN = ConstantNodeDescDetail_NONE, + ConstantNodeDescDetail_MAX = ConstantNodeDescDetail_ConstantRawData +}; + +inline const ConstantNodeDescDetail (&EnumValuesConstantNodeDescDetail())[3] { + static const ConstantNodeDescDetail values[] = { + ConstantNodeDescDetail_NONE, + ConstantNodeDescDetail_ConstantName, + ConstantNodeDescDetail_ConstantRawData + }; + return values; +} + +inline const char * const *EnumNamesConstantNodeDescDetail() { + static const char * const names[4] = { + "NONE", + "ConstantName", + "ConstantRawData", + nullptr + }; + return names; +} + +inline const char *EnumNameConstantNodeDescDetail(ConstantNodeDescDetail e) { + if (::flatbuffers::IsOutRange(e, ConstantNodeDescDetail_NONE, ConstantNodeDescDetail_ConstantRawData)) return ""; + const size_t index = static_cast(e); + return EnumNamesConstantNodeDescDetail()[index]; +} + +template struct ConstantNodeDescDetailTraits { + static const ConstantNodeDescDetail enum_value = ConstantNodeDescDetail_NONE; +}; + +template<> struct ConstantNodeDescDetailTraits { + static const ConstantNodeDescDetail enum_value = ConstantNodeDescDetail_ConstantName; +}; + +template<> struct ConstantNodeDescDetailTraits { + static const ConstantNodeDescDetail enum_value = ConstantNodeDescDetail_ConstantRawData; +}; + +bool VerifyConstantNodeDescDetail(::flatbuffers::Verifier &verifier, const void *obj, ConstantNodeDescDetail type); +bool VerifyConstantNodeDescDetailVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset> *values, const ::flatbuffers::Vector *types); + +enum NodeDesc : uint8_t { + NodeDesc_NONE = 0, + NodeDesc_OperatorNodeDesc = 1, + NodeDesc_ConstantNodeDesc = 2, + NodeDesc_MIN = NodeDesc_NONE, + NodeDesc_MAX = NodeDesc_ConstantNodeDesc +}; + +inline const NodeDesc (&EnumValuesNodeDesc())[3] { + static const NodeDesc values[] = { + NodeDesc_NONE, + NodeDesc_OperatorNodeDesc, + NodeDesc_ConstantNodeDesc + }; + return values; +} + +inline const char * const *EnumNamesNodeDesc() { + static const char * const names[4] = { + "NONE", + "OperatorNodeDesc", + "ConstantNodeDesc", + nullptr + }; + return names; +} + +inline const char *EnumNameNodeDesc(NodeDesc e) { + if (::flatbuffers::IsOutRange(e, NodeDesc_NONE, NodeDesc_ConstantNodeDesc)) return ""; + const size_t index = static_cast(e); + return EnumNamesNodeDesc()[index]; +} + +template struct NodeDescTraits { + static const NodeDesc enum_value = NodeDesc_NONE; +}; + +template<> struct NodeDescTraits { + static const NodeDesc enum_value = NodeDesc_OperatorNodeDesc; +}; + +template<> struct NodeDescTraits { + static const NodeDesc enum_value = NodeDesc_ConstantNodeDesc; +}; + +bool VerifyNodeDesc(::flatbuffers::Verifier &verifier, const void *obj, NodeDesc type); +bool VerifyNodeDescVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset> *values, const ::flatbuffers::Vector *types); + +struct ConstantRawData FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ConstantRawDataBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA = 4 + }; + const ::flatbuffers::Vector *data() const { + return GetPointer *>(VT_DATA); + } + ::flatbuffers::Vector *mutable_data() { + return GetPointer<::flatbuffers::Vector *>(VT_DATA); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.VerifyVector(data()) && + verifier.EndTable(); + } +}; + +struct ConstantRawDataBuilder { + typedef ConstantRawData Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_data(::flatbuffers::Offset<::flatbuffers::Vector> data) { + fbb_.AddOffset(ConstantRawData::VT_DATA, data); + } + explicit ConstantRawDataBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateConstantRawData( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> data = 0) { + ConstantRawDataBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateConstantRawDataDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *data = nullptr) { + auto data__ = data ? _fbb.CreateVector(*data) : 0; + return dml::ir::CreateConstantRawData( + _fbb, + data__); +} + +struct ConstantName FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ConstantNameBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4 + }; + const ::flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + ::flatbuffers::String *mutable_name() { + return GetPointer<::flatbuffers::String *>(VT_NAME); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + verifier.EndTable(); + } +}; + +struct ConstantNameBuilder { + typedef ConstantName Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_name(::flatbuffers::Offset<::flatbuffers::String> name) { + fbb_.AddOffset(ConstantName::VT_NAME, name); + } + explicit ConstantNameBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateConstantName( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::String> name = 0) { + ConstantNameBuilder builder_(_fbb); + builder_.add_name(name); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateConstantNameDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr) { + auto name__ = name ? _fbb.CreateString(name) : 0; + return dml::ir::CreateConstantName( + _fbb, + name__); +} + +struct ConstantNodeDesc FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ConstantNodeDescBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA_TYPE = 4, + VT_DATA = 6 + }; + dml::ir::ConstantNodeDescDetail data_type() const { + return static_cast(GetField(VT_DATA_TYPE, 0)); + } + const void *data() const { + return GetPointer(VT_DATA); + } + template const T *data_as() const; + const dml::ir::ConstantName *data_as_ConstantName() const { + return data_type() == dml::ir::ConstantNodeDescDetail_ConstantName ? static_cast(data()) : nullptr; + } + const dml::ir::ConstantRawData *data_as_ConstantRawData() const { + return data_type() == dml::ir::ConstantNodeDescDetail_ConstantRawData ? static_cast(data()) : nullptr; + } + void *mutable_data() { + return GetPointer(VT_DATA); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_DATA_TYPE, 1) && + VerifyOffset(verifier, VT_DATA) && + VerifyConstantNodeDescDetail(verifier, data(), data_type()) && + verifier.EndTable(); + } +}; + +template<> inline const dml::ir::ConstantName *ConstantNodeDesc::data_as() const { + return data_as_ConstantName(); +} + +template<> inline const dml::ir::ConstantRawData *ConstantNodeDesc::data_as() const { + return data_as_ConstantRawData(); +} + +struct ConstantNodeDescBuilder { + typedef ConstantNodeDesc Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_data_type(dml::ir::ConstantNodeDescDetail data_type) { + fbb_.AddElement(ConstantNodeDesc::VT_DATA_TYPE, static_cast(data_type), 0); + } + void add_data(::flatbuffers::Offset data) { + fbb_.AddOffset(ConstantNodeDesc::VT_DATA, data); + } + explicit ConstantNodeDescBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateConstantNodeDesc( + ::flatbuffers::FlatBufferBuilder &_fbb, + dml::ir::ConstantNodeDescDetail data_type = dml::ir::ConstantNodeDescDetail_NONE, + ::flatbuffers::Offset data = 0) { + ConstantNodeDescBuilder builder_(_fbb); + builder_.add_data(data); + builder_.add_data_type(data_type); + return builder_.Finish(); +} + +struct DmlBufferTensorDesc FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef DmlBufferTensorDescBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATATYPE = 4, + VT_SIZES = 6, + VT_STRIDES = 8, + VT_TOTALTENSORSIZEINBYTES = 10 + }; + const ::flatbuffers::String *dataType() const { + return GetPointer(VT_DATATYPE); + } + ::flatbuffers::String *mutable_dataType() { + return GetPointer<::flatbuffers::String *>(VT_DATATYPE); + } + const ::flatbuffers::Vector *sizes() const { + return GetPointer *>(VT_SIZES); + } + ::flatbuffers::Vector *mutable_sizes() { + return GetPointer<::flatbuffers::Vector *>(VT_SIZES); + } + const ::flatbuffers::Vector *strides() const { + return GetPointer *>(VT_STRIDES); + } + ::flatbuffers::Vector *mutable_strides() { + return GetPointer<::flatbuffers::Vector *>(VT_STRIDES); + } + uint64_t totalTensorSizeInBytes() const { + return GetField(VT_TOTALTENSORSIZEINBYTES, 0); + } + bool mutate_totalTensorSizeInBytes(uint64_t _totalTensorSizeInBytes = 0) { + return SetField(VT_TOTALTENSORSIZEINBYTES, _totalTensorSizeInBytes, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATATYPE) && + verifier.VerifyString(dataType()) && + VerifyOffset(verifier, VT_SIZES) && + verifier.VerifyVector(sizes()) && + VerifyOffset(verifier, VT_STRIDES) && + verifier.VerifyVector(strides()) && + VerifyField(verifier, VT_TOTALTENSORSIZEINBYTES, 8) && + verifier.EndTable(); + } +}; + +struct DmlBufferTensorDescBuilder { + typedef DmlBufferTensorDesc Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_dataType(::flatbuffers::Offset<::flatbuffers::String> dataType) { + fbb_.AddOffset(DmlBufferTensorDesc::VT_DATATYPE, dataType); + } + void add_sizes(::flatbuffers::Offset<::flatbuffers::Vector> sizes) { + fbb_.AddOffset(DmlBufferTensorDesc::VT_SIZES, sizes); + } + void add_strides(::flatbuffers::Offset<::flatbuffers::Vector> strides) { + fbb_.AddOffset(DmlBufferTensorDesc::VT_STRIDES, strides); + } + void add_totalTensorSizeInBytes(uint64_t totalTensorSizeInBytes) { + fbb_.AddElement(DmlBufferTensorDesc::VT_TOTALTENSORSIZEINBYTES, totalTensorSizeInBytes, 0); + } + explicit DmlBufferTensorDescBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateDmlBufferTensorDesc( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::String> dataType = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> sizes = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> strides = 0, + uint64_t totalTensorSizeInBytes = 0) { + DmlBufferTensorDescBuilder builder_(_fbb); + builder_.add_totalTensorSizeInBytes(totalTensorSizeInBytes); + builder_.add_strides(strides); + builder_.add_sizes(sizes); + builder_.add_dataType(dataType); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateDmlBufferTensorDescDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const char *dataType = nullptr, + const std::vector *sizes = nullptr, + const std::vector *strides = nullptr, + uint64_t totalTensorSizeInBytes = 0) { + auto dataType__ = dataType ? _fbb.CreateString(dataType) : 0; + auto sizes__ = sizes ? _fbb.CreateVector(*sizes) : 0; + auto strides__ = strides ? _fbb.CreateVector(*strides) : 0; + return dml::ir::CreateDmlBufferTensorDesc( + _fbb, + dataType__, + sizes__, + strides__, + totalTensorSizeInBytes); +} + +struct OperatorNodeDesc FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef OperatorNodeDescBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_TYPE = 4, + VT_INPUTS = 6, + VT_OUTPUTS = 8, + VT_ATTRIBUTES = 10 + }; + const ::flatbuffers::String *type() const { + return GetPointer(VT_TYPE); + } + ::flatbuffers::String *mutable_type() { + return GetPointer<::flatbuffers::String *>(VT_TYPE); + } + const ::flatbuffers::Vector<::flatbuffers::Offset> *inputs() const { + return GetPointer> *>(VT_INPUTS); + } + ::flatbuffers::Vector<::flatbuffers::Offset> *mutable_inputs() { + return GetPointer<::flatbuffers::Vector<::flatbuffers::Offset> *>(VT_INPUTS); + } + const ::flatbuffers::Vector<::flatbuffers::Offset> *outputs() const { + return GetPointer> *>(VT_OUTPUTS); + } + ::flatbuffers::Vector<::flatbuffers::Offset> *mutable_outputs() { + return GetPointer<::flatbuffers::Vector<::flatbuffers::Offset> *>(VT_OUTPUTS); + } + const ::flatbuffers::Vector<::flatbuffers::Offset> *attributes() const { + return GetPointer> *>(VT_ATTRIBUTES); + } + ::flatbuffers::Vector<::flatbuffers::Offset> *mutable_attributes() { + return GetPointer<::flatbuffers::Vector<::flatbuffers::Offset> *>(VT_ATTRIBUTES); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_TYPE) && + verifier.VerifyString(type()) && + VerifyOffset(verifier, VT_INPUTS) && + verifier.VerifyVector(inputs()) && + verifier.VerifyVectorOfTables(inputs()) && + VerifyOffset(verifier, VT_OUTPUTS) && + verifier.VerifyVector(outputs()) && + verifier.VerifyVectorOfTables(outputs()) && + VerifyOffset(verifier, VT_ATTRIBUTES) && + verifier.VerifyVector(attributes()) && + verifier.VerifyVectorOfTables(attributes()) && + verifier.EndTable(); + } +}; + +struct OperatorNodeDescBuilder { + typedef OperatorNodeDesc Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_type(::flatbuffers::Offset<::flatbuffers::String> type) { + fbb_.AddOffset(OperatorNodeDesc::VT_TYPE, type); + } + void add_inputs(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> inputs) { + fbb_.AddOffset(OperatorNodeDesc::VT_INPUTS, inputs); + } + void add_outputs(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> outputs) { + fbb_.AddOffset(OperatorNodeDesc::VT_OUTPUTS, outputs); + } + void add_attributes(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> attributes) { + fbb_.AddOffset(OperatorNodeDesc::VT_ATTRIBUTES, attributes); + } + explicit OperatorNodeDescBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateOperatorNodeDesc( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::String> type = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> inputs = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> outputs = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> attributes = 0) { + OperatorNodeDescBuilder builder_(_fbb); + builder_.add_attributes(attributes); + builder_.add_outputs(outputs); + builder_.add_inputs(inputs); + builder_.add_type(type); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateOperatorNodeDescDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const char *type = nullptr, + const std::vector<::flatbuffers::Offset> *inputs = nullptr, + const std::vector<::flatbuffers::Offset> *outputs = nullptr, + const std::vector<::flatbuffers::Offset> *attributes = nullptr) { + auto type__ = type ? _fbb.CreateString(type) : 0; + auto inputs__ = inputs ? _fbb.CreateVector<::flatbuffers::Offset>(*inputs) : 0; + auto outputs__ = outputs ? _fbb.CreateVector<::flatbuffers::Offset>(*outputs) : 0; + auto attributes__ = attributes ? _fbb.CreateVector<::flatbuffers::Offset>(*attributes) : 0; + return dml::ir::CreateOperatorNodeDesc( + _fbb, + type__, + inputs__, + outputs__, + attributes__); +} + +struct DmlGraphNode FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef DmlGraphNodeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DESC_TYPE = 4, + VT_DESC = 6, + VT_NAME = 8, + VT_INPUTNAMES = 10, + VT_OUTPUTNAMES = 12 + }; + dml::ir::NodeDesc desc_type() const { + return static_cast(GetField(VT_DESC_TYPE, 0)); + } + const void *desc() const { + return GetPointer(VT_DESC); + } + template const T *desc_as() const; + const dml::ir::OperatorNodeDesc *desc_as_OperatorNodeDesc() const { + return desc_type() == dml::ir::NodeDesc_OperatorNodeDesc ? static_cast(desc()) : nullptr; + } + const dml::ir::ConstantNodeDesc *desc_as_ConstantNodeDesc() const { + return desc_type() == dml::ir::NodeDesc_ConstantNodeDesc ? static_cast(desc()) : nullptr; + } + void *mutable_desc() { + return GetPointer(VT_DESC); + } + const ::flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + ::flatbuffers::String *mutable_name() { + return GetPointer<::flatbuffers::String *>(VT_NAME); + } + const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *inputNames() const { + return GetPointer> *>(VT_INPUTNAMES); + } + ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *mutable_inputNames() { + return GetPointer<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *>(VT_INPUTNAMES); + } + const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *outputNames() const { + return GetPointer> *>(VT_OUTPUTNAMES); + } + ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *mutable_outputNames() { + return GetPointer<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *>(VT_OUTPUTNAMES); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_DESC_TYPE, 1) && + VerifyOffset(verifier, VT_DESC) && + VerifyNodeDesc(verifier, desc(), desc_type()) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyOffset(verifier, VT_INPUTNAMES) && + verifier.VerifyVector(inputNames()) && + verifier.VerifyVectorOfStrings(inputNames()) && + VerifyOffset(verifier, VT_OUTPUTNAMES) && + verifier.VerifyVector(outputNames()) && + verifier.VerifyVectorOfStrings(outputNames()) && + verifier.EndTable(); + } +}; + +template<> inline const dml::ir::OperatorNodeDesc *DmlGraphNode::desc_as() const { + return desc_as_OperatorNodeDesc(); +} + +template<> inline const dml::ir::ConstantNodeDesc *DmlGraphNode::desc_as() const { + return desc_as_ConstantNodeDesc(); +} + +struct DmlGraphNodeBuilder { + typedef DmlGraphNode Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_desc_type(dml::ir::NodeDesc desc_type) { + fbb_.AddElement(DmlGraphNode::VT_DESC_TYPE, static_cast(desc_type), 0); + } + void add_desc(::flatbuffers::Offset desc) { + fbb_.AddOffset(DmlGraphNode::VT_DESC, desc); + } + void add_name(::flatbuffers::Offset<::flatbuffers::String> name) { + fbb_.AddOffset(DmlGraphNode::VT_NAME, name); + } + void add_inputNames(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> inputNames) { + fbb_.AddOffset(DmlGraphNode::VT_INPUTNAMES, inputNames); + } + void add_outputNames(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> outputNames) { + fbb_.AddOffset(DmlGraphNode::VT_OUTPUTNAMES, outputNames); + } + explicit DmlGraphNodeBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateDmlGraphNode( + ::flatbuffers::FlatBufferBuilder &_fbb, + dml::ir::NodeDesc desc_type = dml::ir::NodeDesc_NONE, + ::flatbuffers::Offset desc = 0, + ::flatbuffers::Offset<::flatbuffers::String> name = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> inputNames = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> outputNames = 0) { + DmlGraphNodeBuilder builder_(_fbb); + builder_.add_outputNames(outputNames); + builder_.add_inputNames(inputNames); + builder_.add_name(name); + builder_.add_desc(desc); + builder_.add_desc_type(desc_type); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateDmlGraphNodeDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + dml::ir::NodeDesc desc_type = dml::ir::NodeDesc_NONE, + ::flatbuffers::Offset desc = 0, + const char *name = nullptr, + const std::vector<::flatbuffers::Offset<::flatbuffers::String>> *inputNames = nullptr, + const std::vector<::flatbuffers::Offset<::flatbuffers::String>> *outputNames = nullptr) { + auto name__ = name ? _fbb.CreateString(name) : 0; + auto inputNames__ = inputNames ? _fbb.CreateVector<::flatbuffers::Offset<::flatbuffers::String>>(*inputNames) : 0; + auto outputNames__ = outputNames ? _fbb.CreateVector<::flatbuffers::Offset<::flatbuffers::String>>(*outputNames) : 0; + return dml::ir::CreateDmlGraphNode( + _fbb, + desc_type, + desc, + name__, + inputNames__, + outputNames__); +} + +struct DmlGraphDesc FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef DmlGraphDescBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NODES = 4, + VT_GRAPHINPUTNAMES = 6, + VT_GRAPHOUTPUTNAMES = 8 + }; + const ::flatbuffers::Vector<::flatbuffers::Offset> *nodes() const { + return GetPointer> *>(VT_NODES); + } + ::flatbuffers::Vector<::flatbuffers::Offset> *mutable_nodes() { + return GetPointer<::flatbuffers::Vector<::flatbuffers::Offset> *>(VT_NODES); + } + const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *graphInputNames() const { + return GetPointer> *>(VT_GRAPHINPUTNAMES); + } + ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *mutable_graphInputNames() { + return GetPointer<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *>(VT_GRAPHINPUTNAMES); + } + const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *graphOutputNames() const { + return GetPointer> *>(VT_GRAPHOUTPUTNAMES); + } + ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *mutable_graphOutputNames() { + return GetPointer<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *>(VT_GRAPHOUTPUTNAMES); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NODES) && + verifier.VerifyVector(nodes()) && + verifier.VerifyVectorOfTables(nodes()) && + VerifyOffset(verifier, VT_GRAPHINPUTNAMES) && + verifier.VerifyVector(graphInputNames()) && + verifier.VerifyVectorOfStrings(graphInputNames()) && + VerifyOffset(verifier, VT_GRAPHOUTPUTNAMES) && + verifier.VerifyVector(graphOutputNames()) && + verifier.VerifyVectorOfStrings(graphOutputNames()) && + verifier.EndTable(); + } +}; + +struct DmlGraphDescBuilder { + typedef DmlGraphDesc Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_nodes(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> nodes) { + fbb_.AddOffset(DmlGraphDesc::VT_NODES, nodes); + } + void add_graphInputNames(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> graphInputNames) { + fbb_.AddOffset(DmlGraphDesc::VT_GRAPHINPUTNAMES, graphInputNames); + } + void add_graphOutputNames(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> graphOutputNames) { + fbb_.AddOffset(DmlGraphDesc::VT_GRAPHOUTPUTNAMES, graphOutputNames); + } + explicit DmlGraphDescBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateDmlGraphDesc( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> nodes = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> graphInputNames = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> graphOutputNames = 0) { + DmlGraphDescBuilder builder_(_fbb); + builder_.add_graphOutputNames(graphOutputNames); + builder_.add_graphInputNames(graphInputNames); + builder_.add_nodes(nodes); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateDmlGraphDescDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<::flatbuffers::Offset> *nodes = nullptr, + const std::vector<::flatbuffers::Offset<::flatbuffers::String>> *graphInputNames = nullptr, + const std::vector<::flatbuffers::Offset<::flatbuffers::String>> *graphOutputNames = nullptr) { + auto nodes__ = nodes ? _fbb.CreateVector<::flatbuffers::Offset>(*nodes) : 0; + auto graphInputNames__ = graphInputNames ? _fbb.CreateVector<::flatbuffers::Offset<::flatbuffers::String>>(*graphInputNames) : 0; + auto graphOutputNames__ = graphOutputNames ? _fbb.CreateVector<::flatbuffers::Offset<::flatbuffers::String>>(*graphOutputNames) : 0; + return dml::ir::CreateDmlGraphDesc( + _fbb, + nodes__, + graphInputNames__, + graphOutputNames__); +} + +inline bool VerifyConstantNodeDescDetail(::flatbuffers::Verifier &verifier, const void *obj, ConstantNodeDescDetail type) { + switch (type) { + case ConstantNodeDescDetail_NONE: { + return true; + } + case ConstantNodeDescDetail_ConstantName: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case ConstantNodeDescDetail_ConstantRawData: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + default: return true; + } +} + +inline bool VerifyConstantNodeDescDetailVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset> *values, const ::flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (::flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyConstantNodeDescDetail( + verifier, values->Get(i), types->GetEnum(i))) { + return false; + } + } + return true; +} + +inline bool VerifyNodeDesc(::flatbuffers::Verifier &verifier, const void *obj, NodeDesc type) { + switch (type) { + case NodeDesc_NONE: { + return true; + } + case NodeDesc_OperatorNodeDesc: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case NodeDesc_ConstantNodeDesc: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + default: return true; + } +} + +inline bool VerifyNodeDescVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset> *values, const ::flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (::flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyNodeDesc( + verifier, values->Get(i), types->GetEnum(i))) { + return false; + } + } + return true; +} + +inline const dml::ir::DmlGraphDesc *GetDmlGraphDesc(const void *buf) { + return ::flatbuffers::GetRoot(buf); +} + +inline const dml::ir::DmlGraphDesc *GetSizePrefixedDmlGraphDesc(const void *buf) { + return ::flatbuffers::GetSizePrefixedRoot(buf); +} + +inline DmlGraphDesc *GetMutableDmlGraphDesc(void *buf) { + return ::flatbuffers::GetMutableRoot(buf); +} + +inline dml::ir::DmlGraphDesc *GetMutableSizePrefixedDmlGraphDesc(void *buf) { + return ::flatbuffers::GetMutableSizePrefixedRoot(buf); +} + +inline bool VerifyDmlGraphDescBuffer( + ::flatbuffers::Verifier &verifier) { + return verifier.VerifyBuffer(nullptr); +} + +inline bool VerifySizePrefixedDmlGraphDescBuffer( + ::flatbuffers::Verifier &verifier) { + return verifier.VerifySizePrefixedBuffer(nullptr); +} + +inline void FinishDmlGraphDescBuffer( + ::flatbuffers::FlatBufferBuilder &fbb, + ::flatbuffers::Offset root) { + fbb.Finish(root); +} + +inline void FinishSizePrefixedDmlGraphDescBuffer( + ::flatbuffers::FlatBufferBuilder &fbb, + ::flatbuffers::Offset root) { + fbb.FinishSizePrefixed(root); +} + +} // namespace ir +} // namespace dml + +#endif // FLATBUFFERS_GENERATED_DMLGRAPHDESC_DML_IR_H_ diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDeserialization.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDeserialization.h new file mode 100644 index 0000000000000..9decf0dce1bb2 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDeserialization.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +#pragma once +#include "DmlSerializedGraphDesc.h" + +struct NodeIndex +{ + uint32_t nodeIndex; + uint32_t nodeOutputIndex; +}; + +DmlSerializedGraphDesc DeserializeDmlGraph( + const uint8_t* flatbufferGraphDescBlob, + /*out*/ std::vector>& rawData); \ No newline at end of file diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphSerialization.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphSerialization.h new file mode 100644 index 0000000000000..d8d069da906b7 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphSerialization.h @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +#pragma once +#include "DmlGraphDesc_generated.h" + +struct DmlSerializedGraphDesc; + +flatbuffers::DetachedBuffer SerializeDmlGraph(const DmlSerializedGraphDesc& graphDesc); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlSerializedGraphDesc.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlSerializedGraphDesc.h new file mode 100644 index 0000000000000..51c3d6c81244b --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlSerializedGraphDesc.h @@ -0,0 +1,73 @@ +//----------------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//----------------------------------------------------------------------------- + +#pragma once + +struct ConstantName +{ + std::string name; +}; + +struct ConstantData +{ + std::byte* data; + uint64_t dataSize; +}; + +using DmlSerializedGraphNodeConstantVariant = std::variant< + ConstantName, + ConstantData +>; + +using DmlSerializedGraphNodeDescVariant = std::variant< + AbstractOperatorDesc, + DmlSerializedGraphNodeConstantVariant +>; + +struct DmlSerializedGraphNode +{ + DmlSerializedGraphNodeDescVariant Desc; + std::string Name; +}; + +struct DmlInputSerializedGraphEdge +{ + uint32_t GraphInputIndex; + uint32_t ToNodeIndex; + uint32_t ToNodeInputIndex; + std::string Name; +}; + +struct DmlOutputSerializedGraphEdge +{ + uint32_t FromNodeIndex; + uint32_t FromNodeOutputIndex; + uint32_t GraphOutputIndex; + std::string Name; +}; + +struct DmlIntermediateSerializedGraphEdge +{ + uint32_t FromNodeIndex; + uint32_t FromNodeOutputIndex; + uint32_t ToNodeIndex; + uint32_t ToNodeInputIndex; + std::string Name; +}; + +struct DmlSerializedGraphDesc +{ + uint32_t InputCount; + uint32_t OutputCount; + // nodes must be present in topological order for deserialization to work + // because while creating a intermediate edge during deserialization, node (from + // which given intermediate edge is outputting) must be visited before than the node + // (to which given intermediate edge is inputting) + std::vector Nodes; + std::vector InputEdges; + std::vector OutputEdges; + std::vector IntermediateEdges; +}; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h index 99218c135f058..86c66d8cca26c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h @@ -425,7 +425,6 @@ inline std::vector GetFields(const DML_AVERAGE_POOLING_OPERATOR_D OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.IncludePadding))), }; } - inline std::vector GetFields(const DML_AVERAGE_POOLING1_OPERATOR_DESC& desc) { return { @@ -502,24 +501,6 @@ inline std::vector GetFields(const DML_ROI_POOLING_OPERATOR_DESC& OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.PooledSize))), }; } -inline std::vector GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputScaleTensor))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.InputZeroPointTensor))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputScaleTensor))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.OutputZeroPointTensor))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.DimensionCount))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.IncludePadding))), - }; -} inline std::vector GetFields(const DML_SLICE_OPERATOR_DESC& desc) { return { @@ -1158,6 +1139,19 @@ inline std::vector GetFields(const DML_QUANTIZED_LINEAR_MATRIX_MU OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.OutputTensor))), }; } +inline std::vector GetFields(const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.AScaleTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.AZeroPointTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.BScaleTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.BZeroPointTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.BiasTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} inline std::vector GetFields(const DML_CONVOLUTION_INTEGER_OPERATOR_DESC& desc) { return { @@ -1488,6 +1482,24 @@ inline std::vector GetFields(const DML_MULTIHEAD_ATTENTION_OPERAT OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[17], ToOperatorFieldType(static_cast(desc.MaskType))), }; } +inline std::vector GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.InputZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.OutputZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.IncludePadding))), + }; +} inline std::vector GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc) { return { @@ -1680,6 +1692,23 @@ inline std::vector GetFields(const DML_ACTIVATION_GELU_OPERATOR_D OperatorField(&DML_ACTIVATION_GELU_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), }; } +inline std::vector GetFields(const DML_ACTIVATION_SWISH_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_SWISH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_SWISH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ACTIVATION_SWISH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.SigmoidInputScale))), + }; +} +inline std::vector GetFields(const DML_ACTIVATION_HARD_SWISH_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Alpha))), + OperatorField(&DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Beta))), + }; +} inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) { switch (operatorType) @@ -1800,6 +1829,7 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_RESAMPLE1: return DML_RESAMPLE1_OPERATOR_SCHEMA; case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER: return DML_MATRIX_MULTIPLY_INTEGER_OPERATOR_SCHEMA; case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: return DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA; + case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA; case DML_OPERATOR_CONVOLUTION_INTEGER: return DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA; case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: return DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA; case DML_OPERATOR_ELEMENT_WISE_BIT_AND: return DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA; @@ -1826,6 +1856,7 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_RESAMPLE_GRAD1: return DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA; case DML_OPERATOR_DIAGONAL_MATRIX1: return DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA; case DML_OPERATOR_MULTIHEAD_ATTENTION: return DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA; + case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: return DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_CELU: return DML_ACTIVATION_CELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA; @@ -1850,6 +1881,8 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU: return DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_SHRINK: return DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_GELU: return DML_ACTIVATION_GELU_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_SWISH: return DML_ACTIVATION_SWISH_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_HARD_SWISH: return DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA; default: ORT_THROW_HR(E_INVALIDARG); @@ -2327,6 +2360,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: + return AbstractOperatorDesc( + &DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_CONVOLUTION_INTEGER: return AbstractOperatorDesc( &DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA, @@ -2431,6 +2468,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: + return AbstractOperatorDesc( + &DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_ACTIVATION_ELU: return AbstractOperatorDesc( &DML_ACTIVATION_ELU_OPERATOR_SCHEMA, @@ -2527,13 +2568,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_ACTIVATION_GELU_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); -#pragma warning(push) -#pragma warning(disable: 4063) - case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: + case DML_OPERATOR_ACTIVATION_SWISH: return AbstractOperatorDesc( - &DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); -#pragma warning(pop) + &DML_ACTIVATION_SWISH_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_HARD_SWISH: + return AbstractOperatorDesc( + &DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); default: ORT_THROW_HR(E_INVALIDARG); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h index 25f0dd26c6067..a94bb67b68d36 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h @@ -15,32 +15,34 @@ using ApiAttributeVariant = std::variant< const FLOAT*, const DML_SCALE_BIAS*, DML_SIZE_2D, - DML_SCALAR_UNION + DML_SCALAR_UNION, + BOOL >; namespace OperatorFieldTypes { using TensorDesc = std::optional; // DML_SCHEMA_FIELD_TYPE_TENSOR_DESC using TensorDescArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY - using OperatorDesc = std::optional; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC - using OperatorDescArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY + using FusedActivationOperatorDesc = std::optional; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC + using FusedActivationOperatorDescArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY using UInt = uint32_t; // DML_SCHEMA_FIELD_TYPE_UINT using UInt64 = uint64_t; // DML_SCHEMA_FIELD_TYPE_UINT64 using Int = int32_t; // DML_SCHEMA_FIELD_TYPE_INT using Float = float; // DML_SCHEMA_FIELD_TYPE_FLOAT - using UIntArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_UINT_ARRAY - using IntArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_INT_ARRAY - using FloatArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY + using UIntArray = std::vector; // DML_SCHEMA_FIELD_TYPE_UINT_ARRAY + using IntArray = std::vector; // DML_SCHEMA_FIELD_TYPE_INT_ARRAY + using FloatArray = std::vector; // DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY using ScaleBias = std::optional; // DML_SCHEMA_FIELD_TYPE_SCALE_BIAS using Size2D = DML_SIZE_2D; // DML_SCHEMA_FIELD_TYPE_SIZE_2D using ScalarUnion = DML_SCALAR_UNION; // DML_SCHEMA_FIELD_TYPE_SCALAR_UNION + using Bool = bool; // DML_SCHEMA_FIELD_TYPE_BOOL } using OperatorFieldVariant = std::variant< OperatorFieldTypes::TensorDesc, OperatorFieldTypes::TensorDescArray, - OperatorFieldTypes::OperatorDesc, - OperatorFieldTypes::OperatorDescArray, + OperatorFieldTypes::FusedActivationOperatorDesc, + OperatorFieldTypes::FusedActivationOperatorDescArray, OperatorFieldTypes::UInt, OperatorFieldTypes::UInt64, OperatorFieldTypes::Int, @@ -50,7 +52,8 @@ using OperatorFieldVariant = std::variant< OperatorFieldTypes::FloatArray, OperatorFieldTypes::ScaleBias, OperatorFieldTypes::Size2D, - OperatorFieldTypes::ScalarUnion + OperatorFieldTypes::ScalarUnion, + OperatorFieldTypes::Bool >; class OperatorField @@ -80,11 +83,11 @@ class OperatorField const OperatorFieldTypes::TensorDescArray& AsTensorDescArray() const { return std::get(m_data); } OperatorFieldTypes::TensorDescArray& AsTensorDescArray() { return std::get(m_data); } - const OperatorFieldTypes::OperatorDesc& AsOperatorDesc() const { return std::get(m_data); } - OperatorFieldTypes::OperatorDesc& AsOperatorDesc() { return std::get(m_data); } + const OperatorFieldTypes::FusedActivationOperatorDesc& AsFusedActivationOperatorDesc() const { return std::get(m_data); } + OperatorFieldTypes::FusedActivationOperatorDesc& AsFusedActivationOperatorDesc() { return std::get(m_data); } - const OperatorFieldTypes::OperatorDescArray& AsOperatorDescArray() const { return std::get(m_data); } - OperatorFieldTypes::OperatorDescArray& AsOperatorDescArray() { return std::get(m_data); } + const OperatorFieldTypes::FusedActivationOperatorDescArray& AsFusedActivationOperatorDescArray() const { return std::get(m_data); } + OperatorFieldTypes::FusedActivationOperatorDescArray& AsFusedActivationOperatorDescArray() { return std::get(m_data); } const OperatorFieldTypes::UInt& AsUInt() const { return std::get(m_data); } OperatorFieldTypes::UInt& AsUInt() { return std::get(m_data); } @@ -116,6 +119,9 @@ class OperatorField const OperatorFieldTypes::ScalarUnion& AsScalarUnion() const { return std::get(m_data); } OperatorFieldTypes::ScalarUnion& AsScalarUnion() { return std::get(m_data); } + const OperatorFieldTypes::Bool& AsBool() const { return std::get(m_data); } + OperatorFieldTypes::Bool& AsBool() { return std::get(m_data); } + private: const DML_SCHEMA_FIELD* m_schema; OperatorFieldVariant m_data; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/OperatorFieldTypes_generated.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/OperatorFieldTypes_generated.h new file mode 100644 index 0000000000000..639c31f0dc5c8 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/OperatorFieldTypes_generated.h @@ -0,0 +1,1323 @@ +// automatically generated by the FlatBuffers compiler, do not modify + + +#ifndef FLATBUFFERS_GENERATED_OPERATORFIELDTYPES_DML_IR_OPERATORFIELDTYPES_H_ +#define FLATBUFFERS_GENERATED_OPERATORFIELDTYPES_DML_IR_OPERATORFIELDTYPES_H_ + +#include "core/common/flatbuffers.h" + +// Ensure the included flatbuffers.h is the same version as when this file was +// generated, otherwise it may not be compatible. +static_assert(FLATBUFFERS_VERSION_MAJOR == 23 && + FLATBUFFERS_VERSION_MINOR == 5 && + FLATBUFFERS_VERSION_REVISION == 26, + "Non-compatible flatbuffers version included"); + +namespace dml { +namespace ir { +namespace operatorFieldTypes { + +struct AttributeDesc; +struct AttributeDescBuilder; + +struct Activation; +struct ActivationBuilder; + +struct ActivationArray; +struct ActivationArrayBuilder; + +struct UInt8; + +struct UInt16; + +struct UInt32; + +struct UInt64; + +struct Int8; + +struct Int16; + +struct Int32; + +struct Int64; + +struct Float32; + +struct Float64; + +struct UIntArray; +struct UIntArrayBuilder; + +struct IntArray; +struct IntArrayBuilder; + +struct FloatArray; +struct FloatArrayBuilder; + +struct ScaleBias; + +struct Size2D; + +struct ByteArray; + +struct ScalarUnionData; +struct ScalarUnionDataBuilder; + +struct Bool; + +enum AttributeFieldVariant : uint8_t { + AttributeFieldVariant_NONE = 0, + AttributeFieldVariant_Activation = 1, + AttributeFieldVariant_ActivationArray = 2, + AttributeFieldVariant_UInt32 = 3, + AttributeFieldVariant_UInt64 = 4, + AttributeFieldVariant_Int32 = 5, + AttributeFieldVariant_Float32 = 6, + AttributeFieldVariant_UIntArray = 7, + AttributeFieldVariant_IntArray = 8, + AttributeFieldVariant_FloatArray = 9, + AttributeFieldVariant_ScaleBias = 10, + AttributeFieldVariant_Size2D = 11, + AttributeFieldVariant_ScalarUnionData = 12, + AttributeFieldVariant_Bool = 13, + AttributeFieldVariant_MIN = AttributeFieldVariant_NONE, + AttributeFieldVariant_MAX = AttributeFieldVariant_Bool +}; + +inline const AttributeFieldVariant (&EnumValuesAttributeFieldVariant())[14] { + static const AttributeFieldVariant values[] = { + AttributeFieldVariant_NONE, + AttributeFieldVariant_Activation, + AttributeFieldVariant_ActivationArray, + AttributeFieldVariant_UInt32, + AttributeFieldVariant_UInt64, + AttributeFieldVariant_Int32, + AttributeFieldVariant_Float32, + AttributeFieldVariant_UIntArray, + AttributeFieldVariant_IntArray, + AttributeFieldVariant_FloatArray, + AttributeFieldVariant_ScaleBias, + AttributeFieldVariant_Size2D, + AttributeFieldVariant_ScalarUnionData, + AttributeFieldVariant_Bool + }; + return values; +} + +inline const char * const *EnumNamesAttributeFieldVariant() { + static const char * const names[15] = { + "NONE", + "Activation", + "ActivationArray", + "UInt32", + "UInt64", + "Int32", + "Float32", + "UIntArray", + "IntArray", + "FloatArray", + "ScaleBias", + "Size2D", + "ScalarUnionData", + "Bool", + nullptr + }; + return names; +} + +inline const char *EnumNameAttributeFieldVariant(AttributeFieldVariant e) { + if (::flatbuffers::IsOutRange(e, AttributeFieldVariant_NONE, AttributeFieldVariant_Bool)) return ""; + const size_t index = static_cast(e); + return EnumNamesAttributeFieldVariant()[index]; +} + +template struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_NONE; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_Activation; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_ActivationArray; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_UInt32; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_UInt64; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_Int32; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_Float32; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_UIntArray; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_IntArray; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_FloatArray; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_ScaleBias; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_Size2D; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_ScalarUnionData; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_Bool; +}; + +bool VerifyAttributeFieldVariant(::flatbuffers::Verifier &verifier, const void *obj, AttributeFieldVariant type); +bool VerifyAttributeFieldVariantVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset> *values, const ::flatbuffers::Vector *types); + +enum ScalarVariant : uint8_t { + ScalarVariant_NONE = 0, + ScalarVariant_ByteArray = 1, + ScalarVariant_Int8 = 2, + ScalarVariant_UInt8 = 3, + ScalarVariant_Int16 = 4, + ScalarVariant_UInt16 = 5, + ScalarVariant_Int32 = 6, + ScalarVariant_UInt32 = 7, + ScalarVariant_Int64 = 8, + ScalarVariant_UInt64 = 9, + ScalarVariant_Float32 = 10, + ScalarVariant_Float64 = 11, + ScalarVariant_MIN = ScalarVariant_NONE, + ScalarVariant_MAX = ScalarVariant_Float64 +}; + +inline const ScalarVariant (&EnumValuesScalarVariant())[12] { + static const ScalarVariant values[] = { + ScalarVariant_NONE, + ScalarVariant_ByteArray, + ScalarVariant_Int8, + ScalarVariant_UInt8, + ScalarVariant_Int16, + ScalarVariant_UInt16, + ScalarVariant_Int32, + ScalarVariant_UInt32, + ScalarVariant_Int64, + ScalarVariant_UInt64, + ScalarVariant_Float32, + ScalarVariant_Float64 + }; + return values; +} + +inline const char * const *EnumNamesScalarVariant() { + static const char * const names[13] = { + "NONE", + "ByteArray", + "Int8", + "UInt8", + "Int16", + "UInt16", + "Int32", + "UInt32", + "Int64", + "UInt64", + "Float32", + "Float64", + nullptr + }; + return names; +} + +inline const char *EnumNameScalarVariant(ScalarVariant e) { + if (::flatbuffers::IsOutRange(e, ScalarVariant_NONE, ScalarVariant_Float64)) return ""; + const size_t index = static_cast(e); + return EnumNamesScalarVariant()[index]; +} + +template struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_NONE; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_ByteArray; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_Int8; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_UInt8; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_Int16; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_UInt16; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_Int32; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_UInt32; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_Int64; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_UInt64; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_Float32; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_Float64; +}; + +bool VerifyScalarVariant(::flatbuffers::Verifier &verifier, const void *obj, ScalarVariant type); +bool VerifyScalarVariantVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset> *values, const ::flatbuffers::Vector *types); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(1) UInt8 FLATBUFFERS_FINAL_CLASS { + private: + uint8_t data_; + + public: + UInt8() + : data_(0) { + } + UInt8(uint8_t _data) + : data_(::flatbuffers::EndianScalar(_data)) { + } + uint8_t data() const { + return ::flatbuffers::EndianScalar(data_); + } + void mutate_data(uint8_t _data) { + ::flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(UInt8, 1); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(2) UInt16 FLATBUFFERS_FINAL_CLASS { + private: + uint16_t data_; + + public: + UInt16() + : data_(0) { + } + UInt16(uint16_t _data) + : data_(::flatbuffers::EndianScalar(_data)) { + } + uint16_t data() const { + return ::flatbuffers::EndianScalar(data_); + } + void mutate_data(uint16_t _data) { + ::flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(UInt16, 2); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) UInt32 FLATBUFFERS_FINAL_CLASS { + private: + uint32_t data_; + + public: + UInt32() + : data_(0) { + } + UInt32(uint32_t _data) + : data_(::flatbuffers::EndianScalar(_data)) { + } + uint32_t data() const { + return ::flatbuffers::EndianScalar(data_); + } + void mutate_data(uint32_t _data) { + ::flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(UInt32, 4); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) UInt64 FLATBUFFERS_FINAL_CLASS { + private: + uint64_t data_; + + public: + UInt64() + : data_(0) { + } + UInt64(uint64_t _data) + : data_(::flatbuffers::EndianScalar(_data)) { + } + uint64_t data() const { + return ::flatbuffers::EndianScalar(data_); + } + void mutate_data(uint64_t _data) { + ::flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(UInt64, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(1) Int8 FLATBUFFERS_FINAL_CLASS { + private: + int8_t data_; + + public: + Int8() + : data_(0) { + } + Int8(int8_t _data) + : data_(::flatbuffers::EndianScalar(_data)) { + } + int8_t data() const { + return ::flatbuffers::EndianScalar(data_); + } + void mutate_data(int8_t _data) { + ::flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(Int8, 1); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(2) Int16 FLATBUFFERS_FINAL_CLASS { + private: + int16_t data_; + + public: + Int16() + : data_(0) { + } + Int16(int16_t _data) + : data_(::flatbuffers::EndianScalar(_data)) { + } + int16_t data() const { + return ::flatbuffers::EndianScalar(data_); + } + void mutate_data(int16_t _data) { + ::flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(Int16, 2); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) Int32 FLATBUFFERS_FINAL_CLASS { + private: + int32_t data_; + + public: + Int32() + : data_(0) { + } + Int32(int32_t _data) + : data_(::flatbuffers::EndianScalar(_data)) { + } + int32_t data() const { + return ::flatbuffers::EndianScalar(data_); + } + void mutate_data(int32_t _data) { + ::flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(Int32, 4); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) Int64 FLATBUFFERS_FINAL_CLASS { + private: + int64_t data_; + + public: + Int64() + : data_(0) { + } + Int64(int64_t _data) + : data_(::flatbuffers::EndianScalar(_data)) { + } + int64_t data() const { + return ::flatbuffers::EndianScalar(data_); + } + void mutate_data(int64_t _data) { + ::flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(Int64, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) Float32 FLATBUFFERS_FINAL_CLASS { + private: + float data_; + + public: + Float32() + : data_(0) { + } + Float32(float _data) + : data_(::flatbuffers::EndianScalar(_data)) { + } + float data() const { + return ::flatbuffers::EndianScalar(data_); + } + void mutate_data(float _data) { + ::flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(Float32, 4); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) Float64 FLATBUFFERS_FINAL_CLASS { + private: + double data_; + + public: + Float64() + : data_(0) { + } + Float64(double _data) + : data_(::flatbuffers::EndianScalar(_data)) { + } + double data() const { + return ::flatbuffers::EndianScalar(data_); + } + void mutate_data(double _data) { + ::flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(Float64, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) ScaleBias FLATBUFFERS_FINAL_CLASS { + private: + float scale_; + float bias_; + + public: + ScaleBias() + : scale_(0), + bias_(0) { + } + ScaleBias(float _scale, float _bias) + : scale_(::flatbuffers::EndianScalar(_scale)), + bias_(::flatbuffers::EndianScalar(_bias)) { + } + float scale() const { + return ::flatbuffers::EndianScalar(scale_); + } + void mutate_scale(float _scale) { + ::flatbuffers::WriteScalar(&scale_, _scale); + } + float bias() const { + return ::flatbuffers::EndianScalar(bias_); + } + void mutate_bias(float _bias) { + ::flatbuffers::WriteScalar(&bias_, _bias); + } +}; +FLATBUFFERS_STRUCT_END(ScaleBias, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) Size2D FLATBUFFERS_FINAL_CLASS { + private: + uint32_t width_; + uint32_t height_; + + public: + Size2D() + : width_(0), + height_(0) { + } + Size2D(uint32_t _width, uint32_t _height) + : width_(::flatbuffers::EndianScalar(_width)), + height_(::flatbuffers::EndianScalar(_height)) { + } + uint32_t width() const { + return ::flatbuffers::EndianScalar(width_); + } + void mutate_width(uint32_t _width) { + ::flatbuffers::WriteScalar(&width_, _width); + } + uint32_t height() const { + return ::flatbuffers::EndianScalar(height_); + } + void mutate_height(uint32_t _height) { + ::flatbuffers::WriteScalar(&height_, _height); + } +}; +FLATBUFFERS_STRUCT_END(Size2D, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(1) ByteArray FLATBUFFERS_FINAL_CLASS { + private: + uint8_t data_[8]; + + public: + ByteArray() + : data_() { + } + ByteArray(::flatbuffers::span _data) { + ::flatbuffers::CastToArray(data_).CopyFromSpan(_data); + } + const ::flatbuffers::Array *data() const { + return &::flatbuffers::CastToArray(data_); + } + ::flatbuffers::Array *mutable_data() { + return &::flatbuffers::CastToArray(data_); + } +}; +FLATBUFFERS_STRUCT_END(ByteArray, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(1) Bool FLATBUFFERS_FINAL_CLASS { + private: + uint8_t data_; + + public: + Bool() + : data_(0) { + } + Bool(bool _data) + : data_(::flatbuffers::EndianScalar(static_cast(_data))) { + } + bool data() const { + return ::flatbuffers::EndianScalar(data_) != 0; + } + void mutate_data(bool _data) { + ::flatbuffers::WriteScalar(&data_, static_cast(_data)); + } +}; +FLATBUFFERS_STRUCT_END(Bool, 1); + +struct AttributeDesc FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef AttributeDescBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4, + VT_VAL_TYPE = 6, + VT_VAL = 8 + }; + const ::flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + ::flatbuffers::String *mutable_name() { + return GetPointer<::flatbuffers::String *>(VT_NAME); + } + dml::ir::operatorFieldTypes::AttributeFieldVariant val_type() const { + return static_cast(GetField(VT_VAL_TYPE, 0)); + } + const void *val() const { + return GetPointer(VT_VAL); + } + template const T *val_as() const; + const dml::ir::operatorFieldTypes::Activation *val_as_Activation() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Activation ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::ActivationArray *val_as_ActivationArray() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_ActivationArray ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::UInt32 *val_as_UInt32() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_UInt32 ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::UInt64 *val_as_UInt64() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_UInt64 ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::Int32 *val_as_Int32() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Int32 ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::Float32 *val_as_Float32() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Float32 ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::UIntArray *val_as_UIntArray() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_UIntArray ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::IntArray *val_as_IntArray() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_IntArray ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::FloatArray *val_as_FloatArray() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_FloatArray ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::ScaleBias *val_as_ScaleBias() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_ScaleBias ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::Size2D *val_as_Size2D() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Size2D ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::ScalarUnionData *val_as_ScalarUnionData() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_ScalarUnionData ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::Bool *val_as_Bool() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Bool ? static_cast(val()) : nullptr; + } + void *mutable_val() { + return GetPointer(VT_VAL); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyField(verifier, VT_VAL_TYPE, 1) && + VerifyOffset(verifier, VT_VAL) && + VerifyAttributeFieldVariant(verifier, val(), val_type()) && + verifier.EndTable(); + } +}; + +template<> inline const dml::ir::operatorFieldTypes::Activation *AttributeDesc::val_as() const { + return val_as_Activation(); +} + +template<> inline const dml::ir::operatorFieldTypes::ActivationArray *AttributeDesc::val_as() const { + return val_as_ActivationArray(); +} + +template<> inline const dml::ir::operatorFieldTypes::UInt32 *AttributeDesc::val_as() const { + return val_as_UInt32(); +} + +template<> inline const dml::ir::operatorFieldTypes::UInt64 *AttributeDesc::val_as() const { + return val_as_UInt64(); +} + +template<> inline const dml::ir::operatorFieldTypes::Int32 *AttributeDesc::val_as() const { + return val_as_Int32(); +} + +template<> inline const dml::ir::operatorFieldTypes::Float32 *AttributeDesc::val_as() const { + return val_as_Float32(); +} + +template<> inline const dml::ir::operatorFieldTypes::UIntArray *AttributeDesc::val_as() const { + return val_as_UIntArray(); +} + +template<> inline const dml::ir::operatorFieldTypes::IntArray *AttributeDesc::val_as() const { + return val_as_IntArray(); +} + +template<> inline const dml::ir::operatorFieldTypes::FloatArray *AttributeDesc::val_as() const { + return val_as_FloatArray(); +} + +template<> inline const dml::ir::operatorFieldTypes::ScaleBias *AttributeDesc::val_as() const { + return val_as_ScaleBias(); +} + +template<> inline const dml::ir::operatorFieldTypes::Size2D *AttributeDesc::val_as() const { + return val_as_Size2D(); +} + +template<> inline const dml::ir::operatorFieldTypes::ScalarUnionData *AttributeDesc::val_as() const { + return val_as_ScalarUnionData(); +} + +template<> inline const dml::ir::operatorFieldTypes::Bool *AttributeDesc::val_as() const { + return val_as_Bool(); +} + +struct AttributeDescBuilder { + typedef AttributeDesc Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_name(::flatbuffers::Offset<::flatbuffers::String> name) { + fbb_.AddOffset(AttributeDesc::VT_NAME, name); + } + void add_val_type(dml::ir::operatorFieldTypes::AttributeFieldVariant val_type) { + fbb_.AddElement(AttributeDesc::VT_VAL_TYPE, static_cast(val_type), 0); + } + void add_val(::flatbuffers::Offset val) { + fbb_.AddOffset(AttributeDesc::VT_VAL, val); + } + explicit AttributeDescBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateAttributeDesc( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::String> name = 0, + dml::ir::operatorFieldTypes::AttributeFieldVariant val_type = dml::ir::operatorFieldTypes::AttributeFieldVariant_NONE, + ::flatbuffers::Offset val = 0) { + AttributeDescBuilder builder_(_fbb); + builder_.add_val(val); + builder_.add_name(name); + builder_.add_val_type(val_type); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateAttributeDescDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr, + dml::ir::operatorFieldTypes::AttributeFieldVariant val_type = dml::ir::operatorFieldTypes::AttributeFieldVariant_NONE, + ::flatbuffers::Offset val = 0) { + auto name__ = name ? _fbb.CreateString(name) : 0; + return dml::ir::operatorFieldTypes::CreateAttributeDesc( + _fbb, + name__, + val_type, + val); +} + +struct Activation FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ActivationBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_TYPE = 4, + VT_ATTRIBUTES = 6 + }; + const ::flatbuffers::String *type() const { + return GetPointer(VT_TYPE); + } + ::flatbuffers::String *mutable_type() { + return GetPointer<::flatbuffers::String *>(VT_TYPE); + } + const ::flatbuffers::Vector<::flatbuffers::Offset> *attributes() const { + return GetPointer> *>(VT_ATTRIBUTES); + } + ::flatbuffers::Vector<::flatbuffers::Offset> *mutable_attributes() { + return GetPointer<::flatbuffers::Vector<::flatbuffers::Offset> *>(VT_ATTRIBUTES); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_TYPE) && + verifier.VerifyString(type()) && + VerifyOffset(verifier, VT_ATTRIBUTES) && + verifier.VerifyVector(attributes()) && + verifier.VerifyVectorOfTables(attributes()) && + verifier.EndTable(); + } +}; + +struct ActivationBuilder { + typedef Activation Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_type(::flatbuffers::Offset<::flatbuffers::String> type) { + fbb_.AddOffset(Activation::VT_TYPE, type); + } + void add_attributes(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> attributes) { + fbb_.AddOffset(Activation::VT_ATTRIBUTES, attributes); + } + explicit ActivationBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateActivation( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::String> type = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> attributes = 0) { + ActivationBuilder builder_(_fbb); + builder_.add_attributes(attributes); + builder_.add_type(type); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateActivationDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const char *type = nullptr, + const std::vector<::flatbuffers::Offset> *attributes = nullptr) { + auto type__ = type ? _fbb.CreateString(type) : 0; + auto attributes__ = attributes ? _fbb.CreateVector<::flatbuffers::Offset>(*attributes) : 0; + return dml::ir::operatorFieldTypes::CreateActivation( + _fbb, + type__, + attributes__); +} + +struct ActivationArray FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ActivationArrayBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA = 4 + }; + const ::flatbuffers::Vector<::flatbuffers::Offset> *data() const { + return GetPointer> *>(VT_DATA); + } + ::flatbuffers::Vector<::flatbuffers::Offset> *mutable_data() { + return GetPointer<::flatbuffers::Vector<::flatbuffers::Offset> *>(VT_DATA); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.VerifyVector(data()) && + verifier.VerifyVectorOfTables(data()) && + verifier.EndTable(); + } +}; + +struct ActivationArrayBuilder { + typedef ActivationArray Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_data(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> data) { + fbb_.AddOffset(ActivationArray::VT_DATA, data); + } + explicit ActivationArrayBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateActivationArray( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> data = 0) { + ActivationArrayBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateActivationArrayDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<::flatbuffers::Offset> *data = nullptr) { + auto data__ = data ? _fbb.CreateVector<::flatbuffers::Offset>(*data) : 0; + return dml::ir::operatorFieldTypes::CreateActivationArray( + _fbb, + data__); +} + +struct UIntArray FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef UIntArrayBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA = 4 + }; + const ::flatbuffers::Vector *data() const { + return GetPointer *>(VT_DATA); + } + ::flatbuffers::Vector *mutable_data() { + return GetPointer<::flatbuffers::Vector *>(VT_DATA); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.VerifyVector(data()) && + verifier.EndTable(); + } +}; + +struct UIntArrayBuilder { + typedef UIntArray Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_data(::flatbuffers::Offset<::flatbuffers::Vector> data) { + fbb_.AddOffset(UIntArray::VT_DATA, data); + } + explicit UIntArrayBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateUIntArray( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> data = 0) { + UIntArrayBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateUIntArrayDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *data = nullptr) { + auto data__ = data ? _fbb.CreateVector(*data) : 0; + return dml::ir::operatorFieldTypes::CreateUIntArray( + _fbb, + data__); +} + +struct IntArray FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef IntArrayBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA = 4 + }; + const ::flatbuffers::Vector *data() const { + return GetPointer *>(VT_DATA); + } + ::flatbuffers::Vector *mutable_data() { + return GetPointer<::flatbuffers::Vector *>(VT_DATA); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.VerifyVector(data()) && + verifier.EndTable(); + } +}; + +struct IntArrayBuilder { + typedef IntArray Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_data(::flatbuffers::Offset<::flatbuffers::Vector> data) { + fbb_.AddOffset(IntArray::VT_DATA, data); + } + explicit IntArrayBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateIntArray( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> data = 0) { + IntArrayBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateIntArrayDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *data = nullptr) { + auto data__ = data ? _fbb.CreateVector(*data) : 0; + return dml::ir::operatorFieldTypes::CreateIntArray( + _fbb, + data__); +} + +struct FloatArray FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef FloatArrayBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA = 4 + }; + const ::flatbuffers::Vector *data() const { + return GetPointer *>(VT_DATA); + } + ::flatbuffers::Vector *mutable_data() { + return GetPointer<::flatbuffers::Vector *>(VT_DATA); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.VerifyVector(data()) && + verifier.EndTable(); + } +}; + +struct FloatArrayBuilder { + typedef FloatArray Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_data(::flatbuffers::Offset<::flatbuffers::Vector> data) { + fbb_.AddOffset(FloatArray::VT_DATA, data); + } + explicit FloatArrayBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateFloatArray( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> data = 0) { + FloatArrayBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateFloatArrayDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *data = nullptr) { + auto data__ = data ? _fbb.CreateVector(*data) : 0; + return dml::ir::operatorFieldTypes::CreateFloatArray( + _fbb, + data__); +} + +struct ScalarUnionData FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ScalarUnionDataBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA_TYPE = 4, + VT_DATA = 6 + }; + dml::ir::operatorFieldTypes::ScalarVariant data_type() const { + return static_cast(GetField(VT_DATA_TYPE, 0)); + } + const void *data() const { + return GetPointer(VT_DATA); + } + template const T *data_as() const; + const dml::ir::operatorFieldTypes::ByteArray *data_as_ByteArray() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_ByteArray ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::Int8 *data_as_Int8() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Int8 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::UInt8 *data_as_UInt8() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_UInt8 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::Int16 *data_as_Int16() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Int16 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::UInt16 *data_as_UInt16() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_UInt16 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::Int32 *data_as_Int32() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Int32 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::UInt32 *data_as_UInt32() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_UInt32 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::Int64 *data_as_Int64() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Int64 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::UInt64 *data_as_UInt64() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_UInt64 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::Float32 *data_as_Float32() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Float32 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::Float64 *data_as_Float64() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Float64 ? static_cast(data()) : nullptr; + } + void *mutable_data() { + return GetPointer(VT_DATA); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_DATA_TYPE, 1) && + VerifyOffset(verifier, VT_DATA) && + VerifyScalarVariant(verifier, data(), data_type()) && + verifier.EndTable(); + } +}; + +template<> inline const dml::ir::operatorFieldTypes::ByteArray *ScalarUnionData::data_as() const { + return data_as_ByteArray(); +} + +template<> inline const dml::ir::operatorFieldTypes::Int8 *ScalarUnionData::data_as() const { + return data_as_Int8(); +} + +template<> inline const dml::ir::operatorFieldTypes::UInt8 *ScalarUnionData::data_as() const { + return data_as_UInt8(); +} + +template<> inline const dml::ir::operatorFieldTypes::Int16 *ScalarUnionData::data_as() const { + return data_as_Int16(); +} + +template<> inline const dml::ir::operatorFieldTypes::UInt16 *ScalarUnionData::data_as() const { + return data_as_UInt16(); +} + +template<> inline const dml::ir::operatorFieldTypes::Int32 *ScalarUnionData::data_as() const { + return data_as_Int32(); +} + +template<> inline const dml::ir::operatorFieldTypes::UInt32 *ScalarUnionData::data_as() const { + return data_as_UInt32(); +} + +template<> inline const dml::ir::operatorFieldTypes::Int64 *ScalarUnionData::data_as() const { + return data_as_Int64(); +} + +template<> inline const dml::ir::operatorFieldTypes::UInt64 *ScalarUnionData::data_as() const { + return data_as_UInt64(); +} + +template<> inline const dml::ir::operatorFieldTypes::Float32 *ScalarUnionData::data_as() const { + return data_as_Float32(); +} + +template<> inline const dml::ir::operatorFieldTypes::Float64 *ScalarUnionData::data_as() const { + return data_as_Float64(); +} + +struct ScalarUnionDataBuilder { + typedef ScalarUnionData Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_data_type(dml::ir::operatorFieldTypes::ScalarVariant data_type) { + fbb_.AddElement(ScalarUnionData::VT_DATA_TYPE, static_cast(data_type), 0); + } + void add_data(::flatbuffers::Offset data) { + fbb_.AddOffset(ScalarUnionData::VT_DATA, data); + } + explicit ScalarUnionDataBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateScalarUnionData( + ::flatbuffers::FlatBufferBuilder &_fbb, + dml::ir::operatorFieldTypes::ScalarVariant data_type = dml::ir::operatorFieldTypes::ScalarVariant_NONE, + ::flatbuffers::Offset data = 0) { + ScalarUnionDataBuilder builder_(_fbb); + builder_.add_data(data); + builder_.add_data_type(data_type); + return builder_.Finish(); +} + +inline bool VerifyAttributeFieldVariant(::flatbuffers::Verifier &verifier, const void *obj, AttributeFieldVariant type) { + switch (type) { + case AttributeFieldVariant_NONE: { + return true; + } + case AttributeFieldVariant_Activation: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case AttributeFieldVariant_ActivationArray: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case AttributeFieldVariant_UInt32: { + return verifier.VerifyField(static_cast(obj), 0, 4); + } + case AttributeFieldVariant_UInt64: { + return verifier.VerifyField(static_cast(obj), 0, 8); + } + case AttributeFieldVariant_Int32: { + return verifier.VerifyField(static_cast(obj), 0, 4); + } + case AttributeFieldVariant_Float32: { + return verifier.VerifyField(static_cast(obj), 0, 4); + } + case AttributeFieldVariant_UIntArray: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case AttributeFieldVariant_IntArray: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case AttributeFieldVariant_FloatArray: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case AttributeFieldVariant_ScaleBias: { + return verifier.VerifyField(static_cast(obj), 0, 4); + } + case AttributeFieldVariant_Size2D: { + return verifier.VerifyField(static_cast(obj), 0, 4); + } + case AttributeFieldVariant_ScalarUnionData: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case AttributeFieldVariant_Bool: { + return verifier.VerifyField(static_cast(obj), 0, 1); + } + default: return true; + } +} + +inline bool VerifyAttributeFieldVariantVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset> *values, const ::flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (::flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyAttributeFieldVariant( + verifier, values->Get(i), types->GetEnum(i))) { + return false; + } + } + return true; +} + +inline bool VerifyScalarVariant(::flatbuffers::Verifier &verifier, const void *obj, ScalarVariant type) { + switch (type) { + case ScalarVariant_NONE: { + return true; + } + case ScalarVariant_ByteArray: { + return verifier.VerifyField(static_cast(obj), 0, 1); + } + case ScalarVariant_Int8: { + return verifier.VerifyField(static_cast(obj), 0, 1); + } + case ScalarVariant_UInt8: { + return verifier.VerifyField(static_cast(obj), 0, 1); + } + case ScalarVariant_Int16: { + return verifier.VerifyField(static_cast(obj), 0, 2); + } + case ScalarVariant_UInt16: { + return verifier.VerifyField(static_cast(obj), 0, 2); + } + case ScalarVariant_Int32: { + return verifier.VerifyField(static_cast(obj), 0, 4); + } + case ScalarVariant_UInt32: { + return verifier.VerifyField(static_cast(obj), 0, 4); + } + case ScalarVariant_Int64: { + return verifier.VerifyField(static_cast(obj), 0, 8); + } + case ScalarVariant_UInt64: { + return verifier.VerifyField(static_cast(obj), 0, 8); + } + case ScalarVariant_Float32: { + return verifier.VerifyField(static_cast(obj), 0, 4); + } + case ScalarVariant_Float64: { + return verifier.VerifyField(static_cast(obj), 0, 8); + } + default: return true; + } +} + +inline bool VerifyScalarVariantVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset> *values, const ::flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (::flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyScalarVariant( + verifier, values->Get(i), types->GetEnum(i))) { + return false; + } + } + return true; +} + +} // namespace operatorFieldTypes +} // namespace ir +} // namespace dml + +#endif // FLATBUFFERS_GENERATED_OPERATORFIELDTYPES_DML_IR_OPERATORFIELDTYPES_H_ diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h index 5285481485184..1bc694dfe90c2 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h @@ -26,14 +26,14 @@ namespace SchemaHelpers return field; } - inline OperatorFieldTypes::OperatorDesc ToOperatorFieldType(const DML_OPERATOR_DESC* value) + inline OperatorFieldTypes::FusedActivationOperatorDesc ToOperatorFieldType(const DML_OPERATOR_DESC* value) { - return value ? OperatorFieldTypes::OperatorDesc(ConvertOperatorDesc(*value)) : std::nullopt; + return value ? OperatorFieldTypes::FusedActivationOperatorDesc(ConvertOperatorDesc(*value)) : std::nullopt; } - inline OperatorFieldTypes::OperatorDescArray ToOperatorFieldType(const DML_OPERATOR_DESC* values, uint32_t count) + inline OperatorFieldTypes::FusedActivationOperatorDescArray ToOperatorFieldType(const DML_OPERATOR_DESC* values, uint32_t count) { - OperatorFieldTypes::OperatorDescArray field; + OperatorFieldTypes::FusedActivationOperatorDescArray field; if (values && count != 0) { field.emplace(count); @@ -65,13 +65,17 @@ namespace SchemaHelpers return value; } + inline OperatorFieldTypes::Bool ToOperatorFieldType(bool value) + { + return value; + } + inline OperatorFieldTypes::UIntArray ToOperatorFieldType(const uint32_t* values, uint32_t count) { OperatorFieldTypes::UIntArray field; if (values && count != 0) { - field.emplace(count); - std::copy_n(values, count, field->begin()); + field.assign(values, values + count); } return field; } @@ -81,8 +85,7 @@ namespace SchemaHelpers OperatorFieldTypes::IntArray field; if (values && count != 0) { - field.emplace(count); - std::copy_n(values, count, field->begin()); + field.assign(values, values + count); } return field; } @@ -92,8 +95,7 @@ namespace SchemaHelpers OperatorFieldTypes::FloatArray field; if (values && count != 0) { - field.emplace(count); - std::copy_n(values, count, field->begin()); + field.assign(values, values + count); } return field; } @@ -237,7 +239,7 @@ namespace SchemaHelpers { DML_OPERATOR_DESC* desc = nullptr; - const auto& value = field.AsOperatorDesc(); + const auto& value = field.AsFusedActivationOperatorDesc(); if (value) { desc = allocator->template Allocate(); @@ -251,7 +253,7 @@ namespace SchemaHelpers { DML_OPERATOR_DESC* descs = nullptr; - const auto& values = field.AsOperatorDescArray(); + const auto& values = field.AsFusedActivationOperatorDescArray(); if (values) { descs = allocator->template Allocate(values->size()); @@ -288,16 +290,20 @@ namespace SchemaHelpers dst->Write(value); } break; + case DML_SCHEMA_FIELD_TYPE_BOOL: + { + // OperatorFieldTypes::Bool is a 'bool' (1 byte) but written as 'BOOL' in op descs (4 bytes). + BOOL value = static_cast(field.AsBool()); + dst->Write(value); + } break; + case DML_SCHEMA_FIELD_TYPE_UINT_ARRAY: { uint32_t* arrayPtr = nullptr; const auto& values = field.AsUIntArray(); - if (values) - { - arrayPtr = allocator->template Allocate(values->size()); - std::copy(values->begin(), values->end(), arrayPtr); - } + arrayPtr = allocator->template Allocate(values.size()); + std::copy(values.begin(), values.end(), arrayPtr); dst->Write(arrayPtr); } break; @@ -307,11 +313,8 @@ namespace SchemaHelpers int32_t* arrayPtr = nullptr; const auto& values = field.AsIntArray(); - if (values) - { - arrayPtr = allocator->template Allocate(values->size()); - std::copy(values->begin(), values->end(), arrayPtr); - } + arrayPtr = allocator->template Allocate(values.size()); + std::copy(values.begin(), values.end(), arrayPtr); dst->Write(arrayPtr); } break; @@ -321,11 +324,8 @@ namespace SchemaHelpers float* arrayPtr = nullptr; const auto& values = field.AsFloatArray(); - if (values) - { - arrayPtr = allocator->template Allocate(values->size()); - std::copy(values->begin(), values->end(), arrayPtr); - } + arrayPtr = allocator->template Allocate(values.size()); + std::copy(values.begin(), values.end(), arrayPtr); dst->Write(arrayPtr); } break; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index c6a15e76f4736..e6f008af5c23f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -33,10 +33,10 @@ namespace Dml::GraphDescBuilder #pragma warning(pop) static void RemoveUnconnectedNodes( - std::vector& graphNodes, - std::vector& graphInputEdges, - std::vector& graphIntermediateEdges, - std::vector& graphOutputEdges) + std::vector& graphNodes, + std::vector& graphInputEdges, + std::vector& graphIntermediateEdges, + std::vector& graphOutputEdges) { enum class NodeState { @@ -52,7 +52,7 @@ namespace Dml::GraphDescBuilder }; std::vector nodesData(graphNodes.size()); - for (const DML_INTERMEDIATE_GRAPH_EDGE_DESC& intermediateEdge : graphIntermediateEdges) + for (const DmlIntermediateSerializedGraphEdge& intermediateEdge : graphIntermediateEdges) { nodesData[intermediateEdge.ToNodeIndex].predecessorIndices.push_back(intermediateEdge.FromNodeIndex); } @@ -60,7 +60,7 @@ namespace Dml::GraphDescBuilder std::stack nodeIndicesToVisit; // Start from the outputs of the graph and traverse upwards - for (const DML_OUTPUT_GRAPH_EDGE_DESC& outputEdge : graphOutputEdges) + for (const DmlOutputSerializedGraphEdge& outputEdge : graphOutputEdges) { nodeIndicesToVisit.push(outputEdge.FromNodeIndex); } @@ -143,17 +143,44 @@ namespace Dml::GraphDescBuilder } } + + uint32_t SetAndGetDmlGraphNodeIndex( + const uint32_t operatorDmlGraphNodeIndex, + const std::string& nodeNamePrefix, + AbstractOperatorDesc& operatorDesc, + /*in_out*/std::unordered_map& operatorDmlGraphToDmlGraphNodeIndexMap, + /*in_out*/std::vector& dmlGraphNodes) + { + auto iter = operatorDmlGraphToDmlGraphNodeIndexMap.find(operatorDmlGraphNodeIndex); + if (iter != operatorDmlGraphToDmlGraphNodeIndexMap.end()) + { + return iter->second; + } + operatorDmlGraphToDmlGraphNodeIndexMap[operatorDmlGraphNodeIndex] = static_cast(dmlGraphNodes.size()); + dmlGraphNodes.push_back({operatorDesc, nodeNamePrefix + std::to_string(operatorDmlGraphNodeIndex)}); + return operatorDmlGraphToDmlGraphNodeIndexMap[operatorDmlGraphNodeIndex]; + } + + // Terminology: + // Subgraph: partitioned ONNX graph from the original (main) ONNX graph + // DmlGraph: a graph in DML currency converted from subgraph. + // operatorDmlGraph: a graph in DML currency for a given node or operator + // Main Points to note: + // - GraphDesc will always has sequential indices for input and intermediate edges. + // - 1 onnx node can be converted to one or more dml nodes. GraphDesc BuildGraphDesc( const uint8_t* isConstGpuGraphInput, const size_t isConstGpuGraphInputCount, const std::unordered_map>& isInitializerTransferable, const std::unordered_map& graphNodePropertyMap, - IDMLDevice* device, const ExecutionProviderImpl* executionHandle, const onnxruntime::Path& modelPath, gsl::span subgraphNodes, gsl::span subgraphInputs, - gsl::span subgraphOutputs) + gsl::span subgraphOutputs, + /*out*/ std::unordered_map& serializedGraphInputIndexToSubgraphInputIndex, + /*out*/ std::unordered_map& serializedGraphLargeConstantNameToSubgraphInputIndex, + /*out*/ std::vector>& smallConstantData) { struct NodeAndIndex { @@ -161,19 +188,34 @@ namespace Dml::GraphDescBuilder uint32_t targetIndex; // The index of the input/output on the node (e.g. 1 for the second input on a node) }; - // Map from Lotus node argument names to the new node and index where it will be produced - std::unordered_map nameToNodeAndIndexMap; - std::unordered_map nodeOutputShapes; - // Map from Lotus node argument names to input indices of the fused kernel node. - std::unordered_map nameToDmlFusedNodeInputIndex; + // Map from ORT subgraph input names to indices + std::unordered_map subgraphInputNameToIndexMap; + + // - Map from ORT node's output names to DmlGraph . + // - Once a given ORT node (or operator) will be transformed into a operatorDmlGraph, + // then ORT node's output names will become output edges for the operatorDmlGraph. + // - This map will be populated for those output edges. + std::unordered_map dmlGraphNodeOutputNameToNodeAndIndexMap; + + // This map will be used to re-index an subGraphInputIndex to sequential input index + // for DmlGraph + std::unordered_map subGraphInputIndexToDmlGraphInputIndex; + + // Iterate through each node and create a corresponding node in the new graph + // We can iterate the nodes in any order because the edge connectivity will take care of the topological order + std::unordered_map> inferredOutputShapes; + + std::vector dmlGraphNodes; + std::vector dmlGraphInputEdges; + std::vector dmlGraphIntermediateEdges; + std::vector dmlGraphOutputEdges; for (size_t inputIndex = 0; inputIndex < subgraphInputs.size(); ++inputIndex) { - const onnxruntime::NodeArg* graphInput = subgraphInputs[inputIndex]; - - if (!graphInput) + const onnxruntime::NodeArg* subgraphInput = subgraphInputs[inputIndex]; + if (!subgraphInput) { // This is a workaround for when node inputs get manipulated by transformers outside of our control, // which then causes them to have a different name. If that happens we can't figure out how to @@ -181,45 +223,21 @@ namespace Dml::GraphDescBuilder // just bail early. ORT_THROW_HR(E_UNEXPECTED); } - - nameToDmlFusedNodeInputIndex.emplace(graphInput->Name(), gsl::narrow_cast(inputIndex)); - } - - StackAllocator<1024> allocator; // Used for converting abstract operator descs into DML_OPERATOR_DESC - - std::vector graphNodes; - std::vector graphInputEdges; - std::vector graphIntermediateEdges; - std::vector graphOutputEdges; - - // Avoid using separate command lists for small graphs. This value can be reduced by tuning the - // flushing behavior of DmlCommandRecorder. Its current behavior is to assume that graphs contain - // enough GPU work to be worth flushing immediately. - const uint32_t minNodeCountToReuseCommandList = 5; - bool reuseCommandList = false; - - if (subgraphNodes.size() >= minNodeCountToReuseCommandList || executionHandle->IsMcdmDevice()) - { - reuseCommandList = true; + subgraphInputNameToIndexMap.emplace(subgraphInput->Name(), gsl::narrow_cast(inputIndex)); } auto constantCpuGraphInputGetter = [&isInitializerTransferable, &modelPath](const std::string& argName) { ComPtr tensorWrapper; - auto iter = isInitializerTransferable.find(argName); if (iter != isInitializerTransferable.end()) { // Using const_cast here is simpler than making surrounding code const correct. tensorWrapper = wil::MakeOrThrow(const_cast(iter->second.first), modelPath); } - return tensorWrapper; }; - // Iterate through each node and create a corresponding node in the new graph - // We can iterate the nodes in any order because the edge connectivity will take care of the topological order - std::unordered_map> inferredOutputShapes; for (const onnxruntime::Node* subgraphNode : subgraphNodes) { @@ -277,190 +295,206 @@ namespace Dml::GraphDescBuilder } EdgeShapes outputShapes; - DmlGraphNodeCreateInfo graphNodeCreateInfo; + DmlGraphNodeCreateInfo operatorDmlGraphCreateInfo; graphNodeProps.internalRegInfo->graphNodeFactoryRegistration->factory( node, constantCpuNodeInputGetter, executionHandle, &inputShapesOverrides, /*out*/ &outputShapes, - /*out*/ &graphNodeCreateInfo + /*out*/ &operatorDmlGraphCreateInfo ); ORT_THROW_HR_IF(E_UNEXPECTED, outputShapes.EdgeCount() != node.OutputDefs().size()); for (int i = 0; i < node.OutputDefs().size(); ++i) { inferredOutputShapes[node.OutputDefs()[i]->Name()] = outputShapes.GetShape(i); - } - - // Create a map between operatorGraphNodeIndex to mainGraphNodeIndex. - std::unordered_map operatorGraphNodeIndexToMainGraphNodeIndexMap; - uint32_t graphNodeCount = gsl::narrow_cast(graphNodes.size()); - const bool isNodeAsOpDesc = graphNodeCreateInfo.nodesAsOperatorDesc.size() > 0; - size_t firstOpDescGraphNodeIndex = graphNodes.size(); - - if (isNodeAsOpDesc) + } + + // Algorithm: + // 1. Create constant nodes by iterating through operatorDmlGraph's input edges and keep a map of it, + // because there would be an intermediate edge from the constantNode and source of the intermediate edge + // should come before the destination. + // 2. Again iterate through operatorDmlGraph's input edges to create mainGraph's input and intermediate edges. + // 3. Iterate through operatorDmlGraph's intermediate edges to create mainGraph's intermediate edges. + // 4. Iterate through operatorDmlGraph's output edges to populate outputEdgeNameToDmlGraphNodeAndIndex + // 5. While performing step 2, 3, and 4, insert operatorDmlGraphNode to the mainDmlGraphNode list. + + for (auto& operatorDmlGraphInputEdge : operatorDmlGraphCreateInfo.inputEdges) { - // Can't populate graphNodes vector at this point, because operatorDesc may get modified later. - for (uint32_t nodeIndex = 0; nodeIndex < graphNodeCreateInfo.nodeCount; nodeIndex++) + const onnxruntime::NodeArg* arg = node.InputDefs()[operatorDmlGraphInputEdge.GraphInputIndex]; + if (arg->Exists()) { - ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsOperatorDesc[nodeIndex]); - operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++); - } + auto iter = subgraphInputNameToIndexMap.find(arg->Name()); + if (iter != subgraphInputNameToIndexMap.end() && + iter->second < isConstGpuGraphInputCount && + isConstGpuGraphInput[iter->second]) + { + DmlSerializedGraphNode constantNode = {}; + constantNode.Name = arg->Name(); + + // This is a highly inefficient approach to generating constant nodes. It duplicates constant data + // across the graph input as well as every consumer's unique constant node. However it is currently + // only used for small inputs. + auto& operatorDmlGraphInputNode = operatorDmlGraphCreateInfo.nodes[operatorDmlGraphInputEdge.ToNodeIndex]; + std::vector toNodeInputTensorDescs = operatorDmlGraphInputNode->GetInputTensors(); + DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorDmlGraphInputEdge.ToNodeInputIndex]; + ComPtr constantInput; + + if (tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize) + { + constantInput = constantCpuGraphInputGetter(arg->Name()); + } - graphNodes.resize(graphNodes.size() + graphNodeCreateInfo.nodeCount); - } - else - { - for (uint32_t nodeIndex = 0; nodeIndex < graphNodeCreateInfo.nodeCount; nodeIndex++) - { - ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex].Get()); - operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++); - NodeInfo nodeInfo = {}; - nodeInfo.nodeDef = std::move(graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex]); - graphNodes.push_back(std::move(nodeInfo)); + if (constantInput) + { + // The tensor description's size should be no larger than the constant input unless it was rounded to + // the required alignment. + assert(((constantInput->GetTensorByteSize() + 3) & ~3) >= tensorDesc->totalTensorSizeInBytes); + size_t minimumConstantSize = std::min(constantInput->GetTensorByteSize(), gsl::narrow_cast(tensorDesc->totalTensorSizeInBytes)); + auto data = static_cast(constantInput->GetData()); + std::vector tensorData(data, data + minimumConstantSize); + + smallConstantData.push_back(std::make_unique(tensorData.size())); + std::transform(tensorData.begin(), tensorData.end(), smallConstantData.back().get(), [](uint8_t b) {return static_cast(b);}); + + ConstantData constantData = {smallConstantData.back().get(), tensorData.size()}; + constantNode.Desc = constantData; + } + else + { + ConstantName constantFileName = {GetSanitizedFileName(arg->Name())}; + constantNode.Desc = constantFileName; + } + dmlGraphNodeOutputNameToNodeAndIndexMap[arg->Name()] = {static_cast(dmlGraphNodes.size()), 0}; + dmlGraphNodes.push_back(constantNode); + } } } - // map operatorGraphInputEdge as either mainGraphInputEdge or mainGraphIntermediateEdge - for (auto& operatorGraphInputEdge : graphNodeCreateInfo.inputEdges) - { - // operatorGraphInputEdge.GraphInputIndex will be the ONNX input index. - const onnxruntime::NodeArg* arg = node.InputDefs()[operatorGraphInputEdge.GraphInputIndex]; + // Create a map between operatorGraphNodeIndex to dmlGraphNodeIndex. + std::unordered_map operatorDmlGraphToDmlGraphNodeIndexMap; + // map operatorDmlGraphInputEdge as either mainDmlGraphInputEdge or mainDmlGraphIntermediateEdge + for (auto& operatorDmlGraphInputEdge : operatorDmlGraphCreateInfo.inputEdges) + { + // operatorDmlGraphInputEdge.GraphInputIndex will be the ONNX input index. + const onnxruntime::NodeArg* arg = node.InputDefs()[operatorDmlGraphInputEdge.GraphInputIndex]; if (arg->Exists()) { - auto iter = nameToDmlFusedNodeInputIndex.find(arg->Name()); - uint32_t mainGraphNodeIndex = operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphInputEdge.ToNodeIndex]; - - if (iter != nameToDmlFusedNodeInputIndex.end()) + uint32_t dmlGraphNodeIndex = SetAndGetDmlGraphNodeIndex( + operatorDmlGraphInputEdge.ToNodeIndex, + node.Name(), + *operatorDmlGraphCreateInfo.nodes[operatorDmlGraphInputEdge.ToNodeIndex], + operatorDmlGraphToDmlGraphNodeIndexMap, + dmlGraphNodes); + + auto iter = subgraphInputNameToIndexMap.find(arg->Name()); + if (iter != subgraphInputNameToIndexMap.end()) { - // This is a graph input - - const uint32_t dmlFusedNodeInputIndex = iter->second; - - // If this is a constant input, set the appropriate flags on the desc - if (isNodeAsOpDesc && - dmlFusedNodeInputIndex < isConstGpuGraphInputCount && - isConstGpuGraphInput[dmlFusedNodeInputIndex]) + const uint32_t subgraphInputIndex = iter->second; + + // Either this edge will be + // a constant input, then it will be an intermediate edge and + // set the OWNED_BY_DML flag if it is large constant + // or, + // a non-constant input, then it will be a mainDmlGraphInputEdge. + if (subgraphInputIndex < isConstGpuGraphInputCount && + isConstGpuGraphInput[subgraphInputIndex]) { - // This is a highly inefficient approach to generating constant nodes. It duplicates constant data - // across the graph input as well as every consumer's unique constant node. However it is currently - // only used for small inputs. - uint32_t c_maxConstNodeDataSize = 8; - - ComPtr constantInput = constantCpuGraphInputGetter(arg->Name()); - - auto& operatorGraphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex]; - std::vector toNodeInputTensorDescs = operatorGraphInputNode->GetInputTensors(); - DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex]; - - if (constantInput && tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize) - { - // The tensor description's size should be no larger than the constant input unless it was rounded to - // the required alignment. - assert(((constantInput->GetTensorByteSize() + 3) & ~3) >= tensorDesc->totalTensorSizeInBytes); - size_t minimumConstantSize = std::min(constantInput->GetTensorByteSize(), gsl::narrow_cast(tensorDesc->totalTensorSizeInBytes)); - auto data = static_cast(constantInput->GetData()); - std::vector tensorData(data, data + minimumConstantSize); - - NodeInfo nodeInfo = {}; - nodeInfo.nodeDef = std::move(tensorData); - graphNodes.push_back(std::move(nodeInfo)); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC edge = {}; - edge.FromNodeIndex = static_cast(graphNodes.size() - 1); - edge.FromNodeOutputIndex = 0; - edge.ToNodeIndex = mainGraphNodeIndex; - edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; - graphIntermediateEdges.push_back(edge); - } - else + const auto& constantNodeAndIndex = dmlGraphNodeOutputNameToNodeAndIndexMap.at(arg->Name()); + auto& constantNodeVariant = std::get(dmlGraphNodes[constantNodeAndIndex.nodeIndex].Desc); + if (std::holds_alternative(constantNodeVariant)) { - DML_INPUT_GRAPH_EDGE_DESC edge = {}; - edge.GraphInputIndex = dmlFusedNodeInputIndex; - edge.ToNodeIndex = mainGraphNodeIndex; - edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; - graphInputEdges.push_back(edge); - + auto& mainDmlGraphNode = dmlGraphNodes[dmlGraphNodeIndex]; + AbstractOperatorDesc& abstractOperatorDesc = std::get(mainDmlGraphNode.Desc); + std::vector toNodeInputTensorDescs = abstractOperatorDesc.GetInputTensors(); + DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorDmlGraphInputEdge.ToNodeInputIndex]; tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML; + serializedGraphLargeConstantNameToSubgraphInputIndex[arg->Name()] = subgraphInputIndex; } + + DmlIntermediateSerializedGraphEdge edge = {}; + edge.FromNodeIndex = constantNodeAndIndex.nodeIndex; + edge.FromNodeOutputIndex = constantNodeAndIndex.targetIndex; + edge.ToNodeIndex = dmlGraphNodeIndex; + edge.ToNodeInputIndex = operatorDmlGraphInputEdge.ToNodeInputIndex; + edge.Name = arg->Name() + "-nodeIdx:" + std::to_string(edge.FromNodeIndex) + "-outputIdx:" + std::to_string(edge.FromNodeOutputIndex); + dmlGraphIntermediateEdges.push_back(edge); } else { - DML_INPUT_GRAPH_EDGE_DESC edge = {}; - edge.GraphInputIndex = dmlFusedNodeInputIndex; - edge.ToNodeIndex = mainGraphNodeIndex; - edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; - graphInputEdges.push_back(edge); + DmlInputSerializedGraphEdge edge = {}; + if (subGraphInputIndexToDmlGraphInputIndex.find(subgraphInputIndex) == subGraphInputIndexToDmlGraphInputIndex.end()) + { + subGraphInputIndexToDmlGraphInputIndex[subgraphInputIndex] = static_cast(subGraphInputIndexToDmlGraphInputIndex.size()); + } + + edge.GraphInputIndex = subGraphInputIndexToDmlGraphInputIndex[subgraphInputIndex]; + edge.ToNodeIndex = dmlGraphNodeIndex; + edge.ToNodeInputIndex = operatorDmlGraphInputEdge.ToNodeInputIndex; // ?? might need to point inputIndex + edge.Name = arg->Name(); + + serializedGraphInputIndexToSubgraphInputIndex[edge.GraphInputIndex] = subgraphInputIndex; + dmlGraphInputEdges.push_back(edge); } } else { - const auto& inputNodeAndIndex = nameToNodeAndIndexMap.at(arg->Name()); + const auto& inputNodeAndIndex = dmlGraphNodeOutputNameToNodeAndIndexMap.at(arg->Name()); - DML_INTERMEDIATE_GRAPH_EDGE_DESC edge = {}; + DmlIntermediateSerializedGraphEdge edge = {}; edge.FromNodeIndex = inputNodeAndIndex.nodeIndex; edge.FromNodeOutputIndex = inputNodeAndIndex.targetIndex; - edge.ToNodeIndex = mainGraphNodeIndex; - edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; - graphIntermediateEdges.push_back(edge); + edge.ToNodeIndex = dmlGraphNodeIndex; + edge.ToNodeInputIndex = operatorDmlGraphInputEdge.ToNodeInputIndex; + edge.Name = arg->Name(); + dmlGraphIntermediateEdges.push_back(edge); } } } // map operatorGraphIntermediateEdges as mainGraphIntermediateEdge - for (auto& operatorGraphIntermediateEdge : graphNodeCreateInfo.intermediateEdges) + for (auto& operatorGraphIntermediateEdge : operatorDmlGraphCreateInfo.intermediateEdges) { - DML_INTERMEDIATE_GRAPH_EDGE_DESC edge = {}; - edge.FromNodeIndex = operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphIntermediateEdge.FromNodeIndex]; + DmlIntermediateSerializedGraphEdge edge = {}; + uint32_t shiftedFromNodeIndex = SetAndGetDmlGraphNodeIndex( + operatorGraphIntermediateEdge.FromNodeIndex, + node.Name(), + *operatorDmlGraphCreateInfo.nodes[operatorGraphIntermediateEdge.FromNodeIndex], + operatorDmlGraphToDmlGraphNodeIndexMap, + dmlGraphNodes); + uint32_t shiftedToNodeIndex = SetAndGetDmlGraphNodeIndex( + operatorGraphIntermediateEdge.ToNodeIndex, + node.Name(), + *operatorDmlGraphCreateInfo.nodes[operatorGraphIntermediateEdge.ToNodeIndex], + operatorDmlGraphToDmlGraphNodeIndexMap, + dmlGraphNodes); + + edge.FromNodeIndex = shiftedFromNodeIndex; edge.FromNodeOutputIndex = operatorGraphIntermediateEdge.FromNodeOutputIndex; - edge.ToNodeIndex = operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphIntermediateEdge.ToNodeIndex]; + edge.ToNodeIndex = shiftedToNodeIndex; edge.ToNodeInputIndex = operatorGraphIntermediateEdge.ToNodeInputIndex; - graphIntermediateEdges.push_back(edge); + edge.Name = "nodeIdx:" + std::to_string(shiftedFromNodeIndex) + "-outputIdx:" + std::to_string(operatorGraphIntermediateEdge.FromNodeOutputIndex); + dmlGraphIntermediateEdges.push_back(edge); } - + // populate nameToNodeAndIndexMap (which will be used by above loop) for operatorGraphOutputEdges - for (auto& operatorGraphOutputEdge : graphNodeCreateInfo.outputEdges) + for (auto& operatorGraphOutputEdge : operatorDmlGraphCreateInfo.outputEdges) { const onnxruntime::NodeArg* arg = node.OutputDefs()[operatorGraphOutputEdge.GraphOutputIndex]; if (arg->Exists()) { - nameToNodeAndIndexMap[arg->Name()] = NodeAndIndex { - operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphOutputEdge.FromNodeIndex], - operatorGraphOutputEdge.FromNodeOutputIndex - }; - + uint32_t shiftedNodeIndex = SetAndGetDmlGraphNodeIndex( + operatorGraphOutputEdge.FromNodeIndex, + node.Name(), + *operatorDmlGraphCreateInfo.nodes[operatorGraphOutputEdge.FromNodeIndex], + operatorDmlGraphToDmlGraphNodeIndexMap, + dmlGraphNodes); + dmlGraphNodeOutputNameToNodeAndIndexMap[arg->Name()] = {shiftedNodeIndex, operatorGraphOutputEdge.FromNodeOutputIndex}; nodeOutputShapes[arg->Name()] = outputShapes; } } - - if (isNodeAsOpDesc) - { - for (size_t i = 0; i < graphNodeCreateInfo.nodesAsOperatorDesc.size(); ++i) - { - auto& opDesc = graphNodeCreateInfo.nodesAsOperatorDesc[i]; - - DML_OPERATOR_DESC dmlDesc = SchemaHelpers::ConvertOperatorDesc(*opDesc, &allocator); - - // TODO: Change as new header is ingested - if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING) - dmlDesc.Type = (DML_OPERATOR_TYPE) 169; - - // TODO: Change as new header is ingested - if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT) - dmlDesc.Type = (DML_OPERATOR_TYPE) 170; - - ComPtr op; - ORT_THROW_IF_FAILED(device->CreateOperator(&dmlDesc, IID_PPV_ARGS(&op))); - allocator.Reset(); - - NodeInfo nodeInfo = {}; - nodeInfo.nodeDef = std::move(op); - nodeInfo.name = node.Name(); - graphNodes[firstOpDescGraphNodeIndex + i] = std::move(nodeInfo); - } - } } EdgeShapes graphOutputShapes(subgraphOutputs.size()); @@ -471,24 +505,27 @@ namespace Dml::GraphDescBuilder const onnxruntime::NodeArg* graphOutput = subgraphOutputs[outputIndex]; ORT_THROW_HR_IF_NULL_MSG(E_POINTER, graphOutput, "FusedNode's nodeArgList does not contain one of the nodeArg"); - const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name()); + const auto& outputNodeAndIndex = dmlGraphNodeOutputNameToNodeAndIndexMap.at(graphOutput->Name()); - DML_OUTPUT_GRAPH_EDGE_DESC edge = {}; + DmlOutputSerializedGraphEdge edge = {}; edge.FromNodeIndex = outputNodeAndIndex.nodeIndex; edge.FromNodeOutputIndex = outputNodeAndIndex.targetIndex; edge.GraphOutputIndex = gsl::narrow_cast(outputIndex); - graphOutputEdges.push_back(edge); + edge.Name = graphOutput->Name(); + dmlGraphOutputEdges.push_back(edge); graphOutputShapes.GetMutableShape(outputIndex) = nodeOutputShapes[graphOutput->Name()].GetShape(outputNodeAndIndex.targetIndex); } - RemoveUnconnectedNodes(graphNodes, graphInputEdges, graphIntermediateEdges, graphOutputEdges); + RemoveUnconnectedNodes(dmlGraphNodes, dmlGraphInputEdges, dmlGraphIntermediateEdges, dmlGraphOutputEdges); GraphDesc graphDesc{}; - graphDesc.nodes = std::move(graphNodes); - graphDesc.inputEdges = std::move(graphInputEdges); - graphDesc.outputEdges = std::move(graphOutputEdges); - graphDesc.intermediateEdges = std::move(graphIntermediateEdges); - graphDesc.reuseCommandList = reuseCommandList; + graphDesc.InputCount = static_cast(dmlGraphInputEdges.size()); + graphDesc.OutputCount = static_cast(subgraphOutputs.size()); + graphDesc.Nodes = std::move(dmlGraphNodes); + graphDesc.InputEdges = std::move(dmlGraphInputEdges); + graphDesc.OutputEdges = std::move(dmlGraphOutputEdges); + graphDesc.IntermediateEdges = std::move(dmlGraphIntermediateEdges); + graphDesc.reuseCommandList = (subgraphNodes.size() >= minNodeCountToReuseCommandList || executionHandle->IsMcdmDevice()); graphDesc.outputShapes = std::move(graphOutputShapes); return graphDesc; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h index c95e89b45541b..4055984b40405 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h @@ -22,22 +22,15 @@ namespace Dml namespace GraphDescBuilder { + constexpr uint32_t minNodeCountToReuseCommandList = 5; + constexpr uint32_t c_maxConstNodeDataSize = 8; + // Gets a unique name for the node which survives recreation and graph manipulations between the point // that graph partitioning occurs and kernel creation happens const std::string& GetUniqueNodeName(const onnxruntime::Node& node); - struct NodeInfo - { - std::variant, std::vector> nodeDef; - std::string name; - }; - - struct GraphDesc + struct GraphDesc : DmlSerializedGraphDesc { - std::vector nodes; - std::vector inputEdges; - std::vector outputEdges; - std::vector intermediateEdges; bool reuseCommandList; Windows::AI::MachineLearning::Adapter::EdgeShapes outputShapes; }; @@ -47,11 +40,13 @@ namespace Dml const size_t isConstGpuGraphInputCount, const std::unordered_map>& isInitializerTransferable, const std::unordered_map& graphNodePropertyMap, - IDMLDevice* device, const ExecutionProviderImpl* executionHandle, const onnxruntime::Path& modelPath, gsl::span subgraphNodes, gsl::span subgraphInputs, - gsl::span subgraphOutputs); + gsl::span subgraphOutputs, + /*out*/ std::unordered_map& serializedGraphInputIndexToSubgraphInputIndex, + /*out*/ std::unordered_map& serializedGraphLargeConstantNameToSubgraphInputIndex, + /*out*/ std::vector>& smallConstantData); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index dbd06abf82f72..f29fbc7a1a65b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1123,7 +1123,7 @@ namespace Windows::AI::MachineLearning::Adapter } ORT_CATCH_RETURN } - + template HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper::GetConstantInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept { @@ -1168,7 +1168,7 @@ namespace Windows::AI::MachineLearning::Adapter m_requiredConstantCpuInputs.begin(), m_requiredConstantCpuInputs.end(), inputIndex) != m_requiredConstantCpuInputs.end(); - + // This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present. ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant); } @@ -1508,31 +1508,17 @@ namespace Windows::AI::MachineLearning::Adapter ORT_TRY { assert(operatorGraphDesc != nullptr); - // Either nodesAsOpDesc or nodesIDMLOperator can be present. - assert(operatorGraphDesc->nodeCount == 0 || (!operatorGraphDesc->nodesAsOpDesc ^ !operatorGraphDesc->nodesAsIDMLOperator)); + assert(operatorGraphDesc->nodeCount == 0 || operatorGraphDesc->nodes); - if (operatorGraphDesc->nodesAsOpDesc) - { - m_graphNodeCreateInfo->nodesAsOperatorDesc = std::vector>(); - for (uint32_t nodeIndex = 0; nodeIndex < operatorGraphDesc->nodeCount; nodeIndex++) - { - auto* node = operatorGraphDesc->nodesAsOpDesc[nodeIndex]; - assert(node != nullptr); - AbstractOperatorDesc abstractDesc = SchemaHelpers::ConvertOperatorDesc(*node); - m_graphNodeCreateInfo->nodesAsOperatorDesc.push_back(std::make_unique(std::move(abstractDesc))); - } - } - else + m_graphNodeCreateInfo->nodes = std::vector>(); + for (uint32_t nodeIndex = 0; nodeIndex < operatorGraphDesc->nodeCount; nodeIndex++) { - m_graphNodeCreateInfo->nodesAsIDMLOperator = std::vector>(); - for (uint32_t nodeIndex = 0; nodeIndex < operatorGraphDesc->nodeCount; nodeIndex++) - { - auto* node = operatorGraphDesc->nodesAsIDMLOperator[nodeIndex]; - assert(node != nullptr); - m_graphNodeCreateInfo->nodesAsIDMLOperator.push_back(node); - } + auto* node = operatorGraphDesc->nodes[nodeIndex]; + assert(node != nullptr); + AbstractOperatorDesc abstractDesc = SchemaHelpers::ConvertOperatorDesc(*node); + m_graphNodeCreateInfo->nodes.push_back(std::make_unique(std::move(abstractDesc))); } - + // There can be operators (or kernels) which don't require any input. assert(operatorGraphDesc->inputEdgeCount == 0 || operatorGraphDesc->inputEdges != nullptr); m_graphNodeCreateInfo->inputEdges.insert( @@ -1562,7 +1548,13 @@ namespace Windows::AI::MachineLearning::Adapter OnnxTensorWrapper::OnnxTensorWrapper(onnx::TensorProto* impl, const onnxruntime::Path& modelPath) : m_impl(impl) { // The tensor may be stored as raw data or in typed fields. - if (impl->has_raw_data()) + if (impl->data_location() == onnx::TensorProto_DataLocation_EXTERNAL) + { + THROW_IF_NOT_OK(onnxruntime::utils::UnpackInitializerData(*impl, modelPath, m_unpackedExternalTensor)); + m_dataPtr = reinterpret_cast(m_unpackedExternalTensor.data()); + m_tensorByteSize = m_unpackedExternalTensor.size(); + } + else if (impl->has_raw_data()) { m_dataPtr = reinterpret_cast(impl->mutable_raw_data()->data()); m_tensorByteSize = impl->raw_data().size(); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index 6530d89d895e7..59e253e88457a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -309,6 +309,7 @@ class OnnxTensorWrapper : public WRL::Base, public Closable private: size_t m_tensorByteSize = 0; std::unique_ptr m_unpackedTensor; + std::vector m_unpackedExternalTensor; std::byte* m_dataPtr = nullptr; // Lifetime is managed by the caller and guaranteed to outlive this class diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp index c3bb1a52210f5..287f1e5b6dfe7 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp @@ -53,7 +53,7 @@ namespace Dml MLOperatorGraphDesc operatorGraphDesc = {}; operatorGraphDesc.nodeCount = 1; const DML_OPERATOR_DESC* opDescs{&operatorDesc}; - operatorGraphDesc.nodesAsOpDesc = &opDescs; + operatorGraphDesc.nodes = &opDescs; std::vector inputEdges; for (uint32_t inputIndex = 0; inputIndex < m_kernelInputIndices.size(); inputIndex++) @@ -796,7 +796,7 @@ namespace Dml for (size_t i = 0; i < graphDesc.NodeCount; ++i) { // Create the operator. - ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(operatorGraphDesc.nodesAsOpDesc[i], IID_PPV_ARGS(&dmlOperators[i]))); + ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(operatorGraphDesc.nodes[i], IID_PPV_ARGS(&dmlOperators[i]))); dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{dmlOperators[i].Get()}; dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]}; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp index c8ca6806e75f7..73c2d57e984af 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp @@ -531,7 +531,7 @@ class DmlOperatorAttention : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasAdd.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasAdd.cpp index 1c851c94c4ddc..5aceebbdabfe3 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasAdd.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasAdd.cpp @@ -103,7 +103,7 @@ class DmlOperatorBiasAdd : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasSplitGelu.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasSplitGelu.cpp index 501ce14f1fc08..1e10214ffd463 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasSplitGelu.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasSplitGelu.cpp @@ -137,7 +137,7 @@ class DmlOperatorBiasSplitGelu : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp new file mode 100644 index 0000000000000..c6a87da705a99 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp @@ -0,0 +1,173 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +namespace Dml +{ +// DynamicQuantizeMatMul = MatrixMultiplyIntegerToFloat(DynamicQuantizeLinear(A), B) +class DmlOperatorDynamicQuantizeMatMul : public DmlOperator +{ + // This order matches the ONNX schema. + enum OnnxInputIndex + { + A, // Input + B, + B_scale, + B_zero_point, + Bias, + Count, + }; + +public: + DmlOperatorDynamicQuantizeMatMul(const MLOperatorKernelCreationContext& kernelCreationContext) + : DmlOperator(kernelCreationContext) + { + DmlOperator::Initialize(kernelCreationContext); + + const bool hasBias = kernelCreationContext.IsInputValid(OnnxInputIndex::Bias); + const bool hasBZP = kernelCreationContext.IsInputValid(OnnxInputIndex::B_zero_point); + + // Broadcast Bias tensor to the shape of the output tensor. + if (hasBias) + { + m_inputTensorDescs[OnnxInputIndex::Bias] = CreateTensorDescFromInput( + kernelCreationContext, + OnnxInputIndex::Bias, + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + kernelCreationContext.GetTensorShapeDescription().GetOutputTensorShape(0) + ); + } + MLOperatorTensorDataType BDatatype = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::B).tensorDataType; + + std::vector ATensorShape = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::A); + std::vector ExpectedAScaleTensorShape = {1, 1, 1, 1}; + std::vector ExpectedAZeroPointTensorShape = {1, 1, 1, 1}; + + // output edges between DynQL and MMItoFloat node + TensorDesc intermediateQuantizedATensorDesc = TensorDesc( + BDatatype, + gsl::make_span(ATensorShape), + gsl::make_span(ATensorShape), + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment + ); + + TensorDesc intermediateQuantizedAScaleTensorDesc = TensorDesc( + MLOperatorTensorDataType::Float, + gsl::make_span(ExpectedAScaleTensorShape), + gsl::make_span(ExpectedAScaleTensorShape), + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment + ); + + TensorDesc intermediateQuantizedAZeroPointTensorDesc = TensorDesc( + BDatatype, + gsl::make_span(ExpectedAZeroPointTensorShape), + gsl::make_span(ExpectedAZeroPointTensorShape), + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment + ); + + DML_TENSOR_DESC namedIntermediateQuantizedATensorDesc = intermediateQuantizedATensorDesc.GetDmlDesc(); + DML_TENSOR_DESC namedIntermediateQuantizedAScaleTensorDesc = intermediateQuantizedAScaleTensorDesc.GetDmlDesc(); + DML_TENSOR_DESC namedIntermediateQuantizedAZeroPointTensorDesc = intermediateQuantizedAZeroPointTensorDesc.GetDmlDesc(); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC dynamicQuantizeLinearOperatorDesc = {}; + dynamicQuantizeLinearOperatorDesc.InputTensor = &inputDescs[OnnxInputIndex::A]; + dynamicQuantizeLinearOperatorDesc.OutputTensor = &namedIntermediateQuantizedATensorDesc; + dynamicQuantizeLinearOperatorDesc.OutputScaleTensor = &namedIntermediateQuantizedAScaleTensorDesc; + dynamicQuantizeLinearOperatorDesc.OutputZeroPointTensor = &namedIntermediateQuantizedAZeroPointTensorDesc; + + const DML_OPERATOR_DESC opDesc1{DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR, &dynamicQuantizeLinearOperatorDesc}; + + DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC matrixMultiplyIntergerToFloatOperatorDesc = {}; + matrixMultiplyIntergerToFloatOperatorDesc.ATensor = dynamicQuantizeLinearOperatorDesc.OutputTensor; + matrixMultiplyIntergerToFloatOperatorDesc.AScaleTensor = dynamicQuantizeLinearOperatorDesc.OutputScaleTensor; + matrixMultiplyIntergerToFloatOperatorDesc.AZeroPointTensor = dynamicQuantizeLinearOperatorDesc.OutputZeroPointTensor; + matrixMultiplyIntergerToFloatOperatorDesc.BTensor = &inputDescs[OnnxInputIndex::B]; + matrixMultiplyIntergerToFloatOperatorDesc.BScaleTensor = &inputDescs[OnnxInputIndex::B_scale]; + matrixMultiplyIntergerToFloatOperatorDesc.BZeroPointTensor = hasBZP? &inputDescs[OnnxInputIndex::B_zero_point] : nullptr; + matrixMultiplyIntergerToFloatOperatorDesc.BiasTensor = hasBias? &inputDescs[OnnxInputIndex::Bias] : nullptr; + matrixMultiplyIntergerToFloatOperatorDesc.OutputTensor = &outputDescs[0]; + + const DML_OPERATOR_DESC opDesc2{ DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matrixMultiplyIntergerToFloatOperatorDesc}; + + MLOperatorGraphDesc operatorGraphDesc = {}; + std::vector opDescs{&opDesc1, &opDesc2}; + operatorGraphDesc.nodeCount = static_cast(opDescs.size()); + operatorGraphDesc.nodes = opDescs.data(); + + // set input edges + std::pair nodeToNodeInputIndex[OnnxInputIndex::Count] {{0, 0}, {1, 3}, {1, 4}, {1, 5}, {1, 6}}; + std::vector inputEdges; + for (uint32_t inputIndex = 0; inputIndex < OnnxInputIndex::Count; inputIndex++) + { + if (inputIndex == OnnxInputIndex::B_zero_point && !hasBZP) continue; + if (inputIndex == OnnxInputIndex::Bias && !hasBias) continue; + DML_INPUT_GRAPH_EDGE_DESC inputEdge = {}; + inputEdge.GraphInputIndex = inputIndex; // OnnxInputIndex and DmlInputIndex are identity for QLinearSigmoid + inputEdge.ToNodeIndex = nodeToNodeInputIndex[inputIndex].first; + inputEdge.ToNodeInputIndex = nodeToNodeInputIndex[inputIndex].second; + inputEdges.push_back(inputEdge); + } + operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); + operatorGraphDesc.inputEdges = inputEdges.data(); + + // set intermediate edges + std::vector intermediateEdges; + + DML_INTERMEDIATE_GRAPH_EDGE_DESC dynQLToMMItofloatEdge1 = {}; + dynQLToMMItofloatEdge1.FromNodeIndex = 0; + dynQLToMMItofloatEdge1.FromNodeOutputIndex = 0; + dynQLToMMItofloatEdge1.ToNodeIndex = 1; + dynQLToMMItofloatEdge1.ToNodeInputIndex = 0; + intermediateEdges.push_back(dynQLToMMItofloatEdge1); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC dynQLToMMItofloatEdge2 = {}; + dynQLToMMItofloatEdge2.FromNodeIndex = 0; + dynQLToMMItofloatEdge2.FromNodeOutputIndex = 1; + dynQLToMMItofloatEdge2.ToNodeIndex = 1; + dynQLToMMItofloatEdge2.ToNodeInputIndex = 1; + intermediateEdges.push_back(dynQLToMMItofloatEdge2); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC dynQLToMMItofloatEdge3 = {}; + dynQLToMMItofloatEdge3.FromNodeIndex = 0; + dynQLToMMItofloatEdge3.FromNodeOutputIndex = 2; + dynQLToMMItofloatEdge3.ToNodeIndex = 1; + dynQLToMMItofloatEdge3.ToNodeInputIndex = 2; + intermediateEdges.push_back(dynQLToMMItofloatEdge3); + + operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size()); + operatorGraphDesc.intermediateEdges = intermediateEdges.data(); + + // set the output edges + std::vector outputEdges; + DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {}; + outputEdge.FromNodeIndex = 1; + outputEdge.FromNodeOutputIndex = 0; + outputEdge.GraphOutputIndex = 0; + outputEdges.push_back(outputEdge); + operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); + operatorGraphDesc.outputEdges = outputEdges.data(); + + SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(DynamicQuantizeMatMul, DmlOperatorDynamicQuantizeMatMul); +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEmbedLayerNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEmbedLayerNormalization.cpp index 6a8333cd72561..3c9458658c4d0 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEmbedLayerNormalization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEmbedLayerNormalization.cpp @@ -484,7 +484,7 @@ class DmlOperatorEmbedLayerNormalization : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupNorm.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupNorm.cpp index fed0e4645ffd8..8b275fc550f3e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupNorm.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupNorm.cpp @@ -287,7 +287,7 @@ class DmlOperatorGroupNorm : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp index 5c64059f7caa9..80e6fefc2fb80 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp @@ -247,7 +247,7 @@ class DmlOperatorLayerNormalization : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMatMulIntegerToFloat.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMatMulIntegerToFloat.cpp new file mode 100644 index 0000000000000..b5a3dd0960b86 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMatMulIntegerToFloat.cpp @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +namespace Dml +{ + +class DmlOperatorMatMulIntegerToFloat : public DmlOperator +{ + enum OrtInputTensors : uint32_t + { + ortA, + ortB, + ortAScale, + ortBScale, + ortAZeroPoint, + ortBZeroPoint, + ortBias, + ortInputCount + }; + + enum DmlInputIndex : uint32_t + { + dmlA, + dmlAScale, + dmlAZeroPoint, + dmlB, + dmlBScale, + dmlBZeroPoint, + dmlBias, + dmlInputCount, + }; + +public: + DmlOperatorMatMulIntegerToFloat(const MLOperatorKernelCreationContext& kernelInfo) + : DmlOperator(kernelInfo) + { + std::vector> inputIndices = { OrtInputTensors::ortA, OrtInputTensors::ortAScale, OrtInputTensors::ortAZeroPoint, OrtInputTensors::ortB, OrtInputTensors::ortBScale, OrtInputTensors::ortBZeroPoint, OrtInputTensors::ortBias }; + DmlOperator::Initialize(kernelInfo, inputIndices); + + std::vector inputShape0 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(OrtInputTensors::ortA); + std::vector inputShape1 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(OrtInputTensors::ortB); + std::vector outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0); + + OperatorHelper::MatMulShapeMapping(inputShape0, inputShape1, outputShape); + + // Initialize the input descriptions with broadcasting + m_inputTensorDescs[DmlInputIndex::dmlA] = CreateTensorDescFromInput(kernelInfo, OrtInputTensors::ortA, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape0); + m_inputTensorDescs[DmlInputIndex::dmlB] = CreateTensorDescFromInput(kernelInfo, OrtInputTensors::ortB, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape1); + + // Broadcast Bias tensor to the shape of the output tensor. + if(kernelInfo.IsInputValid(OrtInputTensors::ortBias)) { + m_inputTensorDescs[DmlInputIndex::dmlBias] = CreateTensorDescFromInput(kernelInfo, OrtInputTensors::ortBias, TensorAxis::DoNotCoerce, + TensorAxis::W, TensorAxis::RightAligned, outputShape); + } + + uint32_t dmlDimSize = m_inputTensorDescs[DmlInputIndex::dmlA].GetDimensionCount(); + // Resize the A Scale to be the same dimension as the input tensor. + // The 1D tensor needs to be moved to the H channel. + m_inputTensorDescs[DmlInputIndex::dmlAScale] = CreateTensorDescFromInput( + kernelInfo, + OrtInputTensors::ortAScale, + TensorAxis::DoNotCoerce, + TensorAxis::H, + TensorAxis::LeftAligned, + std::nullopt, + dmlDimSize + ); + + // Resize the A ZeroPoint to be the same dimension as the input tensor. + // The 1D tensor needs to be moved to the H channel. + if (kernelInfo.IsInputValid(OrtInputTensors::ortAZeroPoint)) + { + m_inputTensorDescs[DmlInputIndex::dmlAZeroPoint] = CreateTensorDescFromInput( + kernelInfo, + OrtInputTensors::ortAZeroPoint, + TensorAxis::DoNotCoerce, + TensorAxis::H, + TensorAxis::LeftAligned, + std::nullopt, + dmlDimSize + ); + } + + // B Zeropoint and BScale are already aligned in the W dimension so no need to align them + + // Initialize the output description while overriding the shape + m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC matMulDesc = {}; + matMulDesc.ATensor = &inputDescs[DmlInputIndex::dmlA]; + matMulDesc.AScaleTensor = &inputDescs[DmlInputIndex::dmlAScale]; + matMulDesc.AZeroPointTensor = inputDescs[DmlInputIndex::dmlAZeroPoint].Desc != nullptr ? &inputDescs[DmlInputIndex::dmlAZeroPoint] : nullptr; + matMulDesc.BTensor = &inputDescs[DmlInputIndex::dmlB]; + matMulDesc.BScaleTensor = &inputDescs[DmlInputIndex::dmlBScale]; + matMulDesc.BZeroPointTensor = inputDescs[DmlInputIndex::dmlBZeroPoint].Desc != nullptr ? &inputDescs[DmlInputIndex::dmlBZeroPoint] : nullptr; + matMulDesc.BiasTensor = inputDescs[DmlInputIndex::dmlBias].Desc != nullptr ? &inputDescs[DmlInputIndex::dmlBias] : nullptr; + matMulDesc.OutputTensor = &outputDescs[0]; + + DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matMulDesc }; + SetDmlOperatorDesc(opDesc, kernelInfo); + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(MatMulIntegerToFloat, DmlOperatorMatMulIntegerToFloat); + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp new file mode 100644 index 0000000000000..f9519b26bb4e3 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp @@ -0,0 +1,704 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +/* +Abbreviations: B is batch_size, S is sequence_length, W is hidden_size + N is number of attention heads, H is head size, and W=N*H + +Input, Weight, Bias, Mask Index and Past are Inputs + +Mask Index/Causal Input Weight Bias + | \ | / + | \ | / + | \ | / + | MatMulIntToFloat + | / | \ + | / | \ + | / | \ + | Slice Slice Slice + | | | | + | | | | + | Identity Identity Identity // The identities are used to transpose NCHW -> NHCW while + | | | | // keeping the GEMM strides as NCHW to better target metacommands + | | | | + | | | | Past + | | | | / \ + | | | | / \ + | | | | Slice Slice + | | | | | | + | | | | | | + | | | | | | + --------------------------MHA ----------- + / | \ + / | \ + / | \ + / | \ + / | \ + / | \ + / presentKey presentValue + / \ / + / \ / + / \ / + / Concat + / | + Output1 Output2 (present) + + This kernel creates a DML_GRAPH, as mentioned above. + For reference, refer to this Doc: + https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftqattention + */ + +namespace Dml +{ +class DmlOperatorQAttention : public DmlOperator +{ +public: + DmlOperatorQAttention(const MLOperatorKernelCreationContext& kernelCreationContext) + : DmlOperator(kernelCreationContext) + { + enum InputIndex : uint32_t + { + inputIndex, + weightsIndex, + biasIndex, + inputScaleIndex, + weightScaleIndex, + maskIndex, + inputZeroPointIndex, + weightZeroPointIndex, + pastIndex, + inputCount, + }; + + enum OutputIndex : uint32_t + { + outputIndex, + presentIndex, + outputCount, + }; + + enum MhaInputIndex : uint32_t + { + mhaQueryIndex, + mhaKeyIndex, + mhaValueIndex, + mhaStackedQueryKeyIndex, + mhaStackedKeyValueIndex, + mhaStackedQueryKeyValueIndex, + mhaBiasIndex, + mhaMaskIndex, + mhaRelativePositionBiasIndex, + mhaPastKeyIndex, + mhaPastValueIndex, + mhaInputCount, + }; + + enum MhaOutputIndex : uint32_t + { + mhaOutputIndex, + mhaPresentKeyIndex, + mhaPresentValueIndex, + mhaOutputCount, + }; + + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() >= 5); + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() >= 1); + + const bool hasBias = kernelCreationContext.IsInputValid(biasIndex); + const bool hasMask = kernelCreationContext.IsInputValid(maskIndex); + const bool hasUnpaddedBounds = hasMask && kernelCreationContext.GetInputTensorDimensionCount(maskIndex) == 1; + const bool hasPast = kernelCreationContext.IsInputValid(pastIndex); + + DmlOperator::Initialize(kernelCreationContext, std::nullopt, std::nullopt, std::nullopt, std::nullopt, 1); + + const bool unidirectional = gsl::narrow_cast(kernelCreationContext.GetAttribute(AttrName::Unidirectional)); + const uint32_t numHeads = gsl::narrow_cast(kernelCreationContext.GetAttribute(AttrName::NumHeads)); + ML_CHECK_VALID_ARGUMENT(numHeads > 0); // to avoid process crash because of division by zero. + + auto inputTensorShape = m_inputTensorDescs[inputIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(inputTensorShape.size() == 3); + + auto weightTensorShape = m_inputTensorDescs[weightsIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(weightTensorShape.size() == 2); + ML_CHECK_VALID_ARGUMENT(weightTensorShape[0] == inputTensorShape[2]); + ML_CHECK_VALID_ARGUMENT(weightTensorShape[1] % 3 == 0); + + if (hasBias) + { + auto biasTensorShape = m_inputTensorDescs[biasIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(biasTensorShape.size() == 1); + ML_CHECK_VALID_ARGUMENT(biasTensorShape[0] % 3 == 0); + ML_CHECK_VALID_ARGUMENT(weightTensorShape[1] == biasTensorShape[0]); + } + + if (hasPast) + { + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.IsOutputValid(presentIndex)); + } + + const uint32_t hiddenSize = weightTensorShape[1] / 3; + const uint32_t headSize = hiddenSize / numHeads; + const uint32_t batchSize = inputTensorShape[0]; + const uint32_t sequenceLength = inputTensorShape[1]; + const uint32_t pastSequenceLength = hasPast ? m_inputTensorDescs[pastIndex].GetSizes()[3] : 0; + + uint32_t desiredWeightTensorShape[3] = {batchSize, weightTensorShape[0], 3 * hiddenSize}; + MLOperatorTensorDataType dataType = kernelCreationContext.GetOutputEdgeDescription(outputIndex).tensorDataType; + + m_inputTensorDescs[weightsIndex] = TensorDesc::ConstructBroadcastedTensorDesc( + kernelCreationContext.GetInputEdgeDescription(weightsIndex).tensorDataType, + desiredWeightTensorShape, + weightTensorShape); + + uint32_t desiredBiasTensorShape[3] = {batchSize, sequenceLength, 3 * hiddenSize}; + + if (hasBias) + { + auto biasTensorShape = m_inputTensorDescs[biasIndex].GetSizes(); + m_inputTensorDescs[biasIndex] = TensorDesc::ConstructBroadcastedTensorDesc(kernelCreationContext.GetInputEdgeDescription(biasIndex).tensorDataType, desiredBiasTensorShape, biasTensorShape); + } + + MLOperatorTensorDataType maskTensorDataType = MLOperatorTensorDataType::Undefined; + bool hasMaxSequenceMask = false; + DML_MULTIHEAD_ATTENTION_MASK_TYPE maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE; + if (hasMask) + { + if (hasUnpaddedBounds) + { + auto unpaddedKeyBoundsShape = m_inputTensorDescs[maskIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(unpaddedKeyBoundsShape.size() == 1); + + const uint32_t batchGroupCount = unpaddedKeyBoundsShape[0] / batchSize; + ML_CHECK_VALID_ARGUMENT(batchGroupCount == 1 || batchGroupCount == 2); + + uint32_t desiredShape[2] = {batchGroupCount, batchSize}; + m_inputTensorDescs[maskIndex] = TensorDesc( + m_inputTensorDescs[maskIndex].GetDmlDataType(), + desiredShape); + + maskType = batchGroupCount == 1 + ? DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH + : DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START; + } + else + { + auto maskIndexTensorShape = m_inputTensorDescs[maskIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(maskIndexTensorShape.size() > 1 && maskIndexTensorShape.size() <= 4); + + maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN; + std::vector reshapedMaskIndexTensorShape(maskIndexTensorShape.begin(), maskIndexTensorShape.end()); + if (maskIndexTensorShape.size() == 4 && maskIndexTensorShape[2] != sequenceLength) + { + hasMaxSequenceMask = true; + ML_CHECK_VALID_ARGUMENT(maskIndexTensorShape[2] == maskIndexTensorShape[3]); + const uint32_t maxSequenceLength = maskIndexTensorShape[2]; + uint32_t desiredMaskIndexShape[4] = {batchSize, numHeads, maxSequenceLength, maxSequenceLength}; + maskTensorDataType = kernelCreationContext.GetInputEdgeDescription(maskIndex).tensorDataType; + m_inputTensorDescs[maskIndex] = TensorDesc::ConstructBroadcastedTensorDesc(maskTensorDataType, desiredMaskIndexShape, reshapedMaskIndexTensorShape); + } + else + { + uint32_t maskIndexDimensionCount = gsl::narrow_cast(maskIndexTensorShape.size()); + reshapedMaskIndexTensorShape.insert(reshapedMaskIndexTensorShape.begin() + 1, 4 - maskIndexDimensionCount, 1); + uint32_t desiredMaskIndexShape[4] = {batchSize, numHeads, sequenceLength, sequenceLength}; + maskTensorDataType = kernelCreationContext.GetInputEdgeDescription(maskIndex).tensorDataType; + m_inputTensorDescs[maskIndex] = TensorDesc::ConstructBroadcastedTensorDesc(maskTensorDataType, desiredMaskIndexShape, reshapedMaskIndexTensorShape); + } + } + } + + MLOperatorTensorDataType pastTensorDataType = MLOperatorTensorDataType::Undefined; + MLOperatorTensorDataType presentTensorDataType = MLOperatorTensorDataType::Undefined; + if (hasPast) + { + pastTensorDataType = kernelCreationContext.GetInputEdgeDescription(pastIndex).tensorDataType; + presentTensorDataType = kernelCreationContext.GetOutputEdgeDescription(presentIndex).tensorDataType; + } + + TensorDesc matMulIntToFloatOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, desiredBiasTensorShape); + DML_TENSOR_DESC namedMatMulIntToFloatOutputTensorDesc = matMulIntToFloatOutputTensorDesc.GetDmlDesc(); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC matMulIntToFloatOperatorDesc = {}; + matMulIntToFloatOperatorDesc.ATensor = &inputDescs[InputIndex::inputIndex]; + matMulIntToFloatOperatorDesc.AScaleTensor = &inputDescs[InputIndex::inputScaleIndex]; + matMulIntToFloatOperatorDesc.AZeroPointTensor = &inputDescs[InputIndex::inputZeroPointIndex]; + matMulIntToFloatOperatorDesc.BTensor = &inputDescs[InputIndex::weightsIndex]; + matMulIntToFloatOperatorDesc.BScaleTensor = &inputDescs[InputIndex::weightScaleIndex]; + matMulIntToFloatOperatorDesc.BZeroPointTensor = &inputDescs[InputIndex::weightZeroPointIndex]; + matMulIntToFloatOperatorDesc.BiasTensor = hasBias ? &inputDescs[InputIndex::biasIndex] : nullptr; + matMulIntToFloatOperatorDesc.OutputTensor = &namedMatMulIntToFloatOutputTensorDesc; + + const DML_OPERATOR_DESC matMulIntToFloatDesc = { DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matMulIntToFloatOperatorDesc}; + + std::array queryKeySlicedTensorShape = {batchSize, sequenceLength, hiddenSize + hiddenSize}; + TensorDesc queryKeySlicedInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, queryKeySlicedTensorShape); + DML_TENSOR_DESC namedQueryKeySlicedInputTensorDesc = queryKeySlicedInputTensorDesc.GetDmlDesc(); + + std::array valueSlicedTensorShape = {batchSize, sequenceLength, hiddenSize}; + TensorDesc valueSlicedInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, valueSlicedTensorShape); + DML_TENSOR_DESC namedValueSlicedInputTensorDesc = valueSlicedInputTensorDesc.GetDmlDesc(); + + // Transpose slice QK from [batchSize, sequenceLength, 2, numHeads, headSize] to [batchSize, sequenceLength, numHeads, 2, headSize] + std::array queryKeyTransposedTensorShape = {batchSize, sequenceLength, numHeads, 2, headSize}; + std::array queryKeyTransposedStrides = { + sequenceLength * numHeads * 2 * headSize, + numHeads * 2 * headSize, + headSize, + numHeads * headSize, + 1, + }; + + TensorDesc queryKeyTransposedInputTensorDesc = TensorDesc( + GetDmlDataTypeFromMlDataType(dataType), + queryKeyTransposedTensorShape, + queryKeyTransposedStrides); + DML_TENSOR_DESC namedQueryKeyTransposedInputTensorDesc = queryKeyTransposedInputTensorDesc.GetDmlDesc(); + + TensorDesc queryKeyTransposedOutputTensorDesc = TensorDesc( + GetDmlDataTypeFromMlDataType(dataType), + queryKeyTransposedTensorShape); + DML_TENSOR_DESC namedQueryKeyTransposedOutputTensorDesc = queryKeyTransposedOutputTensorDesc.GetDmlDesc(); + + // Transpose QKV from [batchSize, sequenceLength, 3, numHeads, headSize] to [batchSize, sequenceLength, numHeads, 3, headSize] + std::array queryKeyValueTransposedTensorShape = {batchSize, sequenceLength, numHeads, 3, headSize}; + std::array queryKeyValueTransposedStrides = { + sequenceLength * numHeads * 3 * headSize, + numHeads * 3 * headSize, + headSize, + numHeads * headSize, + 1, + }; + + TensorDesc queryKeyValueTransposedInputTensorDesc = TensorDesc( + GetDmlDataTypeFromMlDataType(dataType), + queryKeyValueTransposedTensorShape, + queryKeyValueTransposedStrides); + DML_TENSOR_DESC namedQueryKeyValueTransposedInputTensorDesc = queryKeyValueTransposedInputTensorDesc.GetDmlDesc(); + + TensorDesc queryKeyValueTransposedOutputTensorDesc = TensorDesc( + GetDmlDataTypeFromMlDataType(dataType), + queryKeyValueTransposedTensorShape); + DML_TENSOR_DESC namedQueryKeyValueTransposedOutputTensorDesc = queryKeyValueTransposedOutputTensorDesc.GetDmlDesc(); + + std::array queryKeySliceOffset = {0, 0, 0}; + std::array queryKeySliceSize = {batchSize, sequenceLength, hiddenSize + hiddenSize}; + std::array queryKeySliceStrides = {1, 1, 1}; + + std::array valueSliceOffset = {0, 0, 2 * hiddenSize}; + std::array valueSliceSize = {batchSize, sequenceLength, hiddenSize}; + std::array valueSliceStrides = {1, 1, 1}; + + // When Q/K/V all have the same hidden size, we just have to transpose it before sending it to MHA + DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC transposeOperatorDesc = {}; + + transposeOperatorDesc.InputTensor = &namedQueryKeyValueTransposedInputTensorDesc; + transposeOperatorDesc.OutputTensor = &namedQueryKeyValueTransposedOutputTensorDesc; + + const DML_OPERATOR_DESC transposedDesc = { DML_OPERATOR_ELEMENT_WISE_IDENTITY, &transposeOperatorDesc}; + + std::array maskSliceOutputShape = {batchSize, numHeads, sequenceLength, sequenceLength}; + std::array maskSliceStrides = {1, 1, 1, 1}; + std::array maskSliceOffsets = {0, 0, 0, 0}; + TensorDesc maskSliceOutputTensorDesc; + DML_TENSOR_DESC namedMaskSliceOutputTensorDesc; + + DML_SLICE1_OPERATOR_DESC maskSlicedOperatorDesc = {}; + if (hasMaxSequenceMask) + { + maskSliceOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(maskTensorDataType, maskSliceOutputShape); + namedMaskSliceOutputTensorDesc = maskSliceOutputTensorDesc.GetDmlDesc(); + maskSlicedOperatorDesc.InputTensor = &inputDescs[maskIndex]; + maskSlicedOperatorDesc.OutputTensor = &namedMaskSliceOutputTensorDesc; + maskSlicedOperatorDesc.DimensionCount = gsl::narrow_cast(maskSliceOutputShape.size()); + maskSlicedOperatorDesc.InputWindowOffsets = maskSliceOffsets.data(); + maskSlicedOperatorDesc.InputWindowSizes = maskSliceOutputShape.data(); + maskSlicedOperatorDesc.InputWindowStrides = maskSliceStrides.data(); + } + const DML_OPERATOR_DESC maskSlicedDesc = { DML_OPERATOR_SLICE1, &maskSlicedOperatorDesc}; + + // We need to slice Past to get PastValue and PastKey tensors for MHA + std::array pastKeyOutputShape = {1, batchSize, numHeads, pastSequenceLength, headSize}; + std::array pastKeyStrides = {1, 1, 1, 1, 1}; + std::array pastKeyOffsets = {0, 0, 0, 0, 0}; + TensorDesc pastKeyOutputTensorDesc; + DML_TENSOR_DESC namedPastKeyOutputTensorDesc; + + std::array pastValueOutputShape = {1, batchSize, numHeads, pastSequenceLength, headSize}; + std::array pastValueStrides = {1, 1, 1, 1, 1}; + std::array pastValueOffsets = {1, 0, 0, 0, 0}; + TensorDesc pastValueOutputTensorDesc; + DML_TENSOR_DESC namedPastValueOutputTensorDesc; + + DML_SLICE1_OPERATOR_DESC pastKeySlicedOperatorDesc = {}; + DML_SLICE1_OPERATOR_DESC pastValueSlicedOperatorDesc = {}; + + if (hasPast) + { + pastKeyOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(pastTensorDataType, pastKeyOutputShape); + namedPastKeyOutputTensorDesc = pastKeyOutputTensorDesc.GetDmlDesc(); + pastKeySlicedOperatorDesc.InputTensor = &inputDescs[pastIndex]; + pastKeySlicedOperatorDesc.OutputTensor = &namedPastKeyOutputTensorDesc; + pastKeySlicedOperatorDesc.DimensionCount = gsl::narrow_cast(pastKeyOutputShape.size()); + pastKeySlicedOperatorDesc.InputWindowOffsets = pastKeyOffsets.data(); + pastKeySlicedOperatorDesc.InputWindowSizes = pastKeyOutputShape.data(); + pastKeySlicedOperatorDesc.InputWindowStrides = pastKeyStrides.data(); + + pastValueOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(pastTensorDataType, pastValueOutputShape); + namedPastValueOutputTensorDesc = pastValueOutputTensorDesc.GetDmlDesc(); + pastValueSlicedOperatorDesc.InputTensor = &inputDescs[pastIndex]; + pastValueSlicedOperatorDesc.OutputTensor = &namedPastValueOutputTensorDesc; + pastValueSlicedOperatorDesc.DimensionCount = gsl::narrow_cast(pastValueOutputShape.size()); + pastValueSlicedOperatorDesc.InputWindowOffsets = pastValueOffsets.data(); + pastValueSlicedOperatorDesc.InputWindowSizes = pastValueOutputShape.data(); + pastValueSlicedOperatorDesc.InputWindowStrides = pastValueStrides.data(); + } + + const DML_OPERATOR_DESC pastKeySlicedDesc = { DML_OPERATOR_SLICE1, &pastKeySlicedOperatorDesc}; + const DML_OPERATOR_DESC pastValueSlicedDesc = { DML_OPERATOR_SLICE1, &pastValueSlicedOperatorDesc}; + + // Causal Mask: Upper Triangular Boolean Matrix + // Example: [[1, 0, 0, 0, 0], + // [1, 1, 0, 0, 0], + // [1, 1, 1, 0, 0], + // [1, 1, 1, 1, 0]] + // DML adds maskFilterValue to the "off" bits in the mask and sets the "on" bits to 0 + // passed to MHA as maskIndex Tensor when unidirectional == 1 + std::array causalMaskOutputShape = {1, 1, sequenceLength, pastSequenceLength + sequenceLength}; + TensorDesc causalMaskTensorDesc; + DML_DIAGONAL_MATRIX1_OPERATOR_DESC causalMaskOperatorDesc = {}; + DML_TENSOR_DESC namedcausalMaskTensorDesc; + + if (unidirectional && !hasMask) + { + causalMaskTensorDesc = TensorDesc::ConstructDefaultTensorDesc(MLOperatorTensorDataType::Int32, causalMaskOutputShape); + namedcausalMaskTensorDesc = causalMaskTensorDesc.GetDmlDesc(); + causalMaskOperatorDesc.ValueDataType = DML_TENSOR_DATA_TYPE_INT32; + causalMaskOperatorDesc.DiagonalFillBegin = INT32_MIN; + causalMaskOperatorDesc.DiagonalFillEnd = pastSequenceLength + 1; + causalMaskOperatorDesc.Value.Int32 = 1; + causalMaskOperatorDesc.OutputTensor = &namedcausalMaskTensorDesc; + maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN; + } + DML_OPERATOR_DESC causalMaskDesc = { DML_OPERATOR_DIAGONAL_MATRIX1, &causalMaskOperatorDesc }; + + DML_MULTIHEAD_ATTENTION_OPERATOR_DESC mhaOperatorDesc = {}; + std::array presentKeyOutputShape = {1, batchSize, numHeads, pastSequenceLength + sequenceLength, headSize}; + std::array presentValueOutputShape = {1, batchSize, numHeads, pastSequenceLength + sequenceLength, headSize}; + TensorDesc presentKeyTensorDesc; + TensorDesc presentValueTensorDesc; + DML_TENSOR_DESC namedPresentKeyOutputTensorDesc; + DML_TENSOR_DESC namedPresentValueOutputTensorDesc; + + mhaOperatorDesc.StackedQueryKeyValueTensor = &namedQueryKeyValueTransposedOutputTensorDesc; + + // Broadcast to MHA MaskTensor Shape + std::array mhaMaskTensorShape = {batchSize, numHeads, sequenceLength, pastSequenceLength + sequenceLength}; + TensorDesc broadcastedcausalMaskTensorDesc; + DML_TENSOR_DESC namedbroadcastedcausalMaskTensorDesc; + if (unidirectional && !hasMask) + { + broadcastedcausalMaskTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(MLOperatorTensorDataType::Int32, mhaMaskTensorShape, causalMaskOutputShape); + namedbroadcastedcausalMaskTensorDesc = broadcastedcausalMaskTensorDesc.GetDmlDesc(); + mhaOperatorDesc.MaskTensor = &namedbroadcastedcausalMaskTensorDesc; + } + else if (hasMaxSequenceMask) + { + mhaOperatorDesc.MaskTensor = &namedMaskSliceOutputTensorDesc; + } + else + { + mhaOperatorDesc.MaskTensor = hasMask ? &inputDescs[maskIndex] : nullptr; + } + + mhaOperatorDesc.RelativePositionBiasTensor = nullptr; + mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex]; + mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute(AttrName::Scale, gsl::narrow_cast(1.0f / std::sqrt(headSize))); + // Set MaskFilterValue to lowest float for Causal Mask + mhaOperatorDesc.MaskFilterValue = unidirectional ? std::numeric_limits::lowest() : + kernelCreationContext.GetOptionalAttribute(AttrName::MaskFilterValue, -10'000.0f); + mhaOperatorDesc.HeadCount = numHeads; + mhaOperatorDesc.MaskType = maskType; + if (hasPast) + { + presentKeyTensorDesc = TensorDesc::ConstructDefaultTensorDesc(presentTensorDataType, presentKeyOutputShape); + namedPresentKeyOutputTensorDesc = presentKeyTensorDesc.GetDmlDesc(); + presentValueTensorDesc = TensorDesc::ConstructDefaultTensorDesc(presentTensorDataType, presentValueOutputShape); + namedPresentValueOutputTensorDesc = presentValueTensorDesc.GetDmlDesc(); + mhaOperatorDesc.PastKeyTensor = &namedPastKeyOutputTensorDesc; + mhaOperatorDesc.PastValueTensor = &namedPastValueOutputTensorDesc; + mhaOperatorDesc.OutputPresentKeyTensor = &namedPresentKeyOutputTensorDesc; + mhaOperatorDesc.OutputPresentValueTensor = &namedPresentValueOutputTensorDesc; + } + + const DML_OPERATOR_DESC mhaDesc = { DML_OPERATOR_MULTIHEAD_ATTENTION, &mhaOperatorDesc }; + + DML_JOIN_OPERATOR_DESC presentKeyValueJoinOperatorDesc = {}; + std::vector joinInputDesc; + + if (hasPast) + { + joinInputDesc.push_back(namedPresentKeyOutputTensorDesc); + joinInputDesc.push_back(namedPresentValueOutputTensorDesc); + presentKeyValueJoinOperatorDesc.InputCount = gsl::narrow_cast(joinInputDesc.size()); + presentKeyValueJoinOperatorDesc.InputTensors = joinInputDesc.data(); + presentKeyValueJoinOperatorDesc.OutputTensor = &outputDescs[presentIndex]; + presentKeyValueJoinOperatorDesc.Axis = gsl::narrow_cast(0); + } + + DML_OPERATOR_DESC presentKeyValueJoinDesc = { DML_OPERATOR_JOIN, &presentKeyValueJoinOperatorDesc }; + + // Construct the graph + std::vector inputEdges; + std::vector intermediateEdges; + std::vector outputEdges; + + std::vector opDescs = { + &matMulIntToFloatDesc, + &mhaDesc, + }; + + uint32_t currentNodeIndex = 0; + const uint32_t matMulIntToFloatNodeIndex = currentNodeIndex++; + const uint32_t mhaNodeIndex = currentNodeIndex++; + + uint32_t queryKeyValueTransposedNodeIndex = 0; + + opDescs.push_back(&transposedDesc); + queryKeyValueTransposedNodeIndex = currentNodeIndex++; + + uint32_t maskSliceNodeIndex = 0; + if (hasMaxSequenceMask) + { + opDescs.push_back(&maskSlicedDesc); + maskSliceNodeIndex = currentNodeIndex++; + } + + uint32_t pastKeySliceNodeIndex = 0; + uint32_t pastValueSliceNodeIndex = 0; + uint32_t concatNodeIndex = 0; + if (hasPast) + { + opDescs.push_back(&pastKeySlicedDesc); + pastKeySliceNodeIndex = currentNodeIndex++; + opDescs.push_back(&pastValueSlicedDesc); + pastValueSliceNodeIndex = currentNodeIndex++; + opDescs.push_back(&presentKeyValueJoinDesc); + concatNodeIndex = currentNodeIndex++; + } + + uint32_t causalMaskNodeIndex = 0; + if (unidirectional && !hasMask) + { + opDescs.push_back(&causalMaskDesc); + causalMaskNodeIndex = currentNodeIndex++; + } + + DML_INPUT_GRAPH_EDGE_DESC inputToMatMulIntToFloatEdge = {}; + inputToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputIndex; + inputToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex; + inputToMatMulIntToFloatEdge.ToNodeInputIndex = 0; + inputEdges.push_back(inputToMatMulIntToFloatEdge); + + DML_INPUT_GRAPH_EDGE_DESC inputScaleToMatMulIntToFloatEdge = {}; + inputScaleToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputScaleIndex; + inputScaleToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex; + inputScaleToMatMulIntToFloatEdge.ToNodeInputIndex = 1; + inputEdges.push_back(inputScaleToMatMulIntToFloatEdge); + + DML_INPUT_GRAPH_EDGE_DESC inputZeroPointToMatMulIntToFloatEdge = {}; + inputZeroPointToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputZeroPointIndex; + inputZeroPointToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex; + inputZeroPointToMatMulIntToFloatEdge.ToNodeInputIndex = 2; + inputEdges.push_back(inputZeroPointToMatMulIntToFloatEdge); + + DML_INPUT_GRAPH_EDGE_DESC weightToMatMulIntToFloatEdge = {}; + weightToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::weightsIndex; + weightToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex; + weightToMatMulIntToFloatEdge.ToNodeInputIndex = 3; + inputEdges.push_back(weightToMatMulIntToFloatEdge); + + DML_INPUT_GRAPH_EDGE_DESC weightScaleToMatMulIntToFloatEdge = {}; + weightScaleToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::weightScaleIndex; + weightScaleToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex; + weightScaleToMatMulIntToFloatEdge.ToNodeInputIndex = 4; + inputEdges.push_back(weightScaleToMatMulIntToFloatEdge); + + DML_INPUT_GRAPH_EDGE_DESC weightZeroPointToMatMulIntToFloatEdge = {}; + weightZeroPointToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::weightZeroPointIndex; + weightZeroPointToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex; + weightZeroPointToMatMulIntToFloatEdge.ToNodeInputIndex = 5; + inputEdges.push_back(weightZeroPointToMatMulIntToFloatEdge); + + if (hasBias) + { + DML_INPUT_GRAPH_EDGE_DESC biasToMatMulIntToFloatEdge = {}; + biasToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::biasIndex; + biasToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex; + biasToMatMulIntToFloatEdge.ToNodeInputIndex = 6; + inputEdges.push_back(biasToMatMulIntToFloatEdge); + } + + if (hasMask) + { + if (hasUnpaddedBounds) + { + DML_INPUT_GRAPH_EDGE_DESC maskToMhaEdge = {}; + maskToMhaEdge.GraphInputIndex = InputIndex::maskIndex; + maskToMhaEdge.ToNodeIndex = mhaNodeIndex; + maskToMhaEdge.ToNodeInputIndex = mhaMaskIndex; + inputEdges.push_back(maskToMhaEdge); + } + else if (hasMaxSequenceMask) + { + DML_INPUT_GRAPH_EDGE_DESC maskToMaskSliceEdge = {}; + maskToMaskSliceEdge.GraphInputIndex = InputIndex::maskIndex; + maskToMaskSliceEdge.ToNodeIndex = maskSliceNodeIndex; + maskToMaskSliceEdge.ToNodeInputIndex = 0; + inputEdges.push_back(maskToMaskSliceEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC maskSliceToMhaEdge = {}; + maskSliceToMhaEdge.FromNodeIndex = maskSliceNodeIndex; + maskSliceToMhaEdge.FromNodeOutputIndex = 0; + maskSliceToMhaEdge.ToNodeIndex = mhaNodeIndex; + maskSliceToMhaEdge.ToNodeInputIndex = mhaMaskIndex; + intermediateEdges.push_back(maskSliceToMhaEdge); + } + else + { + DML_INPUT_GRAPH_EDGE_DESC maskToMhaEdge = {}; + maskToMhaEdge.GraphInputIndex = InputIndex::maskIndex; + maskToMhaEdge.ToNodeIndex = mhaNodeIndex; + maskToMhaEdge.ToNodeInputIndex = mhaMaskIndex; + inputEdges.push_back(maskToMhaEdge); + } + } + else if (unidirectional) + { + DML_INTERMEDIATE_GRAPH_EDGE_DESC causalMaskToMhaEdge = {}; + causalMaskToMhaEdge.FromNodeIndex = causalMaskNodeIndex; + causalMaskToMhaEdge.FromNodeOutputIndex = 0; + causalMaskToMhaEdge.ToNodeIndex = mhaNodeIndex; + causalMaskToMhaEdge.ToNodeInputIndex = mhaMaskIndex; + intermediateEdges.push_back(causalMaskToMhaEdge); + } + + if (hasPast) + { + DML_INPUT_GRAPH_EDGE_DESC pastToPastKeySliceEdge = {}; + pastToPastKeySliceEdge.GraphInputIndex = InputIndex::pastIndex; + pastToPastKeySliceEdge.ToNodeIndex = pastKeySliceNodeIndex; + pastToPastKeySliceEdge.ToNodeInputIndex = 0; + inputEdges.push_back(pastToPastKeySliceEdge); + + DML_INPUT_GRAPH_EDGE_DESC pastToPastValueSliceEdge = {}; + pastToPastValueSliceEdge.GraphInputIndex = InputIndex::pastIndex; + pastToPastValueSliceEdge.ToNodeIndex = pastValueSliceNodeIndex; + pastToPastValueSliceEdge.ToNodeInputIndex = 0; + inputEdges.push_back(pastToPastValueSliceEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC pastKeyToMhaEdge = {}; + pastKeyToMhaEdge.FromNodeIndex = pastKeySliceNodeIndex; + pastKeyToMhaEdge.FromNodeOutputIndex = 0; + pastKeyToMhaEdge.ToNodeIndex = mhaNodeIndex; + pastKeyToMhaEdge.ToNodeInputIndex = mhaPastKeyIndex; + intermediateEdges.push_back(pastKeyToMhaEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC pastValueToMhaEdge = {}; + pastValueToMhaEdge.FromNodeIndex = pastValueSliceNodeIndex; + pastValueToMhaEdge.FromNodeOutputIndex = 0; + pastValueToMhaEdge.ToNodeIndex = mhaNodeIndex; + pastValueToMhaEdge.ToNodeInputIndex = mhaPastValueIndex; + intermediateEdges.push_back(pastValueToMhaEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC presentKeyToConcatEdge = {}; + presentKeyToConcatEdge.FromNodeIndex = mhaNodeIndex; + presentKeyToConcatEdge.FromNodeOutputIndex = mhaPresentKeyIndex; + presentKeyToConcatEdge.ToNodeIndex = concatNodeIndex; + presentKeyToConcatEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(presentKeyToConcatEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC presentValueToConcatEdge = {}; + presentValueToConcatEdge.FromNodeIndex = mhaNodeIndex; + presentValueToConcatEdge.FromNodeOutputIndex = mhaPresentValueIndex; + presentValueToConcatEdge.ToNodeIndex = concatNodeIndex; + presentValueToConcatEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(presentValueToConcatEdge); + } + + DML_INTERMEDIATE_GRAPH_EDGE_DESC matMulIntToFloatToQueryKeyValueTransposeEdge = {}; + matMulIntToFloatToQueryKeyValueTransposeEdge.FromNodeIndex = matMulIntToFloatNodeIndex; + matMulIntToFloatToQueryKeyValueTransposeEdge.FromNodeOutputIndex = 0; + matMulIntToFloatToQueryKeyValueTransposeEdge.ToNodeIndex = queryKeyValueTransposedNodeIndex; + matMulIntToFloatToQueryKeyValueTransposeEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(matMulIntToFloatToQueryKeyValueTransposeEdge); + + // All we need to do here is transpose the stacked QKV tensor into something DML supports + DML_INTERMEDIATE_GRAPH_EDGE_DESC queryKeyValueTransposedToMhaEdge = {}; + queryKeyValueTransposedToMhaEdge.FromNodeIndex = queryKeyValueTransposedNodeIndex; + queryKeyValueTransposedToMhaEdge.FromNodeOutputIndex = 0; + queryKeyValueTransposedToMhaEdge.ToNodeIndex = mhaNodeIndex; + queryKeyValueTransposedToMhaEdge.ToNodeInputIndex = mhaStackedQueryKeyValueIndex; + intermediateEdges.push_back(queryKeyValueTransposedToMhaEdge); + + DML_OUTPUT_GRAPH_EDGE_DESC mhaToOutputEdge = {}; + mhaToOutputEdge.FromNodeIndex = mhaNodeIndex; + mhaToOutputEdge.FromNodeOutputIndex = mhaOutputIndex; + mhaToOutputEdge.GraphOutputIndex = OutputIndex::outputIndex; + outputEdges.push_back(mhaToOutputEdge); + + if (hasPast) + { + DML_OUTPUT_GRAPH_EDGE_DESC concatToOutputEdge = {}; + concatToOutputEdge.FromNodeIndex = concatNodeIndex; + concatToOutputEdge.FromNodeOutputIndex = 0; + concatToOutputEdge.GraphOutputIndex = OutputIndex::presentIndex; + outputEdges.push_back(concatToOutputEdge); + } + + MLOperatorGraphDesc operatorGraphDesc = {}; + operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); + operatorGraphDesc.inputEdges = inputEdges.data(); + operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size()); + operatorGraphDesc.intermediateEdges = intermediateEdges.data(); + operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); + operatorGraphDesc.outputEdges = outputEdges.data(); + operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); + operatorGraphDesc.nodes = opDescs.data(); + + SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); + } +}; + +void CALLBACK QueryQAttention(IMLOperatorSupportQueryContextPrivate* context, /*out*/ bool* isSupported) +{ + *isSupported = false; + + // `unidirectional == 1` with Mask Tensor is not supported yet + MLOperatorAttributes attributes(context); + if (attributes.GetOptionalAttribute(AttrName::Unidirectional, 0) != 0 && context->IsInputValid(5)) + { + return; + } + + // `do_rotary == 1` is not supported yet + if (attributes.GetOptionalAttribute(AttrName::DoRotary, 0) != 0) + { + return; + } + + // `past_present_share_buffer == 1` is not supported yet + if (attributes.GetOptionalAttribute(AttrName::PastPresentShareBuffer, 0) != 0) + { + return; + } + + *isSupported = true; +} + +DML_OP_DEFINE_CREATION_FUNCTION(QAttention, DmlOperatorQAttention); +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp index c97b03dc36b62..8727610ff3112 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp @@ -166,7 +166,7 @@ class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper MLOperatorGraphDesc operatorGraphDesc = {}; operatorGraphDesc.nodeCount = static_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); uint32_t joinNodeIndex = operatorGraphDesc.nodeCount - 2; uint32_t quantizeNodeIndex = operatorGraphDesc.nodeCount - 1; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp index 35f926d62c92a..bc0082fef3496 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp @@ -113,7 +113,7 @@ class DmlOperatorQLinearSigmoid : public DmlOperator MLOperatorGraphDesc operatorGraphDesc = {}; operatorGraphDesc.nodeCount = 3; std::vector opDescs{&opDesc1, &opDesc2, &opDesc3}; - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); // set input edges std::pair nodeToNodeInputIndex[5] {{0, 0}, {0, 1}, {0, 2}, {2, 1}, {2, 2}}; @@ -178,4 +178,4 @@ void CALLBACK QueryQLinearSigmoid(IMLOperatorSupportQueryContextPrivate* context } DML_OP_DEFINE_CREATION_FUNCTION(QLinearSigmoid, DmlOperatorQLinearSigmoid); -} // namespace Dml +} // namespace Dml \ No newline at end of file diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp index 3683ab7b0b0b3..e62b7d707ba78 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp @@ -123,7 +123,7 @@ class DmlOperatorQuickGelu : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp index f332fac9d3a09..b7cceb1d1d998 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp @@ -9,11 +9,12 @@ namespace Dml constexpr NameAndIndex coordinateTransformationModes[] = { {"half_pixel", 0}, - {"pytorch_half_pixel", 1}, - {"align_corners", 2}, - {"asymmetric", 3}, - {"tf_half_pixel_for_nn", 4}, - {"tf_crop_and_resize", 5}, + {"half_pixel_symmetric", 1}, + {"pytorch_half_pixel", 2}, + {"align_corners", 3}, + {"asymmetric", 4}, + {"tf_half_pixel_for_nn", 5}, + {"tf_crop_and_resize", 6}, }; constexpr NameAndIndex nearestNeighborRoundingModes[] = @@ -50,7 +51,7 @@ void ComputePixelOffsetsAndScales( uint32_t coordinateTransformationModeValue = *optionalCoordinateTransformationModeValue; ML_CHECK_VALID_ARGUMENT( - !regionOfInterest.empty() || coordinateTransformationModeValue != 5 /*tf_crop_and_resize*/, + !regionOfInterest.empty() || coordinateTransformationModeValue != 6 /*tf_crop_and_resize*/, "Resize expects 'roi' tensor for 'tf_crop_and_resize' mode." ); @@ -88,6 +89,18 @@ void ComputePixelOffsetsAndScales( break; case 1: + // coordinate_transformation_mode is "half_pixel_symmetric", + // adjustment = output_width_int / output_width + // center = input_width / 2 + // offset = center * (1 - adjustment) + // x_original = (x + 0.5) / scale - (0.5 - offset) + // x_original = (x + 0.5) / scale - (0.5 - [(input_width / 2) * (1 - (output_width_int / output_width))]) + // output_width can be fractional when calculated with scale factor + inputPixelOffset = 0.5f - float((inputDimensions[i] / 2.0f) * (1.0f - outputDimensions[i] / (scales[i] * inputDimensions[i]))); + outputPixelOffset = -0.5; + break; + + case 2: // if coordinate_transformation_mode is "pytorch_half_pixel", // x_original = length_resized > 1 ? (x_resized + 0.5) / scale - 0.5 : 0 if (inputDimensions[i] <= 1) @@ -104,7 +117,7 @@ void ComputePixelOffsetsAndScales( } break; - case 2: + case 3: // if coordinate_transformation_mode is "align_corners", // x_original = x_resized * (length_original - 1) / (length_resized - 1) inputPixelOffset = 0.0; @@ -121,7 +134,7 @@ void ComputePixelOffsetsAndScales( } break; - case 3: + case 4: // if coordinate_transformation_mode is "asymmetric", // x_original = x_resized / scale inputPixelOffset = 0.0; @@ -129,7 +142,7 @@ void ComputePixelOffsetsAndScales( // Keep existing scales. break; - case 4: + case 5: // if coordinate_transformation_mode is "tf_half_pixel_for_nn", // x_original = (x_resized + 0.5) / scale inputPixelOffset = 0.0; @@ -137,7 +150,7 @@ void ComputePixelOffsetsAndScales( // Keep existing scales. break; - case 5: + case 6: // if coordinate_transformation_mode is "tf_crop_and_resize", // x_original = length_resized > 1 ? start_x * (length_original - 1) + x_resized * (end_x - start_x) * (length_original - 1) / (length_resized - 1) // : 0.5 * (start_x + end_x) * (length_original - 1) @@ -177,7 +190,7 @@ class DmlOperatorResize : public DmlOperator, public ResizeHelper public: // Resample a multidimensional image to a new size. DmlOperatorResize(const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t opsetVersion) - : DmlOperator(kernelCreationContext), + : DmlOperator(kernelCreationContext), ResizeHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription(), opsetVersion) { ML_CHECK_VALID_ARGUMENT(!m_scales.empty(), "Resize/Upsample expect scales, either a 2nd input tensors or 'scales' attribute."); @@ -250,6 +263,11 @@ class DmlOperatorResize : public DmlOperator, public ResizeHelper std::string mode = kernelCreationContext.GetOptionalAttribute(AttrName::Mode, "NEAREST"); DML_INTERPOLATION_MODE interpolationMode = Dml::MapStringToInteropolationMode(mode); + +#if DML_TARGET_VERSION >= 0x6300 + const int antialiased = kernelCreationContext.GetOptionalAttribute(AttrName::Antialiased, 0); +#endif + // Map ONNX to DML's mode using offsets and rounding direction. // These offsets are in addition to the coordinate transform offsets. DML_AXIS_DIRECTION roundingDirection = DML_AXIS_DIRECTION_DECREASING; @@ -289,7 +307,12 @@ class DmlOperatorResize : public DmlOperator, public ResizeHelper std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); +#if DML_TARGET_VERSION >= 0x6300 + DML_RESAMPLE3_OPERATOR_DESC operatorDesc = {}; + operatorDesc.Antialiased = static_cast(antialiased); +#else DML_RESAMPLE2_OPERATOR_DESC operatorDesc = {}; +#endif operatorDesc.InputTensor = inputDescs.data(); operatorDesc.OutputTensor = outputDescs.data(); operatorDesc.InterpolationMode = interpolationMode; @@ -298,8 +321,11 @@ class DmlOperatorResize : public DmlOperator, public ResizeHelper operatorDesc.DimensionCount = gsl::narrow_cast(paddedScales.size()); operatorDesc.InputPixelOffsets = inputPixelOffsets.data(); operatorDesc.OutputPixelOffsets = outputPixelOffsets.data(); - +#if DML_TARGET_VERSION >= 0x6300 + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_RESAMPLE3, &operatorDesc }; +#else DML_OPERATOR_DESC opDesc = { DML_OPERATOR_RESAMPLE2, &operatorDesc }; +#endif SetDmlOperatorDesc(opDesc, kernelCreationContext); } }; @@ -342,6 +368,10 @@ void CALLBACK QueryResize(IMLOperatorSupportQueryContextPrivate* context, bool* DML_OP_DEFINE_CREATION_FUNCTION(Resize10, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Resize11, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Resize13, VersionedKernel); +#if DML_TARGET_VERSION >= 0x6300 +DML_OP_DEFINE_CREATION_FUNCTION(Resize18, VersionedKernel); +DML_OP_DEFINE_CREATION_FUNCTION(Resize19, VersionedKernel); +#endif DML_OP_DEFINE_CREATION_FUNCTION(Upsample7, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Upsample9, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Upsample10, VersionedKernel); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp index 30c339b845b36..0f15ebf342b3a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp @@ -43,6 +43,10 @@ class DmlOperatorRotaryEmbedding : public DmlOperator ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 4); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); + // When the input is 4D, it has the shape [batchSize, numHeads, sequenceLength, headSize]. Otherwise, + // it has the shape [batchSize, sequenceLength, hiddenSize] + const bool inputIs4D = kernelInfo.GetInputTensorDimensionCount(inputDataIndex) == 4; + // When positionIds is a scalar, it represents the start offset for each sequence const bool positionIdsIsOffset = kernelInfo.GetInputTensorDimensionCount(positionIdsIndex) == 1; @@ -63,9 +67,9 @@ class DmlOperatorRotaryEmbedding : public DmlOperator // We resize the data to be of shape [batchSize, sequenceLength, numHeads, headSize] const auto inputDataSizes = m_inputTensorDescs[inputDataIndex].GetSizes(); - const uint32_t batchSize = inputDataSizes[1]; + const uint32_t batchSize = inputIs4D ? inputDataSizes[0] : inputDataSizes[1]; const uint32_t sequenceLength = inputDataSizes[2]; - const uint32_t numHeads = inputDataSizes[3] / headSize; + const uint32_t numHeads = inputIs4D ? inputDataSizes[1] : inputDataSizes[3] / headSize; const auto cosCacheSizes = m_inputTensorDescs[cosCacheIndex].GetSizes(); const uint32_t maxSequenceLength = cosCacheSizes[cosCacheSizes.size() - 2]; @@ -80,16 +84,24 @@ class DmlOperatorRotaryEmbedding : public DmlOperator std::vector inputDescs = GetDmlInputDescs(); const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType; - // Splitting the hiddenSize into numHeads and headSize dimensions makes it easier for DML to handle const std::array inputOutputShape = {batchSize, sequenceLength, numHeads, headSize}; TensorDesc inputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape); + TensorDesc stridedInputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape); + + if (inputIs4D) + { + const std::array inputOutputStrides = {headSize * numHeads * sequenceLength, headSize, sequenceLength * headSize, 1}; + stridedInputOutputTensorDesc.SetStrides(inputOutputStrides); + } + const DML_TENSOR_DESC inputOutputDmlTensorDesc = inputOutputTensorDesc.GetDmlDesc(); + const DML_TENSOR_DESC stridedInputOutputDmlTensorDesc = stridedInputOutputTensorDesc.GetDmlDesc(); // Copy the input to preserve its real input shape in the graph without reshaping it. This will disappear during DML's graph compilation phase. DML_SCALE_BIAS scaleBias = {1.0f, 0.0f}; DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC copyInputDesc{}; - copyInputDesc.InputTensor = &inputOutputDmlTensorDesc; + copyInputDesc.InputTensor = &stridedInputOutputDmlTensorDesc; copyInputDesc.OutputTensor = &inputOutputDmlTensorDesc; copyInputDesc.ScaleBias = &scaleBias; const DML_OPERATOR_DESC copyInputDmlDesc = {DML_OPERATOR_ELEMENT_WISE_IDENTITY, ©InputDesc}; @@ -104,8 +116,12 @@ class DmlOperatorRotaryEmbedding : public DmlOperator : std::vector({batchSize, sequenceLength, numHeads, 1, headSize / 2}); TensorDesc inputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); + const DML_TENSOR_DESC inputDataDmlTensorDesc = inputDataTensorDesc.GetDmlDesc(); + TensorDesc joinedDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); + const DML_TENSOR_DESC joinedDataDmlTensorDesc = joinedDataTensorDesc.GetDmlDesc(); + TensorDesc splitInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputDataTensorShape); const std::array splitInputDataDmlTensorDescs = {splitInputDataTensorDesc.GetDmlDesc(), splitInputDataTensorDesc.GetDmlDesc()}; @@ -122,7 +138,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator // Swap the 2 halves and join them together DML_JOIN_OPERATOR_DESC joinInputDesc{}; joinInputDesc.InputTensors = splitInputDataDmlTensorDescs.data(); - joinInputDesc.OutputTensor = &inputDataDmlTensorDesc; + joinInputDesc.OutputTensor = &joinedDataDmlTensorDesc; joinInputDesc.Axis = splitInputDesc.Axis; joinInputDesc.InputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); const DML_OPERATOR_DESC joinInputDmlDesc = {DML_OPERATOR_JOIN, &joinInputDesc}; @@ -212,23 +228,23 @@ class DmlOperatorRotaryEmbedding : public DmlOperator const DML_TENSOR_DESC broadcastedSignDmlTensorDesc = broadcastedSignCosSinTensorDesc.GetDmlDesc(); DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulSignDesc{}; - mulSignDesc.ATensor = &inputDataDmlTensorDesc; + mulSignDesc.ATensor = &joinedDataDmlTensorDesc; mulSignDesc.BTensor = &broadcastedSignDmlTensorDesc; - mulSignDesc.OutputTensor = &inputDataDmlTensorDesc; + mulSignDesc.OutputTensor = &joinedDataDmlTensorDesc; const DML_OPERATOR_DESC mulSignDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulSignDesc}; // Multiply the non-rotated data with the cos and the rotated data with the sin DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulCosSinDesc{}; - mulCosSinDesc.ATensor = &inputDataDmlTensorDesc; + mulCosSinDesc.ATensor = &joinedDataDmlTensorDesc; mulCosSinDesc.BTensor = &broadcastedCosSinDmlTensorDesc; - mulCosSinDesc.OutputTensor = &inputDataDmlTensorDesc; + mulCosSinDesc.OutputTensor = &joinedDataDmlTensorDesc; const DML_OPERATOR_DESC mulCosSinDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulCosSinDesc}; // Add the multiplied cos and sin values together DML_ELEMENT_WISE_ADD_OPERATOR_DESC addDesc{}; addDesc.ATensor = &inputOutputDmlTensorDesc; addDesc.BTensor = &inputOutputDmlTensorDesc; - addDesc.OutputTensor = &inputOutputDmlTensorDesc; + addDesc.OutputTensor = &stridedInputOutputDmlTensorDesc; const DML_OPERATOR_DESC addDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &addDesc}; // Construct the graph @@ -425,7 +441,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp index 4dafd78f21ea8..094c45a0e38e5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp @@ -198,7 +198,7 @@ class DmlOperatorSkipLayerNormalization : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/stockham.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/stockham.h index 9c03b7f6de639..1bfd6e6c6068d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/stockham.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/stockham.h @@ -21,7 +21,7 @@ dcl_uav_structured u0, 4 dcl_uav_structured u1, 4 dcl_uav_structured u2, 4 dcl_input vThreadID.x -dcl_temps 6 +dcl_temps 5 dcl_thread_group 64, 1, 1 iadd r0.x, vThreadID.x, cb0[0].x ult r0.y, r0.x, cb0[0].y @@ -40,66 +40,57 @@ if_nz r0.y ieq r1.y, cb0[7].x, l(1) ult r1.z, r0.w, cb0[5].z and r1.z, r1.z, r1.y - if_nz r1.z - imul null, r1.z, r0.w, cb0[6].z - ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r4.x, r1.z, l(0), u2.xxxx - imad r1.z, r0.w, cb0[6].z, cb0[6].w - ieq r1.w, cb0[5].w, l(2) - ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r1.z, r1.z, l(0), u2.xxxx - and r4.y, r1.z, r1.w + imul null, r1.w, r0.w, cb0[6].z + ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r4.x, r1.w, l(0), u2.xxxx + ieq r1.w, cb0[5].w, l(2) + if_nz r1.w + imad r2.y, r0.w, cb0[6].z, cb0[6].w + ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r4.y, r2.y, l(0), u2.xxxx else - mov r4.xy, l(1.000000,0,0,0) + mov r4.y, l(0) endif + movc r2.yz, r1.zzzz, r4.yyxy, l(0,0,1.000000,0) ult r1.z, r0.w, cb0[1].y - if_nz r1.z - imul null, r0.w, r0.w, cb0[2].y - imad r0.w, r1.x, cb0[2].x, r0.w - imad r0.w, r3.x, cb0[2].z, r0.w - ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r5.x, r0.w, l(0), u0.xxxx - ieq r1.z, cb0[1].w, l(2) - if_nz r1.z - iadd r0.w, r0.w, cb0[2].w - ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r5.y, r0.w, l(0), u0.xxxx - else - mov r5.y, l(0) - endif + imul null, r1.x, r1.x, cb0[2].x + imad r0.w, r0.w, cb0[2].y, r1.x + imad r0.w, r3.x, cb0[2].z, r0.w + ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r4.x, r0.w, l(0), u0.xxxx + ieq r2.w, cb0[1].w, l(2) + if_nz r2.w + iadd r0.w, r0.w, cb0[2].w + ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r4.y, r0.w, l(0), u0.xxxx else - mov r5.xy, l(0,0,0,0) + mov r4.y, l(0) endif - mul r0.w, r4.y, r5.y - mad r0.w, r5.x, r4.x, -r0.w - dp2 r1.z, r5.yxyy, r4.xyxx - ult r1.w, r0.y, cb0[5].z - and r1.y, r1.w, r1.y - if_nz r1.y - imul null, r1.y, r0.y, cb0[6].z - ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r4.x, r1.y, l(0), u2.xxxx - imad r1.y, r0.y, cb0[6].z, cb0[6].w - ieq r1.w, cb0[5].w, l(2) - ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r1.y, r1.y, l(0), u2.xxxx - and r4.y, r1.y, r1.w + and r3.yz, r1.zzzz, r4.xxyx + mul r0.w, r2.y, r3.z + mad r0.w, r3.y, r2.z, -r0.w + dp2 r1.z, r3.yzyy, r2.yzyy + ult r2.y, r0.y, cb0[5].z + and r1.y, r1.y, r2.y + imul null, r2.y, r0.y, cb0[6].z + ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r4.x, r2.y, l(0), u2.xxxx + if_nz r1.w + imad r1.w, r0.y, cb0[6].z, cb0[6].w + ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r4.y, r1.w, l(0), u2.xxxx else - mov r4.xy, l(1.000000,0,0,0) + mov r4.y, l(0) endif - ult r1.y, r0.y, cb0[1].y - if_nz r1.y - imul null, r0.y, r0.y, cb0[2].y - imad r0.y, r1.x, cb0[2].x, r0.y - imad r0.y, r3.x, cb0[2].z, r0.y - ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r1.x, r0.y, l(0), u0.xxxx - ieq r1.w, cb0[1].w, l(2) - if_nz r1.w - iadd r0.y, r0.y, cb0[2].w - ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r1.y, r0.y, l(0), u0.xxxx - else - mov r1.y, l(0) - endif + movc r1.yw, r1.yyyy, r4.yyyx, l(0,0,0,1.000000) + ult r2.y, r0.y, cb0[1].y + imad r0.y, r0.y, cb0[2].y, r1.x + imad r0.y, r3.x, cb0[2].z, r0.y + ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r3.x, r0.y, l(0), u0.xxxx + if_nz r2.w + iadd r0.y, r0.y, cb0[2].w + ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r3.y, r0.y, l(0), u0.xxxx else - mov r1.xy, l(0,0,0,0) + mov r3.y, l(0) endif - mul r0.y, r4.y, r1.y - mad r0.y, r1.x, r4.x, -r0.y - dp2 r1.x, r1.yxyy, r4.xyxx + and r2.yz, r2.yyyy, r3.xxyx + mul r0.y, r1.y, r2.z + mad r0.y, r2.y, r1.w, -r0.y + dp2 r1.x, r2.yzyy, r1.ywyy udiv null, r1.y, r2.x, r0.z ieq r1.w, cb0[0].w, l(1) movc r1.w, r1.w, l(6.283185), l(-6.283185) @@ -117,17 +108,22 @@ if_nz r0.y mad r0.y, r3.x, r1.x, r0.y add r0.y, r0.y, r1.z mul r0.yw, r0.yyyw, cb0[7].zzzz - ne r1.x, cb0[7].y, l(0.000000) - mul r1.y, r1.y, r1.y - mul r1.y, r1.y, l(3.141593) - div r1.y, r1.y, cb0[7].y - sincos r2.x, r3.x, r1.y - mov r2.y, r3.x - movc r1.xy, r1.xxxx, r2.xyxx, l(0,1.000000,0,0) - mul r1.zw, r0.yyyy, r1.xxxy - mad r0.y, r0.w, r1.y, -r1.z - store_structured u1.x, r0.z, l(0), r0.y - mad r0.y, r0.w, r1.x, r1.w + eq r1.x, cb0[7].y, l(0.000000) + if_nz r1.x + mov r1.x, r0.w + else + ne r1.z, cb0[7].y, l(0.000000) + mul r1.y, r1.y, r1.y + mul r1.y, r1.y, l(3.141593) + div r1.y, r1.y, cb0[7].y + sincos r2.x, r3.x, r1.y + mov r2.y, r3.x + movc r1.yz, r1.zzzz, r2.xxyx, l(0,0,1.000000,0) + mul r2.xy, r0.yyyy, r1.yzyy + mad r1.x, r0.w, r1.z, -r2.x + mad r0.y, r0.w, r1.y, r2.y + endif + store_structured u1.x, r0.z, l(0), r1.x store_structured u1.x, r0.x, l(0), r0.y endif ret @@ -136,11 +132,11 @@ ret const BYTE g_DFT[] = { - 68, 88, 66, 67, 222, 156, - 188, 133, 179, 57, 118, 25, - 122, 216, 102, 13, 91, 242, - 99, 27, 1, 0, 0, 0, - 172, 12, 0, 0, 3, 0, + 68, 88, 66, 67, 63, 188, + 200, 227, 206, 73, 64, 21, + 140, 126, 47, 226, 169, 81, + 175, 134, 1, 0, 0, 0, + 112, 12, 0, 0, 3, 0, 0, 0, 44, 0, 0, 0, 60, 0, 0, 0, 76, 0, 0, 0, 73, 83, 71, 78, @@ -149,8 +145,8 @@ const BYTE g_DFT[] = 79, 83, 71, 78, 8, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 83, 72, - 69, 88, 88, 12, 0, 0, - 80, 0, 5, 0, 22, 3, + 69, 88, 28, 12, 0, 0, + 80, 0, 5, 0, 7, 3, 0, 0, 106, 8, 0, 1, 89, 0, 0, 4, 70, 142, 32, 0, 0, 0, 0, 0, @@ -164,7 +160,7 @@ const BYTE g_DFT[] = 17, 0, 2, 0, 0, 0, 4, 0, 0, 0, 95, 0, 0, 2, 18, 0, 2, 0, - 104, 0, 0, 2, 6, 0, + 104, 0, 0, 2, 5, 0, 0, 0, 155, 0, 0, 4, 64, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, @@ -256,11 +252,9 @@ const BYTE g_DFT[] = 16, 0, 1, 0, 0, 0, 42, 0, 16, 0, 1, 0, 0, 0, 26, 0, 16, 0, - 1, 0, 0, 0, 31, 0, - 4, 3, 42, 0, 16, 0, 1, 0, 0, 0, 38, 0, 0, 9, 0, 208, 0, 0, - 66, 0, 16, 0, 1, 0, + 130, 0, 16, 0, 1, 0, 0, 0, 58, 0, 16, 0, 0, 0, 0, 0, 42, 128, 32, 0, 0, 0, 0, 0, @@ -268,221 +262,203 @@ const BYTE g_DFT[] = 0, 139, 2, 35, 0, 128, 131, 153, 25, 0, 18, 0, 16, 0, 4, 0, 0, 0, - 42, 0, 16, 0, 1, 0, + 58, 0, 16, 0, 1, 0, 0, 0, 1, 64, 0, 0, 0, 0, 0, 0, 6, 224, 17, 0, 2, 0, 0, 0, - 35, 0, 0, 11, 66, 0, + 32, 0, 0, 8, 130, 0, 16, 0, 1, 0, 0, 0, - 58, 0, 16, 0, 0, 0, - 0, 0, 42, 128, 32, 0, - 0, 0, 0, 0, 6, 0, - 0, 0, 58, 128, 32, 0, - 0, 0, 0, 0, 6, 0, - 0, 0, 32, 0, 0, 8, - 130, 0, 16, 0, 1, 0, - 0, 0, 58, 128, 32, 0, - 0, 0, 0, 0, 5, 0, - 0, 0, 1, 64, 0, 0, - 2, 0, 0, 0, 167, 0, + 58, 128, 32, 0, 0, 0, + 0, 0, 5, 0, 0, 0, + 1, 64, 0, 0, 2, 0, + 0, 0, 31, 0, 4, 3, + 58, 0, 16, 0, 1, 0, + 0, 0, 35, 0, 0, 11, + 34, 0, 16, 0, 2, 0, + 0, 0, 58, 0, 16, 0, + 0, 0, 0, 0, 42, 128, + 32, 0, 0, 0, 0, 0, + 6, 0, 0, 0, 58, 128, + 32, 0, 0, 0, 0, 0, + 6, 0, 0, 0, 167, 0, 0, 139, 2, 35, 0, 128, - 131, 153, 25, 0, 66, 0, - 16, 0, 1, 0, 0, 0, - 42, 0, 16, 0, 1, 0, + 131, 153, 25, 0, 34, 0, + 16, 0, 4, 0, 0, 0, + 26, 0, 16, 0, 2, 0, 0, 0, 1, 64, 0, 0, 0, 0, 0, 0, 6, 224, 17, 0, 2, 0, 0, 0, - 1, 0, 0, 7, 34, 0, - 16, 0, 4, 0, 0, 0, - 42, 0, 16, 0, 1, 0, - 0, 0, 58, 0, 16, 0, - 1, 0, 0, 0, 18, 0, - 0, 1, 54, 0, 0, 8, - 50, 0, 16, 0, 4, 0, + 18, 0, 0, 1, 54, 0, + 0, 5, 34, 0, 16, 0, + 4, 0, 0, 0, 1, 64, + 0, 0, 0, 0, 0, 0, + 21, 0, 0, 1, 55, 0, + 0, 12, 98, 0, 16, 0, + 2, 0, 0, 0, 166, 10, + 16, 0, 1, 0, 0, 0, + 86, 4, 16, 0, 4, 0, 0, 0, 2, 64, 0, 0, - 0, 0, 128, 63, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 21, 0, - 0, 1, 79, 0, 0, 8, - 66, 0, 16, 0, 1, 0, - 0, 0, 58, 0, 16, 0, - 0, 0, 0, 0, 26, 128, - 32, 0, 0, 0, 0, 0, - 1, 0, 0, 0, 31, 0, - 4, 3, 42, 0, 16, 0, - 1, 0, 0, 0, 38, 0, - 0, 9, 0, 208, 0, 0, - 130, 0, 16, 0, 0, 0, - 0, 0, 58, 0, 16, 0, - 0, 0, 0, 0, 26, 128, - 32, 0, 0, 0, 0, 0, - 2, 0, 0, 0, 35, 0, - 0, 10, 130, 0, 16, 0, - 0, 0, 0, 0, 10, 0, + 0, 0, 0, 0, 128, 63, + 0, 0, 0, 0, 79, 0, + 0, 8, 66, 0, 16, 0, + 1, 0, 0, 0, 58, 0, + 16, 0, 0, 0, 0, 0, + 26, 128, 32, 0, 0, 0, + 0, 0, 1, 0, 0, 0, + 38, 0, 0, 9, 0, 208, + 0, 0, 18, 0, 16, 0, + 1, 0, 0, 0, 10, 0, 16, 0, 1, 0, 0, 0, 10, 128, 32, 0, 0, 0, 0, 0, 2, 0, 0, 0, + 35, 0, 0, 10, 130, 0, + 16, 0, 0, 0, 0, 0, 58, 0, 16, 0, 0, 0, - 0, 0, 35, 0, 0, 10, - 130, 0, 16, 0, 0, 0, + 0, 0, 26, 128, 32, 0, + 0, 0, 0, 0, 2, 0, 0, 0, 10, 0, 16, 0, - 3, 0, 0, 0, 42, 128, + 1, 0, 0, 0, 35, 0, + 0, 10, 130, 0, 16, 0, + 0, 0, 0, 0, 10, 0, + 16, 0, 3, 0, 0, 0, + 42, 128, 32, 0, 0, 0, + 0, 0, 2, 0, 0, 0, + 58, 0, 16, 0, 0, 0, + 0, 0, 167, 0, 0, 139, + 2, 35, 0, 128, 131, 153, + 25, 0, 18, 0, 16, 0, + 4, 0, 0, 0, 58, 0, + 16, 0, 0, 0, 0, 0, + 1, 64, 0, 0, 0, 0, + 0, 0, 6, 224, 17, 0, + 0, 0, 0, 0, 32, 0, + 0, 8, 130, 0, 16, 0, + 2, 0, 0, 0, 58, 128, 32, 0, 0, 0, 0, 0, - 2, 0, 0, 0, 58, 0, + 1, 0, 0, 0, 1, 64, + 0, 0, 2, 0, 0, 0, + 31, 0, 4, 3, 58, 0, + 16, 0, 2, 0, 0, 0, + 30, 0, 0, 8, 130, 0, 16, 0, 0, 0, 0, 0, - 167, 0, 0, 139, 2, 35, - 0, 128, 131, 153, 25, 0, - 18, 0, 16, 0, 5, 0, - 0, 0, 58, 0, 16, 0, - 0, 0, 0, 0, 1, 64, - 0, 0, 0, 0, 0, 0, - 6, 224, 17, 0, 0, 0, - 0, 0, 32, 0, 0, 8, - 66, 0, 16, 0, 1, 0, + 58, 0, 16, 0, 0, 0, 0, 0, 58, 128, 32, 0, - 0, 0, 0, 0, 1, 0, - 0, 0, 1, 64, 0, 0, - 2, 0, 0, 0, 31, 0, - 4, 3, 42, 0, 16, 0, - 1, 0, 0, 0, 30, 0, - 0, 8, 130, 0, 16, 0, - 0, 0, 0, 0, 58, 0, + 0, 0, 0, 0, 2, 0, + 0, 0, 167, 0, 0, 139, + 2, 35, 0, 128, 131, 153, + 25, 0, 34, 0, 16, 0, + 4, 0, 0, 0, 58, 0, 16, 0, 0, 0, 0, 0, - 58, 128, 32, 0, 0, 0, - 0, 0, 2, 0, 0, 0, - 167, 0, 0, 139, 2, 35, - 0, 128, 131, 153, 25, 0, - 34, 0, 16, 0, 5, 0, - 0, 0, 58, 0, 16, 0, - 0, 0, 0, 0, 1, 64, - 0, 0, 0, 0, 0, 0, - 6, 224, 17, 0, 0, 0, - 0, 0, 18, 0, 0, 1, - 54, 0, 0, 5, 34, 0, - 16, 0, 5, 0, 0, 0, 1, 64, 0, 0, 0, 0, - 0, 0, 21, 0, 0, 1, - 18, 0, 0, 1, 54, 0, - 0, 8, 50, 0, 16, 0, - 5, 0, 0, 0, 2, 64, - 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - 21, 0, 0, 1, 56, 0, - 0, 7, 130, 0, 16, 0, - 0, 0, 0, 0, 26, 0, - 16, 0, 4, 0, 0, 0, - 26, 0, 16, 0, 5, 0, - 0, 0, 50, 0, 0, 10, - 130, 0, 16, 0, 0, 0, - 0, 0, 10, 0, 16, 0, - 5, 0, 0, 0, 10, 0, + 0, 0, 6, 224, 17, 0, + 0, 0, 0, 0, 18, 0, + 0, 1, 54, 0, 0, 5, + 34, 0, 16, 0, 4, 0, + 0, 0, 1, 64, 0, 0, + 0, 0, 0, 0, 21, 0, + 0, 1, 1, 0, 0, 7, + 98, 0, 16, 0, 3, 0, + 0, 0, 166, 10, 16, 0, + 1, 0, 0, 0, 6, 1, 16, 0, 4, 0, 0, 0, - 58, 0, 16, 128, 65, 0, - 0, 0, 0, 0, 0, 0, - 15, 0, 0, 7, 66, 0, - 16, 0, 1, 0, 0, 0, - 22, 5, 16, 0, 5, 0, - 0, 0, 70, 0, 16, 0, - 4, 0, 0, 0, 79, 0, - 0, 8, 130, 0, 16, 0, + 56, 0, 0, 7, 130, 0, + 16, 0, 0, 0, 0, 0, + 26, 0, 16, 0, 2, 0, + 0, 0, 42, 0, 16, 0, + 3, 0, 0, 0, 50, 0, + 0, 10, 130, 0, 16, 0, + 0, 0, 0, 0, 26, 0, + 16, 0, 3, 0, 0, 0, + 42, 0, 16, 0, 2, 0, + 0, 0, 58, 0, 16, 128, + 65, 0, 0, 0, 0, 0, + 0, 0, 15, 0, 0, 7, + 66, 0, 16, 0, 1, 0, + 0, 0, 150, 5, 16, 0, + 3, 0, 0, 0, 150, 5, + 16, 0, 2, 0, 0, 0, + 79, 0, 0, 8, 34, 0, + 16, 0, 2, 0, 0, 0, + 26, 0, 16, 0, 0, 0, + 0, 0, 42, 128, 32, 0, + 0, 0, 0, 0, 5, 0, + 0, 0, 1, 0, 0, 7, + 34, 0, 16, 0, 1, 0, + 0, 0, 26, 0, 16, 0, 1, 0, 0, 0, 26, 0, + 16, 0, 2, 0, 0, 0, + 38, 0, 0, 9, 0, 208, + 0, 0, 34, 0, 16, 0, + 2, 0, 0, 0, 26, 0, 16, 0, 0, 0, 0, 0, 42, 128, 32, 0, 0, 0, - 0, 0, 5, 0, 0, 0, - 1, 0, 0, 7, 34, 0, - 16, 0, 1, 0, 0, 0, - 58, 0, 16, 0, 1, 0, + 0, 0, 6, 0, 0, 0, + 167, 0, 0, 139, 2, 35, + 0, 128, 131, 153, 25, 0, + 18, 0, 16, 0, 4, 0, 0, 0, 26, 0, 16, 0, - 1, 0, 0, 0, 31, 0, - 4, 3, 26, 0, 16, 0, - 1, 0, 0, 0, 38, 0, - 0, 9, 0, 208, 0, 0, - 34, 0, 16, 0, 1, 0, + 2, 0, 0, 0, 1, 64, + 0, 0, 0, 0, 0, 0, + 6, 224, 17, 0, 2, 0, + 0, 0, 31, 0, 4, 3, + 58, 0, 16, 0, 1, 0, + 0, 0, 35, 0, 0, 11, + 130, 0, 16, 0, 1, 0, 0, 0, 26, 0, 16, 0, 0, 0, 0, 0, 42, 128, + 32, 0, 0, 0, 0, 0, + 6, 0, 0, 0, 58, 128, 32, 0, 0, 0, 0, 0, 6, 0, 0, 0, 167, 0, 0, 139, 2, 35, 0, 128, - 131, 153, 25, 0, 18, 0, + 131, 153, 25, 0, 34, 0, 16, 0, 4, 0, 0, 0, - 26, 0, 16, 0, 1, 0, + 58, 0, 16, 0, 1, 0, 0, 0, 1, 64, 0, 0, 0, 0, 0, 0, 6, 224, 17, 0, 2, 0, 0, 0, - 35, 0, 0, 11, 34, 0, - 16, 0, 1, 0, 0, 0, - 26, 0, 16, 0, 0, 0, - 0, 0, 42, 128, 32, 0, - 0, 0, 0, 0, 6, 0, - 0, 0, 58, 128, 32, 0, - 0, 0, 0, 0, 6, 0, - 0, 0, 32, 0, 0, 8, - 130, 0, 16, 0, 1, 0, - 0, 0, 58, 128, 32, 0, - 0, 0, 0, 0, 5, 0, - 0, 0, 1, 64, 0, 0, - 2, 0, 0, 0, 167, 0, - 0, 139, 2, 35, 0, 128, - 131, 153, 25, 0, 34, 0, + 18, 0, 0, 1, 54, 0, + 0, 5, 34, 0, 16, 0, + 4, 0, 0, 0, 1, 64, + 0, 0, 0, 0, 0, 0, + 21, 0, 0, 1, 55, 0, + 0, 12, 162, 0, 16, 0, + 1, 0, 0, 0, 86, 5, 16, 0, 1, 0, 0, 0, - 26, 0, 16, 0, 1, 0, - 0, 0, 1, 64, 0, 0, - 0, 0, 0, 0, 6, 224, - 17, 0, 2, 0, 0, 0, - 1, 0, 0, 7, 34, 0, - 16, 0, 4, 0, 0, 0, - 26, 0, 16, 0, 1, 0, - 0, 0, 58, 0, 16, 0, - 1, 0, 0, 0, 18, 0, - 0, 1, 54, 0, 0, 8, - 50, 0, 16, 0, 4, 0, + 86, 1, 16, 0, 4, 0, 0, 0, 2, 64, 0, 0, - 0, 0, 128, 63, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 21, 0, - 0, 1, 79, 0, 0, 8, - 34, 0, 16, 0, 1, 0, - 0, 0, 26, 0, 16, 0, - 0, 0, 0, 0, 26, 128, - 32, 0, 0, 0, 0, 0, - 1, 0, 0, 0, 31, 0, - 4, 3, 26, 0, 16, 0, - 1, 0, 0, 0, 38, 0, - 0, 9, 0, 208, 0, 0, - 34, 0, 16, 0, 0, 0, - 0, 0, 26, 0, 16, 0, - 0, 0, 0, 0, 26, 128, - 32, 0, 0, 0, 0, 0, - 2, 0, 0, 0, 35, 0, + 0, 0, 0, 0, 0, 0, + 0, 0, 128, 63, 79, 0, + 0, 8, 34, 0, 16, 0, + 2, 0, 0, 0, 26, 0, + 16, 0, 0, 0, 0, 0, + 26, 128, 32, 0, 0, 0, + 0, 0, 1, 0, 0, 0, + 35, 0, 0, 10, 34, 0, + 16, 0, 0, 0, 0, 0, + 26, 0, 16, 0, 0, 0, + 0, 0, 26, 128, 32, 0, + 0, 0, 0, 0, 2, 0, + 0, 0, 10, 0, 16, 0, + 1, 0, 0, 0, 35, 0, 0, 10, 34, 0, 16, 0, 0, 0, 0, 0, 10, 0, - 16, 0, 1, 0, 0, 0, - 10, 128, 32, 0, 0, 0, + 16, 0, 3, 0, 0, 0, + 42, 128, 32, 0, 0, 0, 0, 0, 2, 0, 0, 0, 26, 0, 16, 0, 0, 0, - 0, 0, 35, 0, 0, 10, - 34, 0, 16, 0, 0, 0, - 0, 0, 10, 0, 16, 0, - 3, 0, 0, 0, 42, 128, - 32, 0, 0, 0, 0, 0, - 2, 0, 0, 0, 26, 0, + 0, 0, 167, 0, 0, 139, + 2, 35, 0, 128, 131, 153, + 25, 0, 18, 0, 16, 0, + 3, 0, 0, 0, 26, 0, 16, 0, 0, 0, 0, 0, - 167, 0, 0, 139, 2, 35, - 0, 128, 131, 153, 25, 0, - 18, 0, 16, 0, 1, 0, - 0, 0, 26, 0, 16, 0, - 0, 0, 0, 0, 1, 64, - 0, 0, 0, 0, 0, 0, - 6, 224, 17, 0, 0, 0, - 0, 0, 32, 0, 0, 8, - 130, 0, 16, 0, 1, 0, - 0, 0, 58, 128, 32, 0, - 0, 0, 0, 0, 1, 0, - 0, 0, 1, 64, 0, 0, - 2, 0, 0, 0, 31, 0, + 1, 64, 0, 0, 0, 0, + 0, 0, 6, 224, 17, 0, + 0, 0, 0, 0, 31, 0, 4, 3, 58, 0, 16, 0, - 1, 0, 0, 0, 30, 0, + 2, 0, 0, 0, 30, 0, 0, 8, 34, 0, 16, 0, 0, 0, 0, 0, 26, 0, 16, 0, 0, 0, 0, 0, @@ -490,39 +466,37 @@ const BYTE g_DFT[] = 0, 0, 2, 0, 0, 0, 167, 0, 0, 139, 2, 35, 0, 128, 131, 153, 25, 0, - 34, 0, 16, 0, 1, 0, + 34, 0, 16, 0, 3, 0, 0, 0, 26, 0, 16, 0, 0, 0, 0, 0, 1, 64, 0, 0, 0, 0, 0, 0, 6, 224, 17, 0, 0, 0, 0, 0, 18, 0, 0, 1, 54, 0, 0, 5, 34, 0, - 16, 0, 1, 0, 0, 0, + 16, 0, 3, 0, 0, 0, 1, 64, 0, 0, 0, 0, 0, 0, 21, 0, 0, 1, - 18, 0, 0, 1, 54, 0, - 0, 8, 50, 0, 16, 0, - 1, 0, 0, 0, 2, 64, - 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - 21, 0, 0, 1, 56, 0, + 1, 0, 0, 7, 98, 0, + 16, 0, 2, 0, 0, 0, + 86, 5, 16, 0, 2, 0, + 0, 0, 6, 1, 16, 0, + 3, 0, 0, 0, 56, 0, 0, 7, 34, 0, 16, 0, 0, 0, 0, 0, 26, 0, - 16, 0, 4, 0, 0, 0, - 26, 0, 16, 0, 1, 0, + 16, 0, 1, 0, 0, 0, + 42, 0, 16, 0, 2, 0, 0, 0, 50, 0, 0, 10, 34, 0, 16, 0, 0, 0, - 0, 0, 10, 0, 16, 0, - 1, 0, 0, 0, 10, 0, - 16, 0, 4, 0, 0, 0, + 0, 0, 26, 0, 16, 0, + 2, 0, 0, 0, 58, 0, + 16, 0, 1, 0, 0, 0, 26, 0, 16, 128, 65, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 7, 18, 0, 16, 0, 1, 0, 0, 0, - 22, 5, 16, 0, 1, 0, - 0, 0, 70, 0, 16, 0, - 4, 0, 0, 0, 78, 0, + 150, 5, 16, 0, 2, 0, + 0, 0, 214, 5, 16, 0, + 1, 0, 0, 0, 78, 0, 0, 8, 0, 208, 0, 0, 34, 0, 16, 0, 1, 0, 0, 0, 10, 0, 16, 0, @@ -610,65 +584,77 @@ const BYTE g_DFT[] = 16, 0, 0, 0, 0, 0, 166, 138, 32, 0, 0, 0, 0, 0, 7, 0, 0, 0, - 57, 0, 0, 8, 18, 0, + 24, 0, 0, 8, 18, 0, 16, 0, 1, 0, 0, 0, 26, 128, 32, 0, 0, 0, 0, 0, 7, 0, 0, 0, 1, 64, 0, 0, 0, 0, + 0, 0, 31, 0, 4, 3, + 10, 0, 16, 0, 1, 0, + 0, 0, 54, 0, 0, 5, + 18, 0, 16, 0, 1, 0, + 0, 0, 58, 0, 16, 0, + 0, 0, 0, 0, 18, 0, + 0, 1, 57, 0, 0, 8, + 66, 0, 16, 0, 1, 0, + 0, 0, 26, 128, 32, 0, + 0, 0, 0, 0, 7, 0, + 0, 0, 1, 64, 0, 0, + 0, 0, 0, 0, 56, 0, + 0, 7, 34, 0, 16, 0, + 1, 0, 0, 0, 26, 0, + 16, 0, 1, 0, 0, 0, + 26, 0, 16, 0, 1, 0, 0, 0, 56, 0, 0, 7, 34, 0, 16, 0, 1, 0, 0, 0, 26, 0, 16, 0, - 1, 0, 0, 0, 26, 0, - 16, 0, 1, 0, 0, 0, - 56, 0, 0, 7, 34, 0, + 1, 0, 0, 0, 1, 64, + 0, 0, 219, 15, 73, 64, + 14, 0, 0, 8, 34, 0, 16, 0, 1, 0, 0, 0, 26, 0, 16, 0, 1, 0, - 0, 0, 1, 64, 0, 0, - 219, 15, 73, 64, 14, 0, - 0, 8, 34, 0, 16, 0, - 1, 0, 0, 0, 26, 0, + 0, 0, 26, 128, 32, 0, + 0, 0, 0, 0, 7, 0, + 0, 0, 77, 0, 0, 7, + 18, 0, 16, 0, 2, 0, + 0, 0, 18, 0, 16, 0, + 3, 0, 0, 0, 26, 0, 16, 0, 1, 0, 0, 0, - 26, 128, 32, 0, 0, 0, - 0, 0, 7, 0, 0, 0, - 77, 0, 0, 7, 18, 0, + 54, 0, 0, 5, 34, 0, 16, 0, 2, 0, 0, 0, - 18, 0, 16, 0, 3, 0, - 0, 0, 26, 0, 16, 0, - 1, 0, 0, 0, 54, 0, - 0, 5, 34, 0, 16, 0, - 2, 0, 0, 0, 10, 0, - 16, 0, 3, 0, 0, 0, - 55, 0, 0, 12, 50, 0, - 16, 0, 1, 0, 0, 0, - 6, 0, 16, 0, 1, 0, - 0, 0, 70, 0, 16, 0, - 2, 0, 0, 0, 2, 64, + 10, 0, 16, 0, 3, 0, + 0, 0, 55, 0, 0, 12, + 98, 0, 16, 0, 1, 0, + 0, 0, 166, 10, 16, 0, + 1, 0, 0, 0, 6, 1, + 16, 0, 2, 0, 0, 0, + 2, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 63, 0, 0, - 0, 0, 0, 0, 0, 0, - 56, 0, 0, 7, 194, 0, + 0, 0, 56, 0, 0, 7, + 50, 0, 16, 0, 2, 0, + 0, 0, 86, 5, 16, 0, + 0, 0, 0, 0, 150, 5, 16, 0, 1, 0, 0, 0, - 86, 5, 16, 0, 0, 0, - 0, 0, 6, 4, 16, 0, - 1, 0, 0, 0, 50, 0, - 0, 10, 34, 0, 16, 0, + 50, 0, 0, 10, 18, 0, + 16, 0, 1, 0, 0, 0, + 58, 0, 16, 0, 0, 0, + 0, 0, 42, 0, 16, 0, + 1, 0, 0, 0, 10, 0, + 16, 128, 65, 0, 0, 0, + 2, 0, 0, 0, 50, 0, + 0, 9, 34, 0, 16, 0, 0, 0, 0, 0, 58, 0, 16, 0, 0, 0, 0, 0, 26, 0, 16, 0, 1, 0, - 0, 0, 42, 0, 16, 128, - 65, 0, 0, 0, 1, 0, - 0, 0, 168, 0, 0, 9, + 0, 0, 26, 0, 16, 0, + 2, 0, 0, 0, 21, 0, + 0, 1, 168, 0, 0, 9, 18, 224, 17, 0, 1, 0, 0, 0, 42, 0, 16, 0, 0, 0, 0, 0, 1, 64, 0, 0, 0, 0, 0, 0, - 26, 0, 16, 0, 0, 0, - 0, 0, 50, 0, 0, 9, - 34, 0, 16, 0, 0, 0, - 0, 0, 58, 0, 16, 0, - 0, 0, 0, 0, 10, 0, - 16, 0, 1, 0, 0, 0, - 58, 0, 16, 0, 1, 0, + 10, 0, 16, 0, 1, 0, 0, 0, 168, 0, 0, 9, 18, 224, 17, 0, 1, 0, 0, 0, 10, 0, 16, 0, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/stockham_fp16.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/stockham_fp16.h index 988c0aa66ade2..56ce759875687 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/stockham_fp16.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/stockham_fp16.h @@ -15,7 +15,7 @@ ; Name Index Mask Register SysValue Format Used ; -------------------- ----- ------ -------- -------- ------- ------ ; no parameters -; shader hash: e08f21199c48b0db30bf21bd8c5b80dc +; shader hash: 6a1d88feb14177832f5ee49ca330c549 ; ; Pipeline Runtime Information: ; @@ -125,7 +125,7 @@ define void @DFT() { %47 = fpext half %46 to float %48 = extractvalue %dx.types.CBufRet.i32 %37, 3 %49 = icmp eq i32 %48, 2 - br i1 %49, label %50, label %56 + br i1 %49, label %50, label %56, !dx.controlflow.hints !15 ;